Commit c0094faa authored by Xin Zhao's avatar Xin Zhao
Browse files

Split shared RMA packet structures.



Previously several RMA packet types share the same structure,
which is misleading for coding. Here make different
RMA packet types use different packet data structures.
Signed-off-by: Pavan Balaji's avatarPavan Balaji <balaji@anl.gov>
parent bfbb1048
......@@ -204,9 +204,11 @@ MPIDI_CH3_PKT_DEFS
datatype_ = pkt_.get.datatype; \
break; \
case (MPIDI_CH3_PKT_ACCUMULATE): \
case (MPIDI_CH3_PKT_GET_ACCUM): \
datatype_ = pkt_.accum.datatype; \
break; \
case (MPIDI_CH3_PKT_GET_ACCUM): \
datatype_ = pkt_.get_accum.datatype; \
break; \
case (MPIDI_CH3_PKT_CAS): \
datatype_ = pkt_.cas.datatype; \
break; \
......@@ -271,7 +273,6 @@ typedef struct MPIDI_CH3_Pkt_get_resp {
typedef struct MPIDI_CH3_Pkt_accum {
MPIDI_CH3_Pkt_type_t type;
MPIDI_CH3_Pkt_flags_t flags;
MPI_Request request_handle; /* For get_accumulate response */
void *addr;
int count;
MPI_Datatype datatype;
......@@ -286,6 +287,24 @@ typedef struct MPIDI_CH3_Pkt_accum {
* with shared locks. Otherwise set to NULL*/
} MPIDI_CH3_Pkt_accum_t;
typedef struct MPIDI_CH3_Pkt_get_accum {
MPIDI_CH3_Pkt_type_t type;
MPIDI_CH3_Pkt_flags_t flags;
MPI_Request request_handle; /* For get_accumulate response */
void *addr;
int count;
MPI_Datatype datatype;
int dataloop_size; /* for derived datatypes */
MPI_Op op;
MPI_Win target_win_handle; /* Used in the last RMA operation in each
* epoch for decrementing rma op counter in
* active target rma and for unlocking window
* in passive target rma. Otherwise set to NULL*/
MPI_Win source_win_handle; /* Used in the last RMA operation in an
* epoch in the case of passive target rma
* with shared locks. Otherwise set to NULL*/
} MPIDI_CH3_Pkt_get_accum_t;
typedef struct MPIDI_CH3_Pkt_get_accum_resp {
MPIDI_CH3_Pkt_type_t type;
MPI_Request request_handle;
......@@ -362,6 +381,26 @@ typedef struct MPIDI_CH3_Pkt_lock {
int origin_rank;
} MPIDI_CH3_Pkt_lock_t;
typedef struct MPIDI_CH3_Pkt_unlock {
MPIDI_CH3_Pkt_type_t type;
int lock_type;
MPI_Win target_win_handle;
MPI_Win source_win_handle;
int target_rank; /* Used in unluck/flush response to look up the
* target state at the origin. */
int origin_rank;
} MPIDI_CH3_Pkt_unlock_t;
typedef struct MPIDI_CH3_Pkt_flush {
MPIDI_CH3_Pkt_type_t type;
int lock_type;
MPI_Win target_win_handle;
MPI_Win source_win_handle;
int target_rank; /* Used in unluck/flush response to look up the
* target state at the origin. */
int origin_rank;
} MPIDI_CH3_Pkt_flush_t;
typedef struct MPIDI_CH3_Pkt_lock_granted {
MPIDI_CH3_Pkt_type_t type;
MPI_Win source_win_handle;
......@@ -369,9 +408,12 @@ typedef struct MPIDI_CH3_Pkt_lock_granted {
* target state at the origin. */
} MPIDI_CH3_Pkt_lock_granted_t;
typedef MPIDI_CH3_Pkt_lock_granted_t MPIDI_CH3_Pkt_flush_ack_t;
typedef MPIDI_CH3_Pkt_lock_t MPIDI_CH3_Pkt_unlock_t;
typedef MPIDI_CH3_Pkt_lock_t MPIDI_CH3_Pkt_flush_t;
typedef struct MPIDI_CH3_Pkt_flush_ack {
MPIDI_CH3_Pkt_type_t type;
MPI_Win source_win_handle;
int target_rank; /* Used in flush_ack response to look up the
* target state at the origin. */
} MPIDI_CH3_Pkt_flush_ack_t;
typedef struct MPIDI_CH3_Pkt_lock_put_unlock {
MPIDI_CH3_Pkt_type_t type;
......@@ -441,6 +483,7 @@ typedef union MPIDI_CH3_Pkt {
MPIDI_CH3_Pkt_get_resp_t get_resp;
MPIDI_CH3_Pkt_accum_t accum;
MPIDI_CH3_Pkt_accum_immed_t accum_immed;
MPIDI_CH3_Pkt_get_accum_t get_accum;
MPIDI_CH3_Pkt_lock_t lock;
MPIDI_CH3_Pkt_lock_granted_t lock_granted;
MPIDI_CH3_Pkt_unlock_t unlock;
......
......@@ -883,6 +883,7 @@ static int send_rma_msg(MPIDI_RMA_Op_t * rma_op, MPID_Win * win_ptr, MPIDI_CH3_P
{
MPIDI_CH3_Pkt_put_t *put_pkt = &rma_op->pkt.put;
MPIDI_CH3_Pkt_accum_t *accum_pkt = &rma_op->pkt.accum;
MPIDI_CH3_Pkt_get_accum_t *get_accum_pkt = &rma_op->pkt.get_accum;
MPID_IOV iov[MPID_IOV_LIMIT];
int mpi_errno = MPI_SUCCESS;
int origin_dt_derived, target_dt_derived, iovcnt;
......@@ -918,8 +919,8 @@ static int send_rma_msg(MPIDI_RMA_Op_t * rma_op, MPID_Win * win_ptr, MPIDI_CH3_P
resp_req->dev.user_buf = rma_op->result_addr;
resp_req->dev.user_count = rma_op->result_count;
resp_req->dev.datatype = rma_op->result_datatype;
resp_req->dev.target_win_handle = accum_pkt->target_win_handle;
resp_req->dev.source_win_handle = accum_pkt->source_win_handle;
resp_req->dev.target_win_handle = get_accum_pkt->target_win_handle;
resp_req->dev.source_win_handle = get_accum_pkt->source_win_handle;
if (!MPIR_DATATYPE_IS_PREDEFINED(resp_req->dev.datatype)) {
MPID_Datatype *result_dtp = NULL;
......@@ -929,12 +930,11 @@ static int send_rma_msg(MPIDI_RMA_Op_t * rma_op, MPID_Win * win_ptr, MPIDI_CH3_P
* request is freed. */
}
/* Note: Get_accumulate uses the same packet type as accumulate */
accum_pkt->request_handle = resp_req->handle;
get_accum_pkt->request_handle = resp_req->handle;
accum_pkt->flags = flags;
iov[0].MPID_IOV_BUF = (MPID_IOV_BUF_CAST) accum_pkt;
iov[0].MPID_IOV_LEN = sizeof(*accum_pkt);
get_accum_pkt->flags = flags;
iov[0].MPID_IOV_BUF = (MPID_IOV_BUF_CAST) get_accum_pkt;
iov[0].MPID_IOV_LEN = sizeof(*get_accum_pkt);
}
else {
accum_pkt->flags = flags;
......@@ -998,9 +998,12 @@ static int send_rma_msg(MPIDI_RMA_Op_t * rma_op, MPID_Win * win_ptr, MPIDI_CH3_P
if (rma_op->pkt.type == MPIDI_CH3_PKT_PUT) {
put_pkt->dataloop_size = target_dtp->dataloop_size;
}
else {
else if (rma_op->pkt.type == MPIDI_CH3_PKT_ACCUMULATE) {
accum_pkt->dataloop_size = target_dtp->dataloop_size;
}
else {
get_accum_pkt->dataloop_size = target_dtp->dataloop_size;
}
}
MPID_Datatype_get_size_macro(rma_op->origin_datatype, origin_type_size);
......
......@@ -569,16 +569,16 @@ int MPIDI_Get_accumulate(const void *origin_addr, int origin_count,
}
else {
MPIDI_CH3_Pkt_accum_t *accum_pkt = &(new_ptr->pkt.accum);
MPIDI_Pkt_init(accum_pkt, MPIDI_CH3_PKT_GET_ACCUM);
accum_pkt->addr = (char *) win_ptr->base_addrs[target_rank] +
MPIDI_CH3_Pkt_get_accum_t *get_accum_pkt = &(new_ptr->pkt.get_accum);
MPIDI_Pkt_init(get_accum_pkt, MPIDI_CH3_PKT_GET_ACCUM);
get_accum_pkt->addr = (char *) win_ptr->base_addrs[target_rank] +
win_ptr->disp_units[target_rank] * target_disp;
accum_pkt->count = target_count;
accum_pkt->datatype = target_datatype;
accum_pkt->dataloop_size = 0;
accum_pkt->op = op;
accum_pkt->target_win_handle = win_ptr->all_win_handles[target_rank];
accum_pkt->source_win_handle = win_ptr->handle;
get_accum_pkt->count = target_count;
get_accum_pkt->datatype = target_datatype;
get_accum_pkt->dataloop_size = 0;
get_accum_pkt->op = op;
get_accum_pkt->target_win_handle = win_ptr->all_win_handles[target_rank];
get_accum_pkt->source_win_handle = win_ptr->handle;
new_ptr->origin_addr = (void *) origin_addr;
new_ptr->origin_count = origin_count;
......
......@@ -456,7 +456,7 @@ int MPIDI_CH3_PktHandler_Accumulate(MPIDI_VC_t * vc, MPIDI_CH3_Pkt_t * pkt,
int MPIDI_CH3_PktHandler_GetAccumulate(MPIDI_VC_t * vc, MPIDI_CH3_Pkt_t * pkt,
MPIDI_msg_sz_t * buflen, MPID_Request ** rreqp)
{
MPIDI_CH3_Pkt_accum_t *accum_pkt = &pkt->accum;
MPIDI_CH3_Pkt_get_accum_t *get_accum_pkt = &pkt->get_accum;
MPID_Request *req = NULL;
MPI_Aint true_lb, true_extent, extent;
void *tmp_buf = NULL;
......@@ -472,9 +472,9 @@ int MPIDI_CH3_PktHandler_GetAccumulate(MPIDI_VC_t * vc, MPIDI_CH3_Pkt_t * pkt,
MPIU_DBG_MSG(CH3_OTHER, VERBOSE, "received accumulate pkt");
MPIU_Assert(accum_pkt->target_win_handle != MPI_WIN_NULL);
MPID_Win_get_ptr(accum_pkt->target_win_handle, win_ptr);
mpi_errno = MPIDI_CH3_Start_rma_op_target(win_ptr, accum_pkt->flags);
MPIU_Assert(get_accum_pkt->target_win_handle != MPI_WIN_NULL);
MPID_Win_get_ptr(get_accum_pkt->target_win_handle, win_ptr);
mpi_errno = MPIDI_CH3_Start_rma_op_target(win_ptr, get_accum_pkt->flags);
data_len = *buflen - sizeof(MPIDI_CH3_Pkt_t);
data_buf = (char *) pkt + sizeof(MPIDI_CH3_Pkt_t);
......@@ -483,35 +483,35 @@ int MPIDI_CH3_PktHandler_GetAccumulate(MPIDI_VC_t * vc, MPIDI_CH3_Pkt_t * pkt,
MPIU_Object_set_ref(req, 1);
*rreqp = req;
req->dev.user_count = accum_pkt->count;
req->dev.op = accum_pkt->op;
req->dev.real_user_buf = accum_pkt->addr;
req->dev.target_win_handle = accum_pkt->target_win_handle;
req->dev.source_win_handle = accum_pkt->source_win_handle;
req->dev.flags = accum_pkt->flags;
req->dev.user_count = get_accum_pkt->count;
req->dev.op = get_accum_pkt->op;
req->dev.real_user_buf = get_accum_pkt->addr;
req->dev.target_win_handle = get_accum_pkt->target_win_handle;
req->dev.source_win_handle = get_accum_pkt->source_win_handle;
req->dev.flags = get_accum_pkt->flags;
req->dev.resp_request_handle = accum_pkt->request_handle;
req->dev.resp_request_handle = get_accum_pkt->request_handle;
if (MPIR_DATATYPE_IS_PREDEFINED(accum_pkt->datatype)) {
if (MPIR_DATATYPE_IS_PREDEFINED(get_accum_pkt->datatype)) {
MPIDI_Request_set_type(req, MPIDI_REQUEST_TYPE_GET_ACCUM_RESP);
req->dev.datatype = accum_pkt->datatype;
req->dev.datatype = get_accum_pkt->datatype;
MPIR_Type_get_true_extent_impl(accum_pkt->datatype, &true_lb, &true_extent);
MPID_Datatype_get_extent_macro(accum_pkt->datatype, extent);
MPIR_Type_get_true_extent_impl(get_accum_pkt->datatype, &true_lb, &true_extent);
MPID_Datatype_get_extent_macro(get_accum_pkt->datatype, extent);
/* Predefined types should always have zero lb */
MPIU_Assert(true_lb == 0);
tmp_buf = MPIU_Malloc(accum_pkt->count * (MPIR_MAX(extent, true_extent)));
tmp_buf = MPIU_Malloc(get_accum_pkt->count * (MPIR_MAX(extent, true_extent)));
if (!tmp_buf) {
MPIU_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %d",
accum_pkt->count * MPIR_MAX(extent, true_extent));
get_accum_pkt->count * MPIR_MAX(extent, true_extent));
}
req->dev.user_buf = tmp_buf;
MPID_Datatype_get_size_macro(accum_pkt->datatype, type_size);
req->dev.recv_data_sz = type_size * accum_pkt->count;
MPID_Datatype_get_size_macro(get_accum_pkt->datatype, type_size);
req->dev.recv_data_sz = type_size * get_accum_pkt->count;
mpi_errno = MPIDI_CH3U_Receive_data_found(req, data_buf, &data_len, &complete);
MPIU_ERR_CHKANDJUMP1(mpi_errno, mpi_errno, MPI_ERR_OTHER, "**ch3|postrecv",
......@@ -548,20 +548,20 @@ int MPIDI_CH3_PktHandler_GetAccumulate(MPIDI_VC_t * vc, MPIDI_CH3_Pkt_t * pkt,
"MPIDI_RMA_dtype_info");
}
req->dev.dataloop = MPIU_Malloc(accum_pkt->dataloop_size);
req->dev.dataloop = MPIU_Malloc(get_accum_pkt->dataloop_size);
if (!req->dev.dataloop) {
MPIU_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %d",
accum_pkt->dataloop_size);
get_accum_pkt->dataloop_size);
}
if (data_len >= sizeof(MPIDI_RMA_dtype_info) + accum_pkt->dataloop_size) {
if (data_len >= sizeof(MPIDI_RMA_dtype_info) + get_accum_pkt->dataloop_size) {
/* copy all of dtype_info and dataloop */
MPIU_Memcpy(req->dev.dtype_info, data_buf, sizeof(MPIDI_RMA_dtype_info));
MPIU_Memcpy(req->dev.dataloop, data_buf + sizeof(MPIDI_RMA_dtype_info),
accum_pkt->dataloop_size);
get_accum_pkt->dataloop_size);
*buflen =
sizeof(MPIDI_CH3_Pkt_t) + sizeof(MPIDI_RMA_dtype_info) + accum_pkt->dataloop_size;
sizeof(MPIDI_CH3_Pkt_t) + sizeof(MPIDI_RMA_dtype_info) + get_accum_pkt->dataloop_size;
/* All dtype data has been received, call req handler */
mpi_errno = MPIDI_CH3_ReqHandler_AccumRespDerivedDTComplete(vc, req, &complete);
......@@ -576,7 +576,7 @@ int MPIDI_CH3_PktHandler_GetAccumulate(MPIDI_VC_t * vc, MPIDI_CH3_Pkt_t * pkt,
req->dev.iov[0].MPID_IOV_BUF = (MPID_IOV_BUF_CAST) req->dev.dtype_info;
req->dev.iov[0].MPID_IOV_LEN = sizeof(MPIDI_RMA_dtype_info);
req->dev.iov[1].MPID_IOV_BUF = (MPID_IOV_BUF_CAST) req->dev.dataloop;
req->dev.iov[1].MPID_IOV_LEN = accum_pkt->dataloop_size;
req->dev.iov[1].MPID_IOV_LEN = get_accum_pkt->dataloop_size;
req->dev.iov_count = 2;
*buflen = sizeof(MPIDI_CH3_Pkt_t);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment