Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines with new year.
[simgrid.git] / src / smpi / colls / bcast / bcast-SMP-binary.cpp
1 /* Copyright (c) 2013-2020. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8
9 int bcast_SMP_binary_segment_byte = 8192;
10 namespace simgrid{
11 namespace smpi{
12 int bcast__SMP_binary(void *buf, int count,
13                       MPI_Datatype datatype, int root,
14                       MPI_Comm comm)
15 {
16   int tag = COLL_TAG_BCAST;
17   MPI_Status status;
18   MPI_Request request;
19   int rank, size;
20   int i;
21   MPI_Aint extent;
22   extent = datatype->get_extent();
23
24   rank = comm->rank();
25   size = comm->size();
26   if(comm->get_leaders_comm()==MPI_COMM_NULL){
27     comm->init_smp();
28   }
29   int host_num_core=1;
30   if (comm->is_uniform()){
31     host_num_core = comm->get_intra_comm()->size();
32   }else{
33     //implementation buggy in this case
34     return bcast__mpich(buf , count, datatype, root, comm);
35   }
36
37   int segment = bcast_SMP_binary_segment_byte / extent;
38   int pipe_length = count / segment;
39   int remainder = count % segment;
40
41   int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
42   int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
43   int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
44   int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
45   int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
46   int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
47   int increment = segment * extent;
48
49   int base = (rank / host_num_core) * host_num_core;
50   int num_core = host_num_core;
51   if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
52     num_core = size - (rank / host_num_core) * host_num_core;
53
54   // if root is not zero send to rank zero first
55   if (root != 0) {
56     if (rank == root)
57       Request::send(buf, count, datatype, 0, tag, comm);
58     else if (rank == 0)
59       Request::recv(buf, count, datatype, root, tag, comm, &status);
60   }
61   // when a message is smaller than a block size => no pipeline
62   if (count <= segment) {
63     // case ROOT-of-each-SMP
64     if (rank % host_num_core == 0) {
65       // case ROOT
66       if (rank == 0) {
67         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
68         if (to_inter_left < size)
69           Request::send(buf, count, datatype, to_inter_left, tag, comm);
70         if (to_inter_right < size)
71           Request::send(buf, count, datatype, to_inter_right, tag, comm);
72         if ((to_intra_left - base) < num_core)
73           Request::send(buf, count, datatype, to_intra_left, tag, comm);
74         if ((to_intra_right - base) < num_core)
75           Request::send(buf, count, datatype, to_intra_right, tag, comm);
76       }
77       // case LEAVES ROOT-of-eash-SMP
78       else if (to_inter_left >= size) {
79         //printf("node %d from %d\n",rank,from_inter);
80         request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
81         Request::wait(&request, &status);
82         if ((to_intra_left - base) < num_core)
83           Request::send(buf, count, datatype, to_intra_left, tag, comm);
84         if ((to_intra_right - base) < num_core)
85           Request::send(buf, count, datatype, to_intra_right, tag, comm);
86       }
87       // case INTERMEDIAT ROOT-of-each-SMP
88       else {
89         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
90         request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
91         Request::wait(&request, &status);
92         Request::send(buf, count, datatype, to_inter_left, tag, comm);
93         if (to_inter_right < size)
94           Request::send(buf, count, datatype, to_inter_right, tag, comm);
95         if ((to_intra_left - base) < num_core)
96           Request::send(buf, count, datatype, to_intra_left, tag, comm);
97         if ((to_intra_right - base) < num_core)
98           Request::send(buf, count, datatype, to_intra_right, tag, comm);
99       }
100     }
101     // case non ROOT-of-each-SMP
102     else {
103       // case leaves
104       if ((to_intra_left - base) >= num_core) {
105         request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
106         Request::wait(&request, &status);
107       }
108       // case intermediate
109       else {
110         request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
111         Request::wait(&request, &status);
112         Request::send(buf, count, datatype, to_intra_left, tag, comm);
113         if ((to_intra_right - base) < num_core)
114           Request::send(buf, count, datatype, to_intra_right, tag, comm);
115       }
116     }
117
118     return MPI_SUCCESS;
119   }
120
121   // pipeline bcast
122   else {
123     MPI_Request* request_array = new MPI_Request[size + pipe_length];
124     MPI_Status* status_array   = new MPI_Status[size + pipe_length];
125
126     // case ROOT-of-each-SMP
127     if (rank % host_num_core == 0) {
128       // case ROOT
129       if (rank == 0) {
130         for (i = 0; i < pipe_length; i++) {
131           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
132           if (to_inter_left < size)
133             Request::send((char *) buf + (i * increment), segment, datatype,
134                      to_inter_left, (tag + i), comm);
135           if (to_inter_right < size)
136             Request::send((char *) buf + (i * increment), segment, datatype,
137                      to_inter_right, (tag + i), comm);
138           if ((to_intra_left - base) < num_core)
139             Request::send((char *) buf + (i * increment), segment, datatype,
140                      to_intra_left, (tag + i), comm);
141           if ((to_intra_right - base) < num_core)
142             Request::send((char *) buf + (i * increment), segment, datatype,
143                      to_intra_right, (tag + i), comm);
144         }
145       }
146       // case LEAVES ROOT-of-eash-SMP
147       else if (to_inter_left >= size) {
148         //printf("node %d from %d\n",rank,from_inter);
149         for (i = 0; i < pipe_length; i++) {
150           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
151                     from_inter, (tag + i), comm);
152         }
153         for (i = 0; i < pipe_length; i++) {
154           Request::wait(&request_array[i], &status);
155           if ((to_intra_left - base) < num_core)
156             Request::send((char *) buf + (i * increment), segment, datatype,
157                      to_intra_left, (tag + i), comm);
158           if ((to_intra_right - base) < num_core)
159             Request::send((char *) buf + (i * increment), segment, datatype,
160                      to_intra_right, (tag + i), comm);
161         }
162       }
163       // case INTERMEDIAT ROOT-of-each-SMP
164       else {
165         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
166         for (i = 0; i < pipe_length; i++) {
167           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
168                     from_inter, (tag + i), comm);
169         }
170         for (i = 0; i < pipe_length; i++) {
171           Request::wait(&request_array[i], &status);
172           Request::send((char *) buf + (i * increment), segment, datatype,
173                    to_inter_left, (tag + i), comm);
174           if (to_inter_right < size)
175             Request::send((char *) buf + (i * increment), segment, datatype,
176                      to_inter_right, (tag + i), comm);
177           if ((to_intra_left - base) < num_core)
178             Request::send((char *) buf + (i * increment), segment, datatype,
179                      to_intra_left, (tag + i), comm);
180           if ((to_intra_right - base) < num_core)
181             Request::send((char *) buf + (i * increment), segment, datatype,
182                      to_intra_right, (tag + i), comm);
183         }
184       }
185     }
186     // case non-ROOT-of-each-SMP
187     else {
188       // case leaves
189       if ((to_intra_left - base) >= num_core) {
190         for (i = 0; i < pipe_length; i++) {
191           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
192                     from_intra, (tag + i), comm);
193         }
194         Request::waitall((pipe_length), request_array, status_array);
195       }
196       // case intermediate
197       else {
198         for (i = 0; i < pipe_length; i++) {
199           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
200                     from_intra, (tag + i), comm);
201         }
202         for (i = 0; i < pipe_length; i++) {
203           Request::wait(&request_array[i], &status);
204           Request::send((char *) buf + (i * increment), segment, datatype,
205                    to_intra_left, (tag + i), comm);
206           if ((to_intra_right - base) < num_core)
207             Request::send((char *) buf + (i * increment), segment, datatype,
208                      to_intra_right, (tag + i), comm);
209         }
210       }
211     }
212
213     delete[] request_array;
214     delete[] status_array;
215   }
216
217   // when count is not divisible by block size, use default BCAST for the remainder
218   if ((remainder != 0) && (count > segment)) {
219     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
220     colls::bcast((char*)buf + (pipe_length * increment), remainder, datatype, root, comm);
221   }
222
223   return 1;
224 }
225
226 }
227 }