Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
b01fbd027969cd814ba2355dada7bfa6d75a4cfa
[simgrid.git] / src / smpi / colls / reduce_scatter-mpich.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8
9 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
10 {
11     /* a mask for the high order bits that should be copied as-is */
12     int high_mask = ~((0x1 << bits) - 1);
13     int retval = x & high_mask;
14     int i;
15
16     for (i = 0; i < bits; ++i) {
17         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
18         retval |= bitval << ((bits - i) - 1);
19     }
20
21     return retval;
22 }
23
24
25 int smpi_coll_tuned_reduce_scatter_mpich_pair(void *sendbuf, void *recvbuf, int recvcounts[],
26                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
27 {
28     int   rank, comm_size, i;
29     MPI_Aint extent, true_extent, true_lb; 
30     int  *disps;
31     void *tmp_recvbuf;
32     int mpi_errno = MPI_SUCCESS;
33     int total_count, dst, src;
34     int is_commutative;
35     comm_size = smpi_comm_size(comm);
36     rank = smpi_comm_rank(comm);
37
38     extent =smpi_datatype_get_extent(datatype);
39     smpi_datatype_extent(datatype, &true_lb, &true_extent);
40     
41     if (smpi_op_is_commute(op)) {
42         is_commutative = 1;
43     }
44
45     disps = (int*)xbt_malloc( comm_size * sizeof(int));
46
47     total_count = 0;
48     for (i=0; i<comm_size; i++) {
49         disps[i] = total_count;
50         total_count += recvcounts[i];
51     }
52     
53     if (total_count == 0) {
54         xbt_free(disps);
55         return MPI_ERR_COUNT;
56     }
57
58         if (sendbuf != MPI_IN_PLACE) {
59             /* copy local data into recvbuf */
60             smpi_datatype_copy(((char *)sendbuf+disps[rank]*extent),
61                                        recvcounts[rank], datatype, recvbuf,
62                                        recvcounts[rank], datatype);
63         }
64         
65         /* allocate temporary buffer to store incoming data */
66         tmp_recvbuf = (void*)smpi_get_tmp_recvbuffer(recvcounts[rank]*(MAX(true_extent,extent))+1);
67         /* adjust for potential negative lower bound in datatype */
68         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
69         
70         for (i=1; i<comm_size; i++) {
71             src = (rank - i + comm_size) % comm_size;
72             dst = (rank + i) % comm_size;
73             
74             /* send the data that dst needs. recv data that this process
75                needs from src into tmp_recvbuf */
76             if (sendbuf != MPI_IN_PLACE) 
77                 smpi_mpi_sendrecv(((char *)sendbuf+disps[dst]*extent), 
78                                              recvcounts[dst], datatype, dst,
79                                              COLL_TAG_SCATTER, tmp_recvbuf,
80                                              recvcounts[rank], datatype, src,
81                                              COLL_TAG_SCATTER, comm,
82                                              MPI_STATUS_IGNORE);
83             else
84                 smpi_mpi_sendrecv(((char *)recvbuf+disps[dst]*extent), 
85                                              recvcounts[dst], datatype, dst,
86                                              COLL_TAG_SCATTER, tmp_recvbuf,
87                                              recvcounts[rank], datatype, src,
88                                              COLL_TAG_SCATTER, comm,
89                                              MPI_STATUS_IGNORE);
90             
91             if (is_commutative || (src < rank)) {
92                 if (sendbuf != MPI_IN_PLACE) {
93                      smpi_op_apply( op,
94                                                   tmp_recvbuf, recvbuf, &recvcounts[rank],
95                                &datatype); 
96                 }
97                 else {
98                     smpi_op_apply(op, 
99                         tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
100                         &recvcounts[rank], &datatype);
101                     /* we can't store the result at the beginning of
102                        recvbuf right here because there is useful data
103                        there that other process/processes need. at the
104                        end, we will copy back the result to the
105                        beginning of recvbuf. */
106                 }
107             }
108             else {
109                 if (sendbuf != MPI_IN_PLACE) {
110                     smpi_op_apply(op, 
111                        recvbuf, tmp_recvbuf, &recvcounts[rank], &datatype);
112                     /* copy result back into recvbuf */
113                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
114                                                datatype, recvbuf,
115                                                recvcounts[rank], datatype);
116                     if (mpi_errno) return(mpi_errno);
117                 }
118                 else {
119                     smpi_op_apply(op, 
120                         ((char *)recvbuf+disps[rank]*extent),
121                         tmp_recvbuf, &recvcounts[rank], &datatype);
122                     /* copy result back into recvbuf */
123                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
124                                                datatype, 
125                                                ((char *)recvbuf +
126                                                 disps[rank]*extent), 
127                                                recvcounts[rank], datatype);
128                     if (mpi_errno) return(mpi_errno);
129                 }
130             }
131         }
132         
133         /* if MPI_IN_PLACE, move output data to the beginning of
134            recvbuf. already done for rank 0. */
135         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
136             mpi_errno = smpi_datatype_copy(((char *)recvbuf +
137                                         disps[rank]*extent),  
138                                        recvcounts[rank], datatype,
139                                        recvbuf, 
140                                        recvcounts[rank], datatype );
141             if (mpi_errno) return(mpi_errno);
142         }
143     
144         xbt_free(disps);
145         smpi_free_tmp_buffer(tmp_recvbuf);
146
147         return MPI_SUCCESS;
148 }
149     
150
151 int smpi_coll_tuned_reduce_scatter_mpich_noncomm(void *sendbuf, void *recvbuf, int recvcounts[],
152                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
153 {
154     int mpi_errno = MPI_SUCCESS;
155     int comm_size = smpi_comm_size(comm) ;
156     int rank = smpi_comm_rank(comm);
157     int pof2;
158     int log2_comm_size;
159     int i, k;
160     int recv_offset, send_offset;
161     int block_size, total_count, size;
162     MPI_Aint true_extent, true_lb;
163     int buf0_was_inout;
164     void *tmp_buf0;
165     void *tmp_buf1;
166     void *result_ptr;
167
168     smpi_datatype_extent(datatype, &true_lb, &true_extent);
169
170     pof2 = 1;
171     log2_comm_size = 0;
172     while (pof2 < comm_size) {
173         pof2 <<= 1;
174         ++log2_comm_size;
175     }
176
177     /* begin error checking */
178     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
179
180     for (i = 0; i < (comm_size - 1); ++i) {
181         xbt_assert(recvcounts[i] == recvcounts[i+1]);
182     }
183     /* end error checking */
184
185     /* size of a block (count of datatype per block, NOT bytes per block) */
186     block_size = recvcounts[0];
187     total_count = block_size * comm_size;
188
189     tmp_buf0=( void *)smpi_get_tmp_sendbuffer( true_extent * total_count);
190     tmp_buf1=( void *)smpi_get_tmp_recvbuffer( true_extent * total_count);
191     void *tmp_buf0_save=tmp_buf0;
192     void *tmp_buf1_save=tmp_buf1;
193
194     /* adjust for potential negative lower bound in datatype */
195     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
196     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
197
198     /* Copy our send data to tmp_buf0.  We do this one block at a time and
199        permute the blocks as we go according to the mirror permutation. */
200     for (i = 0; i < comm_size; ++i) {
201         mpi_errno = smpi_datatype_copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
202                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
203         if (mpi_errno) return(mpi_errno);
204     }
205     buf0_was_inout = 1;
206
207     send_offset = 0;
208     recv_offset = 0;
209     size = total_count;
210     for (k = 0; k < log2_comm_size; ++k) {
211         /* use a double-buffering scheme to avoid local copies */
212         char *incoming_data = static_cast<char*>(buf0_was_inout ? tmp_buf1 : tmp_buf0);
213         char *outgoing_data = static_cast<char*>(buf0_was_inout ? tmp_buf0 : tmp_buf1);
214         int peer = rank ^ (0x1 << k);
215         size /= 2;
216
217         if (rank > peer) {
218             /* we have the higher rank: send top half, recv bottom half */
219             recv_offset += size;
220         }
221         else {
222             /* we have the lower rank: recv top half, send bottom half */
223             send_offset += size;
224         }
225
226         smpi_mpi_sendrecv(outgoing_data + send_offset*true_extent,
227                                      size, datatype, peer, COLL_TAG_SCATTER,
228                                      incoming_data + recv_offset*true_extent,
229                                      size, datatype, peer, COLL_TAG_SCATTER,
230                                      comm, MPI_STATUS_IGNORE);
231         /* always perform the reduction at recv_offset, the data at send_offset
232            is now our peer's responsibility */
233         if (rank > peer) {
234             /* higher ranked value so need to call op(received_data, my_data) */
235             smpi_op_apply(op, 
236                    incoming_data + recv_offset*true_extent,
237                      outgoing_data + recv_offset*true_extent,
238                      &size, &datatype );
239             /* buf0_was_inout = buf0_was_inout; */
240         }
241         else {
242             /* lower ranked value so need to call op(my_data, received_data) */
243             smpi_op_apply( op,
244                      outgoing_data + recv_offset*true_extent,
245                      incoming_data + recv_offset*true_extent,
246                      &size, &datatype);
247             buf0_was_inout = !buf0_was_inout;
248         }
249
250         /* the next round of send/recv needs to happen within the block (of size
251            "size") that we just received and reduced */
252         send_offset = recv_offset;
253     }
254
255     xbt_assert(size == recvcounts[rank]);
256
257     /* copy the reduced data to the recvbuf */
258     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
259     mpi_errno = smpi_datatype_copy(result_ptr, size, datatype,
260                                recvbuf, size, datatype);
261     smpi_free_tmp_buffer(tmp_buf0_save);
262     smpi_free_tmp_buffer(tmp_buf1_save);
263     if (mpi_errno) return(mpi_errno);
264     return MPI_SUCCESS;
265 }
266
267
268
269 int smpi_coll_tuned_reduce_scatter_mpich_rdb(void *sendbuf, void *recvbuf, int recvcounts[],
270                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
271 {
272     int   rank, comm_size, i;
273     MPI_Aint extent, true_extent, true_lb; 
274     int  *disps;
275     void *tmp_recvbuf, *tmp_results;
276     int mpi_errno = MPI_SUCCESS;
277     int dis[2], blklens[2], total_count, dst;
278     int mask, dst_tree_root, my_tree_root, j, k;
279     int received;
280     MPI_Datatype sendtype, recvtype;
281     int nprocs_completed, tmp_mask, tree_root, is_commutative=0;
282     comm_size = smpi_comm_size(comm);
283     rank = smpi_comm_rank(comm);
284
285     extent =smpi_datatype_get_extent(datatype);
286     smpi_datatype_extent(datatype, &true_lb, &true_extent);
287     
288     if (smpi_op_is_commute(op)) {
289         is_commutative = 1;
290     }
291
292     disps = (int*)xbt_malloc( comm_size * sizeof(int));
293
294     total_count = 0;
295     for (i=0; i<comm_size; i++) {
296         disps[i] = total_count;
297         total_count += recvcounts[i];
298     }
299     
300             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
301
302             /* need to allocate temporary buffer to receive incoming data*/
303             tmp_recvbuf= (void *) smpi_get_tmp_recvbuffer( total_count*(MAX(true_extent,extent)));
304             /* adjust for potential negative lower bound in datatype */
305             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
306
307             /* need to allocate another temporary buffer to accumulate
308                results */
309             tmp_results = (void *)smpi_get_tmp_sendbuffer( total_count*(MAX(true_extent,extent)));
310             /* adjust for potential negative lower bound in datatype */
311             tmp_results = (void *)((char*)tmp_results - true_lb);
312
313             /* copy sendbuf into tmp_results */
314             if (sendbuf != MPI_IN_PLACE)
315                 mpi_errno = smpi_datatype_copy(sendbuf, total_count, datatype,
316                                            tmp_results, total_count, datatype);
317             else
318                 mpi_errno = smpi_datatype_copy(recvbuf, total_count, datatype,
319                                            tmp_results, total_count, datatype);
320
321             if (mpi_errno) return(mpi_errno);
322
323             mask = 0x1;
324             i = 0;
325             while (mask < comm_size) {
326                 dst = rank ^ mask;
327
328                 dst_tree_root = dst >> i;
329                 dst_tree_root <<= i;
330
331                 my_tree_root = rank >> i;
332                 my_tree_root <<= i;
333
334                 /* At step 1, processes exchange (n-n/p) amount of
335                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
336                    amount of data, and so forth. We use derived datatypes for this.
337
338                    At each step, a process does not need to send data
339                    indexed from my_tree_root to
340                    my_tree_root+mask-1. Similarly, a process won't receive
341                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
342
343                 /* calculate sendtype */
344                 blklens[0] = blklens[1] = 0;
345                 for (j=0; j<my_tree_root; j++)
346                     blklens[0] += recvcounts[j];
347                 for (j=my_tree_root+mask; j<comm_size; j++)
348                     blklens[1] += recvcounts[j];
349
350                 dis[0] = 0;
351                 dis[1] = blklens[0];
352                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
353                     dis[1] += recvcounts[j];
354
355                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &sendtype);
356                 if (mpi_errno) return(mpi_errno);
357                 
358                 smpi_datatype_commit(&sendtype);
359
360                 /* calculate recvtype */
361                 blklens[0] = blklens[1] = 0;
362                 for (j=0; j<dst_tree_root && j<comm_size; j++)
363                     blklens[0] += recvcounts[j];
364                 for (j=dst_tree_root+mask; j<comm_size; j++)
365                     blklens[1] += recvcounts[j];
366
367                 dis[0] = 0;
368                 dis[1] = blklens[0];
369                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
370                     dis[1] += recvcounts[j];
371
372                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &recvtype);
373                 if (mpi_errno) return(mpi_errno);
374                 
375                 smpi_datatype_commit(&recvtype);
376
377                 received = 0;
378                 if (dst < comm_size) {
379                     /* tmp_results contains data to be sent in each step. Data is
380                        received in tmp_recvbuf and then accumulated into
381                        tmp_results. accumulation is done later below.   */ 
382
383                     smpi_mpi_sendrecv(tmp_results, 1, sendtype, dst,
384                                                  COLL_TAG_SCATTER,
385                                                  tmp_recvbuf, 1, recvtype, dst,
386                                                  COLL_TAG_SCATTER, comm,
387                                                  MPI_STATUS_IGNORE);
388                     received = 1;
389                 }
390
391                 /* if some processes in this process's subtree in this step
392                    did not have any destination process to communicate with
393                    because of non-power-of-two, we need to send them the
394                    result. We use a logarithmic recursive-halfing algorithm
395                    for this. */
396
397                 if (dst_tree_root + mask > comm_size) {
398                     nprocs_completed = comm_size - my_tree_root - mask;
399                     /* nprocs_completed is the number of processes in this
400                        subtree that have all the data. Send data to others
401                        in a tree fashion. First find root of current tree
402                        that is being divided into two. k is the number of
403                        least-significant bits in this process's rank that
404                        must be zeroed out to find the rank of the root */ 
405                     j = mask;
406                     k = 0;
407                     while (j) {
408                         j >>= 1;
409                         k++;
410                     }
411                     k--;
412
413                     tmp_mask = mask >> 1;
414                     while (tmp_mask) {
415                         dst = rank ^ tmp_mask;
416
417                         tree_root = rank >> k;
418                         tree_root <<= k;
419
420                         /* send only if this proc has data and destination
421                            doesn't have data. at any step, multiple processes
422                            can send if they have the data */
423                         if ((dst > rank) && 
424                             (rank < tree_root + nprocs_completed)
425                             && (dst >= tree_root + nprocs_completed)) {
426                             /* send the current result */
427                             smpi_mpi_send(tmp_recvbuf, 1, recvtype,
428                                                      dst, COLL_TAG_SCATTER,
429                                                      comm);
430                         }
431                         /* recv only if this proc. doesn't have data and sender
432                            has data */
433                         else if ((dst < rank) && 
434                                  (dst < tree_root + nprocs_completed) &&
435                                  (rank >= tree_root + nprocs_completed)) {
436                             smpi_mpi_recv(tmp_recvbuf, 1, recvtype, dst,
437                                                      COLL_TAG_SCATTER,
438                                                      comm, MPI_STATUS_IGNORE); 
439                             received = 1;
440                         }
441                         tmp_mask >>= 1;
442                         k--;
443                     }
444                 }
445
446                 /* The following reduction is done here instead of after 
447                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
448                    because to do it above, in the noncommutative 
449                    case, we would need an extra temp buffer so as not to
450                    overwrite temp_recvbuf, because temp_recvbuf may have
451                    to be communicated to other processes in the
452                    non-power-of-two case. To avoid that extra allocation,
453                    we do the reduce here. */
454                 if (received) {
455                     if (is_commutative || (dst_tree_root < my_tree_root)) {
456                         {
457                                  smpi_op_apply(op, 
458                                tmp_recvbuf, tmp_results, &blklens[0],
459                                &datatype); 
460                                 smpi_op_apply(op, 
461                                ((char *)tmp_recvbuf + dis[1]*extent),
462                                ((char *)tmp_results + dis[1]*extent),
463                                &blklens[1], &datatype); 
464                         }
465                     }
466                     else {
467                         {
468                                  smpi_op_apply(op,
469                                    tmp_results, tmp_recvbuf, &blklens[0],
470                                    &datatype); 
471                                  smpi_op_apply(op,
472                                    ((char *)tmp_results + dis[1]*extent),
473                                    ((char *)tmp_recvbuf + dis[1]*extent),
474                                    &blklens[1], &datatype); 
475                         }
476                         /* copy result back into tmp_results */
477                         mpi_errno = smpi_datatype_copy(tmp_recvbuf, 1, recvtype, 
478                                                    tmp_results, 1, recvtype);
479                         if (mpi_errno) return(mpi_errno);
480                     }
481                 }
482
483                 smpi_datatype_unuse(sendtype);
484                 smpi_datatype_unuse(recvtype);
485
486                 mask <<= 1;
487                 i++;
488             }
489
490             /* now copy final results from tmp_results to recvbuf */
491             mpi_errno = smpi_datatype_copy(((char *)tmp_results+disps[rank]*extent),
492                                        recvcounts[rank], datatype, recvbuf,
493                                        recvcounts[rank], datatype);
494             if (mpi_errno) return(mpi_errno);
495
496     xbt_free(disps);
497     smpi_free_tmp_buffer(tmp_recvbuf);
498     smpi_free_tmp_buffer(tmp_results);
499     return MPI_SUCCESS;
500         }
501
502