Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'mc' into mc++
[simgrid.git] / src / smpi / colls / bcast-SMP-linear.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8 #ifndef NUM_CORE
9 #define NUM_CORE 8
10 #endif
11
12 int bcast_SMP_linear_segment_byte = 8192;
13
14 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
15                                      MPI_Datatype datatype, int root,
16                                      MPI_Comm comm)
17 {
18   int tag = COLL_TAG_BCAST;
19   MPI_Status status;
20   MPI_Request request;
21   MPI_Request *request_array;
22   MPI_Status *status_array;
23   int rank, size;
24   int i;
25   MPI_Aint extent;
26   extent = smpi_datatype_get_extent(datatype);
27
28   rank = smpi_comm_rank(comm);
29   size = smpi_comm_size(comm);
30   int num_core = simcall_host_get_core(SIMIX_host_self());
31   // do we use the default one or the number of cores in the platform ?
32   // if the number of cores is one, the platform may be simulated with 1 node = 1 core
33   if (num_core == 1) num_core = NUM_CORE;
34
35   if(size%num_core)
36     THROWF(arg_error,0, "bcast SMP linear can't be used with non multiple of num_core=%d number of processes!",num_core);
37
38   int segment = bcast_SMP_linear_segment_byte / extent;
39   int pipe_length = count / segment;
40   int remainder = count % segment;
41   int increment = segment * extent;
42
43
44   /* leader of each SMP do inter-communication
45      and act as a root for intra-communication */
46   int to_inter = (rank + num_core) % size;
47   int to_intra = (rank + 1) % size;
48   int from_inter = (rank - num_core + size) % size;
49   int from_intra = (rank + size - 1) % size;
50
51   // call native when MPI communication size is too small
52   if (size <= num_core) {
53     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");              
54     smpi_mpi_bcast(buf, count, datatype, root, comm);
55     return MPI_SUCCESS;            
56   }
57   // if root is not zero send to rank zero first
58   if (root != 0) {
59     if (rank == root)
60       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
61     else if (rank == 0)
62       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
63   }
64   // when a message is smaller than a block size => no pipeline 
65   if (count <= segment) {
66     // case ROOT
67     if (rank == 0) {
68       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
69       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
70     }
71     // case last ROOT of each SMP
72     else if (rank == (((size - 1) / num_core) * num_core)) {
73       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
74       smpi_mpi_wait(&request, &status);
75       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
76     }
77     // case intermediate ROOT of each SMP
78     else if (rank % num_core == 0) {
79       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
80       smpi_mpi_wait(&request, &status);
81       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
82       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
83     }
84     // case last non-ROOT of each SMP
85     else if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
86       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
87       smpi_mpi_wait(&request, &status);
88     }
89     // case intermediate non-ROOT of each SMP
90     else {
91       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
92       smpi_mpi_wait(&request, &status);
93       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
94     }
95     return MPI_SUCCESS;
96   }
97   // pipeline bcast
98   else {
99     request_array =
100         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
101     status_array =
102         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
103
104     // case ROOT of each SMP
105     if (rank % num_core == 0) {
106       // case real root
107       if (rank == 0) {
108         for (i = 0; i < pipe_length; i++) {
109           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
110                    (tag + i), comm);
111           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
112                    (tag + i), comm);
113         }
114       }
115       // case last ROOT of each SMP
116       else if (rank == (((size - 1) / num_core) * num_core)) {
117         for (i = 0; i < pipe_length; i++) {
118           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
119                     from_inter, (tag + i), comm);
120         }
121         for (i = 0; i < pipe_length; i++) {
122           smpi_mpi_wait(&request_array[i], &status);
123           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
124                    (tag + i), comm);
125         }
126       }
127       // case intermediate ROOT of each SMP
128       else {
129         for (i = 0; i < pipe_length; i++) {
130           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
131                     from_inter, (tag + i), comm);
132         }
133         for (i = 0; i < pipe_length; i++) {
134           smpi_mpi_wait(&request_array[i], &status);
135           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
136                    (tag + i), comm);
137           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
138                    (tag + i), comm);
139         }
140       }
141     } else {                    // case last non-ROOT of each SMP
142       if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
143         for (i = 0; i < pipe_length; i++) {
144           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
145                     from_intra, (tag + i), comm);
146         }
147         for (i = 0; i < pipe_length; i++) {
148           smpi_mpi_wait(&request_array[i], &status);
149         }
150       }
151       // case intermediate non-ROOT of each SMP
152       else {
153         for (i = 0; i < pipe_length; i++) {
154           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
155                     from_intra, (tag + i), comm);
156         }
157         for (i = 0; i < pipe_length; i++) {
158           smpi_mpi_wait(&request_array[i], &status);
159           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
160                    (tag + i), comm);
161         }
162       }
163     }
164     free(request_array);
165     free(status_array);
166   }
167
168   // when count is not divisible by block size, use default BCAST for the remainder
169   if ((remainder != 0) && (count > segment)) {
170     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");                     
171     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
172               root, comm);
173   }
174
175   return MPI_SUCCESS;
176 }