Commit b173d3ff authored by Darius Buntinas's avatar Darius Buntinas
Browse files

[svn-r6798] Replaced NMPI_Reduce with MPIR_Reduce_impl. Reviewed by goodell@

parent 11541a4e
......@@ -3351,12 +3351,14 @@ int MPIR_Reduce_scatter_block_intra(void *sendbuf, void *recvbuf, int recvcount,
int MPIR_Reduce_scatter_block_inter(void *sendbuf, void *recvbuf, int recvcount,
MPI_Datatype datatype, MPI_Op op, MPID_Comm
*comm_ptr);
int MPIR_Reduce_impl(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
MPI_Op op, int root, MPID_Comm *comm_ptr );
int MPIR_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
MPI_Op op, int root, MPID_Comm *comm_ptr );
int MPIR_Reduce_intra(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
MPI_Op op, int root, MPID_Comm *comm_ptr );
int MPIR_Reduce_or_coll_fn(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
MPI_Op op, int root, MPID_Comm *comm_ptr );
int MPIR_Reduce_inter (void *sendbuf, void *recvbuf, int count, MPI_Datatype
datatype, MPI_Op op, int root, MPID_Comm *comm_ptr);
datatype, MPI_Op op, int root, MPID_Comm *comm_ptr);
int MPIR_Scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
MPI_Op op, MPID_Comm *comm_ptr);
int MPIR_Scatter_intra(void *sendbuf, int sendcnt, MPI_Datatype sendtype,
......
......@@ -27,7 +27,6 @@
#define NMPI_Get_count MPI_Get_count
#define NMPI_Pack MPI_Pack
#define NMPI_Pack_size MPI_Pack_size
#define NMPI_Reduce MPI_Reduce
#define NMPI_Reduce_scatter MPI_Reduce_scatter
#define NMPI_Reduce_scatter_block MPI_Reduce_scatter_block
#define NMPI_Unpack MPI_Unpack
......@@ -110,7 +109,6 @@
#define NMPI_Get_count PMPI_Get_count
#define NMPI_Pack PMPI_Pack
#define NMPI_Pack_size PMPI_Pack_size
#define NMPI_Reduce PMPI_Reduce
#define NMPI_Reduce_scatter PMPI_Reduce_scatter
#define NMPI_Reduce_scatter_block PMPI_Reduce_scatter_block
#define NMPI_Unpack PMPI_Unpack
......
......@@ -172,10 +172,10 @@ int MPIR_Allreduce_intra (
/* IN_PLACE and not root of reduce. Data supplied to this
allreduce is in recvbuf. Pass that as the sendbuf to reduce. */
mpi_errno = MPIR_Reduce_or_coll_fn(recvbuf, NULL, count, datatype, op, 0, comm_ptr->node_comm);
mpi_errno = MPIR_Reduce_impl(recvbuf, NULL, count, datatype, op, 0, comm_ptr->node_comm);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
} else {
mpi_errno = MPIR_Reduce_or_coll_fn(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr->node_comm);
mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, 0, comm_ptr->node_comm);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
}
} else {
......@@ -212,8 +212,9 @@ int MPIR_Allreduce_intra (
if (!is_homogeneous) {
/* heterogeneous. To get the same result on all processes, we
do a reduce to 0 and then broadcast. */
mpi_errno = NMPI_Reduce ( sendbuf, recvbuf, count, datatype,
op, 0, comm );
mpi_errno = MPIR_Reduce_impl ( sendbuf, recvbuf, count, datatype,
op, 0, comm_ptr );
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
/* FIXME: mpi_errno is error CODE, not necessarily the error
class MPI_ERR_OP. In MPICH2, we can get the error class
with
......
......@@ -536,6 +536,7 @@ static int MPIR_Reduce_redscat_gather (
mpi_errno = MPIC_Recv(recvbuf, cnts[0], datatype,
0, MPIR_REDUCE_TAG, comm,
MPI_STATUS_IGNORE);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
newrank = 0;
send_idx = 0;
last_idx = 2;
......@@ -543,6 +544,7 @@ static int MPIR_Reduce_redscat_gather (
else if (newrank == 0) { /* send */
mpi_errno = MPIC_Send(recvbuf, cnts[0], datatype,
root, MPIR_REDUCE_TAG, comm);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
newrank = -1;
}
newroot = 0;
......@@ -722,12 +724,107 @@ int MPIR_Reduce_intra (
int mpi_errno = MPI_SUCCESS;
int comm_size, is_commutative, type_size, pof2;
MPID_Op *op_ptr;
#if defined(USE_SMP_COLLECTIVES)
MPIU_CHKLMEM_DECL(1);
#endif
MPIU_THREADPRIV_DECL;
if (count == 0) return MPI_SUCCESS;
/* check if multiple threads are calling this collective function */
MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
MPIU_THREADPRIV_GET;
MPIR_Nest_incr();
#if defined(USE_SMP_COLLECTIVES)
/* is the op commutative? We do SMP optimizations only if it is. */
if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN)
is_commutative = 1;
else {
MPID_Op_get_ptr(op, op_ptr);
is_commutative = (op_ptr->kind == MPID_OP_USER_NONCOMMUTE) ? 0 : 1;
}
if (MPIR_Comm_is_node_aware(comm_ptr) && is_commutative) {
void *tmp_buf = NULL;
MPI_Aint true_lb, true_extent, extent;
/* Create a temporary buffer on local roots of all nodes */
if (comm_ptr->node_roots_comm != NULL) {
mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent);
if (mpi_errno) { MPIU_ERR_POP(mpi_errno); }
MPID_Datatype_get_extent_macro(datatype, extent);
MPID_Ensure_Aint_fits_in_pointer(count * MPIR_MAX(extent, true_extent));
MPIU_CHKLMEM_MALLOC(tmp_buf, void *, count*(MPIR_MAX(extent,true_extent)),
mpi_errno, "temporary buffer");
/* adjust for potential negative lower bound in datatype */
tmp_buf = (void *)((char*)tmp_buf - true_lb);
}
/* do the intranode reduce on all nodes other than the root's node */
if (comm_ptr->node_comm != NULL &&
MPIU_Get_intranode_rank(comm_ptr, root) == -1) {
mpi_errno = MPIR_Reduce_impl(sendbuf, tmp_buf, count, datatype,
op, 0, comm_ptr->node_comm);
if (mpi_errno) goto fn_fail;
}
/* do the internode reduce to the root's node */
if (comm_ptr->node_roots_comm != NULL) {
if (comm_ptr->node_roots_comm->rank != MPIU_Get_internode_rank(comm_ptr, root)) {
/* I am not on root's node. Use tmp_buf if we
participated in the first reduce, otherwise use sendbuf */
void *buf = (comm_ptr->node_comm == NULL ? sendbuf : tmp_buf);
mpi_errno = MPIR_Reduce_impl(buf, NULL, count, datatype,
op, MPIU_Get_internode_rank(comm_ptr, root),
comm_ptr->node_roots_comm);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
}
else { /* I am on root's node. I have not participated in the earlier reduce. */
if (comm_ptr->rank != root) {
/* I am not the root though. I don't have a valid recvbuf.
Use tmp_buf as recvbuf. */
mpi_errno = MPIR_Reduce_impl(sendbuf, tmp_buf, count, datatype,
op, MPIU_Get_internode_rank(comm_ptr, root),
comm_ptr->node_roots_comm);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
/* point sendbuf at tmp_buf to make final intranode reduce easy */
sendbuf = tmp_buf;
}
else {
/* I am the root. in_place is automatically handled. */
mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype,
op, MPIU_Get_internode_rank(comm_ptr, root),
comm_ptr->node_roots_comm);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
/* set sendbuf to MPI_IN_PLACE to make final intranode reduce easy. */
sendbuf = MPI_IN_PLACE;
}
}
}
/* do the intranode reduce on the root's node */
if (comm_ptr->node_comm != NULL &&
MPIU_Get_intranode_rank(comm_ptr, root) != -1) {
mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype,
op, MPIU_Get_intranode_rank(comm_ptr, root),
comm_ptr->node_comm);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
}
goto fn_exit;
}
#endif
comm_size = comm_ptr->local_size;
......@@ -749,9 +846,6 @@ int MPIR_Reduce_intra (
while (pof2 <= comm_size) pof2 <<= 1;
pof2 >>=1;
/* check if multiple threads are calling this collective function */
MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
if ((count*type_size > MPIR_REDUCE_SHORT_MSG) &&
(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) && (count >= pof2)) {
/* do a reduce-scatter followed by gather to root. */
......@@ -764,50 +858,20 @@ int MPIR_Reduce_intra (
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
}
/* check if multiple threads are calling this collective function */
MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr );
fn_exit:
/* check if multiple threads are calling this collective function */
MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr );
MPIR_Nest_decr();
return (mpi_errno);
#if defined(USE_SMP_COLLECTIVES)
MPIU_CHKLMEM_FREEALL();
#endif
return mpi_errno;
fn_fail:
goto fn_exit;
}
/* end:nested */
/* A simple utility function to that calls the comm_ptr->coll_fns->Reduce
override if it exists or else it calls MPIR_Reduce_intra with the same arguments. */
#undef FUNCNAME
#define FUNCNAME MPIR_Reduce_or_coll_fn
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
int MPIR_Reduce_or_coll_fn(
void *sendbuf,
void *recvbuf,
int count,
MPI_Datatype datatype,
MPI_Op op,
int root,
MPID_Comm *comm_ptr )
{
int mpi_errno = MPI_SUCCESS;
if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Reduce != NULL)
{
/* --BEGIN USEREXTENSION-- */
mpi_errno = comm_ptr->coll_fns->Reduce(sendbuf, recvbuf, count,
datatype, op, root, comm_ptr);
/* --END USEREXTENSION-- */
}
else {
mpi_errno = MPIR_Reduce_intra(sendbuf, recvbuf, count,
datatype, op, root, comm_ptr);
}
return mpi_errno;
}
/* begin:nested */
/* Needed in intercommunicator allreduce */
#undef FUNCNAME
......@@ -844,17 +908,16 @@ int MPIR_Reduce_inter (
return MPI_SUCCESS;
}
MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
MPIU_THREADPRIV_GET;
MPIR_Nest_incr();
comm = comm_ptr->handle;
if (root == MPI_ROOT) {
/* root receives data from rank 0 on remote group */
MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
/* root receives data from rank 0 on remote group */
mpi_errno = MPIC_Recv(recvbuf, count, datatype, 0,
MPIR_REDUCE_TAG, comm, &status);
MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr );
if (mpi_errno) { MPIU_ERR_POP(mpi_errno); }
}
else {
......@@ -880,8 +943,10 @@ int MPIR_Reduce_inter (
}
/* Get the local intracommunicator */
if (!comm_ptr->local_comm)
MPIR_Setup_intercomm_localcomm( comm_ptr );
if (!comm_ptr->local_comm) {
mpi_errno = MPIR_Setup_intercomm_localcomm( comm_ptr );
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
}
newcomm_ptr = comm_ptr->local_comm;
......@@ -892,15 +957,14 @@ int MPIR_Reduce_inter (
if (rank == 0)
{
MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
mpi_errno = MPIC_Send(tmp_buf, count, datatype, root,
MPIR_REDUCE_TAG, comm);
MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr );
if (mpi_errno) { MPIU_ERR_POP(mpi_errno); }
}
}
fn_exit:
MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr );
MPIU_CHKLMEM_FREEALL();
MPIR_Nest_decr();
return mpi_errno;
......@@ -909,6 +973,76 @@ int MPIR_Reduce_inter (
goto fn_exit;
}
/* end:nested */
/* MPIR_Reduce performs an reduce using point-to-point messages.
This is intended to be used by device-specific implementations of
reduce. In all other cases MPIR_Reduce_impl should be
used. */
#undef FUNCNAME
#define FUNCNAME MPIR_Reduce
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
int MPIR_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
MPI_Op op, int root, MPID_Comm *comm_ptr)
{
int mpi_errno = MPI_SUCCESS;
if (comm_ptr->comm_kind == MPID_INTRACOMM) {
/* intracommunicator */
mpi_errno = MPIR_Reduce_intra(sendbuf, recvbuf, count, datatype,
op, root, comm_ptr);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
} else {
/* intercommunicator */
mpi_errno = MPIR_Reduce_inter(sendbuf, recvbuf, count, datatype,
op, root, comm_ptr);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
}
fn_exit:
return mpi_errno;
fn_fail:
goto fn_exit;
}
/* MPIR_Reduce_impl should be called by any internal component that
would otherwise call MPI_Reduce. This differs from
MPIR_Reduce in that this will call the coll_fns version if it
exists. This function replaces NMPI_Reduce. */
#undef FUNCNAME
#define FUNCNAME MPIR_Reduce_impl
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
int MPIR_Reduce_impl(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
MPI_Op op, int root, MPID_Comm *comm_ptr)
{
int mpi_errno = MPI_SUCCESS;
if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Reduce != NULL) {
mpi_errno = comm_ptr->coll_fns->Reduce(sendbuf, recvbuf, count,
datatype, op, root, comm_ptr);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
} else {
if (comm_ptr->comm_kind == MPID_INTRACOMM) {
/* intracommunicator */
mpi_errno = MPIR_Reduce_intra(sendbuf, recvbuf, count, datatype,
op, root, comm_ptr);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
} else {
/* intercommunicator */
mpi_errno = MPIR_Reduce_inter(sendbuf, recvbuf, count, datatype,
op, root, comm_ptr);
if (mpi_errno) MPIU_ERR_POP(mpi_errno);
}
}
fn_exit:
return mpi_errno;
fn_fail:
goto fn_exit;
}
#endif
......@@ -953,9 +1087,6 @@ int MPI_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
{
int mpi_errno = MPI_SUCCESS;
MPID_Comm *comm_ptr = NULL;
#if defined(USE_SMP_COLLECTIVES)
MPIU_CHKLMEM_DECL(1);
#endif
MPIU_THREADPRIV_DECL;
MPID_MPI_STATE_DECL(MPID_STATE_MPI_REDUCE);
......@@ -1064,128 +1195,12 @@ int MPI_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
/* ... body of routine ... */
if (comm_ptr->coll_fns != NULL && comm_ptr->coll_fns->Reduce != NULL)
{
mpi_errno = comm_ptr->coll_fns->Reduce(sendbuf, recvbuf, count,
datatype, op, root, comm_ptr);
}
else
{
if (comm_ptr->comm_kind == MPID_INTRACOMM) {
/* intracommunicator */
#if defined(USE_SMP_COLLECTIVES)
MPID_Op *op_ptr;
int is_commutative;
/* is the op commutative? We do SMP optimizations only if it is. */
if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN)
is_commutative = 1;
else {
MPID_Op_get_ptr(op, op_ptr);
is_commutative = (op_ptr->kind == MPID_OP_USER_NONCOMMUTE) ? 0 : 1;
}
if (MPIR_Comm_is_node_aware(comm_ptr) && is_commutative) {
void *tmp_buf = NULL;
MPI_Aint true_lb, true_extent, extent;
MPIU_THREADPRIV_GET;
/* Create a temporary buffer on local roots of all nodes */
if (comm_ptr->node_roots_comm != NULL) {
MPIR_Nest_incr();
mpi_errno = NMPI_Type_get_true_extent(datatype, &true_lb, &true_extent);
MPIR_Nest_decr();
if (mpi_errno) { MPIU_ERR_POP(mpi_errno); }
MPID_Datatype_get_extent_macro(datatype, extent);
MPID_Ensure_Aint_fits_in_pointer(count * MPIR_MAX(extent, true_extent));
MPIU_CHKLMEM_MALLOC(tmp_buf, void *, count*(MPIR_MAX(extent,true_extent)),
mpi_errno, "temporary buffer");
/* adjust for potential negative lower bound in datatype */
tmp_buf = (void *)((char*)tmp_buf - true_lb);
}
/* do the intranode reduce on all nodes other than the root's node */
if (comm_ptr->node_comm != NULL &&
MPIU_Get_intranode_rank(comm_ptr, root) == -1) {
mpi_errno = MPIR_Reduce_or_coll_fn(sendbuf, tmp_buf, count, datatype,
op, 0, comm_ptr->node_comm);
if (mpi_errno) goto fn_fail;
}
/* do the internode reduce to the root's node */
if (comm_ptr->node_roots_comm != NULL) {
if (comm_ptr->node_roots_comm->rank != MPIU_Get_internode_rank(comm_ptr, root)) {
/* I am not on root's node. Use tmp_buf if we
participated in the first reduce, otherwise use sendbuf */
void *buf = (comm_ptr->node_comm == NULL ? sendbuf : tmp_buf);
mpi_errno = MPIR_Reduce_intra(buf, NULL, count, datatype,
op, MPIU_Get_internode_rank(comm_ptr, root),
comm_ptr->node_roots_comm);
}
else { /* I am on root's node. I have not participated in the earlier reduce. */
if (comm_ptr->rank != root) {
/* I am not the root though. I don't have a valid recvbuf.
Use tmp_buf as recvbuf. */
mpi_errno = MPIR_Reduce_or_coll_fn(sendbuf, tmp_buf, count, datatype,
op, MPIU_Get_internode_rank(comm_ptr, root),
comm_ptr->node_roots_comm);
/* point sendbuf at tmp_buf to make final intranode reduce easy */
sendbuf = tmp_buf;
}
else {
/* I am the root. in_place is automatically handled. */
mpi_errno = MPIR_Reduce_or_coll_fn(sendbuf, recvbuf, count, datatype,
op, MPIU_Get_internode_rank(comm_ptr, root),
comm_ptr->node_roots_comm);
/* set sendbuf to MPI_IN_PLACE to make final intranode reduce easy. */
sendbuf = MPI_IN_PLACE;
}
}
if (mpi_errno) goto fn_fail;
}
/* do the intranode reduce on the root's node */
if (comm_ptr->node_comm != NULL &&
MPIU_Get_intranode_rank(comm_ptr, root) != -1) {
mpi_errno = MPIR_Reduce_or_coll_fn(sendbuf, recvbuf, count, datatype,
op, MPIU_Get_intranode_rank(comm_ptr, root),
comm_ptr->node_comm);
}
}
else {
mpi_errno = MPIR_Reduce_intra(sendbuf, recvbuf, count, datatype,
op, root, comm_ptr);
}
#else
mpi_errno = MPIR_Reduce_intra(sendbuf, recvbuf, count, datatype,
op, root, comm_ptr);
#endif
}
else {
/* intercommunicator */
mpi_errno = MPIR_Reduce_inter(sendbuf, recvbuf, count, datatype,
op, root, comm_ptr);
}
}
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
mpi_errno = MPIR_Reduce_impl(sendbuf, recvbuf, count, datatype, op, root, comm_ptr);
if (mpi_errno) goto fn_exit;
/* ... end of body of routine ... */
fn_exit:
#if defined(USE_SMP_COLLECTIVES)
MPIU_CHKLMEM_FREEALL();
#endif
MPID_MPI_COLL_FUNC_EXIT(MPID_STATE_MPI_REDUCE);
MPIU_THREAD_CS_EXIT(ALLFUNC,);
return mpi_errno;
......
......@@ -1005,21 +1005,17 @@ void mpig_usage_finalize(void)
total_nbytes = (int64_t *) MPIU_Malloc(mpig_process.my_pg_size * sizeof(int64_t));
total_nbytesv = (int64_t *) MPIU_Malloc(mpig_process.my_pg_size * sizeof(int64_t));
}
MPIR_Nest_incr();
{
rc = MPIR_Gather_impl(&mpig_process.nbytes_sent, sizeof(int64_t), MPI_BYTE, total_nbytes, sizeof(int64_t), MPI_BYTE,
0, MPIR_Process.comm_world);
if (rc) goto err;
rc = MPIR_Gather_impl(&mpig_process.vmpi_nbytes_sent, sizeof(int64_t), MPI_BYTE, total_nbytesv, sizeof(int64_t), MPI_BYTE,
0, MPIR_Process.comm_world);
if (rc) goto err;
NMPI_Reduce(mpig_process.function_count, total_function_count, MPIG_FUNC_CNT_NUMFUNCS, MPI_INT, MPI_SUM,
0, MPI_COMM_WORLD);
}
MPIR_Nest_decr();
rc = MPIR_Gather_impl(&mpig_process.nbytes_sent, sizeof(int64_t), MPI_BYTE, total_nbytes, sizeof(int64_t), MPI_BYTE,
0, MPIR_Process.comm_world);
if (rc) goto err;
rc = MPIR_Gather_impl(&mpig_process.vmpi_nbytes_sent, sizeof(int64_t), MPI_BYTE, total_nbytesv, sizeof(int64_t), MPI_BYTE,
0, MPIR_Process.comm_world);
if (rc) goto err;
rc = MPIR_Reduce_impl(mpig_process.function_count, total_function_count, MPIG_FUNC_CNT_NUMFUNCS, MPI_INT,
MPI_SUM, 0, MPIR_Process.comm_world);
if (rc) goto err;
if(mpig_process.my_pg_rank == 0)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment