iexscan.c 11.1 KB
Newer Older
1
2
/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
/*
3
 *  (C) 2010 by Argonne National Laboratory.
4
5
6
7
 *      See COPYRIGHT in top-level directory.
 */

#include "mpiimpl.h"
8
#include "collutil.h"
9

10
/* -- Begin Profiling Symbol Block for routine MPI_Iexscan */
11
#if defined(HAVE_PRAGMA_WEAK)
12
#pragma weak MPI_Iexscan = PMPI_Iexscan
13
#elif defined(HAVE_PRAGMA_HP_SEC_DEF)
14
#pragma _HP_SECONDARY_DEF PMPI_Iexscan  MPI_Iexscan
15
#elif defined(HAVE_PRAGMA_CRI_DUP)
16
#pragma _CRI duplicate MPI_Iexscan as PMPI_Iexscan
17
18
19
20
21
22
#endif
/* -- End Profiling Symbol Block */

/* Define MPICH_MPI_FROM_PMPI if weak symbols are not supported to build
   the MPI routines */
#ifndef MPICH_MPI_FROM_PMPI
23
24
#undef MPI_Iexscan
#define MPI_Iexscan PMPI_Iexscan
25
26
27

/* any non-MPI functions go here, especially non-static ones */

28
29
/* This is the default implementation of exscan. The algorithm is:

30
   Algorithm: MPI_Iexscan
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

   We use a lgp recursive doubling algorithm. The basic algorithm is
   given below. (You can replace "+" with any other scan operator.)
   The result is stored in recvbuf.

 .vb
   partial_scan = sendbuf;
   mask = 0x1;
   flag = 0;
   while (mask < size) {
      dst = rank^mask;
      if (dst < size) {
         send partial_scan to dst;
         recv from dst into tmp_buf;
         if (rank > dst) {
            partial_scan = tmp_buf + partial_scan;
            if (rank != 0) {
               if (flag == 0) {
                   recv_buf = tmp_buf;
                   flag = 1;
               }
               else
                   recv_buf = tmp_buf + recvbuf;
            }
         }
         else {
            if (op is commutative)
               partial_scan = tmp_buf + partial_scan;
            else {
               tmp_buf = partial_scan + tmp_buf;
               partial_scan = tmp_buf;
            }
         }
      }
      mask <<= 1;
   }
.ve

   End Algorithm: MPI_Exscan
*/
#undef FUNCNAME
#define FUNCNAME MPIR_Iexscan
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
75
int MPIR_Iexscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr, MPID_Sched_t s)
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
{
    int mpi_errno = MPI_SUCCESS;
    int rank, comm_size;
    int mask, dst, is_commutative, flag;
    MPI_Aint true_extent, true_lb, extent;
    void *partial_scan, *tmp_buf;
    MPIR_SCHED_CHKPMEM_DECL(2);

    if (count == 0)
        goto fn_exit;

    comm_size = comm_ptr->local_size;
    rank = comm_ptr->rank;

    is_commutative = MPIR_Op_is_commutative(op);

    /* need to allocate temporary buffer to store partial scan*/
    MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
    MPID_Datatype_get_extent_macro(datatype, extent);

    MPIR_SCHED_CHKPMEM_MALLOC(partial_scan, void *, (count*(MPIR_MAX(true_extent,extent))), mpi_errno, "partial_scan");
    /* adjust for potential negative lower bound in datatype */
    partial_scan = (void *)((char*)partial_scan - true_lb);

    /* need to allocate temporary buffer to store incoming data*/
    MPIR_SCHED_CHKPMEM_MALLOC(tmp_buf, void *, (count*(MPIR_MAX(true_extent,extent))), mpi_errno, "tmp_buf");
    /* adjust for potential negative lower bound in datatype */
    tmp_buf = (void *)((char*)tmp_buf - true_lb);

    mpi_errno = MPID_Sched_copy((sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf), count, datatype,
                               partial_scan, count, datatype, s);
    if (mpi_errno) MPIU_ERR_POP(mpi_errno);

    flag = 0;
    mask = 0x1;
    while (mask < comm_size) {
        dst = rank ^ mask;
        if (dst < comm_size) {
            /* Send partial_scan to dst. Recv into tmp_buf */
            mpi_errno = MPID_Sched_send(partial_scan, count, datatype, dst, comm_ptr, s);
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
            /* sendrecv, no barrier here */
            mpi_errno = MPID_Sched_recv(tmp_buf, count, datatype, dst, comm_ptr, s);
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
            MPID_SCHED_BARRIER(s);

            if (rank > dst) {
                mpi_errno = MPID_Sched_reduce(tmp_buf, partial_scan, count, datatype, op, s);
                if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                MPID_SCHED_BARRIER(s);

                /* On rank 0, recvbuf is not defined.  For sendbuf==MPI_IN_PLACE
                   recvbuf must not change (per MPI-2.2).
                   On rank 1, recvbuf is to be set equal to the value
                   in sendbuf on rank 0.
                   On others, recvbuf is the scan of values in the
                   sendbufs on lower ranks. */
                if (rank != 0) {
                    if (flag == 0) {
                        /* simply copy data recd from rank 0 into recvbuf */
                        mpi_errno = MPID_Sched_copy(tmp_buf, count, datatype,
                                                    recvbuf, count, datatype, s);
                        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                        MPID_SCHED_BARRIER(s);

                        flag = 1;
                    }
                    else {
                        mpi_errno = MPID_Sched_reduce(tmp_buf, recvbuf, count, datatype, op, s);
                        if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                        MPID_SCHED_BARRIER(s);
                    }
                }
            }
            else {
                if (is_commutative) {
                    mpi_errno = MPID_Sched_reduce(tmp_buf, partial_scan, count, datatype, op, s);
                    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                    MPID_SCHED_BARRIER(s);
                }
                else {
                    mpi_errno = MPID_Sched_reduce(partial_scan, tmp_buf, count, datatype, op, s);
                    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                    MPID_SCHED_BARRIER(s);

                    mpi_errno = MPID_Sched_copy(tmp_buf, count, datatype,
                                                partial_scan, count, datatype, s);
                    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
                    MPID_SCHED_BARRIER(s);
                }
            }
        }
        mask <<= 1;
    }

    MPIR_SCHED_CHKPMEM_COMMIT(s);
fn_exit:
    return mpi_errno;
fn_fail:
    MPIR_SCHED_CHKPMEM_REAP(s);
    goto fn_exit;
}

179
180
181
182
#undef FUNCNAME
#define FUNCNAME MPIR_Iexscan_impl
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
183
int MPIR_Iexscan_impl(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPID_Comm *comm_ptr, MPI_Request *request)
184
185
{
    int mpi_errno = MPI_SUCCESS;
186
    MPID_Request *reqp = NULL;
187
188
    int tag = -1;
    MPID_Sched_t s = MPID_SCHED_NULL;
189

190
191
    *request = MPI_REQUEST_NULL;

192
    MPIU_Assert(comm_ptr->coll_fns != NULL);
193
    if (comm_ptr->coll_fns->Iexscan_req != NULL) {
194
        /* --BEGIN USEREXTENSION-- */
195
        mpi_errno = comm_ptr->coll_fns->Iexscan_req(sendbuf, recvbuf, count, datatype, op, comm_ptr, &reqp);
196
197
198
199
200
201
202
203
        if (reqp) {
            *request = reqp->handle;
            if (mpi_errno) MPIU_ERR_POP(mpi_errno);
            goto fn_exit;
        }
        /* --END USEREXTENSION-- */
    }

204
    mpi_errno = MPID_Sched_next_tag(comm_ptr, &tag);
205
206
207
208
209
    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
    mpi_errno = MPID_Sched_create(&s);
    if (mpi_errno) MPIU_ERR_POP(mpi_errno);

    MPIU_Assert(comm_ptr->coll_fns != NULL);
210
211
    MPIU_Assert(comm_ptr->coll_fns->Iexscan_sched != NULL);
    mpi_errno = comm_ptr->coll_fns->Iexscan_sched(sendbuf, recvbuf, count, datatype, op, comm_ptr, s);
212
213
214
215
216
217
    if (mpi_errno) MPIU_ERR_POP(mpi_errno);

    mpi_errno = MPID_Sched_start(&s, comm_ptr, tag, &reqp);
    if (reqp)
        *request = reqp->handle;
    if (mpi_errno) MPIU_ERR_POP(mpi_errno);
218
219
220
221
222
223
224
225
226
227

fn_exit:
    return mpi_errno;
fn_fail:
    goto fn_exit;
}

#endif /* MPICH_MPI_FROM_PMPI */

#undef FUNCNAME
228
#define FUNCNAME MPI_Iexscan
229
230
231
#undef FCNAME
#define FCNAME MPIU_QUOTE(FUNCNAME)
/*@
232
MPI_Iexscan - XXX description here
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

Input Parameters:
+ sendbuf - starting address of the send buffer (choice)
. count - number of elements in input buffer (non-negative integer)
. datatype - data type of elements of input buffer (handle)
. op - operation (handle)
- comm - communicator (handle)

Output Parameters:
+ recvbuf - starting address of the receive buffer (choice)
- request - communication request (handle)

.N ThreadSafe

.N Fortran

.N Errors
@*/
251
int MPI_Iexscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request)
252
253
254
{
    int mpi_errno = MPI_SUCCESS;
    MPID_Comm *comm_ptr = NULL;
255
    MPID_MPI_STATE_DECL(MPID_STATE_MPI_IEXSCAN);
256
257

    MPIU_THREAD_CS_ENTER(ALLFUNC,);
258
    MPID_MPI_FUNC_ENTER(MPID_STATE_MPI_IEXSCAN);
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    /* Validate parameters, especially handles needing to be converted */
#   ifdef HAVE_ERROR_CHECKING
    {
        MPID_BEGIN_ERROR_CHECKS
        {
            MPIR_ERRTEST_DATATYPE(datatype, "datatype", mpi_errno);
            MPIR_ERRTEST_OP(op, mpi_errno);
            MPIR_ERRTEST_COMM(comm, mpi_errno);

            /* TODO more checks may be appropriate */
        }
        MPID_END_ERROR_CHECKS
    }
#   endif /* HAVE_ERROR_CHECKING */

    /* Convert MPI object handles to object pointers */
    MPID_Comm_get_ptr(comm, comm_ptr);

    /* Validate parameters and objects (post conversion) */
#   ifdef HAVE_ERROR_CHECKING
    {
        MPID_BEGIN_ERROR_CHECKS
        {
283
            MPID_Comm_valid_ptr(comm_ptr, mpi_errno);
284
            MPIR_ERRTEST_COMM_INTRA(comm_ptr, mpi_errno);
285
286
287
288
            if (HANDLE_GET_KIND(datatype) != HANDLE_KIND_BUILTIN) {
                MPID_Datatype *datatype_ptr = NULL;
                MPID_Datatype_get_ptr(datatype, datatype_ptr);
                MPID_Datatype_valid_ptr(datatype_ptr, mpi_errno);
289
                if (mpi_errno != MPI_SUCCESS) goto fn_fail;
290
                MPID_Datatype_committed_ptr(datatype_ptr, mpi_errno);
291
                if (mpi_errno != MPI_SUCCESS) goto fn_fail;
292
293
294
295
296
297
298
299
            }

            if (HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) {
                MPID_Op *op_ptr = NULL;
                MPID_Op_get_ptr(op, op_ptr);
                MPID_Op_valid_ptr(op_ptr, mpi_errno);
            }
            else if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) {
300
                mpi_errno = ( * MPIR_OP_HDL_TO_DTYPE_FN(op) )(datatype);
301
            }
302
            if (mpi_errno != MPI_SUCCESS) goto fn_fail;
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

            MPIR_ERRTEST_ARGNULL(request,"request", mpi_errno);
            /* TODO more checks may be appropriate (counts, in_place, buffer aliasing, etc) */
        }
        MPID_END_ERROR_CHECKS
    }
#   endif /* HAVE_ERROR_CHECKING */

    /* ... body of routine ...  */

    mpi_errno = MPIR_Iexscan_impl(sendbuf, recvbuf, count, datatype, op, comm_ptr, request);
    if (mpi_errno) MPIU_ERR_POP(mpi_errno);

    /* ... end of body of routine ... */

fn_exit:
319
    MPID_MPI_FUNC_EXIT(MPID_STATE_MPI_IEXSCAN);
320
321
322
323
324
325
326
327
328
    MPIU_THREAD_CS_EXIT(ALLFUNC,);
    return mpi_errno;

fn_fail:
    /* --BEGIN ERROR HANDLING-- */
#   ifdef HAVE_ERROR_CHECKING
    {
        mpi_errno = MPIR_Err_create_code(
            mpi_errno, MPIR_ERR_RECOVERABLE, FCNAME, __LINE__, MPI_ERR_OTHER,
329
            "**mpi_iexscan", "**mpi_iexscan %p %p %d %D %O %C %p", sendbuf, recvbuf, count, datatype, op, comm, request);
330
331
332
333
334
335
336
    }
#   endif
    mpi_errno = MPIR_Err_return_comm(comm_ptr, FCNAME, mpi_errno);
    goto fn_exit;
    /* --END ERROR HANDLING-- */
    goto fn_exit;
}