Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Rename C++ only header files from .h to .hpp.
[simgrid.git] / src / smpi / colls / gather / gather-mvapich.cpp
1 /* Copyright (c) 2013-2017. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37
38 #include "../colls_private.hpp"
39
40 #define MPIR_Gather_MV2_Direct Coll_gather_ompi_basic_linear::gather
41 #define MPIR_Gather_MV2_two_level_Direct Coll_gather_ompi_basic_linear::gather
42 #define MPIR_Gather_intra Coll_gather_mpich::gather
43 typedef int (*MV2_Gather_function_ptr) (void *sendbuf,
44     int sendcnt,
45     MPI_Datatype sendtype,
46     void *recvbuf,
47     int recvcnt,
48     MPI_Datatype recvtype,
49     int root, MPI_Comm comm);
50
51 extern MV2_Gather_function_ptr MV2_Gather_inter_leader_function;
52 extern MV2_Gather_function_ptr MV2_Gather_intra_node_function;
53
54 #define TEMP_BUF_HAS_NO_DATA (0)
55 #define TEMP_BUF_HAS_DATA (1)
56
57
58 namespace simgrid{
59 namespace smpi{
60
61 /* sendbuf           - (in) sender's buffer
62  * sendcnt           - (in) sender's element count
63  * sendtype          - (in) sender's data type
64  * recvbuf           - (in) receiver's buffer
65  * recvcnt           - (in) receiver's element count
66  * recvtype          - (in) receiver's data type
67  * root              - (in)root for the gather operation
68  * rank              - (in) global rank(rank in the global comm)
69  * tmp_buf           - (out/in) tmp_buf into which intra node
70  *                     data is gathered
71  * is_data_avail     - (in) based on this, tmp_buf acts
72  *                     as in/out parameter.
73  *                     1 - tmp_buf acts as in parameter
74  *                     0 - tmp_buf acts as out parameter
75  * comm_ptr          - (in) pointer to the communicator
76  *                     (shmem_comm or intra_sock_comm or
77  *                     inter-sock_leader_comm)
78  * intra_node_fn_ptr - (in) Function ptr to choose the
79  *                      intra node gather function
80  * errflag           - (out) to record errors
81  */
82 static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sendtype,
83                             void *recvbuf, int recvcnt, MPI_Datatype recvtype,
84                             int root, int rank,
85                             void *tmp_buf, int nbytes,
86                             int is_data_avail,
87                             MPI_Comm comm,
88                             MV2_Gather_function_ptr intra_node_fn_ptr)
89 {
90     int mpi_errno = MPI_SUCCESS;
91     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
92     MPI_Aint true_lb, sendtype_true_extent, recvtype_true_extent;
93
94
95     if (sendtype != MPI_DATATYPE_NULL) {
96         sendtype->extent(&true_lb,
97                                        &sendtype_true_extent);
98     }
99     if (recvtype != MPI_DATATYPE_NULL) {
100         recvtype_extent=recvtype->get_extent();
101         recvtype->extent(&true_lb,
102                                        &recvtype_true_extent);
103     }
104
105     /* Special case, when tmp_buf itself has data */
106     if (rank == root && sendbuf == MPI_IN_PLACE && is_data_avail) {
107
108          mpi_errno = intra_node_fn_ptr(MPI_IN_PLACE,
109                                        sendcnt, sendtype, tmp_buf, nbytes,
110                                        MPI_BYTE, 0, comm);
111
112     } else if (rank == root && sendbuf == MPI_IN_PLACE) {
113          mpi_errno = intra_node_fn_ptr((char*)recvbuf +
114                                        rank * recvcnt * recvtype_extent,
115                                        recvcnt, recvtype, tmp_buf, nbytes,
116                                        MPI_BYTE, 0, comm);
117     } else {
118         mpi_errno = intra_node_fn_ptr(sendbuf, sendcnt, sendtype,
119                                       tmp_buf, nbytes, MPI_BYTE,
120                                       0, comm);
121     }
122
123     return mpi_errno;
124
125 }
126
127
128
129 int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
130                                             int sendcnt,
131                                             MPI_Datatype sendtype,
132                                             void *recvbuf,
133                                             int recvcnt,
134                                             MPI_Datatype recvtype,
135                                             int root,
136                                             MPI_Comm comm)
137 {
138     void *leader_gather_buf = NULL;
139     int comm_size, rank;
140     int local_rank, local_size;
141     int leader_comm_rank = -1, leader_comm_size = 0;
142     int mpi_errno = MPI_SUCCESS;
143     int recvtype_size = 0, sendtype_size = 0, nbytes=0;
144     int leader_root, leader_of_root;
145     MPI_Status status;
146     MPI_Aint sendtype_extent = 0, recvtype_extent = 0;  /* Datatype extent */
147     MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
148     MPI_Comm shmem_comm, leader_comm;
149     void* tmp_buf = NULL;
150
151
152     //if not set (use of the algo directly, without mvapich2 selector)
153     if(MV2_Gather_intra_node_function==NULL)
154       MV2_Gather_intra_node_function= Coll_gather_mpich::gather;
155
156     if(comm->get_leaders_comm()==MPI_COMM_NULL){
157       comm->init_smp();
158     }
159     comm_size = comm->size();
160     rank = comm->rank();
161
162     if (((rank == root) && (recvcnt == 0)) ||
163         ((rank != root) && (sendcnt == 0))) {
164         return MPI_SUCCESS;
165     }
166
167     if (sendtype != MPI_DATATYPE_NULL) {
168         sendtype_extent=sendtype->get_extent();
169         sendtype_size=sendtype->size();
170         sendtype->extent(&true_lb,
171                                        &sendtype_true_extent);
172     }
173     if (recvtype != MPI_DATATYPE_NULL) {
174         recvtype_extent=recvtype->get_extent();
175         recvtype_size=recvtype->size();
176         recvtype->extent(&true_lb,
177                                        &recvtype_true_extent);
178     }
179
180     /* extract the rank,size information for the intra-node
181      * communicator */
182     shmem_comm = comm->get_intra_comm();
183     local_rank = shmem_comm->rank();
184     local_size = shmem_comm->size();
185
186     if (local_rank == 0) {
187         /* Node leader. Extract the rank, size information for the leader
188          * communicator */
189         leader_comm = comm->get_leaders_comm();
190         if(leader_comm==MPI_COMM_NULL){
191           leader_comm = MPI_COMM_WORLD;
192         }
193         leader_comm_size = leader_comm->size();
194         leader_comm_rank = leader_comm->size();
195     }
196
197     if (rank == root) {
198         nbytes = recvcnt * recvtype_size;
199
200     } else {
201         nbytes = sendcnt * sendtype_size;
202     }
203
204 #if defined(_SMP_LIMIC_)
205      if((g_use_limic2_coll) && (shmem_commptr->ch.use_intra_sock_comm == 1)
206          && (use_limic_gather)
207          &&((num_scheme == USE_GATHER_PT_PT_BINOMIAL)
208             || (num_scheme == USE_GATHER_PT_PT_DIRECT)
209             ||(num_scheme == USE_GATHER_PT_LINEAR_BINOMIAL)
210             || (num_scheme == USE_GATHER_PT_LINEAR_DIRECT)
211             || (num_scheme == USE_GATHER_LINEAR_PT_BINOMIAL)
212             || (num_scheme == USE_GATHER_LINEAR_PT_DIRECT)
213             || (num_scheme == USE_GATHER_LINEAR_LINEAR)
214             || (num_scheme == USE_GATHER_SINGLE_LEADER))) {
215
216             mpi_errno = MV2_Gather_intra_node_function(sendbuf, sendcnt, sendtype,
217                                                     recvbuf, recvcnt,recvtype,
218                                                     root, comm);
219      } else
220
221 #endif/*#if defined(_SMP_LIMIC_)*/
222     {
223         if (local_rank == 0) {
224             /* Node leader, allocate tmp_buffer */
225             if (rank == root) {
226                 tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent,
227                             recvtype_true_extent) * local_size);
228             } else {
229                 tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent,
230                             sendtype_true_extent) *
231                         local_size);
232             }
233             if (tmp_buf == NULL) {
234                 mpi_errno = MPI_ERR_OTHER;
235                 return mpi_errno;
236             }
237         }
238          /*while testing mpich2 gather test, we see that
239          * which basically splits the comm, and we come to
240          * a point, where use_intra_sock_comm == 0, but if the
241          * intra node function is MPIR_Intra_node_LIMIC_Gather_MV2,
242          * it would use the intra sock comm. In such cases, we
243          * fallback to binomial as a default case.*/
244 #if defined(_SMP_LIMIC_)
245         if(*MV2_Gather_intra_node_function == MPIR_Intra_node_LIMIC_Gather_MV2) {
246
247             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
248                                                  recvbuf, recvcnt, recvtype,
249                                                  root, rank,
250                                                  tmp_buf, nbytes,
251                                                  TEMP_BUF_HAS_NO_DATA,
252                                                  shmem_commptr,
253                                                  MPIR_Gather_intra);
254         } else
255 #endif
256         {
257             /*We are gathering the data into tmp_buf and the output
258              * will be of MPI_BYTE datatype. Since the tmp_buf has no
259              * local data, we pass is_data_avail = TEMP_BUF_HAS_NO_DATA*/
260             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
261                                                  recvbuf, recvcnt, recvtype,
262                                                  root, rank,
263                                                  tmp_buf, nbytes,
264                                                  TEMP_BUF_HAS_NO_DATA,
265                                                  shmem_comm,
266                                                  MV2_Gather_intra_node_function
267                                                  );
268         }
269     }
270     leader_comm = comm->get_leaders_comm();
271     int* leaders_map = comm->get_leaders_map();
272     leader_of_root = comm->group()->rank(leaders_map[root]);
273     leader_root = leader_comm->group()->rank(leaders_map[root]);
274     /* leader_root is the rank of the leader of the root in leader_comm.
275      * leader_root is to be used as the root of the inter-leader gather ops
276      */
277     if (not comm->is_uniform()) {
278       if (local_rank == 0) {
279         int* displs   = NULL;
280         int* recvcnts = NULL;
281         int* node_sizes;
282         int i = 0;
283         /* Node leaders have all the data. But, different nodes can have
284          * different number of processes. Do a Gather first to get the
285          * buffer lengths at each leader, followed by a Gatherv to move
286          * the actual data */
287
288         if (leader_comm_rank == leader_root && root != leader_of_root) {
289           /* The root of the Gather operation is not a node-level
290            * leader and this process's rank in the leader_comm
291            * is the same as leader_root */
292           if (rank == root) {
293             leader_gather_buf =
294                 smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent, recvtype_true_extent) * comm_size);
295           } else {
296             leader_gather_buf =
297                 smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent, sendtype_true_extent) * comm_size);
298           }
299           if (leader_gather_buf == NULL) {
300             mpi_errno = MPI_ERR_OTHER;
301             return mpi_errno;
302           }
303         }
304
305         node_sizes = comm->get_non_uniform_map();
306
307         if (leader_comm_rank == leader_root) {
308           displs   = static_cast<int*>(xbt_malloc(sizeof(int) * leader_comm_size));
309           recvcnts = static_cast<int*>(xbt_malloc(sizeof(int) * leader_comm_size));
310           if (not displs || not recvcnts) {
311             mpi_errno = MPI_ERR_OTHER;
312             return mpi_errno;
313           }
314         }
315
316         if (root == leader_of_root) {
317           /* The root of the gather operation is also the node
318            * leader. Receive into recvbuf and we are done */
319           if (leader_comm_rank == leader_root) {
320             recvcnts[0] = node_sizes[0] * recvcnt;
321             displs[0]   = 0;
322
323             for (i = 1; i < leader_comm_size; i++) {
324               displs[i]   = displs[i - 1] + node_sizes[i - 1] * recvcnt;
325               recvcnts[i] = node_sizes[i] * recvcnt;
326             }
327           }
328           Colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, recvbuf, recvcnts, displs, recvtype, leader_root,
329                          leader_comm);
330         } else {
331           /* The root of the gather operation is not the node leader.
332            * Receive into leader_gather_buf and then send
333            * to the root */
334           if (leader_comm_rank == leader_root) {
335             recvcnts[0] = node_sizes[0] * nbytes;
336             displs[0]   = 0;
337
338             for (i = 1; i < leader_comm_size; i++) {
339               displs[i]   = displs[i - 1] + node_sizes[i - 1] * nbytes;
340               recvcnts[i] = node_sizes[i] * nbytes;
341             }
342           }
343           Colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, leader_gather_buf, recvcnts, displs, MPI_BYTE,
344                          leader_root, leader_comm);
345         }
346         if (leader_comm_rank == leader_root) {
347           xbt_free(displs);
348           xbt_free(recvcnts);
349         }
350       }
351     } else {
352         /* All nodes have the same number of processes.
353          * Just do one Gather to get all
354          * the data at the leader of the root process */
355         if (local_rank == 0) {
356             if (leader_comm_rank == leader_root && root != leader_of_root) {
357                 /* The root of the Gather operation is not a node-level leader
358                  */
359                 leader_gather_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
360                 if (leader_gather_buf == NULL) {
361                     mpi_errno = MPI_ERR_OTHER;
362                     return mpi_errno;
363                 }
364             }
365             if (root == leader_of_root) {
366                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf,
367                                                    nbytes * local_size,
368                                                    MPI_BYTE, recvbuf,
369                                                    recvcnt * local_size,
370                                                    recvtype, leader_root,
371                                                    leader_comm);
372
373             } else {
374                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf, nbytes * local_size,
375                                                    MPI_BYTE, leader_gather_buf,
376                                                    nbytes * local_size,
377                                                    MPI_BYTE, leader_root,
378                                                    leader_comm);
379             }
380         }
381     }
382     if ((local_rank == 0) && (root != rank)
383         && (leader_of_root == rank)) {
384         Request::send(leader_gather_buf,
385                                  nbytes * comm_size, MPI_BYTE,
386                                  root, COLL_TAG_GATHER, comm);
387     }
388
389     if (rank == root && local_rank != 0) {
390         /* The root of the gather operation is not the node leader. Receive
391          y* data from the node leader */
392         Request::recv(recvbuf, recvcnt * comm_size, recvtype,
393                                  leader_of_root, COLL_TAG_GATHER, comm,
394                                  &status);
395     }
396
397     /* check if multiple threads are calling this collective function */
398     if (local_rank == 0 ) {
399         if (tmp_buf != NULL) {
400             smpi_free_tmp_buffer(tmp_buf);
401         }
402         if (leader_gather_buf != NULL) {
403             smpi_free_tmp_buffer(leader_gather_buf);
404         }
405     }
406
407     return (mpi_errno);
408 }
409 }
410 }
411