Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Welcome to simgrid::smpi::Op
[simgrid.git] / src / smpi / colls / gather-mvapich.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37
38 #include "colls_private.h"
39
40 #define MPIR_Gather_MV2_Direct smpi_coll_tuned_gather_ompi_basic_linear
41 #define MPIR_Gather_MV2_two_level_Direct smpi_coll_tuned_gather_ompi_basic_linear
42 #define MPIR_Gather_intra smpi_coll_tuned_gather_mpich
43 typedef int (*MV2_Gather_function_ptr) (void *sendbuf,
44     int sendcnt,
45     MPI_Datatype sendtype,
46     void *recvbuf,
47     int recvcnt,
48     MPI_Datatype recvtype,
49     int root, MPI_Comm comm);
50     
51 extern MV2_Gather_function_ptr MV2_Gather_inter_leader_function;
52 extern MV2_Gather_function_ptr MV2_Gather_intra_node_function;
53
54 #define TEMP_BUF_HAS_NO_DATA (0)
55 #define TEMP_BUF_HAS_DATA (1)
56
57 /* sendbuf           - (in) sender's buffer
58  * sendcnt           - (in) sender's element count
59  * sendtype          - (in) sender's data type
60  * recvbuf           - (in) receiver's buffer
61  * recvcnt           - (in) receiver's element count
62  * recvtype          - (in) receiver's data type
63  * root              - (in)root for the gather operation
64  * rank              - (in) global rank(rank in the global comm)
65  * tmp_buf           - (out/in) tmp_buf into which intra node
66  *                     data is gathered
67  * is_data_avail     - (in) based on this, tmp_buf acts
68  *                     as in/out parameter.
69  *                     1 - tmp_buf acts as in parameter
70  *                     0 - tmp_buf acts as out parameter
71  * comm_ptr          - (in) pointer to the communicator
72  *                     (shmem_comm or intra_sock_comm or
73  *                     inter-sock_leader_comm)
74  * intra_node_fn_ptr - (in) Function ptr to choose the
75  *                      intra node gather function  
76  * errflag           - (out) to record errors
77  */
78 static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sendtype,
79                             void *recvbuf, int recvcnt, MPI_Datatype recvtype,
80                             int root, int rank, 
81                             void *tmp_buf, int nbytes,
82                             int is_data_avail,
83                             MPI_Comm comm,  
84                             MV2_Gather_function_ptr intra_node_fn_ptr)
85 {
86     int mpi_errno = MPI_SUCCESS;
87     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
88     MPI_Aint true_lb, sendtype_true_extent, recvtype_true_extent;
89
90
91     if (sendtype != MPI_DATATYPE_NULL) {
92         smpi_datatype_extent(sendtype, &true_lb,
93                                        &sendtype_true_extent);
94     }
95     if (recvtype != MPI_DATATYPE_NULL) {
96         recvtype_extent=smpi_datatype_get_extent(recvtype);
97         smpi_datatype_extent(recvtype, &true_lb,
98                                        &recvtype_true_extent);
99     }
100     
101     /* Special case, when tmp_buf itself has data */
102     if (rank == root && sendbuf == MPI_IN_PLACE && is_data_avail) {
103          
104          mpi_errno = intra_node_fn_ptr(MPI_IN_PLACE,
105                                        sendcnt, sendtype, tmp_buf, nbytes,
106                                        MPI_BYTE, 0, comm);
107
108     } else if (rank == root && sendbuf == MPI_IN_PLACE) {
109          mpi_errno = intra_node_fn_ptr((char*)recvbuf +
110                                        rank * recvcnt * recvtype_extent,
111                                        recvcnt, recvtype, tmp_buf, nbytes,
112                                        MPI_BYTE, 0, comm);
113     } else {
114         mpi_errno = intra_node_fn_ptr(sendbuf, sendcnt, sendtype,
115                                       tmp_buf, nbytes, MPI_BYTE,
116                                       0, comm);
117     }
118
119     return mpi_errno;
120
121 }
122
123
124 int smpi_coll_tuned_gather_mvapich2_two_level(void *sendbuf,
125                                             int sendcnt,
126                                             MPI_Datatype sendtype,
127                                             void *recvbuf,
128                                             int recvcnt,
129                                             MPI_Datatype recvtype,
130                                             int root,
131                                             MPI_Comm comm)
132 {
133     void *leader_gather_buf = NULL;
134     int comm_size, rank;
135     int local_rank, local_size;
136     int leader_comm_rank = -1, leader_comm_size = 0;
137     int mpi_errno = MPI_SUCCESS;
138     int recvtype_size = 0, sendtype_size = 0, nbytes=0;
139     int leader_root, leader_of_root;
140     MPI_Status status;
141     MPI_Aint sendtype_extent = 0, recvtype_extent = 0;  /* Datatype extent */
142     MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
143     MPI_Comm shmem_comm, leader_comm;
144     void* tmp_buf = NULL;
145     
146
147     //if not set (use of the algo directly, without mvapich2 selector)
148     if(MV2_Gather_intra_node_function==NULL)
149       MV2_Gather_intra_node_function=smpi_coll_tuned_gather_mpich;
150     
151     if(comm->get_leaders_comm()==MPI_COMM_NULL){
152       comm->init_smp();
153     }
154     comm_size = comm->size();
155     rank = comm->rank();
156
157     if (((rank == root) && (recvcnt == 0)) ||
158         ((rank != root) && (sendcnt == 0))) {
159         return MPI_SUCCESS;
160     }
161
162     if (sendtype != MPI_DATATYPE_NULL) {
163         sendtype_extent=smpi_datatype_get_extent(sendtype);
164         sendtype_size=smpi_datatype_size(sendtype);
165         smpi_datatype_extent(sendtype, &true_lb,
166                                        &sendtype_true_extent);
167     }
168     if (recvtype != MPI_DATATYPE_NULL) {
169         recvtype_extent=smpi_datatype_get_extent(recvtype);
170         recvtype_size=smpi_datatype_size(recvtype);
171         smpi_datatype_extent(recvtype, &true_lb,
172                                        &recvtype_true_extent);
173     }
174
175     /* extract the rank,size information for the intra-node
176      * communicator */
177     shmem_comm = comm->get_intra_comm();
178     local_rank = shmem_comm->rank();
179     local_size = shmem_comm->size();
180     
181     if (local_rank == 0) {
182         /* Node leader. Extract the rank, size information for the leader
183          * communicator */
184         leader_comm = comm->get_leaders_comm();
185         if(leader_comm==MPI_COMM_NULL){
186           leader_comm = MPI_COMM_WORLD;
187         }
188         leader_comm_size = leader_comm->size();
189         leader_comm_rank = leader_comm->size();
190     }
191
192     if (rank == root) {
193         nbytes = recvcnt * recvtype_size;
194
195     } else {
196         nbytes = sendcnt * sendtype_size;
197     }
198
199 #if defined(_SMP_LIMIC_)
200      if((g_use_limic2_coll) && (shmem_commptr->ch.use_intra_sock_comm == 1) 
201          && (use_limic_gather)
202          &&((num_scheme == USE_GATHER_PT_PT_BINOMIAL) 
203             || (num_scheme == USE_GATHER_PT_PT_DIRECT)
204             ||(num_scheme == USE_GATHER_PT_LINEAR_BINOMIAL) 
205             || (num_scheme == USE_GATHER_PT_LINEAR_DIRECT)
206             || (num_scheme == USE_GATHER_LINEAR_PT_BINOMIAL)
207             || (num_scheme == USE_GATHER_LINEAR_PT_DIRECT)
208             || (num_scheme == USE_GATHER_LINEAR_LINEAR)
209             || (num_scheme == USE_GATHER_SINGLE_LEADER))) {
210             
211             mpi_errno = MV2_Gather_intra_node_function(sendbuf, sendcnt, sendtype,
212                                                     recvbuf, recvcnt,recvtype, 
213                                                     root, comm);
214      } else
215
216 #endif/*#if defined(_SMP_LIMIC_)*/    
217     {
218         if (local_rank == 0) {
219             /* Node leader, allocate tmp_buffer */
220             if (rank == root) {
221                 tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent,
222                             recvtype_true_extent) * local_size);
223             } else {
224                 tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent,
225                             sendtype_true_extent) *
226                         local_size);
227             }
228             if (tmp_buf == NULL) {
229                 mpi_errno = MPI_ERR_OTHER;
230                 return mpi_errno;
231             }
232         }
233          /*while testing mpich2 gather test, we see that
234          * which basically splits the comm, and we come to
235          * a point, where use_intra_sock_comm == 0, but if the 
236          * intra node function is MPIR_Intra_node_LIMIC_Gather_MV2,
237          * it would use the intra sock comm. In such cases, we 
238          * fallback to binomial as a default case.*/
239 #if defined(_SMP_LIMIC_)         
240         if(*MV2_Gather_intra_node_function == MPIR_Intra_node_LIMIC_Gather_MV2) {
241
242             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
243                                                  recvbuf, recvcnt, recvtype,
244                                                  root, rank, 
245                                                  tmp_buf, nbytes, 
246                                                  TEMP_BUF_HAS_NO_DATA,
247                                                  shmem_commptr,
248                                                  MPIR_Gather_intra);
249         } else
250 #endif
251         {
252             /*We are gathering the data into tmp_buf and the output
253              * will be of MPI_BYTE datatype. Since the tmp_buf has no
254              * local data, we pass is_data_avail = TEMP_BUF_HAS_NO_DATA*/
255             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
256                                                  recvbuf, recvcnt, recvtype,
257                                                  root, rank, 
258                                                  tmp_buf, nbytes, 
259                                                  TEMP_BUF_HAS_NO_DATA,
260                                                  shmem_comm,
261                                                  MV2_Gather_intra_node_function
262                                                  );
263         }
264     }
265     leader_comm = comm->get_leaders_comm();
266     int* leaders_map = comm->get_leaders_map();
267     leader_of_root = comm->group()->rank(leaders_map[root]);
268     leader_root = leader_comm->group()->rank(leaders_map[root]);
269     /* leader_root is the rank of the leader of the root in leader_comm. 
270      * leader_root is to be used as the root of the inter-leader gather ops 
271      */
272     if (!comm->is_uniform()) {
273         if (local_rank == 0) {
274             int *displs = NULL;
275             int *recvcnts = NULL;
276             int *node_sizes;
277             int i = 0;
278             /* Node leaders have all the data. But, different nodes can have
279              * different number of processes. Do a Gather first to get the 
280              * buffer lengths at each leader, followed by a Gatherv to move
281              * the actual data */
282
283             if (leader_comm_rank == leader_root && root != leader_of_root) {
284                 /* The root of the Gather operation is not a node-level 
285                  * leader and this process's rank in the leader_comm 
286                  * is the same as leader_root */
287                 if(rank == root) { 
288                     leader_gather_buf = smpi_get_tmp_recvbuffer(recvcnt *
289                                                 MAX(recvtype_extent,
290                                                 recvtype_true_extent) *
291                                                 comm_size);
292                 } else { 
293                     leader_gather_buf = smpi_get_tmp_sendbuffer(sendcnt *
294                                                 MAX(sendtype_extent,
295                                                 sendtype_true_extent) *
296                                                 comm_size);
297                 } 
298                 if (leader_gather_buf == NULL) {
299                     mpi_errno =  MPI_ERR_OTHER;
300                     return mpi_errno;
301                 }
302             }
303
304             node_sizes = comm->get_non_uniform_map();
305
306             if (leader_comm_rank == leader_root) {
307                 displs =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
308                 recvcnts =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
309                 if (!displs || !recvcnts) {
310                     mpi_errno = MPI_ERR_OTHER;
311                     return mpi_errno;
312                 }
313             }
314
315             if (root == leader_of_root) {
316                 /* The root of the gather operation is also the node 
317                  * leader. Receive into recvbuf and we are done */
318                 if (leader_comm_rank == leader_root) {
319                     recvcnts[0] = node_sizes[0] * recvcnt;
320                     displs[0] = 0;
321
322                     for (i = 1; i < leader_comm_size; i++) {
323                         displs[i] = displs[i - 1] + node_sizes[i - 1] * recvcnt;
324                         recvcnts[i] = node_sizes[i] * recvcnt;
325                     }
326                 } 
327                 smpi_mpi_gatherv(tmp_buf,
328                                          local_size * nbytes,
329                                          MPI_BYTE, recvbuf, recvcnts,
330                                          displs, recvtype,
331                                          leader_root, leader_comm);
332             } else {
333                 /* The root of the gather operation is not the node leader. 
334                  * Receive into leader_gather_buf and then send 
335                  * to the root */
336                 if (leader_comm_rank == leader_root) {
337                     recvcnts[0] = node_sizes[0] * nbytes;
338                     displs[0] = 0;
339
340                     for (i = 1; i < leader_comm_size; i++) {
341                         displs[i] = displs[i - 1] + node_sizes[i - 1] * nbytes;
342                         recvcnts[i] = node_sizes[i] * nbytes;
343                     }
344                 } 
345                 smpi_mpi_gatherv(tmp_buf, local_size * nbytes,
346                                          MPI_BYTE, leader_gather_buf,
347                                          recvcnts, displs, MPI_BYTE,
348                                          leader_root, leader_comm);
349             }
350             if (leader_comm_rank == leader_root) {
351                 xbt_free(displs);
352                 xbt_free(recvcnts);
353             }
354         }
355     } else {
356         /* All nodes have the same number of processes. 
357          * Just do one Gather to get all 
358          * the data at the leader of the root process */
359         if (local_rank == 0) {
360             if (leader_comm_rank == leader_root && root != leader_of_root) {
361                 /* The root of the Gather operation is not a node-level leader
362                  */
363                 leader_gather_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
364                 if (leader_gather_buf == NULL) {
365                     mpi_errno = MPI_ERR_OTHER;
366                     return mpi_errno;
367                 }
368             }
369             if (root == leader_of_root) {
370                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf,
371                                                    nbytes * local_size,
372                                                    MPI_BYTE, recvbuf,
373                                                    recvcnt * local_size,
374                                                    recvtype, leader_root,
375                                                    leader_comm);
376                  
377             } else {
378                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf, nbytes * local_size,
379                                                    MPI_BYTE, leader_gather_buf,
380                                                    nbytes * local_size,
381                                                    MPI_BYTE, leader_root,
382                                                    leader_comm);
383             }
384         }
385     }
386     if ((local_rank == 0) && (root != rank)
387         && (leader_of_root == rank)) {
388         Request::send(leader_gather_buf,
389                                  nbytes * comm_size, MPI_BYTE,
390                                  root, COLL_TAG_GATHER, comm);
391     }
392
393     if (rank == root && local_rank != 0) {
394         /* The root of the gather operation is not the node leader. Receive
395          y* data from the node leader */
396         Request::recv(recvbuf, recvcnt * comm_size, recvtype,
397                                  leader_of_root, COLL_TAG_GATHER, comm,
398                                  &status);
399     }
400
401     /* check if multiple threads are calling this collective function */
402     if (local_rank == 0 ) {
403         if (tmp_buf != NULL) {
404             smpi_free_tmp_buffer(tmp_buf);
405         }
406         if (leader_gather_buf != NULL) {
407             smpi_free_tmp_buffer(leader_gather_buf);
408         }
409     }
410
411     return (mpi_errno);
412 }
413