Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
9936d250fe75a20c46d86e7e398f90faf7fca261
[simgrid.git] / src / smpi / colls / gather / gather-mvapich.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37
38 #include "../colls_private.h"
39
40
41
42
43
44 #define MPIR_Gather_MV2_Direct Coll_gather_ompi_basic_linear::gather
45 #define MPIR_Gather_MV2_two_level_Direct Coll_gather_ompi_basic_linear::gather
46 #define MPIR_Gather_intra Coll_gather_mpich::gather
47 typedef int (*MV2_Gather_function_ptr) (void *sendbuf,
48     int sendcnt,
49     MPI_Datatype sendtype,
50     void *recvbuf,
51     int recvcnt,
52     MPI_Datatype recvtype,
53     int root, MPI_Comm comm);
54     
55 extern MV2_Gather_function_ptr MV2_Gather_inter_leader_function;
56 extern MV2_Gather_function_ptr MV2_Gather_intra_node_function;
57
58 #define TEMP_BUF_HAS_NO_DATA (0)
59 #define TEMP_BUF_HAS_DATA (1)
60
61
62 namespace simgrid{
63 namespace smpi{
64
65 /* sendbuf           - (in) sender's buffer
66  * sendcnt           - (in) sender's element count
67  * sendtype          - (in) sender's data type
68  * recvbuf           - (in) receiver's buffer
69  * recvcnt           - (in) receiver's element count
70  * recvtype          - (in) receiver's data type
71  * root              - (in)root for the gather operation
72  * rank              - (in) global rank(rank in the global comm)
73  * tmp_buf           - (out/in) tmp_buf into which intra node
74  *                     data is gathered
75  * is_data_avail     - (in) based on this, tmp_buf acts
76  *                     as in/out parameter.
77  *                     1 - tmp_buf acts as in parameter
78  *                     0 - tmp_buf acts as out parameter
79  * comm_ptr          - (in) pointer to the communicator
80  *                     (shmem_comm or intra_sock_comm or
81  *                     inter-sock_leader_comm)
82  * intra_node_fn_ptr - (in) Function ptr to choose the
83  *                      intra node gather function  
84  * errflag           - (out) to record errors
85  */
86 static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sendtype,
87                             void *recvbuf, int recvcnt, MPI_Datatype recvtype,
88                             int root, int rank, 
89                             void *tmp_buf, int nbytes,
90                             int is_data_avail,
91                             MPI_Comm comm,  
92                             MV2_Gather_function_ptr intra_node_fn_ptr)
93 {
94     int mpi_errno = MPI_SUCCESS;
95     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
96     MPI_Aint true_lb, sendtype_true_extent, recvtype_true_extent;
97
98
99     if (sendtype != MPI_DATATYPE_NULL) {
100         sendtype->extent(&true_lb,
101                                        &sendtype_true_extent);
102     }
103     if (recvtype != MPI_DATATYPE_NULL) {
104         recvtype_extent=recvtype->get_extent();
105         recvtype->extent(&true_lb,
106                                        &recvtype_true_extent);
107     }
108     
109     /* Special case, when tmp_buf itself has data */
110     if (rank == root && sendbuf == MPI_IN_PLACE && is_data_avail) {
111          
112          mpi_errno = intra_node_fn_ptr(MPI_IN_PLACE,
113                                        sendcnt, sendtype, tmp_buf, nbytes,
114                                        MPI_BYTE, 0, comm);
115
116     } else if (rank == root && sendbuf == MPI_IN_PLACE) {
117          mpi_errno = intra_node_fn_ptr((char*)recvbuf +
118                                        rank * recvcnt * recvtype_extent,
119                                        recvcnt, recvtype, tmp_buf, nbytes,
120                                        MPI_BYTE, 0, comm);
121     } else {
122         mpi_errno = intra_node_fn_ptr(sendbuf, sendcnt, sendtype,
123                                       tmp_buf, nbytes, MPI_BYTE,
124                                       0, comm);
125     }
126
127     return mpi_errno;
128
129 }
130
131
132
133 int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
134                                             int sendcnt,
135                                             MPI_Datatype sendtype,
136                                             void *recvbuf,
137                                             int recvcnt,
138                                             MPI_Datatype recvtype,
139                                             int root,
140                                             MPI_Comm comm)
141 {
142     void *leader_gather_buf = NULL;
143     int comm_size, rank;
144     int local_rank, local_size;
145     int leader_comm_rank = -1, leader_comm_size = 0;
146     int mpi_errno = MPI_SUCCESS;
147     int recvtype_size = 0, sendtype_size = 0, nbytes=0;
148     int leader_root, leader_of_root;
149     MPI_Status status;
150     MPI_Aint sendtype_extent = 0, recvtype_extent = 0;  /* Datatype extent */
151     MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
152     MPI_Comm shmem_comm, leader_comm;
153     void* tmp_buf = NULL;
154     
155
156     //if not set (use of the algo directly, without mvapich2 selector)
157     if(MV2_Gather_intra_node_function==NULL)
158       MV2_Gather_intra_node_function= Coll_gather_mpich::gather;
159     
160     if(comm->get_leaders_comm()==MPI_COMM_NULL){
161       comm->init_smp();
162     }
163     comm_size = comm->size();
164     rank = comm->rank();
165
166     if (((rank == root) && (recvcnt == 0)) ||
167         ((rank != root) && (sendcnt == 0))) {
168         return MPI_SUCCESS;
169     }
170
171     if (sendtype != MPI_DATATYPE_NULL) {
172         sendtype_extent=sendtype->get_extent();
173         sendtype_size=sendtype->size();
174         sendtype->extent(&true_lb,
175                                        &sendtype_true_extent);
176     }
177     if (recvtype != MPI_DATATYPE_NULL) {
178         recvtype_extent=recvtype->get_extent();
179         recvtype_size=recvtype->size();
180         recvtype->extent(&true_lb,
181                                        &recvtype_true_extent);
182     }
183
184     /* extract the rank,size information for the intra-node
185      * communicator */
186     shmem_comm = comm->get_intra_comm();
187     local_rank = shmem_comm->rank();
188     local_size = shmem_comm->size();
189     
190     if (local_rank == 0) {
191         /* Node leader. Extract the rank, size information for the leader
192          * communicator */
193         leader_comm = comm->get_leaders_comm();
194         if(leader_comm==MPI_COMM_NULL){
195           leader_comm = MPI_COMM_WORLD;
196         }
197         leader_comm_size = leader_comm->size();
198         leader_comm_rank = leader_comm->size();
199     }
200
201     if (rank == root) {
202         nbytes = recvcnt * recvtype_size;
203
204     } else {
205         nbytes = sendcnt * sendtype_size;
206     }
207
208 #if defined(_SMP_LIMIC_)
209      if((g_use_limic2_coll) && (shmem_commptr->ch.use_intra_sock_comm == 1) 
210          && (use_limic_gather)
211          &&((num_scheme == USE_GATHER_PT_PT_BINOMIAL) 
212             || (num_scheme == USE_GATHER_PT_PT_DIRECT)
213             ||(num_scheme == USE_GATHER_PT_LINEAR_BINOMIAL) 
214             || (num_scheme == USE_GATHER_PT_LINEAR_DIRECT)
215             || (num_scheme == USE_GATHER_LINEAR_PT_BINOMIAL)
216             || (num_scheme == USE_GATHER_LINEAR_PT_DIRECT)
217             || (num_scheme == USE_GATHER_LINEAR_LINEAR)
218             || (num_scheme == USE_GATHER_SINGLE_LEADER))) {
219             
220             mpi_errno = MV2_Gather_intra_node_function(sendbuf, sendcnt, sendtype,
221                                                     recvbuf, recvcnt,recvtype, 
222                                                     root, comm);
223      } else
224
225 #endif/*#if defined(_SMP_LIMIC_)*/    
226     {
227         if (local_rank == 0) {
228             /* Node leader, allocate tmp_buffer */
229             if (rank == root) {
230                 tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent,
231                             recvtype_true_extent) * local_size);
232             } else {
233                 tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent,
234                             sendtype_true_extent) *
235                         local_size);
236             }
237             if (tmp_buf == NULL) {
238                 mpi_errno = MPI_ERR_OTHER;
239                 return mpi_errno;
240             }
241         }
242          /*while testing mpich2 gather test, we see that
243          * which basically splits the comm, and we come to
244          * a point, where use_intra_sock_comm == 0, but if the 
245          * intra node function is MPIR_Intra_node_LIMIC_Gather_MV2,
246          * it would use the intra sock comm. In such cases, we 
247          * fallback to binomial as a default case.*/
248 #if defined(_SMP_LIMIC_)         
249         if(*MV2_Gather_intra_node_function == MPIR_Intra_node_LIMIC_Gather_MV2) {
250
251             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
252                                                  recvbuf, recvcnt, recvtype,
253                                                  root, rank, 
254                                                  tmp_buf, nbytes, 
255                                                  TEMP_BUF_HAS_NO_DATA,
256                                                  shmem_commptr,
257                                                  MPIR_Gather_intra);
258         } else
259 #endif
260         {
261             /*We are gathering the data into tmp_buf and the output
262              * will be of MPI_BYTE datatype. Since the tmp_buf has no
263              * local data, we pass is_data_avail = TEMP_BUF_HAS_NO_DATA*/
264             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
265                                                  recvbuf, recvcnt, recvtype,
266                                                  root, rank, 
267                                                  tmp_buf, nbytes, 
268                                                  TEMP_BUF_HAS_NO_DATA,
269                                                  shmem_comm,
270                                                  MV2_Gather_intra_node_function
271                                                  );
272         }
273     }
274     leader_comm = comm->get_leaders_comm();
275     int* leaders_map = comm->get_leaders_map();
276     leader_of_root = comm->group()->rank(leaders_map[root]);
277     leader_root = leader_comm->group()->rank(leaders_map[root]);
278     /* leader_root is the rank of the leader of the root in leader_comm. 
279      * leader_root is to be used as the root of the inter-leader gather ops 
280      */
281     if (!comm->is_uniform()) {
282         if (local_rank == 0) {
283             int *displs = NULL;
284             int *recvcnts = NULL;
285             int *node_sizes;
286             int i = 0;
287             /* Node leaders have all the data. But, different nodes can have
288              * different number of processes. Do a Gather first to get the 
289              * buffer lengths at each leader, followed by a Gatherv to move
290              * the actual data */
291
292             if (leader_comm_rank == leader_root && root != leader_of_root) {
293                 /* The root of the Gather operation is not a node-level 
294                  * leader and this process's rank in the leader_comm 
295                  * is the same as leader_root */
296                 if(rank == root) { 
297                     leader_gather_buf = smpi_get_tmp_recvbuffer(recvcnt *
298                                                 MAX(recvtype_extent,
299                                                 recvtype_true_extent) *
300                                                 comm_size);
301                 } else { 
302                     leader_gather_buf = smpi_get_tmp_sendbuffer(sendcnt *
303                                                 MAX(sendtype_extent,
304                                                 sendtype_true_extent) *
305                                                 comm_size);
306                 } 
307                 if (leader_gather_buf == NULL) {
308                     mpi_errno =  MPI_ERR_OTHER;
309                     return mpi_errno;
310                 }
311             }
312
313             node_sizes = comm->get_non_uniform_map();
314
315             if (leader_comm_rank == leader_root) {
316                 displs =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
317                 recvcnts =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
318                 if (!displs || !recvcnts) {
319                     mpi_errno = MPI_ERR_OTHER;
320                     return mpi_errno;
321                 }
322             }
323
324             if (root == leader_of_root) {
325                 /* The root of the gather operation is also the node 
326                  * leader. Receive into recvbuf and we are done */
327                 if (leader_comm_rank == leader_root) {
328                     recvcnts[0] = node_sizes[0] * recvcnt;
329                     displs[0] = 0;
330
331                     for (i = 1; i < leader_comm_size; i++) {
332                         displs[i] = displs[i - 1] + node_sizes[i - 1] * recvcnt;
333                         recvcnts[i] = node_sizes[i] * recvcnt;
334                     }
335                 } 
336                 Colls::gatherv(tmp_buf,
337                                          local_size * nbytes,
338                                          MPI_BYTE, recvbuf, recvcnts,
339                                          displs, recvtype,
340                                          leader_root, leader_comm);
341             } else {
342                 /* The root of the gather operation is not the node leader. 
343                  * Receive into leader_gather_buf and then send 
344                  * to the root */
345                 if (leader_comm_rank == leader_root) {
346                     recvcnts[0] = node_sizes[0] * nbytes;
347                     displs[0] = 0;
348
349                     for (i = 1; i < leader_comm_size; i++) {
350                         displs[i] = displs[i - 1] + node_sizes[i - 1] * nbytes;
351                         recvcnts[i] = node_sizes[i] * nbytes;
352                     }
353                 } 
354                 Colls::gatherv(tmp_buf, local_size * nbytes,
355                                          MPI_BYTE, leader_gather_buf,
356                                          recvcnts, displs, MPI_BYTE,
357                                          leader_root, leader_comm);
358             }
359             if (leader_comm_rank == leader_root) {
360                 xbt_free(displs);
361                 xbt_free(recvcnts);
362             }
363         }
364     } else {
365         /* All nodes have the same number of processes. 
366          * Just do one Gather to get all 
367          * the data at the leader of the root process */
368         if (local_rank == 0) {
369             if (leader_comm_rank == leader_root && root != leader_of_root) {
370                 /* The root of the Gather operation is not a node-level leader
371                  */
372                 leader_gather_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
373                 if (leader_gather_buf == NULL) {
374                     mpi_errno = MPI_ERR_OTHER;
375                     return mpi_errno;
376                 }
377             }
378             if (root == leader_of_root) {
379                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf,
380                                                    nbytes * local_size,
381                                                    MPI_BYTE, recvbuf,
382                                                    recvcnt * local_size,
383                                                    recvtype, leader_root,
384                                                    leader_comm);
385                  
386             } else {
387                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf, nbytes * local_size,
388                                                    MPI_BYTE, leader_gather_buf,
389                                                    nbytes * local_size,
390                                                    MPI_BYTE, leader_root,
391                                                    leader_comm);
392             }
393         }
394     }
395     if ((local_rank == 0) && (root != rank)
396         && (leader_of_root == rank)) {
397         Request::send(leader_gather_buf,
398                                  nbytes * comm_size, MPI_BYTE,
399                                  root, COLL_TAG_GATHER, comm);
400     }
401
402     if (rank == root && local_rank != 0) {
403         /* The root of the gather operation is not the node leader. Receive
404          y* data from the node leader */
405         Request::recv(recvbuf, recvcnt * comm_size, recvtype,
406                                  leader_of_root, COLL_TAG_GATHER, comm,
407                                  &status);
408     }
409
410     /* check if multiple threads are calling this collective function */
411     if (local_rank == 0 ) {
412         if (tmp_buf != NULL) {
413             smpi_free_tmp_buffer(tmp_buf);
414         }
415         if (leader_gather_buf != NULL) {
416             smpi_free_tmp_buffer(leader_gather_buf);
417         }
418     }
419
420     return (mpi_errno);
421 }
422 }
423 }
424