Commit a1063bb7 authored by balaji's avatar balaji
Browse files

Fixes to the temporary buffer calculation and usage in the MPI_Gather code.

parent 5d9c9575
......@@ -59,7 +59,7 @@ int MPIR_Gather (
int curr_cnt=0, relative_rank, nbytes, is_homogeneous;
int mask, sendtype_size, recvtype_size, src, dst, relative_src;
int recvblks;
int tmp_buf_size, diff;
int tmp_buf_size, missing;
void *tmp_buf=NULL;
MPI_Status status;
MPI_Aint extent=0; /* Datatype extent */
......@@ -94,7 +94,7 @@ int MPIR_Gather (
/* Use binomial tree algorithm. */
relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
if (rank == root)
MPID_Datatype_get_extent_macro(recvtype, extent);
......@@ -114,39 +114,25 @@ int MPIR_Gather (
nbytes = sendtype_size * sendcnt;
}
/* Find the accurate size of the temporary buffer needed */
/* Find the number of missing nodes in my sub-tree compared to
* a balanced tree */
for (mask = 1; mask < comm_size; mask <<= 1);
diff = mask-- - comm_size;
/* Setup temporary buffer size needed by assuming a balanced tree */
for (tmp_buf_size = 1; mask; mask >>= 1)
if (!(relative_rank & mask)) {
tmp_buf_size = mask + 1;
break;
}
/* If my sub-tree is unbalanced, reduce my count by diff */
do {
if (relative_rank & 1) break;
--mask;
while (relative_rank & mask) mask >>= 1;
missing = (relative_rank | mask) - comm_size + 1;
if (missing < 0) missing = 0;
tmp_buf_size = (mask - missing);
mask = 1;
while (((mask | relative_rank) != relative_rank) && (mask < comm_size))
mask <<= 1;
/* If the message is smaller than the threshold, we will copy
* our message in there too */
if (nbytes < MPIR_GATHER_VSMALL_MSG) tmp_buf_size++;
if ((relative_rank | (mask - 1)) < comm_size) break;
tmp_buf_size -= diff;
} while (0);
tmp_buf_size *= nbytes;
/* For zero-ranked root, we don't need any temporary buffer */
if ((rank == root) && (!root || (nbytes >= MPIR_GATHER_VSMALL_MSG)))
tmp_buf_size = 0;
/* If there is only one element, we'll directly send it from
* the send buffer. We won't need the temporary buffer in this
* case. */
if (tmp_buf_size == 1) tmp_buf_size = 0;
else tmp_buf_size *= nbytes;
if (tmp_buf_size) {
tmp_buf = MPIU_Malloc(tmp_buf_size);
/* --BEGIN ERROR HANDLING-- */
......@@ -158,7 +144,6 @@ int MPIR_Gather (
/* --END ERROR HANDLING-- */
}
if (rank == root)
{
if (sendbuf != MPI_IN_PLACE)
......@@ -229,13 +214,18 @@ int MPIR_Gather (
}
else /* Intermediate nodes store in temporary buffer */
{
int offset;
/* Estimate the amount of data that is going to come in */
recvblks = mask;
relative_src = ((src - root) < 0) ? (src - root + comm_size) : (src - root);
if (relative_src + mask > comm_size)
recvblks -= (relative_src + mask - comm_size);
mpi_errno = MPIC_Recv(((char *)tmp_buf + mask * nbytes),
if (nbytes < MPIR_GATHER_VSMALL_MSG)
offset = mask * nbytes;
else offset = 0;
mpi_errno = MPIC_Recv(((char *)tmp_buf + offset),
recvblks * nbytes, MPI_BYTE, src,
MPIR_GATHER_TAG, comm,
&status);
......@@ -260,11 +250,17 @@ int MPIR_Gather (
MPIR_GATHER_TAG, comm);
}
else {
int offset;
if (nbytes < MPIR_GATHER_VSMALL_MSG)
offset = nbytes;
else offset = 0;
blocks[0] = sendcnt;
struct_displs[0] = (MPI_Aint) sendbuf;
types[0] = sendtype;
blocks[1] = curr_cnt - nbytes;
struct_displs[1] = (MPI_Aint) ((char*) tmp_buf + nbytes);
struct_displs[1] = (MPI_Aint) ((char*) tmp_buf + offset);
types[1] = MPI_BYTE;
NMPI_Type_create_struct(2, blocks, struct_displs, types, &tmp_type);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment