Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MPI_Comm -> C++
[simgrid.git] / src / smpi / colls / bcast-SMP-binary.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8
9
10 int bcast_SMP_binary_segment_byte = 8192;
11
12 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
13                                      MPI_Datatype datatype, int root,
14                                      MPI_Comm comm)
15 {
16   int tag = COLL_TAG_BCAST;
17   MPI_Status status;
18   MPI_Request request;
19   MPI_Request *request_array;
20   MPI_Status *status_array;
21   int rank, size;
22   int i;
23   MPI_Aint extent;
24   extent = smpi_datatype_get_extent(datatype);
25
26   rank = comm->rank();
27   size = comm->size();
28   if(comm->get_leaders_comm()==MPI_COMM_NULL){
29     comm->init_smp();
30   }
31   int host_num_core=1;
32   if (comm->is_uniform()){
33     host_num_core = comm->get_intra_comm()->size();
34   }else{
35     //implementation buggy in this case
36     return smpi_coll_tuned_bcast_mpich( buf , count, datatype,
37               root, comm);
38   }
39
40   int segment = bcast_SMP_binary_segment_byte / extent;
41   int pipe_length = count / segment;
42   int remainder = count % segment;
43
44   int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
45   int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
46   int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
47   int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
48   int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
49   int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
50   int increment = segment * extent;
51
52   int base = (rank / host_num_core) * host_num_core;
53   int num_core = host_num_core;
54   if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
55     num_core = size - (rank / host_num_core) * host_num_core;
56
57   // if root is not zero send to rank zero first
58   if (root != 0) {
59     if (rank == root)
60       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
61     else if (rank == 0)
62       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
63   }
64   // when a message is smaller than a block size => no pipeline 
65   if (count <= segment) {
66     // case ROOT-of-each-SMP
67     if (rank % host_num_core == 0) {
68       // case ROOT
69       if (rank == 0) {
70         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
71         if (to_inter_left < size)
72           smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
73         if (to_inter_right < size)
74           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
75         if ((to_intra_left - base) < num_core)
76           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
77         if ((to_intra_right - base) < num_core)
78           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
79       }
80       // case LEAVES ROOT-of-eash-SMP
81       else if (to_inter_left >= size) {
82         //printf("node %d from %d\n",rank,from_inter);
83         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
84         smpi_mpi_wait(&request, &status);
85         if ((to_intra_left - base) < num_core)
86           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
87         if ((to_intra_right - base) < num_core)
88           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
89       }
90       // case INTERMEDIAT ROOT-of-each-SMP
91       else {
92         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
93         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
94         smpi_mpi_wait(&request, &status);
95         smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
96         if (to_inter_right < size)
97           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
98         if ((to_intra_left - base) < num_core)
99           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
100         if ((to_intra_right - base) < num_core)
101           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
102       }
103     }
104     // case non ROOT-of-each-SMP
105     else {
106       // case leaves
107       if ((to_intra_left - base) >= num_core) {
108         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
109         smpi_mpi_wait(&request, &status);
110       }
111       // case intermediate
112       else {
113         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
114         smpi_mpi_wait(&request, &status);
115         smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
116         if ((to_intra_right - base) < num_core)
117           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
118       }
119     }
120
121     return MPI_SUCCESS;
122   }
123
124   // pipeline bcast
125   else {
126     request_array =
127         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
128     status_array =
129         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
130
131     // case ROOT-of-each-SMP
132     if (rank % host_num_core == 0) {
133       // case ROOT
134       if (rank == 0) {
135         for (i = 0; i < pipe_length; i++) {
136           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
137           if (to_inter_left < size)
138             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
139                      to_inter_left, (tag + i), comm);
140           if (to_inter_right < size)
141             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
142                      to_inter_right, (tag + i), comm);
143           if ((to_intra_left - base) < num_core)
144             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
145                      to_intra_left, (tag + i), comm);
146           if ((to_intra_right - base) < num_core)
147             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
148                      to_intra_right, (tag + i), comm);
149         }
150       }
151       // case LEAVES ROOT-of-eash-SMP
152       else if (to_inter_left >= size) {
153         //printf("node %d from %d\n",rank,from_inter);
154         for (i = 0; i < pipe_length; i++) {
155           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
156                     from_inter, (tag + i), comm);
157         }
158         for (i = 0; i < pipe_length; i++) {
159           smpi_mpi_wait(&request_array[i], &status);
160           if ((to_intra_left - base) < num_core)
161             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
162                      to_intra_left, (tag + i), comm);
163           if ((to_intra_right - base) < num_core)
164             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
165                      to_intra_right, (tag + i), comm);
166         }
167       }
168       // case INTERMEDIAT ROOT-of-each-SMP
169       else {
170         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
171         for (i = 0; i < pipe_length; i++) {
172           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
173                     from_inter, (tag + i), comm);
174         }
175         for (i = 0; i < pipe_length; i++) {
176           smpi_mpi_wait(&request_array[i], &status);
177           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
178                    to_inter_left, (tag + i), comm);
179           if (to_inter_right < size)
180             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
181                      to_inter_right, (tag + i), comm);
182           if ((to_intra_left - base) < num_core)
183             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
184                      to_intra_left, (tag + i), comm);
185           if ((to_intra_right - base) < num_core)
186             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
187                      to_intra_right, (tag + i), comm);
188         }
189       }
190     }
191     // case non-ROOT-of-each-SMP
192     else {
193       // case leaves
194       if ((to_intra_left - base) >= num_core) {
195         for (i = 0; i < pipe_length; i++) {
196           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
197                     from_intra, (tag + i), comm);
198         }
199         smpi_mpi_waitall((pipe_length), request_array, status_array);
200       }
201       // case intermediate
202       else {
203         for (i = 0; i < pipe_length; i++) {
204           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
205                     from_intra, (tag + i), comm);
206         }
207         for (i = 0; i < pipe_length; i++) {
208           smpi_mpi_wait(&request_array[i], &status);
209           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
210                    to_intra_left, (tag + i), comm);
211           if ((to_intra_right - base) < num_core)
212             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
213                      to_intra_right, (tag + i), comm);
214         }
215       }
216     }
217
218     free(request_array);
219     free(status_array);
220   }
221
222   // when count is not divisible by block size, use default BCAST for the remainder
223   if ((remainder != 0) && (count > segment)) {
224     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");      
225     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
226               root, comm);
227   }
228
229   return 1;
230 }