Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
change algo chosen for pairwise, to work with non power of 2 number of procs
[simgrid.git] / src / smpi / colls / reduce_scatter-mpich.c
1 #include "colls_private.h"
2
3 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
4 {
5     /* a mask for the high order bits that should be copied as-is */
6     int high_mask = ~((0x1 << bits) - 1);
7     int retval = x & high_mask;
8     int i;
9
10     for (i = 0; i < bits; ++i) {
11         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
12         retval |= bitval << ((bits - i) - 1);
13     }
14
15     return retval;
16 }
17
18
19 int smpi_coll_tuned_reduce_scatter_mpich_pair(void *sendbuf, void *recvbuf, int recvcounts[],
20                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
21 {
22     int   rank, comm_size, i;
23     MPI_Aint extent, true_extent, true_lb; 
24     int  *disps;
25     void *tmp_recvbuf;
26     int mpi_errno = MPI_SUCCESS;
27     int total_count, dst, src;
28     int is_commutative;
29     comm_size = smpi_comm_size(comm);
30     rank = smpi_comm_rank(comm);
31
32     extent =smpi_datatype_get_extent(datatype);
33     smpi_datatype_extent(datatype, &true_lb, &true_extent);
34     
35     if (smpi_op_is_commute(op)) {
36         is_commutative = 1;
37     }
38
39     disps = (int*)xbt_malloc( comm_size * sizeof(int));
40
41     total_count = 0;
42     for (i=0; i<comm_size; i++) {
43         disps[i] = total_count;
44         total_count += recvcounts[i];
45     }
46     
47     if (total_count == 0) {
48         return MPI_ERR_COUNT;
49     }
50
51         if (sendbuf != MPI_IN_PLACE) {
52             /* copy local data into recvbuf */
53             smpi_datatype_copy(((char *)sendbuf+disps[rank]*extent),
54                                        recvcounts[rank], datatype, recvbuf,
55                                        recvcounts[rank], datatype);
56         }
57         
58         /* allocate temporary buffer to store incoming data */
59         tmp_recvbuf = (void*)xbt_malloc(recvcounts[rank]*(max(true_extent,extent))+1);
60         /* adjust for potential negative lower bound in datatype */
61         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
62         
63         for (i=1; i<comm_size; i++) {
64             src = (rank - i + comm_size) % comm_size;
65             dst = (rank + i) % comm_size;
66             
67             /* send the data that dst needs. recv data that this process
68                needs from src into tmp_recvbuf */
69             if (sendbuf != MPI_IN_PLACE) 
70                 smpi_mpi_sendrecv(((char *)sendbuf+disps[dst]*extent), 
71                                              recvcounts[dst], datatype, dst,
72                                              COLL_TAG_SCATTER, tmp_recvbuf,
73                                              recvcounts[rank], datatype, src,
74                                              COLL_TAG_SCATTER, comm,
75                                              MPI_STATUS_IGNORE);
76             else
77                 smpi_mpi_sendrecv(((char *)recvbuf+disps[dst]*extent), 
78                                              recvcounts[dst], datatype, dst,
79                                              COLL_TAG_SCATTER, tmp_recvbuf,
80                                              recvcounts[rank], datatype, src,
81                                              COLL_TAG_SCATTER, comm,
82                                              MPI_STATUS_IGNORE);
83             
84             if (is_commutative || (src < rank)) {
85                 if (sendbuf != MPI_IN_PLACE) {
86                      smpi_op_apply( op,
87                                                   tmp_recvbuf, recvbuf, &recvcounts[rank],
88                                &datatype); 
89                 }
90                 else {
91                     smpi_op_apply(op, 
92                         tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
93                         &recvcounts[rank], &datatype);
94                     /* we can't store the result at the beginning of
95                        recvbuf right here because there is useful data
96                        there that other process/processes need. at the
97                        end, we will copy back the result to the
98                        beginning of recvbuf. */
99                 }
100             }
101             else {
102                 if (sendbuf != MPI_IN_PLACE) {
103                     smpi_op_apply(op, 
104                        recvbuf, tmp_recvbuf, &recvcounts[rank], &datatype);
105                     /* copy result back into recvbuf */
106                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
107                                                datatype, recvbuf,
108                                                recvcounts[rank], datatype);
109                     if (mpi_errno) return(mpi_errno);
110                 }
111                 else {
112                     smpi_op_apply(op, 
113                         ((char *)recvbuf+disps[rank]*extent),
114                         tmp_recvbuf, &recvcounts[rank], &datatype);
115                     /* copy result back into recvbuf */
116                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
117                                                datatype, 
118                                                ((char *)recvbuf +
119                                                 disps[rank]*extent), 
120                                                recvcounts[rank], datatype);
121                     if (mpi_errno) return(mpi_errno);
122                 }
123             }
124         }
125         
126         /* if MPI_IN_PLACE, move output data to the beginning of
127            recvbuf. already done for rank 0. */
128         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
129             mpi_errno = smpi_datatype_copy(((char *)recvbuf +
130                                         disps[rank]*extent),  
131                                        recvcounts[rank], datatype,
132                                        recvbuf, 
133                                        recvcounts[rank], datatype );
134             if (mpi_errno) return(mpi_errno);
135         }
136     
137 return MPI_SUCCESS;
138 }
139     
140
141 int smpi_coll_tuned_reduce_scatter_mpich_noncomm(void *sendbuf, void *recvbuf, int recvcounts[],
142                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
143 {
144     int mpi_errno = MPI_SUCCESS;
145     int comm_size = smpi_comm_size(comm) ;
146     int rank = smpi_comm_rank(comm);
147     int pof2;
148     int log2_comm_size;
149     int i, k;
150     int recv_offset, send_offset;
151     int block_size, total_count, size;
152     MPI_Aint true_extent, true_lb;
153     int buf0_was_inout;
154     void *tmp_buf0;
155     void *tmp_buf1;
156     void *result_ptr;
157
158     smpi_datatype_extent(datatype, &true_lb, &true_extent);
159
160     pof2 = 1;
161     log2_comm_size = 0;
162     while (pof2 < comm_size) {
163         pof2 <<= 1;
164         ++log2_comm_size;
165     }
166
167     /* begin error checking */
168     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
169
170     for (i = 0; i < (comm_size - 1); ++i) {
171         xbt_assert(recvcounts[i] == recvcounts[i+1]);
172     }
173     /* end error checking */
174
175     /* size of a block (count of datatype per block, NOT bytes per block) */
176     block_size = recvcounts[0];
177     total_count = block_size * comm_size;
178
179     tmp_buf0=( void *)xbt_malloc( true_extent * total_count);
180     tmp_buf1=( void *)xbt_malloc( true_extent * total_count);
181     /* adjust for potential negative lower bound in datatype */
182     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
183     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
184
185     /* Copy our send data to tmp_buf0.  We do this one block at a time and
186        permute the blocks as we go according to the mirror permutation. */
187     for (i = 0; i < comm_size; ++i) {
188         mpi_errno = smpi_datatype_copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
189                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
190         if (mpi_errno) return(mpi_errno);
191     }
192     buf0_was_inout = 1;
193
194     send_offset = 0;
195     recv_offset = 0;
196     size = total_count;
197     for (k = 0; k < log2_comm_size; ++k) {
198         /* use a double-buffering scheme to avoid local copies */
199         char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
200         char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
201         int peer = rank ^ (0x1 << k);
202         size /= 2;
203
204         if (rank > peer) {
205             /* we have the higher rank: send top half, recv bottom half */
206             recv_offset += size;
207         }
208         else {
209             /* we have the lower rank: recv top half, send bottom half */
210             send_offset += size;
211         }
212
213         smpi_mpi_sendrecv(outgoing_data + send_offset*true_extent,
214                                      size, datatype, peer, COLL_TAG_SCATTER,
215                                      incoming_data + recv_offset*true_extent,
216                                      size, datatype, peer, COLL_TAG_SCATTER,
217                                      comm, MPI_STATUS_IGNORE);
218         /* always perform the reduction at recv_offset, the data at send_offset
219            is now our peer's responsibility */
220         if (rank > peer) {
221             /* higher ranked value so need to call op(received_data, my_data) */
222             smpi_op_apply(op, 
223                    incoming_data + recv_offset*true_extent,
224                      outgoing_data + recv_offset*true_extent,
225                      &size, &datatype );
226             /* buf0_was_inout = buf0_was_inout; */
227         }
228         else {
229             /* lower ranked value so need to call op(my_data, received_data) */
230             smpi_op_apply( op,
231                      outgoing_data + recv_offset*true_extent,
232                      incoming_data + recv_offset*true_extent,
233                      &size, &datatype);
234             buf0_was_inout = !buf0_was_inout;
235         }
236
237         /* the next round of send/recv needs to happen within the block (of size
238            "size") that we just received and reduced */
239         send_offset = recv_offset;
240     }
241
242     xbt_assert(size == recvcounts[rank]);
243
244     /* copy the reduced data to the recvbuf */
245     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
246     mpi_errno = smpi_datatype_copy(result_ptr, size, datatype,
247                                recvbuf, size, datatype);
248     if (mpi_errno) return(mpi_errno);
249     return MPI_SUCCESS;
250 }
251
252
253
254 int smpi_coll_tuned_reduce_scatter_mpich_rdb(void *sendbuf, void *recvbuf, int recvcounts[],
255                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
256 {
257     int   rank, comm_size, i;
258     MPI_Aint extent, true_extent, true_lb; 
259     int  *disps;
260     void *tmp_recvbuf, *tmp_results;
261     int mpi_errno = MPI_SUCCESS;
262     int dis[2], blklens[2], total_count, dst;
263     int mask, dst_tree_root, my_tree_root, j, k;
264     int received;
265     MPI_Datatype sendtype, recvtype;
266     int nprocs_completed, tmp_mask, tree_root, is_commutative=0;
267     comm_size = smpi_comm_size(comm);
268     rank = smpi_comm_rank(comm);
269
270     extent =smpi_datatype_get_extent(datatype);
271     smpi_datatype_extent(datatype, &true_lb, &true_extent);
272     
273     if (smpi_op_is_commute(op)) {
274         is_commutative = 1;
275     }
276
277     disps = (int*)xbt_malloc( comm_size * sizeof(int));
278
279     total_count = 0;
280     for (i=0; i<comm_size; i++) {
281         disps[i] = total_count;
282         total_count += recvcounts[i];
283     }
284     
285             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
286
287             /* need to allocate temporary buffer to receive incoming data*/
288             tmp_recvbuf= (void *) xbt_malloc( total_count*(max(true_extent,extent)));
289             /* adjust for potential negative lower bound in datatype */
290             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
291
292             /* need to allocate another temporary buffer to accumulate
293                results */
294             tmp_results = (void *)xbt_malloc( total_count*(max(true_extent,extent)));
295             /* adjust for potential negative lower bound in datatype */
296             tmp_results = (void *)((char*)tmp_results - true_lb);
297
298             /* copy sendbuf into tmp_results */
299             if (sendbuf != MPI_IN_PLACE)
300                 mpi_errno = smpi_datatype_copy(sendbuf, total_count, datatype,
301                                            tmp_results, total_count, datatype);
302             else
303                 mpi_errno = smpi_datatype_copy(recvbuf, total_count, datatype,
304                                            tmp_results, total_count, datatype);
305
306             if (mpi_errno) return(mpi_errno);
307
308             mask = 0x1;
309             i = 0;
310             while (mask < comm_size) {
311                 dst = rank ^ mask;
312
313                 dst_tree_root = dst >> i;
314                 dst_tree_root <<= i;
315
316                 my_tree_root = rank >> i;
317                 my_tree_root <<= i;
318
319                 /* At step 1, processes exchange (n-n/p) amount of
320                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
321                    amount of data, and so forth. We use derived datatypes for this.
322
323                    At each step, a process does not need to send data
324                    indexed from my_tree_root to
325                    my_tree_root+mask-1. Similarly, a process won't receive
326                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
327
328                 /* calculate sendtype */
329                 blklens[0] = blklens[1] = 0;
330                 for (j=0; j<my_tree_root; j++)
331                     blklens[0] += recvcounts[j];
332                 for (j=my_tree_root+mask; j<comm_size; j++)
333                     blklens[1] += recvcounts[j];
334
335                 dis[0] = 0;
336                 dis[1] = blklens[0];
337                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
338                     dis[1] += recvcounts[j];
339
340                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &sendtype);
341                 if (mpi_errno) return(mpi_errno);
342                 
343                 smpi_datatype_commit(&sendtype);
344
345                 /* calculate recvtype */
346                 blklens[0] = blklens[1] = 0;
347                 for (j=0; j<dst_tree_root && j<comm_size; j++)
348                     blklens[0] += recvcounts[j];
349                 for (j=dst_tree_root+mask; j<comm_size; j++)
350                     blklens[1] += recvcounts[j];
351
352                 dis[0] = 0;
353                 dis[1] = blklens[0];
354                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
355                     dis[1] += recvcounts[j];
356
357                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &recvtype);
358                 if (mpi_errno) return(mpi_errno);
359                 
360                 smpi_datatype_commit(&recvtype);
361
362                 received = 0;
363                 if (dst < comm_size) {
364                     /* tmp_results contains data to be sent in each step. Data is
365                        received in tmp_recvbuf and then accumulated into
366                        tmp_results. accumulation is done later below.   */ 
367
368                     smpi_mpi_sendrecv(tmp_results, 1, sendtype, dst,
369                                                  COLL_TAG_SCATTER,
370                                                  tmp_recvbuf, 1, recvtype, dst,
371                                                  COLL_TAG_SCATTER, comm,
372                                                  MPI_STATUS_IGNORE);
373                     received = 1;
374                 }
375
376                 /* if some processes in this process's subtree in this step
377                    did not have any destination process to communicate with
378                    because of non-power-of-two, we need to send them the
379                    result. We use a logarithmic recursive-halfing algorithm
380                    for this. */
381
382                 if (dst_tree_root + mask > comm_size) {
383                     nprocs_completed = comm_size - my_tree_root - mask;
384                     /* nprocs_completed is the number of processes in this
385                        subtree that have all the data. Send data to others
386                        in a tree fashion. First find root of current tree
387                        that is being divided into two. k is the number of
388                        least-significant bits in this process's rank that
389                        must be zeroed out to find the rank of the root */ 
390                     j = mask;
391                     k = 0;
392                     while (j) {
393                         j >>= 1;
394                         k++;
395                     }
396                     k--;
397
398                     tmp_mask = mask >> 1;
399                     while (tmp_mask) {
400                         dst = rank ^ tmp_mask;
401
402                         tree_root = rank >> k;
403                         tree_root <<= k;
404
405                         /* send only if this proc has data and destination
406                            doesn't have data. at any step, multiple processes
407                            can send if they have the data */
408                         if ((dst > rank) && 
409                             (rank < tree_root + nprocs_completed)
410                             && (dst >= tree_root + nprocs_completed)) {
411                             /* send the current result */
412                             smpi_mpi_send(tmp_recvbuf, 1, recvtype,
413                                                      dst, COLL_TAG_SCATTER,
414                                                      comm);
415                         }
416                         /* recv only if this proc. doesn't have data and sender
417                            has data */
418                         else if ((dst < rank) && 
419                                  (dst < tree_root + nprocs_completed) &&
420                                  (rank >= tree_root + nprocs_completed)) {
421                             smpi_mpi_recv(tmp_recvbuf, 1, recvtype, dst,
422                                                      COLL_TAG_SCATTER,
423                                                      comm, MPI_STATUS_IGNORE); 
424                             received = 1;
425                         }
426                         tmp_mask >>= 1;
427                         k--;
428                     }
429                 }
430
431                 /* The following reduction is done here instead of after 
432                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
433                    because to do it above, in the noncommutative 
434                    case, we would need an extra temp buffer so as not to
435                    overwrite temp_recvbuf, because temp_recvbuf may have
436                    to be communicated to other processes in the
437                    non-power-of-two case. To avoid that extra allocation,
438                    we do the reduce here. */
439                 if (received) {
440                     if (is_commutative || (dst_tree_root < my_tree_root)) {
441                         {
442                                  smpi_op_apply(op, 
443                                tmp_recvbuf, tmp_results, &blklens[0],
444                                &datatype); 
445                                 smpi_op_apply(op, 
446                                ((char *)tmp_recvbuf + dis[1]*extent),
447                                ((char *)tmp_results + dis[1]*extent),
448                                &blklens[1], &datatype); 
449                         }
450                     }
451                     else {
452                         {
453                                  smpi_op_apply(op,
454                                    tmp_results, tmp_recvbuf, &blklens[0],
455                                    &datatype); 
456                                  smpi_op_apply(op,
457                                    ((char *)tmp_results + dis[1]*extent),
458                                    ((char *)tmp_recvbuf + dis[1]*extent),
459                                    &blklens[1], &datatype); 
460                         }
461                         /* copy result back into tmp_results */
462                         mpi_errno = smpi_datatype_copy(tmp_recvbuf, 1, recvtype, 
463                                                    tmp_results, 1, recvtype);
464                         if (mpi_errno) return(mpi_errno);
465                     }
466                 }
467
468                 //smpi_datatype_free(&sendtype);
469                 //smpi_datatype_free(&recvtype);
470
471                 mask <<= 1;
472                 i++;
473             }
474
475             /* now copy final results from tmp_results to recvbuf */
476             mpi_errno = smpi_datatype_copy(((char *)tmp_results+disps[rank]*extent),
477                                        recvcounts[rank], datatype, recvbuf,
478                                        recvcounts[rank], datatype);
479             if (mpi_errno) return(mpi_errno);
480     xbt_free(disps);
481     xbt_free(tmp_recvbuf);
482     xbt_free(tmp_results);
483     return MPI_SUCCESS;
484         }
485
486