Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
protect (hopefully) collective communication algorithms from abuse.
[simgrid.git] / src / smpi / colls / allreduce-rab1.c
1 #include "colls_private.h"
2 //#include <star-reduction.c>
3
4 // NP pow of 2 for now
5 int smpi_coll_tuned_allreduce_rab1(void *sbuff, void *rbuff,
6                                    int count, MPI_Datatype dtype,
7                                    MPI_Op op, MPI_Comm comm)
8 {
9   MPI_Status status;
10   MPI_Aint extent;
11   int tag = COLL_TAG_ALLREDUCE, rank, nprocs, send_size, newcnt, share;
12   int pof2 = 1, mask, send_idx, recv_idx, dst, send_cnt, recv_cnt;
13
14   void *recv, *tmp_buf;
15
16   rank = smpi_comm_rank(comm);
17   nprocs = smpi_comm_size(comm);
18
19   if((nprocs&(nprocs-1)))
20     THROWF(arg_error,0, "allreduce rab1 algorithm can't be used with non power of two number of processes ! ");
21
22   extent = smpi_datatype_get_extent(dtype);
23
24   pof2 = 1;
25   while (pof2 <= nprocs)
26     pof2 <<= 1;
27   pof2 >>= 1;
28
29   mask = 1;
30   send_idx = recv_idx = 0;
31
32   // uneven count
33   if ((count % nprocs)) {
34     send_size = (count + nprocs) / nprocs;
35     newcnt = send_size * nprocs;
36
37     recv = (void *) xbt_malloc(extent * newcnt);
38     tmp_buf = (void *) xbt_malloc(extent * newcnt);
39     memcpy(recv, sbuff, extent * count);
40
41
42     mask = pof2 / 2;
43     share = newcnt / pof2;
44     while (mask > 0) {
45       dst = rank ^ mask;
46       send_cnt = recv_cnt = newcnt / (pof2 / mask);
47
48       if (rank < dst)
49         send_idx = recv_idx + (mask * share);
50       else
51         recv_idx = send_idx + (mask * share);
52
53       smpi_mpi_sendrecv((char *) recv + send_idx * extent, send_cnt, dtype, dst, tag,
54                    tmp_buf, recv_cnt, dtype, dst, tag, comm, &status);
55
56       smpi_op_apply(op, tmp_buf, (char *) recv + recv_idx * extent, &recv_cnt,
57                      &dtype);
58
59       // update send_idx for next iteration 
60       send_idx = recv_idx;
61       mask >>= 1;
62     }
63
64     memcpy(tmp_buf, (char *) recv + recv_idx * extent, recv_cnt * extent);
65     mpi_coll_allgather_fun(tmp_buf, recv_cnt, dtype, recv, recv_cnt, dtype, comm);
66
67     memcpy(rbuff, recv, count * extent);
68     free(recv);
69     free(tmp_buf);
70
71   }
72
73   else {
74     tmp_buf = (void *) xbt_malloc(extent * count);
75     memcpy(rbuff, sbuff, count * extent);
76     mask = pof2 / 2;
77     share = count / pof2;
78     while (mask > 0) {
79       dst = rank ^ mask;
80       send_cnt = recv_cnt = count / (pof2 / mask);
81
82       if (rank < dst)
83         send_idx = recv_idx + (mask * share);
84       else
85         recv_idx = send_idx + (mask * share);
86
87       smpi_mpi_sendrecv((char *) rbuff + send_idx * extent, send_cnt, dtype, dst,
88                    tag, tmp_buf, recv_cnt, dtype, dst, tag, comm, &status);
89
90       smpi_op_apply(op, tmp_buf, (char *) rbuff + recv_idx * extent, &recv_cnt,
91                      &dtype);
92
93       // update send_idx for next iteration 
94       send_idx = recv_idx;
95       mask >>= 1;
96     }
97
98     memcpy(tmp_buf, (char *) rbuff + recv_idx * extent, recv_cnt * extent);
99     mpi_coll_allgather_fun(tmp_buf, recv_cnt, dtype, rbuff, recv_cnt, dtype, comm);
100     free(tmp_buf);
101   }
102
103   return MPI_SUCCESS;
104 }