1 #include "colls_private.h"
6 int bcast_SMP_binary_segment_byte = 8192;
8 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
9 MPI_Datatype datatype, int root,
12 int tag = COLL_TAG_BCAST;
15 MPI_Request *request_array;
16 MPI_Status *status_array;
20 extent = smpi_datatype_get_extent(datatype);
22 rank = smpi_comm_rank(comm);
23 size = smpi_comm_size(comm);
24 int host_num_core = simcall_host_get_core(SIMIX_host_self());
25 // do we use the default one or the number of cores in the platform ?
26 // if the number of cores is one, the platform may be simulated with 1 node = 1 core
27 if (host_num_core == 1) host_num_core = NUM_CORE;
29 if(size%host_num_core)
30 THROWF(arg_error,0, "bcast SMP binary can't be used with non multiple of NUM_CORE=%d number of processes ! ",host_num_core);
32 int segment = bcast_SMP_binary_segment_byte / extent;
33 int pipe_length = count / segment;
34 int remainder = count % segment;
36 int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
37 int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
38 int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
39 int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
40 int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
41 int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
42 int increment = segment * extent;
44 int base = (rank / host_num_core) * host_num_core;
45 int num_core = host_num_core;
46 if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
47 num_core = size - (rank / host_num_core) * host_num_core;
49 // if root is not zero send to rank zero first
52 smpi_mpi_send(buf, count, datatype, 0, tag, comm);
54 smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
56 // when a message is smaller than a block size => no pipeline
57 if (count <= segment) {
58 // case ROOT-of-each-SMP
59 if (rank % host_num_core == 0) {
62 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
63 if (to_inter_left < size)
64 smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
65 if (to_inter_right < size)
66 smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
67 if ((to_intra_left - base) < num_core)
68 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
69 if ((to_intra_right - base) < num_core)
70 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
72 // case LEAVES ROOT-of-eash-SMP
73 else if (to_inter_left >= size) {
74 //printf("node %d from %d\n",rank,from_inter);
75 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
76 smpi_mpi_wait(&request, &status);
77 if ((to_intra_left - base) < num_core)
78 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
79 if ((to_intra_right - base) < num_core)
80 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
82 // case INTERMEDIAT ROOT-of-each-SMP
84 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
85 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
86 smpi_mpi_wait(&request, &status);
87 smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
88 if (to_inter_right < size)
89 smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
90 if ((to_intra_left - base) < num_core)
91 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
92 if ((to_intra_right - base) < num_core)
93 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
96 // case non ROOT-of-each-SMP
99 if ((to_intra_left - base) >= num_core) {
100 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
101 smpi_mpi_wait(&request, &status);
105 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
106 smpi_mpi_wait(&request, &status);
107 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
108 if ((to_intra_right - base) < num_core)
109 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
119 (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
121 (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
123 // case ROOT-of-each-SMP
124 if (rank % host_num_core == 0) {
127 for (i = 0; i < pipe_length; i++) {
128 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
129 if (to_inter_left < size)
130 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
131 to_inter_left, (tag + i), comm);
132 if (to_inter_right < size)
133 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
134 to_inter_right, (tag + i), comm);
135 if ((to_intra_left - base) < num_core)
136 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
137 to_intra_left, (tag + i), comm);
138 if ((to_intra_right - base) < num_core)
139 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
140 to_intra_right, (tag + i), comm);
143 // case LEAVES ROOT-of-eash-SMP
144 else if (to_inter_left >= size) {
145 //printf("node %d from %d\n",rank,from_inter);
146 for (i = 0; i < pipe_length; i++) {
147 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
148 from_inter, (tag + i), comm);
150 for (i = 0; i < pipe_length; i++) {
151 smpi_mpi_wait(&request_array[i], &status);
152 if ((to_intra_left - base) < num_core)
153 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
154 to_intra_left, (tag + i), comm);
155 if ((to_intra_right - base) < num_core)
156 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
157 to_intra_right, (tag + i), comm);
160 // case INTERMEDIAT ROOT-of-each-SMP
162 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
163 for (i = 0; i < pipe_length; i++) {
164 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
165 from_inter, (tag + i), comm);
167 for (i = 0; i < pipe_length; i++) {
168 smpi_mpi_wait(&request_array[i], &status);
169 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
170 to_inter_left, (tag + i), comm);
171 if (to_inter_right < size)
172 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
173 to_inter_right, (tag + i), comm);
174 if ((to_intra_left - base) < num_core)
175 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
176 to_intra_left, (tag + i), comm);
177 if ((to_intra_right - base) < num_core)
178 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
179 to_intra_right, (tag + i), comm);
183 // case non-ROOT-of-each-SMP
186 if ((to_intra_left - base) >= num_core) {
187 for (i = 0; i < pipe_length; i++) {
188 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
189 from_intra, (tag + i), comm);
191 smpi_mpi_waitall((pipe_length), request_array, status_array);
195 for (i = 0; i < pipe_length; i++) {
196 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
197 from_intra, (tag + i), comm);
199 for (i = 0; i < pipe_length; i++) {
200 smpi_mpi_wait(&request_array[i], &status);
201 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
202 to_intra_left, (tag + i), comm);
203 if ((to_intra_right - base) < num_core)
204 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
205 to_intra_right, (tag + i), comm);
214 // when count is not divisible by block size, use default BCAST for the remainder
215 if ((remainder != 0) && (count > segment)) {
216 XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
217 smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,