Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of git+ssh://scm.gforge.inria.fr//gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / colls / bcast-SMP-linear.c
1 #include "colls_private.h"
2 #ifndef NUM_CORE
3 #define NUM_CORE 8
4 #endif
5
6 int bcast_SMP_linear_segment_byte = 8192;
7
8 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
9                                      MPI_Datatype datatype, int root,
10                                      MPI_Comm comm)
11 {
12   int tag = COLL_TAG_BCAST;
13   MPI_Status status;
14   MPI_Request request;
15   MPI_Request *request_array;
16   MPI_Status *status_array;
17   int rank, size;
18   int i;
19   MPI_Aint extent;
20   extent = smpi_datatype_get_extent(datatype);
21
22   rank = smpi_comm_rank(comm);
23   size = smpi_comm_size(comm);
24   int num_core = simcall_host_get_core(SIMIX_host_self());
25   // do we use the default one or the number of cores in the platform ?
26   // if the number of cores is one, the platform may be simulated with 1 node = 1 core
27   if (num_core == 1) num_core = NUM_CORE;
28
29   if(size%num_core)
30     THROWF(arg_error,0, "bcast SMP linear can't be used with non multiple of num_core=%d number of processes!",num_core);
31
32   int segment = bcast_SMP_linear_segment_byte / extent;
33   int pipe_length = count / segment;
34   int remainder = count % segment;
35   int increment = segment * extent;
36
37
38   /* leader of each SMP do inter-communication
39      and act as a root for intra-communication */
40   int to_inter = (rank + num_core) % size;
41   int to_intra = (rank + 1) % size;
42   int from_inter = (rank - num_core + size) % size;
43   int from_intra = (rank + size - 1) % size;
44
45   // call native when MPI communication size is too small
46   if (size <= num_core) {
47     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");              
48     smpi_mpi_bcast(buf, count, datatype, root, comm);
49     return MPI_SUCCESS;            
50   }
51   // if root is not zero send to rank zero first
52   if (root != 0) {
53     if (rank == root)
54       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
55     else if (rank == 0)
56       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
57   }
58   // when a message is smaller than a block size => no pipeline 
59   if (count <= segment) {
60     // case ROOT
61     if (rank == 0) {
62       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
63       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
64     }
65     // case last ROOT of each SMP
66     else if (rank == (((size - 1) / num_core) * num_core)) {
67       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
68       smpi_mpi_wait(&request, &status);
69       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
70     }
71     // case intermediate ROOT of each SMP
72     else if (rank % num_core == 0) {
73       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
74       smpi_mpi_wait(&request, &status);
75       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
76       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
77     }
78     // case last non-ROOT of each SMP
79     else if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
80       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
81       smpi_mpi_wait(&request, &status);
82     }
83     // case intermediate non-ROOT of each SMP
84     else {
85       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
86       smpi_mpi_wait(&request, &status);
87       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
88     }
89     return MPI_SUCCESS;
90   }
91   // pipeline bcast
92   else {
93     request_array =
94         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
95     status_array =
96         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
97
98     // case ROOT of each SMP
99     if (rank % num_core == 0) {
100       // case real root
101       if (rank == 0) {
102         for (i = 0; i < pipe_length; i++) {
103           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
104                    (tag + i), comm);
105           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
106                    (tag + i), comm);
107         }
108       }
109       // case last ROOT of each SMP
110       else if (rank == (((size - 1) / num_core) * num_core)) {
111         for (i = 0; i < pipe_length; i++) {
112           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
113                     from_inter, (tag + i), comm);
114         }
115         for (i = 0; i < pipe_length; i++) {
116           smpi_mpi_wait(&request_array[i], &status);
117           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
118                    (tag + i), comm);
119         }
120       }
121       // case intermediate ROOT of each SMP
122       else {
123         for (i = 0; i < pipe_length; i++) {
124           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
125                     from_inter, (tag + i), comm);
126         }
127         for (i = 0; i < pipe_length; i++) {
128           smpi_mpi_wait(&request_array[i], &status);
129           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
130                    (tag + i), comm);
131           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
132                    (tag + i), comm);
133         }
134       }
135     } else {                    // case last non-ROOT of each SMP
136       if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
137         for (i = 0; i < pipe_length; i++) {
138           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
139                     from_intra, (tag + i), comm);
140         }
141         for (i = 0; i < pipe_length; i++) {
142           smpi_mpi_wait(&request_array[i], &status);
143         }
144       }
145       // case intermediate non-ROOT of each SMP
146       else {
147         for (i = 0; i < pipe_length; i++) {
148           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
149                     from_intra, (tag + i), comm);
150         }
151         for (i = 0; i < pipe_length; i++) {
152           smpi_mpi_wait(&request_array[i], &status);
153           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
154                    (tag + i), comm);
155         }
156       }
157     }
158     free(request_array);
159     free(status_array);
160   }
161
162   // when count is not divisible by block size, use default BCAST for the remainder
163   if ((remainder != 0) && (count > segment)) {
164     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");                     
165     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
166               root, comm);
167   }
168
169   return MPI_SUCCESS;
170 }