Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' into hypervisor
[simgrid.git] / src / smpi / colls / reduce_scatter-mpich.c
1 #include "colls_private.h"
2 #define MPIR_REDUCE_SCATTER_TAG 222
3
4 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
5 {
6     /* a mask for the high order bits that should be copied as-is */
7     int high_mask = ~((0x1 << bits) - 1);
8     int retval = x & high_mask;
9     int i;
10
11     for (i = 0; i < bits; ++i) {
12         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
13         retval |= bitval << ((bits - i) - 1);
14     }
15
16     return retval;
17 }
18
19
20 int smpi_coll_tuned_reduce_scatter_mpich_pair(void *sendbuf, void *recvbuf, int recvcounts[],
21                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
22 {
23     int   rank, comm_size, i;
24     MPI_Aint extent, true_extent, true_lb; 
25     int  *disps;
26     void *tmp_recvbuf;
27     int mpi_errno = MPI_SUCCESS;
28     int total_count, dst, src;
29     int is_commutative;
30     comm_size = smpi_comm_size(comm);
31     rank = smpi_comm_rank(comm);
32
33     extent =smpi_datatype_get_extent(datatype);
34     smpi_datatype_extent(datatype, &true_lb, &true_extent);
35     
36     if (smpi_op_is_commute(op)) {
37         is_commutative = 1;
38     }
39
40     disps = (int*)xbt_malloc( comm_size * sizeof(int));
41
42     total_count = 0;
43     for (i=0; i<comm_size; i++) {
44         disps[i] = total_count;
45         total_count += recvcounts[i];
46     }
47     
48     if (total_count == 0) {
49         return MPI_ERR_COUNT;
50     }
51
52         if (sendbuf != MPI_IN_PLACE) {
53             /* copy local data into recvbuf */
54             smpi_datatype_copy(((char *)sendbuf+disps[rank]*extent),
55                                        recvcounts[rank], datatype, recvbuf,
56                                        recvcounts[rank], datatype);
57         }
58         
59         /* allocate temporary buffer to store incoming data */
60         tmp_recvbuf = (void*)xbt_malloc(recvcounts[rank]*(max(true_extent,extent))+1);
61         /* adjust for potential negative lower bound in datatype */
62         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
63         
64         for (i=1; i<comm_size; i++) {
65             src = (rank - i + comm_size) % comm_size;
66             dst = (rank + i) % comm_size;
67             
68             /* send the data that dst needs. recv data that this process
69                needs from src into tmp_recvbuf */
70             if (sendbuf != MPI_IN_PLACE) 
71                 smpi_mpi_sendrecv(((char *)sendbuf+disps[dst]*extent), 
72                                              recvcounts[dst], datatype, dst,
73                                              MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf,
74                                              recvcounts[rank], datatype, src,
75                                              MPIR_REDUCE_SCATTER_TAG, comm,
76                                              MPI_STATUS_IGNORE);
77             else
78                 smpi_mpi_sendrecv(((char *)recvbuf+disps[dst]*extent), 
79                                              recvcounts[dst], datatype, dst,
80                                              MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf,
81                                              recvcounts[rank], datatype, src,
82                                              MPIR_REDUCE_SCATTER_TAG, comm,
83                                              MPI_STATUS_IGNORE);
84             
85             if (is_commutative || (src < rank)) {
86                 if (sendbuf != MPI_IN_PLACE) {
87                      smpi_op_apply( op,
88                                                   tmp_recvbuf, recvbuf, &recvcounts[rank],
89                                &datatype); 
90                 }
91                 else {
92                     smpi_op_apply(op, 
93                         tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
94                         &recvcounts[rank], &datatype);
95                     /* we can't store the result at the beginning of
96                        recvbuf right here because there is useful data
97                        there that other process/processes need. at the
98                        end, we will copy back the result to the
99                        beginning of recvbuf. */
100                 }
101             }
102             else {
103                 if (sendbuf != MPI_IN_PLACE) {
104                     smpi_op_apply(op, 
105                        recvbuf, tmp_recvbuf, &recvcounts[rank], &datatype);
106                     /* copy result back into recvbuf */
107                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
108                                                datatype, recvbuf,
109                                                recvcounts[rank], datatype);
110                     if (mpi_errno) return(mpi_errno);
111                 }
112                 else {
113                     smpi_op_apply(op, 
114                         ((char *)recvbuf+disps[rank]*extent),
115                         tmp_recvbuf, &recvcounts[rank], &datatype);
116                     /* copy result back into recvbuf */
117                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
118                                                datatype, 
119                                                ((char *)recvbuf +
120                                                 disps[rank]*extent), 
121                                                recvcounts[rank], datatype);
122                     if (mpi_errno) return(mpi_errno);
123                 }
124             }
125         }
126         
127         /* if MPI_IN_PLACE, move output data to the beginning of
128            recvbuf. already done for rank 0. */
129         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
130             mpi_errno = smpi_datatype_copy(((char *)recvbuf +
131                                         disps[rank]*extent),  
132                                        recvcounts[rank], datatype,
133                                        recvbuf, 
134                                        recvcounts[rank], datatype );
135             if (mpi_errno) return(mpi_errno);
136         }
137     
138 return MPI_SUCCESS;
139 }
140     
141
142 int smpi_coll_tuned_reduce_scatter_mpich_noncomm(void *sendbuf, void *recvbuf, int recvcounts[],
143                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
144 {
145     int mpi_errno = MPI_SUCCESS;
146     int comm_size = smpi_comm_size(comm) ;
147     int rank = smpi_comm_rank(comm);
148     int pof2;
149     int log2_comm_size;
150     int i, k;
151     int recv_offset, send_offset;
152     int block_size, total_count, size;
153     MPI_Aint true_extent, true_lb;
154     int buf0_was_inout;
155     void *tmp_buf0;
156     void *tmp_buf1;
157     void *result_ptr;
158
159     smpi_datatype_extent(datatype, &true_lb, &true_extent);
160
161     pof2 = 1;
162     log2_comm_size = 0;
163     while (pof2 < comm_size) {
164         pof2 <<= 1;
165         ++log2_comm_size;
166     }
167
168     /* begin error checking */
169     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
170
171     for (i = 0; i < (comm_size - 1); ++i) {
172         xbt_assert(recvcounts[i] == recvcounts[i+1]);
173     }
174     /* end error checking */
175
176     /* size of a block (count of datatype per block, NOT bytes per block) */
177     block_size = recvcounts[0];
178     total_count = block_size * comm_size;
179
180     tmp_buf0=( void *)xbt_malloc( true_extent * total_count);
181     tmp_buf1=( void *)xbt_malloc( true_extent * total_count);
182     /* adjust for potential negative lower bound in datatype */
183     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
184     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
185
186     /* Copy our send data to tmp_buf0.  We do this one block at a time and
187        permute the blocks as we go according to the mirror permutation. */
188     for (i = 0; i < comm_size; ++i) {
189         mpi_errno = smpi_datatype_copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
190                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
191         if (mpi_errno) return(mpi_errno);
192     }
193     buf0_was_inout = 1;
194
195     send_offset = 0;
196     recv_offset = 0;
197     size = total_count;
198     for (k = 0; k < log2_comm_size; ++k) {
199         /* use a double-buffering scheme to avoid local copies */
200         char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
201         char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
202         int peer = rank ^ (0x1 << k);
203         size /= 2;
204
205         if (rank > peer) {
206             /* we have the higher rank: send top half, recv bottom half */
207             recv_offset += size;
208         }
209         else {
210             /* we have the lower rank: recv top half, send bottom half */
211             send_offset += size;
212         }
213
214         smpi_mpi_sendrecv(outgoing_data + send_offset*true_extent,
215                                      size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
216                                      incoming_data + recv_offset*true_extent,
217                                      size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
218                                      comm, MPI_STATUS_IGNORE);
219         /* always perform the reduction at recv_offset, the data at send_offset
220            is now our peer's responsibility */
221         if (rank > peer) {
222             /* higher ranked value so need to call op(received_data, my_data) */
223             smpi_op_apply(op, 
224                    incoming_data + recv_offset*true_extent,
225                      outgoing_data + recv_offset*true_extent,
226                      &size, &datatype );
227             /* buf0_was_inout = buf0_was_inout; */
228         }
229         else {
230             /* lower ranked value so need to call op(my_data, received_data) */
231             smpi_op_apply( op,
232                      outgoing_data + recv_offset*true_extent,
233                      incoming_data + recv_offset*true_extent,
234                      &size, &datatype);
235             buf0_was_inout = !buf0_was_inout;
236         }
237
238         /* the next round of send/recv needs to happen within the block (of size
239            "size") that we just received and reduced */
240         send_offset = recv_offset;
241     }
242
243     xbt_assert(size == recvcounts[rank]);
244
245     /* copy the reduced data to the recvbuf */
246     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
247     mpi_errno = smpi_datatype_copy(result_ptr, size, datatype,
248                                recvbuf, size, datatype);
249     if (mpi_errno) return(mpi_errno);
250     return MPI_SUCCESS;
251 }
252
253
254
255 int smpi_coll_tuned_reduce_scatter_mpich_rdb(void *sendbuf, void *recvbuf, int recvcounts[],
256                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
257 {
258     int   rank, comm_size, i;
259     MPI_Aint extent, true_extent, true_lb; 
260     int  *disps;
261     void *tmp_recvbuf, *tmp_results;
262     int mpi_errno = MPI_SUCCESS;
263     int dis[2], blklens[2], total_count, dst;
264     int mask, dst_tree_root, my_tree_root, j, k;
265     int received;
266     MPI_Datatype sendtype, recvtype;
267     int nprocs_completed, tmp_mask, tree_root, is_commutative;
268     comm_size = smpi_comm_size(comm);
269     rank = smpi_comm_rank(comm);
270
271     extent =smpi_datatype_get_extent(datatype);
272     smpi_datatype_extent(datatype, &true_lb, &true_extent);
273     
274     if (smpi_op_is_commute(op)) {
275         is_commutative = 1;
276     }
277
278     disps = (int*)xbt_malloc( comm_size * sizeof(int));
279
280     total_count = 0;
281     for (i=0; i<comm_size; i++) {
282         disps[i] = total_count;
283         total_count += recvcounts[i];
284     }
285     
286             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
287
288             /* need to allocate temporary buffer to receive incoming data*/
289             tmp_recvbuf= (void *) xbt_malloc( total_count*(max(true_extent,extent)));
290             /* adjust for potential negative lower bound in datatype */
291             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
292
293             /* need to allocate another temporary buffer to accumulate
294                results */
295             tmp_results = (void *)xbt_malloc( total_count*(max(true_extent,extent)));
296             /* adjust for potential negative lower bound in datatype */
297             tmp_results = (void *)((char*)tmp_results - true_lb);
298
299             /* copy sendbuf into tmp_results */
300             if (sendbuf != MPI_IN_PLACE)
301                 mpi_errno = smpi_datatype_copy(sendbuf, total_count, datatype,
302                                            tmp_results, total_count, datatype);
303             else
304                 mpi_errno = smpi_datatype_copy(recvbuf, total_count, datatype,
305                                            tmp_results, total_count, datatype);
306
307             if (mpi_errno) return(mpi_errno);
308
309             mask = 0x1;
310             i = 0;
311             while (mask < comm_size) {
312                 dst = rank ^ mask;
313
314                 dst_tree_root = dst >> i;
315                 dst_tree_root <<= i;
316
317                 my_tree_root = rank >> i;
318                 my_tree_root <<= i;
319
320                 /* At step 1, processes exchange (n-n/p) amount of
321                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
322                    amount of data, and so forth. We use derived datatypes for this.
323
324                    At each step, a process does not need to send data
325                    indexed from my_tree_root to
326                    my_tree_root+mask-1. Similarly, a process won't receive
327                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
328
329                 /* calculate sendtype */
330                 blklens[0] = blklens[1] = 0;
331                 for (j=0; j<my_tree_root; j++)
332                     blklens[0] += recvcounts[j];
333                 for (j=my_tree_root+mask; j<comm_size; j++)
334                     blklens[1] += recvcounts[j];
335
336                 dis[0] = 0;
337                 dis[1] = blklens[0];
338                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
339                     dis[1] += recvcounts[j];
340
341                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &sendtype);
342                 if (mpi_errno) return(mpi_errno);
343                 
344                 smpi_datatype_commit(&sendtype);
345
346                 /* calculate recvtype */
347                 blklens[0] = blklens[1] = 0;
348                 for (j=0; j<dst_tree_root && j<comm_size; j++)
349                     blklens[0] += recvcounts[j];
350                 for (j=dst_tree_root+mask; j<comm_size; j++)
351                     blklens[1] += recvcounts[j];
352
353                 dis[0] = 0;
354                 dis[1] = blklens[0];
355                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
356                     dis[1] += recvcounts[j];
357
358                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &recvtype);
359                 if (mpi_errno) return(mpi_errno);
360                 
361                 smpi_datatype_commit(&recvtype);
362
363                 received = 0;
364                 if (dst < comm_size) {
365                     /* tmp_results contains data to be sent in each step. Data is
366                        received in tmp_recvbuf and then accumulated into
367                        tmp_results. accumulation is done later below.   */ 
368
369                     smpi_mpi_sendrecv(tmp_results, 1, sendtype, dst,
370                                                  MPIR_REDUCE_SCATTER_TAG, 
371                                                  tmp_recvbuf, 1, recvtype, dst,
372                                                  MPIR_REDUCE_SCATTER_TAG, comm,
373                                                  MPI_STATUS_IGNORE);
374                     received = 1;
375                 }
376
377                 /* if some processes in this process's subtree in this step
378                    did not have any destination process to communicate with
379                    because of non-power-of-two, we need to send them the
380                    result. We use a logarithmic recursive-halfing algorithm
381                    for this. */
382
383                 if (dst_tree_root + mask > comm_size) {
384                     nprocs_completed = comm_size - my_tree_root - mask;
385                     /* nprocs_completed is the number of processes in this
386                        subtree that have all the data. Send data to others
387                        in a tree fashion. First find root of current tree
388                        that is being divided into two. k is the number of
389                        least-significant bits in this process's rank that
390                        must be zeroed out to find the rank of the root */ 
391                     j = mask;
392                     k = 0;
393                     while (j) {
394                         j >>= 1;
395                         k++;
396                     }
397                     k--;
398
399                     tmp_mask = mask >> 1;
400                     while (tmp_mask) {
401                         dst = rank ^ tmp_mask;
402
403                         tree_root = rank >> k;
404                         tree_root <<= k;
405
406                         /* send only if this proc has data and destination
407                            doesn't have data. at any step, multiple processes
408                            can send if they have the data */
409                         if ((dst > rank) && 
410                             (rank < tree_root + nprocs_completed)
411                             && (dst >= tree_root + nprocs_completed)) {
412                             /* send the current result */
413                             smpi_mpi_send(tmp_recvbuf, 1, recvtype,
414                                                      dst, MPIR_REDUCE_SCATTER_TAG,
415                                                      comm);
416                         }
417                         /* recv only if this proc. doesn't have data and sender
418                            has data */
419                         else if ((dst < rank) && 
420                                  (dst < tree_root + nprocs_completed) &&
421                                  (rank >= tree_root + nprocs_completed)) {
422                             smpi_mpi_recv(tmp_recvbuf, 1, recvtype, dst,
423                                                      MPIR_REDUCE_SCATTER_TAG,
424                                                      comm, MPI_STATUS_IGNORE); 
425                             received = 1;
426                         }
427                         tmp_mask >>= 1;
428                         k--;
429                     }
430                 }
431
432                 /* The following reduction is done here instead of after 
433                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
434                    because to do it above, in the noncommutative 
435                    case, we would need an extra temp buffer so as not to
436                    overwrite temp_recvbuf, because temp_recvbuf may have
437                    to be communicated to other processes in the
438                    non-power-of-two case. To avoid that extra allocation,
439                    we do the reduce here. */
440                 if (received) {
441                     if (is_commutative || (dst_tree_root < my_tree_root)) {
442                         {
443                                  smpi_op_apply(op, 
444                                tmp_recvbuf, tmp_results, &blklens[0],
445                                &datatype); 
446                                 smpi_op_apply(op, 
447                                ((char *)tmp_recvbuf + dis[1]*extent),
448                                ((char *)tmp_results + dis[1]*extent),
449                                &blklens[1], &datatype); 
450                         }
451                     }
452                     else {
453                         {
454                                  smpi_op_apply(op,
455                                    tmp_results, tmp_recvbuf, &blklens[0],
456                                    &datatype); 
457                                  smpi_op_apply(op,
458                                    ((char *)tmp_results + dis[1]*extent),
459                                    ((char *)tmp_recvbuf + dis[1]*extent),
460                                    &blklens[1], &datatype); 
461                         }
462                         /* copy result back into tmp_results */
463                         mpi_errno = smpi_datatype_copy(tmp_recvbuf, 1, recvtype, 
464                                                    tmp_results, 1, recvtype);
465                         if (mpi_errno) return(mpi_errno);
466                     }
467                 }
468
469                 //smpi_datatype_free(&sendtype);
470                 //smpi_datatype_free(&recvtype);
471
472                 mask <<= 1;
473                 i++;
474             }
475
476             /* now copy final results from tmp_results to recvbuf */
477             mpi_errno = smpi_datatype_copy(((char *)tmp_results+disps[rank]*extent),
478                                        recvcounts[rank], datatype, recvbuf,
479                                        recvcounts[rank], datatype);
480             if (mpi_errno) return(mpi_errno);
481
482     return MPI_SUCCESS;
483         }
484
485