Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines for 2022.
[simgrid.git] / src / smpi / colls / scatter / scatter-mvapich-two-level.cpp
1 /* Copyright (c) 2013-2022. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37 #include "../colls_private.hpp"
38
39 #define MPIR_Scatter_MV2_Binomial scatter__ompi_binomial
40 #define MPIR_Scatter_MV2_Direct scatter__ompi_basic_linear
41
42 extern int (*MV2_Scatter_intra_function) (const void *sendbuf, int sendcount, MPI_Datatype sendtype,
43     void *recvbuf, int recvcount, MPI_Datatype recvtype,
44     int root, MPI_Comm comm);
45
46 namespace simgrid{
47 namespace smpi{
48
49 int scatter__mvapich2_two_level_direct(const void *sendbuf,
50                                        int sendcnt,
51                                        MPI_Datatype sendtype,
52                                        void *recvbuf,
53                                        int recvcnt,
54                                        MPI_Datatype recvtype,
55                                        int root, MPI_Comm  comm)
56 {
57     int comm_size, rank;
58     int local_rank, local_size;
59     int leader_comm_rank = -1, leader_comm_size = -1;
60     int mpi_errno = MPI_SUCCESS;
61     int recvtype_size, sendtype_size, nbytes;
62     unsigned char* tmp_buf            = nullptr;
63     unsigned char* leader_scatter_buf = nullptr;
64     MPI_Status status;
65     int leader_root, leader_of_root = -1;
66     MPI_Comm shmem_comm, leader_comm;
67     //if not set (use of the algo directly, without mvapich2 selector)
68     if (MV2_Scatter_intra_function == nullptr)
69       MV2_Scatter_intra_function = scatter__mpich;
70
71     if(comm->get_leaders_comm()==MPI_COMM_NULL){
72       comm->init_smp();
73     }
74     comm_size = comm->size();
75     rank = comm->rank();
76
77     if (((rank == root) && (recvcnt == 0))
78         || ((rank != root) && (sendcnt == 0))) {
79         return MPI_SUCCESS;
80     }
81
82     /* extract the rank,size information for the intra-node
83      * communicator */
84     shmem_comm = comm->get_intra_comm();
85     local_rank = shmem_comm->rank();
86     local_size = shmem_comm->size();
87
88     if (local_rank == 0) {
89         /* Node leader. Extract the rank, size information for the leader
90          * communicator */
91         leader_comm = comm->get_leaders_comm();
92         leader_comm_size = leader_comm->size();
93         leader_comm_rank = leader_comm->rank();
94     }
95
96     if (local_size == comm_size) {
97         /* purely intra-node scatter. Just use the direct algorithm and we are done */
98         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
99                                             recvbuf, recvcnt, recvtype,
100                                             root, comm);
101
102     } else {
103         recvtype_size=recvtype->size();
104         sendtype_size=sendtype->size();
105
106         if (rank == root) {
107             nbytes = sendcnt * sendtype_size;
108         } else {
109             nbytes = recvcnt * recvtype_size;
110         }
111
112         if (local_rank == 0) {
113             /* Node leader, allocate tmp_buffer */
114             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
115         }
116
117         leader_comm = comm->get_leaders_comm();
118         int* leaders_map = comm->get_leaders_map();
119         leader_of_root = comm->group()->rank(leaders_map[root]);
120         leader_root = leader_comm->group()->rank(leaders_map[root]);
121         /* leader_root is the rank of the leader of the root in leader_comm.
122          * leader_root is to be used as the root of the inter-leader gather ops
123          */
124
125         if ((local_rank == 0) && (root != rank)
126             && (leader_of_root == rank)) {
127             /* The root of the scatter operation is not the node leader. Recv
128              * data from the node leader */
129             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
130             Request::recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
131                              root, COLL_TAG_SCATTER, comm, &status);
132
133         }
134
135         if (rank == root && local_rank != 0) {
136             /* The root of the scatter operation is not the node leader. Send
137              * data to the node leader */
138             Request::send(sendbuf, sendcnt * comm_size, sendtype,
139                                      leader_of_root, COLL_TAG_SCATTER, comm
140                                      );
141         }
142
143         if (leader_comm_size > 1 && local_rank == 0) {
144           if (not comm->is_uniform()) {
145             int* displs   = nullptr;
146             int* sendcnts = nullptr;
147             int* node_sizes;
148             int i      = 0;
149             node_sizes = comm->get_non_uniform_map();
150
151             if (root != leader_of_root) {
152               if (leader_comm_rank == leader_root) {
153                 displs      = new int[leader_comm_size];
154                 sendcnts    = new int[leader_comm_size];
155                 sendcnts[0] = node_sizes[0] * nbytes;
156                 displs[0]   = 0;
157
158                 for (i = 1; i < leader_comm_size; i++) {
159                   displs[i]   = displs[i - 1] + node_sizes[i - 1] * nbytes;
160                   sendcnts[i] = node_sizes[i] * nbytes;
161                 }
162               }
163               colls::scatterv(leader_scatter_buf, sendcnts, displs, MPI_BYTE, tmp_buf, nbytes * local_size, MPI_BYTE,
164                               leader_root, leader_comm);
165             } else {
166               if (leader_comm_rank == leader_root) {
167                 displs      = new int[leader_comm_size];
168                 sendcnts    = new int[leader_comm_size];
169                 sendcnts[0] = node_sizes[0] * sendcnt;
170                 displs[0]   = 0;
171
172                 for (i = 1; i < leader_comm_size; i++) {
173                   displs[i]   = displs[i - 1] + node_sizes[i - 1] * sendcnt;
174                   sendcnts[i] = node_sizes[i] * sendcnt;
175                 }
176               }
177               colls::scatterv(sendbuf, sendcnts, displs, sendtype, tmp_buf, nbytes * local_size, MPI_BYTE, leader_root,
178                               leader_comm);
179             }
180             if (leader_comm_rank == leader_root) {
181               delete[] displs;
182               delete[] sendcnts;
183             }
184             } else {
185                 if (leader_of_root != root) {
186                     mpi_errno =
187                         MPIR_Scatter_MV2_Direct(leader_scatter_buf,
188                                                 nbytes * local_size, MPI_BYTE,
189                                                 tmp_buf, nbytes * local_size,
190                                                 MPI_BYTE, leader_root,
191                                                 leader_comm);
192                 } else {
193                     mpi_errno =
194                         MPIR_Scatter_MV2_Direct(sendbuf, sendcnt * local_size,
195                                                 sendtype, tmp_buf,
196                                                 nbytes * local_size, MPI_BYTE,
197                                                 leader_root, leader_comm);
198
199                 }
200             }
201         }
202         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
203
204         if (rank == root && recvbuf == MPI_IN_PLACE) {
205             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
206                                                 (void *)sendbuf, sendcnt, sendtype,
207                                                 0, shmem_comm);
208         } else {
209             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
210                                                 recvbuf, recvcnt, recvtype,
211                                                 0, shmem_comm);
212         }
213     }
214
215     /* check if multiple threads are calling this collective function */
216     if (comm_size != local_size && local_rank == 0) {
217         smpi_free_tmp_buffer(tmp_buf);
218         if (leader_of_root == rank && root != rank) {
219             smpi_free_tmp_buffer(leader_scatter_buf);
220         }
221     }
222     return (mpi_errno);
223 }
224
225
226 int scatter__mvapich2_two_level_binomial(const void *sendbuf,
227                                          int sendcnt,
228                                          MPI_Datatype sendtype,
229                                          void *recvbuf,
230                                          int recvcnt,
231                                          MPI_Datatype recvtype,
232                                          int root, MPI_Comm comm)
233 {
234     int comm_size, rank;
235     int local_rank, local_size;
236     int leader_comm_rank = -1, leader_comm_size = -1;
237     int mpi_errno = MPI_SUCCESS;
238     int recvtype_size, sendtype_size, nbytes;
239     unsigned char* tmp_buf            = nullptr;
240     unsigned char* leader_scatter_buf = nullptr;
241     MPI_Status status;
242     int leader_root = -1, leader_of_root = -1;
243     MPI_Comm shmem_comm, leader_comm;
244
245
246     //if not set (use of the algo directly, without mvapich2 selector)
247     if (MV2_Scatter_intra_function == nullptr)
248       MV2_Scatter_intra_function = scatter__mpich;
249
250     if(comm->get_leaders_comm()==MPI_COMM_NULL){
251       comm->init_smp();
252     }
253     comm_size = comm->size();
254     rank = comm->rank();
255
256     if (((rank == root) && (recvcnt == 0))
257         || ((rank != root) && (sendcnt == 0))) {
258         return MPI_SUCCESS;
259     }
260
261     /* extract the rank,size information for the intra-node
262      * communicator */
263     shmem_comm = comm->get_intra_comm();
264     local_rank = shmem_comm->rank();
265     local_size = shmem_comm->size();
266
267     if (local_rank == 0) {
268         /* Node leader. Extract the rank, size information for the leader
269          * communicator */
270         leader_comm = comm->get_leaders_comm();
271         leader_comm_size = leader_comm->size();
272         leader_comm_rank = leader_comm->rank();
273     }
274
275     if (local_size == comm_size) {
276         /* purely intra-node scatter. Just use the direct algorithm and we are done */
277         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
278                                             recvbuf, recvcnt, recvtype,
279                                             root, comm);
280
281     } else {
282         recvtype_size=recvtype->size();
283         sendtype_size=sendtype->size();
284
285         if (rank == root) {
286             nbytes = sendcnt * sendtype_size;
287         } else {
288             nbytes = recvcnt * recvtype_size;
289         }
290
291         if (local_rank == 0) {
292             /* Node leader, allocate tmp_buffer */
293             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
294         }
295         leader_comm = comm->get_leaders_comm();
296         int* leaders_map = comm->get_leaders_map();
297         leader_of_root = comm->group()->rank(leaders_map[root]);
298         leader_root = leader_comm->group()->rank(leaders_map[root]);
299         /* leader_root is the rank of the leader of the root in leader_comm.
300          * leader_root is to be used as the root of the inter-leader gather ops
301          */
302
303         if ((local_rank == 0) && (root != rank)
304             && (leader_of_root == rank)) {
305             /* The root of the scatter operation is not the node leader. Recv
306              * data from the node leader */
307             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
308             Request::recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
309                              root, COLL_TAG_SCATTER, comm, &status);
310         }
311
312         if (rank == root && local_rank != 0) {
313             /* The root of the scatter operation is not the node leader. Send
314              * data to the node leader */
315             Request::send(sendbuf, sendcnt * comm_size, sendtype,
316                                      leader_of_root, COLL_TAG_SCATTER, comm);
317         }
318
319         if (leader_comm_size > 1 && local_rank == 0) {
320           if (not comm->is_uniform()) {
321             int* displs   = nullptr;
322             int* sendcnts = nullptr;
323             int* node_sizes;
324             int i      = 0;
325             node_sizes = comm->get_non_uniform_map();
326
327             if (root != leader_of_root) {
328               if (leader_comm_rank == leader_root) {
329                 displs      = new int[leader_comm_size];
330                 sendcnts    = new int[leader_comm_size];
331                 sendcnts[0] = node_sizes[0] * nbytes;
332                 displs[0]   = 0;
333
334                 for (i = 1; i < leader_comm_size; i++) {
335                   displs[i]   = displs[i - 1] + node_sizes[i - 1] * nbytes;
336                   sendcnts[i] = node_sizes[i] * nbytes;
337                 }
338               }
339               colls::scatterv(leader_scatter_buf, sendcnts, displs, MPI_BYTE, tmp_buf, nbytes * local_size, MPI_BYTE,
340                               leader_root, leader_comm);
341             } else {
342               if (leader_comm_rank == leader_root) {
343                 displs      = new int[leader_comm_size];
344                 sendcnts    = new int[leader_comm_size];
345                 sendcnts[0] = node_sizes[0] * sendcnt;
346                 displs[0]   = 0;
347
348                 for (i = 1; i < leader_comm_size; i++) {
349                   displs[i]   = displs[i - 1] + node_sizes[i - 1] * sendcnt;
350                   sendcnts[i] = node_sizes[i] * sendcnt;
351                 }
352               }
353               colls::scatterv(sendbuf, sendcnts, displs, sendtype, tmp_buf, nbytes * local_size, MPI_BYTE, leader_root,
354                               leader_comm);
355             }
356             if (leader_comm_rank == leader_root) {
357               delete[] displs;
358               delete[] sendcnts;
359             }
360             } else {
361                 if (leader_of_root != root) {
362                     mpi_errno =
363                         MPIR_Scatter_MV2_Binomial(leader_scatter_buf,
364                                                   nbytes * local_size, MPI_BYTE,
365                                                   tmp_buf, nbytes * local_size,
366                                                   MPI_BYTE, leader_root,
367                                                   leader_comm);
368                 } else {
369                     mpi_errno =
370                         MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt * local_size,
371                                                   sendtype, tmp_buf,
372                                                   nbytes * local_size, MPI_BYTE,
373                                                   leader_root, leader_comm);
374
375                 }
376             }
377         }
378         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
379
380         if (rank == root && recvbuf == MPI_IN_PLACE) {
381             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
382                                                 (void *)sendbuf, sendcnt, sendtype,
383                                                 0, shmem_comm);
384         } else {
385             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
386                                                 recvbuf, recvcnt, recvtype,
387                                                 0, shmem_comm);
388         }
389
390     }
391
392
393     /* check if multiple threads are calling this collective function */
394     if (comm_size != local_size && local_rank == 0) {
395         smpi_free_tmp_buffer(tmp_buf);
396         if (leader_of_root == rank && root != rank) {
397             smpi_free_tmp_buffer(leader_scatter_buf);
398         }
399     }
400
401     return (mpi_errno);
402 }
403
404 }
405 }
406