Commit b926001a authored by Matthieu Dorier's avatar Matthieu Dorier
Browse files

fixed issue in mona_recv, added support for MONA_ANY_TAG and MONA_ANY_SOURCE

parent 03300054
......@@ -19,6 +19,9 @@ typedef struct mona_request* mona_request_t;
#define MONA_INSTANCE_NULL ((mona_instance_t)NULL)
#define MONA_REQUEST_NULL ((mona_request_t)NULL)
#define MONA_ANY_SOURCE NA_ADDR_NULL
#define MONA_ANY_TAG 0xFFFFFFFF
/**
* @brief Initialize a Mona instance.
*
......@@ -349,22 +352,31 @@ na_return_t mona_isend(
* (mona_msg_*) functions may lead to undefined behaviors and should
* be avoided.
*
* Because the called may use MONA_ANY_SOURCE and/or MONA_ANY_TAG,
* the actual_src and actual_tag can be used to get the actual sender
* and tag. These parameters, as well as actual_size, may be set to
* NULL to be ignored.
*
* @param mona [IN/OUT] Mona instance
* @param buf [OUT] buffer in which to place the received data
* @param buf_size [IN] buffer size
* @param size [IN] buffer size
* @param dest [IN] source address
* @param tag [IN] tag
* @param actual_size [OUT] actual received size
* @param actual_src [OUT] actual source
* @param actual_tag [OUT] actual tag
*
* @return NA_SUCCESS or corresponding NA error code
*/
na_return_t mona_recv(
mona_instance_t mona,
void* buf,
na_size_t buf_size,
na_size_t size,
na_addr_t src,
na_tag_t tag,
na_size_t* actual_size);
na_size_t* actual_size,
na_addr_t* actual_src,
na_tag_t* actual_tag);
/**
* @brief Non-blocking equivalent of mona_recv. The resulting mona_request_t
......@@ -376,6 +388,8 @@ na_return_t mona_recv(
* @param dest [IN] source address
* @param tag [IN] tag
* @param actual_size [OUT] actual received size
* @param actual_src [OUT] actual source
* @param actual_tag [OUT] actual tag
* @param req [OUT] request
*
* @return NA_SUCCESS or corresponding NA error code
......@@ -387,6 +401,8 @@ na_return_t mona_irecv(
na_addr_t src,
na_tag_t tag,
na_size_t* actual_size,
na_addr_t* actual_src,
na_tag_t* actual_tag,
mona_request_t* req);
/**
......
......@@ -17,11 +17,19 @@ typedef struct cached_op_id {
typedef struct cached_msg* cached_msg_t;
typedef struct cached_msg {
char* buffer;
void* plugin_data;
cached_msg_t next;
char* buffer;
void* plugin_data;
void* next; // may point to a cached_msg or to a pending_msg depending on context
} cached_msg;
typedef struct pending_msg* pending_msg_t;
typedef struct pending_msg {
cached_msg_t cached_msg;
na_size_t recv_size;
na_addr_t recv_addr;
na_tag_t recv_tag;
} pending_msg;
typedef struct mona_instance {
// NA structures
na_class_t* na_class;
......@@ -45,6 +53,12 @@ typedef struct mona_instance {
// message cache for high-level functions
cached_msg_t msg_cache;
ABT_mutex msg_cache_mtx;
// pending messages received in high-level mona_recv
pending_msg_t pending_msg_oldest; // head of the queue
pending_msg_t pending_msg_newest; // last of the queue
ABT_mutex pending_msg_mtx;
ABT_cond pending_msg_cv;
na_bool_t pending_msg_queue_active; // a thread is queuing messages
} mona_instance;
typedef struct mona_request {
......
......@@ -152,12 +152,18 @@ mona_instance_t mona_init_na_pool(
mona->op_id_cache_mtx = ABT_MUTEX_NULL;
mona->req_cache_mtx = ABT_MUTEX_NULL;
mona->msg_cache_mtx = ABT_MUTEX_NULL;
mona->pending_msg_mtx = ABT_MUTEX_NULL;
mona->pending_msg_cv = ABT_COND_NULL;
ret = ABT_mutex_create(&(mona->op_id_cache_mtx));
if(ret != ABT_SUCCESS) goto error;
ret = ABT_mutex_create(&(mona->req_cache_mtx));
if(ret != ABT_SUCCESS) goto error;
ret = ABT_mutex_create(&(mona->msg_cache_mtx));
if(ret != ABT_SUCCESS) goto error;
ret = ABT_mutex_create(&(mona->pending_msg_mtx));
if(ret != ABT_SUCCESS) goto error;
ret = ABT_cond_create(&(mona->pending_msg_cv));
if(ret != ABT_SUCCESS) goto error;
mona->op_id_cache = (cached_op_id_t)calloc(1, sizeof(*(mona->op_id_cache)));
mona->op_id_cache->op_id = NA_Op_create(na_class);
......@@ -181,6 +187,12 @@ error:
ABT_mutex_free(&(mona->op_id_cache_mtx));
if(mona->req_cache_mtx != ABT_MUTEX_NULL)
ABT_mutex_free(&(mona->req_cache_mtx));
if(mona->msg_cache_mtx != ABT_MUTEX_NULL)
ABT_mutex_free(&(mona->msg_cache_mtx));
if(mona->pending_msg_mtx != ABT_MUTEX_NULL)
ABT_mutex_free(&(mona->pending_msg_mtx));
if(mona->pending_msg_cv != ABT_COND_NULL)
ABT_cond_free(&(mona->pending_msg_cv));
free(mona);
mona = MONA_INSTANCE_NULL;
goto finish;
......@@ -207,6 +219,9 @@ na_return_t mona_finalize(mona_instance_t mona)
clear_msg_cache(mona);
ABT_mutex_free(&(mona->msg_cache_mtx));
ABT_mutex_free(&(mona->pending_msg_mtx));
ABT_cond_free(&(mona->pending_msg_cv));
if(mona->owns_na_class_and_context) {
NA_Context_destroy(
mona->na_class,
......@@ -614,81 +629,226 @@ na_return_t mona_isend(
return NA_SUCCESS;
}
static cached_msg_t wait_for_matching_unexpected_message(
mona_instance_t mona,
na_addr_t src,
na_tag_t tag,
na_size_t* actual_size,
na_addr_t* actual_src,
na_tag_t* actual_tag)
{
cached_msg_t msg = NULL; /* result */
na_size_t msg_size = mona_msg_get_max_unexpected_size(mona);
na_return_t na_ret = NA_SUCCESS;
// lock the queue of pending messages
ABT_mutex_lock(mona->pending_msg_mtx);
// search in the queue of pending messages for one matching
search_in_queue:
{
pending_msg_t p_msg = mona->pending_msg_oldest;
pending_msg_t p_prev_msg = NULL;
while(p_msg) {
if((tag == MONA_ANY_TAG || p_msg->recv_tag == tag)
&& (src == MONA_ANY_SOURCE || mona_addr_cmp(mona, src, p_msg->recv_addr))) {
break;
} else {
p_prev_msg = p_msg;
p_msg = p_msg->cached_msg->next;
}
}
if(p_msg) { // matching message was found
msg = p_msg->cached_msg;
// remove it from the queue of pending messages
if(p_prev_msg) p_prev_msg->cached_msg->next = p_msg->cached_msg->next;
if(p_msg == mona->pending_msg_oldest)
mona->pending_msg_oldest = p_msg->cached_msg->next;
if(p_msg == mona->pending_msg_newest)
mona->pending_msg_newest = p_prev_msg;
// unlock the queue
ABT_mutex_unlock(mona->pending_msg_mtx);
// copy size, source, and tag
if(actual_size) *actual_size = p_msg->recv_size;
if(actual_src) mona_addr_dup(mona, p_msg->recv_addr, actual_src);
if(actual_tag) *actual_tag = p_msg->recv_tag;
// free the pending message object
mona_addr_free(mona, p_msg->recv_addr);
free(p_msg);
// return the message
return msg;
}
}
// here the matching message wasn't found in the queue
{
// if another thread is actively issuing unexpected recv, wait for the queue to update
if(mona->pending_msg_queue_active) {
ABT_cond_wait(mona->pending_msg_cv, mona->pending_msg_mtx);
if(mona->pending_msg_queue_active)
goto search_in_queue;
}
}
// here no matching message was found and there isn't any other threads updating the queue
// so this thread will take the responsibility for actively listening for messages
mona->pending_msg_queue_active = NA_TRUE;
ABT_mutex_unlock(mona->pending_msg_mtx);
recv_new_message:
{
na_size_t recv_size = 0;
na_addr_t recv_addr = NA_ADDR_NULL;
na_tag_t recv_tag = 0;
// get message from cache
msg = get_msg_from_cache(mona);
// issue unexpected recv
na_ret = mona_msg_recv_unexpected(
mona, msg->buffer, msg_size, msg->plugin_data,
&recv_addr, &recv_tag, &recv_size);
if(na_ret != NA_SUCCESS)
goto error;
// check is received message is matching
if((tag == MONA_ANY_TAG || recv_tag == tag)
&& (src == MONA_ANY_SOURCE || mona_addr_cmp(mona, src, recv_addr))) {
// received message matches
// notify other threads that this thread won't be updating the queue anymore
ABT_mutex_lock(mona->pending_msg_mtx);
mona->pending_msg_queue_active = NA_FALSE;
ABT_mutex_unlock(mona->pending_msg_mtx);
ABT_cond_broadcast(mona->pending_msg_cv);
// copy size, source, and tag
if(actual_size) *actual_size = recv_size;
if(actual_src) *actual_src = recv_addr;
else mona_addr_free(mona, recv_addr);
if(actual_tag) *actual_tag = recv_tag;
// return the message
return msg;
} else {
// received message doesn't match, create a pending message...
pending_msg_t p_msg = (pending_msg_t)malloc(sizeof(*p_msg));
p_msg->cached_msg = msg;
p_msg->recv_size = recv_size;
p_msg->recv_addr = recv_addr;
p_msg->recv_tag = recv_tag;
msg->next = NULL;
// ... and put it in the queue
ABT_mutex_lock(mona->pending_msg_mtx);
if(mona->pending_msg_oldest == NULL) {
mona->pending_msg_oldest = p_msg;
mona->pending_msg_newest = p_msg;
} else {
mona->pending_msg_newest->cached_msg->next = p_msg;
mona->pending_msg_newest = p_msg;
}
// notify other threads that the queue has been updated
ABT_mutex_unlock(mona->pending_msg_mtx);
ABT_cond_broadcast(mona->pending_msg_cv);
goto recv_new_message;
}
}
// error handling
error:
if(msg) return_msg_to_cache(mona, msg);
ABT_mutex_unlock(mona->pending_msg_mtx);
ABT_cond_broadcast(mona->pending_msg_cv);
return NULL;
}
na_return_t mona_recv(
mona_instance_t mona,
void* buf,
na_size_t buf_size,
na_size_t size,
na_addr_t src,
na_tag_t tag,
na_size_t* actual_size)
na_size_t* actual_size,
na_addr_t* actual_src,
na_tag_t* actual_tag)
{
na_return_t na_ret = NA_SUCCESS;
na_mem_handle_t mem_handle = NA_MEM_HANDLE_NULL;
na_mem_handle_t remote_handle = NA_MEM_HANDLE_NULL;
cached_msg_t msg = get_msg_from_cache(mona);
na_size_t msg_size = mona_msg_get_max_unexpected_size(mona);
// Receive unexpected message
// XXX handle the case of receiving something that is not destined to us
na_size_t recv_size;
na_addr_t recv_addr;
na_tag_t recv_tag;
na_ret = mona_msg_recv_unexpected(
mona, msg->buffer, msg_size, msg->plugin_data,
&recv_addr, &recv_tag, &recv_size);
if(na_ret != NA_SUCCESS) goto finish;
na_size_t header_size = mona_msg_get_unexpected_header_size(mona);
cached_msg_t msg = NULL;
na_size_t recv_size = 0;
na_addr_t recv_addr = NA_ADDR_NULL;
na_tag_t recv_tag = 0;
// wait for a matching unexpected message to come around
msg = wait_for_matching_unexpected_message(mona, src, tag, &recv_size, &recv_addr, &recv_tag);
if(!msg) return NA_PROTOCOL_ERROR;
// At this point, we know msg is the message we are looking for
// and the attributes are recv_size, recv_tag, and recv_addr
char* p = msg->buffer + header_size;
char* p = msg->buffer + mona_msg_get_unexpected_header_size(mona);
if(*p == HL_MSG_SMALL) { // small message, embedded data
p += 1;
recv_size -= mona_msg_get_unexpected_header_size(mona)+1;
recv_size = recv_size < buf_size ? recv_size : buf_size;
memcpy(buf, p, recv_size);
if(actual_size)
*actual_size = recv_size;
recv_size -= header_size + 1;
recv_size = recv_size < size ? recv_size : size;
if(recv_size)
memcpy(buf, p, recv_size);
} else // large message, using RDMA transfer
if(*p == HL_MSG_LARGE) {
} else if(*p == HL_MSG_LARGE) { // large message, using RDMA transfer
p += 1;
na_size_t mem_handle_size;
na_size_t data_size;
// read the size of the serialize mem handle
memcpy(&mem_handle_size, p, sizeof(mem_handle_size));
p += sizeof(mem_handle_size);
// read the size of the data associated with the mem handle
memcpy(&data_size, p, sizeof(data_size));
p += sizeof(data_size);
// Expose user memory for RDMA
na_ret = mona_mem_handle_create(mona, (void*)buf, buf_size, NA_MEM_WRITE_ONLY, &mem_handle);
// expose user memory for RDMA
na_ret = mona_mem_handle_create(
mona, (void*)buf, size, NA_MEM_WRITE_ONLY, &mem_handle);
if(na_ret != NA_SUCCESS) goto finish;
na_ret = mona_mem_register(mona, mem_handle);
if(na_ret != NA_SUCCESS) goto finish;
// Deserialize remote handle
// Deserialize remote memory handle
na_ret = mona_mem_handle_deserialize(
mona, &remote_handle, p, mem_handle_size);
if(na_ret != NA_SUCCESS) goto finish;
// Issue RDMA operation
// XXX how do we support a source id different from 0 ?
data_size = data_size < buf_size ? data_size : buf_size;
na_ret = mona_get(mona, mem_handle, 0, remote_handle, 0, data_size, src, 0);
if(na_ret != NA_SUCCESS) goto finish;
data_size = data_size < size ? data_size : size;
if(data_size) {
na_ret = mona_get(mona, mem_handle, 0, remote_handle, 0, data_size, recv_addr, 0);
if(na_ret != NA_SUCCESS) goto finish;
}
// Send ACK
msg_size = mona_msg_get_expected_header_size(mona) + 1;
na_size_t msg_size = header_size + 1;
msg->buffer[msg_size-1] = 0;
na_ret = mona_msg_init_expected(mona, msg->buffer, msg_size);
if(na_ret != NA_SUCCESS) goto finish;
// XXX how do we support a source id different from 0 ?
na_ret = mona_msg_send_expected(mona, msg->buffer, msg_size, msg->plugin_data, src, 0, tag);
na_ret = mona_msg_send_expected(mona, msg->buffer, msg_size,
msg->plugin_data, recv_addr, 0, recv_tag);
if(na_ret != NA_SUCCESS) goto finish;
}
if(actual_size)
*actual_size = recv_size;
if(actual_tag)
*actual_tag = recv_tag;
if(actual_src)
*actual_src = recv_addr;
else
mona_addr_free(mona, recv_addr);
recv_addr = NA_ADDR_NULL;
finish:
if(recv_addr != NA_ADDR_NULL)
mona_addr_free(mona, recv_addr);
if(mem_handle != NA_MEM_HANDLE_NULL)
mona_mem_handle_free(mona, mem_handle);
if(remote_handle != NA_MEM_HANDLE_NULL)
......@@ -700,10 +860,12 @@ finish:
struct irecv_args {
mona_instance_t mona;
void* buf;
na_size_t buf_size;
na_size_t size;
na_addr_t src;
na_tag_t tag;
na_size_t* actual_size;
na_addr_t* actual_src;
na_tag_t* actual_tag;
mona_request_t req;
};
......@@ -713,10 +875,12 @@ static void irecv_thread(void* x)
na_return_t na_ret = mona_recv(
args->mona,
args->buf,
args->buf_size,
args->size,
args->src,
args->tag,
args->actual_size);
args->actual_size,
args->actual_src,
args->actual_tag);
ABT_eventual_set(args->req->eventual, &na_ret, sizeof(na_ret));
free(args);
}
......@@ -724,10 +888,12 @@ static void irecv_thread(void* x)
na_return_t mona_irecv(
mona_instance_t mona,
void* buf,
na_size_t buf_size,
na_size_t size,
na_addr_t src,
na_tag_t tag,
na_size_t* actual_size,
na_addr_t* actual_src,
na_tag_t* actual_tag,
mona_request_t* req)
{
ABT_eventual eventual;
......@@ -738,9 +904,11 @@ na_return_t mona_irecv(
struct irecv_args* args = (struct irecv_args*)malloc(sizeof(*args));
args->mona = mona;
args->buf = buf;
args->buf_size = buf_size;
args->size = size;
args->src = src;
args->actual_size = actual_size;
args->actual_src = actual_src;
args->actual_tag = actual_tag;
args->tag = tag;
mona_request_t tmp_req = get_req_from_cache(mona);
......
......@@ -16,6 +16,12 @@ target_include_directories (test-isend-irecv PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/../include)
target_link_libraries (test-isend-irecv mona)
add_executable (test-isend-irecv-multi test-isend-irecv-multi.c munit/munit.c)
target_include_directories (test-isend-irecv-multi PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/munit
${CMAKE_CURRENT_SOURCE_DIR}/../include)
target_link_libraries (test-isend-irecv-multi mona)
add_executable (test-send-recv-unexpected test-send-recv-unexpected.c munit/munit.c)
target_include_directories (test-send-recv-unexpected PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/munit
......@@ -44,6 +50,7 @@ target_link_libraries (test-na mona)
add_test (NAME TestInit COMMAND ./test-init)
add_test (NAME TestSendRecv COMMAND mpirun -np 2 ./test-send-recv)
add_test (NAME TestISendIRecv COMMAND mpirun -np 2 ./test-isend-irecv)
add_test (NAME TestISendIRecvMulti COMMAND mpirun -np 2 ./test-isend-irecv-multi)
add_test (NAME TestSendRecvUnexpected COMMAND mpirun -np 2 ./test-send-recv-unexpected)
add_test (NAME TestSendRecvExpected COMMAND mpirun -np 2 ./test-send-recv-expected)
add_test (NAME TestPutGet COMMAND mpirun -np 2 ./test-put-get)
......
#include <mpi.h>
#include "munit/munit.h"
#include "mona.h"
typedef struct {
mona_instance_t mona;
int rank;
na_addr_t self_addr;
na_addr_t other_addr;
} test_context;
static void* test_context_setup(const MunitParameter params[], void* user_data)
{
(void)params;
(void)user_data;
int ret;
MPI_Init(NULL, NULL);
ABT_init(0, NULL);
mona_instance_t mona = mona_init("ofi+tcp", NA_TRUE, NULL);
test_context* context = (test_context*)calloc(1, sizeof(*context));
context->mona = mona;
MPI_Comm_rank(MPI_COMM_WORLD, &(context->rank));
ret = mona_addr_self(mona, &(context->self_addr));
munit_assert_int(ret, ==, NA_SUCCESS);
char self_addr_str[128];
na_size_t self_addr_size = 128;
ret = mona_addr_to_string(mona, self_addr_str, &self_addr_size, context->self_addr);
munit_assert_int(ret, ==, NA_SUCCESS);
char other_addr_str[128];
MPI_Sendrecv(self_addr_str, 128, MPI_BYTE, (context->rank + 1) % 2, 0,
other_addr_str, 128, MPI_BYTE, (context->rank + 1) % 2, 0,
MPI_COMM_WORLD, MPI_STATUS_IGNORE);
ret = mona_addr_lookup(mona, other_addr_str, &(context->other_addr));
munit_assert_int(ret, ==, NA_SUCCESS);
return context;
}
static void test_context_tear_down(void* fixture)
{
MPI_Barrier(MPI_COMM_WORLD);
test_context* context = (test_context*)fixture;
mona_addr_free(context->mona, context->self_addr);
mona_addr_free(context->mona, context->other_addr);
mona_finalize(context->mona);
free(context);
ABT_finalize();
MPI_Finalize();
}
static MunitResult test_isend_irecv_multi(const MunitParameter params[], void* data)
{
(void)params;
test_context* context = (test_context*)data;
na_return_t ret;
mona_instance_t mona = context->mona;
na_size_t msg_len = 8192;
char* buf = malloc(msg_len*4);
if(context->rank == 0) { // sender
int i;
for(i = 0; i < (int)msg_len*4; i++) {
buf[i] = i % 32;
}
ret = mona_send(mona, buf, msg_len, context->other_addr, 0, 1);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_send(mona, buf+msg_len, msg_len, context->other_addr, 0, 2);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_send(mona, buf+msg_len*2, msg_len, context->other_addr, 0, 3);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_send(mona, buf+msg_len*3, msg_len, context->other_addr, 0, 4);
munit_assert_int(ret, ==, NA_SUCCESS);
} else { // receiver
int i;
mona_request_t req[4];
na_size_t recv_size[4];
ret = mona_irecv(mona, buf, msg_len, context->other_addr, 4, recv_size, NULL, NULL, req);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_irecv(mona, buf+msg_len*2, msg_len, context->other_addr, 2, recv_size+1, NULL, NULL, req+1);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_irecv(mona, buf+msg_len*3, msg_len, context->other_addr, 3, recv_size+2, NULL, NULL, req+2);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_irecv(mona, buf+msg_len, msg_len, context->other_addr, 1, recv_size+3, NULL, NULL, req+3);
munit_assert_int(ret, ==, NA_SUCCESS);
for(i=0 ; i < 4; i++) {
ret = mona_wait(*(req+i));
munit_assert_int(ret, ==, NA_SUCCESS);
}
}
return MUNIT_OK;
}
static MunitTest test_suite_tests[] = {
{ (char*) "/hl", test_isend_irecv_multi, test_context_setup, test_context_tear_down, MUNIT_TEST_OPTION_NONE, NULL },
{ NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL }
};
static const MunitSuite test_suite = {
(char*) "/mona/isend-irecv-multi", test_suite_tests, NULL, 1, MUNIT_SUITE_OPTION_NONE
};
int main(int argc, char* argv[MUNIT_ARRAY_PARAM(argc + 1)]) {
return munit_suite_main(&test_suite, (void*) "mona", argc, argv);
}
......@@ -83,7 +83,7 @@ static MunitResult test_isend_irecv(const MunitParameter params[], void* data)
munit_assert_int(ret, ==, NA_SUCCESS);
na_size_t recv_size;
ret = mona_irecv(mona, buf, 64, context->other_addr, 1234, &recv_size, &req);
ret = mona_irecv(mona, buf, 64, context->other_addr, 1234, &recv_size, NULL, NULL, &req);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_wait(req);
......@@ -98,7 +98,7 @@ static MunitResult test_isend_irecv(const MunitParameter params[], void* data)
mona_request_t req;
na_size_t recv_size;
ret = mona_irecv(mona, buf, msg_len, context->other_addr, 1234, &recv_size, &req);
ret = mona_irecv(mona, buf, msg_len, context->other_addr, 1234, &recv_size, NULL, NULL, &req);
munit_assert_int(ret, ==, NA_SUCCESS);
ret = mona_wait(req);
......
......@@ -78,7 +78,7 @@ static MunitResult test_send_recv(const MunitParameter params[], void* data)