Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add new entry in Release_Notes.
[simgrid.git] / src / smpi / colls / gather / gather-mvapich.cpp
1 /* Copyright (c) 2013-2023. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37
38 #include "../colls_private.hpp"
39 #include <algorithm>
40
41 #define MPIR_Gather_MV2_Direct gather__ompi_basic_linear
42 #define MPIR_Gather_MV2_two_level_Direct gather__ompi_basic_linear
43 #define MPIR_Gather_intra gather__mpich
44 typedef int (*MV2_Gather_function_ptr) (const void *sendbuf,
45     int sendcnt,
46     MPI_Datatype sendtype,
47     void *recvbuf,
48     int recvcnt,
49     MPI_Datatype recvtype,
50     int root, MPI_Comm comm);
51
52 extern MV2_Gather_function_ptr MV2_Gather_inter_leader_function;
53 extern MV2_Gather_function_ptr MV2_Gather_intra_node_function;
54
55 #define TEMP_BUF_HAS_NO_DATA (false)
56 #define TEMP_BUF_HAS_DATA (true)
57
58 namespace simgrid::smpi {
59
60 /* sendbuf           - (in) sender's buffer
61  * sendcnt           - (in) sender's element count
62  * sendtype          - (in) sender's data type
63  * recvbuf           - (in) receiver's buffer
64  * recvcnt           - (in) receiver's element count
65  * recvtype          - (in) receiver's data type
66  * root              - (in)root for the gather operation
67  * rank              - (in) global rank(rank in the global comm)
68  * tmp_buf           - (out/in) tmp_buf into which intra node
69  *                     data is gathered
70  * is_data_avail     - (in) based on this, tmp_buf acts
71  *                     as in/out parameter.
72  *                     1 - tmp_buf acts as in parameter
73  *                     0 - tmp_buf acts as out parameter
74  * comm_ptr          - (in) pointer to the communicator
75  *                     (shmem_comm or intra_sock_comm or
76  *                     inter-sock_leader_comm)
77  * intra_node_fn_ptr - (in) Function ptr to choose the
78  *                      intra node gather function
79  * errflag           - (out) to record errors
80  */
81 static int MPIR_pt_pt_intra_gather(const void* sendbuf, int sendcnt, MPI_Datatype sendtype, void* recvbuf, int recvcnt,
82                                    MPI_Datatype recvtype, int root, int rank, void* tmp_buf, int nbytes,
83                                    bool is_data_avail, MPI_Comm comm, MV2_Gather_function_ptr intra_node_fn_ptr)
84 {
85     int mpi_errno = MPI_SUCCESS;
86     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
87     MPI_Aint true_lb, sendtype_true_extent, recvtype_true_extent;
88
89
90     if (sendtype != MPI_DATATYPE_NULL) {
91         sendtype->extent(&true_lb,
92                                        &sendtype_true_extent);
93     }
94     if (recvtype != MPI_DATATYPE_NULL) {
95         recvtype_extent=recvtype->get_extent();
96         recvtype->extent(&true_lb,
97                                        &recvtype_true_extent);
98     }
99
100     /* Special case, when tmp_buf itself has data */
101     if (rank == root && sendbuf == MPI_IN_PLACE && is_data_avail) {
102
103          mpi_errno = intra_node_fn_ptr(MPI_IN_PLACE,
104                                        sendcnt, sendtype, tmp_buf, nbytes,
105                                        MPI_BYTE, 0, comm);
106
107     } else if (rank == root && sendbuf == MPI_IN_PLACE) {
108          mpi_errno = intra_node_fn_ptr((char*)recvbuf +
109                                        rank * recvcnt * recvtype_extent,
110                                        recvcnt, recvtype, tmp_buf, nbytes,
111                                        MPI_BYTE, 0, comm);
112     } else {
113         mpi_errno = intra_node_fn_ptr(sendbuf, sendcnt, sendtype,
114                                       tmp_buf, nbytes, MPI_BYTE,
115                                       0, comm);
116     }
117
118     return mpi_errno;
119
120 }
121
122
123
124 int gather__mvapich2_two_level(const void *sendbuf,
125                                int sendcnt,
126                                MPI_Datatype sendtype,
127                                void *recvbuf,
128                                int recvcnt,
129                                MPI_Datatype recvtype,
130                                int root,
131                                MPI_Comm comm)
132 {
133   unsigned char* leader_gather_buf = nullptr;
134   int comm_size, rank;
135   int local_rank, local_size;
136   int leader_comm_rank = -1, leader_comm_size = 0;
137   int mpi_errno     = MPI_SUCCESS;
138   int recvtype_size = 0, sendtype_size = 0, nbytes = 0;
139   int leader_root, leader_of_root;
140   MPI_Status status;
141   MPI_Aint sendtype_extent = 0, recvtype_extent = 0; /* Datatype extent */
142   MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
143   MPI_Comm shmem_comm, leader_comm;
144   unsigned char* tmp_buf = nullptr;
145
146   // if not set (use of the algo directly, without mvapich2 selector)
147   if (MV2_Gather_intra_node_function == nullptr)
148     MV2_Gather_intra_node_function = gather__mpich;
149
150   if (comm->get_leaders_comm() == MPI_COMM_NULL) {
151     comm->init_smp();
152     }
153     comm_size = comm->size();
154     rank = comm->rank();
155
156     if (((rank == root) && (recvcnt == 0)) ||
157         ((rank != root) && (sendcnt == 0))) {
158         return MPI_SUCCESS;
159     }
160
161     if (sendtype != MPI_DATATYPE_NULL) {
162         sendtype_extent=sendtype->get_extent();
163         sendtype_size=sendtype->size();
164         sendtype->extent(&true_lb,
165                                        &sendtype_true_extent);
166     }
167     if (recvtype != MPI_DATATYPE_NULL) {
168         recvtype_extent=recvtype->get_extent();
169         recvtype_size=recvtype->size();
170         recvtype->extent(&true_lb,
171                                        &recvtype_true_extent);
172     }
173
174     /* extract the rank,size information for the intra-node
175      * communicator */
176     shmem_comm = comm->get_intra_comm();
177     local_rank = shmem_comm->rank();
178     local_size = shmem_comm->size();
179
180     if (local_rank == 0) {
181         /* Node leader. Extract the rank, size information for the leader
182          * communicator */
183         leader_comm = comm->get_leaders_comm();
184         if(leader_comm==MPI_COMM_NULL){
185           leader_comm = MPI_COMM_WORLD;
186         }
187         leader_comm_size = leader_comm->size();
188         leader_comm_rank = leader_comm->size();
189     }
190
191     if (rank == root) {
192         nbytes = recvcnt * recvtype_size;
193
194     } else {
195         nbytes = sendcnt * sendtype_size;
196     }
197
198 #if defined(_SMP_LIMIC_)
199      if((g_use_limic2_coll) && (shmem_commptr->ch.use_intra_sock_comm == 1)
200          && (use_limic_gather)
201          &&((num_scheme == USE_GATHER_PT_PT_BINOMIAL)
202             || (num_scheme == USE_GATHER_PT_PT_DIRECT)
203             ||(num_scheme == USE_GATHER_PT_LINEAR_BINOMIAL)
204             || (num_scheme == USE_GATHER_PT_LINEAR_DIRECT)
205             || (num_scheme == USE_GATHER_LINEAR_PT_BINOMIAL)
206             || (num_scheme == USE_GATHER_LINEAR_PT_DIRECT)
207             || (num_scheme == USE_GATHER_LINEAR_LINEAR)
208             || (num_scheme == USE_GATHER_SINGLE_LEADER))) {
209
210             mpi_errno = MV2_Gather_intra_node_function(sendbuf, sendcnt, sendtype,
211                                                     recvbuf, recvcnt,recvtype,
212                                                     root, comm);
213      } else
214
215 #endif/*#if defined(_SMP_LIMIC_)*/
216     {
217         if (local_rank == 0) {
218             /* Node leader, allocate tmp_buffer */
219             if (rank == root) {
220               tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * std::max(recvtype_extent, recvtype_true_extent) * local_size);
221             } else {
222               tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * std::max(sendtype_extent, sendtype_true_extent) * local_size);
223             }
224             if (tmp_buf == nullptr) {
225               mpi_errno = MPI_ERR_OTHER;
226               return mpi_errno;
227             }
228         }
229          /*while testing mpich2 gather test, we see that
230          * which basically splits the comm, and we come to
231          * a point, where use_intra_sock_comm == 0, but if the
232          * intra node function is MPIR_Intra_node_LIMIC_Gather_MV2,
233          * it would use the intra sock comm. In such cases, we
234          * fallback to binomial as a default case.*/
235 #if defined(_SMP_LIMIC_)
236         if(*MV2_Gather_intra_node_function == MPIR_Intra_node_LIMIC_Gather_MV2) {
237
238             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
239                                                  recvbuf, recvcnt, recvtype,
240                                                  root, rank,
241                                                  tmp_buf, nbytes,
242                                                  TEMP_BUF_HAS_NO_DATA,
243                                                  shmem_commptr,
244                                                  MPIR_Gather_intra);
245         } else
246 #endif
247         {
248             /*We are gathering the data into tmp_buf and the output
249              * will be of MPI_BYTE datatype. Since the tmp_buf has no
250              * local data, we pass is_data_avail = TEMP_BUF_HAS_NO_DATA*/
251             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
252                                                  recvbuf, recvcnt, recvtype,
253                                                  root, rank,
254                                                  tmp_buf, nbytes,
255                                                  TEMP_BUF_HAS_NO_DATA,
256                                                  shmem_comm,
257                                                  MV2_Gather_intra_node_function
258                                                  );
259         }
260     }
261     leader_comm = comm->get_leaders_comm();
262     int* leaders_map = comm->get_leaders_map();
263     leader_of_root = comm->group()->rank(leaders_map[root]);
264     leader_root = leader_comm->group()->rank(leaders_map[root]);
265     /* leader_root is the rank of the leader of the root in leader_comm.
266      * leader_root is to be used as the root of the inter-leader gather ops
267      */
268     if (not comm->is_uniform()) {
269       if (local_rank == 0) {
270         int* displs   = nullptr;
271         int* recvcnts = nullptr;
272         int* node_sizes;
273         int i = 0;
274         /* Node leaders have all the data. But, different nodes can have
275          * different number of processes. Do a Gather first to get the
276          * buffer lengths at each leader, followed by a Gatherv to move
277          * the actual data */
278
279         if (leader_comm_rank == leader_root && root != leader_of_root) {
280           /* The root of the Gather operation is not a node-level
281            * leader and this process's rank in the leader_comm
282            * is the same as leader_root */
283           if (rank == root) {
284             leader_gather_buf =
285                 smpi_get_tmp_recvbuffer(recvcnt * std::max(recvtype_extent, recvtype_true_extent) * comm_size);
286           } else {
287             leader_gather_buf =
288                 smpi_get_tmp_sendbuffer(sendcnt * std::max(sendtype_extent, sendtype_true_extent) * comm_size);
289           }
290           if (leader_gather_buf == nullptr) {
291             mpi_errno = MPI_ERR_OTHER;
292             return mpi_errno;
293           }
294         }
295
296         node_sizes = comm->get_non_uniform_map();
297
298         if (leader_comm_rank == leader_root) {
299           displs   = new int[leader_comm_size];
300           recvcnts = new int[leader_comm_size];
301         }
302
303         if (root == leader_of_root) {
304           /* The root of the gather operation is also the node
305            * leader. Receive into recvbuf and we are done */
306           if (leader_comm_rank == leader_root) {
307             recvcnts[0] = node_sizes[0] * recvcnt;
308             displs[0]   = 0;
309
310             for (i = 1; i < leader_comm_size; i++) {
311               displs[i]   = displs[i - 1] + node_sizes[i - 1] * recvcnt;
312               recvcnts[i] = node_sizes[i] * recvcnt;
313             }
314           }
315           colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, recvbuf, recvcnts, displs, recvtype, leader_root,
316                          leader_comm);
317         } else {
318           /* The root of the gather operation is not the node leader.
319            * Receive into leader_gather_buf and then send
320            * to the root */
321           if (leader_comm_rank == leader_root) {
322             recvcnts[0] = node_sizes[0] * nbytes;
323             displs[0]   = 0;
324
325             for (i = 1; i < leader_comm_size; i++) {
326               displs[i]   = displs[i - 1] + node_sizes[i - 1] * nbytes;
327               recvcnts[i] = node_sizes[i] * nbytes;
328             }
329           }
330           colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, leader_gather_buf, recvcnts, displs, MPI_BYTE,
331                          leader_root, leader_comm);
332         }
333         if (leader_comm_rank == leader_root) {
334           delete[] displs;
335           delete[] recvcnts;
336         }
337       }
338     } else {
339         /* All nodes have the same number of processes.
340          * Just do one Gather to get all
341          * the data at the leader of the root process */
342         if (local_rank == 0) {
343             if (leader_comm_rank == leader_root && root != leader_of_root) {
344                 /* The root of the Gather operation is not a node-level leader
345                  */
346                 leader_gather_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
347                 if (leader_gather_buf == nullptr) {
348                   mpi_errno = MPI_ERR_OTHER;
349                   return mpi_errno;
350                 }
351             }
352             if (root == leader_of_root) {
353                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf,
354                                                    nbytes * local_size,
355                                                    MPI_BYTE, recvbuf,
356                                                    recvcnt * local_size,
357                                                    recvtype, leader_root,
358                                                    leader_comm);
359
360             } else {
361                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf, nbytes * local_size,
362                                                    MPI_BYTE, leader_gather_buf,
363                                                    nbytes * local_size,
364                                                    MPI_BYTE, leader_root,
365                                                    leader_comm);
366             }
367         }
368     }
369     if ((local_rank == 0) && (root != rank)
370         && (leader_of_root == rank)) {
371         Request::send(leader_gather_buf,
372                                  nbytes * comm_size, MPI_BYTE,
373                                  root, COLL_TAG_GATHER, comm);
374     }
375
376     if (rank == root && local_rank != 0) {
377         /* The root of the gather operation is not the node leader. Receive
378          y* data from the node leader */
379         Request::recv(recvbuf, recvcnt * comm_size, recvtype,
380                                  leader_of_root, COLL_TAG_GATHER, comm,
381                                  &status);
382     }
383
384     /* check if multiple threads are calling this collective function */
385     if (local_rank == 0 ) {
386       if (tmp_buf != nullptr) {
387         smpi_free_tmp_buffer(tmp_buf);
388       }
389       if (leader_gather_buf != nullptr) {
390         smpi_free_tmp_buffer(leader_gather_buf);
391       }
392     }
393
394     return (mpi_errno);
395 }
396 } // namespace simgrid::smpi