Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
e5ff57aaca332bc8e925572bb3493a8b4c8ebb33
[simgrid.git] / src / smpi / colls / allreduce / allreduce-rab1.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8 //#include <star-reduction.c>
9 namespace simgrid{
10 namespace smpi{
11 // NP pow of 2 for now
12 int Coll_allreduce_rab1::allreduce(const void *sbuff, void *rbuff,
13                                    int count, MPI_Datatype dtype,
14                                    MPI_Op op, MPI_Comm comm)
15 {
16   MPI_Status status;
17   MPI_Aint extent;
18   int tag = COLL_TAG_ALLREDUCE, send_size, newcnt, share;
19   unsigned int pof2 = 1, mask;
20   int send_idx, recv_idx, dst, send_cnt, recv_cnt;
21
22   int rank = comm->rank();
23   unsigned int nprocs = comm->size();
24
25   if((nprocs&(nprocs-1)))
26     THROWF(arg_error,0, "allreduce rab1 algorithm can't be used with non power of two number of processes ! ");
27
28   extent = dtype->get_extent();
29
30   pof2 = 1;
31   while (pof2 <= nprocs)
32     pof2 <<= 1;
33   pof2 >>= 1;
34
35   send_idx = recv_idx = 0;
36
37   // uneven count
38   if ((count % nprocs)) {
39     send_size = (count + nprocs) / nprocs;
40     newcnt = send_size * nprocs;
41
42     unsigned char* recv    = smpi_get_tmp_recvbuffer(extent * newcnt);
43     unsigned char* tmp_buf = smpi_get_tmp_sendbuffer(extent * newcnt);
44     memcpy(recv, sbuff, extent * count);
45
46
47     mask = pof2 / 2;
48     share = newcnt / pof2;
49     while (mask > 0) {
50       dst = rank ^ mask;
51       send_cnt = recv_cnt = newcnt / (pof2 / mask);
52
53       if (rank < dst)
54         send_idx = recv_idx + (mask * share);
55       else
56         recv_idx = send_idx + (mask * share);
57
58       Request::sendrecv(recv + send_idx * extent, send_cnt, dtype, dst, tag, tmp_buf, recv_cnt, dtype, dst, tag, comm,
59                         &status);
60
61       if (op != MPI_OP_NULL)
62         op->apply(tmp_buf, recv + recv_idx * extent, &recv_cnt, dtype);
63
64       // update send_idx for next iteration
65       send_idx = recv_idx;
66       mask >>= 1;
67     }
68
69     memcpy(tmp_buf, recv + recv_idx * extent, recv_cnt * extent);
70     Colls::allgather(tmp_buf, recv_cnt, dtype, recv, recv_cnt, dtype, comm);
71
72     memcpy(rbuff, recv, count * extent);
73     smpi_free_tmp_buffer(recv);
74     smpi_free_tmp_buffer(tmp_buf);
75
76   }
77
78   else {
79     unsigned char* tmp_buf = smpi_get_tmp_sendbuffer(extent * count);
80     memcpy(rbuff, sbuff, count * extent);
81     mask = pof2 / 2;
82     share = count / pof2;
83     while (mask > 0) {
84       dst = rank ^ mask;
85       send_cnt = recv_cnt = count / (pof2 / mask);
86
87       if (rank < dst)
88         send_idx = recv_idx + (mask * share);
89       else
90         recv_idx = send_idx + (mask * share);
91
92       Request::sendrecv((char *) rbuff + send_idx * extent, send_cnt, dtype, dst,
93                    tag, tmp_buf, recv_cnt, dtype, dst, tag, comm, &status);
94
95       if(op!=MPI_OP_NULL) op->apply( tmp_buf, (char *) rbuff + recv_idx * extent, &recv_cnt,
96                      dtype);
97
98       // update send_idx for next iteration
99       send_idx = recv_idx;
100       mask >>= 1;
101     }
102
103     memcpy(tmp_buf, (char *) rbuff + recv_idx * extent, recv_cnt * extent);
104     Colls::allgather(tmp_buf, recv_cnt, dtype, rbuff, recv_cnt, dtype, comm);
105     smpi_free_tmp_buffer(tmp_buf);
106   }
107
108   return MPI_SUCCESS;
109 }
110 }
111 }