Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
9ca1192305651ab308f74a66caab1ed707e2cf44
[simgrid.git] / src / smpi / colls / gather-mvapich.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37
38 #include "colls_private.h"
39 #include "src/smpi/smpi_group.hpp"
40
41 #define MPIR_Gather_MV2_Direct smpi_coll_tuned_gather_ompi_basic_linear
42 #define MPIR_Gather_MV2_two_level_Direct smpi_coll_tuned_gather_ompi_basic_linear
43 #define MPIR_Gather_intra smpi_coll_tuned_gather_mpich
44 typedef int (*MV2_Gather_function_ptr) (void *sendbuf,
45     int sendcnt,
46     MPI_Datatype sendtype,
47     void *recvbuf,
48     int recvcnt,
49     MPI_Datatype recvtype,
50     int root, MPI_Comm comm);
51     
52 extern MV2_Gather_function_ptr MV2_Gather_inter_leader_function;
53 extern MV2_Gather_function_ptr MV2_Gather_intra_node_function;
54
55 #define TEMP_BUF_HAS_NO_DATA (0)
56 #define TEMP_BUF_HAS_DATA (1)
57
58 /* sendbuf           - (in) sender's buffer
59  * sendcnt           - (in) sender's element count
60  * sendtype          - (in) sender's data type
61  * recvbuf           - (in) receiver's buffer
62  * recvcnt           - (in) receiver's element count
63  * recvtype          - (in) receiver's data type
64  * root              - (in)root for the gather operation
65  * rank              - (in) global rank(rank in the global comm)
66  * tmp_buf           - (out/in) tmp_buf into which intra node
67  *                     data is gathered
68  * is_data_avail     - (in) based on this, tmp_buf acts
69  *                     as in/out parameter.
70  *                     1 - tmp_buf acts as in parameter
71  *                     0 - tmp_buf acts as out parameter
72  * comm_ptr          - (in) pointer to the communicator
73  *                     (shmem_comm or intra_sock_comm or
74  *                     inter-sock_leader_comm)
75  * intra_node_fn_ptr - (in) Function ptr to choose the
76  *                      intra node gather function  
77  * errflag           - (out) to record errors
78  */
79 static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sendtype,
80                             void *recvbuf, int recvcnt, MPI_Datatype recvtype,
81                             int root, int rank, 
82                             void *tmp_buf, int nbytes,
83                             int is_data_avail,
84                             MPI_Comm comm,  
85                             MV2_Gather_function_ptr intra_node_fn_ptr)
86 {
87     int mpi_errno = MPI_SUCCESS;
88     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
89     MPI_Aint true_lb, sendtype_true_extent, recvtype_true_extent;
90
91
92     if (sendtype != MPI_DATATYPE_NULL) {
93         smpi_datatype_extent(sendtype, &true_lb,
94                                        &sendtype_true_extent);
95     }
96     if (recvtype != MPI_DATATYPE_NULL) {
97         recvtype_extent=smpi_datatype_get_extent(recvtype);
98         smpi_datatype_extent(recvtype, &true_lb,
99                                        &recvtype_true_extent);
100     }
101     
102     /* Special case, when tmp_buf itself has data */
103     if (rank == root && sendbuf == MPI_IN_PLACE && is_data_avail) {
104          
105          mpi_errno = intra_node_fn_ptr(MPI_IN_PLACE,
106                                        sendcnt, sendtype, tmp_buf, nbytes,
107                                        MPI_BYTE, 0, comm);
108
109     } else if (rank == root && sendbuf == MPI_IN_PLACE) {
110          mpi_errno = intra_node_fn_ptr((char*)recvbuf +
111                                        rank * recvcnt * recvtype_extent,
112                                        recvcnt, recvtype, tmp_buf, nbytes,
113                                        MPI_BYTE, 0, comm);
114     } else {
115         mpi_errno = intra_node_fn_ptr(sendbuf, sendcnt, sendtype,
116                                       tmp_buf, nbytes, MPI_BYTE,
117                                       0, comm);
118     }
119
120     return mpi_errno;
121
122 }
123
124
125 int smpi_coll_tuned_gather_mvapich2_two_level(void *sendbuf,
126                                             int sendcnt,
127                                             MPI_Datatype sendtype,
128                                             void *recvbuf,
129                                             int recvcnt,
130                                             MPI_Datatype recvtype,
131                                             int root,
132                                             MPI_Comm comm)
133 {
134     void *leader_gather_buf = NULL;
135     int comm_size, rank;
136     int local_rank, local_size;
137     int leader_comm_rank = -1, leader_comm_size = 0;
138     int mpi_errno = MPI_SUCCESS;
139     int recvtype_size = 0, sendtype_size = 0, nbytes=0;
140     int leader_root, leader_of_root;
141     MPI_Status status;
142     MPI_Aint sendtype_extent = 0, recvtype_extent = 0;  /* Datatype extent */
143     MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
144     MPI_Comm shmem_comm, leader_comm;
145     void* tmp_buf = NULL;
146     
147
148     //if not set (use of the algo directly, without mvapich2 selector)
149     if(MV2_Gather_intra_node_function==NULL)
150       MV2_Gather_intra_node_function=smpi_coll_tuned_gather_mpich;
151     
152     if(comm->get_leaders_comm()==MPI_COMM_NULL){
153       comm->init_smp();
154     }
155     comm_size = comm->size();
156     rank = comm->rank();
157
158     if (((rank == root) && (recvcnt == 0)) ||
159         ((rank != root) && (sendcnt == 0))) {
160         return MPI_SUCCESS;
161     }
162
163     if (sendtype != MPI_DATATYPE_NULL) {
164         sendtype_extent=smpi_datatype_get_extent(sendtype);
165         sendtype_size=smpi_datatype_size(sendtype);
166         smpi_datatype_extent(sendtype, &true_lb,
167                                        &sendtype_true_extent);
168     }
169     if (recvtype != MPI_DATATYPE_NULL) {
170         recvtype_extent=smpi_datatype_get_extent(recvtype);
171         recvtype_size=smpi_datatype_size(recvtype);
172         smpi_datatype_extent(recvtype, &true_lb,
173                                        &recvtype_true_extent);
174     }
175
176     /* extract the rank,size information for the intra-node
177      * communicator */
178     shmem_comm = comm->get_intra_comm();
179     local_rank = shmem_comm->rank();
180     local_size = shmem_comm->size();
181     
182     if (local_rank == 0) {
183         /* Node leader. Extract the rank, size information for the leader
184          * communicator */
185         leader_comm = comm->get_leaders_comm();
186         if(leader_comm==MPI_COMM_NULL){
187           leader_comm = MPI_COMM_WORLD;
188         }
189         leader_comm_size = leader_comm->size();
190         leader_comm_rank = leader_comm->size();
191     }
192
193     if (rank == root) {
194         nbytes = recvcnt * recvtype_size;
195
196     } else {
197         nbytes = sendcnt * sendtype_size;
198     }
199
200 #if defined(_SMP_LIMIC_)
201      if((g_use_limic2_coll) && (shmem_commptr->ch.use_intra_sock_comm == 1) 
202          && (use_limic_gather)
203          &&((num_scheme == USE_GATHER_PT_PT_BINOMIAL) 
204             || (num_scheme == USE_GATHER_PT_PT_DIRECT)
205             ||(num_scheme == USE_GATHER_PT_LINEAR_BINOMIAL) 
206             || (num_scheme == USE_GATHER_PT_LINEAR_DIRECT)
207             || (num_scheme == USE_GATHER_LINEAR_PT_BINOMIAL)
208             || (num_scheme == USE_GATHER_LINEAR_PT_DIRECT)
209             || (num_scheme == USE_GATHER_LINEAR_LINEAR)
210             || (num_scheme == USE_GATHER_SINGLE_LEADER))) {
211             
212             mpi_errno = MV2_Gather_intra_node_function(sendbuf, sendcnt, sendtype,
213                                                     recvbuf, recvcnt,recvtype, 
214                                                     root, comm);
215      } else
216
217 #endif/*#if defined(_SMP_LIMIC_)*/    
218     {
219         if (local_rank == 0) {
220             /* Node leader, allocate tmp_buffer */
221             if (rank == root) {
222                 tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent,
223                             recvtype_true_extent) * local_size);
224             } else {
225                 tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent,
226                             sendtype_true_extent) *
227                         local_size);
228             }
229             if (tmp_buf == NULL) {
230                 mpi_errno = MPI_ERR_OTHER;
231                 return mpi_errno;
232             }
233         }
234          /*while testing mpich2 gather test, we see that
235          * which basically splits the comm, and we come to
236          * a point, where use_intra_sock_comm == 0, but if the 
237          * intra node function is MPIR_Intra_node_LIMIC_Gather_MV2,
238          * it would use the intra sock comm. In such cases, we 
239          * fallback to binomial as a default case.*/
240 #if defined(_SMP_LIMIC_)         
241         if(*MV2_Gather_intra_node_function == MPIR_Intra_node_LIMIC_Gather_MV2) {
242
243             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
244                                                  recvbuf, recvcnt, recvtype,
245                                                  root, rank, 
246                                                  tmp_buf, nbytes, 
247                                                  TEMP_BUF_HAS_NO_DATA,
248                                                  shmem_commptr,
249                                                  MPIR_Gather_intra);
250         } else
251 #endif
252         {
253             /*We are gathering the data into tmp_buf and the output
254              * will be of MPI_BYTE datatype. Since the tmp_buf has no
255              * local data, we pass is_data_avail = TEMP_BUF_HAS_NO_DATA*/
256             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
257                                                  recvbuf, recvcnt, recvtype,
258                                                  root, rank, 
259                                                  tmp_buf, nbytes, 
260                                                  TEMP_BUF_HAS_NO_DATA,
261                                                  shmem_comm,
262                                                  MV2_Gather_intra_node_function
263                                                  );
264         }
265     }
266     leader_comm = comm->get_leaders_comm();
267     int* leaders_map = comm->get_leaders_map();
268     leader_of_root = comm->group()->rank(leaders_map[root]);
269     leader_root = leader_comm->group()->rank(leaders_map[root]);
270     /* leader_root is the rank of the leader of the root in leader_comm. 
271      * leader_root is to be used as the root of the inter-leader gather ops 
272      */
273     if (!comm->is_uniform()) {
274         if (local_rank == 0) {
275             int *displs = NULL;
276             int *recvcnts = NULL;
277             int *node_sizes;
278             int i = 0;
279             /* Node leaders have all the data. But, different nodes can have
280              * different number of processes. Do a Gather first to get the 
281              * buffer lengths at each leader, followed by a Gatherv to move
282              * the actual data */
283
284             if (leader_comm_rank == leader_root && root != leader_of_root) {
285                 /* The root of the Gather operation is not a node-level 
286                  * leader and this process's rank in the leader_comm 
287                  * is the same as leader_root */
288                 if(rank == root) { 
289                     leader_gather_buf = smpi_get_tmp_recvbuffer(recvcnt *
290                                                 MAX(recvtype_extent,
291                                                 recvtype_true_extent) *
292                                                 comm_size);
293                 } else { 
294                     leader_gather_buf = smpi_get_tmp_sendbuffer(sendcnt *
295                                                 MAX(sendtype_extent,
296                                                 sendtype_true_extent) *
297                                                 comm_size);
298                 } 
299                 if (leader_gather_buf == NULL) {
300                     mpi_errno =  MPI_ERR_OTHER;
301                     return mpi_errno;
302                 }
303             }
304
305             node_sizes = comm->get_non_uniform_map();
306
307             if (leader_comm_rank == leader_root) {
308                 displs =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
309                 recvcnts =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
310                 if (!displs || !recvcnts) {
311                     mpi_errno = MPI_ERR_OTHER;
312                     return mpi_errno;
313                 }
314             }
315
316             if (root == leader_of_root) {
317                 /* The root of the gather operation is also the node 
318                  * leader. Receive into recvbuf and we are done */
319                 if (leader_comm_rank == leader_root) {
320                     recvcnts[0] = node_sizes[0] * recvcnt;
321                     displs[0] = 0;
322
323                     for (i = 1; i < leader_comm_size; i++) {
324                         displs[i] = displs[i - 1] + node_sizes[i - 1] * recvcnt;
325                         recvcnts[i] = node_sizes[i] * recvcnt;
326                     }
327                 } 
328                 smpi_mpi_gatherv(tmp_buf,
329                                          local_size * nbytes,
330                                          MPI_BYTE, recvbuf, recvcnts,
331                                          displs, recvtype,
332                                          leader_root, leader_comm);
333             } else {
334                 /* The root of the gather operation is not the node leader. 
335                  * Receive into leader_gather_buf and then send 
336                  * to the root */
337                 if (leader_comm_rank == leader_root) {
338                     recvcnts[0] = node_sizes[0] * nbytes;
339                     displs[0] = 0;
340
341                     for (i = 1; i < leader_comm_size; i++) {
342                         displs[i] = displs[i - 1] + node_sizes[i - 1] * nbytes;
343                         recvcnts[i] = node_sizes[i] * nbytes;
344                     }
345                 } 
346                 smpi_mpi_gatherv(tmp_buf, local_size * nbytes,
347                                          MPI_BYTE, leader_gather_buf,
348                                          recvcnts, displs, MPI_BYTE,
349                                          leader_root, leader_comm);
350             }
351             if (leader_comm_rank == leader_root) {
352                 xbt_free(displs);
353                 xbt_free(recvcnts);
354             }
355         }
356     } else {
357         /* All nodes have the same number of processes. 
358          * Just do one Gather to get all 
359          * the data at the leader of the root process */
360         if (local_rank == 0) {
361             if (leader_comm_rank == leader_root && root != leader_of_root) {
362                 /* The root of the Gather operation is not a node-level leader
363                  */
364                 leader_gather_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
365                 if (leader_gather_buf == NULL) {
366                     mpi_errno = MPI_ERR_OTHER;
367                     return mpi_errno;
368                 }
369             }
370             if (root == leader_of_root) {
371                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf,
372                                                    nbytes * local_size,
373                                                    MPI_BYTE, recvbuf,
374                                                    recvcnt * local_size,
375                                                    recvtype, leader_root,
376                                                    leader_comm);
377                  
378             } else {
379                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf, nbytes * local_size,
380                                                    MPI_BYTE, leader_gather_buf,
381                                                    nbytes * local_size,
382                                                    MPI_BYTE, leader_root,
383                                                    leader_comm);
384             }
385         }
386     }
387     if ((local_rank == 0) && (root != rank)
388         && (leader_of_root == rank)) {
389         smpi_mpi_send(leader_gather_buf,
390                                  nbytes * comm_size, MPI_BYTE,
391                                  root, COLL_TAG_GATHER, comm);
392     }
393
394     if (rank == root && local_rank != 0) {
395         /* The root of the gather operation is not the node leader. Receive
396          y* data from the node leader */
397         smpi_mpi_recv(recvbuf, recvcnt * comm_size, recvtype,
398                                  leader_of_root, COLL_TAG_GATHER, comm,
399                                  &status);
400     }
401
402     /* check if multiple threads are calling this collective function */
403     if (local_rank == 0 ) {
404         if (tmp_buf != NULL) {
405             smpi_free_tmp_buffer(tmp_buf);
406         }
407         if (leader_gather_buf != NULL) {
408             smpi_free_tmp_buffer(leader_gather_buf);
409         }
410     }
411
412     return (mpi_errno);
413 }
414