Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Move all smpi colls to cpp.
[simgrid.git] / src / smpi / colls / scatter-mvapich-two-level.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37 #include "colls_private.h"
38
39 #define MPIR_Scatter_MV2_Binomial smpi_coll_tuned_scatter_ompi_binomial
40 #define MPIR_Scatter_MV2_Direct smpi_coll_tuned_scatter_ompi_basic_linear
41
42 extern int (*MV2_Scatter_intra_function) (void *sendbuf, int sendcount, MPI_Datatype sendtype,
43     void *recvbuf, int recvcount, MPI_Datatype recvtype,
44     int root, MPI_Comm comm);
45
46 int smpi_coll_tuned_scatter_mvapich2_two_level_direct(void *sendbuf,
47                                       int sendcnt,
48                                       MPI_Datatype sendtype,
49                                       void *recvbuf,
50                                       int recvcnt,
51                                       MPI_Datatype recvtype,
52                                       int root, MPI_Comm  comm)
53 {
54     int comm_size, rank;
55     int local_rank, local_size;
56     int leader_comm_rank = -1, leader_comm_size = -1;
57     int mpi_errno = MPI_SUCCESS;
58     int recvtype_size, sendtype_size, nbytes;
59     void *tmp_buf = NULL;
60     void *leader_scatter_buf = NULL;
61     MPI_Status status;
62     int leader_root, leader_of_root = -1;
63     MPI_Comm shmem_comm, leader_comm;
64     //if not set (use of the algo directly, without mvapich2 selector)
65     if(MV2_Scatter_intra_function==NULL)
66       MV2_Scatter_intra_function=smpi_coll_tuned_scatter_mpich;
67     
68     if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
69       smpi_comm_init_smp(comm);
70     }
71     comm_size = smpi_comm_size(comm);
72     rank = smpi_comm_rank(comm);
73
74     if (((rank == root) && (recvcnt == 0))
75         || ((rank != root) && (sendcnt == 0))) {
76         return MPI_SUCCESS;
77     }
78
79     /* extract the rank,size information for the intra-node
80      * communicator */
81     shmem_comm = smpi_comm_get_intra_comm(comm);
82     local_rank = smpi_comm_rank(shmem_comm);
83     local_size = smpi_comm_size(shmem_comm);
84
85     if (local_rank == 0) {
86         /* Node leader. Extract the rank, size information for the leader
87          * communicator */
88         leader_comm = smpi_comm_get_leaders_comm(comm);
89         leader_comm_size = smpi_comm_size(leader_comm);
90         leader_comm_rank = smpi_comm_rank(leader_comm);
91     }
92
93     if (local_size == comm_size) {
94         /* purely intra-node scatter. Just use the direct algorithm and we are done */
95         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
96                                             recvbuf, recvcnt, recvtype,
97                                             root, comm);
98
99     } else {
100         recvtype_size=smpi_datatype_size(recvtype);
101         sendtype_size=smpi_datatype_size(sendtype);
102
103         if (rank == root) {
104             nbytes = sendcnt * sendtype_size;
105         } else {
106             nbytes = recvcnt * recvtype_size;
107         }
108
109         if (local_rank == 0) {
110             /* Node leader, allocate tmp_buffer */
111             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
112         }
113
114         leader_comm = smpi_comm_get_leaders_comm(comm);
115         int* leaders_map = smpi_comm_get_leaders_map(comm);
116         leader_of_root = smpi_group_rank(smpi_comm_group(comm),leaders_map[root]);
117         leader_root = smpi_group_rank(smpi_comm_group(leader_comm),leaders_map[root]);
118         /* leader_root is the rank of the leader of the root in leader_comm.
119          * leader_root is to be used as the root of the inter-leader gather ops
120          */
121
122         if ((local_rank == 0) && (root != rank)
123             && (leader_of_root == rank)) {
124             /* The root of the scatter operation is not the node leader. Recv
125              * data from the node leader */
126             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
127             smpi_mpi_recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
128                              root, COLL_TAG_SCATTER, comm, &status);
129
130         }
131
132         if (rank == root && local_rank != 0) {
133             /* The root of the scatter operation is not the node leader. Send
134              * data to the node leader */
135             smpi_mpi_send(sendbuf, sendcnt * comm_size, sendtype,
136                                      leader_of_root, COLL_TAG_SCATTER, comm
137                                      );
138         }
139
140         if (leader_comm_size > 1 && local_rank == 0) {
141             if (!smpi_comm_is_uniform(comm)) {
142                 int *displs = NULL;
143                 int *sendcnts = NULL;
144                 int *node_sizes;
145                 int i = 0;
146                 node_sizes = smpi_comm_get_non_uniform_map(comm);
147
148                 if (root != leader_of_root) {
149                     if (leader_comm_rank == leader_root) {
150                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
151                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
152                         sendcnts[0] = node_sizes[0] * nbytes;
153                         displs[0] = 0;
154
155                         for (i = 1; i < leader_comm_size; i++) {
156                             displs[i] =
157                                 displs[i - 1] + node_sizes[i - 1] * nbytes;
158                             sendcnts[i] = node_sizes[i] * nbytes;
159                         }
160                     }
161                         smpi_mpi_scatterv(leader_scatter_buf, sendcnts, displs,
162                                       MPI_BYTE, tmp_buf, nbytes * local_size,
163                                       MPI_BYTE, leader_root, leader_comm);
164                 } else {
165                     if (leader_comm_rank == leader_root) {
166                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
167                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
168                         sendcnts[0] = node_sizes[0] * sendcnt;
169                         displs[0] = 0;
170
171                         for (i = 1; i < leader_comm_size; i++) {
172                             displs[i] =
173                                 displs[i - 1] + node_sizes[i - 1] * sendcnt;
174                             sendcnts[i] = node_sizes[i] * sendcnt;
175                         }
176                     }
177                     smpi_mpi_scatterv(sendbuf, sendcnts, displs,
178                                               sendtype, tmp_buf,
179                                               nbytes * local_size, MPI_BYTE,
180                                               leader_root, leader_comm);
181                 }
182                 if (leader_comm_rank == leader_root) {
183                     xbt_free(displs);
184                     xbt_free(sendcnts);
185                 }
186             } else {
187                 if (leader_of_root != root) {
188                     mpi_errno =
189                         MPIR_Scatter_MV2_Direct(leader_scatter_buf,
190                                                 nbytes * local_size, MPI_BYTE,
191                                                 tmp_buf, nbytes * local_size,
192                                                 MPI_BYTE, leader_root,
193                                                 leader_comm);
194                 } else {
195                     mpi_errno =
196                         MPIR_Scatter_MV2_Direct(sendbuf, sendcnt * local_size,
197                                                 sendtype, tmp_buf,
198                                                 nbytes * local_size, MPI_BYTE,
199                                                 leader_root, leader_comm);
200
201                 }
202             }
203         }
204         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
205
206         if (rank == root && recvbuf == MPI_IN_PLACE) {
207             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
208                                                 (void *)sendbuf, sendcnt, sendtype,
209                                                 0, shmem_comm);
210         } else {
211             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
212                                                 recvbuf, recvcnt, recvtype,
213                                                 0, shmem_comm);
214         }
215     }
216
217     /* check if multiple threads are calling this collective function */
218     if (comm_size != local_size && local_rank == 0) {
219         smpi_free_tmp_buffer(tmp_buf);
220         if (leader_of_root == rank && root != rank) {
221             smpi_free_tmp_buffer(leader_scatter_buf);
222         }
223     }
224     return (mpi_errno);
225 }
226
227
228 int smpi_coll_tuned_scatter_mvapich2_two_level_binomial(void *sendbuf,
229                                         int sendcnt,
230                                         MPI_Datatype sendtype,
231                                         void *recvbuf,
232                                         int recvcnt,
233                                         MPI_Datatype recvtype,
234                                         int root, MPI_Comm comm)
235 {
236     int comm_size, rank;
237     int local_rank, local_size;
238     int leader_comm_rank = -1, leader_comm_size = -1;
239     int mpi_errno = MPI_SUCCESS;
240     int recvtype_size, sendtype_size, nbytes;
241     void *tmp_buf = NULL;
242     void *leader_scatter_buf = NULL;
243     MPI_Status status;
244     int leader_root = -1, leader_of_root = -1;
245     MPI_Comm shmem_comm, leader_comm;
246
247
248     //if not set (use of the algo directly, without mvapich2 selector)
249     if(MV2_Scatter_intra_function==NULL)
250       MV2_Scatter_intra_function=smpi_coll_tuned_scatter_mpich;
251     
252     if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
253       smpi_comm_init_smp(comm);
254     }
255     comm_size = smpi_comm_size(comm);
256     rank = smpi_comm_rank(comm);
257
258     if (((rank == root) && (recvcnt == 0))
259         || ((rank != root) && (sendcnt == 0))) {
260         return MPI_SUCCESS;
261     }
262
263     /* extract the rank,size information for the intra-node
264      * communicator */
265     shmem_comm = smpi_comm_get_intra_comm(comm);
266     local_rank = smpi_comm_rank(shmem_comm);
267     local_size = smpi_comm_size(shmem_comm);
268
269     if (local_rank == 0) {
270         /* Node leader. Extract the rank, size information for the leader
271          * communicator */
272         leader_comm = smpi_comm_get_leaders_comm(comm);
273         leader_comm_size = smpi_comm_size(leader_comm);
274         leader_comm_rank = smpi_comm_rank(leader_comm);
275     }
276
277     if (local_size == comm_size) {
278         /* purely intra-node scatter. Just use the direct algorithm and we are done */
279         mpi_errno = MPIR_Scatter_MV2_Direct(sendbuf, sendcnt, sendtype,
280                                             recvbuf, recvcnt, recvtype,
281                                             root, comm);
282
283     } else {
284         recvtype_size=smpi_datatype_size(recvtype);
285         sendtype_size=smpi_datatype_size(sendtype);
286
287         if (rank == root) {
288             nbytes = sendcnt * sendtype_size;
289         } else {
290             nbytes = recvcnt * recvtype_size;
291         }
292
293         if (local_rank == 0) {
294             /* Node leader, allocate tmp_buffer */
295             tmp_buf = smpi_get_tmp_sendbuffer(nbytes * local_size);
296         }
297         leader_comm = smpi_comm_get_leaders_comm(comm);
298         int* leaders_map = smpi_comm_get_leaders_map(comm);
299         leader_of_root = smpi_group_rank(smpi_comm_group(comm),leaders_map[root]);
300         leader_root = smpi_group_rank(smpi_comm_group(leader_comm),leaders_map[root]);
301         /* leader_root is the rank of the leader of the root in leader_comm.
302          * leader_root is to be used as the root of the inter-leader gather ops
303          */
304
305         if ((local_rank == 0) && (root != rank)
306             && (leader_of_root == rank)) {
307             /* The root of the scatter operation is not the node leader. Recv
308              * data from the node leader */
309             leader_scatter_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
310             smpi_mpi_recv(leader_scatter_buf, nbytes * comm_size, MPI_BYTE,
311                              root, COLL_TAG_SCATTER, comm, &status);
312         }
313
314         if (rank == root && local_rank != 0) {
315             /* The root of the scatter operation is not the node leader. Send
316              * data to the node leader */
317             smpi_mpi_send(sendbuf, sendcnt * comm_size, sendtype,
318                                      leader_of_root, COLL_TAG_SCATTER, comm);
319         }
320
321         if (leader_comm_size > 1 && local_rank == 0) {
322             if (!smpi_comm_is_uniform(comm)) {
323                 int *displs = NULL;
324                 int *sendcnts = NULL;
325                 int *node_sizes;
326                 int i = 0;
327                 node_sizes = smpi_comm_get_non_uniform_map(comm);
328
329                 if (root != leader_of_root) {
330                     if (leader_comm_rank == leader_root) {
331                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
332                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
333                         sendcnts[0] = node_sizes[0] * nbytes;
334                         displs[0] = 0;
335
336                         for (i = 1; i < leader_comm_size; i++) {
337                             displs[i] =
338                                 displs[i - 1] + node_sizes[i - 1] * nbytes;
339                             sendcnts[i] = node_sizes[i] * nbytes;
340                         }
341                     }
342                         smpi_mpi_scatterv(leader_scatter_buf, sendcnts, displs,
343                                       MPI_BYTE, tmp_buf, nbytes * local_size,
344                                       MPI_BYTE, leader_root, leader_comm);
345                 } else {
346                     if (leader_comm_rank == leader_root) {
347                         displs = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
348                         sendcnts = static_cast<int*>(xbt_malloc(sizeof (int) * leader_comm_size));
349                         sendcnts[0] = node_sizes[0] * sendcnt;
350                         displs[0] = 0;
351
352                         for (i = 1; i < leader_comm_size; i++) {
353                             displs[i] =
354                                 displs[i - 1] + node_sizes[i - 1] * sendcnt;
355                             sendcnts[i] = node_sizes[i] * sendcnt;
356                         }
357                     }
358                     smpi_mpi_scatterv(sendbuf, sendcnts, displs,
359                                               sendtype, tmp_buf,
360                                               nbytes * local_size, MPI_BYTE,
361                                               leader_root, leader_comm);
362                 }
363                 if (leader_comm_rank == leader_root) {
364                     xbt_free(displs);
365                     xbt_free(sendcnts);
366                 }
367             } else {
368                 if (leader_of_root != root) {
369                     mpi_errno =
370                         MPIR_Scatter_MV2_Binomial(leader_scatter_buf,
371                                                   nbytes * local_size, MPI_BYTE,
372                                                   tmp_buf, nbytes * local_size,
373                                                   MPI_BYTE, leader_root,
374                                                   leader_comm);
375                 } else {
376                     mpi_errno =
377                         MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt * local_size,
378                                                   sendtype, tmp_buf,
379                                                   nbytes * local_size, MPI_BYTE,
380                                                   leader_root, leader_comm);
381
382                 }
383             }
384         }
385         /* The leaders are now done with the inter-leader part. Scatter the data within the nodes */
386
387         if (rank == root && recvbuf == MPI_IN_PLACE) {
388             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
389                                                 (void *)sendbuf, sendcnt, sendtype,
390                                                 0, shmem_comm);
391         } else {
392             mpi_errno = MV2_Scatter_intra_function(tmp_buf, nbytes, MPI_BYTE,
393                                                 recvbuf, recvcnt, recvtype,
394                                                 0, shmem_comm);
395         }
396
397     }
398
399
400     /* check if multiple threads are calling this collective function */
401     if (comm_size != local_size && local_rank == 0) {
402         smpi_free_tmp_buffer(tmp_buf);
403         if (leader_of_root == rank && root != rank) {
404             smpi_free_tmp_buffer(leader_scatter_buf);
405         }
406     }
407
408     return (mpi_errno);
409 }
410