Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Change malloc/free to new/delete.
[simgrid.git] / src / smpi / colls / bcast / bcast-SMP-binary.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8
9 int bcast_SMP_binary_segment_byte = 8192;
10 namespace simgrid{
11 namespace smpi{
12 int Coll_bcast_SMP_binary::bcast(void *buf, int count,
13                                      MPI_Datatype datatype, int root,
14                                      MPI_Comm comm)
15 {
16   int tag = COLL_TAG_BCAST;
17   MPI_Status status;
18   MPI_Request request;
19   int rank, size;
20   int i;
21   MPI_Aint extent;
22   extent = datatype->get_extent();
23
24   rank = comm->rank();
25   size = comm->size();
26   if(comm->get_leaders_comm()==MPI_COMM_NULL){
27     comm->init_smp();
28   }
29   int host_num_core=1;
30   if (comm->is_uniform()){
31     host_num_core = comm->get_intra_comm()->size();
32   }else{
33     //implementation buggy in this case
34     return Coll_bcast_mpich::bcast( buf , count, datatype,
35               root, comm);
36   }
37
38   int segment = bcast_SMP_binary_segment_byte / extent;
39   int pipe_length = count / segment;
40   int remainder = count % segment;
41
42   int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
43   int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
44   int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
45   int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
46   int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
47   int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
48   int increment = segment * extent;
49
50   int base = (rank / host_num_core) * host_num_core;
51   int num_core = host_num_core;
52   if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
53     num_core = size - (rank / host_num_core) * host_num_core;
54
55   // if root is not zero send to rank zero first
56   if (root != 0) {
57     if (rank == root)
58       Request::send(buf, count, datatype, 0, tag, comm);
59     else if (rank == 0)
60       Request::recv(buf, count, datatype, root, tag, comm, &status);
61   }
62   // when a message is smaller than a block size => no pipeline
63   if (count <= segment) {
64     // case ROOT-of-each-SMP
65     if (rank % host_num_core == 0) {
66       // case ROOT
67       if (rank == 0) {
68         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
69         if (to_inter_left < size)
70           Request::send(buf, count, datatype, to_inter_left, tag, comm);
71         if (to_inter_right < size)
72           Request::send(buf, count, datatype, to_inter_right, tag, comm);
73         if ((to_intra_left - base) < num_core)
74           Request::send(buf, count, datatype, to_intra_left, tag, comm);
75         if ((to_intra_right - base) < num_core)
76           Request::send(buf, count, datatype, to_intra_right, tag, comm);
77       }
78       // case LEAVES ROOT-of-eash-SMP
79       else if (to_inter_left >= size) {
80         //printf("node %d from %d\n",rank,from_inter);
81         request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
82         Request::wait(&request, &status);
83         if ((to_intra_left - base) < num_core)
84           Request::send(buf, count, datatype, to_intra_left, tag, comm);
85         if ((to_intra_right - base) < num_core)
86           Request::send(buf, count, datatype, to_intra_right, tag, comm);
87       }
88       // case INTERMEDIAT ROOT-of-each-SMP
89       else {
90         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
91         request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
92         Request::wait(&request, &status);
93         Request::send(buf, count, datatype, to_inter_left, tag, comm);
94         if (to_inter_right < size)
95           Request::send(buf, count, datatype, to_inter_right, tag, comm);
96         if ((to_intra_left - base) < num_core)
97           Request::send(buf, count, datatype, to_intra_left, tag, comm);
98         if ((to_intra_right - base) < num_core)
99           Request::send(buf, count, datatype, to_intra_right, tag, comm);
100       }
101     }
102     // case non ROOT-of-each-SMP
103     else {
104       // case leaves
105       if ((to_intra_left - base) >= num_core) {
106         request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
107         Request::wait(&request, &status);
108       }
109       // case intermediate
110       else {
111         request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
112         Request::wait(&request, &status);
113         Request::send(buf, count, datatype, to_intra_left, tag, comm);
114         if ((to_intra_right - base) < num_core)
115           Request::send(buf, count, datatype, to_intra_right, tag, comm);
116       }
117     }
118
119     return MPI_SUCCESS;
120   }
121
122   // pipeline bcast
123   else {
124     MPI_Request* request_array = new MPI_Request[size + pipe_length];
125     MPI_Status* status_array   = new MPI_Status[size + pipe_length];
126
127     // case ROOT-of-each-SMP
128     if (rank % host_num_core == 0) {
129       // case ROOT
130       if (rank == 0) {
131         for (i = 0; i < pipe_length; i++) {
132           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
133           if (to_inter_left < size)
134             Request::send((char *) buf + (i * increment), segment, datatype,
135                      to_inter_left, (tag + i), comm);
136           if (to_inter_right < size)
137             Request::send((char *) buf + (i * increment), segment, datatype,
138                      to_inter_right, (tag + i), comm);
139           if ((to_intra_left - base) < num_core)
140             Request::send((char *) buf + (i * increment), segment, datatype,
141                      to_intra_left, (tag + i), comm);
142           if ((to_intra_right - base) < num_core)
143             Request::send((char *) buf + (i * increment), segment, datatype,
144                      to_intra_right, (tag + i), comm);
145         }
146       }
147       // case LEAVES ROOT-of-eash-SMP
148       else if (to_inter_left >= size) {
149         //printf("node %d from %d\n",rank,from_inter);
150         for (i = 0; i < pipe_length; i++) {
151           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
152                     from_inter, (tag + i), comm);
153         }
154         for (i = 0; i < pipe_length; i++) {
155           Request::wait(&request_array[i], &status);
156           if ((to_intra_left - base) < num_core)
157             Request::send((char *) buf + (i * increment), segment, datatype,
158                      to_intra_left, (tag + i), comm);
159           if ((to_intra_right - base) < num_core)
160             Request::send((char *) buf + (i * increment), segment, datatype,
161                      to_intra_right, (tag + i), comm);
162         }
163       }
164       // case INTERMEDIAT ROOT-of-each-SMP
165       else {
166         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
167         for (i = 0; i < pipe_length; i++) {
168           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
169                     from_inter, (tag + i), comm);
170         }
171         for (i = 0; i < pipe_length; i++) {
172           Request::wait(&request_array[i], &status);
173           Request::send((char *) buf + (i * increment), segment, datatype,
174                    to_inter_left, (tag + i), comm);
175           if (to_inter_right < size)
176             Request::send((char *) buf + (i * increment), segment, datatype,
177                      to_inter_right, (tag + i), comm);
178           if ((to_intra_left - base) < num_core)
179             Request::send((char *) buf + (i * increment), segment, datatype,
180                      to_intra_left, (tag + i), comm);
181           if ((to_intra_right - base) < num_core)
182             Request::send((char *) buf + (i * increment), segment, datatype,
183                      to_intra_right, (tag + i), comm);
184         }
185       }
186     }
187     // case non-ROOT-of-each-SMP
188     else {
189       // case leaves
190       if ((to_intra_left - base) >= num_core) {
191         for (i = 0; i < pipe_length; i++) {
192           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
193                     from_intra, (tag + i), comm);
194         }
195         Request::waitall((pipe_length), request_array, status_array);
196       }
197       // case intermediate
198       else {
199         for (i = 0; i < pipe_length; i++) {
200           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
201                     from_intra, (tag + i), comm);
202         }
203         for (i = 0; i < pipe_length; i++) {
204           Request::wait(&request_array[i], &status);
205           Request::send((char *) buf + (i * increment), segment, datatype,
206                    to_intra_left, (tag + i), comm);
207           if ((to_intra_right - base) < num_core)
208             Request::send((char *) buf + (i * increment), segment, datatype,
209                      to_intra_right, (tag + i), comm);
210         }
211       }
212     }
213
214     delete[] request_array;
215     delete[] status_array;
216   }
217
218   // when count is not divisible by block size, use default BCAST for the remainder
219   if ((remainder != 0) && (count > segment)) {
220     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
221     Colls::bcast((char *) buf + (pipe_length * increment), remainder, datatype,
222               root, comm);
223   }
224
225   return 1;
226 }
227
228 }
229 }