red_scat.c 49.5 KB
Newer Older
1
/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
2
3
4
5
6
7
8
/*
 *
 *  (C) 2001 by Argonne National Laboratory.
 *      See COPYRIGHT in top-level directory.
 */

#include "mpiimpl.h"
9
#include "collutil.h"
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*
=== BEGIN_MPI_T_CVAR_INFO_BLOCK ===

cvars:
    - name        : MPIR_CVAR_REDSCAT_COMMUTATIVE_LONG_MSG_SIZE
      category    : COLLECTIVE
      type        : int
      default     : 524288
      class       : device
      verbosity   : MPI_T_VERBOSITY_USER_BASIC
      scope       : MPI_T_SCOPE_ALL_EQ
      description : >-
        the long message algorithm will be used if the operation is commutative
        and the send buffer size is >= this value (in bytes)

=== END_MPI_T_CVAR_INFO_BLOCK ===
*/

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/* -- Begin Profiling Symbol Block for routine MPI_Reduce_scatter */
#if defined(HAVE_PRAGMA_WEAK)
#pragma weak MPI_Reduce_scatter = PMPI_Reduce_scatter
#elif defined(HAVE_PRAGMA_HP_SEC_DEF)
#pragma _HP_SECONDARY_DEF PMPI_Reduce_scatter  MPI_Reduce_scatter
#elif defined(HAVE_PRAGMA_CRI_DUP)
#pragma _CRI duplicate MPI_Reduce_scatter as PMPI_Reduce_scatter
#endif
/* -- End Profiling Symbol Block */

/* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build
   the MPI routines */
#ifndef MPICH_MPI_FROM_PMPI
#undef MPI_Reduce_scatter
#define MPI_Reduce_scatter PMPI_Reduce_scatter

45
/* Implements the reduce-scatter butterfly algorithm described in J. L. Traff's
46
47
 * "An Improved Algorithm for (Non-commutative) Reduce-Scatter with an 
 * Application"
48
49
50
51
52
53
 * from EuroPVM/MPI 2005.  This function currently only implements support for
 * the power-of-2, block-regular case (all receive counts are equal). */
#undef FUNCNAME
#define FUNCNAME MPIR_Reduce_scatter_noncomm
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
54
55
56
static int MPIR_Reduce_scatter_noncomm(const void *sendbuf, void *recvbuf, const int recvcounts[],
                                       MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr,
                                       int *errflag)
57
58
{
    int mpi_errno = MPI_SUCCESS;
59
    int mpi_errno_ret = MPI_SUCCESS;
60
61
62
63
64
65
66
    int comm_size = comm_ptr->local_size;
    int rank = comm_ptr->rank;
    int pof2;
    int log2_comm_size;
    int i, k;
    int recv_offset, send_offset;
    int block_size, total_count, size;
67
    MPI_Aint true_extent, true_lb;
68
69
70
71
72
73
74
    int buf0_was_inout;
    void *tmp_buf0;
    void *tmp_buf1;
    void *result_ptr;
    MPI_Comm comm = comm_ptr->handle;
    MPIU_CHKLMEM_DECL(3);

75
    MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
76
77
78
79
80
81
82
83
84
85
86
87

    pof2 = 1;
    log2_comm_size = 0;
    while (pof2 < comm_size) {
        pof2 <<= 1;
        ++log2_comm_size;
    }

    /* begin error checking */
    MPIU_Assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */

    for (i = 0; i < (comm_size - 1); ++i) {
88
        MPIU_Assert(recvcounts[i] == recvcounts[i+1]);
89
90
91
92
    }
    /* end error checking */

    /* size of a block (count of datatype per block, NOT bytes per block) */
93
    block_size = recvcounts[0];
94
95
96
97
98
99
100
101
102
103
104
    total_count = block_size * comm_size;

    MPIU_CHKLMEM_MALLOC(tmp_buf0, void *, true_extent * total_count, mpi_errno, "tmp_buf0");
    MPIU_CHKLMEM_MALLOC(tmp_buf1, void *, true_extent * total_count, mpi_errno, "tmp_buf1");
    /* adjust for potential negative lower bound in datatype */
    tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
    tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);

    /* Copy our send data to tmp_buf0.  We do this one block at a time and
       permute the blocks as we go according to the mirror permutation. */
    for (i = 0; i < comm_size; ++i) {
105
106
        mpi_errno = MPIR_Localcopy((char *)(sendbuf == MPI_IN_PLACE ? (const void *)recvbuf : sendbuf) + (i * true_extent * block_size),
                                   block_size, datatype,
107
                                   (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
    }
    buf0_was_inout = 1;

    send_offset = 0;
    recv_offset = 0;
    size = total_count;
    for (k = 0; k < log2_comm_size; ++k) {
        /* use a double-buffering scheme to avoid local copies */
        char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
        char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
        int peer = rank ^ (0x1 << k);
        size /= 2;

        if (rank > peer) {
            /* we have the higher rank: send top half, recv bottom half */
            recv_offset += size;
        }
        else {
            /* we have the lower rank: recv top half, send bottom half */
            send_offset += size;
        }

131
        mpi_errno = MPIC_Sendrecv(outgoing_data + send_offset*true_extent,
132
133
134
135
                                     size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
                                     incoming_data + recv_offset*true_extent,
                                     size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
                                     comm, MPI_STATUS_IGNORE, errflag);
136
137
        if (mpi_errno) {
            /* for communication errors, just record the error but continue */
138
            *errflag = TRUE;
139
140
141
            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
        }
142
143
144
145
        /* always perform the reduction at recv_offset, the data at send_offset
           is now our peer's responsibility */
        if (rank > peer) {
            /* higher ranked value so need to call op(received_data, my_data) */
146
147
	    mpi_errno = MPIR_Reduce_local_impl( 
		     incoming_data + recv_offset*true_extent,
148
                     outgoing_data + recv_offset*true_extent,
149
                     size, datatype, op );
150
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
151
152
153
        }
        else {
            /* lower ranked value so need to call op(my_data, received_data) */
154
155
	    MPIR_Reduce_local_impl( 
		     outgoing_data + recv_offset*true_extent,
156
                     incoming_data + recv_offset*true_extent,
157
                     size, datatype, op);
158
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
159
160
161
162
163
164
165
166
            buf0_was_inout = !buf0_was_inout;
        }

        /* the next round of send/recv needs to happen within the block (of size
           "size") that we just received and reduced */
        send_offset = recv_offset;
    }

167
    MPIU_Assert(size == recvcounts[rank]);
168
169
170
171
172

    /* copy the reduced data to the recvbuf */
    result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
    mpi_errno = MPIR_Localcopy(result_ptr, size, datatype,
                               recvbuf, size, datatype);
173
174
    if (mpi_errno) MPIU_ERR_POP(mpi_errno);

175
176
fn_exit:
    MPIU_CHKLMEM_FREEALL();
177
178
    if (mpi_errno_ret)
        mpi_errno = mpi_errno_ret;
179
180
    else if (*errflag)
        MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**coll_fail");
181
182
183
184
185
    return mpi_errno;
fn_fail:
    goto fn_exit;
}

186
/* This is the default implementation of reduce_scatter. The algorithm is:
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
   Algorithm: MPI_Reduce_scatter

   If the operation is commutative, for short and medium-size
   messages, we use a recursive-halving
   algorithm in which the first p/2 processes send the second n/2 data
   to their counterparts in the other half and receive the first n/2
   data from them. This procedure continues recursively, halving the
   data communicated at each step, for a total of lgp steps. If the
   number of processes is not a power-of-two, we convert it to the
   nearest lower power-of-two by having the first few even-numbered
   processes send their data to the neighboring odd-numbered process
   at (rank+1). Those odd-numbered processes compute the result for
   their left neighbor as well in the recursive halving algorithm, and
   then at  the end send the result back to the processes that didn't
202
   participate.
203
204
205
206
207
208
209
210
   Therefore, if p is a power-of-two,
   Cost = lgp.alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma
   If p is not a power-of-two,
   Cost = (floor(lgp)+2).alpha + n.(1+(p-1+n)/p).beta + n.(1+(p-1)/p).gamma
   The above cost in the non power-of-two case is approximate because
   there is some imbalance in the amount of work each process does
   because some processes do the work of their neighbors as well.

211
   For commutative operations and very long messages we use
212
213
   we use a pairwise exchange algorithm similar to
   the one used in MPI_Alltoall. At step i, each process sends n/p
214
   amount of data to (rank+i) and receives n/p amount of data from
215
216
217
218
219
220
   (rank-i).
   Cost = (p-1).alpha + n.((p-1)/p).beta + n.((p-1)/p).gamma


   If the operation is not commutative, we do the following:

221
   We use a recursive doubling algorithm, which
222
223
224
225
226
227
   takes lgp steps. At step 1, processes exchange (n-n/p) amount of
   data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
   amount of data, and so forth.

   Cost = lgp.alpha + n.(lgp-(p-1)/p).beta + n.(lgp-(p-1)/p).gamma

228
   Possible improvements:
229
230
231
232

   End Algorithm: MPI_Reduce_scatter
*/

233
#undef FUNCNAME
234
#define FUNCNAME MPIR_Reduce_scatter_intra
235
236
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
237

238
/* not declared static because a machine-specific function may call this one in some cases */
239
240
int MPIR_Reduce_scatter_intra(const void *sendbuf, void *recvbuf, const int recvcounts[],
                              MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr, int *errflag)
241
242
243
244
245
{
    int   rank, comm_size, i;
    MPI_Aint extent, true_extent, true_lb; 
    int  *disps;
    void *tmp_recvbuf, *tmp_results;
246
247
    int mpi_errno = MPI_SUCCESS;
    int mpi_errno_ret = MPI_SUCCESS;
248
249
250
251
252
253
254
255
256
257
    int type_size, dis[2], blklens[2], total_count, nbytes, src, dst;
    int mask, dst_tree_root, my_tree_root, j, k;
    int *newcnts, *newdisps, rem, newdst, send_idx, recv_idx,
        last_idx, send_cnt, recv_cnt;
    int pof2, old_i, newrank, received;
    MPI_Datatype sendtype, recvtype;
    int nprocs_completed, tmp_mask, tree_root, is_commutative;
    MPID_Op *op_ptr;
    MPI_Comm comm;
    MPIU_THREADPRIV_DECL;
258
259
    MPIU_CHKLMEM_DECL(5);

260
261
262
263
264
265
266
267
268
    comm = comm_ptr->handle;
    comm_size = comm_ptr->local_size;
    rank = comm_ptr->rank;

    /* set op_errno to 0. stored in perthread structure */
    MPIU_THREADPRIV_GET;
    MPIU_THREADPRIV_FIELD(op_errno) = 0;

    MPID_Datatype_get_extent_macro(datatype, extent);
269
    MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
270
271
272
273
274
275
276
277
278
279
280
281
    
    if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) {
        is_commutative = 1;
    }
    else {
        MPID_Op_get_ptr(op, op_ptr);
        if (op_ptr->kind == MPID_OP_USER_NONCOMMUTE)
            is_commutative = 0;
        else
            is_commutative = 1;
    }

282
    MPIU_CHKLMEM_MALLOC(disps, int *, comm_size * sizeof(int), mpi_errno, "disps");
283
284
285
286

    total_count = 0;
    for (i=0; i<comm_size; i++) {
        disps[i] = total_count;
287
        total_count += recvcounts[i];
288
289
290
    }
    
    if (total_count == 0) {
291
        goto fn_exit;
292
293
294
295
296
297
298
    }

    MPID_Datatype_get_size_macro(datatype, type_size);
    nbytes = total_count * type_size;
    
    /* check if multiple threads are calling this collective function */
    MPIDU_ERR_CHECK_MULTIPLE_THREADS_ENTER( comm_ptr );
299
300
301
302
303

    /* total_count*extent eventually gets malloced. it isn't added to
     * a user-passed in buffer */
    MPID_Ensure_Aint_fits_in_pointer(total_count * MPIR_MAX(true_extent, extent));

304
    if ((is_commutative) && (nbytes < MPIR_CVAR_REDSCAT_COMMUTATIVE_LONG_MSG_SIZE)) {
305
306
307
        /* commutative and short. use recursive halving algorithm */

        /* allocate temp. buffer to receive incoming data */
308
        MPIU_CHKLMEM_MALLOC(tmp_recvbuf, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_recvbuf");
309
310
311
312
313
        /* adjust for potential negative lower bound in datatype */
        tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
            
        /* need to allocate another temporary buffer to accumulate
           results because recvbuf may not be big enough */
314
        MPIU_CHKLMEM_MALLOC(tmp_results, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_results");
315
316
317
318
319
320
321
322
323
324
325
        /* adjust for potential negative lower bound in datatype */
        tmp_results = (void *)((char*)tmp_results - true_lb);
        
        /* copy sendbuf into tmp_results */
        if (sendbuf != MPI_IN_PLACE)
            mpi_errno = MPIR_Localcopy(sendbuf, total_count, datatype,
                                       tmp_results, total_count, datatype);
        else
            mpi_errno = MPIR_Localcopy(recvbuf, total_count, datatype,
                                       tmp_results, total_count, datatype);
        
326
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341

        pof2 = 1;
        while (pof2 <= comm_size) pof2 <<= 1;
        pof2 >>=1;

        rem = comm_size - pof2;

        /* In the non-power-of-two case, all even-numbered
           processes of rank < 2*rem send their data to
           (rank+1). These even-numbered processes no longer
           participate in the algorithm until the very end. The
           remaining processes form a nice power-of-two. */

        if (rank < 2*rem) {
            if (rank % 2 == 0) { /* even */
342
                mpi_errno = MPIC_Send(tmp_results, total_count,
343
344
                                         datatype, rank+1,
                                         MPIR_REDUCE_SCATTER_TAG, comm, errflag);
345
346
                if (mpi_errno) {
                    /* for communication errors, just record the error but continue */
347
                    *errflag = TRUE;
348
349
350
                    MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                    MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                }
351
352
353
354
355
356
357
                
                /* temporarily set the rank to -1 so that this
                   process does not pariticipate in recursive
                   doubling */
                newrank = -1; 
            }
            else { /* odd */
358
                mpi_errno = MPIC_Recv(tmp_recvbuf, total_count,
359
360
361
                                         datatype, rank-1,
                                         MPIR_REDUCE_SCATTER_TAG, comm,
                                         MPI_STATUS_IGNORE, errflag);
362
363
                if (mpi_errno) {
                    /* for communication errors, just record the error but continue */
364
                    *errflag = TRUE;
365
366
367
                    MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                    MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                }
368
369
370
371
                
                /* do the reduction on received data. since the
                   ordering is right, it doesn't matter whether
                   the operation is commutative or not. */
372
373
		mpi_errno = MPIR_Reduce_local_impl( 
		    tmp_recvbuf, tmp_results, total_count, datatype, op );
374
375
376
377
378
379
380
381
382
                
                /* change the rank */
                newrank = rank / 2;
            }
        }
        else  /* rank >= 2*rem */
            newrank = rank - rem;

        if (newrank != -1) {
383
            /* recalculate the recvcounts and disps arrays because the
384
385
386
387
               even-numbered processes who no longer participate will
               have their result calculated by the process to their
               right (rank+1). */

388
389
            MPIU_CHKLMEM_MALLOC(newcnts, int *, pof2*sizeof(int), mpi_errno, "newcnts");
            MPIU_CHKLMEM_MALLOC(newdisps, int *, pof2*sizeof(int), mpi_errno, "newdisps");
390
391
392
393
394
395
396
            
            for (i=0; i<pof2; i++) {
                /* what does i map to in the old ranking? */
                old_i = (i < rem) ? i*2 + 1 : i + rem;
                if (old_i < 2*rem) {
                    /* This process has to also do its left neighbor's
                       work */
397
                    newcnts[i] = recvcounts[old_i] + recvcounts[old_i-1];
398
399
                }
                else
400
                    newcnts[i] = recvcounts[old_i];
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
            }
            
            newdisps[0] = 0;
            for (i=1; i<pof2; i++)
                newdisps[i] = newdisps[i-1] + newcnts[i-1];

            mask = pof2 >> 1;
            send_idx = recv_idx = 0;
            last_idx = pof2;
            while (mask > 0) {
                newdst = newrank ^ mask;
                /* find real rank of dest */
                dst = (newdst < rem) ? newdst*2 + 1 : newdst + rem;
                
                send_cnt = recv_cnt = 0;
                if (newrank < newdst) {
                    send_idx = recv_idx + mask;
                    for (i=send_idx; i<last_idx; i++)
                        send_cnt += newcnts[i];
                    for (i=recv_idx; i<send_idx; i++)
                        recv_cnt += newcnts[i];
                }
                else {
                    recv_idx = send_idx + mask;
                    for (i=send_idx; i<recv_idx; i++)
                        send_cnt += newcnts[i];
                    for (i=recv_idx; i<last_idx; i++)
                        recv_cnt += newcnts[i];
                }
                
/*                    printf("Rank %d, send_idx %d, recv_idx %d, send_cnt %d, recv_cnt %d, last_idx %d\n", newrank, send_idx, recv_idx,
                      send_cnt, recv_cnt, last_idx);
*/
                /* Send data from tmp_results. Recv into tmp_recvbuf */ 
                if ((send_cnt != 0) && (recv_cnt != 0)) 
436
                    mpi_errno = MPIC_Sendrecv((char *) tmp_results +
437
438
439
440
441
442
443
444
                                                 newdisps[send_idx]*extent,
                                                 send_cnt, datatype,
                                                 dst, MPIR_REDUCE_SCATTER_TAG,
                                                 (char *) tmp_recvbuf +
                                                 newdisps[recv_idx]*extent,
                                                 recv_cnt, datatype, dst,
                                                 MPIR_REDUCE_SCATTER_TAG, comm,
                                                 MPI_STATUS_IGNORE, errflag);
445
                else if ((send_cnt == 0) && (recv_cnt != 0))
446
                    mpi_errno = MPIC_Recv((char *) tmp_recvbuf +
447
448
449
450
                                             newdisps[recv_idx]*extent,
                                             recv_cnt, datatype, dst,
                                             MPIR_REDUCE_SCATTER_TAG, comm,
                                             MPI_STATUS_IGNORE, errflag);
451
                else if ((recv_cnt == 0) && (send_cnt != 0))
452
                    mpi_errno = MPIC_Send((char *) tmp_results +
453
454
455
456
                                             newdisps[send_idx]*extent,
                                             send_cnt, datatype,
                                             dst, MPIR_REDUCE_SCATTER_TAG,
                                             comm, errflag);
457

458
459
                if (mpi_errno) {
                    /* for communication errors, just record the error but continue */
460
                    *errflag = TRUE;
461
462
463
                    MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                    MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                }
464
465
466
467
468
                
                /* tmp_recvbuf contains data received in this step.
                   tmp_results contains data accumulated so far */
                
                if (recv_cnt) {
469
470
		    mpi_errno = MPIR_Reduce_local_impl( 
			     (char *) tmp_recvbuf + newdisps[recv_idx]*extent,
471
                             (char *) tmp_results + newdisps[recv_idx]*extent, 
472
			     recv_cnt, datatype, op);
473
474
475
476
477
478
479
480
481
                }

                /* update send_idx for next iteration */
                send_idx = recv_idx;
                last_idx = recv_idx + mask;
                mask >>= 1;
            }

            /* copy this process's result from tmp_results to recvbuf */
482
            if (recvcounts[rank]) {
483
484
                mpi_errno = MPIR_Localcopy((char *)tmp_results +
                                           disps[rank]*extent, 
485
486
                                           recvcounts[rank], datatype, recvbuf,
                                           recvcounts[rank], datatype);
487
                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
488
489
490
491
492
493
494
495
496
            }
            
        }

        /* In the non-power-of-two case, all odd-numbered
           processes of rank < 2*rem send to (rank-1) the result they
           calculated for that process */
        if (rank < 2*rem) {
            if (rank % 2) { /* odd */
497
                if (recvcounts[rank-1]) {
498
                    mpi_errno = MPIC_Send((char *) tmp_results +
499
                                             disps[rank-1]*extent, recvcounts[rank-1],
500
501
                                             datatype, rank-1,
                                             MPIR_REDUCE_SCATTER_TAG, comm, errflag);
502
503
                    if (mpi_errno) {
                        /* for communication errors, just record the error but continue */
504
                        *errflag = TRUE;
505
506
507
                        MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                        MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                    }
508
                }
509
510
            }
            else  {   /* even */
511
                if (recvcounts[rank]) {
512
                    mpi_errno = MPIC_Recv(recvbuf, recvcounts[rank],
513
514
515
                                             datatype, rank+1,
                                             MPIR_REDUCE_SCATTER_TAG, comm,
                                             MPI_STATUS_IGNORE, errflag);
516
517
                    if (mpi_errno) {
                        /* for communication errors, just record the error but continue */
518
                        *errflag = TRUE;
519
520
521
                        MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                        MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                    }
522
                }
523
524
525
526
            }
        }
    }
    
527
    if (is_commutative && (nbytes >= MPIR_CVAR_REDSCAT_COMMUTATIVE_LONG_MSG_SIZE)) {
528
529
530

        /* commutative and long message, or noncommutative and long message.
           use (p-1) pairwise exchanges */ 
531

532
533
534
        if (sendbuf != MPI_IN_PLACE) {
            /* copy local data into recvbuf */
            mpi_errno = MPIR_Localcopy(((char *)sendbuf+disps[rank]*extent),
535
536
                                       recvcounts[rank], datatype, recvbuf,
                                       recvcounts[rank], datatype);
537
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
538
539
540
        }
        
        /* allocate temporary buffer to store incoming data */
541
        MPIU_CHKLMEM_MALLOC(tmp_recvbuf, void *, recvcounts[rank]*(MPIR_MAX(true_extent,extent))+1, mpi_errno, "tmp_recvbuf");
542
543
544
545
546
547
548
549
550
551
        /* adjust for potential negative lower bound in datatype */
        tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
        
        for (i=1; i<comm_size; i++) {
            src = (rank - i + comm_size) % comm_size;
            dst = (rank + i) % comm_size;
            
            /* send the data that dst needs. recv data that this process
               needs from src into tmp_recvbuf */
            if (sendbuf != MPI_IN_PLACE) 
552
                mpi_errno = MPIC_Sendrecv(((char *)sendbuf+disps[dst]*extent),
553
                                             recvcounts[dst], datatype, dst,
554
                                             MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf,
555
                                             recvcounts[rank], datatype, src,
556
557
                                             MPIR_REDUCE_SCATTER_TAG, comm,
                                             MPI_STATUS_IGNORE, errflag);
558
            else
559
                mpi_errno = MPIC_Sendrecv(((char *)recvbuf+disps[dst]*extent),
560
                                             recvcounts[dst], datatype, dst,
561
                                             MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf,
562
                                             recvcounts[rank], datatype, src,
563
564
                                             MPIR_REDUCE_SCATTER_TAG, comm,
                                             MPI_STATUS_IGNORE, errflag);
565
            
566
567
            if (mpi_errno) {
                /* for communication errors, just record the error but continue */
568
                *errflag = TRUE;
569
570
571
                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
            }
572
573
574
            
            if (is_commutative || (src < rank)) {
                if (sendbuf != MPI_IN_PLACE) {
575
		    mpi_errno = MPIR_Reduce_local_impl( 
576
			       tmp_recvbuf, recvbuf, recvcounts[rank],
577
                               datatype, op ); 
578
579
                }
                else {
580
581
		    mpi_errno = MPIR_Reduce_local_impl( 
			tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
582
			recvcounts[rank], datatype, op);
583
584
585
586
587
588
589
590
591
                    /* we can't store the result at the beginning of
                       recvbuf right here because there is useful data
                       there that other process/processes need. at the
                       end, we will copy back the result to the
                       beginning of recvbuf. */
                }
            }
            else {
                if (sendbuf != MPI_IN_PLACE) {
592
		    mpi_errno = MPIR_Reduce_local_impl( 
593
		       recvbuf, tmp_recvbuf, recvcounts[rank], datatype, op);
594
                    /* copy result back into recvbuf */
595
                    mpi_errno = MPIR_Localcopy(tmp_recvbuf, recvcounts[rank],
596
                                               datatype, recvbuf,
597
                                               recvcounts[rank], datatype);
598
                    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
599
600
                }
                else {
601
602
		    mpi_errno = MPIR_Reduce_local_impl( 
                        ((char *)recvbuf+disps[rank]*extent),
603
			tmp_recvbuf, recvcounts[rank], datatype, op);
604
                    /* copy result back into recvbuf */
605
                    mpi_errno = MPIR_Localcopy(tmp_recvbuf, recvcounts[rank],
606
607
608
                                               datatype, 
                                               ((char *)recvbuf +
                                                disps[rank]*extent), 
609
                                               recvcounts[rank], datatype);
610
                    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
611
612
613
614
615
616
617
618
619
                }
            }
        }
        
        /* if MPI_IN_PLACE, move output data to the beginning of
           recvbuf. already done for rank 0. */
        if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
            mpi_errno = MPIR_Localcopy(((char *)recvbuf +
                                        disps[rank]*extent),  
620
                                       recvcounts[rank], datatype,
621
                                       recvbuf, 
622
                                       recvcounts[rank], datatype);
623
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
624
625
626
        }
    }
    
627
    if (!is_commutative) {
628
629
        int is_block_regular = 1;
        for (i = 0; i < (comm_size - 1); ++i) {
630
            if (recvcounts[i] != recvcounts[i+1]) {
631
632
                is_block_regular = 0;
                break;
633
            }
634
635
636
637
638
639
640
641
        }

        /* slightly retask pof2 to mean pof2 equal or greater, not always greater as it is above */
        pof2 = 1;
        while (pof2 < comm_size) pof2 <<= 1;

        if (pof2 == comm_size && is_block_regular) {
            /* noncommutative, pof2 size, and block regular */
642
            mpi_errno = MPIR_Reduce_scatter_noncomm(sendbuf, recvbuf, recvcounts, datatype, op, comm_ptr, errflag);
643
644
645
646
647
648
            if (mpi_errno) {
                /* for communication errors, just record the error but continue */
                *errflag = TRUE;
                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
            }
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        }
        else {
            /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */

            /* need to allocate temporary buffer to receive incoming data*/
            MPIU_CHKLMEM_MALLOC(tmp_recvbuf, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_recvbuf");
            /* adjust for potential negative lower bound in datatype */
            tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);

            /* need to allocate another temporary buffer to accumulate
               results */
            MPIU_CHKLMEM_MALLOC(tmp_results, void *, total_count*(MPIR_MAX(true_extent,extent)), mpi_errno, "tmp_results");
            /* adjust for potential negative lower bound in datatype */
            tmp_results = (void *)((char*)tmp_results - true_lb);

            /* copy sendbuf into tmp_results */
            if (sendbuf != MPI_IN_PLACE)
                mpi_errno = MPIR_Localcopy(sendbuf, total_count, datatype,
                                           tmp_results, total_count, datatype);
            else
                mpi_errno = MPIR_Localcopy(recvbuf, total_count, datatype,
                                           tmp_results, total_count, datatype);

            if (mpi_errno) MPIU_ERR_POP(mpi_errno);

            mask = 0x1;
            i = 0;
            while (mask < comm_size) {
                dst = rank ^ mask;

                dst_tree_root = dst >> i;
                dst_tree_root <<= i;

                my_tree_root = rank >> i;
                my_tree_root <<= i;

                /* At step 1, processes exchange (n-n/p) amount of
                   data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
                   amount of data, and so forth. We use derived datatypes for this.

                   At each step, a process does not need to send data
                   indexed from my_tree_root to
                   my_tree_root+mask-1. Similarly, a process won't receive
                   data indexed from dst_tree_root to dst_tree_root+mask-1. */

                /* calculate sendtype */
                blklens[0] = blklens[1] = 0;
                for (j=0; j<my_tree_root; j++)
697
                    blklens[0] += recvcounts[j];
698
                for (j=my_tree_root+mask; j<comm_size; j++)
699
                    blklens[1] += recvcounts[j];
700
701
702
703

                dis[0] = 0;
                dis[1] = blklens[0];
                for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
704
                    dis[1] += recvcounts[j];
705

706
707
708
709
710
                mpi_errno = MPIR_Type_indexed_impl(2, blklens, dis, datatype, &sendtype);
                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                
                mpi_errno = MPIR_Type_commit_impl(&sendtype);
                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
711
712
713
714

                /* calculate recvtype */
                blklens[0] = blklens[1] = 0;
                for (j=0; j<dst_tree_root && j<comm_size; j++)
715
                    blklens[0] += recvcounts[j];
716
                for (j=dst_tree_root+mask; j<comm_size; j++)
717
                    blklens[1] += recvcounts[j];
718
719
720
721

                dis[0] = 0;
                dis[1] = blklens[0];
                for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
722
                    dis[1] += recvcounts[j];
723

724
725
726
727
728
                mpi_errno = MPIR_Type_indexed_impl(2, blklens, dis, datatype, &recvtype);
                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                
                mpi_errno = MPIR_Type_commit_impl(&recvtype);
                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
729
730
731
732
733
734
735

                received = 0;
                if (dst < comm_size) {
                    /* tmp_results contains data to be sent in each step. Data is
                       received in tmp_recvbuf and then accumulated into
                       tmp_results. accumulation is done later below.   */ 

736
                    mpi_errno = MPIC_Sendrecv(tmp_results, 1, sendtype, dst,
737
738
739
740
                                                 MPIR_REDUCE_SCATTER_TAG, 
                                                 tmp_recvbuf, 1, recvtype, dst,
                                                 MPIR_REDUCE_SCATTER_TAG, comm,
                                                 MPI_STATUS_IGNORE, errflag);
741
                    received = 1;
742
743
                    if (mpi_errno) {
                        /* for communication errors, just record the error but continue */
744
                        *errflag = TRUE;
745
746
747
                        MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                        MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                    }
748
                }
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768

                /* if some processes in this process's subtree in this step
                   did not have any destination process to communicate with
                   because of non-power-of-two, we need to send them the
                   result. We use a logarithmic recursive-halfing algorithm
                   for this. */

                if (dst_tree_root + mask > comm_size) {
                    nprocs_completed = comm_size - my_tree_root - mask;
                    /* nprocs_completed is the number of processes in this
                       subtree that have all the data. Send data to others
                       in a tree fashion. First find root of current tree
                       that is being divided into two. k is the number of
                       least-significant bits in this process's rank that
                       must be zeroed out to find the rank of the root */ 
                    j = mask;
                    k = 0;
                    while (j) {
                        j >>= 1;
                        k++;
769
770
                    }
                    k--;
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785

                    tmp_mask = mask >> 1;
                    while (tmp_mask) {
                        dst = rank ^ tmp_mask;

                        tree_root = rank >> k;
                        tree_root <<= k;

                        /* send only if this proc has data and destination
                           doesn't have data. at any step, multiple processes
                           can send if they have the data */
                        if ((dst > rank) && 
                            (rank < tree_root + nprocs_completed)
                            && (dst >= tree_root + nprocs_completed)) {
                            /* send the current result */
786
                            mpi_errno = MPIC_Send(tmp_recvbuf, 1, recvtype,
787
788
789
790
791
792
793
794
                                                     dst, MPIR_REDUCE_SCATTER_TAG,
                                                     comm, errflag);
                            if (mpi_errno) {
                                /* for communication errors, just record the error but continue */
                                *errflag = TRUE;
                                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                            }
795
796
797
798
799
800
                        }
                        /* recv only if this proc. doesn't have data and sender
                           has data */
                        else if ((dst < rank) && 
                                 (dst < tree_root + nprocs_completed) &&
                                 (rank >= tree_root + nprocs_completed)) {
801
                            mpi_errno = MPIC_Recv(tmp_recvbuf, 1, recvtype, dst,
802
803
                                                     MPIR_REDUCE_SCATTER_TAG,
                                                     comm, MPI_STATUS_IGNORE, errflag); 
804
                            received = 1;
805
806
                            if (mpi_errno) {
                                /* for communication errors, just record the error but continue */
807
                                *errflag = TRUE;
808
809
810
                                MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
                                MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
                            }
811
812
813
814
                        }
                        tmp_mask >>= 1;
                        k--;
                    }
815
                }
816
817

                /* The following reduction is done here instead of after 
818
                   the MPIC_Sendrecv or MPIC_Recv above. This is
819
820
821
822
823
824
825
826
827
                   because to do it above, in the noncommutative 
                   case, we would need an extra temp buffer so as not to
                   overwrite temp_recvbuf, because temp_recvbuf may have
                   to be communicated to other processes in the
                   non-power-of-two case. To avoid that extra allocation,
                   we do the reduce here. */
                if (received) {
                    if (is_commutative || (dst_tree_root < my_tree_root)) {
                        {
828
829
830
831
832
833
834
			    mpi_errno = MPIR_Reduce_local_impl( 
                               tmp_recvbuf, tmp_results, blklens[0],
			       datatype, op); 
			    mpi_errno = MPIR_Reduce_local_impl( 
                               ((char *)tmp_recvbuf + dis[1]*extent),
			       ((char *)tmp_results + dis[1]*extent),
			       blklens[1], datatype, op); 
835
                        }
836
                    }
837
838
                    else {
                        {
839
840
841
842
843
			    mpi_errno = MPIR_Reduce_local_impl(
                                   tmp_results, tmp_recvbuf, blklens[0],
                                   datatype, op); 
			    mpi_errno = MPIR_Reduce_local_impl(
                                   ((char *)tmp_results + dis[1]*extent),
844
                                   ((char *)tmp_recvbuf + dis[1]*extent),
845
                                   blklens[1], datatype, op); 
846
847
848
849
850
                        }
                        /* copy result back into tmp_results */
                        mpi_errno = MPIR_Localcopy(tmp_recvbuf, 1, recvtype, 
                                                   tmp_results, 1, recvtype);
                        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
851
852
                    }
                }
853

854
855
                MPIR_Type_free_impl(&sendtype);
                MPIR_Type_free_impl(&recvtype);
856
857
858

                mask <<= 1;
                i++;
859
            }
860
861
862

            /* now copy final results from tmp_results to recvbuf */
            mpi_errno = MPIR_Localcopy(((char *)tmp_results+disps[rank]*extent),
863
864
                                       recvcounts[rank], datatype, recvbuf,
                                       recvcounts[rank], datatype);
865
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
866
        }
867
868
869
870
871
    }

fn_exit:
    MPIU_CHKLMEM_FREEALL();

872
873
874
875
876
877
    /* check if multiple threads are calling this collective function */
    MPIDU_ERR_CHECK_MULTIPLE_THREADS_EXIT( comm_ptr );

    if (MPIU_THREADPRIV_FIELD(op_errno)) 
	mpi_errno = MPIU_THREADPRIV_FIELD(op_errno);

878
879
    if (mpi_errno_ret)
        mpi_errno = mpi_errno_ret;
880
881
    else if (*errflag)
        MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**coll_fail");
882
    return mpi_errno;
883
884
fn_fail:
    goto fn_exit;
885
}
886

887

888
#undef FUNCNAME
889
#define FUNCNAME MPIR_Reduce_scatter_inter
890
891
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
892

893
/* not declared static because a machine-specific function may call this one in some cases */
894
895
896
int MPIR_Reduce_scatter_inter(const void *sendbuf, void *recvbuf, const int recvcounts[],
                              MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr,
                              int *errflag)
897
898
899
900
901
902
903
904
{
/* Intercommunicator Reduce_scatter.
   We first do an intercommunicator reduce to rank 0 on left group,
   then an intercommunicator reduce to rank 0 on right group, followed
   by local intracommunicator scattervs in each group.
*/
    
    int rank, mpi_errno, root, local_size, total_count, i;
905
    int mpi_errno_ret = MPI_SUCCESS;
906
907
908
909
    MPI_Aint true_extent, true_lb = 0, extent;
    void *tmp_buf=NULL;
    int *disps=NULL;
    MPID_Comm *newcomm_ptr = NULL;
910
    MPIU_CHKLMEM_DECL(2);
911
912
913
914
915

    rank = comm_ptr->rank;
    local_size = comm_ptr->local_size;

    total_count = 0;
916
    for (i=0; i<local_size; i++) total_count += recvcounts[i];
917
918
919
920

    if (rank == 0) {
        /* In each group, rank 0 allocates a temp. buffer for the 
           reduce */
921
922
        
        MPIU_CHKLMEM_MALLOC(disps, int *, local_size*sizeof(int), mpi_errno, "disps");
923
924
925
926

        total_count = 0;
        for (i=0; i<local_size; i++) {
            disps[i] = total_count;
927
            total_count += recvcounts[i];
928
929
        }

930
        MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
931
932
        MPID_Datatype_get_extent_macro(datatype, extent);

933
934
        MPIU_CHKLMEM_MALLOC(tmp_buf, void *, total_count*(MPIR_MAX(extent,true_extent)), mpi_errno, "tmp_buf");

935
936
937
938
939
940
941
942
943
944
        /* adjust for potential negative lower bound in datatype */
        tmp_buf = (void *)((char*)tmp_buf - true_lb);
    }

    /* first do a reduce from right group to rank 0 in left group,
       then from left group to rank 0 in right group*/
    if (comm_ptr->is_low_group) {
        /* reduce from right group to rank 0*/
        root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
945
                                root, comm_ptr, errflag);
946
947
        if (mpi_errno) {
            /* for communication errors, just record the error but continue */
948
            *errflag = TRUE;
949
950
951
            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
        }
952
953
954
955

        /* reduce to rank 0 of right group */
        root = 0;
        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
956
                                root, comm_ptr, errflag);
957
958
        if (mpi_errno) {
            /* for communication errors, just record the error but continue */
959
            *errflag = TRUE;
960
961
962
            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
        }
963
964
965
966
967
    }
    else {
        /* reduce to rank 0 of left group */
        root = 0;
        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
968
                                root, comm_ptr, errflag);
969
970
        if (mpi_errno) {
            /* for communication errors, just record the error but continue */
971
            *errflag = TRUE;
972
973
974
            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
        }
975
976
977
978

        /* reduce from right group to rank 0 */
        root = (rank == 0) ? MPI_ROOT : MPI_PROC_NULL;
        mpi_errno = MPIR_Reduce_inter(sendbuf, tmp_buf, total_count, datatype, op,
979
                                root, comm_ptr, errflag);
980
981
        if (mpi_errno) {
            /* for communication errors, just record the error but continue */
982
            *errflag = TRUE;
983
984
985
            MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
            MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
        }
986
987
988
    }

    /* Get the local intracommunicator */
989
990
991
992
    if (!comm_ptr->local_comm) {
	mpi_errno = MPIR_Setup_intercomm_localcomm( comm_ptr );
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
    }
993
994
995

    newcomm_ptr = comm_ptr->local_comm;

996
997
    mpi_errno = MPIR_Scatterv(tmp_buf, recvcounts, disps, datatype, recvbuf,
                              recvcounts[rank], datatype, 0, newcomm_ptr, errflag);
998
999
    if (mpi_errno) {
        /* for communication errors, just record the error but continue */
1000
        *errflag = TRUE;
1001
1002
1003
        MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**fail");
        MPIU_ERR_ADD(mpi_errno_ret, mpi_errno);
    }
1004
    
1005
1006
 fn_exit:
    MPIU_CHKLMEM_FREEALL();
1007
1008
    if (mpi_errno_ret)
        mpi_errno = mpi_errno_ret;
1009
1010
    else if (*errflag)
        MPIU_ERR_SET(mpi_errno, MPI_ERR_OTHER, "**coll_fail");
1011
1012
1013
1014
    return mpi_errno;
 fn_fail:
    goto fn_exit;
}
1015

1016
1017
1018
1019
1020
1021

/* MPIR_Reduce_Scatter performs an reduce_scatter using point-to-point
   messages.  This is intended to be used by device-specific
   implementations of reduce_scatter.  In all other cases
   MPIR_Reduce_Scatter_impl should be used. */
#undef FUNCNAME
1022
#define FUNCNAME MPIR_Reduce_scatter
1023
1024
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
1025
int MPIR_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[],
1026
                        MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr, int *errflag)
1027
1028
1029
1030
1031
{
    int mpi_errno = MPI_SUCCESS;
        
    if (comm_ptr->comm_kind == MPID_INTRACOMM) {
        /* intracommunicator */
1032
        mpi_errno = MPIR_Reduce_scatter_intra(sendbuf, recvbuf, recvcounts,
1033
                                              datatype, op, comm_ptr, errflag);
1034
1035
1036
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
    } else {
        /* intercommunicator */
1037
        mpi_errno = MPIR_Reduce_scatter_inter(sendbuf, recvbuf, recvcounts,
1038
                                              datatype, op, comm_ptr, errflag);
1039
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
1040
1041
    }

1042
 fn_exit:
1043
    return mpi_errno;
1044
1045
 fn_fail:
    goto fn_exit;
1046
1047
}

1048
1049
1050
1051
/* MPIR_Reduce_Scatter_impl should be called by any internal component
   that would otherwise call MPI_Reduce_Scatter.  This differs from
   MPIR_Reduce_Scatter in that this will call the coll_fns version if
   it exists.  This function replaces NMPI_Reduce_Scatter. */
1052
#undef FUNCNAME
1053
#define FUNCNAME MPIR_Reduce_scatter_impl
1054
1055
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
1056
int MPIR_Reduce_scatter_impl(const void *sendbuf, void *recvbuf, const int recvcounts[],
1057
                             MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr, int *errflag)
1058
1059
1060
{
    int mpi_errno = MPI_SUCCESS;
        
1061
1062
1063
    if (comm_ptr->coll_fns != NULL && 
	comm_ptr->coll_fns->Reduce_scatter != NULL) {
	/* --BEGIN USEREXTENSION-- */
1064
	mpi_errno = comm_ptr->coll_fns->Reduce_scatter(sendbuf, recvbuf, recvcounts,
1065
                                                       datatype, op, comm_ptr, errflag);
1066
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
1067
	/* --END USEREXTENSION-- */
1068
    } else {
1069
        mpi_errno = MPIR_Reduce_scatter(sendbuf, recvbuf, recvcounts,
1070
                                        datatype, op, comm_ptr, errflag);
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
    }
    
 fn_exit:
    return mpi_errno;
 fn_fail:
    goto fn_exit;
}

#endif

1082
1083
1084
1085
#undef FUNCNAME
#define FUNCNAME MPI_Reduce_scatter
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
/*@

MPI_Reduce_scatter - Combines values and scatters the results

Input Parameters:
+ sendbuf - starting address of send buffer (choice) 
. recvcounts - integer array specifying the 
number of elements in result distributed to each process.
Array must be identical on all calling processes. 
. datatype - data type of elements of input buffer (handle) 
. op - operation (handle) 
- comm - communicator (handle) 

1099
Output Parameters:
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
. recvbuf - starting address of receive buffer (choice) 

.N ThreadSafe

.N Fortran

.N collops

.N Errors
.N MPI_SUCCESS
.N MPI_ERR_COMM
.N MPI_ERR_COUNT
.N MPI_ERR_TYPE
.N MPI_ERR_BUFFER
.N MPI_ERR_OP
.N MPI_ERR_BUFFER_ALIAS
@*/
1117
int MPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int recvcounts[],
1118
1119
1120
1121
		       MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
{
    int mpi_errno = MPI_SUCCESS;
    MPID_Comm *comm_ptr = NULL;
1122
    int errflag = FALSE;
1123
1124
1125
1126
    MPID_MPI_STATE_DECL(MPID_STATE_MPI_REDUCE_SCATTER);

    MPIR_ERRTEST_INITIALIZED_ORDIE();
    
1127
    MPIU_THREAD_CS_ENTER(ALLFUNC,);
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    MPID_MPI_COLL_FUNC_ENTER(MPID_STATE_MPI_REDUCE_SCATTER);

    /* Validate parameters, especially handles needing to be converted */
#   ifdef HAVE_ERROR_CHECKING
    {
        MPID_BEGIN_ERROR_CHECKS;
        {
	    MPIR_ERRTEST_COMM(comm, mpi_errno);
	}
        MPID_END_ERROR_CHECKS;
    }
#   endif /* HAVE_ERROR_CHECKING */

    /* Convert MPI object handles to object pointers */
    MPID_Comm_get_ptr( comm, comm_ptr );

    /* Validate parameters and objects (post conversion) */
#   ifdef HAVE_ERROR_CHECKING
    {
        MPID_BEGIN_ERROR_CHECKS;
        {
	    MPID_Datatype *datatype_ptr = NULL;
            MPID_Op *op_ptr = NULL;
            int i, size, sum;
	    
            MPID_Comm_valid_ptr( comm_ptr, mpi_errno );
            if (mpi_errno != MPI_SUCCESS) goto fn_fail;

            size = comm_ptr->local_size; 
1157
            /* even in intercomm. case, recvcounts is of size local_size */
1158
1159
1160

            sum = 0;
	    for (i=0; i<size; i++) {
1161
1162
		MPIR_ERRTEST_COUNT(recvcounts[i],mpi_errno);
                sum += recvcounts[i];
1163
1164
1165
1166
1167
1168
	    }

	    MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno);
            if (HANDLE_GET_KIND(datatype) != HANDLE_KIND_BUILTIN) {
                MPID_Datatype_get_ptr(datatype, datatype_ptr);
                MPID_Datatype_valid_ptr( datatype_ptr, mpi_errno );
1169
                if (mpi_errno != MPI_SUCCESS) goto fn_fail;
1170
                MPID_Datatype_committed_ptr( datatype_ptr, mpi_errno );
1171
                if (mpi_errno != MPI_SUCCESS) goto fn_fail;
1172
1173
            }

1174
            MPIR_ERRTEST_RECVBUF_INPLACE(recvbuf, recvcounts[comm_ptr->rank], mpi_errno);
1175
            if (comm_ptr->comm_kind == MPID_INTERCOMM) {
1176
                MPIR_ERRTEST_SENDBUF_INPLACE(sendbuf, sum, mpi_errno);
1177
            } else if (sendbuf != MPI_IN_PLACE && sum != 0)
1178
                MPIR_ERRTEST_ALIAS_COLL(sendbuf, recvbuf, mpi_errno)
1179

1180
            MPIR_ERRTEST_USERBUFFER(recvbuf,recvcounts[comm_ptr->rank],datatype,mpi_errno);
1181
1182
1183
1184
1185
1186
1187
1188
1189
<