Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines.
[simgrid.git] / src / smpi / colls / bcast / bcast-mvapich-smp.cpp
1 /* Copyright (c) 2013-2021. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37
38 #include "../colls_private.hpp"
39
40 extern int (*MV2_Bcast_function) (void *buffer, int count, MPI_Datatype datatype,
41                            int root, MPI_Comm comm_ptr);
42
43 extern int (*MV2_Bcast_intra_node_function) (void *buffer, int count, MPI_Datatype datatype,
44                                       int root, MPI_Comm comm_ptr);
45
46 extern int zcpy_knomial_factor;
47 extern int mv2_pipelined_zcpy_knomial_factor;
48 extern int bcast_segment_size;
49 extern int mv2_inter_node_knomial_factor;
50 extern int mv2_intra_node_knomial_factor;
51 extern int mv2_bcast_two_level_system_size;
52 #define INTRA_NODE_ROOT 0
53
54 #define MPIR_Pipelined_Bcast_Zcpy_MV2 bcast__mpich
55 #define MPIR_Pipelined_Bcast_MV2 bcast__mpich
56 #define MPIR_Bcast_binomial_MV2 bcast__binomial_tree
57 #define MPIR_Bcast_scatter_ring_allgather_shm_MV2 bcast__scatter_LR_allgather
58 #define MPIR_Bcast_scatter_doubling_allgather_MV2 bcast__scatter_rdb_allgather
59 #define MPIR_Bcast_scatter_ring_allgather_MV2 bcast__scatter_LR_allgather
60 #define MPIR_Shmem_Bcast_MV2 bcast__mpich
61 #define MPIR_Bcast_tune_inter_node_helper_MV2 bcast__mvapich2_inter_node
62 #define MPIR_Bcast_inter_node_helper_MV2 bcast__mvapich2_inter_node
63 #define MPIR_Knomial_Bcast_intra_node_MV2 bcast__mvapich2_knomial_intra_node
64 #define MPIR_Bcast_intra_MV2 bcast__mvapich2_intra_node
65
66 extern int zcpy_knomial_factor;
67 extern int mv2_pipelined_zcpy_knomial_factor;
68 extern int bcast_segment_size;
69 extern int mv2_inter_node_knomial_factor;
70 extern int mv2_intra_node_knomial_factor;
71 #define mv2_bcast_two_level_system_size  64
72 #define mv2_bcast_short_msg             16384
73 #define mv2_bcast_large_msg            512*1024
74 #define mv2_knomial_intra_node_threshold 131072
75 #define mv2_scatter_rd_inter_leader_bcast 1
76 namespace simgrid {
77 namespace smpi {
78 int bcast__mvapich2_inter_node(void *buffer,
79                                int count,
80                                MPI_Datatype datatype,
81                                int root,
82                                MPI_Comm  comm)
83 {
84     int rank;
85     int mpi_errno = MPI_SUCCESS;
86     MPI_Comm shmem_comm, leader_comm;
87     int local_rank, local_size, global_rank = -1;
88     int leader_root, leader_of_root;
89
90
91     rank = comm->rank();
92     //comm_size = comm->size();
93
94     if (MV2_Bcast_function == nullptr) {
95       MV2_Bcast_function = bcast__mpich;
96     }
97
98     if (MV2_Bcast_intra_node_function == nullptr) {
99       MV2_Bcast_intra_node_function = bcast__mpich;
100     }
101
102     if(comm->get_leaders_comm()==MPI_COMM_NULL){
103       comm->init_smp();
104     }
105
106     shmem_comm = comm->get_intra_comm();
107     local_rank = shmem_comm->rank();
108     local_size = shmem_comm->size();
109
110     leader_comm = comm->get_leaders_comm();
111
112     if ((local_rank == 0) && (local_size > 1)) {
113       global_rank = leader_comm->rank();
114     }
115
116     int* leaders_map = comm->get_leaders_map();
117     leader_of_root = comm->group()->rank(leaders_map[root]);
118     leader_root = leader_comm->group()->rank(leaders_map[root]);
119
120
121     if (local_size > 1) {
122         if ((local_rank == 0) && (root != rank) && (leader_root == global_rank)) {
123             Request::recv(buffer, count, datatype, root,
124                                      COLL_TAG_BCAST, comm, MPI_STATUS_IGNORE);
125         }
126         if ((local_rank != 0) && (root == rank)) {
127             Request::send(buffer, count, datatype,
128                                      leader_of_root, COLL_TAG_BCAST, comm);
129         }
130     }
131 #if defined(_MCST_SUPPORT_)
132     if (comm_ptr->ch.is_mcast_ok) {
133         mpi_errno = MPIR_Mcast_inter_node_MV2(buffer, count, datatype, root, comm_ptr,
134                                               errflag);
135         if (mpi_errno == MPI_SUCCESS) {
136             goto fn_exit;
137         } else {
138             goto fn_fail;
139         }
140     }
141 #endif
142 /*
143     if (local_rank == 0) {
144         leader_comm = comm->get_leaders_comm();
145         root = leader_root;
146     }
147
148     if (MV2_Bcast_function == &MPIR_Pipelined_Bcast_MV2) {
149         mpi_errno = MPIR_Pipelined_Bcast_MV2(buffer, count, datatype,
150                                              root, comm);
151     } else if (MV2_Bcast_function == &MPIR_Bcast_scatter_ring_allgather_shm_MV2) {
152         mpi_errno = MPIR_Bcast_scatter_ring_allgather_shm_MV2(buffer, count,
153                                                               datatype, root,
154                                                               comm);
155     } else */{
156         if (local_rank == 0) {
157       /*      if (MV2_Bcast_function == &MPIR_Knomial_Bcast_inter_node_wrapper_MV2) {
158                 mpi_errno = MPIR_Knomial_Bcast_inter_node_wrapper_MV2(buffer, count,
159                                                               datatype, root,
160                                                               comm);
161             } else {*/
162                 mpi_errno = MV2_Bcast_function(buffer, count, datatype,
163                                                leader_root, leader_comm);
164           //  }
165         }
166     }
167
168     return mpi_errno;
169 }
170
171
172 int bcast__mvapich2_knomial_intra_node(void *buffer,
173                                        int count,
174                                        MPI_Datatype datatype,
175                                        int root, MPI_Comm  comm)
176 {
177     int local_size = 0, rank;
178     int mpi_errno = MPI_SUCCESS;
179     int src, dst, mask, relative_rank;
180     int k;
181     if (MV2_Bcast_function == nullptr) {
182       MV2_Bcast_function = bcast__mpich;
183     }
184
185     if (MV2_Bcast_intra_node_function == nullptr) {
186       MV2_Bcast_intra_node_function = bcast__mpich;
187     }
188
189     if(comm->get_leaders_comm()==MPI_COMM_NULL){
190       comm->init_smp();
191     }
192
193     local_size = comm->size();
194     rank = comm->rank();
195
196     auto* reqarray = new MPI_Request[2 * mv2_intra_node_knomial_factor];
197
198     auto* starray = new MPI_Status[2 * mv2_intra_node_knomial_factor];
199
200     /* intra-node k-nomial bcast  */
201     if (local_size > 1) {
202         relative_rank = (rank >= root) ? rank - root : rank - root + local_size;
203         mask = 0x1;
204
205         while (mask < local_size) {
206             if (relative_rank % (mv2_intra_node_knomial_factor * mask)) {
207                 src = relative_rank / (mv2_intra_node_knomial_factor * mask) *
208                     (mv2_intra_node_knomial_factor * mask) + root;
209                 if (src >= local_size) {
210                     src -= local_size;
211                 }
212
213                 Request::recv(buffer, count, datatype, src,
214                                          COLL_TAG_BCAST, comm,
215                                          MPI_STATUS_IGNORE);
216                 break;
217             }
218             mask *= mv2_intra_node_knomial_factor;
219         }
220         mask /= mv2_intra_node_knomial_factor;
221
222         while (mask > 0) {
223             int reqs = 0;
224             for (k = 1; k < mv2_intra_node_knomial_factor; k++) {
225                 if (relative_rank + mask * k < local_size) {
226                     dst = rank + mask * k;
227                     if (dst >= local_size) {
228                         dst -= local_size;
229                     }
230                     reqarray[reqs++]=Request::isend(buffer, count, datatype, dst,
231                                               COLL_TAG_BCAST, comm);
232                 }
233             }
234             Request::waitall(reqs, reqarray, starray);
235
236             mask /= mv2_intra_node_knomial_factor;
237         }
238     }
239     delete[] reqarray;
240     delete[] starray;
241     return mpi_errno;
242 }
243
244
245 int bcast__mvapich2_intra_node(void *buffer,
246                                int count,
247                                MPI_Datatype datatype,
248                                int root, MPI_Comm  comm)
249 {
250     int mpi_errno = MPI_SUCCESS;
251     int comm_size;
252     bool two_level_bcast = true;
253     size_t nbytes = 0;
254     bool is_homogeneous, is_contig;
255     MPI_Aint type_size;
256     unsigned char* tmp_buf = nullptr;
257     MPI_Comm shmem_comm;
258
259     if (count == 0)
260         return MPI_SUCCESS;
261     if (MV2_Bcast_function == nullptr) {
262       MV2_Bcast_function = bcast__mpich;
263     }
264
265     if (MV2_Bcast_intra_node_function == nullptr) {
266       MV2_Bcast_intra_node_function = bcast__mpich;
267     }
268
269     if(comm->get_leaders_comm()==MPI_COMM_NULL){
270       comm->init_smp();
271     }
272
273     comm_size = comm->size();
274    // rank = comm->rank();
275 /*
276     if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
277         is_contig = true;
278 /*    else {
279         MPID_Datatype_get_ptr(datatype, dtp);
280         is_contig = dtp->is_contig;
281     }
282 */
283     is_homogeneous = true;
284 #ifdef MPID_HAS_HETERO
285     if (comm_ptr->is_hetero)
286       is_homogeneous = false;
287 #endif
288
289     /* MPI_Type_size() might not give the accurate size of the packed
290      * datatype for heterogeneous systems (because of padding, encoding,
291      * etc). On the other hand, MPI_Pack_size() can become very
292      * expensive, depending on the implementation, especially for
293      * heterogeneous systems. We want to use MPI_Type_size() wherever
294      * possible, and MPI_Pack_size() in other places.
295      */
296     //if (is_homogeneous) {
297         type_size=datatype->size();
298     //}
299 /*    else {*/
300 /*        MPIR_Pack_size_impl(1, datatype, &type_size);*/
301 /*    }*/
302     nbytes = (size_t) (count) * (type_size);
303     if (comm_size <= mv2_bcast_two_level_system_size) {
304         if (nbytes > mv2_bcast_short_msg && nbytes < mv2_bcast_large_msg) {
305           two_level_bcast = true;
306         } else {
307           two_level_bcast = false;
308         }
309     }
310
311     if (two_level_bcast
312 #if defined(_MCST_SUPPORT_)
313             || comm_ptr->ch.is_mcast_ok
314 #endif
315         ) {
316
317       if (not is_contig || not is_homogeneous) {
318         tmp_buf = smpi_get_tmp_sendbuffer(nbytes);
319
320         /* TODO: Pipeline the packing and communication */
321         // position = 0;
322         /*            if (rank == root) {*/
323         /*                mpi_errno =*/
324         /*                    MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
325         /*                if (mpi_errno)*/
326         /*                    MPIU_ERR_POP(mpi_errno);*/
327         /*            }*/
328         }
329
330         shmem_comm = comm->get_intra_comm();
331         if (not is_contig || not is_homogeneous) {
332           mpi_errno = MPIR_Bcast_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
333         } else {
334             mpi_errno =
335                 MPIR_Bcast_inter_node_helper_MV2(buffer, count, datatype, root,
336                                                  comm);
337         }
338
339         /* We are now done with the inter-node phase */
340             if (nbytes <= mv2_knomial_intra_node_threshold) {
341               if (not is_contig || not is_homogeneous) {
342                 mpi_errno = MPIR_Shmem_Bcast_MV2(tmp_buf, nbytes, MPI_BYTE, root, shmem_comm);
343                 } else {
344                     mpi_errno = MPIR_Shmem_Bcast_MV2(buffer, count, datatype,
345                                                      root, shmem_comm);
346                 }
347             } else {
348               if (not is_contig || not is_homogeneous) {
349                 mpi_errno = MPIR_Knomial_Bcast_intra_node_MV2(tmp_buf, nbytes, MPI_BYTE, INTRA_NODE_ROOT, shmem_comm);
350                 } else {
351                     mpi_errno =
352                         MPIR_Knomial_Bcast_intra_node_MV2(buffer, count,
353                                                           datatype,
354                                                           INTRA_NODE_ROOT,
355                                                           shmem_comm);
356                 }
357             }
358
359     } else {
360         if (nbytes <= mv2_bcast_short_msg) {
361             mpi_errno = MPIR_Bcast_binomial_MV2(buffer, count, datatype, root,
362                                                 comm);
363         } else {
364             if (mv2_scatter_rd_inter_leader_bcast) {
365                 mpi_errno = MPIR_Bcast_scatter_ring_allgather_MV2(buffer, count,
366                                                                   datatype,
367                                                                   root,
368                                                                   comm);
369             } else {
370                 mpi_errno =
371                     MPIR_Bcast_scatter_doubling_allgather_MV2(buffer, count,
372                                                               datatype, root,
373                                                               comm);
374             }
375         }
376     }
377
378
379     return mpi_errno;
380
381 }
382
383 }
384 }