Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
have smp-aware algorithms use number of cores on the node as basis for their computat...
[simgrid.git] / src / smpi / colls / bcast-SMP-binary.c
1 #include "colls_private.h"
2 #ifndef NUM_CORE
3 #define NUM_CORE 8
4 #endif
5
6 int bcast_SMP_binary_segment_byte = 8192;
7
8 int smpi_coll_tuned_bcast_SMP_binary(void *buf, int count,
9                                      MPI_Datatype datatype, int root,
10                                      MPI_Comm comm)
11 {
12   int tag = COLL_TAG_BCAST;
13   MPI_Status status;
14   MPI_Request request;
15   MPI_Request *request_array;
16   MPI_Status *status_array;
17   int rank, size;
18   int i;
19   MPI_Aint extent;
20   extent = smpi_datatype_get_extent(datatype);
21
22   rank = smpi_comm_rank(comm);
23   size = smpi_comm_size(comm);
24   int host_num_core = simcall_host_get_core(SIMIX_host_self());
25   // do we use the default one or the number of cores in the platform ?
26   // if the number of cores is one, the platform may be simulated with 1 node = 1 core
27   if (host_num_core == 1) host_num_core = NUM_CORE;
28
29   if(size%host_num_core)
30     THROWF(arg_error,0, "bcast SMP binary can't be used with non multiple of NUM_CORE=%d number of processes ! ",host_num_core);
31
32   int segment = bcast_SMP_binary_segment_byte / extent;
33   int pipe_length = count / segment;
34   int remainder = count % segment;
35
36   int to_intra_left = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 1;
37   int to_intra_right = (rank / host_num_core) * host_num_core + (rank % host_num_core) * 2 + 2;
38   int to_inter_left = ((rank / host_num_core) * 2 + 1) * host_num_core;
39   int to_inter_right = ((rank / host_num_core) * 2 + 2) * host_num_core;
40   int from_inter = (((rank / host_num_core) - 1) / 2) * host_num_core;
41   int from_intra = (rank / host_num_core) * host_num_core + ((rank % host_num_core) - 1) / 2;
42   int increment = segment * extent;
43
44   int base = (rank / host_num_core) * host_num_core;
45   int num_core = host_num_core;
46   if (((rank / host_num_core) * host_num_core) == ((size / host_num_core) * host_num_core))
47     num_core = size - (rank / host_num_core) * host_num_core;
48
49   // if root is not zero send to rank zero first
50   if (root != 0) {
51     if (rank == root)
52       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
53     else if (rank == 0)
54       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
55   }
56   // when a message is smaller than a block size => no pipeline 
57   if (count <= segment) {
58     // case ROOT-of-each-SMP
59     if (rank % host_num_core == 0) {
60       // case ROOT
61       if (rank == 0) {
62         //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
63         if (to_inter_left < size)
64           smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
65         if (to_inter_right < size)
66           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
67         if ((to_intra_left - base) < num_core)
68           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
69         if ((to_intra_right - base) < num_core)
70           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
71       }
72       // case LEAVES ROOT-of-eash-SMP
73       else if (to_inter_left >= size) {
74         //printf("node %d from %d\n",rank,from_inter);
75         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
76         smpi_mpi_wait(&request, &status);
77         if ((to_intra_left - base) < num_core)
78           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
79         if ((to_intra_right - base) < num_core)
80           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
81       }
82       // case INTERMEDIAT ROOT-of-each-SMP
83       else {
84         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
85         request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
86         smpi_mpi_wait(&request, &status);
87         smpi_mpi_send(buf, count, datatype, to_inter_left, tag, comm);
88         if (to_inter_right < size)
89           smpi_mpi_send(buf, count, datatype, to_inter_right, tag, comm);
90         if ((to_intra_left - base) < num_core)
91           smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
92         if ((to_intra_right - base) < num_core)
93           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
94       }
95     }
96     // case non ROOT-of-each-SMP
97     else {
98       // case leaves
99       if ((to_intra_left - base) >= num_core) {
100         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
101         smpi_mpi_wait(&request, &status);
102       }
103       // case intermediate
104       else {
105         request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
106         smpi_mpi_wait(&request, &status);
107         smpi_mpi_send(buf, count, datatype, to_intra_left, tag, comm);
108         if ((to_intra_right - base) < num_core)
109           smpi_mpi_send(buf, count, datatype, to_intra_right, tag, comm);
110       }
111     }
112
113     return MPI_SUCCESS;
114   }
115
116   // pipeline bcast
117   else {
118     request_array =
119         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
120     status_array =
121         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
122
123     // case ROOT-of-each-SMP
124     if (rank % host_num_core == 0) {
125       // case ROOT
126       if (rank == 0) {
127         for (i = 0; i < pipe_length; i++) {
128           //printf("node %d left %d right %d\n",rank,to_inter_left,to_inter_right);
129           if (to_inter_left < size)
130             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
131                      to_inter_left, (tag + i), comm);
132           if (to_inter_right < size)
133             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
134                      to_inter_right, (tag + i), comm);
135           if ((to_intra_left - base) < num_core)
136             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
137                      to_intra_left, (tag + i), comm);
138           if ((to_intra_right - base) < num_core)
139             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
140                      to_intra_right, (tag + i), comm);
141         }
142       }
143       // case LEAVES ROOT-of-eash-SMP
144       else if (to_inter_left >= size) {
145         //printf("node %d from %d\n",rank,from_inter);
146         for (i = 0; i < pipe_length; i++) {
147           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
148                     from_inter, (tag + i), comm);
149         }
150         for (i = 0; i < pipe_length; i++) {
151           smpi_mpi_wait(&request_array[i], &status);
152           if ((to_intra_left - base) < num_core)
153             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
154                      to_intra_left, (tag + i), comm);
155           if ((to_intra_right - base) < num_core)
156             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
157                      to_intra_right, (tag + i), comm);
158         }
159       }
160       // case INTERMEDIAT ROOT-of-each-SMP
161       else {
162         //printf("node %d left %d right %d from %d\n",rank,to_inter_left,to_inter_right,from_inter);
163         for (i = 0; i < pipe_length; i++) {
164           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
165                     from_inter, (tag + i), comm);
166         }
167         for (i = 0; i < pipe_length; i++) {
168           smpi_mpi_wait(&request_array[i], &status);
169           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
170                    to_inter_left, (tag + i), comm);
171           if (to_inter_right < size)
172             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
173                      to_inter_right, (tag + i), comm);
174           if ((to_intra_left - base) < num_core)
175             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
176                      to_intra_left, (tag + i), comm);
177           if ((to_intra_right - base) < num_core)
178             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
179                      to_intra_right, (tag + i), comm);
180         }
181       }
182     }
183     // case non-ROOT-of-each-SMP
184     else {
185       // case leaves
186       if ((to_intra_left - base) >= num_core) {
187         for (i = 0; i < pipe_length; i++) {
188           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
189                     from_intra, (tag + i), comm);
190         }
191         smpi_mpi_waitall((pipe_length), request_array, status_array);
192       }
193       // case intermediate
194       else {
195         for (i = 0; i < pipe_length; i++) {
196           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
197                     from_intra, (tag + i), comm);
198         }
199         for (i = 0; i < pipe_length; i++) {
200           smpi_mpi_wait(&request_array[i], &status);
201           smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
202                    to_intra_left, (tag + i), comm);
203           if ((to_intra_right - base) < num_core)
204             smpi_mpi_send((char *) buf + (i * increment), segment, datatype,
205                      to_intra_right, (tag + i), comm);
206         }
207       }
208     }
209
210     free(request_array);
211     free(status_array);
212   }
213
214   // when count is not divisible by block size, use default BCAST for the remainder
215   if ((remainder != 0) && (count > segment)) {
216     XBT_WARN("MPI_bcast_SMP_binary use default MPI_bcast.");      
217     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
218               root, comm);
219   }
220
221   return 1;
222 }