1 #include "colls_private.h"
6 int bcast_SMP_binary_segment_byte = 8192;
8 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
9 MPI_Datatype datatype, int root,
12 int tag = COLL_TAG_BCAST;
15 MPI_Request *request_array;
16 MPI_Status *status_array;
20 extent = smpi_datatype_get_extent(datatype);
22 rank = smpi_comm_rank(comm);
23 size = smpi_comm_size(comm);
26 THROWF(arg_error,0, "bcast SMP binary can't be used with non multiple of NUM_CORE=%d number of processes ! ",NUM_CORE);
28 int segment = bcast_SMP_binary_segment_byte / extent;
29 int pipe_length = count / segment;
30 int remainder = count % segment;
32 int to_intra_left = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 1;
33 int to_intra_right = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 2;
34 int to_inter_left = ((rank / NUM_CORE) * 2 + 1) * NUM_CORE;
35 int to_inter_right = ((rank / NUM_CORE) * 2 + 2) * NUM_CORE;
36 int from_inter = (((rank / NUM_CORE) - 1) / 2) * NUM_CORE;
37 int from_intra = (rank / NUM_CORE) * NUM_CORE + ((rank % NUM_CORE) - 1) / 2;
38 int increment = segment * extent;
40 int base = (rank / NUM_CORE) * NUM_CORE;
41 int num_core = NUM_CORE;
42 if (((rank / NUM_CORE) * NUM_CORE) == ((size / NUM_CORE) * NUM_CORE))
43 num_core = size - (rank / NUM_CORE) * NUM_CORE;
45 // if root is not zero send to rank zero first
48 smpi_mpi_send(buf, count, datatype, 0, tag, comm);
50 smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
52 // when a message is smaller than a block size => no pipeline
53 if (count <= segment) {
54 // case ROOT-of-each-SMP
55 if (rank % NUM_CORE == 0) {
58 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
59 if (to_inter_left < size)
60 smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
61 if (to_inter_right < size)
62 smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
63 if ((to_intra_left - base) < num_core)
64 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
65 if ((to_intra_right - base) < num_core)
66 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
68 // case LEAVES ROOT-of-eash-SMP
69 else if (to_inter_left >= size) {
70 //printf("node %d from %d\n",rank,from_inter);
71 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
72 smpi_mpi_wait(&request, &status);
73 if ((to_intra_left - base) < num_core)
74 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
75 if ((to_intra_right - base) < num_core)
76 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
78 // case INTERMEDIAT ROOT-of-each-SMP
80 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
81 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
82 smpi_mpi_wait(&request, &status);
83 smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
84 if (to_inter_right < size)
85 smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
86 if ((to_intra_left - base) < num_core)
87 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
88 if ((to_intra_right - base) < num_core)
89 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
92 // case non ROOT-of-each-SMP
95 if ((to_intra_left - base) >= num_core) {
96 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
97 smpi_mpi_wait(&request, &status);
101 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
102 smpi_mpi_wait(&request, &status);
103 smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
104 if ((to_intra_right - base) < num_core)
105 smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
115 (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
117 (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
119 // case ROOT-of-each-SMP
120 if (rank % NUM_CORE == 0) {
123 for (i = 0; i < pipe_length; i++) {
124 //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
125 if (to_inter_left < size)
126 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
127 to_inter_left, (tag + i), comm);
128 if (to_inter_right < size)
129 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
130 to_inter_right, (tag + i), comm);
131 if ((to_intra_left - base) < num_core)
132 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
133 to_intra_left, (tag + i), comm);
134 if ((to_intra_right - base) < num_core)
135 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
136 to_intra_right, (tag + i), comm);
139 // case LEAVES ROOT-of-eash-SMP
140 else if (to_inter_left >= size) {
141 //printf("node %d from %d\n",rank,from_inter);
142 for (i = 0; i < pipe_length; i++) {
143 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
144 from_inter, (tag + i), comm);
146 for (i = 0; i < pipe_length; i++) {
147 smpi_mpi_wait(&request_array[i], &status);
148 if ((to_intra_left - base) < num_core)
149 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
150 to_intra_left, (tag + i), comm);
151 if ((to_intra_right - base) < num_core)
152 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
153 to_intra_right, (tag + i), comm);
156 // case INTERMEDIAT ROOT-of-each-SMP
158 //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
159 for (i = 0; i < pipe_length; i++) {
160 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
161 from_inter, (tag + i), comm);
163 for (i = 0; i < pipe_length; i++) {
164 smpi_mpi_wait(&request_array[i], &status);
165 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
166 to_inter_left, (tag + i), comm);
167 if (to_inter_right < size)
168 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
169 to_inter_right, (tag + i), comm);
170 if ((to_intra_left - base) < num_core)
171 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
172 to_intra_left, (tag + i), comm);
173 if ((to_intra_right - base) < num_core)
174 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
175 to_intra_right, (tag + i), comm);
179 // case non-ROOT-of-each-SMP
182 if ((to_intra_left - base) >= num_core) {
183 for (i = 0; i < pipe_length; i++) {
184 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
185 from_intra, (tag + i), comm);
187 smpi_mpi_waitall((pipe_length), request_array, status_array);
191 for (i = 0; i < pipe_length; i++) {
192 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
193 from_intra, (tag + i), comm);
195 for (i = 0; i < pipe_length; i++) {
196 smpi_mpi_wait(&request_array[i], &status);
197 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
198 to_intra_left, (tag + i), comm);
199 if ((to_intra_right - base) < num_core)
200 smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
201 to_intra_right, (tag + i), comm);
210 // when count is not divisible by block size, use default BCAST for the remainder
211 if ((remainder != 0) && (count > segment)) {
212 XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");
213 smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,