Commit 0b9c595b authored by Matthieu Dorier's avatar Matthieu Dorier

added numpy support

parent 2acbe4b5
...@@ -128,6 +128,17 @@ class BakeProviderHandle(): ...@@ -128,6 +128,17 @@ class BakeProviderHandle():
""" """
return _pybakeclient.write(self._ph, rid._rid, offset, data) return _pybakeclient.write(self._ph, rid._rid, offset, data)
def write_numpy(self, rid, offset, array):
"""
Writes a numpy array in a region at a specified offset.
Args:
rid (BakeRegionID): region in which to write.
offset (int): offset at which to write.
data (numpy.ndarray): numpy array to write.
"""
return _pybakeclient.write_numpy(self._ph, rid._rid, offset, array)
def persist(self, rid): def persist(self, rid):
""" """
Make the changes to a given region persist. Make the changes to a given region persist.
...@@ -149,6 +160,24 @@ class BakeProviderHandle(): ...@@ -149,6 +160,24 @@ class BakeProviderHandle():
data (str): data to write. data (str): data to write.
""" """
rid = _pybakeclient.create_write_persist(self._ph, bti._tid, data) rid = _pybakeclient.create_write_persist(self._ph, bti._tid, data)
if(rid is None):
return None
return BakeRegionID(rid)
def create_write_persist_numpy(self, bti, array):
"""
Creates a new region, write the numpy array to it at a given offset,
and persist the region.
Args:
bti (BakeTargetID): target id in which to create the region.
size (int): size of the region to create.
offset (int): offset at which to write data in the region.
array (numpy.ndarray): numpy array to write.
"""
rid = _pybakeclient.create_write_persist_numpy(self._ph, bti._tid, array)
if(rid is None):
return None
return BakeRegionID(rid) return BakeRegionID(rid)
def get_size(self, rid): def get_size(self, rid):
...@@ -184,6 +213,25 @@ class BakeProviderHandle(): ...@@ -184,6 +213,25 @@ class BakeProviderHandle():
size = self.get_size(rid) - offset size = self.get_size(rid) - offset
return _pybakeclient.read(self._ph, rid._rid, offset, size) return _pybakeclient.read(self._ph, rid._rid, offset, size)
def read_numpy(self, rid, offset, shape, dtype):
"""
Reads the data contained in a given region, at a given offset,
and interpret it as a numpy array of a given shape and datatype.
This function will fail if the full array cannot be loaded
(e.g. the size of the region from the provided offset is too small
compared with the size of the numpy that should result from the call)
Args:
rid (BakeRegionID): region id.
offset (int): offset at which to read.
shape (tuple): shape of the resulting array.
dtype (numpy.dtype): datatype of the resuling array.
Returns:
A numpy array or None if it could not be read.
"""
return _pybakeclient.read_numpy(self._ph, rid._rid, offset, shape, dtype)
def remove(self, rid): def remove(self, rid):
""" """
Remove a region from its target. Remove a region from its target.
......
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
#include <margo.h> #include <margo.h>
#include <bake.h> #include <bake.h>
#include <bake-client.h> #include <bake-client.h>
#if HAS_NUMPY
#include <boost/python/numpy.hpp>
namespace np = boost::python::numpy;
#endif
BOOST_PYTHON_OPAQUE_SPECIALIZED_TYPE_ID(margo_instance) BOOST_PYTHON_OPAQUE_SPECIALIZED_TYPE_ID(margo_instance)
BOOST_PYTHON_OPAQUE_SPECIALIZED_TYPE_ID(bake_provider_handle) BOOST_PYTHON_OPAQUE_SPECIALIZED_TYPE_ID(bake_provider_handle)
...@@ -89,6 +93,28 @@ static bpl::object pybake_write( ...@@ -89,6 +93,28 @@ static bpl::object pybake_write(
else return bpl::object(false); else return bpl::object(false);
} }
#if HAS_NUMPY
static bpl::object pybake_write_numpy(
bake_provider_handle_t ph,
const bake_region_id_t& rid,
uint64_t offset,
const np::ndarray& data)
{
if(!(data.get_flags() & np::ndarray::bitflag::V_CONTIGUOUS)) {
std::cerr << "[pyBAKE error]: non-contiguous numpy arrays not yet supported" << std::endl;
return bpl::object(false);
}
size_t size = data.get_dtype().get_itemsize();
for(int i = 0; i < data.get_nd(); i++) {
size *= data.shape(i);
}
void* buffer = data.get_data();
int ret = bake_write(ph, rid, offset, buffer, size);
if(ret != 0) return bpl::object(false);
else return bpl::object(true);
}
#endif
static bpl::object pybake_persist( static bpl::object pybake_persist(
bake_provider_handle_t ph, bake_provider_handle_t ph,
const bake_region_id_t& rid) const bake_region_id_t& rid)
...@@ -110,6 +136,29 @@ static bpl::object pybake_create_write_persist( ...@@ -110,6 +136,29 @@ static bpl::object pybake_create_write_persist(
else return bpl::object(); else return bpl::object();
} }
#if HAS_NUMPY
static bpl::object pybake_create_write_persist_numpy(
bake_provider_handle_t ph,
bake_target_id_t tid,
const np::ndarray& data)
{
bake_region_id_t rid;
if(!(data.get_flags() & np::ndarray::bitflag::V_CONTIGUOUS)) {
std::cerr << "[pyBAKE error]: non-contiguous numpy arrays not yet supported" << std::endl;
return bpl::object();
}
size_t size = data.get_dtype().get_itemsize();
for(int i = 0; i < data.get_nd(); i++) {
size *= data.shape(i);
}
void* buffer = data.get_data();
int ret = bake_create_write_persist(ph, tid,
buffer, size, &rid);
if(ret == 0) return bpl::object(rid);
else return bpl::object();
}
#endif
static bpl::object pybake_get_size( static bpl::object pybake_get_size(
bake_provider_handle_t ph, bake_provider_handle_t ph,
const bake_region_id_t& rid) const bake_region_id_t& rid)
...@@ -134,14 +183,35 @@ static bpl::object pybake_read( ...@@ -134,14 +183,35 @@ static bpl::object pybake_read(
return bpl::object(result); return bpl::object(result);
} }
#if HAS_NUMPY
static bpl::object pybake_read_numpy(
bake_provider_handle_t ph,
const bake_region_id_t& rid,
uint64_t offset,
const bpl::tuple& shape,
const np::dtype& dtype)
{
np::ndarray result = np::empty(shape, dtype);
size_t size = dtype.get_itemsize();
for(int i=0; i < result.get_nd(); i++)
size *= result.shape(i);
uint64_t bytes_read;
int ret = bake_read(ph, rid, offset, (void*)result.get_data(), size, &bytes_read);
if(ret != 0) return bpl::object();
if(bytes_read != size) return bpl::object();
else return result;
}
#endif
BOOST_PYTHON_MODULE(_pybakeclient) BOOST_PYTHON_MODULE(_pybakeclient)
{ {
#define ret_policy_opaque bpl::return_value_policy<bpl::return_opaque_pointer>() #define ret_policy_opaque bpl::return_value_policy<bpl::return_opaque_pointer>()
#if HAS_NUMPY
np::initialize();
#endif
bpl::import("_pybaketarget"); bpl::import("_pybaketarget");
bpl::opaque<bake_client>(); bpl::opaque<bake_client>();
bpl::opaque<bake_provider_handle>(); bpl::opaque<bake_provider_handle>();
// bpl::class_<bake_region_id_t>("bake_region_id", bpl::no_init);
bpl::def("client_init", &pybake_client_init, ret_policy_opaque); bpl::def("client_init", &pybake_client_init, ret_policy_opaque);
bpl::def("client_finalize", &bake_client_finalize); bpl::def("client_finalize", &bake_client_finalize);
bpl::def("provider_handle_create", &pybake_provider_handle_create, ret_policy_opaque); bpl::def("provider_handle_create", &pybake_provider_handle_create, ret_policy_opaque);
...@@ -158,6 +228,11 @@ BOOST_PYTHON_MODULE(_pybakeclient) ...@@ -158,6 +228,11 @@ BOOST_PYTHON_MODULE(_pybakeclient)
bpl::def("read", &pybake_read); bpl::def("read", &pybake_read);
bpl::def("remove", &bake_remove); bpl::def("remove", &bake_remove);
bpl::def("shutdown_service", &bake_shutdown_service); bpl::def("shutdown_service", &bake_shutdown_service);
#if HAS_NUMPY
bpl::def("write_numpy", &pybake_write_numpy);
bpl::def("create_write_persist_numpy", &pybake_create_write_persist_numpy);
bpl::def("read_numpy", &pybake_read_numpy);
#endif
#undef ret_policy_opaque #undef ret_policy_opaque
} }
...@@ -10,15 +10,27 @@ os.environ['OPT'] = " ".join( ...@@ -10,15 +10,27 @@ os.environ['OPT'] = " ".join(
flag for flag in opt.split() if flag != '-Wstrict-prototypes' flag for flag in opt.split() if flag != '-Wstrict-prototypes'
) )
try:
import numpy
has_numpy = 1
except ImportError:
has_numpy = 0
if has_numpy == 1:
client_libs=['boost_python','margo','bake-client', 'boost_numpy']
else:
client_libs=['boost_python','margo','bake-client']
pybake_server_module = Extension('_pybakeserver', ["pybake/src/server.cpp"], pybake_server_module = Extension('_pybakeserver', ["pybake/src/server.cpp"],
libraries=['boost_python','margo','bake-server'], libraries=['boost_python','margo','bake-server'],
include_dirs=['.'], include_dirs=['.'],
depends=[]) depends=[])
pybake_client_module = Extension('_pybakeclient', ["pybake/src/client.cpp"], pybake_client_module = Extension('_pybakeclient', ["pybake/src/client.cpp"],
libraries=['boost_python','margo','bake-client'], libraries=client_libs,
include_dirs=['.'], include_dirs=['.'],
depends=[]) depends=[],
define_macros=[('HAS_NUMPY', has_numpy)])
pybake_target_module = Extension('_pybaketarget', ["pybake/src/target.cpp"], pybake_target_module = Extension('_pybaketarget', ["pybake/src/target.cpp"],
libraries=['boost_python', 'uuid' ], libraries=['boost_python', 'uuid' ],
......
# (C) 2018 The University of Chicago
# See COPYRIGHT in top-level directory.
import sys
from pymargo import MargoInstance
from pybake.target import BakeRegionID
from pybake.client import *
import numpy as np
mid = MargoInstance('tcp')
server_addr = sys.argv[1]
mplex_id = int(sys.argv[2])
client = BakeClient(mid)
addr = mid.lookup(server_addr)
ph = client.create_provider_handle(addr, mplex_id)
# Testing get_eager_limit
lim = ph.get_eager_limit()
print "Eager limit is: "+str(lim)
# probe the provider handle (for all targets)
targets = ph.probe()
print "Probe found the following targets:"
for t in targets:
print "===== "+str(t)
target = targets[0]
# write into a region
arr = np.random.randn(5,6)
print "Writing the following numpy array: "
print str(arr)
region = ph.create_write_persist_numpy(target, arr)
# get size of region
s = ph.get_size(region)
print "Region size is "+str(s)
# read region
result = ph.read_numpy(region, 0, shape=(5,6), dtype=arr.dtype)
# check for equalit
print "Reading region gave the following numpy array: "
print str(result)
if((result == arr).all()):
print "The two arrays are equal"
else:
print "The two arrays are NOT equal"
del ph
client.shutdown_service(addr)
del addr
client.finalize()
mid.finalize()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment