Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of git+ssh://scm.gforge.inria.fr//gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / colls / allreduce-rab1.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8 //#include <star-reduction.c>
9
10 // NP pow of 2 for now
11 int smpi_coll_tuned_allreduce_rab1(void *sbuff, void *rbuff,
12                                    int count, MPI_Datatype dtype,
13                                    MPI_Op op, MPI_Comm comm)
14 {
15   MPI_Status status;
16   MPI_Aint extent;
17   int tag = COLL_TAG_ALLREDUCE, send_size, newcnt, share;
18   unsigned int pof2 = 1, mask;
19   int send_idx, recv_idx, dst, send_cnt, recv_cnt;
20
21   void *recv, *tmp_buf;
22
23   unsigned int rank = smpi_comm_rank(comm);
24   unsigned int nprocs = smpi_comm_size(comm);
25
26   if((nprocs&(nprocs-1)))
27     THROWF(arg_error,0, "allreduce rab1 algorithm can't be used with non power of two number of processes ! ");
28
29   extent = smpi_datatype_get_extent(dtype);
30
31   pof2 = 1;
32   while (pof2 <= nprocs)
33     pof2 <<= 1;
34   pof2 >>= 1;
35
36   send_idx = recv_idx = 0;
37
38   // uneven count
39   if ((count % nprocs)) {
40     send_size = (count + nprocs) / nprocs;
41     newcnt = send_size * nprocs;
42
43     recv = (void *) smpi_get_tmp_recvbuffer(extent * newcnt);
44     tmp_buf = (void *) smpi_get_tmp_sendbuffer(extent * newcnt);
45     memcpy(recv, sbuff, extent * count);
46
47
48     mask = pof2 / 2;
49     share = newcnt / pof2;
50     while (mask > 0) {
51       dst = rank ^ mask;
52       send_cnt = recv_cnt = newcnt / (pof2 / mask);
53
54       if (rank < dst)
55         send_idx = recv_idx + (mask * share);
56       else
57         recv_idx = send_idx + (mask * share);
58
59       smpi_mpi_sendrecv((char *) recv + send_idx * extent, send_cnt, dtype, dst, tag,
60                    tmp_buf, recv_cnt, dtype, dst, tag, comm, &status);
61
62       smpi_op_apply(op, tmp_buf, (char *) recv + recv_idx * extent, &recv_cnt,
63                      &dtype);
64
65       // update send_idx for next iteration 
66       send_idx = recv_idx;
67       mask >>= 1;
68     }
69
70     memcpy(tmp_buf, (char *) recv + recv_idx * extent, recv_cnt * extent);
71     mpi_coll_allgather_fun(tmp_buf, recv_cnt, dtype, recv, recv_cnt, dtype, comm);
72
73     memcpy(rbuff, recv, count * extent);
74     smpi_free_tmp_buffer(recv);
75     smpi_free_tmp_buffer(tmp_buf);
76
77   }
78
79   else {
80     tmp_buf = (void *) smpi_get_tmp_sendbuffer(extent * count);
81     memcpy(rbuff, sbuff, count * extent);
82     mask = pof2 / 2;
83     share = count / pof2;
84     while (mask > 0) {
85       dst = rank ^ mask;
86       send_cnt = recv_cnt = count / (pof2 / mask);
87
88       if (rank < dst)
89         send_idx = recv_idx + (mask * share);
90       else
91         recv_idx = send_idx + (mask * share);
92
93       smpi_mpi_sendrecv((char *) rbuff + send_idx * extent, send_cnt, dtype, dst,
94                    tag, tmp_buf, recv_cnt, dtype, dst, tag, comm, &status);
95
96       smpi_op_apply(op, tmp_buf, (char *) rbuff + recv_idx * extent, &recv_cnt,
97                      &dtype);
98
99       // update send_idx for next iteration 
100       send_idx = recv_idx;
101       mask >>= 1;
102     }
103
104     memcpy(tmp_buf, (char *) rbuff + recv_idx * extent, recv_cnt * extent);
105     mpi_coll_allgather_fun(tmp_buf, recv_cnt, dtype, rbuff, recv_cnt, dtype, comm);
106     smpi_free_tmp_buffer(tmp_buf);
107   }
108
109   return MPI_SUCCESS;
110 }