Commit 40862985 authored by Wesley Bland's avatar Wesley Bland Committed by Antonio J. Pena
Browse files

Improve error checking for buffer aliasing



If the user isn't using MPI_IN_PLACE when they should, this check will do a
better job of warning them about it.

See #2049
Signed-off-by: default avatarAntonio J. Pena <apenya@mcs.anl.gov>
parent 00f452cd
......@@ -932,8 +932,18 @@ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
MPID_Comm_valid_ptr( comm_ptr, mpi_errno );
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
if (comm_ptr->comm_kind == MPID_INTERCOMM)
if (comm_ptr->comm_kind == MPID_INTERCOMM) {
MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcount, mpi_errno);
} else {
/* catch common aliasing cases */
if (sendbuf != MPI_IN_PLACE && sendtype == recvtype &&
recvcount != 0 && sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, (char*)recvbuf + comm_ptr->rank*recvcount*recvtype_size, mpi_errno);
}
}
if (sendbuf != MPI_IN_PLACE)
{
MPIR_ERRTEST_COUNT(sendcount, mpi_errno);
......@@ -961,10 +971,6 @@ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
}
MPIR_ERRTEST_USERBUFFER(recvbuf,recvcount,recvtype,mpi_errno);
/* catch common aliasing cases */
if (sendbuf != MPI_IN_PLACE && sendtype == recvtype && recvcount != 0 && sendcount != 0)
MPIR_ERRTEST_ALIAS_COLL(sendbuf,recvbuf,mpi_errno);
}
MPID_END_ERROR_CHECKS;
}
......
......@@ -1057,6 +1057,16 @@ int MPI_Allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
}
MPIR_ERRTEST_USERBUFFER(sendbuf,sendcount,sendtype,mpi_errno);
/* catch common aliasing cases */
if (comm_ptr->comm_kind == MPID_INTRACOMM &&
sendtype == recvtype &&
recvcounts[comm_ptr->rank] != 0 &&
sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, (char*)recvbuf + displs[comm_ptr->rank]*recvtype_size, mpi_errno);
}
}
if (comm_ptr->comm_kind == MPID_INTRACOMM)
......
......@@ -868,8 +868,12 @@ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
}
if (comm_ptr->comm_kind == MPID_INTERCOMM)
if (comm_ptr->comm_kind == MPID_INTERCOMM) {
MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, count, mpi_errno);
} else {
if (count != 0 && sendbuf != MPI_IN_PLACE)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
}
if (sendbuf != MPI_IN_PLACE)
MPIR_ERRTEST_USERBUFFER(sendbuf,count,datatype,mpi_errno);
......@@ -885,9 +889,6 @@ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
mpi_errno =
( * MPIR_OP_HDL_TO_DTYPE_FN(op) )(datatype);
}
if (count != 0) {
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
}
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
}
MPID_END_ERROR_CHECKS;
......
......@@ -845,6 +845,13 @@ int MPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
MPID_Datatype_committed_ptr( sendtype_ptr, mpi_errno );
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
}
if (comm_ptr->comm_kind == MPID_INTRACOMM &&
sendbuf != MPI_IN_PLACE &&
sendcount == recvcount &&
sendtype == recvtype &&
sendcount != 0)
MPIR_ERRTEST_ALIAS_COLL(sendbuf,recvbuf,mpi_errno);
}
MPIR_ERRTEST_COUNT(recvcount, mpi_errno);
......
......@@ -469,9 +469,12 @@ int MPI_Alltoallv(const void *sendbuf, const int *sendcounts,
MPID_Comm_valid_ptr( comm_ptr, mpi_errno );
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
if (comm_ptr->comm_kind == MPID_INTRACOMM)
if (comm_ptr->comm_kind == MPID_INTRACOMM) {
comm_size = comm_ptr->local_size;
else
if (sendbuf != MPI_IN_PLACE && sendtype == recvtype && sendcounts == recvcounts)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
} else
comm_size = comm_ptr->remote_size;
if (comm_ptr->comm_kind == MPID_INTERCOMM && sendbuf == MPI_IN_PLACE) {
......
......@@ -471,9 +471,12 @@ int MPI_Alltoallw(const void *sendbuf, const int sendcounts[],
MPIU_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**sendbuf_inplace");
}
if (comm_ptr->comm_kind == MPID_INTRACOMM)
if (comm_ptr->comm_kind == MPID_INTRACOMM) {
comm_size = comm_ptr->local_size;
else
if (sendbuf != MPI_IN_PLACE && sendcounts == recvcounts && sendtypes == recvtypes)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
} else
comm_size = comm_ptr->remote_size;
for (i=0; i<comm_size; i++) {
......
......@@ -369,8 +369,10 @@ int MPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datat
mpi_errno =
( * MPIR_OP_HDL_TO_DTYPE_FN(op) )(datatype);
}
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
if (sendbuf != MPI_IN_PLACE)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
}
MPID_END_ERROR_CHECKS;
}
......
......@@ -824,8 +824,11 @@ int MPI_Gather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
MPIR_ERRTEST_USERBUFFER(recvbuf,recvcount,recvtype,mpi_errno);
/* catch common aliasing cases */
if (recvbuf != MPI_IN_PLACE && sendtype == recvtype && sendcount == recvcount && sendcount != 0)
MPIR_ERRTEST_ALIAS_COLL(sendbuf,recvbuf,mpi_errno);
if (recvbuf != MPI_IN_PLACE && sendtype == recvtype && sendcount == recvcount && sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, ((char *)recvbuf) + comm_ptr->rank*recvcount*recvtype_size,mpi_errno);
}
}
else
MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcount, mpi_errno);
......
......@@ -355,6 +355,13 @@ int MPI_Gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
break;
}
}
/* catch common aliasing cases */
if (sendbuf != MPI_IN_PLACE && sendtype == recvtype && recvcounts[comm_ptr->rank] != 0 && sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, (char*)recvbuf + displs[comm_ptr->rank]*recvtype_size, mpi_errno);
}
}
else
MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcount, mpi_errno);
......
......@@ -710,7 +710,15 @@ int MPI_Iallgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
}
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
/* catch common aliasing cases */
if (recvbuf != MPI_IN_PLACE && sendtype == recvtype && sendcount == recvcount && sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, (char*)recvbuf + comm_ptr->rank*recvcount*recvtype_size, mpi_errno);
}
/* TODO more checks may be appropriate (counts, in_place, etc) */
}
MPID_END_ERROR_CHECKS
}
......
......@@ -796,7 +796,8 @@ int MPI_Iallgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, v
MPID_Comm_valid_ptr(comm_ptr, mpi_errno);
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
if (sendbuf != MPI_IN_PLACE && HANDLE_GET_KIND(sendtype) != HANDLE_KIND_BUILTIN) {
if (sendbuf != MPI_IN_PLACE) {
if (HANDLE_GET_KIND(sendtype) != HANDLE_KIND_BUILTIN) {
MPID_Datatype *sendtype_ptr = NULL;
MPID_Datatype_get_ptr(sendtype, sendtype_ptr);
MPID_Datatype_valid_ptr(sendtype_ptr, mpi_errno);
......@@ -805,6 +806,17 @@ int MPI_Iallgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, v
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
}
/* catch common aliasing cases */
if (comm_ptr->comm_kind == MPID_INTRACOMM &&
sendtype == recvtype &&
recvcounts[comm_ptr->rank] != 0 &&
sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, (char*)recvbuf + displs[comm_ptr->rank]*recvtype_size, mpi_errno);
}
}
MPIR_ERRTEST_ARGNULL(recvcounts,"recvcounts", mpi_errno);
MPIR_ERRTEST_ARGNULL(displs,"displs", mpi_errno);
if (HANDLE_GET_KIND(recvtype) != HANDLE_KIND_BUILTIN) {
......
......@@ -772,6 +772,10 @@ int MPI_Iallreduce(const void *sendbuf, void *recvbuf, int count,
MPIR_ERRTEST_USERBUFFER(sendbuf,count,datatype,mpi_errno);
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
if (comm_ptr->comm_kind == MPID_INTRACOMM && count != 0 && sendbuf != MPI_IN_PLACE)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
}
MPID_END_ERROR_CHECKS
......
......@@ -641,7 +641,14 @@ int MPI_Ialltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
}
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
if (comm_ptr->comm_kind == MPID_INTRACOMM &&
sendbuf != MPI_IN_PLACE &&
sendcount == recvcount &&
sendtype == recvtype &&
sendcount != 0)
MPIR_ERRTEST_ALIAS_COLL(sendbuf,recvbuf,mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, etc) */
}
MPID_END_ERROR_CHECKS
}
......
......@@ -378,6 +378,12 @@ int MPI_Ialltoallv(const void *sendbuf, const int sendcounts[], const int sdispl
}
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
if (comm_ptr->comm_kind == MPID_INTRACOMM &&
sendbuf != MPI_IN_PLACE &&
sendcounts == recvcounts &&
sendtype == recvtype)
MPIR_ERRTEST_ALIAS_COLL(sendbuf,recvbuf,mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
}
MPID_END_ERROR_CHECKS
......
......@@ -360,6 +360,11 @@ int MPI_Ialltoallw(const void *sendbuf, const int sendcounts[], const int sdispl
MPIR_ERRTEST_ARGNULL(sendcounts,"sendcounts", mpi_errno);
MPIR_ERRTEST_ARGNULL(sdispls,"sdispls", mpi_errno);
MPIR_ERRTEST_ARGNULL(sendtypes,"sendtypes", mpi_errno);
if (comm_ptr->comm_kind == MPID_INTRACOMM &&
sendcounts == recvcounts &&
sendtypes == recvtypes)
MPIR_ERRTEST_ALIAS_COLL(sendbuf,recvbuf,mpi_errno);
}
MPIR_ERRTEST_ARGNULL(recvcounts,"recvcounts", mpi_errno);
MPIR_ERRTEST_ARGNULL(rdispls,"rdispls", mpi_errno);
......@@ -368,7 +373,7 @@ int MPI_Ialltoallw(const void *sendbuf, const int sendcounts[], const int sdispl
MPIU_ERR_SETANDJUMP(mpi_errno, MPI_ERR_OTHER, "**sendbuf_inplace");
}
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
/* TODO more checks may be appropriate (counts, in_place, etc) */
}
MPID_END_ERROR_CHECKS
}
......
......@@ -305,7 +305,10 @@ int MPI_Iexscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype data
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
if (sendbuf != MPI_IN_PLACE)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, etc) */
}
MPID_END_ERROR_CHECKS
}
......
......@@ -649,8 +649,11 @@ int MPI_Igather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
MPIR_ERRTEST_USERBUFFER(recvbuf,recvcount,recvtype,mpi_errno);
/* catch common aliasing cases */
if (recvbuf != MPI_IN_PLACE && sendtype == recvtype && sendcount == recvcount && sendcount != 0)
MPIR_ERRTEST_ALIAS_COLL(sendbuf,recvbuf,mpi_errno);
if (recvbuf != MPI_IN_PLACE && sendtype == recvtype && sendcount == recvcount && sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, (char*)recvbuf + comm_ptr->rank*recvcount*recvtype_size, mpi_errno);
}
}
else
MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcount, mpi_errno);
......
......@@ -256,6 +256,13 @@ int MPI_Igatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void
break;
}
}
/* catch common aliasing cases */
if (sendbuf != MPI_IN_PLACE && sendtype == recvtype && recvcounts[comm_ptr->rank] != 0 && sendcount != 0) {
int recvtype_size;
MPID_Datatype_get_size_macro(recvtype, recvtype_size);
MPIR_ERRTEST_ALIAS_COLL(sendbuf, (char*)recvbuf + displs[comm_ptr->rank]*recvtype_size, mpi_errno);
}
}
else
MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sendcount, mpi_errno);
......
......@@ -1133,7 +1133,10 @@ int MPI_Ireduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
if (comm_ptr->comm_kind == MPID_INTRACOMM && sendbuf != MPI_IN_PLACE)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno)
/* TODO more checks may be appropriate (counts, in_place, etc) */
}
MPID_END_ERROR_CHECKS
}
......
......@@ -1034,7 +1034,10 @@ int MPI_Ireduce_scatter_block(const void *sendbuf, void *recvbuf,
if (mpi_errno != MPI_SUCCESS) goto fn_fail;
MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
/* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
if (comm_ptr->comm_kind == MPID_INTRACOMM && sendbuf != MPI_IN_PLACE)
MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno)
/* TODO more checks may be appropriate (counts, in_place, etc) */
}
MPID_END_ERROR_CHECKS
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment