Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
protect (hopefully) collective communication algorithms from abuse.
[simgrid.git] / src / smpi / colls / bcast-SMP-binary.c
1 #include "colls_private.h"
2 #ifndef NUM_CORE
3 #define NUM_CORE 8
4 #endif
5
6 int bcast_SMP_binary_segment_byte = 8192;
7
8 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
9                                      MPI_Datatype datatype, int root,
10                                      MPI_Comm comm)
11 {
12   int tag = COLL_TAG_BCAST;
13   MPI_Status status;
14   MPI_Request request;
15   MPI_Request *request_array;
16   MPI_Status *status_array;
17   int rank, size;
18   int i;
19   MPI_Aint extent;
20   extent = smpi_datatype_get_extent(datatype);
21
22   rank = smpi_comm_rank(comm);
23   size = smpi_comm_size(comm);
24
25   if(size%NUM_CORE)
26     THROWF(arg_error,0, "bcast SMP binary can't be used with non multiple of NUM_CORE=%d number of processes ! ",NUM_CORE);
27
28   int segment = bcast_SMP_binary_segment_byte / extent;
29   int pipe_length = count / segment;
30   int remainder = count % segment;
31
32   int to_intra_left = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 1;
33   int to_intra_right = (rank / NUM_CORE) * NUM_CORE + (rank % NUM_CORE) * 2 + 2;
34   int to_inter_left = ((rank / NUM_CORE) * 2 + 1) * NUM_CORE;
35   int to_inter_right = ((rank / NUM_CORE) * 2 + 2) * NUM_CORE;
36   int from_inter = (((rank / NUM_CORE) - 1) / 2) * NUM_CORE;
37   int from_intra = (rank / NUM_CORE) * NUM_CORE + ((rank % NUM_CORE) - 1) / 2;
38   int increment = segment * extent;
39
40   int base = (rank / NUM_CORE) * NUM_CORE;
41   int num_core = NUM_CORE;
42   if (((rank / NUM_CORE) * NUM_CORE) == ((size / NUM_CORE) * NUM_CORE))
43     num_core = size - (rank / NUM_CORE) * NUM_CORE;
44
45   // if root is not zero send to rank zero first
46   if (root != 0) {
47     if (rank == root)
48       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
49     else if (rank == 0)
50       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
51   }
52   // when a message is smaller than a block size => no pipeline 
53   if (count <= segment) {
54     // case ROOT-of-each-SMP
55     if (rank % NUM_CORE == 0) {
56       // case ROOT
57       if (rank == 0) {
58         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
59         if (to_inter_left < size)
60           smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
61         if (to_inter_right < size)
62           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
63         if ((to_intra_left - base) < num_core)
64           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
65         if ((to_intra_right - base) < num_core)
66           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
67       }
68       // case LEAVES ROOT-of-eash-SMP
69       else if (to_inter_left >= size) {
70         //printf("node %d from %d\n",rank,from_inter);
71         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
72         smpi_mpi_wait(&request, &status);
73         if ((to_intra_left - base) < num_core)
74           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
75         if ((to_intra_right - base) < num_core)
76           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
77       }
78       // case INTERMEDIAT ROOT-of-each-SMP
79       else {
80         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
81         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
82         smpi_mpi_wait(&request, &status);
83         smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
84         if (to_inter_right < size)
85           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
86         if ((to_intra_left - base) < num_core)
87           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
88         if ((to_intra_right - base) < num_core)
89           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
90       }
91     }
92     // case non ROOT-of-each-SMP
93     else {
94       // case leaves
95       if ((to_intra_left - base) >= num_core) {
96         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
97         smpi_mpi_wait(&request, &status);
98       }
99       // case intermediate
100       else {
101         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
102         smpi_mpi_wait(&request, &status);
103         smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
104         if ((to_intra_right - base) < num_core)
105           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
106       }
107     }
108
109     return MPI_SUCCESS;
110   }
111
112   // pipeline bcast
113   else {
114     request_array =
115         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
116     status_array =
117         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
118
119     // case ROOT-of-each-SMP
120     if (rank % NUM_CORE == 0) {
121       // case ROOT
122       if (rank == 0) {
123         for (i = 0; i < pipe_length; i++) {
124           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
125           if (to_inter_left < size)
126             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
127                      to_inter_left, (tag + i), comm);
128           if (to_inter_right < size)
129             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
130                      to_inter_right, (tag + i), comm);
131           if ((to_intra_left - base) < num_core)
132             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
133                      to_intra_left, (tag + i), comm);
134           if ((to_intra_right - base) < num_core)
135             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
136                      to_intra_right, (tag + i), comm);
137         }
138       }
139       // case LEAVES ROOT-of-eash-SMP
140       else if (to_inter_left >= size) {
141         //printf("node %d from %d\n",rank,from_inter);
142         for (i = 0; i < pipe_length; i++) {
143           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
144                     from_inter, (tag + i), comm);
145         }
146         for (i = 0; i < pipe_length; i++) {
147           smpi_mpi_wait(&request_array[i], &status);
148           if ((to_intra_left - base) < num_core)
149             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
150                      to_intra_left, (tag + i), comm);
151           if ((to_intra_right - base) < num_core)
152             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
153                      to_intra_right, (tag + i), comm);
154         }
155       }
156       // case INTERMEDIAT ROOT-of-each-SMP
157       else {
158         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
159         for (i = 0; i < pipe_length; i++) {
160           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
161                     from_inter, (tag + i), comm);
162         }
163         for (i = 0; i < pipe_length; i++) {
164           smpi_mpi_wait(&request_array[i], &status);
165           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
166                    to_inter_left, (tag + i), comm);
167           if (to_inter_right < size)
168             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
169                      to_inter_right, (tag + i), comm);
170           if ((to_intra_left - base) < num_core)
171             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
172                      to_intra_left, (tag + i), comm);
173           if ((to_intra_right - base) < num_core)
174             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
175                      to_intra_right, (tag + i), comm);
176         }
177       }
178     }
179     // case non-ROOT-of-each-SMP
180     else {
181       // case leaves
182       if ((to_intra_left - base) >= num_core) {
183         for (i = 0; i < pipe_length; i++) {
184           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
185                     from_intra, (tag + i), comm);
186         }
187         smpi_mpi_waitall((pipe_length), request_array, status_array);
188       }
189       // case intermediate
190       else {
191         for (i = 0; i < pipe_length; i++) {
192           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
193                     from_intra, (tag + i), comm);
194         }
195         for (i = 0; i < pipe_length; i++) {
196           smpi_mpi_wait(&request_array[i], &status);
197           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
198                    to_intra_left, (tag + i), comm);
199           if ((to_intra_right - base) < num_core)
200             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
201                      to_intra_right, (tag + i), comm);
202         }
203       }
204     }
205
206     free(request_array);
207     free(status_array);
208   }
209
210   // when count is not divisible by block size, use default BCAST for the remainder
211   if ((remainder != 0) && (count > segment)) {
212     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");      
213     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
214               root, comm);
215   }
216
217   return 1;
218 }