1 /* Copyright (c) 2013-2020. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "../colls_private.hpp"
9 int bcast_SMP_binary_segment_byte = 8192;
12 int bcast__SMP_binary(void *buf, int count,
13 MPI_Datatype datatype, int root,
16 int tag = COLL_TAG_BCAST;
22 extent = datatype->get_extent();
26 if(comm->get_leaders_comm()==MPI_COMM_NULL){
30 if (comm->is_uniform()){
31 host_num_core = comm->get_intra_comm()->size();
33 //implementation buggy in this case
34 return bcast__mpich(buf , count, datatype, root, comm);
37 int segment = bcast_SMP_binary_segment_byte / extent;
38 int pipe_length = count / segment;
39 int remainder = count % segment;
41 int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
42 int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
43 int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
44 int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
45 int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
46 int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
47 int increment = segment * extent;
49 int base = (rank / host_num_core) * host_num_core;
50 int num_core = host_num_core;
51 if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
52 num_core = size - (rank / host_num_core) * host_num_core;
54 // if root is not zero send to rank zero first
57 Request::send(buf, count, datatype, 0, tag, comm);
59 Request::recv(buf, count, datatype, root, tag, comm, &status);
61 // when a message is smaller than a block size => no pipeline
62 if (count <= segment) {
63 // case ROOT-of-each-SMP
64 if (rank % host_num_core == 0) {
67 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
68 if (to_inter_left < size)
69 Request::send(buf, count, datatype, to_inter_left, tag, comm);
70 if (to_inter_right < size)
71 Request::send(buf, count, datatype, to_inter_right, tag, comm);
72 if ((to_intra_left - base) < num_core)
73 Request::send(buf, count, datatype, to_intra_left, tag, comm);
74 if ((to_intra_right - base) < num_core)
75 Request::send(buf, count, datatype, to_intra_right, tag, comm);
77 // case LEAVES ROOT-of-eash-SMP
78 else if (to_inter_left >= size) {
79 //printf("node %d from %d\n",rank,from_inter);
80 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
81 Request::wait(&request, &status);
82 if ((to_intra_left - base) < num_core)
83 Request::send(buf, count, datatype, to_intra_left, tag, comm);
84 if ((to_intra_right - base) < num_core)
85 Request::send(buf, count, datatype, to_intra_right, tag, comm);
87 // case INTERMEDIAT ROOT-of-each-SMP
89 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
90 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
91 Request::wait(&request, &status);
92 Request::send(buf, count, datatype, to_inter_left, tag, comm);
93 if (to_inter_right < size)
94 Request::send(buf, count, datatype, to_inter_right, tag, comm);
95 if ((to_intra_left - base) < num_core)
96 Request::send(buf, count, datatype, to_intra_left, tag, comm);
97 if ((to_intra_right - base) < num_core)
98 Request::send(buf, count, datatype, to_intra_right, tag, comm);
101 // case non ROOT-of-each-SMP
104 if ((to_intra_left - base) >= num_core) {
105 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
106 Request::wait(&request, &status);
110 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
111 Request::wait(&request, &status);
112 Request::send(buf, count, datatype, to_intra_left, tag, comm);
113 if ((to_intra_right - base) < num_core)
114 Request::send(buf, count, datatype, to_intra_right, tag, comm);
123 auto* request_array = new MPI_Request[size + pipe_length];
124 auto* status_array = new MPI_Status[size + pipe_length];
126 // case ROOT-of-each-SMP
127 if (rank % host_num_core == 0) {
130 for (i = 0; i < pipe_length; i++) {
131 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
132 if (to_inter_left < size)
133 Request::send((char *) buf + (i * increment), segment, datatype,
134 to_inter_left, (tag + i), comm);
135 if (to_inter_right < size)
136 Request::send((char *) buf + (i * increment), segment, datatype,
137 to_inter_right, (tag + i), comm);
138 if ((to_intra_left - base) < num_core)
139 Request::send((char *) buf + (i * increment), segment, datatype,
140 to_intra_left, (tag + i), comm);
141 if ((to_intra_right - base) < num_core)
142 Request::send((char *) buf + (i * increment), segment, datatype,
143 to_intra_right, (tag + i), comm);
146 // case LEAVES ROOT-of-eash-SMP
147 else if (to_inter_left >= size) {
148 //printf("node %d from %d\n",rank,from_inter);
149 for (i = 0; i < pipe_length; i++) {
150 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
151 from_inter, (tag + i), comm);
153 for (i = 0; i < pipe_length; i++) {
154 Request::wait(&request_array[i], &status);
155 if ((to_intra_left - base) < num_core)
156 Request::send((char *) buf + (i * increment), segment, datatype,
157 to_intra_left, (tag + i), comm);
158 if ((to_intra_right - base) < num_core)
159 Request::send((char *) buf + (i * increment), segment, datatype,
160 to_intra_right, (tag + i), comm);
163 // case INTERMEDIAT ROOT-of-each-SMP
165 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
166 for (i = 0; i < pipe_length; i++) {
167 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
168 from_inter, (tag + i), comm);
170 for (i = 0; i < pipe_length; i++) {
171 Request::wait(&request_array[i], &status);
172 Request::send((char *) buf + (i * increment), segment, datatype,
173 to_inter_left, (tag + i), comm);
174 if (to_inter_right < size)
175 Request::send((char *) buf + (i * increment), segment, datatype,
176 to_inter_right, (tag + i), comm);
177 if ((to_intra_left - base) < num_core)
178 Request::send((char *) buf + (i * increment), segment, datatype,
179 to_intra_left, (tag + i), comm);
180 if ((to_intra_right - base) < num_core)
181 Request::send((char *) buf + (i * increment), segment, datatype,
182 to_intra_right, (tag + i), comm);
186 // case non-ROOT-of-each-SMP
189 if ((to_intra_left - base) >= num_core) {
190 for (i = 0; i < pipe_length; i++) {
191 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
192 from_intra, (tag + i), comm);
194 Request::waitall((pipe_length), request_array, status_array);
198 for (i = 0; i < pipe_length; i++) {
199 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
200 from_intra, (tag + i), comm);
202 for (i = 0; i < pipe_length; i++) {
203 Request::wait(&request_array[i], &status);
204 Request::send((char *) buf + (i * increment), segment, datatype,
205 to_intra_left, (tag + i), comm);
206 if ((to_intra_right - base) < num_core)
207 Request::send((char *) buf + (i * increment), segment, datatype,
208 to_intra_right, (tag + i), comm);
213 delete[] request_array;
214 delete[] status_array;
217 // when count is not divisible by block size, use default BCAST for the remainder
218 if ((remainder != 0) && (count > segment)) {
219 XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
220 colls::bcast((char*)buf + (pipe_length * increment), remainder, datatype, root, comm);