Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MPI_Comm -> C++
[simgrid.git] / src / smpi / colls / scatter-mvapich-two-level.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37 #include "colls_private.h"
38 #include "src/smpi/smpi_group.hpp"
39
40 #define MPIR_Scatter_MV2_Binomial smpi_coll_tuned_scatter_ompi_binomial
41 #define MPIR_Scatter_MV2_Direct smpi_coll_tuned_scatter_ompi_basic_linear
42
43 extern int (*MV2_Scatter_intra_function) (void *sendbuf, int sendcount, MPI_Datatype sendtype,
44     void *recvbuf, int recvcount, MPI_Datatype recvtype,
45     int root, MPI_Comm comm);
46
47 int smpi_coll_tuned_scatter_mvapich2_two_level_direct(void *sendbuf,
48                                       int sendcnt,
49                                       MPI_Datatype sendtype,
50                                       void *recvbuf,
51                                       int recvcnt,
52                                       MPI_Datatype recvtype,
53                                       int root, MPI_Comm  comm)
54 {
55     int comm_size, rank;
56     int local_rank, local_size;
57     int leader_comm_rank = -1, leader_comm_size = -1;
58     int mpi_errno = MPI_SUCCESS;
59     int recvtype_size, sendtype_size, nbytes;
60     void *tmp_buf = NULL;
61     void *leader_scatter_buf = NULL;
62     MPI_Status status;
63     int leader_root, leader_of_root = -1;
64     MPI_Comm shmem_comm, leader_comm;
65     //if not set (use of the algo directly, without mvapich2 selector)
66     if(MV2_Scatter_intra_function==NULL)
67       MV2_Scatter_intra_function=smpi_coll_tuned_scatter_mpich;
68     
69     if(comm->get_leaders_comm()==MPI_COMM_NULL){
70       comm->init_smp();
71     }
72     comm_size = comm->size();
73     rank = comm->rank();
74
75     if (((rank == root) && (recvcnt == 0))
76         || ((rank != root) && (sendcnt == 0))) {
77         return MPI_SUCCESS;
78     }
79
80     /* extract the rank,size information for the intra-node
81      * communicator */
82     shmem_comm = comm->get_intra_comm();
83     local_rank = shmem_comm->rank();
84     local_size = shmem_comm->size();
85
86     if (local_rank == 0) {
87         /* Node leader. Extract the rank, size information for the leader
88          * communicator */
89         leader_comm = comm->get_leaders_comm();
90         leader_comm_size = leader_comm->size();
91         leader_comm_rank = leader_comm->rank();
92     }
93
94     if (local_size == comm_size) {
95         /* purely intra-node scatter. Just use the direct algorithm and we are done */
96         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
97                                             recvbuf, recvcnt, recvtype,
98                                             root, comm);
99
100     } else {
101         recvtype_size=smpi_datatype_size(recvtype);
102         sendtype_size=smpi_datatype_size(sendtype);
103
104         if (rank == root) {
105             nbytes = sendcnt * sendtype_size;
106         } else {
107             nbytes = recvcnt * recvtype_size;
108         }
109
110         if (local_rank == 0) {
111             /* Node leader, allocate tmp_buffer */
112             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
113         }
114
115         leader_comm = comm->get_leaders_comm();
116         int* leaders_map = comm->get_leaders_map();
117         leader_of_root = comm->group()->rank(leaders_map[root]);
118         leader_root = leader_comm->group()->rank(leaders_map[root]);
119         /* leader_root is the rank of the leader of the root in leader_comm.
120          * leader_root is to be used as the root of the inter-leader gather ops
121          */
122
123         if ((local_rank == 0) && (root != rank)
124             && (leader_of_root == rank)) {
125             /* The root of the scatter operation is not the node leader. Recv
126              * data from the node leader */
127             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
128             smpi_mpi_recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
129                              root, COLL_TAG_SCATTER, comm, &status);
130
131         }
132
133         if (rank == root && local_rank != 0) {
134             /* The root of the scatter operation is not the node leader. Send
135              * data to the node leader */
136             smpi_mpi_send(sendbuf, sendcnt * comm_size, sendtype,
137                                      leader_of_root, COLL_TAG_SCATTER, comm
138                                      );
139         }
140
141         if (leader_comm_size > 1 && local_rank == 0) {
142             if (!comm->is_uniform()) {
143                 int *displs = NULL;
144                 int *sendcnts = NULL;
145                 int *node_sizes;
146                 int i = 0;
147                 node_sizes = comm->get_non_uniform_map();
148
149                 if (root != leader_of_root) {
150                     if (leader_comm_rank == leader_root) {
151                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
152                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
153                         sendcnts[0] = node_sizes[0] * nbytes;
154                         displs[0] = 0;
155
156                         for (i = 1; i < leader_comm_size; i++) {
157                             displs[i] =
158                                 displs[i - 1] + node_sizes[i - 1] * nbytes;
159                             sendcnts[i] = node_sizes[i] * nbytes;
160                         }
161                     }
162                         smpi_mpi_scatterv(leader_scatter_buf, sendcnts, displs,
163                                       MPI_BYTE, tmp_buf, nbytes * local_size,
164                                       MPI_BYTE, leader_root, leader_comm);
165                 } else {
166                     if (leader_comm_rank == leader_root) {
167                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
168                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
169                         sendcnts[0] = node_sizes[0] * sendcnt;
170                         displs[0] = 0;
171
172                         for (i = 1; i < leader_comm_size; i++) {
173                             displs[i] =
174                                 displs[i - 1] + node_sizes[i - 1] * sendcnt;
175                             sendcnts[i] = node_sizes[i] * sendcnt;
176                         }
177                     }
178                     smpi_mpi_scatterv(sendbuf, sendcnts, displs,
179                                               sendtype, tmp_buf,
180                                               nbytes * local_size, MPI_BYTE,
181                                               leader_root, leader_comm);
182                 }
183                 if (leader_comm_rank == leader_root) {
184                     xbt_free(displs);
185                     xbt_free(sendcnts);
186                 }
187             } else {
188                 if (leader_of_root != root) {
189                     mpi_errno =
190                         MPIR_Scatter_MV2_Direct(leader_scatter_buf,
191                                                 nbytes * local_size, MPI_BYTE,
192                                                 tmp_buf, nbytes * local_size,
193                                                 MPI_BYTE, leader_root,
194                                                 leader_comm);
195                 } else {
196                     mpi_errno =
197                         MPIR_Scatter_MV2_Direct(sendbuf, sendcnt * local_size,
198                                                 sendtype, tmp_buf,
199                                                 nbytes * local_size, MPI_BYTE,
200                                                 leader_root, leader_comm);
201
202                 }
203             }
204         }
205         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
206
207         if (rank == root && recvbuf == MPI_IN_PLACE) {
208             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
209                                                 (void *)sendbuf, sendcnt, sendtype,
210                                                 0, shmem_comm);
211         } else {
212             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
213                                                 recvbuf, recvcnt, recvtype,
214                                                 0, shmem_comm);
215         }
216     }
217
218     /* check if multiple threads are calling this collective function */
219     if (comm_size != local_size && local_rank == 0) {
220         smpi_free_tmp_buffer(tmp_buf);
221         if (leader_of_root == rank && root != rank) {
222             smpi_free_tmp_buffer(leader_scatter_buf);
223         }
224     }
225     return (mpi_errno);
226 }
227
228
229 int smpi_coll_tuned_scatter_mvapich2_two_level_binomial(void *sendbuf,
230                                         int sendcnt,
231                                         MPI_Datatype sendtype,
232                                         void *recvbuf,
233                                         int recvcnt,
234                                         MPI_Datatype recvtype,
235                                         int root, MPI_Comm comm)
236 {
237     int comm_size, rank;
238     int local_rank, local_size;
239     int leader_comm_rank = -1, leader_comm_size = -1;
240     int mpi_errno = MPI_SUCCESS;
241     int recvtype_size, sendtype_size, nbytes;
242     void *tmp_buf = NULL;
243     void *leader_scatter_buf = NULL;
244     MPI_Status status;
245     int leader_root = -1, leader_of_root = -1;
246     MPI_Comm shmem_comm, leader_comm;
247
248
249     //if not set (use of the algo directly, without mvapich2 selector)
250     if(MV2_Scatter_intra_function==NULL)
251       MV2_Scatter_intra_function=smpi_coll_tuned_scatter_mpich;
252     
253     if(comm->get_leaders_comm()==MPI_COMM_NULL){
254       comm->init_smp();
255     }
256     comm_size = comm->size();
257     rank = comm->rank();
258
259     if (((rank == root) && (recvcnt == 0))
260         || ((rank != root) && (sendcnt == 0))) {
261         return MPI_SUCCESS;
262     }
263
264     /* extract the rank,size information for the intra-node
265      * communicator */
266     shmem_comm = comm->get_intra_comm();
267     local_rank = shmem_comm->rank();
268     local_size = shmem_comm->size();
269
270     if (local_rank == 0) {
271         /* Node leader. Extract the rank, size information for the leader
272          * communicator */
273         leader_comm = comm->get_leaders_comm();
274         leader_comm_size = leader_comm->size();
275         leader_comm_rank = leader_comm->rank();
276     }
277
278     if (local_size == comm_size) {
279         /* purely intra-node scatter. Just use the direct algorithm and we are done */
280         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
281                                             recvbuf, recvcnt, recvtype,
282                                             root, comm);
283
284     } else {
285         recvtype_size=smpi_datatype_size(recvtype);
286         sendtype_size=smpi_datatype_size(sendtype);
287
288         if (rank == root) {
289             nbytes = sendcnt * sendtype_size;
290         } else {
291             nbytes = recvcnt * recvtype_size;
292         }
293
294         if (local_rank == 0) {
295             /* Node leader, allocate tmp_buffer */
296             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
297         }
298         leader_comm = comm->get_leaders_comm();
299         int* leaders_map = comm->get_leaders_map();
300         leader_of_root = comm->group()->rank(leaders_map[root]);
301         leader_root = leader_comm->group()->rank(leaders_map[root]);
302         /* leader_root is the rank of the leader of the root in leader_comm.
303          * leader_root is to be used as the root of the inter-leader gather ops
304          */
305
306         if ((local_rank == 0) && (root != rank)
307             && (leader_of_root == rank)) {
308             /* The root of the scatter operation is not the node leader. Recv
309              * data from the node leader */
310             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
311             smpi_mpi_recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
312                              root, COLL_TAG_SCATTER, comm, &status);
313         }
314
315         if (rank == root && local_rank != 0) {
316             /* The root of the scatter operation is not the node leader. Send
317              * data to the node leader */
318             smpi_mpi_send(sendbuf, sendcnt * comm_size, sendtype,
319                                      leader_of_root, COLL_TAG_SCATTER, comm);
320         }
321
322         if (leader_comm_size > 1 && local_rank == 0) {
323             if (!comm->is_uniform()) {
324                 int *displs = NULL;
325                 int *sendcnts = NULL;
326                 int *node_sizes;
327                 int i = 0;
328                 node_sizes = comm->get_non_uniform_map();
329
330                 if (root != leader_of_root) {
331                     if (leader_comm_rank == leader_root) {
332                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
333                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
334                         sendcnts[0] = node_sizes[0] * nbytes;
335                         displs[0] = 0;
336
337                         for (i = 1; i < leader_comm_size; i++) {
338                             displs[i] =
339                                 displs[i - 1] + node_sizes[i - 1] * nbytes;
340                             sendcnts[i] = node_sizes[i] * nbytes;
341                         }
342                     }
343                         smpi_mpi_scatterv(leader_scatter_buf, sendcnts, displs,
344                                       MPI_BYTE, tmp_buf, nbytes * local_size,
345                                       MPI_BYTE, leader_root, leader_comm);
346                 } else {
347                     if (leader_comm_rank == leader_root) {
348                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
349                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
350                         sendcnts[0] = node_sizes[0] * sendcnt;
351                         displs[0] = 0;
352
353                         for (i = 1; i < leader_comm_size; i++) {
354                             displs[i] =
355                                 displs[i - 1] + node_sizes[i - 1] * sendcnt;
356                             sendcnts[i] = node_sizes[i] * sendcnt;
357                         }
358                     }
359                     smpi_mpi_scatterv(sendbuf, sendcnts, displs,
360                                               sendtype, tmp_buf,
361                                               nbytes * local_size, MPI_BYTE,
362                                               leader_root, leader_comm);
363                 }
364                 if (leader_comm_rank == leader_root) {
365                     xbt_free(displs);
366                     xbt_free(sendcnts);
367                 }
368             } else {
369                 if (leader_of_root != root) {
370                     mpi_errno =
371                         MPIR_Scatter_MV2_Binomial(leader_scatter_buf,
372                                                   nbytes * local_size, MPI_BYTE,
373                                                   tmp_buf, nbytes * local_size,
374                                                   MPI_BYTE, leader_root,
375                                                   leader_comm);
376                 } else {
377                     mpi_errno =
378                         MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt * local_size,
379                                                   sendtype, tmp_buf,
380                                                   nbytes * local_size, MPI_BYTE,
381                                                   leader_root, leader_comm);
382
383                 }
384             }
385         }
386         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
387
388         if (rank == root && recvbuf == MPI_IN_PLACE) {
389             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
390                                                 (void *)sendbuf, sendcnt, sendtype,
391                                                 0, shmem_comm);
392         } else {
393             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
394                                                 recvbuf, recvcnt, recvtype,
395                                                 0, shmem_comm);
396         }
397
398     }
399
400
401     /* check if multiple threads are calling this collective function */
402     if (comm_size != local_size && local_rank == 0) {
403         smpi_free_tmp_buffer(tmp_buf);
404         if (leader_of_root == rank && root != rank) {
405             smpi_free_tmp_buffer(leader_scatter_buf);
406         }
407     }
408
409     return (mpi_errno);
410 }
411