Commit 4b07ac8a authored by Sameer Kumar's avatar Sameer Kumar Committed by Pavan Balaji
Browse files

Bug fix for strided datatypes.



Full fix

Fix for get accumulate that sends contig ack back and then scatters result buffer on the src node.

Remove unused params.
Signed-off-by: default avatarMichael Blocksome <blocksom@us.ibm.com>
parent c03e766f
......@@ -45,7 +45,9 @@ MPIDI_Win_DoneCB(pami_context_t context,
req->origin.count,
req->origin.datatype);
MPID_assert(mpi_errno == MPI_SUCCESS);
#ifndef USE_PAMI_RDMA
MPIDI_Win_datatype_unmap(&req->target.dt);
#endif
MPID_Datatype_release(req->origin.dt.pointer);
MPIU_Free(req->buffer);
MPIU_Free(req->user_buffer);
......
......@@ -38,28 +38,21 @@ MPIDI_Win_GetAccumSendAck(pami_context_t context,
pami_result_t rc = PAMI_SUCCESS;
//Copy from msginfo->addr to a contiguous buffer
MPIDI_Datatype result_dt;
char *buffer = NULL;
MPIDI_Win_datatype_basic(msginfo->result_count,
msginfo->result_datatype,
&result_dt);
int use_map = 0;
buffer = MPIU_Malloc(result_dt.size);
if (result_dt.contig)
memcpy(buffer, (msginfo->addr + result_dt.true_lb), result_dt.size);
buffer = MPIU_Malloc(msginfo->size);
MPID_assert(buffer != NULL);
if (msginfo->num_contig == 1)
memcpy(buffer, msginfo->addr, msginfo->size);
else
{
use_map = 1;
MPID_assert(buffer != NULL);
int mpi_errno = 0;
mpi_errno = MPIR_Localcopy(msginfo->addr,
msginfo->count,
msginfo->type,
buffer,
result_dt.size,
msginfo->size,
MPI_CHAR);
MPID_assert(mpi_errno == MPI_SUCCESS);
}
......@@ -69,6 +62,7 @@ MPIDI_Win_GetAccumSendAck(pami_context_t context,
pami_send_t params = {
.send = {
.header = {
.iov_base = msginfo,
.iov_len = sizeof(MPIDI_Win_GetAccMsgInfo),
},
.dispatch = MPIDI_Protocols_WinGetAccumAck,
......@@ -80,30 +74,14 @@ MPIDI_Win_GetAccumSendAck(pami_context_t context,
},
};
int index = 0;
size_t local_offset = 0;
//Set the map
MPIDI_Win_datatype_map(&result_dt);
MPID_assert(result_dt.num_contig == msginfo->result_num_contig);
while (index < result_dt.num_contig) {
params.send.header.iov_base = msginfo;
params.send.data.iov_len = result_dt.map[index].DLOOP_VECTOR_LEN;
params.send.data.iov_base = buffer + local_offset;
rc = PAMI_Send(context, &params);
MPID_assert(rc == PAMI_SUCCESS);
local_offset += params.send.data.iov_len;
++index;
}
/** Review PAMI_Send semantics and consider moving the free calls to
the completion callback*/
//free msginfo
MPIU_Free(msginfo);
params.send.data.iov_len = msginfo->size;
params.send.data.iov_base = buffer;
rc = PAMI_Send(context, &params);
MPID_assert(rc == PAMI_SUCCESS);
if (use_map && result_dt.map != &result_dt.__map)
MPIU_Free (result_dt.map);
//free msginfo
//MPIU_Free(msginfo);
}
void
......@@ -158,8 +136,7 @@ MPIDI_Win_GetAccDoneCB(pami_context_t context,
++req->win->mpid.sync.complete;
++req->origin.completed;
if (req->origin.completed ==
(req->result_num_contig + req->target.dt.num_contig))
if (req->origin.completed == req->target.dt.num_contig + 1)
{
if(req->req_handle)
MPID_cc_set(req->req_handle->cc_ptr, 0);
......@@ -177,6 +154,29 @@ MPIDI_Win_GetAccDoneCB(pami_context_t context,
MPIDI_Progress_signal();
}
void
MPIDI_Win_GetAccAckDoneCB(pami_context_t context,
void * _msginfo,
pami_result_t result)
{
MPIDI_Win_GetAccMsgInfo * msginfo =(MPIDI_Win_GetAccMsgInfo *)_msginfo;
MPIDI_Win_request *req = (MPIDI_Win_request *) msginfo->request;
if (req->result_num_contig > 1) {
MPIR_Localcopy(req->result.addr,
req->result.count,
req->result.datatype,
msginfo->result_addr,
msginfo->size,
MPI_CHAR);
MPIU_Free(msginfo->result_addr);
}
MPIU_Free(msginfo);
MPIDI_Win_GetAccDoneCB(context, req, result);
}
void
MPIDI_WinGetAccumAckCB(pami_context_t context,
void * cookie,
......@@ -189,24 +189,22 @@ MPIDI_WinGetAccumAckCB(pami_context_t context,
{
MPID_assert(recv != NULL);
MPID_assert(sndbuf == NULL);
MPID_assert(msginfo_size == sizeof(MPIDI_Win_GetAccMsgInfo));
MPID_assert(_msginfo != NULL);
const MPIDI_Win_GetAccMsgInfo * msginfo =(const MPIDI_Win_GetAccMsgInfo *)_msginfo;
int null=0;
pami_type_t pami_type;
pami_data_function pami_op;
MPI_Op op = msginfo->op;
MPIDI_Datatype_to_pami(msginfo->result_datatype, &pami_type, op, &pami_op, &null);
recv->addr = msginfo->result_addr;
recv->type = pami_type;
MPIDI_Win_GetAccMsgInfo * msginfo =MPIU_Malloc(sizeof(MPIDI_Win_GetAccMsgInfo));
*msginfo = *(const MPIDI_Win_GetAccMsgInfo *)_msginfo;
MPIDI_Win_request *req = (MPIDI_Win_request *) msginfo->request;
msginfo->result_addr = NULL;
recv->addr = req->result.addr;
if (req->result_num_contig > 1)
recv->addr = msginfo->result_addr = MPIU_Malloc(msginfo->size);
recv->type = PAMI_TYPE_BYTE;
recv->offset = 0;
recv->data_fn = PAMI_DATA_COPY;
recv->data_cookie = NULL;
recv->local_fn = MPIDI_Win_GetAccDoneCB;
recv->cookie = msginfo->req;
recv->local_fn = MPIDI_Win_GetAccAckDoneCB;
recv->cookie = msginfo;
}
static pami_result_t
......@@ -455,13 +453,15 @@ MPID_Get_accumulate(const void * origin_addr,
MPIDI_Win_datatype_map(&req->target.dt);
MPIDI_Datatype result_dt;
MPIDI_Win_datatype_basic(result_count, result_datatype, &result_dt);
req->result_num_contig = 1;
if (!result_dt.contig)
req->result_num_contig =result_dt.pointer->max_contig_blocks*result_count+1;
req->result.addr = result_addr;
req->result.count = result_count;
req->result.datatype = result_datatype;
MPIDI_Win_datatype_basic(result_count, result_datatype, &req->result.dt);
MPIDI_Win_datatype_map(&req->result.dt);
req->result_num_contig = req->result.dt.num_contig;
//We wait for #messages depending on target and result_datatype
win->mpid.sync.total += (req->result_num_contig + req->target.dt.num_contig);
win->mpid.sync.total += (1 + req->target.dt.num_contig);
{
MPI_Datatype basic_type = MPI_DATATYPE_NULL;
......@@ -478,20 +478,6 @@ MPID_Get_accumulate(const void * origin_addr,
}
MPID_assert(basic_type != MPI_DATATYPE_NULL);
MPI_Datatype result_basic_type = MPI_DATATYPE_NULL;
MPID_Datatype_get_basic_type(result_datatype, result_basic_type);
/* MPID_Datatype_get_basic_type() doesn't handle the struct types */
if ((result_datatype == MPI_FLOAT_INT) ||
(result_datatype == MPI_DOUBLE_INT) ||
(result_datatype == MPI_LONG_INT) ||
(result_datatype == MPI_SHORT_INT) ||
(result_datatype == MPI_LONG_DOUBLE_INT))
{
MPID_assert(result_basic_type == MPI_DATATYPE_NULL);
result_basic_type = result_datatype;
}
MPID_assert(result_basic_type != MPI_DATATYPE_NULL);
unsigned index;
MPIDI_Win_GetAccMsgInfo * headers = MPIU_Calloc0(req->target.dt.num_contig, MPIDI_Win_GetAccMsgInfo);
req->accum_headers = headers;
......@@ -505,11 +491,8 @@ MPID_Get_accumulate(const void * origin_addr,
headers[index].count = target_count;
headers[index].counter = index;
headers[index].num_contig = req->target.dt.num_contig;
headers[index].size = req->target.dt.size;
headers[index].request = req;
headers[index].result_addr = result_addr;
headers[index].result_count = result_count;
headers[index].result_datatype = result_basic_type;
headers[index].result_num_contig= req->result_num_contig;
}
}
......
......@@ -199,11 +199,9 @@ typedef struct
int count;
int counter;
int num_contig;
int size;
void * request;
void * result_addr;
int result_count;
MPI_Datatype result_datatype;
int result_num_contig;
pami_endpoint_t src_endpoint;
} MPIDI_Win_GetAccMsgInfo;
......@@ -245,13 +243,19 @@ typedef struct _mpidi_win_request
MPIDI_Datatype dt;
} target;
struct
{
void *addr;
int count;
MPI_Datatype datatype;
MPIDI_Datatype dt;
} result;
void *user_buffer;
void *compare_buffer; /* anchor of compare buffer for compare and swap */
void *compare_buffer; /* anchor of compare buffer for compare and swap */
uint32_t buffer_free;
void *buffer;
struct _mpidi_win_request *next;
void * compare_addr;
void * result_addr;
MPI_Op op;
int result_num_contig;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment