Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branches 'master' and 'master' of github.com:simgrid/simgrid
[simgrid.git] / src / smpi / colls / bcast / bcast-SMP-binary.cpp
1 /* Copyright (c) 2013-2017. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.h"
8
9
10 int bcast_SMP_binary_segment_byte = 8192;
11 namespace simgrid{
12 namespace smpi{
13 int Coll_bcast_SMP_binary::bcast(void *buf, int count,
14                                      MPI_Datatype datatype, int root,
15                                      MPI_Comm comm)
16 {
17   int tag = COLL_TAG_BCAST;
18   MPI_Status status;
19   MPI_Request request;
20   MPI_Request *request_array;
21   MPI_Status *status_array;
22   int rank, size;
23   int i;
24   MPI_Aint extent;
25   extent = datatype->get_extent();
26
27   rank = comm->rank();
28   size = comm->size();
29   if(comm->get_leaders_comm()==MPI_COMM_NULL){
30     comm->init_smp();
31   }
32   int host_num_core=1;
33   if (comm->is_uniform()){
34     host_num_core = comm->get_intra_comm()->size();
35   }else{
36     //implementation buggy in this case
37     return Coll_bcast_mpich::bcast( buf , count, datatype,
38               root, comm);
39   }
40
41   int segment = bcast_SMP_binary_segment_byte / extent;
42   int pipe_length = count / segment;
43   int remainder = count % segment;
44
45   int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
46   int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
47   int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
48   int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
49   int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
50   int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
51   int increment = segment * extent;
52
53   int base = (rank / host_num_core) * host_num_core;
54   int num_core = host_num_core;
55   if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
56     num_core = size - (rank / host_num_core) * host_num_core;
57
58   // if root is not zero send to rank zero first
59   if (root != 0) {
60     if (rank == root)
61       Request::send(buf, count, datatype, 0, tag, comm);
62     else if (rank == 0)
63       Request::recv(buf, count, datatype, root, tag, comm, &status);
64   }
65   // when a message is smaller than a block size => no pipeline 
66   if (count <= segment) {
67     // case ROOT-of-each-SMP
68     if (rank % host_num_core == 0) {
69       // case ROOT
70       if (rank == 0) {
71         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
72         if (to_inter_left < size)
73           Request::send(buf, count, datatype, to_inter_left, tag, comm);
74         if (to_inter_right < size)
75           Request::send(buf, count, datatype, to_inter_right, tag, comm);
76         if ((to_intra_left - base) < num_core)
77           Request::send(buf, count, datatype, to_intra_left, tag, comm);
78         if ((to_intra_right - base) < num_core)
79           Request::send(buf, count, datatype, to_intra_right, tag, comm);
80       }
81       // case LEAVES ROOT-of-eash-SMP
82       else if (to_inter_left >= size) {
83         //printf("node %d from %d\n",rank,from_inter);
84         request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
85         Request::wait(&request, &status);
86         if ((to_intra_left - base) < num_core)
87           Request::send(buf, count, datatype, to_intra_left, tag, comm);
88         if ((to_intra_right - base) < num_core)
89           Request::send(buf, count, datatype, to_intra_right, tag, comm);
90       }
91       // case INTERMEDIAT ROOT-of-each-SMP
92       else {
93         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
94         request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
95         Request::wait(&request, &status);
96         Request::send(buf, count, datatype, to_inter_left, tag, comm);
97         if (to_inter_right < size)
98           Request::send(buf, count, datatype, to_inter_right, tag, comm);
99         if ((to_intra_left - base) < num_core)
100           Request::send(buf, count, datatype, to_intra_left, tag, comm);
101         if ((to_intra_right - base) < num_core)
102           Request::send(buf, count, datatype, to_intra_right, tag, comm);
103       }
104     }
105     // case non ROOT-of-each-SMP
106     else {
107       // case leaves
108       if ((to_intra_left - base) >= num_core) {
109         request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
110         Request::wait(&request, &status);
111       }
112       // case intermediate
113       else {
114         request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
115         Request::wait(&request, &status);
116         Request::send(buf, count, datatype, to_intra_left, tag, comm);
117         if ((to_intra_right - base) < num_core)
118           Request::send(buf, count, datatype, to_intra_right, tag, comm);
119       }
120     }
121
122     return MPI_SUCCESS;
123   }
124
125   // pipeline bcast
126   else {
127     request_array =
128         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
129     status_array =
130         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
131
132     // case ROOT-of-each-SMP
133     if (rank % host_num_core == 0) {
134       // case ROOT
135       if (rank == 0) {
136         for (i = 0; i < pipe_length; i++) {
137           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
138           if (to_inter_left < size)
139             Request::send((char *) buf + (i * increment), segment, datatype,
140                      to_inter_left, (tag + i), comm);
141           if (to_inter_right < size)
142             Request::send((char *) buf + (i * increment), segment, datatype,
143                      to_inter_right, (tag + i), comm);
144           if ((to_intra_left - base) < num_core)
145             Request::send((char *) buf + (i * increment), segment, datatype,
146                      to_intra_left, (tag + i), comm);
147           if ((to_intra_right - base) < num_core)
148             Request::send((char *) buf + (i * increment), segment, datatype,
149                      to_intra_right, (tag + i), comm);
150         }
151       }
152       // case LEAVES ROOT-of-eash-SMP
153       else if (to_inter_left >= size) {
154         //printf("node %d from %d\n",rank,from_inter);
155         for (i = 0; i < pipe_length; i++) {
156           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
157                     from_inter, (tag + i), comm);
158         }
159         for (i = 0; i < pipe_length; i++) {
160           Request::wait(&request_array[i], &status);
161           if ((to_intra_left - base) < num_core)
162             Request::send((char *) buf + (i * increment), segment, datatype,
163                      to_intra_left, (tag + i), comm);
164           if ((to_intra_right - base) < num_core)
165             Request::send((char *) buf + (i * increment), segment, datatype,
166                      to_intra_right, (tag + i), comm);
167         }
168       }
169       // case INTERMEDIAT ROOT-of-each-SMP
170       else {
171         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
172         for (i = 0; i < pipe_length; i++) {
173           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
174                     from_inter, (tag + i), comm);
175         }
176         for (i = 0; i < pipe_length; i++) {
177           Request::wait(&request_array[i], &status);
178           Request::send((char *) buf + (i * increment), segment, datatype,
179                    to_inter_left, (tag + i), comm);
180           if (to_inter_right < size)
181             Request::send((char *) buf + (i * increment), segment, datatype,
182                      to_inter_right, (tag + i), comm);
183           if ((to_intra_left - base) < num_core)
184             Request::send((char *) buf + (i * increment), segment, datatype,
185                      to_intra_left, (tag + i), comm);
186           if ((to_intra_right - base) < num_core)
187             Request::send((char *) buf + (i * increment), segment, datatype,
188                      to_intra_right, (tag + i), comm);
189         }
190       }
191     }
192     // case non-ROOT-of-each-SMP
193     else {
194       // case leaves
195       if ((to_intra_left - base) >= num_core) {
196         for (i = 0; i < pipe_length; i++) {
197           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
198                     from_intra, (tag + i), comm);
199         }
200         Request::waitall((pipe_length), request_array, status_array);
201       }
202       // case intermediate
203       else {
204         for (i = 0; i < pipe_length; i++) {
205           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
206                     from_intra, (tag + i), comm);
207         }
208         for (i = 0; i < pipe_length; i++) {
209           Request::wait(&request_array[i], &status);
210           Request::send((char *) buf + (i * increment), segment, datatype,
211                    to_intra_left, (tag + i), comm);
212           if ((to_intra_right - base) < num_core)
213             Request::send((char *) buf + (i * increment), segment, datatype,
214                      to_intra_right, (tag + i), comm);
215         }
216       }
217     }
218
219     free(request_array);
220     free(status_array);
221   }
222
223   // when count is not divisible by block size, use default BCAST for the remainder
224   if ((remainder != 0) && (count > segment)) {
225     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
226     Colls::bcast((char *) buf + (pipe_length * increment), remainder, datatype,
227               root, comm);
228   }
229
230   return 1;
231 }
232
233 }
234 }