Commit 0b8767b7 authored by Matthieu Dorier's avatar Matthieu Dorier

added bulk

parent 520d24e6
#ifndef __THALLIUM_HPP
#define __THALLIUM_HPP
#include <thallium/bulk_mode.hpp>
#include <thallium/bulk.hpp>
#include <thallium/engine.hpp>
#include <thallium/endpoint.hpp>
#include <thallium/remote_procedure.hpp>
#include <thallium/callable_remote_procedure.hpp>
#include <thallium/resolved_bulk.hpp>
#endif
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#ifndef __THALLIUM_BULK_HPP
#define __THALLIUM_BULK_HPP
#include <cstdint>
#include <string>
#include <vector>
#include <margo.h>
#include <thallium/endpoint.hpp>
namespace thallium {
class engine;
class resolved_bulk;
class bulk {
friend class engine;
friend class resolved_bulk;
private:
engine* m_engine;
hg_bulk_t m_bulk;
bool m_is_local;
bulk(engine& e, hg_bulk_t b, bool local)
: m_engine(&e), m_bulk(b), m_is_local(local) {}
class bulk_segment {
friend class resolved_bulk;
std::size_t m_offset;
std::size_t m_size;
const bulk& m_bulk;
public:
bulk_segment(const bulk& b)
: m_offset(0), m_size(b.size()), m_bulk(b) {}
bulk_segment(const bulk& b, std::size_t offset, std::size_t size)
: m_offset(offset), m_size(size), m_bulk(b) {}
bulk_segment(const bulk_segment&) = delete;
bulk_segment(bulk_segment&&) = default;
~bulk_segment() = default;
resolved_bulk on(const endpoint& ep) const;
};
public:
bulk()
: m_engine(nullptr), m_bulk(HG_BULK_NULL), m_is_local(false) {}
bulk(const bulk& other)
: m_engine(other.m_engine), m_bulk(other.m_bulk), m_is_local(other.m_is_local) {
margo_bulk_ref_incr(m_bulk);
}
bulk(bulk&& other)
: m_engine(other.m_engine), m_bulk(other.m_bulk), m_is_local(std::move(other.m_is_local)) {
other.m_bulk = HG_BULK_NULL;
}
bulk& operator=(const bulk& other) {
if(this == &other) return *this;
if(m_bulk != HG_BULK_NULL) {
margo_bulk_free(m_bulk);
}
m_bulk = other.m_bulk;
m_engine = other.m_engine;
m_is_local = other.m_is_local;
if(m_bulk != HG_BULK_NULL) {
margo_bulk_ref_incr(m_bulk);
}
return *this;
}
bulk& operator=(bulk&& other) {
if(this == &other) return *this;
if(m_bulk != HG_BULK_NULL) {
margo_bulk_free(m_bulk);
}
m_engine = other.m_engine;
m_bulk = other.m_bulk;
m_is_local = other.m_is_local;
other.m_bulk = HG_BULK_NULL;
return *this;
}
~bulk() {
if(m_bulk != HG_BULK_NULL) {
margo_bulk_free(m_bulk);
}
}
std::size_t size() const {
if(m_bulk != HG_BULK_NULL)
return margo_bulk_get_size(m_bulk);
else
return 0;
}
bool is_null() const {
return m_bulk == HG_BULK_NULL;
}
bulk_segment select(std::size_t offset, std::size_t size) const;
bulk_segment operator()(std::size_t offset, std::size_t size) const;
template<typename A>
void save(A& ar) {
hg_size_t s = margo_bulk_get_serialize_size(m_bulk, HG_TRUE);
std::vector<char> buf(s);
margo_bulk_serialize(&buf[0], s, HG_TRUE, m_bulk);
// XXX check return values
ar & buf;
}
template<typename A>
void load(A& ar);
};
}
#include <thallium/engine.hpp>
namespace thallium {
template<typename A>
void bulk::load(A& ar) {
std::vector<char> buf;
ar & buf;
m_engine = &(ar.get_engine());
margo_bulk_deserialize(m_engine->m_mid, &m_bulk, &buf[0], buf.size());
// XXX check return value
m_is_local = false;
}
}
#endif
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#ifndef __THALLIUM_BULK_MODE_HPP
#define __THALLIUM_BULK_MODE_HPP
#include <margo.h>
namespace thallium {
enum class bulk_mode : hg_uint32_t {
read_write = HG_BULK_READWRITE,
read_only = HG_BULK_READ_ONLY,
write_only = HG_BULK_WRITE_ONLY
};
}
#endif
......@@ -27,18 +27,19 @@ class callable_remote_procedure {
friend class remote_procedure;
private:
engine* m_engine;
hg_handle_t m_handle;
bool m_ignore_response;
callable_remote_procedure(hg_id_t id, const endpoint& ep, bool ignore_resp);
callable_remote_procedure(engine& e, hg_id_t id, const endpoint& ep, bool ignore_resp);
auto forward(const buffer& buf) const {
margo_forward(m_handle, const_cast<void*>(static_cast<const void*>(&buf)));
buffer output;
if(m_ignore_response) return packed_response(std::move(output));
if(m_ignore_response) return packed_response(std::move(output), *m_engine);
margo_get_output(m_handle, &output);
margo_free_output(m_handle, &output); // won't do anything on a buffer type
return packed_response(std::move(output));
return packed_response(std::move(output), *m_engine);
}
public:
......@@ -90,7 +91,7 @@ public:
template<typename ... T>
auto operator()(T&& ... t) const {
buffer b;
buffer_output_archive arch(b);
buffer_output_archive arch(b, *m_engine);
serialize_many(arch, std::forward<T>(t)...);
return forward(b);
}
......
......@@ -13,11 +13,15 @@
namespace thallium {
class engine;
class request;
class resolved_bulk;
class endpoint {
friend class engine;
friend class request;
friend class callable_remote_procedure;
friend class resolved_bulk;
private:
......@@ -45,7 +49,11 @@ public:
~endpoint();
operator std::string() const;
operator std::string() const;
bool is_null() const {
return m_addr == HG_ADDR_NULL;
}
};
}
......
......@@ -15,15 +15,21 @@
#include <thallium/function_cast.hpp>
#include <thallium/buffer.hpp>
#include <thallium/request.hpp>
#include <thallium/bulk_mode.hpp>
namespace thallium {
class bulk;
class endpoint;
class resolved_bulk;
class remote_procedure;
class engine {
friend class request;
friend class bulk;
friend class endpoint;
friend class resolved_bulk;
friend class remote_procedure;
friend class callable_remote_procedure;
......@@ -35,14 +41,25 @@ private:
bool m_is_server;
std::unordered_map<hg_id_t, rpc_t> m_rpcs;
struct rpc_callback_data {
engine* m_engine;
void* m_function;
};
static void free_rpc_callback_data(void* data) {
rpc_callback_data* cb_data = (rpc_callback_data*)data;
delete cb_data;
}
template<typename F, bool disable_response>
static void rpc_handler_ult(hg_handle_t handle) {
using G = std::remove_reference_t<F>;
const struct hg_info* info = margo_get_info(handle);
margo_instance_id mid = margo_hg_handle_get_instance(handle);
void* data = margo_registered_data(mid, info->id);
auto f = function_cast<G>(data);
request req(handle, disable_response);
auto cb_data = static_cast<rpc_callback_data*>(data);
auto f = function_cast<G>(cb_data->m_function);
request req(*(cb_data->m_engine), handle, disable_response);
buffer input;
margo_get_input(handle, &input);
(*f)(req, input);
......@@ -115,6 +132,8 @@ public:
endpoint lookup(const std::string& address) const;
bulk expose(const std::vector<std::pair<void*,size_t>>& segments, bulk_mode flag);
operator std::string() const;
};
......@@ -138,19 +157,23 @@ remote_procedure engine::define(const std::string& name,
process_buffer,
rpc_callback<rpc_t, false>);
m_rpcs[id] = [fun](const request& r, const buffer& b) {
m_rpcs[id] = [fun,this](const request& r, const buffer& b) {
std::function<void(Args...)> l = [&fun, &r](Args&&... args) {
fun(r, std::forward<Args>(args)...);
};
std::tuple<std::decay_t<Args>...> iargs;
if(sizeof...(Args) > 0) {
buffer_input_archive iarch(b);
buffer_input_archive iarch(b, *this);
iarch & iargs;
}
apply_function_to_tuple(l,iargs);
};
margo_register_data(m_mid, id, void_cast(&m_rpcs[id]), nullptr);
rpc_callback_data* cb_data = new rpc_callback_data;
cb_data->m_engine = this;
cb_data->m_function = void_cast(&m_rpcs[id]);
margo_register_data(m_mid, id, (void*)cb_data, free_rpc_callback_data);
return remote_procedure(*this, id);
}
......
......@@ -20,17 +20,18 @@ class packed_response {
private:
buffer m_buffer;
engine* m_engine;
buffer m_buffer;
packed_response(buffer&& b)
: m_buffer(std::move(b)) {}
packed_response(buffer&& b, engine& e)
: m_engine(&e), m_buffer(std::move(b)) {}
public:
template<typename T>
T as() const {
T t;
buffer_input_archive iarch(m_buffer);
buffer_input_archive iarch(m_buffer, *m_engine);
iarch & t;
return t;
}
......
......@@ -19,7 +19,7 @@ class remote_procedure {
friend class engine;
private:
engine& m_engine;
engine* m_engine;
hg_id_t m_id;
bool m_ignore_response;
......
......@@ -13,6 +13,7 @@
namespace thallium {
class engine;
class endpoint;
class request {
......@@ -20,28 +21,30 @@ class request {
private:
engine* m_engine;
hg_handle_t m_handle;
bool m_disable_response;
request(hg_handle_t h, bool disable_resp)
: m_handle(h), m_disable_response(disable_resp) {}
request(engine& e, hg_handle_t h, bool disable_resp)
: m_engine(&e), m_handle(h), m_disable_response(disable_resp) {}
public:
request(const request& other)
: m_handle(other.m_handle), m_disable_response(other.m_disable_response) {
: m_engine(other.m_engine), m_handle(other.m_handle), m_disable_response(other.m_disable_response) {
margo_ref_incr(m_handle);
}
request(request&& other)
: m_handle(other.m_handle), m_disable_response(other.m_disable_response) {
request(request&& other)
: m_engine(other.m_engine), m_handle(other.m_handle), m_disable_response(other.m_disable_response) {
other.m_handle = HG_HANDLE_NULL;
}
request& operator=(const request& other) {
if(m_handle == other.m_handle) return *this;
margo_destroy(m_handle);
m_handle = other.m_handle;
m_engine = other.m_engine;
m_handle = other.m_handle;
m_disable_response = other.m_disable_response;
margo_ref_incr(m_handle);
return *this;
......@@ -50,7 +53,8 @@ public:
request& operator=(request&& other) {
if(m_handle == other.m_handle) return *this;
margo_destroy(m_handle);
m_handle = other.m_handle;
m_engine = other.m_engine;
m_handle = other.m_handle;
m_disable_response = other.m_disable_response;
other.m_handle = HG_HANDLE_NULL;
return *this;
......@@ -65,27 +69,13 @@ public:
if(m_disable_response) return; // XXX throwing an exception?
if(m_handle != HG_HANDLE_NULL) {
buffer b;
buffer_output_archive arch(b);
buffer_output_archive arch(b, *m_engine);
serialize_many(arch, std::forward<T>(t)...);
margo_respond(m_handle, &b);
}
}
/*
void respond(const buffer& output) const {
if(m_disable_response) return; // XXX throwing an exception?
if(m_handle != HG_HANDLE_NULL) {
margo_respond(m_handle, const_cast<void*>(static_cast<const void*>(&output)));
}
}
void respond(buffer& output) const {
respond((const buffer&)output);
}
void respond(buffer&& output) const {
respond((const buffer&)output);
}
*/
endpoint get_endpoint() const;
};
}
......
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#ifndef __THALLIUM_RESOLVED_BULK_HPP
#define __THALLIUM_RESOLVED_BULK_HPP
#include <cstdint>
#include <string>
#include <vector>
#include <margo.h>
#include <thallium/bulk.hpp>
namespace thallium {
class resolved_bulk {
friend class bulk;
private:
const bulk::bulk_segment& m_segment;
endpoint m_endpoint;
resolved_bulk(const bulk::bulk_segment& b, const endpoint& ep)
: m_segment(b), m_endpoint(ep) {}
public:
std::size_t operator>>(const bulk::bulk_segment& dest) const;
std::size_t operator<<(const bulk::bulk_segment& src) const;
};
}
#endif
......@@ -14,6 +14,8 @@
namespace thallium {
class engine;
/**
* buffer_input_archive wraps a buffer object and
* offers the functionalities to deserialize its content
......@@ -25,8 +27,9 @@ class buffer_input_archive : public input_archive {
private:
const buffer& buffer_;
std::size_t pos;
const buffer& m_buffer;
std::size_t m_pos;
engine* m_engine;
template<typename T, bool b>
inline void read_impl(T&& t, const std::integral_constant<bool, b>&) {
......@@ -48,7 +51,11 @@ public:
* the buffer_input_archive instance should be shorter than that
* of the buffer.
*/
buffer_input_archive(const buffer& b) : buffer_(b), pos(0) {}
buffer_input_archive(const buffer& b, engine& e)
: m_buffer(b), m_pos(0), m_engine(&e) {}
buffer_input_archive(const buffer& b)
: m_buffer(b), m_pos(0), m_engine(nullptr) {}
/**
* Operator to get C++ objects of type T from the archive.
......@@ -81,12 +88,16 @@ public:
*/
template<typename T>
inline void read(T* t, std::size_t count=1) {
if(pos + count*sizeof(T) > buffer_.size()) {
if(m_pos + count*sizeof(T) > m_buffer.size()) {
throw std::runtime_error("Reading beyond buffer size");
}
std::memcpy((void*)t,(const void*)(&buffer_[pos]),count*sizeof(T));
pos += count*sizeof(T);
std::memcpy((void*)t,(const void*)(&m_buffer[m_pos]),count*sizeof(T));
m_pos += count*sizeof(T);
}
engine& get_engine() const {
return *m_engine;
}
};
}
......
......@@ -12,6 +12,8 @@
namespace thallium {
class engine;
/**
* buffer_output_archive wraps and hg::buffer object and
* offers the functionalities to serialize C++ objects into
......@@ -23,8 +25,9 @@ class buffer_output_archive : public output_archive {
private:
buffer& buffer_;
std::size_t pos;
buffer& m_buffer;
std::size_t m_pos;
engine* m_engine;
template<typename T, bool b>
inline void write_impl(T&& t, const std::integral_constant<bool, b>&) {
......@@ -46,10 +49,16 @@ public:
* of the buffer_output_archive instance should be shorter than
* that of the buffer itself.
*/
buffer_output_archive(buffer& b) : buffer_(b), pos(0) {
buffer_.resize(0);
buffer_output_archive(buffer& b, engine& e)
: m_buffer(b), m_pos(0), m_engine(&e) {
m_buffer.resize(0);
}
buffer_output_archive(buffer& b)
: m_buffer(b), m_pos(0), m_engine(nullptr) {
m_buffer.resize(0);
}
/**
* Operator to add a C++ object of type T into the archive.
* The object should either be a basic type, or an STL container
......@@ -82,14 +91,14 @@ public:
template<typename T>
inline void write(T* const t, size_t count=1) {
size_t s = count*sizeof(T);
if(pos+s > buffer_.size()) {
if(pos+s > buffer_.capacity()) {
buffer_.reserve(buffer_.capacity()*2);
if(m_pos+s > m_buffer.size()) {
if(m_pos+s > m_buffer.capacity()) {
m_buffer.reserve(m_buffer.capacity()*2);
}
buffer_.resize(pos+s);
m_buffer.resize(m_pos+s);
}
memcpy((void*)(&buffer_[pos]),(void*)t,s);
pos += s;
memcpy((void*)(&m_buffer[m_pos]),(void*)t,s);
m_pos += s;
}
};
......
......@@ -4,9 +4,12 @@
#
# list of source files
set(thallium-src endpoint.cpp
set(thallium-src bulk.cpp
endpoint.cpp
engine.cpp
remote_procedure.cpp
request.cpp
resolved_bulk.cpp
callable_remote_procedure.cpp
proc_buffer.cpp)
......
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#include <thallium/bulk.hpp>
#include <thallium/resolved_bulk.hpp>
namespace thallium {
bulk::bulk_segment bulk::select(std::size_t offset, std::size_t size) const {
return bulk_segment(*this, offset, size);
}
bulk::bulk_segment bulk::operator()(std::size_t offset, std::size_t size) const {
return select(offset, size);
}
resolved_bulk bulk::bulk_segment::on(const endpoint& ep) const {
return resolved_bulk(*this, ep);
}
}
......@@ -10,7 +10,8 @@
namespace thallium {
callable_remote_procedure::callable_remote_procedure(hg_id_t id, const endpoint& ep, bool ignore_resp) {
callable_remote_procedure::callable_remote_procedure(engine& e, hg_id_t id, const endpoint& ep, bool ignore_resp)
: m_engine(&e) {
m_ignore_response = ignore_resp;
// TODO throw exception if this call fails
margo_create(ep.m_engine->m_mid, ep.m_addr, id, &m_handle);
......
......@@ -8,6 +8,7 @@
#include <thallium/remote_procedure.hpp>
#include <thallium/engine.hpp>
#include <thallium/endpoint.hpp>
#include <thallium/bulk.hpp>
namespace thallium {
......@@ -35,5 +36,19 @@ remote_procedure engine::define(const std::string& name) {
return remote_procedure(*this, id);
}
bulk engine::expose(const std::vector<std::pair<void*,size_t>>& segments, bulk_mode flag) {
hg_bulk_t handle;
hg_uint32_t count = segments.size();
std::vector<void*> buf_ptrs(count);
std::vector<hg_size_t> buf_sizes(count);
for(unsigned i=0; i < segments.size(); i++) {
buf_ptrs[i] = segments[i].first;
buf_sizes[i] = segments[i].second;
}
hg_return_t ret = margo_bulk_create(m_mid, count, &buf_ptrs[0], &buf_sizes[0], (hg_uint32_t)flag, &handle);
// TODO throw an exception if ret != HG_SUCCESS
return bulk(*this, handle, true);
}
}
......@@ -10,15 +10,15 @@
namespace thallium {
remote_procedure::remote_procedure(engine& e, hg_id_t id)
: m_engine(e), m_id(id), m_ignore_response(false) { }
: m_engine(&e), m_id(id), m_ignore_response(false) { }
callable_remote_procedure remote_procedure::on(const endpoint& ep) const {
return callable_remote_procedure(m_id, ep, m_ignore_response);
return callable_remote_procedure(*m_engine, m_id, ep, m_ignore_response);
}
remote_procedure& remote_procedure::ignore_response() {
m_ignore_response = true;
margo_registered_disable_response(m_engine.m_mid, m_id, HG_TRUE);
margo_registered_disable_response(m_engine->m_mid, m_id, HG_TRUE);
return *this;
}
......
#include <thallium/engine.hpp>
#include <thallium/request.hpp>
#include <thallium/endpoint.hpp>
namespace thallium {
endpoint request::get_endpoint() const {
const struct hg_info* info = margo_get_info(m_handle);
hg_addr_t addr;
margo_addr_dup(m_engine->m_mid, info->addr, &addr);
return endpoint(*m_engine, addr);
}
}
/*
* (C) 2017 The University of Chicago
*
* See COPYRIGHT in top-level directory.
*/
#include <thallium/bulk.hpp>
#include <thallium/resolved_bulk.hpp>
namespace thallium {