Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Use DBL_MAX for values of type double.
[simgrid.git] / src / smpi / colls / bcast-SMP-binary.c
1 #include "colls_private.h"
2 #ifndef NUM_CORE
3 #define NUM_CORE 8
4 #endif
5
6 int bcast_SMP_binary_segment_byte = 8192;
7
8 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
9                                      MPI_Datatype datatype, int root,
10                                      MPI_Comm comm)
11 {
12   int tag = COLL_TAG_BCAST;
13   MPI_Status status;
14   MPI_Request request;
15   MPI_Request *request_array;
16   MPI_Status *status_array;
17   int rank, size;
18   int i;
19   MPI_Aint extent;
20   extent = smpi_datatype_get_extent(datatype);
21
22   rank = smpi_comm_rank(comm);
23   size = smpi_comm_size(comm);
24
25   int segment = bcast_SMP_binary_segment_byte / extent;
26   int pipe_length = count / segment;
27   int remainder = count % segment;
28
29   int to_intra_left = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 1;
30   int to_intra_right = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 2;
31   int to_inter_left = ((rank / NUM_CORE) * 2 + 1) * NUM_CORE;
32   int to_inter_right = ((rank / NUM_CORE) * 2 + 2) * NUM_CORE;
33   int from_inter = (((rank / NUM_CORE) - 1) / 2) * NUM_CORE;
34   int from_intra = (rank / NUM_CORE) * NUM_CORE + ((rank % NUM_CORE) - 1) / 2;
35   int increment = segment * extent;
36
37   int base = (rank / NUM_CORE) * NUM_CORE;
38   int num_core = NUM_CORE;
39   if (((rank / NUM_CORE) * NUM_CORE) == ((size / NUM_CORE) * NUM_CORE))
40     num_core = size - (rank / NUM_CORE) * NUM_CORE;
41
42   // if root is not zero send to rank zero first
43   if (root != 0) {
44     if (rank == root)
45       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
46     else if (rank == 0)
47       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
48   }
49   // when a message is smaller than a block size => no pipeline 
50   if (count <= segment) {
51     // case ROOT-of-each-SMP
52     if (rank % NUM_CORE == 0) {
53       // case ROOT
54       if (rank == 0) {
55         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
56         if (to_inter_left < size)
57           smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
58         if (to_inter_right < size)
59           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
60         if ((to_intra_left - base) < num_core)
61           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
62         if ((to_intra_right - base) < num_core)
63           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
64       }
65       // case LEAVES ROOT-of-eash-SMP
66       else if (to_inter_left >= size) {
67         //printf("node %d from %d\n",rank,from_inter);
68         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
69         smpi_mpi_wait(&request, &status);
70         if ((to_intra_left - base) < num_core)
71           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
72         if ((to_intra_right - base) < num_core)
73           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
74       }
75       // case INTERMEDIAT ROOT-of-each-SMP
76       else {
77         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
78         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
79         smpi_mpi_wait(&request, &status);
80         smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
81         if (to_inter_right < size)
82           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
83         if ((to_intra_left - base) < num_core)
84           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
85         if ((to_intra_right - base) < num_core)
86           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
87       }
88     }
89     // case non ROOT-of-each-SMP
90     else {
91       // case leaves
92       if ((to_intra_left - base) >= num_core) {
93         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
94         smpi_mpi_wait(&request, &status);
95       }
96       // case intermediate
97       else {
98         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
99         smpi_mpi_wait(&request, &status);
100         smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
101         if ((to_intra_right - base) < num_core)
102           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
103       }
104     }
105
106     return MPI_SUCCESS;
107   }
108
109   // pipeline bcast
110   else {
111     request_array =
112         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
113     status_array =
114         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
115
116     // case ROOT-of-each-SMP
117     if (rank % NUM_CORE == 0) {
118       // case ROOT
119       if (rank == 0) {
120         for (i = 0; i < pipe_length; i++) {
121           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
122           if (to_inter_left < size)
123             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
124                      to_inter_left, (tag + i), comm);
125           if (to_inter_right < size)
126             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
127                      to_inter_right, (tag + i), comm);
128           if ((to_intra_left - base) < num_core)
129             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
130                      to_intra_left, (tag + i), comm);
131           if ((to_intra_right - base) < num_core)
132             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
133                      to_intra_right, (tag + i), comm);
134         }
135       }
136       // case LEAVES ROOT-of-eash-SMP
137       else if (to_inter_left >= size) {
138         //printf("node %d from %d\n",rank,from_inter);
139         for (i = 0; i < pipe_length; i++) {
140           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
141                     from_inter, (tag + i), comm);
142         }
143         for (i = 0; i < pipe_length; i++) {
144           smpi_mpi_wait(&request_array[i], &status);
145           if ((to_intra_left - base) < num_core)
146             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
147                      to_intra_left, (tag + i), comm);
148           if ((to_intra_right - base) < num_core)
149             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
150                      to_intra_right, (tag + i), comm);
151         }
152       }
153       // case INTERMEDIAT ROOT-of-each-SMP
154       else {
155         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
156         for (i = 0; i < pipe_length; i++) {
157           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
158                     from_inter, (tag + i), comm);
159         }
160         for (i = 0; i < pipe_length; i++) {
161           smpi_mpi_wait(&request_array[i], &status);
162           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
163                    to_inter_left, (tag + i), comm);
164           if (to_inter_right < size)
165             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
166                      to_inter_right, (tag + i), comm);
167           if ((to_intra_left - base) < num_core)
168             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
169                      to_intra_left, (tag + i), comm);
170           if ((to_intra_right - base) < num_core)
171             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
172                      to_intra_right, (tag + i), comm);
173         }
174       }
175     }
176     // case non-ROOT-of-each-SMP
177     else {
178       // case leaves
179       if ((to_intra_left - base) >= num_core) {
180         for (i = 0; i < pipe_length; i++) {
181           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
182                     from_intra, (tag + i), comm);
183         }
184         smpi_mpi_waitall((pipe_length), request_array, status_array);
185       }
186       // case intermediate
187       else {
188         for (i = 0; i < pipe_length; i++) {
189           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
190                     from_intra, (tag + i), comm);
191         }
192         for (i = 0; i < pipe_length; i++) {
193           smpi_mpi_wait(&request_array[i], &status);
194           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
195                    to_intra_left, (tag + i), comm);
196           if ((to_intra_right - base) < num_core)
197             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
198                      to_intra_right, (tag + i), comm);
199         }
200       }
201     }
202
203     free(request_array);
204     free(status_array);
205   }
206
207   // when count is not divisible by block size, use default BCAST for the remainder
208   if ((remainder != 0) && (count > segment)) {
209     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");      
210     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
211               root, comm);
212   }
213
214   return 1;
215 }