Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge 'master' into mc
[simgrid.git] / src / smpi / colls / bcast-SMP-binary.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8 #ifndef NUM_CORE
9 #define NUM_CORE 8
10 #endif
11
12 int bcast_SMP_binary_segment_byte = 8192;
13
14 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
15                                      MPI_Datatype datatype, int root,
16                                      MPI_Comm comm)
17 {
18   int tag = COLL_TAG_BCAST;
19   MPI_Status status;
20   MPI_Request request;
21   MPI_Request *request_array;
22   MPI_Status *status_array;
23   int rank, size;
24   int i;
25   MPI_Aint extent;
26   extent = smpi_datatype_get_extent(datatype);
27
28   rank = smpi_comm_rank(comm);
29   size = smpi_comm_size(comm);
30   int host_num_core = simcall_host_get_core(SIMIX_host_self());
31   // do we use the default one or the number of cores in the platform ?
32   // if the number of cores is one, the platform may be simulated with 1 node = 1 core
33   if (host_num_core == 1) host_num_core = NUM_CORE;
34
35   if(size%host_num_core)
36     THROWF(arg_error,0, "bcast SMP binary can't be used with non multiple of NUM_CORE=%d number of processes ! ",host_num_core);
37
38   int segment = bcast_SMP_binary_segment_byte / extent;
39   int pipe_length = count / segment;
40   int remainder = count % segment;
41
42   int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
43   int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
44   int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
45   int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
46   int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
47   int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
48   int increment = segment * extent;
49
50   int base = (rank / host_num_core) * host_num_core;
51   int num_core = host_num_core;
52   if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
53     num_core = size - (rank / host_num_core) * host_num_core;
54
55   // if root is not zero send to rank zero first
56   if (root != 0) {
57     if (rank == root)
58       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
59     else if (rank == 0)
60       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
61   }
62   // when a message is smaller than a block size => no pipeline 
63   if (count <= segment) {
64     // case ROOT-of-each-SMP
65     if (rank % host_num_core == 0) {
66       // case ROOT
67       if (rank == 0) {
68         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
69         if (to_inter_left < size)
70           smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
71         if (to_inter_right < size)
72           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
73         if ((to_intra_left - base) < num_core)
74           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
75         if ((to_intra_right - base) < num_core)
76           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
77       }
78       // case LEAVES ROOT-of-eash-SMP
79       else if (to_inter_left >= size) {
80         //printf("node %d from %d\n",rank,from_inter);
81         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
82         smpi_mpi_wait(&request, &status);
83         if ((to_intra_left - base) < num_core)
84           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
85         if ((to_intra_right - base) < num_core)
86           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
87       }
88       // case INTERMEDIAT ROOT-of-each-SMP
89       else {
90         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
91         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
92         smpi_mpi_wait(&request, &status);
93         smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
94         if (to_inter_right < size)
95           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
96         if ((to_intra_left - base) < num_core)
97           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
98         if ((to_intra_right - base) < num_core)
99           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
100       }
101     }
102     // case non ROOT-of-each-SMP
103     else {
104       // case leaves
105       if ((to_intra_left - base) >= num_core) {
106         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
107         smpi_mpi_wait(&request, &status);
108       }
109       // case intermediate
110       else {
111         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
112         smpi_mpi_wait(&request, &status);
113         smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
114         if ((to_intra_right - base) < num_core)
115           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
116       }
117     }
118
119     return MPI_SUCCESS;
120   }
121
122   // pipeline bcast
123   else {
124     request_array =
125         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
126     status_array =
127         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
128
129     // case ROOT-of-each-SMP
130     if (rank % host_num_core == 0) {
131       // case ROOT
132       if (rank == 0) {
133         for (i = 0; i < pipe_length; i++) {
134           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
135           if (to_inter_left < size)
136             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
137                      to_inter_left, (tag + i), comm);
138           if (to_inter_right < size)
139             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
140                      to_inter_right, (tag + i), comm);
141           if ((to_intra_left - base) < num_core)
142             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
143                      to_intra_left, (tag + i), comm);
144           if ((to_intra_right - base) < num_core)
145             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
146                      to_intra_right, (tag + i), comm);
147         }
148       }
149       // case LEAVES ROOT-of-eash-SMP
150       else if (to_inter_left >= size) {
151         //printf("node %d from %d\n",rank,from_inter);
152         for (i = 0; i < pipe_length; i++) {
153           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
154                     from_inter, (tag + i), comm);
155         }
156         for (i = 0; i < pipe_length; i++) {
157           smpi_mpi_wait(&request_array[i], &status);
158           if ((to_intra_left - base) < num_core)
159             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
160                      to_intra_left, (tag + i), comm);
161           if ((to_intra_right - base) < num_core)
162             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
163                      to_intra_right, (tag + i), comm);
164         }
165       }
166       // case INTERMEDIAT ROOT-of-each-SMP
167       else {
168         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
169         for (i = 0; i < pipe_length; i++) {
170           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
171                     from_inter, (tag + i), comm);
172         }
173         for (i = 0; i < pipe_length; i++) {
174           smpi_mpi_wait(&request_array[i], &status);
175           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
176                    to_inter_left, (tag + i), comm);
177           if (to_inter_right < size)
178             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
179                      to_inter_right, (tag + i), comm);
180           if ((to_intra_left - base) < num_core)
181             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
182                      to_intra_left, (tag + i), comm);
183           if ((to_intra_right - base) < num_core)
184             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
185                      to_intra_right, (tag + i), comm);
186         }
187       }
188     }
189     // case non-ROOT-of-each-SMP
190     else {
191       // case leaves
192       if ((to_intra_left - base) >= num_core) {
193         for (i = 0; i < pipe_length; i++) {
194           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
195                     from_intra, (tag + i), comm);
196         }
197         smpi_mpi_waitall((pipe_length), request_array, status_array);
198       }
199       // case intermediate
200       else {
201         for (i = 0; i < pipe_length; i++) {
202           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
203                     from_intra, (tag + i), comm);
204         }
205         for (i = 0; i < pipe_length; i++) {
206           smpi_mpi_wait(&request_array[i], &status);
207           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
208                    to_intra_left, (tag + i), comm);
209           if ((to_intra_right - base) < num_core)
210             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
211                      to_intra_right, (tag + i), comm);
212         }
213       }
214     }
215
216     free(request_array);
217     free(status_array);
218   }
219
220   // when count is not divisible by block size, use default BCAST for the remainder
221   if ((remainder != 0) && (count > segment)) {
222     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");      
223     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
224               root, comm);
225   }
226
227   return 1;
228 }