Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
sanitize the OOP of kernel::profile
[simgrid.git] / src / smpi / colls / allreduce / allreduce-rab1.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8 //#include <star-reduction.c>
9 namespace simgrid{
10 namespace smpi{
11 // NP pow of 2 for now
12 int Coll_allreduce_rab1::allreduce(void *sbuff, void *rbuff,
13                                    int count, MPI_Datatype dtype,
14                                    MPI_Op op, MPI_Comm comm)
15 {
16   MPI_Status status;
17   MPI_Aint extent;
18   int tag = COLL_TAG_ALLREDUCE, send_size, newcnt, share;
19   unsigned int pof2 = 1, mask;
20   int send_idx, recv_idx, dst, send_cnt, recv_cnt;
21
22   void *recv, *tmp_buf;
23
24   int rank = comm->rank();
25   unsigned int nprocs = comm->size();
26
27   if((nprocs&(nprocs-1)))
28     THROWF(arg_error,0, "allreduce rab1 algorithm can't be used with non power of two number of processes ! ");
29
30   extent = dtype->get_extent();
31
32   pof2 = 1;
33   while (pof2 <= nprocs)
34     pof2 <<= 1;
35   pof2 >>= 1;
36
37   send_idx = recv_idx = 0;
38
39   // uneven count
40   if ((count % nprocs)) {
41     send_size = (count + nprocs) / nprocs;
42     newcnt = send_size * nprocs;
43
44     recv = (void *) smpi_get_tmp_recvbuffer(extent * newcnt);
45     tmp_buf = (void *) smpi_get_tmp_sendbuffer(extent * newcnt);
46     memcpy(recv, sbuff, extent * count);
47
48
49     mask = pof2 / 2;
50     share = newcnt / pof2;
51     while (mask > 0) {
52       dst = rank ^ mask;
53       send_cnt = recv_cnt = newcnt / (pof2 / mask);
54
55       if (rank < dst)
56         send_idx = recv_idx + (mask * share);
57       else
58         recv_idx = send_idx + (mask * share);
59
60       Request::sendrecv((char *) recv + send_idx * extent, send_cnt, dtype, dst, tag,
61                    tmp_buf, recv_cnt, dtype, dst, tag, comm, &status);
62
63       if(op!=MPI_OP_NULL) op->apply( tmp_buf, (char *) recv + recv_idx * extent, &recv_cnt,
64                      dtype);
65
66       // update send_idx for next iteration
67       send_idx = recv_idx;
68       mask >>= 1;
69     }
70
71     memcpy(tmp_buf, (char *) recv + recv_idx * extent, recv_cnt * extent);
72     Colls::allgather(tmp_buf, recv_cnt, dtype, recv, recv_cnt, dtype, comm);
73
74     memcpy(rbuff, recv, count * extent);
75     smpi_free_tmp_buffer(recv);
76     smpi_free_tmp_buffer(tmp_buf);
77
78   }
79
80   else {
81     tmp_buf = (void *) smpi_get_tmp_sendbuffer(extent * count);
82     memcpy(rbuff, sbuff, count * extent);
83     mask = pof2 / 2;
84     share = count / pof2;
85     while (mask > 0) {
86       dst = rank ^ mask;
87       send_cnt = recv_cnt = count / (pof2 / mask);
88
89       if (rank < dst)
90         send_idx = recv_idx + (mask * share);
91       else
92         recv_idx = send_idx + (mask * share);
93
94       Request::sendrecv((char *) rbuff + send_idx * extent, send_cnt, dtype, dst,
95                    tag, tmp_buf, recv_cnt, dtype, dst, tag, comm, &status);
96
97       if(op!=MPI_OP_NULL) op->apply( tmp_buf, (char *) rbuff + recv_idx * extent, &recv_cnt,
98                      dtype);
99
100       // update send_idx for next iteration
101       send_idx = recv_idx;
102       mask >>= 1;
103     }
104
105     memcpy(tmp_buf, (char *) rbuff + recv_idx * extent, recv_cnt * extent);
106     Colls::allgather(tmp_buf, recv_cnt, dtype, rbuff, recv_cnt, dtype, comm);
107     smpi_free_tmp_buffer(tmp_buf);
108   }
109
110   return MPI_SUCCESS;
111 }
112 }
113 }