1 /* Copyright (c) 2013-2017. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "../colls_private.h"
10 int bcast_SMP_binary_segment_byte = 8192;
13 int Coll_bcast_SMP_binary::bcast(void *buf, int count,
14 MPI_Datatype datatype, int root,
17 int tag = COLL_TAG_BCAST;
20 MPI_Request *request_array;
21 MPI_Status *status_array;
25 extent = datatype->get_extent();
29 if(comm->get_leaders_comm()==MPI_COMM_NULL){
33 if (comm->is_uniform()){
34 host_num_core = comm->get_intra_comm()->size();
36 //implementation buggy in this case
37 return Coll_bcast_mpich::bcast( buf , count, datatype,
41 int segment = bcast_SMP_binary_segment_byte / extent;
42 int pipe_length = count / segment;
43 int remainder = count % segment;
45 int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
46 int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
47 int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
48 int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
49 int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
50 int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
51 int increment = segment * extent;
53 int base = (rank / host_num_core) * host_num_core;
54 int num_core = host_num_core;
55 if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
56 num_core = size - (rank / host_num_core) * host_num_core;
58 // if root is not zero send to rank zero first
61 Request::send(buf, count, datatype, 0, tag, comm);
63 Request::recv(buf, count, datatype, root, tag, comm, &status);
65 // when a message is smaller than a block size => no pipeline
66 if (count <= segment) {
67 // case ROOT-of-each-SMP
68 if (rank % host_num_core == 0) {
71 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
72 if (to_inter_left < size)
73 Request::send(buf, count, datatype, to_inter_left, tag, comm);
74 if (to_inter_right < size)
75 Request::send(buf, count, datatype, to_inter_right, tag, comm);
76 if ((to_intra_left - base) < num_core)
77 Request::send(buf, count, datatype, to_intra_left, tag, comm);
78 if ((to_intra_right - base) < num_core)
79 Request::send(buf, count, datatype, to_intra_right, tag, comm);
81 // case LEAVES ROOT-of-eash-SMP
82 else if (to_inter_left >= size) {
83 //printf("node %d from %d\n",rank,from_inter);
84 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
85 Request::wait(&request, &status);
86 if ((to_intra_left - base) < num_core)
87 Request::send(buf, count, datatype, to_intra_left, tag, comm);
88 if ((to_intra_right - base) < num_core)
89 Request::send(buf, count, datatype, to_intra_right, tag, comm);
91 // case INTERMEDIAT ROOT-of-each-SMP
93 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
94 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
95 Request::wait(&request, &status);
96 Request::send(buf, count, datatype, to_inter_left, tag, comm);
97 if (to_inter_right < size)
98 Request::send(buf, count, datatype, to_inter_right, tag, comm);
99 if ((to_intra_left - base) < num_core)
100 Request::send(buf, count, datatype, to_intra_left, tag, comm);
101 if ((to_intra_right - base) < num_core)
102 Request::send(buf, count, datatype, to_intra_right, tag, comm);
105 // case non ROOT-of-each-SMP
108 if ((to_intra_left - base) >= num_core) {
109 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
110 Request::wait(&request, &status);
114 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
115 Request::wait(&request, &status);
116 Request::send(buf, count, datatype, to_intra_left, tag, comm);
117 if ((to_intra_right - base) < num_core)
118 Request::send(buf, count, datatype, to_intra_right, tag, comm);
128 (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
130 (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
132 // case ROOT-of-each-SMP
133 if (rank % host_num_core == 0) {
136 for (i = 0; i < pipe_length; i++) {
137 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
138 if (to_inter_left < size)
139 Request::send((char *) buf + (i * increment), segment, datatype,
140 to_inter_left, (tag + i), comm);
141 if (to_inter_right < size)
142 Request::send((char *) buf + (i * increment), segment, datatype,
143 to_inter_right, (tag + i), comm);
144 if ((to_intra_left - base) < num_core)
145 Request::send((char *) buf + (i * increment), segment, datatype,
146 to_intra_left, (tag + i), comm);
147 if ((to_intra_right - base) < num_core)
148 Request::send((char *) buf + (i * increment), segment, datatype,
149 to_intra_right, (tag + i), comm);
152 // case LEAVES ROOT-of-eash-SMP
153 else if (to_inter_left >= size) {
154 //printf("node %d from %d\n",rank,from_inter);
155 for (i = 0; i < pipe_length; i++) {
156 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
157 from_inter, (tag + i), comm);
159 for (i = 0; i < pipe_length; i++) {
160 Request::wait(&request_array[i], &status);
161 if ((to_intra_left - base) < num_core)
162 Request::send((char *) buf + (i * increment), segment, datatype,
163 to_intra_left, (tag + i), comm);
164 if ((to_intra_right - base) < num_core)
165 Request::send((char *) buf + (i * increment), segment, datatype,
166 to_intra_right, (tag + i), comm);
169 // case INTERMEDIAT ROOT-of-each-SMP
171 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
172 for (i = 0; i < pipe_length; i++) {
173 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
174 from_inter, (tag + i), comm);
176 for (i = 0; i < pipe_length; i++) {
177 Request::wait(&request_array[i], &status);
178 Request::send((char *) buf + (i * increment), segment, datatype,
179 to_inter_left, (tag + i), comm);
180 if (to_inter_right < size)
181 Request::send((char *) buf + (i * increment), segment, datatype,
182 to_inter_right, (tag + i), comm);
183 if ((to_intra_left - base) < num_core)
184 Request::send((char *) buf + (i * increment), segment, datatype,
185 to_intra_left, (tag + i), comm);
186 if ((to_intra_right - base) < num_core)
187 Request::send((char *) buf + (i * increment), segment, datatype,
188 to_intra_right, (tag + i), comm);
192 // case non-ROOT-of-each-SMP
195 if ((to_intra_left - base) >= num_core) {
196 for (i = 0; i < pipe_length; i++) {
197 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
198 from_intra, (tag + i), comm);
200 Request::waitall((pipe_length), request_array, status_array);
204 for (i = 0; i < pipe_length; i++) {
205 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
206 from_intra, (tag + i), comm);
208 for (i = 0; i < pipe_length; i++) {
209 Request::wait(&request_array[i], &status);
210 Request::send((char *) buf + (i * increment), segment, datatype,
211 to_intra_left, (tag + i), comm);
212 if ((to_intra_right - base) < num_core)
213 Request::send((char *) buf + (i * increment), segment, datatype,
214 to_intra_right, (tag + i), comm);
223 // when count is not divisible by block size, use default BCAST for the remainder
224 if ((remainder != 0) && (count > segment)) {
225 XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
226 Colls::bcast((char *) buf + (pipe_length * increment), remainder, datatype,