Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
04b40b22400ffe738c6a0bce75b864c3dc4a4392
[simgrid.git] / src / smpi / colls / bcast-SMP-binary.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8 #ifndef NUM_CORE
9 #define NUM_CORE 8
10 #endif
11
12 int bcast_SMP_binary_segment_byte = 8192;
13
14 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
15                                      MPI_Datatype datatype, int root,
16                                      MPI_Comm comm)
17 {
18   int tag = COLL_TAG_BCAST;
19   MPI_Status status;
20   MPI_Request request;
21   MPI_Request *request_array;
22   MPI_Status *status_array;
23   int rank, size;
24   int i;
25   MPI_Aint extent;
26   extent = smpi_datatype_get_extent(datatype);
27
28   rank = smpi_comm_rank(comm);
29   size = smpi_comm_size(comm);
30   int host_num_core = simcall_host_get_core(SIMIX_host_self());
31   // do we use the default one or the number of cores in the platform ?
32   // if the number of cores is one, the platform may be simulated with 1 node = 1 core
33   if (host_num_core == 1) host_num_core = NUM_CORE;
34
35   int segment = bcast_SMP_binary_segment_byte / extent;
36   int pipe_length = count / segment;
37   int remainder = count % segment;
38
39   int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
40   int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
41   int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
42   int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
43   int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
44   int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
45   int increment = segment * extent;
46
47   int base = (rank / host_num_core) * host_num_core;
48   int num_core = host_num_core;
49   if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
50     num_core = size - (rank / host_num_core) * host_num_core;
51
52   // if root is not zero send to rank zero first
53   if (root != 0) {
54     if (rank == root)
55       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
56     else if (rank == 0)
57       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
58   }
59   // when a message is smaller than a block size => no pipeline 
60   if (count <= segment) {
61     // case ROOT-of-each-SMP
62     if (rank % host_num_core == 0) {
63       // case ROOT
64       if (rank == 0) {
65         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
66         if (to_inter_left < size)
67           smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
68         if (to_inter_right < size)
69           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
70         if ((to_intra_left - base) < num_core)
71           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
72         if ((to_intra_right - base) < num_core)
73           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
74       }
75       // case LEAVES ROOT-of-eash-SMP
76       else if (to_inter_left >= size) {
77         //printf("node %d from %d\n",rank,from_inter);
78         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
79         smpi_mpi_wait(&request, &status);
80         if ((to_intra_left - base) < num_core)
81           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
82         if ((to_intra_right - base) < num_core)
83           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
84       }
85       // case INTERMEDIAT ROOT-of-each-SMP
86       else {
87         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
88         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
89         smpi_mpi_wait(&request, &status);
90         smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
91         if (to_inter_right < size)
92           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
93         if ((to_intra_left - base) < num_core)
94           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
95         if ((to_intra_right - base) < num_core)
96           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
97       }
98     }
99     // case non ROOT-of-each-SMP
100     else {
101       // case leaves
102       if ((to_intra_left - base) >= num_core) {
103         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
104         smpi_mpi_wait(&request, &status);
105       }
106       // case intermediate
107       else {
108         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
109         smpi_mpi_wait(&request, &status);
110         smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
111         if ((to_intra_right - base) < num_core)
112           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
113       }
114     }
115
116     return MPI_SUCCESS;
117   }
118
119   // pipeline bcast
120   else {
121     request_array =
122         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
123     status_array =
124         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
125
126     // case ROOT-of-each-SMP
127     if (rank % host_num_core == 0) {
128       // case ROOT
129       if (rank == 0) {
130         for (i = 0; i < pipe_length; i++) {
131           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
132           if (to_inter_left < size)
133             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
134                      to_inter_left, (tag + i), comm);
135           if (to_inter_right < size)
136             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
137                      to_inter_right, (tag + i), comm);
138           if ((to_intra_left - base) < num_core)
139             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
140                      to_intra_left, (tag + i), comm);
141           if ((to_intra_right - base) < num_core)
142             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
143                      to_intra_right, (tag + i), comm);
144         }
145       }
146       // case LEAVES ROOT-of-eash-SMP
147       else if (to_inter_left >= size) {
148         //printf("node %d from %d\n",rank,from_inter);
149         for (i = 0; i < pipe_length; i++) {
150           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
151                     from_inter, (tag + i), comm);
152         }
153         for (i = 0; i < pipe_length; i++) {
154           smpi_mpi_wait(&request_array[i], &status);
155           if ((to_intra_left - base) < num_core)
156             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
157                      to_intra_left, (tag + i), comm);
158           if ((to_intra_right - base) < num_core)
159             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
160                      to_intra_right, (tag + i), comm);
161         }
162       }
163       // case INTERMEDIAT ROOT-of-each-SMP
164       else {
165         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
166         for (i = 0; i < pipe_length; i++) {
167           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
168                     from_inter, (tag + i), comm);
169         }
170         for (i = 0; i < pipe_length; i++) {
171           smpi_mpi_wait(&request_array[i], &status);
172           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
173                    to_inter_left, (tag + i), comm);
174           if (to_inter_right < size)
175             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
176                      to_inter_right, (tag + i), comm);
177           if ((to_intra_left - base) < num_core)
178             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
179                      to_intra_left, (tag + i), comm);
180           if ((to_intra_right - base) < num_core)
181             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
182                      to_intra_right, (tag + i), comm);
183         }
184       }
185     }
186     // case non-ROOT-of-each-SMP
187     else {
188       // case leaves
189       if ((to_intra_left - base) >= num_core) {
190         for (i = 0; i < pipe_length; i++) {
191           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
192                     from_intra, (tag + i), comm);
193         }
194         smpi_mpi_waitall((pipe_length), request_array, status_array);
195       }
196       // case intermediate
197       else {
198         for (i = 0; i < pipe_length; i++) {
199           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
200                     from_intra, (tag + i), comm);
201         }
202         for (i = 0; i < pipe_length; i++) {
203           smpi_mpi_wait(&request_array[i], &status);
204           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
205                    to_intra_left, (tag + i), comm);
206           if ((to_intra_right - base) < num_core)
207             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
208                      to_intra_right, (tag + i), comm);
209         }
210       }
211     }
212
213     free(request_array);
214     free(status_array);
215   }
216
217   // when count is not divisible by block size, use default BCAST for the remainder
218   if ((remainder != 0) && (count > segment)) {
219     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");      
220     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
221               root, comm);
222   }
223
224   return 1;
225 }