Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add collectives for allgather, allreduce, bcast and reduce
[simgrid.git] / src / smpi / colls / bcast-SMP-linear.c
1 #include "colls.h"
2 #ifndef NUM_CORE
3 #define NUM_CORE 8
4 #endif
5
6 int bcast_SMP_linear_segment_byte = 8192;
7
8 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
9                                      MPI_Datatype datatype, int root,
10                                      MPI_Comm comm)
11 {
12   int tag = 5000;
13   MPI_Status status;
14   MPI_Request request;
15   MPI_Request *request_array;
16   MPI_Status *status_array;
17   int rank, size;
18   int i;
19   MPI_Aint extent;
20   MPI_Type_extent(datatype, &extent);
21
22   MPI_Comm_rank(comm, &rank);
23   MPI_Comm_size(comm, &size);
24
25   int segment = bcast_SMP_linear_segment_byte / extent;
26   int pipe_length = count / segment;
27   int remainder = count % segment;
28   int increment = segment * extent;
29
30
31   /* leader of each SMP do inter-communication
32      and act as a root for intra-communication */
33   int to_inter = (rank + NUM_CORE) % size;
34   int to_intra = (rank + 1) % size;
35   int from_inter = (rank - NUM_CORE + size) % size;
36   int from_intra = (rank + size - 1) % size;
37
38   // call native when MPI communication size is too small
39   if (size <= NUM_CORE) {
40     return MPI_Bcast(buf, count, datatype, root, comm);
41   }
42   // if root is not zero send to rank zero first
43   if (root != 0) {
44     if (rank == root)
45       MPI_Send(buf, count, datatype, 0, tag, comm);
46     else if (rank == 0)
47       MPI_Recv(buf, count, datatype, root, tag, comm, &status);
48   }
49   // when a message is smaller than a block size => no pipeline 
50   if (count <= segment) {
51     // case ROOT
52     if (rank == 0) {
53       MPI_Send(buf, count, datatype, to_inter, tag, comm);
54       MPI_Send(buf, count, datatype, to_intra, tag, comm);
55     }
56     // case last ROOT of each SMP
57     else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
58       MPI_Irecv(buf, count, datatype, from_inter, tag, comm, &request);
59       MPI_Wait(&request, &status);
60       MPI_Send(buf, count, datatype, to_intra, tag, comm);
61     }
62     // case intermediate ROOT of each SMP
63     else if (rank % NUM_CORE == 0) {
64       MPI_Irecv(buf, count, datatype, from_inter, tag, comm, &request);
65       MPI_Wait(&request, &status);
66       MPI_Send(buf, count, datatype, to_inter, tag, comm);
67       MPI_Send(buf, count, datatype, to_intra, tag, comm);
68     }
69     // case last non-ROOT of each SMP
70     else if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
71       MPI_Irecv(buf, count, datatype, from_intra, tag, comm, &request);
72       MPI_Wait(&request, &status);
73     }
74     // case intermediate non-ROOT of each SMP
75     else {
76       MPI_Irecv(buf, count, datatype, from_intra, tag, comm, &request);
77       MPI_Wait(&request, &status);
78       MPI_Send(buf, count, datatype, to_intra, tag, comm);
79     }
80     return MPI_SUCCESS;
81   }
82   // pipeline bcast
83   else {
84     request_array =
85         (MPI_Request *) malloc((size + pipe_length) * sizeof(MPI_Request));
86     status_array =
87         (MPI_Status *) malloc((size + pipe_length) * sizeof(MPI_Status));
88
89     // case ROOT of each SMP
90     if (rank % NUM_CORE == 0) {
91       // case real root
92       if (rank == 0) {
93         for (i = 0; i < pipe_length; i++) {
94           MPI_Send((char *) buf + (i * increment), segment, datatype, to_inter,
95                    (tag + i), comm);
96           MPI_Send((char *) buf + (i * increment), segment, datatype, to_intra,
97                    (tag + i), comm);
98         }
99       }
100       // case last ROOT of each SMP
101       else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
102         for (i = 0; i < pipe_length; i++) {
103           MPI_Irecv((char *) buf + (i * increment), segment, datatype,
104                     from_inter, (tag + i), comm, &request_array[i]);
105         }
106         for (i = 0; i < pipe_length; i++) {
107           MPI_Wait(&request_array[i], &status);
108           MPI_Send((char *) buf + (i * increment), segment, datatype, to_intra,
109                    (tag + i), comm);
110         }
111       }
112       // case intermediate ROOT of each SMP
113       else {
114         for (i = 0; i < pipe_length; i++) {
115           MPI_Irecv((char *) buf + (i * increment), segment, datatype,
116                     from_inter, (tag + i), comm, &request_array[i]);
117         }
118         for (i = 0; i < pipe_length; i++) {
119           MPI_Wait(&request_array[i], &status);
120           MPI_Send((char *) buf + (i * increment), segment, datatype, to_inter,
121                    (tag + i), comm);
122           MPI_Send((char *) buf + (i * increment), segment, datatype, to_intra,
123                    (tag + i), comm);
124         }
125       }
126     } else {                    // case last non-ROOT of each SMP
127       if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
128         for (i = 0; i < pipe_length; i++) {
129           MPI_Irecv((char *) buf + (i * increment), segment, datatype,
130                     from_intra, (tag + i), comm, &request_array[i]);
131         }
132         for (i = 0; i < pipe_length; i++) {
133           MPI_Wait(&request_array[i], &status);
134         }
135       }
136       // case intermediate non-ROOT of each SMP
137       else {
138         for (i = 0; i < pipe_length; i++) {
139           MPI_Irecv((char *) buf + (i * increment), segment, datatype,
140                     from_intra, (tag + i), comm, &request_array[i]);
141         }
142         for (i = 0; i < pipe_length; i++) {
143           MPI_Wait(&request_array[i], &status);
144           MPI_Send((char *) buf + (i * increment), segment, datatype, to_intra,
145                    (tag + i), comm);
146         }
147       }
148     }
149     free(request_array);
150     free(status_array);
151   }
152
153   // when count is not divisible by block size, use default BCAST for the remainder
154   if ((remainder != 0) && (count > segment)) {
155     MPI_Bcast((char *) buf + (pipe_length * increment), remainder, datatype,
156               root, comm);
157   }
158
159   return 1;
160 }