Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
b3f9b6a6306e4d66434c9e805b206b4ec4f6298f
[simgrid.git] / src / smpi / colls / bcast-SMP-linear.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8
9 int bcast_SMP_linear_segment_byte = 8192;
10
11 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
12                                      MPI_Datatype datatype, int root,
13                                      MPI_Comm comm)
14 {
15   int tag = COLL_TAG_BCAST;
16   MPI_Status status;
17   MPI_Request request;
18   MPI_Request *request_array;
19   MPI_Status *status_array;
20   int rank, size;
21   int i;
22   MPI_Aint extent;
23   extent = smpi_datatype_get_extent(datatype);
24
25   rank = smpi_comm_rank(comm);
26   size = smpi_comm_size(comm);
27   if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
28     smpi_comm_init_smp(comm);
29   }
30   int num_core=1;
31   if (smpi_comm_is_uniform(comm)){
32     num_core = smpi_comm_size(smpi_comm_get_intra_comm(comm));
33   }else{
34     //implementation buggy in this case
35     return smpi_coll_tuned_bcast_mpich( buf , count, datatype,
36               root, comm);
37   }
38
39   int segment = bcast_SMP_linear_segment_byte / extent;
40   segment =  segment == 0 ? 1 :segment; 
41   int pipe_length = count / segment;
42   int remainder = count % segment;
43   int increment = segment * extent;
44
45
46   /* leader of each SMP do inter-communication
47      and act as a root for intra-communication */
48   int to_inter = (rank + num_core) % size;
49   int to_intra = (rank + 1) % size;
50   int from_inter = (rank - num_core + size) % size;
51   int from_intra = (rank + size - 1) % size;
52
53   // call native when MPI communication size is too small
54   if (size <= num_core) {
55     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");              
56     smpi_mpi_bcast(buf, count, datatype, root, comm);
57     return MPI_SUCCESS;            
58   }
59   // if root is not zero send to rank zero first
60   if (root != 0) {
61     if (rank == root)
62       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
63     else if (rank == 0)
64       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
65   }
66   // when a message is smaller than a block size => no pipeline 
67   if (count <= segment) {
68     // case ROOT
69     if (rank == 0) {
70       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
71       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
72     }
73     // case last ROOT of each SMP
74     else if (rank == (((size - 1) / num_core) * num_core)) {
75       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
76       smpi_mpi_wait(&request, &status);
77       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
78     }
79     // case intermediate ROOT of each SMP
80     else if (rank % num_core == 0) {
81       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
82       smpi_mpi_wait(&request, &status);
83       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
84       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
85     }
86     // case last non-ROOT of each SMP
87     else if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
88       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
89       smpi_mpi_wait(&request, &status);
90     }
91     // case intermediate non-ROOT of each SMP
92     else {
93       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
94       smpi_mpi_wait(&request, &status);
95       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
96     }
97     return MPI_SUCCESS;
98   }
99   // pipeline bcast
100   else {
101     request_array =
102         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
103     status_array =
104         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
105
106     // case ROOT of each SMP
107     if (rank % num_core == 0) {
108       // case real root
109       if (rank == 0) {
110         for (i = 0; i < pipe_length; i++) {
111           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
112                    (tag + i), comm);
113           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
114                    (tag + i), comm);
115         }
116       }
117       // case last ROOT of each SMP
118       else if (rank == (((size - 1) / num_core) * num_core)) {
119         for (i = 0; i < pipe_length; i++) {
120           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
121                     from_inter, (tag + i), comm);
122         }
123         for (i = 0; i < pipe_length; i++) {
124           smpi_mpi_wait(&request_array[i], &status);
125           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
126                    (tag + i), comm);
127         }
128       }
129       // case intermediate ROOT of each SMP
130       else {
131         for (i = 0; i < pipe_length; i++) {
132           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
133                     from_inter, (tag + i), comm);
134         }
135         for (i = 0; i < pipe_length; i++) {
136           smpi_mpi_wait(&request_array[i], &status);
137           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
138                    (tag + i), comm);
139           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
140                    (tag + i), comm);
141         }
142       }
143     } else {                    // case last non-ROOT of each SMP
144       if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
145         for (i = 0; i < pipe_length; i++) {
146           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
147                     from_intra, (tag + i), comm);
148         }
149         for (i = 0; i < pipe_length; i++) {
150           smpi_mpi_wait(&request_array[i], &status);
151         }
152       }
153       // case intermediate non-ROOT of each SMP
154       else {
155         for (i = 0; i < pipe_length; i++) {
156           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
157                     from_intra, (tag + i), comm);
158         }
159         for (i = 0; i < pipe_length; i++) {
160           smpi_mpi_wait(&request_array[i], &status);
161           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
162                    (tag + i), comm);
163         }
164       }
165     }
166     free(request_array);
167     free(status_array);
168   }
169
170   // when count is not divisible by block size, use default BCAST for the remainder
171   if ((remainder != 0) && (count > segment)) {
172     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");                     
173     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
174               root, comm);
175   }
176
177   return MPI_SUCCESS;
178 }