1 /* Copyright (c) 2013-2023. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "../colls_private.hpp"
9 int bcast_SMP_binary_segment_byte = 8192;
10 namespace simgrid::smpi {
11 int bcast__SMP_binary(void *buf, int count,
12 MPI_Datatype datatype, int root,
15 int tag = COLL_TAG_BCAST;
21 extent = datatype->get_extent();
25 if(comm->get_leaders_comm()==MPI_COMM_NULL){
29 if (comm->is_uniform()){
30 host_num_core = comm->get_intra_comm()->size();
32 //implementation buggy in this case
33 return bcast__mpich(buf , count, datatype, root, comm);
36 int segment = bcast_SMP_binary_segment_byte / extent;
37 int pipe_length = count / segment;
38 int remainder = count % segment;
40 int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
41 int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
42 int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
43 int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
44 int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
45 int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
46 int increment = segment * extent;
48 int base = (rank / host_num_core) * host_num_core;
49 int num_core = host_num_core;
50 if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
51 num_core = size - (rank / host_num_core) * host_num_core;
53 // if root is not zero send to rank zero first
56 Request::send(buf, count, datatype, 0, tag, comm);
58 Request::recv(buf, count, datatype, root, tag, comm, &status);
60 // when a message is smaller than a block size => no pipeline
61 if (count <= segment) {
62 // case ROOT-of-each-SMP
63 if (rank % host_num_core == 0) {
66 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
67 if (to_inter_left < size)
68 Request::send(buf, count, datatype, to_inter_left, tag, comm);
69 if (to_inter_right < size)
70 Request::send(buf, count, datatype, to_inter_right, tag, comm);
71 if ((to_intra_left - base) < num_core)
72 Request::send(buf, count, datatype, to_intra_left, tag, comm);
73 if ((to_intra_right - base) < num_core)
74 Request::send(buf, count, datatype, to_intra_right, tag, comm);
76 // case LEAVES ROOT-of-eash-SMP
77 else if (to_inter_left >= size) {
78 //printf("node %d from %d\n",rank,from_inter);
79 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
80 Request::wait(&request, &status);
81 if ((to_intra_left - base) < num_core)
82 Request::send(buf, count, datatype, to_intra_left, tag, comm);
83 if ((to_intra_right - base) < num_core)
84 Request::send(buf, count, datatype, to_intra_right, tag, comm);
86 // case INTERMEDIAT ROOT-of-each-SMP
88 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
89 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
90 Request::wait(&request, &status);
91 Request::send(buf, count, datatype, to_inter_left, tag, comm);
92 if (to_inter_right < size)
93 Request::send(buf, count, datatype, to_inter_right, tag, comm);
94 if ((to_intra_left - base) < num_core)
95 Request::send(buf, count, datatype, to_intra_left, tag, comm);
96 if ((to_intra_right - base) < num_core)
97 Request::send(buf, count, datatype, to_intra_right, tag, comm);
100 // case non ROOT-of-each-SMP
103 if ((to_intra_left - base) >= num_core) {
104 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
105 Request::wait(&request, &status);
109 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
110 Request::wait(&request, &status);
111 Request::send(buf, count, datatype, to_intra_left, tag, comm);
112 if ((to_intra_right - base) < num_core)
113 Request::send(buf, count, datatype, to_intra_right, tag, comm);
122 auto* request_array = new MPI_Request[size + pipe_length];
123 auto* status_array = new MPI_Status[size + pipe_length];
125 // case ROOT-of-each-SMP
126 if (rank % host_num_core == 0) {
129 for (i = 0; i < pipe_length; i++) {
130 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
131 if (to_inter_left < size)
132 Request::send((char *) buf + (i * increment), segment, datatype,
133 to_inter_left, (tag + i), comm);
134 if (to_inter_right < size)
135 Request::send((char *) buf + (i * increment), segment, datatype,
136 to_inter_right, (tag + i), comm);
137 if ((to_intra_left - base) < num_core)
138 Request::send((char *) buf + (i * increment), segment, datatype,
139 to_intra_left, (tag + i), comm);
140 if ((to_intra_right - base) < num_core)
141 Request::send((char *) buf + (i * increment), segment, datatype,
142 to_intra_right, (tag + i), comm);
145 // case LEAVES ROOT-of-eash-SMP
146 else if (to_inter_left >= size) {
147 //printf("node %d from %d\n",rank,from_inter);
148 for (i = 0; i < pipe_length; i++) {
149 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
150 from_inter, (tag + i), comm);
152 for (i = 0; i < pipe_length; i++) {
153 Request::wait(&request_array[i], &status);
154 if ((to_intra_left - base) < num_core)
155 Request::send((char *) buf + (i * increment), segment, datatype,
156 to_intra_left, (tag + i), comm);
157 if ((to_intra_right - base) < num_core)
158 Request::send((char *) buf + (i * increment), segment, datatype,
159 to_intra_right, (tag + i), comm);
162 // case INTERMEDIAT ROOT-of-each-SMP
164 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
165 for (i = 0; i < pipe_length; i++) {
166 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
167 from_inter, (tag + i), comm);
169 for (i = 0; i < pipe_length; i++) {
170 Request::wait(&request_array[i], &status);
171 Request::send((char *) buf + (i * increment), segment, datatype,
172 to_inter_left, (tag + i), comm);
173 if (to_inter_right < size)
174 Request::send((char *) buf + (i * increment), segment, datatype,
175 to_inter_right, (tag + i), comm);
176 if ((to_intra_left - base) < num_core)
177 Request::send((char *) buf + (i * increment), segment, datatype,
178 to_intra_left, (tag + i), comm);
179 if ((to_intra_right - base) < num_core)
180 Request::send((char *) buf + (i * increment), segment, datatype,
181 to_intra_right, (tag + i), comm);
185 // case non-ROOT-of-each-SMP
188 if ((to_intra_left - base) >= num_core) {
189 for (i = 0; i < pipe_length; i++) {
190 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
191 from_intra, (tag + i), comm);
193 Request::waitall((pipe_length), request_array, status_array);
197 for (i = 0; i < pipe_length; i++) {
198 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
199 from_intra, (tag + i), comm);
201 for (i = 0; i < pipe_length; i++) {
202 Request::wait(&request_array[i], &status);
203 Request::send((char *) buf + (i * increment), segment, datatype,
204 to_intra_left, (tag + i), comm);
205 if ((to_intra_right - base) < num_core)
206 Request::send((char *) buf + (i * increment), segment, datatype,
207 to_intra_right, (tag + i), comm);
212 delete[] request_array;
213 delete[] status_array;
216 if ((remainder != 0) && (count > segment)) {
217 XBT_INFO("MPI_bcast_SMP_binary: count is not divisible by block size, use default MPI_bcast for remainer.");
218 colls::bcast((char*)buf + (pipe_length * increment), remainder, datatype, root, comm);
224 } // namespace simgrid::smpi