1 /* Copyright (c) 2013-2019. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "../colls_private.hpp"
9 int bcast_SMP_binary_segment_byte = 8192;
12 int Coll_bcast_SMP_binary::bcast(void *buf, int count,
13 MPI_Datatype datatype, int root,
16 int tag = COLL_TAG_BCAST;
22 extent = datatype->get_extent();
26 if(comm->get_leaders_comm()==MPI_COMM_NULL){
30 if (comm->is_uniform()){
31 host_num_core = comm->get_intra_comm()->size();
33 //implementation buggy in this case
34 return Coll_bcast_mpich::bcast( buf , count, datatype,
38 int segment = bcast_SMP_binary_segment_byte / extent;
39 int pipe_length = count / segment;
40 int remainder = count % segment;
42 int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
43 int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
44 int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
45 int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
46 int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
47 int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
48 int increment = segment * extent;
50 int base = (rank / host_num_core) * host_num_core;
51 int num_core = host_num_core;
52 if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
53 num_core = size - (rank / host_num_core) * host_num_core;
55 // if root is not zero send to rank zero first
58 Request::send(buf, count, datatype, 0, tag, comm);
60 Request::recv(buf, count, datatype, root, tag, comm, &status);
62 // when a message is smaller than a block size => no pipeline
63 if (count <= segment) {
64 // case ROOT-of-each-SMP
65 if (rank % host_num_core == 0) {
68 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
69 if (to_inter_left < size)
70 Request::send(buf, count, datatype, to_inter_left, tag, comm);
71 if (to_inter_right < size)
72 Request::send(buf, count, datatype, to_inter_right, tag, comm);
73 if ((to_intra_left - base) < num_core)
74 Request::send(buf, count, datatype, to_intra_left, tag, comm);
75 if ((to_intra_right - base) < num_core)
76 Request::send(buf, count, datatype, to_intra_right, tag, comm);
78 // case LEAVES ROOT-of-eash-SMP
79 else if (to_inter_left >= size) {
80 //printf("node %d from %d\n",rank,from_inter);
81 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
82 Request::wait(&request, &status);
83 if ((to_intra_left - base) < num_core)
84 Request::send(buf, count, datatype, to_intra_left, tag, comm);
85 if ((to_intra_right - base) < num_core)
86 Request::send(buf, count, datatype, to_intra_right, tag, comm);
88 // case INTERMEDIAT ROOT-of-each-SMP
90 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
91 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
92 Request::wait(&request, &status);
93 Request::send(buf, count, datatype, to_inter_left, tag, comm);
94 if (to_inter_right < size)
95 Request::send(buf, count, datatype, to_inter_right, tag, comm);
96 if ((to_intra_left - base) < num_core)
97 Request::send(buf, count, datatype, to_intra_left, tag, comm);
98 if ((to_intra_right - base) < num_core)
99 Request::send(buf, count, datatype, to_intra_right, tag, comm);
102 // case non ROOT-of-each-SMP
105 if ((to_intra_left - base) >= num_core) {
106 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
107 Request::wait(&request, &status);
111 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
112 Request::wait(&request, &status);
113 Request::send(buf, count, datatype, to_intra_left, tag, comm);
114 if ((to_intra_right - base) < num_core)
115 Request::send(buf, count, datatype, to_intra_right, tag, comm);
124 MPI_Request* request_array = new MPI_Request[size + pipe_length];
125 MPI_Status* status_array = new MPI_Status[size + pipe_length];
127 // case ROOT-of-each-SMP
128 if (rank % host_num_core == 0) {
131 for (i = 0; i < pipe_length; i++) {
132 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
133 if (to_inter_left < size)
134 Request::send((char *) buf + (i * increment), segment, datatype,
135 to_inter_left, (tag + i), comm);
136 if (to_inter_right < size)
137 Request::send((char *) buf + (i * increment), segment, datatype,
138 to_inter_right, (tag + i), comm);
139 if ((to_intra_left - base) < num_core)
140 Request::send((char *) buf + (i * increment), segment, datatype,
141 to_intra_left, (tag + i), comm);
142 if ((to_intra_right - base) < num_core)
143 Request::send((char *) buf + (i * increment), segment, datatype,
144 to_intra_right, (tag + i), comm);
147 // case LEAVES ROOT-of-eash-SMP
148 else if (to_inter_left >= size) {
149 //printf("node %d from %d\n",rank,from_inter);
150 for (i = 0; i < pipe_length; i++) {
151 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
152 from_inter, (tag + i), comm);
154 for (i = 0; i < pipe_length; i++) {
155 Request::wait(&request_array[i], &status);
156 if ((to_intra_left - base) < num_core)
157 Request::send((char *) buf + (i * increment), segment, datatype,
158 to_intra_left, (tag + i), comm);
159 if ((to_intra_right - base) < num_core)
160 Request::send((char *) buf + (i * increment), segment, datatype,
161 to_intra_right, (tag + i), comm);
164 // case INTERMEDIAT ROOT-of-each-SMP
166 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
167 for (i = 0; i < pipe_length; i++) {
168 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
169 from_inter, (tag + i), comm);
171 for (i = 0; i < pipe_length; i++) {
172 Request::wait(&request_array[i], &status);
173 Request::send((char *) buf + (i * increment), segment, datatype,
174 to_inter_left, (tag + i), comm);
175 if (to_inter_right < size)
176 Request::send((char *) buf + (i * increment), segment, datatype,
177 to_inter_right, (tag + i), comm);
178 if ((to_intra_left - base) < num_core)
179 Request::send((char *) buf + (i * increment), segment, datatype,
180 to_intra_left, (tag + i), comm);
181 if ((to_intra_right - base) < num_core)
182 Request::send((char *) buf + (i * increment), segment, datatype,
183 to_intra_right, (tag + i), comm);
187 // case non-ROOT-of-each-SMP
190 if ((to_intra_left - base) >= num_core) {
191 for (i = 0; i < pipe_length; i++) {
192 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
193 from_intra, (tag + i), comm);
195 Request::waitall((pipe_length), request_array, status_array);
199 for (i = 0; i < pipe_length; i++) {
200 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
201 from_intra, (tag + i), comm);
203 for (i = 0; i < pipe_length; i++) {
204 Request::wait(&request_array[i], &status);
205 Request::send((char *) buf + (i * increment), segment, datatype,
206 to_intra_left, (tag + i), comm);
207 if ((to_intra_right - base) < num_core)
208 Request::send((char *) buf + (i * increment), segment, datatype,
209 to_intra_right, (tag + i), comm);
214 delete[] request_array;
215 delete[] status_array;
218 // when count is not divisible by block size, use default BCAST for the remainder
219 if ((remainder != 0) && (count > segment)) {
220 XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
221 Colls::bcast((char *) buf + (pipe_length * increment), remainder, datatype,