Commit 0b8767b7 authored by Matthieu Dorier's avatar Matthieu Dorier

added bulk

parent 520d24e6
#ifndef __THALLIUM_HPP #ifndef __THALLIUM_HPP
#define __THALLIUM_HPP #define __THALLIUM_HPP
#include <thallium/bulk_mode.hpp>
#include <thallium/bulk.hpp>
#include <thallium/engine.hpp> #include <thallium/engine.hpp>
#include <thallium/endpoint.hpp> #include <thallium/endpoint.hpp>
#include <thallium/remote_procedure.hpp> #include <thallium/remote_procedure.hpp>
#include <thallium/callable_remote_procedure.hpp> #include <thallium/callable_remote_procedure.hpp>
#include <thallium/resolved_bulk.hpp>
#endif #endif
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#ifndef __THALLIUM_BULK_HPP
#define __THALLIUM_BULK_HPP
#include <cstdint>
#include <string>
#include <vector>
#include <margo.h>
#include <thallium/endpoint.hpp>
namespace thallium {
class engine;
class resolved_bulk;
class bulk {
friend class engine;
friend class resolved_bulk;
private:
engine* m_engine;
hg_bulk_t m_bulk;
bool m_is_local;
bulk(engine& e, hg_bulk_t b, bool local)
: m_engine(&e), m_bulk(b), m_is_local(local) {}
class bulk_segment {
friend class resolved_bulk;
std::size_t m_offset;
std::size_t m_size;
const bulk& m_bulk;
public:
bulk_segment(const bulk& b)
: m_offset(0), m_size(b.size()), m_bulk(b) {}
bulk_segment(const bulk& b, std::size_t offset, std::size_t size)
: m_offset(offset), m_size(size), m_bulk(b) {}
bulk_segment(const bulk_segment&) = delete;
bulk_segment(bulk_segment&&) = default;
~bulk_segment() = default;
resolved_bulk on(const endpoint& ep) const;
};
public:
bulk()
: m_engine(nullptr), m_bulk(HG_BULK_NULL), m_is_local(false) {}
bulk(const bulk& other)
: m_engine(other.m_engine), m_bulk(other.m_bulk), m_is_local(other.m_is_local) {
margo_bulk_ref_incr(m_bulk);
}
bulk(bulk&& other)
: m_engine(other.m_engine), m_bulk(other.m_bulk), m_is_local(std::move(other.m_is_local)) {
other.m_bulk = HG_BULK_NULL;
}
bulk& operator=(const bulk& other) {
if(this == &other) return *this;
if(m_bulk != HG_BULK_NULL) {
margo_bulk_free(m_bulk);
}
m_bulk = other.m_bulk;
m_engine = other.m_engine;
m_is_local = other.m_is_local;
if(m_bulk != HG_BULK_NULL) {
margo_bulk_ref_incr(m_bulk);
}
return *this;
}
bulk& operator=(bulk&& other) {
if(this == &other) return *this;
if(m_bulk != HG_BULK_NULL) {
margo_bulk_free(m_bulk);
}
m_engine = other.m_engine;
m_bulk = other.m_bulk;
m_is_local = other.m_is_local;
other.m_bulk = HG_BULK_NULL;
return *this;
}
~bulk() {
if(m_bulk != HG_BULK_NULL) {
margo_bulk_free(m_bulk);
}
}
std::size_t size() const {
if(m_bulk != HG_BULK_NULL)
return margo_bulk_get_size(m_bulk);
else
return 0;
}
bool is_null() const {
return m_bulk == HG_BULK_NULL;
}
bulk_segment select(std::size_t offset, std::size_t size) const;
bulk_segment operator()(std::size_t offset, std::size_t size) const;
template<typename A>
void save(A& ar) {
hg_size_t s = margo_bulk_get_serialize_size(m_bulk, HG_TRUE);
std::vector<char> buf(s);
margo_bulk_serialize(&buf[0], s, HG_TRUE, m_bulk);
// XXX check return values
ar & buf;
}
template<typename A>
void load(A& ar);
};
}
#include <thallium/engine.hpp>
namespace thallium {
template<typename A>
void bulk::load(A& ar) {
std::vector<char> buf;
ar & buf;
m_engine = &(ar.get_engine());
margo_bulk_deserialize(m_engine->m_mid, &m_bulk, &buf[0], buf.size());
// XXX check return value
m_is_local = false;
}
}
#endif
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#ifndef __THALLIUM_BULK_MODE_HPP
#define __THALLIUM_BULK_MODE_HPP
#include <margo.h>
namespace thallium {
enum class bulk_mode : hg_uint32_t {
read_write = HG_BULK_READWRITE,
read_only = HG_BULK_READ_ONLY,
write_only = HG_BULK_WRITE_ONLY
};
}
#endif
...@@ -27,18 +27,19 @@ class callable_remote_procedure { ...@@ -27,18 +27,19 @@ class callable_remote_procedure {
friend class remote_procedure; friend class remote_procedure;
private: private:
engine* m_engine;
hg_handle_t m_handle; hg_handle_t m_handle;
bool m_ignore_response; bool m_ignore_response;
callable_remote_procedure(hg_id_t id, const endpoint& ep, bool ignore_resp); callable_remote_procedure(engine& e, hg_id_t id, const endpoint& ep, bool ignore_resp);
auto forward(const buffer& buf) const { auto forward(const buffer& buf) const {
margo_forward(m_handle, const_cast<void*>(static_cast<const void*>(&buf))); margo_forward(m_handle, const_cast<void*>(static_cast<const void*>(&buf)));
buffer output; buffer output;
if(m_ignore_response) return packed_response(std::move(output)); if(m_ignore_response) return packed_response(std::move(output), *m_engine);
margo_get_output(m_handle, &output); margo_get_output(m_handle, &output);
margo_free_output(m_handle, &output); // won't do anything on a buffer type margo_free_output(m_handle, &output); // won't do anything on a buffer type
return packed_response(std::move(output)); return packed_response(std::move(output), *m_engine);
} }
public: public:
...@@ -90,7 +91,7 @@ public: ...@@ -90,7 +91,7 @@ public:
template<typename ... T> template<typename ... T>
auto operator()(T&& ... t) const { auto operator()(T&& ... t) const {
buffer b; buffer b;
buffer_output_archive arch(b); buffer_output_archive arch(b, *m_engine);
serialize_many(arch, std::forward<T>(t)...); serialize_many(arch, std::forward<T>(t)...);
return forward(b); return forward(b);
} }
......
...@@ -13,11 +13,15 @@ ...@@ -13,11 +13,15 @@
namespace thallium { namespace thallium {
class engine; class engine;
class request;
class resolved_bulk;
class endpoint { class endpoint {
friend class engine; friend class engine;
friend class request;
friend class callable_remote_procedure; friend class callable_remote_procedure;
friend class resolved_bulk;
private: private:
...@@ -45,7 +49,11 @@ public: ...@@ -45,7 +49,11 @@ public:
~endpoint(); ~endpoint();
operator std::string() const; operator std::string() const;
bool is_null() const {
return m_addr == HG_ADDR_NULL;
}
}; };
} }
......
...@@ -15,15 +15,21 @@ ...@@ -15,15 +15,21 @@
#include <thallium/function_cast.hpp> #include <thallium/function_cast.hpp>
#include <thallium/buffer.hpp> #include <thallium/buffer.hpp>
#include <thallium/request.hpp> #include <thallium/request.hpp>
#include <thallium/bulk_mode.hpp>
namespace thallium { namespace thallium {
class bulk;
class endpoint; class endpoint;
class resolved_bulk;
class remote_procedure; class remote_procedure;
class engine { class engine {
friend class request;
friend class bulk;
friend class endpoint; friend class endpoint;
friend class resolved_bulk;
friend class remote_procedure; friend class remote_procedure;
friend class callable_remote_procedure; friend class callable_remote_procedure;
...@@ -35,14 +41,25 @@ private: ...@@ -35,14 +41,25 @@ private:
bool m_is_server; bool m_is_server;
std::unordered_map<hg_id_t, rpc_t> m_rpcs; std::unordered_map<hg_id_t, rpc_t> m_rpcs;
struct rpc_callback_data {
engine* m_engine;
void* m_function;
};
static void free_rpc_callback_data(void* data) {
rpc_callback_data* cb_data = (rpc_callback_data*)data;
delete cb_data;
}
template<typename F, bool disable_response> template<typename F, bool disable_response>
static void rpc_handler_ult(hg_handle_t handle) { static void rpc_handler_ult(hg_handle_t handle) {
using G = std::remove_reference_t<F>; using G = std::remove_reference_t<F>;
const struct hg_info* info = margo_get_info(handle); const struct hg_info* info = margo_get_info(handle);
margo_instance_id mid = margo_hg_handle_get_instance(handle); margo_instance_id mid = margo_hg_handle_get_instance(handle);
void* data = margo_registered_data(mid, info->id); void* data = margo_registered_data(mid, info->id);
auto f = function_cast<G>(data); auto cb_data = static_cast<rpc_callback_data*>(data);
request req(handle, disable_response); auto f = function_cast<G>(cb_data->m_function);
request req(*(cb_data->m_engine), handle, disable_response);
buffer input; buffer input;
margo_get_input(handle, &input); margo_get_input(handle, &input);
(*f)(req, input); (*f)(req, input);
...@@ -115,6 +132,8 @@ public: ...@@ -115,6 +132,8 @@ public:
endpoint lookup(const std::string& address) const; endpoint lookup(const std::string& address) const;
bulk expose(const std::vector<std::pair<void*,size_t>>& segments, bulk_mode flag);
operator std::string() const; operator std::string() const;
}; };
...@@ -138,19 +157,23 @@ remote_procedure engine::define(const std::string& name, ...@@ -138,19 +157,23 @@ remote_procedure engine::define(const std::string& name,
process_buffer, process_buffer,
rpc_callback<rpc_t, false>); rpc_callback<rpc_t, false>);
m_rpcs[id] = [fun](const request& r, const buffer& b) { m_rpcs[id] = [fun,this](const request& r, const buffer& b) {
std::function<void(Args...)> l = [&fun, &r](Args&&... args) { std::function<void(Args...)> l = [&fun, &r](Args&&... args) {
fun(r, std::forward<Args>(args)...); fun(r, std::forward<Args>(args)...);
}; };
std::tuple<std::decay_t<Args>...> iargs; std::tuple<std::decay_t<Args>...> iargs;
if(sizeof...(Args) > 0) { if(sizeof...(Args) > 0) {
buffer_input_archive iarch(b); buffer_input_archive iarch(b, *this);
iarch & iargs; iarch & iargs;
} }
apply_function_to_tuple(l,iargs); apply_function_to_tuple(l,iargs);
}; };
margo_register_data(m_mid, id, void_cast(&m_rpcs[id]), nullptr); rpc_callback_data* cb_data = new rpc_callback_data;
cb_data->m_engine = this;
cb_data->m_function = void_cast(&m_rpcs[id]);
margo_register_data(m_mid, id, (void*)cb_data, free_rpc_callback_data);
return remote_procedure(*this, id); return remote_procedure(*this, id);
} }
......
...@@ -20,17 +20,18 @@ class packed_response { ...@@ -20,17 +20,18 @@ class packed_response {
private: private:
buffer m_buffer; engine* m_engine;
buffer m_buffer;
packed_response(buffer&& b) packed_response(buffer&& b, engine& e)
: m_buffer(std::move(b)) {} : m_engine(&e), m_buffer(std::move(b)) {}
public: public:
template<typename T> template<typename T>
T as() const { T as() const {
T t; T t;
buffer_input_archive iarch(m_buffer); buffer_input_archive iarch(m_buffer, *m_engine);
iarch & t; iarch & t;
return t; return t;
} }
......
...@@ -19,7 +19,7 @@ class remote_procedure { ...@@ -19,7 +19,7 @@ class remote_procedure {
friend class engine; friend class engine;
private: private:
engine& m_engine; engine* m_engine;
hg_id_t m_id; hg_id_t m_id;
bool m_ignore_response; bool m_ignore_response;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
namespace thallium { namespace thallium {
class engine; class engine;
class endpoint;
class request { class request {
...@@ -20,28 +21,30 @@ class request { ...@@ -20,28 +21,30 @@ class request {
private: private:
engine* m_engine;
hg_handle_t m_handle; hg_handle_t m_handle;
bool m_disable_response; bool m_disable_response;
request(hg_handle_t h, bool disable_resp) request(engine& e, hg_handle_t h, bool disable_resp)
: m_handle(h), m_disable_response(disable_resp) {} : m_engine(&e), m_handle(h), m_disable_response(disable_resp) {}
public: public:
request(const request& other) request(const request& other)
: m_handle(other.m_handle), m_disable_response(other.m_disable_response) { : m_engine(other.m_engine), m_handle(other.m_handle), m_disable_response(other.m_disable_response) {
margo_ref_incr(m_handle); margo_ref_incr(m_handle);
} }
request(request&& other) request(request&& other)
: m_handle(other.m_handle), m_disable_response(other.m_disable_response) { : m_engine(other.m_engine), m_handle(other.m_handle), m_disable_response(other.m_disable_response) {
other.m_handle = HG_HANDLE_NULL; other.m_handle = HG_HANDLE_NULL;
} }
request& operator=(const request& other) { request& operator=(const request& other) {
if(m_handle == other.m_handle) return *this; if(m_handle == other.m_handle) return *this;
margo_destroy(m_handle); margo_destroy(m_handle);
m_handle = other.m_handle; m_engine = other.m_engine;
m_handle = other.m_handle;
m_disable_response = other.m_disable_response; m_disable_response = other.m_disable_response;
margo_ref_incr(m_handle); margo_ref_incr(m_handle);
return *this; return *this;
...@@ -50,7 +53,8 @@ public: ...@@ -50,7 +53,8 @@ public:
request& operator=(request&& other) { request& operator=(request&& other) {
if(m_handle == other.m_handle) return *this; if(m_handle == other.m_handle) return *this;
margo_destroy(m_handle); margo_destroy(m_handle);
m_handle = other.m_handle; m_engine = other.m_engine;
m_handle = other.m_handle;
m_disable_response = other.m_disable_response; m_disable_response = other.m_disable_response;
other.m_handle = HG_HANDLE_NULL; other.m_handle = HG_HANDLE_NULL;
return *this; return *this;
...@@ -65,27 +69,13 @@ public: ...@@ -65,27 +69,13 @@ public:
if(m_disable_response) return; // XXX throwing an exception? if(m_disable_response) return; // XXX throwing an exception?
if(m_handle != HG_HANDLE_NULL) { if(m_handle != HG_HANDLE_NULL) {
buffer b; buffer b;
buffer_output_archive arch(b); buffer_output_archive arch(b, *m_engine);
serialize_many(arch, std::forward<T>(t)...); serialize_many(arch, std::forward<T>(t)...);
margo_respond(m_handle, &b); margo_respond(m_handle, &b);
} }
} }
/*
void respond(const buffer& output) const { endpoint get_endpoint() const;
if(m_disable_response) return; // XXX throwing an exception?
if(m_handle != HG_HANDLE_NULL) {
margo_respond(m_handle, const_cast<void*>(static_cast<const void*>(&output)));
}
}
void respond(buffer& output) const {
respond((const buffer&)output);
}
void respond(buffer&& output) const {
respond((const buffer&)output);
}
*/
}; };
} }
......
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#ifndef __THALLIUM_RESOLVED_BULK_HPP
#define __THALLIUM_RESOLVED_BULK_HPP
#include <cstdint>
#include <string>
#include <vector>
#include <margo.h>
#include <thallium/bulk.hpp>
namespace thallium {
class resolved_bulk {
friend class bulk;
private:
const bulk::bulk_segment& m_segment;
endpoint m_endpoint;
resolved_bulk(const bulk::bulk_segment& b, const endpoint& ep)
: m_segment(b), m_endpoint(ep) {}
public:
std::size_t operator>>(const bulk::bulk_segment& dest) const;
std::size_t operator<<(const bulk::bulk_segment& src) const;
};
}
#endif
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
namespace thallium { namespace thallium {
class engine;
/** /**
* buffer_input_archive wraps a buffer object and * buffer_input_archive wraps a buffer object and
* offers the functionalities to deserialize its content * offers the functionalities to deserialize its content
...@@ -25,8 +27,9 @@ class buffer_input_archive : public input_archive { ...@@ -25,8 +27,9 @@ class buffer_input_archive : public input_archive {
private: private:
const buffer& buffer_; const buffer& m_buffer;
std::size_t pos; std::size_t m_pos;
engine* m_engine;
template<typename T, bool b> template<typename T, bool b>
inline void read_impl(T&& t, const std::integral_constant<bool, b>&) { inline void read_impl(T&& t, const std::integral_constant<bool, b>&) {
...@@ -48,7 +51,11 @@ public: ...@@ -48,7 +51,11 @@ public:
* the buffer_input_archive instance should be shorter than that * the buffer_input_archive instance should be shorter than that
* of the buffer. * of the buffer.
*/ */
buffer_input_archive(const buffer& b) : buffer_(b), pos(0) {} buffer_input_archive(const buffer& b, engine& e)
: m_buffer(b), m_pos(0), m_engine(&e) {}
buffer_input_archive(const buffer& b)
: m_buffer(b), m_pos(0), m_engine(nullptr) {}
/** /**
* Operator to get C++ objects of type T from the archive. * Operator to get C++ objects of type T from the archive.
...@@ -81,12 +88,16 @@ public: ...@@ -81,12 +88,16 @@ public:
*/ */
template<typename T> template<typename T>
inline void read(T* t, std::size_t count=1) { inline void read(T* t, std::size_t count=1) {
if(pos + count*sizeof(T) > buffer_.size()) { if(m_pos + count*sizeof(T) > m_buffer.size()) {
throw std::runtime_error("Reading beyond buffer size"); throw std::runtime_error("Reading beyond buffer size");
} }
std::memcpy((void*)t,(const void*)(&buffer_[pos]),count*sizeof(T)); std::memcpy((void*)t,(const void*)(&m_buffer[m_pos]),count*sizeof(T));
pos += count*sizeof(T); m_pos += count*sizeof(T);
} }
engine& get_engine() const {
return *m_engine;
}
}; };
} }
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
namespace thallium { namespace thallium {
class engine;
/** /**
* buffer_output_archive wraps and hg::buffer object and * buffer_output_archive wraps and hg::buffer object and
* offers the functionalities to serialize C++ objects into * offers the functionalities to serialize C++ objects into
...@@ -23,8 +25,9 @@ class buffer_output_archive : public output_archive { ...@@ -23,8 +25,9 @@ class buffer_output_archive : public output_archive {
private: private:
buffer& buffer_; buffer& m_buffer;
std::size_t pos; std::size_t m_pos;
engine* m_engine;
template<typename T, bool b> template<typename T, bool b>
inline void write_impl(T&& t, const std::integral_constant<bool, b>&) {