From 0b9c595ba855f61fe88d2f0ad30056983a261a63 Mon Sep 17 00:00:00 2001 From: Matthieu Dorier Date: Sat, 12 May 2018 02:17:03 +0200 Subject: [PATCH] added numpy support --- pybake/client.py | 48 ++++++++++++++++++++++++++ pybake/src/client.cpp | 79 +++++++++++++++++++++++++++++++++++++++++-- setup.py | 16 +++++++-- test/numpy_client.py | 55 ++++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 test/numpy_client.py diff --git a/pybake/client.py b/pybake/client.py index 9056074..e160a67 100644 --- a/pybake/client.py +++ b/pybake/client.py @@ -128,6 +128,17 @@ class BakeProviderHandle(): """ return _pybakeclient.write(self._ph, rid._rid, offset, data) + def write_numpy(self, rid, offset, array): + """ + Writes a numpy array in a region at a specified offset. + + Args: + rid (BakeRegionID): region in which to write. + offset (int): offset at which to write. + data (numpy.ndarray): numpy array to write. + """ + return _pybakeclient.write_numpy(self._ph, rid._rid, offset, array) + def persist(self, rid): """ Make the changes to a given region persist. @@ -149,6 +160,24 @@ class BakeProviderHandle(): data (str): data to write. """ rid = _pybakeclient.create_write_persist(self._ph, bti._tid, data) + if(rid is None): + return None + return BakeRegionID(rid) + + def create_write_persist_numpy(self, bti, array): + """ + Creates a new region, write the numpy array to it at a given offset, + and persist the region. + + Args: + bti (BakeTargetID): target id in which to create the region. + size (int): size of the region to create. + offset (int): offset at which to write data in the region. + array (numpy.ndarray): numpy array to write. + """ + rid = _pybakeclient.create_write_persist_numpy(self._ph, bti._tid, array) + if(rid is None): + return None return BakeRegionID(rid) def get_size(self, rid): @@ -184,6 +213,25 @@ class BakeProviderHandle(): size = self.get_size(rid) - offset return _pybakeclient.read(self._ph, rid._rid, offset, size) + def read_numpy(self, rid, offset, shape, dtype): + """ + Reads the data contained in a given region, at a given offset, + and interpret it as a numpy array of a given shape and datatype. + This function will fail if the full array cannot be loaded + (e.g. the size of the region from the provided offset is too small + compared with the size of the numpy that should result from the call) + + Args: + rid (BakeRegionID): region id. + offset (int): offset at which to read. + shape (tuple): shape of the resulting array. + dtype (numpy.dtype): datatype of the resuling array. + + Returns: + A numpy array or None if it could not be read. + """ + return _pybakeclient.read_numpy(self._ph, rid._rid, offset, shape, dtype) + def remove(self, rid): """ Remove a region from its target. diff --git a/pybake/src/client.cpp b/pybake/src/client.cpp index cf28228..e1fd19b 100644 --- a/pybake/src/client.cpp +++ b/pybake/src/client.cpp @@ -18,6 +18,10 @@ #include #include #include +#if HAS_NUMPY +#include +namespace np = boost::python::numpy; +#endif BOOST_PYTHON_OPAQUE_SPECIALIZED_TYPE_ID(margo_instance) BOOST_PYTHON_OPAQUE_SPECIALIZED_TYPE_ID(bake_provider_handle) @@ -89,6 +93,28 @@ static bpl::object pybake_write( else return bpl::object(false); } +#if HAS_NUMPY +static bpl::object pybake_write_numpy( + bake_provider_handle_t ph, + const bake_region_id_t& rid, + uint64_t offset, + const np::ndarray& data) +{ + if(!(data.get_flags() & np::ndarray::bitflag::V_CONTIGUOUS)) { + std::cerr << "[pyBAKE error]: non-contiguous numpy arrays not yet supported" << std::endl; + return bpl::object(false); + } + size_t size = data.get_dtype().get_itemsize(); + for(int i = 0; i < data.get_nd(); i++) { + size *= data.shape(i); + } + void* buffer = data.get_data(); + int ret = bake_write(ph, rid, offset, buffer, size); + if(ret != 0) return bpl::object(false); + else return bpl::object(true); +} +#endif + static bpl::object pybake_persist( bake_provider_handle_t ph, const bake_region_id_t& rid) @@ -110,6 +136,29 @@ static bpl::object pybake_create_write_persist( else return bpl::object(); } +#if HAS_NUMPY +static bpl::object pybake_create_write_persist_numpy( + bake_provider_handle_t ph, + bake_target_id_t tid, + const np::ndarray& data) +{ + bake_region_id_t rid; + if(!(data.get_flags() & np::ndarray::bitflag::V_CONTIGUOUS)) { + std::cerr << "[pyBAKE error]: non-contiguous numpy arrays not yet supported" << std::endl; + return bpl::object(); + } + size_t size = data.get_dtype().get_itemsize(); + for(int i = 0; i < data.get_nd(); i++) { + size *= data.shape(i); + } + void* buffer = data.get_data(); + int ret = bake_create_write_persist(ph, tid, + buffer, size, &rid); + if(ret == 0) return bpl::object(rid); + else return bpl::object(); +} +#endif + static bpl::object pybake_get_size( bake_provider_handle_t ph, const bake_region_id_t& rid) @@ -134,14 +183,35 @@ static bpl::object pybake_read( return bpl::object(result); } +#if HAS_NUMPY +static bpl::object pybake_read_numpy( + bake_provider_handle_t ph, + const bake_region_id_t& rid, + uint64_t offset, + const bpl::tuple& shape, + const np::dtype& dtype) +{ + np::ndarray result = np::empty(shape, dtype); + size_t size = dtype.get_itemsize(); + for(int i=0; i < result.get_nd(); i++) + size *= result.shape(i); + uint64_t bytes_read; + int ret = bake_read(ph, rid, offset, (void*)result.get_data(), size, &bytes_read); + if(ret != 0) return bpl::object(); + if(bytes_read != size) return bpl::object(); + else return result; +} +#endif + BOOST_PYTHON_MODULE(_pybakeclient) { #define ret_policy_opaque bpl::return_value_policy() - +#if HAS_NUMPY + np::initialize(); +#endif bpl::import("_pybaketarget"); bpl::opaque(); bpl::opaque(); -// bpl::class_("bake_region_id", bpl::no_init); bpl::def("client_init", &pybake_client_init, ret_policy_opaque); bpl::def("client_finalize", &bake_client_finalize); bpl::def("provider_handle_create", &pybake_provider_handle_create, ret_policy_opaque); @@ -158,6 +228,11 @@ BOOST_PYTHON_MODULE(_pybakeclient) bpl::def("read", &pybake_read); bpl::def("remove", &bake_remove); bpl::def("shutdown_service", &bake_shutdown_service); +#if HAS_NUMPY + bpl::def("write_numpy", &pybake_write_numpy); + bpl::def("create_write_persist_numpy", &pybake_create_write_persist_numpy); + bpl::def("read_numpy", &pybake_read_numpy); +#endif #undef ret_policy_opaque } diff --git a/setup.py b/setup.py index 75d55cb..856df67 100644 --- a/setup.py +++ b/setup.py @@ -10,15 +10,27 @@ os.environ['OPT'] = " ".join( flag for flag in opt.split() if flag != '-Wstrict-prototypes' ) +try: + import numpy + has_numpy = 1 +except ImportError: + has_numpy = 0 + +if has_numpy == 1: + client_libs=['boost_python','margo','bake-client', 'boost_numpy'] +else: + client_libs=['boost_python','margo','bake-client'] + pybake_server_module = Extension('_pybakeserver', ["pybake/src/server.cpp"], libraries=['boost_python','margo','bake-server'], include_dirs=['.'], depends=[]) pybake_client_module = Extension('_pybakeclient', ["pybake/src/client.cpp"], - libraries=['boost_python','margo','bake-client'], + libraries=client_libs, include_dirs=['.'], - depends=[]) + depends=[], + define_macros=[('HAS_NUMPY', has_numpy)]) pybake_target_module = Extension('_pybaketarget', ["pybake/src/target.cpp"], libraries=['boost_python', 'uuid' ], diff --git a/test/numpy_client.py b/test/numpy_client.py new file mode 100644 index 0000000..1673f87 --- /dev/null +++ b/test/numpy_client.py @@ -0,0 +1,55 @@ +# (C) 2018 The University of Chicago +# See COPYRIGHT in top-level directory. +import sys +from pymargo import MargoInstance +from pybake.target import BakeRegionID +from pybake.client import * +import numpy as np + +mid = MargoInstance('tcp') + +server_addr = sys.argv[1] +mplex_id = int(sys.argv[2]) + +client = BakeClient(mid) +addr = mid.lookup(server_addr) +ph = client.create_provider_handle(addr, mplex_id) + +# Testing get_eager_limit +lim = ph.get_eager_limit() +print "Eager limit is: "+str(lim) + +# probe the provider handle (for all targets) +targets = ph.probe() +print "Probe found the following targets:" +for t in targets: + print "===== "+str(t) + +target = targets[0] + +# write into a region +arr = np.random.randn(5,6) +print "Writing the following numpy array: " +print str(arr) +region = ph.create_write_persist_numpy(target, arr) + +# get size of region +s = ph.get_size(region) +print "Region size is "+str(s) + +# read region +result = ph.read_numpy(region, 0, shape=(5,6), dtype=arr.dtype) +# check for equalit +print "Reading region gave the following numpy array: " +print str(result) + +if((result == arr).all()): + print "The two arrays are equal" +else: + print "The two arrays are NOT equal" + +del ph +client.shutdown_service(addr) +del addr +client.finalize() +mid.finalize() -- 2.26.2