Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
092ab26ed07f2e81998d9ff1da9a3d8cedda6315
[simgrid.git] / src / smpi / colls / bcast-SMP-linear.c
1 #include "colls_private.h"
2 #ifndef NUM_CORE
3 #define NUM_CORE 8
4 #endif
5
6 int bcast_SMP_linear_segment_byte = 8192;
7
8 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
9                                      MPI_Datatype datatype, int root,
10                                      MPI_Comm comm)
11 {
12   int tag = COLL_TAG_BCAST;
13   MPI_Status status;
14   MPI_Request request;
15   MPI_Request *request_array;
16   MPI_Status *status_array;
17   int rank, size;
18   int i;
19   MPI_Aint extent;
20   extent = smpi_datatype_get_extent(datatype);
21
22   rank = smpi_comm_rank(comm);
23   size = smpi_comm_size(comm);
24
25   if(size%NUM_CORE)
26     THROWF(arg_error,0, "bcast SMP linear can't be used with non multiple of NUM_CORE=%d number of processes ! ",NUM_CORE);
27
28   int segment = bcast_SMP_linear_segment_byte / extent;
29   int pipe_length = count / segment;
30   int remainder = count % segment;
31   int increment = segment * extent;
32
33
34   /* leader of each SMP do inter-communication
35      and act as a root for intra-communication */
36   int to_inter = (rank + NUM_CORE) % size;
37   int to_intra = (rank + 1) % size;
38   int from_inter = (rank - NUM_CORE + size) % size;
39   int from_intra = (rank + size - 1) % size;
40
41   // call native when MPI communication size is too small
42   if (size <= NUM_CORE) {
43     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");              
44     smpi_mpi_bcast(buf, count, datatype, root, comm);
45     return MPI_SUCCESS;            
46   }
47   // if root is not zero send to rank zero first
48   if (root != 0) {
49     if (rank == root)
50       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
51     else if (rank == 0)
52       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
53   }
54   // when a message is smaller than a block size => no pipeline 
55   if (count <= segment) {
56     // case ROOT
57     if (rank == 0) {
58       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
59       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
60     }
61     // case last ROOT of each SMP
62     else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
63       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
64       smpi_mpi_wait(&request, &status);
65       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
66     }
67     // case intermediate ROOT of each SMP
68     else if (rank % NUM_CORE == 0) {
69       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
70       smpi_mpi_wait(&request, &status);
71       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
72       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
73     }
74     // case last non-ROOT of each SMP
75     else if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
76       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
77       smpi_mpi_wait(&request, &status);
78     }
79     // case intermediate non-ROOT of each SMP
80     else {
81       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
82       smpi_mpi_wait(&request, &status);
83       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
84     }
85     return MPI_SUCCESS;
86   }
87   // pipeline bcast
88   else {
89     request_array =
90         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
91     status_array =
92         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
93
94     // case ROOT of each SMP
95     if (rank % NUM_CORE == 0) {
96       // case real root
97       if (rank == 0) {
98         for (i = 0; i < pipe_length; i++) {
99           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
100                    (tag + i), comm);
101           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
102                    (tag + i), comm);
103         }
104       }
105       // case last ROOT of each SMP
106       else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
107         for (i = 0; i < pipe_length; i++) {
108           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
109                     from_inter, (tag + i), comm);
110         }
111         for (i = 0; i < pipe_length; i++) {
112           smpi_mpi_wait(&request_array[i], &status);
113           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
114                    (tag + i), comm);
115         }
116       }
117       // case intermediate ROOT of each SMP
118       else {
119         for (i = 0; i < pipe_length; i++) {
120           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
121                     from_inter, (tag + i), comm);
122         }
123         for (i = 0; i < pipe_length; i++) {
124           smpi_mpi_wait(&request_array[i], &status);
125           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
126                    (tag + i), comm);
127           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
128                    (tag + i), comm);
129         }
130       }
131     } else {                    // case last non-ROOT of each SMP
132       if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
133         for (i = 0; i < pipe_length; i++) {
134           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
135                     from_intra, (tag + i), comm);
136         }
137         for (i = 0; i < pipe_length; i++) {
138           smpi_mpi_wait(&request_array[i], &status);
139         }
140       }
141       // case intermediate non-ROOT of each SMP
142       else {
143         for (i = 0; i < pipe_length; i++) {
144           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
145                     from_intra, (tag + i), comm);
146         }
147         for (i = 0; i < pipe_length; i++) {
148           smpi_mpi_wait(&request_array[i], &status);
149           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
150                    (tag + i), comm);
151         }
152       }
153     }
154     free(request_array);
155     free(status_array);
156   }
157
158   // when count is not divisible by block size, use default BCAST for the remainder
159   if ((remainder != 0) && (count > segment)) {
160     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");                     
161     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
162               root, comm);
163   }
164
165   return MPI_SUCCESS;
166 }