Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
SMPI colls in not really C++. But cleaner than before.
[simgrid.git] / src / smpi / colls / bcast / bcast-SMP-linear.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.h"
8
9 int bcast_SMP_linear_segment_byte = 8192;
10
11 int Coll_bcast_SMP_linear::bcast(void *buf, int count,
12                                      MPI_Datatype datatype, int root,
13                                      MPI_Comm comm)
14 {
15   int tag = COLL_TAG_BCAST;
16   MPI_Status status;
17   MPI_Request request;
18   MPI_Request *request_array;
19   MPI_Status *status_array;
20   int rank, size;
21   int i;
22   MPI_Aint extent;
23   extent = datatype->get_extent();
24
25   rank = comm->rank();
26   size = comm->size();
27   if(comm->get_leaders_comm()==MPI_COMM_NULL){
28     comm->init_smp();
29   }
30   int num_core=1;
31   if (comm->is_uniform()){
32     num_core = comm->get_intra_comm()->size();
33   }else{
34     //implementation buggy in this case
35     return Coll_bcast_mpich::bcast( buf , count, datatype,
36               root, comm);
37   }
38
39   int segment = bcast_SMP_linear_segment_byte / extent;
40   segment =  segment == 0 ? 1 :segment; 
41   int pipe_length = count / segment;
42   int remainder = count % segment;
43   int increment = segment * extent;
44
45
46   /* leader of each SMP do inter-communication
47      and act as a root for intra-communication */
48   int to_inter = (rank + num_core) % size;
49   int to_intra = (rank + 1) % size;
50   int from_inter = (rank - num_core + size) % size;
51   int from_intra = (rank + size - 1) % size;
52
53   // call native when MPI communication size is too small
54   if (size <= num_core) {
55     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");              
56     Coll_bcast_default::bcast(buf, count, datatype, root, comm);
57     return MPI_SUCCESS;            
58   }
59   // if root is not zero send to rank zero first
60   if (root != 0) {
61     if (rank == root)
62       Request::send(buf, count, datatype, 0, tag, comm);
63     else if (rank == 0)
64       Request::recv(buf, count, datatype, root, tag, comm, &status);
65   }
66   // when a message is smaller than a block size => no pipeline 
67   if (count <= segment) {
68     // case ROOT
69     if (rank == 0) {
70       Request::send(buf, count, datatype, to_inter, tag, comm);
71       Request::send(buf, count, datatype, to_intra, tag, comm);
72     }
73     // case last ROOT of each SMP
74     else if (rank == (((size - 1) / num_core) * num_core)) {
75       request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
76       Request::wait(&request, &status);
77       Request::send(buf, count, datatype, to_intra, tag, comm);
78     }
79     // case intermediate ROOT of each SMP
80     else if (rank % num_core == 0) {
81       request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
82       Request::wait(&request, &status);
83       Request::send(buf, count, datatype, to_inter, tag, comm);
84       Request::send(buf, count, datatype, to_intra, tag, comm);
85     }
86     // case last non-ROOT of each SMP
87     else if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
88       request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
89       Request::wait(&request, &status);
90     }
91     // case intermediate non-ROOT of each SMP
92     else {
93       request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
94       Request::wait(&request, &status);
95       Request::send(buf, count, datatype, to_intra, tag, comm);
96     }
97     return MPI_SUCCESS;
98   }
99   // pipeline bcast
100   else {
101     request_array =
102         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
103     status_array =
104         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
105
106     // case ROOT of each SMP
107     if (rank % num_core == 0) {
108       // case real root
109       if (rank == 0) {
110         for (i = 0; i < pipe_length; i++) {
111           Request::send((char *) buf + (i * increment), segment, datatype, to_inter,
112                    (tag + i), comm);
113           Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
114                    (tag + i), comm);
115         }
116       }
117       // case last ROOT of each SMP
118       else if (rank == (((size - 1) / num_core) * num_core)) {
119         for (i = 0; i < pipe_length; i++) {
120           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
121                     from_inter, (tag + i), comm);
122         }
123         for (i = 0; i < pipe_length; i++) {
124           Request::wait(&request_array[i], &status);
125           Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
126                    (tag + i), comm);
127         }
128       }
129       // case intermediate ROOT of each SMP
130       else {
131         for (i = 0; i < pipe_length; i++) {
132           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
133                     from_inter, (tag + i), comm);
134         }
135         for (i = 0; i < pipe_length; i++) {
136           Request::wait(&request_array[i], &status);
137           Request::send((char *) buf + (i * increment), segment, datatype, to_inter,
138                    (tag + i), comm);
139           Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
140                    (tag + i), comm);
141         }
142       }
143     } else {                    // case last non-ROOT of each SMP
144       if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
145         for (i = 0; i < pipe_length; i++) {
146           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
147                     from_intra, (tag + i), comm);
148         }
149         for (i = 0; i < pipe_length; i++) {
150           Request::wait(&request_array[i], &status);
151         }
152       }
153       // case intermediate non-ROOT of each SMP
154       else {
155         for (i = 0; i < pipe_length; i++) {
156           request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
157                     from_intra, (tag + i), comm);
158         }
159         for (i = 0; i < pipe_length; i++) {
160           Request::wait(&request_array[i], &status);
161           Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
162                    (tag + i), comm);
163         }
164       }
165     }
166     free(request_array);
167     free(status_array);
168   }
169
170   // when count is not divisible by block size, use default BCAST for the remainder
171   if ((remainder != 0) && (count > segment)) {
172     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");                     
173     Colls::bcast((char *) buf + (pipe_length * increment), remainder, datatype,
174               root, comm);
175   }
176
177   return MPI_SUCCESS;
178 }