Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
40e9f788680b67b142435f4650298e619178126f
[simgrid.git] / src / smpi / colls / scatter / scatter-mvapich-two-level.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37 #include "../colls_private.h"
38
39 #define MPIR_Scatter_MV2_Binomial Coll_scatter_ompi_binomial::scatter
40 #define MPIR_Scatter_MV2_Direct Coll_scatter_ompi_basic_linear::scatter
41
42 extern int (*MV2_Scatter_intra_function) (void *sendbuf, int sendcount, MPI_Datatype sendtype,
43     void *recvbuf, int recvcount, MPI_Datatype recvtype,
44     int root, MPI_Comm comm);
45
46 namespace simgrid{
47 namespace smpi{
48
49 int Coll_scatter_mvapich2_two_level_direct::scatter(void *sendbuf,
50                                       int sendcnt,
51                                       MPI_Datatype sendtype,
52                                       void *recvbuf,
53                                       int recvcnt,
54                                       MPI_Datatype recvtype,
55                                       int root, MPI_Comm  comm)
56 {
57     int comm_size, rank;
58     int local_rank, local_size;
59     int leader_comm_rank = -1, leader_comm_size = -1;
60     int mpi_errno = MPI_SUCCESS;
61     int recvtype_size, sendtype_size, nbytes;
62     void *tmp_buf = NULL;
63     void *leader_scatter_buf = NULL;
64     MPI_Status status;
65     int leader_root, leader_of_root = -1;
66     MPI_Comm shmem_comm, leader_comm;
67     //if not set (use of the algo directly, without mvapich2 selector)
68     if(MV2_Scatter_intra_function==NULL)
69       MV2_Scatter_intra_function=Coll_scatter_mpich::scatter;
70     
71     if(comm->get_leaders_comm()==MPI_COMM_NULL){
72       comm->init_smp();
73     }
74     comm_size = comm->size();
75     rank = comm->rank();
76
77     if (((rank == root) && (recvcnt == 0))
78         || ((rank != root) && (sendcnt == 0))) {
79         return MPI_SUCCESS;
80     }
81
82     /* extract the rank,size information for the intra-node
83      * communicator */
84     shmem_comm = comm->get_intra_comm();
85     local_rank = shmem_comm->rank();
86     local_size = shmem_comm->size();
87
88     if (local_rank == 0) {
89         /* Node leader. Extract the rank, size information for the leader
90          * communicator */
91         leader_comm = comm->get_leaders_comm();
92         leader_comm_size = leader_comm->size();
93         leader_comm_rank = leader_comm->rank();
94     }
95
96     if (local_size == comm_size) {
97         /* purely intra-node scatter. Just use the direct algorithm and we are done */
98         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
99                                             recvbuf, recvcnt, recvtype,
100                                             root, comm);
101
102     } else {
103         recvtype_size=recvtype->size();
104         sendtype_size=sendtype->size();
105
106         if (rank == root) {
107             nbytes = sendcnt * sendtype_size;
108         } else {
109             nbytes = recvcnt * recvtype_size;
110         }
111
112         if (local_rank == 0) {
113             /* Node leader, allocate tmp_buffer */
114             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
115         }
116
117         leader_comm = comm->get_leaders_comm();
118         int* leaders_map = comm->get_leaders_map();
119         leader_of_root = comm->group()->rank(leaders_map[root]);
120         leader_root = leader_comm->group()->rank(leaders_map[root]);
121         /* leader_root is the rank of the leader of the root in leader_comm.
122          * leader_root is to be used as the root of the inter-leader gather ops
123          */
124
125         if ((local_rank == 0) && (root != rank)
126             && (leader_of_root == rank)) {
127             /* The root of the scatter operation is not the node leader. Recv
128              * data from the node leader */
129             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
130             Request::recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
131                              root, COLL_TAG_SCATTER, comm, &status);
132
133         }
134
135         if (rank == root && local_rank != 0) {
136             /* The root of the scatter operation is not the node leader. Send
137              * data to the node leader */
138             Request::send(sendbuf, sendcnt * comm_size, sendtype,
139                                      leader_of_root, COLL_TAG_SCATTER, comm
140                                      );
141         }
142
143         if (leader_comm_size > 1 && local_rank == 0) {
144             if (!comm->is_uniform()) {
145                 int *displs = NULL;
146                 int *sendcnts = NULL;
147                 int *node_sizes;
148                 int i = 0;
149                 node_sizes = comm->get_non_uniform_map();
150
151                 if (root != leader_of_root) {
152                     if (leader_comm_rank == leader_root) {
153                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
154                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
155                         sendcnts[0] = node_sizes[0] * nbytes;
156                         displs[0] = 0;
157
158                         for (i = 1; i < leader_comm_size; i++) {
159                             displs[i] =
160                                 displs[i - 1] + node_sizes[i - 1] * nbytes;
161                             sendcnts[i] = node_sizes[i] * nbytes;
162                         }
163                     }
164                         Colls::scatterv(leader_scatter_buf, sendcnts, displs,
165                                       MPI_BYTE, tmp_buf, nbytes * local_size,
166                                       MPI_BYTE, leader_root, leader_comm);
167                 } else {
168                     if (leader_comm_rank == leader_root) {
169                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
170                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
171                         sendcnts[0] = node_sizes[0] * sendcnt;
172                         displs[0] = 0;
173
174                         for (i = 1; i < leader_comm_size; i++) {
175                             displs[i] =
176                                 displs[i - 1] + node_sizes[i - 1] * sendcnt;
177                             sendcnts[i] = node_sizes[i] * sendcnt;
178                         }
179                     }
180                     Colls::scatterv(sendbuf, sendcnts, displs,
181                                               sendtype, tmp_buf,
182                                               nbytes * local_size, MPI_BYTE,
183                                               leader_root, leader_comm);
184                 }
185                 if (leader_comm_rank == leader_root) {
186                     xbt_free(displs);
187                     xbt_free(sendcnts);
188                 }
189             } else {
190                 if (leader_of_root != root) {
191                     mpi_errno =
192                         MPIR_Scatter_MV2_Direct(leader_scatter_buf,
193                                                 nbytes * local_size, MPI_BYTE,
194                                                 tmp_buf, nbytes * local_size,
195                                                 MPI_BYTE, leader_root,
196                                                 leader_comm);
197                 } else {
198                     mpi_errno =
199                         MPIR_Scatter_MV2_Direct(sendbuf, sendcnt * local_size,
200                                                 sendtype, tmp_buf,
201                                                 nbytes * local_size, MPI_BYTE,
202                                                 leader_root, leader_comm);
203
204                 }
205             }
206         }
207         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
208
209         if (rank == root && recvbuf == MPI_IN_PLACE) {
210             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
211                                                 (void *)sendbuf, sendcnt, sendtype,
212                                                 0, shmem_comm);
213         } else {
214             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
215                                                 recvbuf, recvcnt, recvtype,
216                                                 0, shmem_comm);
217         }
218     }
219
220     /* check if multiple threads are calling this collective function */
221     if (comm_size != local_size && local_rank == 0) {
222         smpi_free_tmp_buffer(tmp_buf);
223         if (leader_of_root == rank && root != rank) {
224             smpi_free_tmp_buffer(leader_scatter_buf);
225         }
226     }
227     return (mpi_errno);
228 }
229
230
231 int Coll_scatter_mvapich2_two_level_binomial::scatter(void *sendbuf,
232                                         int sendcnt,
233                                         MPI_Datatype sendtype,
234                                         void *recvbuf,
235                                         int recvcnt,
236                                         MPI_Datatype recvtype,
237                                         int root, MPI_Comm comm)
238 {
239     int comm_size, rank;
240     int local_rank, local_size;
241     int leader_comm_rank = -1, leader_comm_size = -1;
242     int mpi_errno = MPI_SUCCESS;
243     int recvtype_size, sendtype_size, nbytes;
244     void *tmp_buf = NULL;
245     void *leader_scatter_buf = NULL;
246     MPI_Status status;
247     int leader_root = -1, leader_of_root = -1;
248     MPI_Comm shmem_comm, leader_comm;
249
250
251     //if not set (use of the algo directly, without mvapich2 selector)
252     if(MV2_Scatter_intra_function==NULL)
253       MV2_Scatter_intra_function=Coll_scatter_mpich::scatter;
254     
255     if(comm->get_leaders_comm()==MPI_COMM_NULL){
256       comm->init_smp();
257     }
258     comm_size = comm->size();
259     rank = comm->rank();
260
261     if (((rank == root) && (recvcnt == 0))
262         || ((rank != root) && (sendcnt == 0))) {
263         return MPI_SUCCESS;
264     }
265
266     /* extract the rank,size information for the intra-node
267      * communicator */
268     shmem_comm = comm->get_intra_comm();
269     local_rank = shmem_comm->rank();
270     local_size = shmem_comm->size();
271
272     if (local_rank == 0) {
273         /* Node leader. Extract the rank, size information for the leader
274          * communicator */
275         leader_comm = comm->get_leaders_comm();
276         leader_comm_size = leader_comm->size();
277         leader_comm_rank = leader_comm->rank();
278     }
279
280     if (local_size == comm_size) {
281         /* purely intra-node scatter. Just use the direct algorithm and we are done */
282         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
283                                             recvbuf, recvcnt, recvtype,
284                                             root, comm);
285
286     } else {
287         recvtype_size=recvtype->size();
288         sendtype_size=sendtype->size();
289
290         if (rank == root) {
291             nbytes = sendcnt * sendtype_size;
292         } else {
293             nbytes = recvcnt * recvtype_size;
294         }
295
296         if (local_rank == 0) {
297             /* Node leader, allocate tmp_buffer */
298             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
299         }
300         leader_comm = comm->get_leaders_comm();
301         int* leaders_map = comm->get_leaders_map();
302         leader_of_root = comm->group()->rank(leaders_map[root]);
303         leader_root = leader_comm->group()->rank(leaders_map[root]);
304         /* leader_root is the rank of the leader of the root in leader_comm.
305          * leader_root is to be used as the root of the inter-leader gather ops
306          */
307
308         if ((local_rank == 0) && (root != rank)
309             && (leader_of_root == rank)) {
310             /* The root of the scatter operation is not the node leader. Recv
311              * data from the node leader */
312             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
313             Request::recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
314                              root, COLL_TAG_SCATTER, comm, &status);
315         }
316
317         if (rank == root && local_rank != 0) {
318             /* The root of the scatter operation is not the node leader. Send
319              * data to the node leader */
320             Request::send(sendbuf, sendcnt * comm_size, sendtype,
321                                      leader_of_root, COLL_TAG_SCATTER, comm);
322         }
323
324         if (leader_comm_size > 1 && local_rank == 0) {
325             if (!comm->is_uniform()) {
326                 int *displs = NULL;
327                 int *sendcnts = NULL;
328                 int *node_sizes;
329                 int i = 0;
330                 node_sizes = comm->get_non_uniform_map();
331
332                 if (root != leader_of_root) {
333                     if (leader_comm_rank == leader_root) {
334                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
335                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
336                         sendcnts[0] = node_sizes[0] * nbytes;
337                         displs[0] = 0;
338
339                         for (i = 1; i < leader_comm_size; i++) {
340                             displs[i] =
341                                 displs[i - 1] + node_sizes[i - 1] * nbytes;
342                             sendcnts[i] = node_sizes[i] * nbytes;
343                         }
344                     }
345                         Colls::scatterv(leader_scatter_buf, sendcnts, displs,
346                                       MPI_BYTE, tmp_buf, nbytes * local_size,
347                                       MPI_BYTE, leader_root, leader_comm);
348                 } else {
349                     if (leader_comm_rank == leader_root) {
350                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
351                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
352                         sendcnts[0] = node_sizes[0] * sendcnt;
353                         displs[0] = 0;
354
355                         for (i = 1; i < leader_comm_size; i++) {
356                             displs[i] =
357                                 displs[i - 1] + node_sizes[i - 1] * sendcnt;
358                             sendcnts[i] = node_sizes[i] * sendcnt;
359                         }
360                     }
361                     Colls::scatterv(sendbuf, sendcnts, displs,
362                                               sendtype, tmp_buf,
363                                               nbytes * local_size, MPI_BYTE,
364                                               leader_root, leader_comm);
365                 }
366                 if (leader_comm_rank == leader_root) {
367                     xbt_free(displs);
368                     xbt_free(sendcnts);
369                 }
370             } else {
371                 if (leader_of_root != root) {
372                     mpi_errno =
373                         MPIR_Scatter_MV2_Binomial(leader_scatter_buf,
374                                                   nbytes * local_size, MPI_BYTE,
375                                                   tmp_buf, nbytes * local_size,
376                                                   MPI_BYTE, leader_root,
377                                                   leader_comm);
378                 } else {
379                     mpi_errno =
380                         MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt * local_size,
381                                                   sendtype, tmp_buf,
382                                                   nbytes * local_size, MPI_BYTE,
383                                                   leader_root, leader_comm);
384
385                 }
386             }
387         }
388         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
389
390         if (rank == root && recvbuf == MPI_IN_PLACE) {
391             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
392                                                 (void *)sendbuf, sendcnt, sendtype,
393                                                 0, shmem_comm);
394         } else {
395             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
396                                                 recvbuf, recvcnt, recvtype,
397                                                 0, shmem_comm);
398         }
399
400     }
401
402
403     /* check if multiple threads are calling this collective function */
404     if (comm_size != local_size && local_rank == 0) {
405         smpi_free_tmp_buffer(tmp_buf);
406         if (leader_of_root == rank && root != rank) {
407             smpi_free_tmp_buffer(leader_scatter_buf);
408         }
409     }
410
411     return (mpi_errno);
412 }
413
414 }
415 }
416