Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add new entry in Release_Notes.
[simgrid.git] / src / smpi / colls / allreduce / allreduce-mvapich-two-level.cpp
1 /* Copyright (c) 2013-2023. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2009 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  *
19  * Additional copyrights may follow
20  */
21  /* -*- Mode: C; c-basic-offset:4 ; -*- */
22 /* Copyright (c) 2001-2014, The Ohio State University. All rights
23  * reserved.
24  *
25  * This file is part of the MVAPICH2 software package developed by the
26  * team members of The Ohio State University's Network-Based Computing
27  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
28  *
29  * For detailed copyright and licensing information, please refer to the
30  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
31  */
32 /*
33  *
34  *  (C) 2001 by Argonne National Laboratory.
35  *      See COPYRIGHT in top-level directory.
36  */
37
38 #include "../colls_private.hpp"
39
40 #define MPIR_Allreduce_pt2pt_rd_MV2 allreduce__rdb
41 #define MPIR_Allreduce_pt2pt_rs_MV2 allreduce__mvapich2_rs
42
43 extern int (*MV2_Allreducection)(const void *sendbuf,
44     void *recvbuf,
45     int count,
46     MPI_Datatype datatype,
47     MPI_Op op, MPI_Comm comm);
48
49
50 extern int (*MV2_Allreduce_intra_function)(const void *sendbuf,
51     void *recvbuf,
52     int count,
53     MPI_Datatype datatype,
54     MPI_Op op, MPI_Comm comm);
55
56 namespace simgrid::smpi {
57 static  int MPIR_Allreduce_reduce_p2p_MV2(const void *sendbuf,
58     void *recvbuf,
59     int count,
60     MPI_Datatype datatype,
61     MPI_Op op, MPI_Comm  comm)
62 {
63   colls::reduce(sendbuf, recvbuf, count, datatype, op, 0, comm);
64   return MPI_SUCCESS;
65 }
66
67 static  int MPIR_Allreduce_reduce_shmem_MV2(const void *sendbuf,
68     void *recvbuf,
69     int count,
70     MPI_Datatype datatype,
71     MPI_Op op, MPI_Comm  comm)
72 {
73   colls::reduce(sendbuf, recvbuf, count, datatype, op, 0, comm);
74   return MPI_SUCCESS;
75 }
76
77
78 /* general two level allreduce helper function */
79 int allreduce__mvapich2_two_level(const void *sendbuf,
80                              void *recvbuf,
81                              int count,
82                              MPI_Datatype datatype,
83                              MPI_Op op, MPI_Comm comm)
84 {
85     int mpi_errno = MPI_SUCCESS;
86     int total_size = 0;
87     MPI_Aint true_lb, true_extent;
88     MPI_Comm shmem_comm = MPI_COMM_NULL, leader_comm = MPI_COMM_NULL;
89     int local_rank = -1, local_size = 0;
90
91     //if not set (use of the algo directly, without mvapich2 selector)
92     if (MV2_Allreduce_intra_function == nullptr)
93       MV2_Allreduce_intra_function = allreduce__mpich;
94     if (MV2_Allreducection == nullptr)
95       MV2_Allreducection = allreduce__rdb;
96
97     if(comm->get_leaders_comm()==MPI_COMM_NULL){
98       comm->init_smp();
99     }
100
101     if (count == 0) {
102         return MPI_SUCCESS;
103     }
104     datatype->extent(&true_lb,
105                                        &true_extent);
106
107     total_size = comm->size();
108     shmem_comm = comm->get_intra_comm();
109     local_rank = shmem_comm->rank();
110     local_size = shmem_comm->size();
111
112     leader_comm = comm->get_leaders_comm();
113
114     if (local_rank == 0) {
115         if (sendbuf != MPI_IN_PLACE) {
116             Datatype::copy(sendbuf, count, datatype, recvbuf,
117                                        count, datatype);
118         }
119     }
120
121     /* Doing the shared memory gather and reduction by the leader */
122     if (local_rank == 0) {
123         if ((MV2_Allreduce_intra_function == &MPIR_Allreduce_reduce_shmem_MV2) ||
124               (MV2_Allreduce_intra_function == &MPIR_Allreduce_reduce_p2p_MV2) ) {
125         mpi_errno =
126         MV2_Allreduce_intra_function(sendbuf, recvbuf, count, datatype,
127                                      op, comm);
128         }
129         else {
130         mpi_errno =
131         MV2_Allreduce_intra_function(sendbuf, recvbuf, count, datatype,
132                                      op, shmem_comm);
133         }
134
135         if (local_size != total_size) {
136           unsigned char* sendtmpbuf = smpi_get_tmp_sendbuffer(count * datatype->get_extent());
137           Datatype::copy(recvbuf, count, datatype,sendtmpbuf, count, datatype);
138             /* inter-node allreduce */
139             if(MV2_Allreducection == &MPIR_Allreduce_pt2pt_rd_MV2){
140                 mpi_errno =
141                     MPIR_Allreduce_pt2pt_rd_MV2(sendtmpbuf, recvbuf, count, datatype, op,
142                                       leader_comm);
143             } else {
144                 mpi_errno =
145                     MPIR_Allreduce_pt2pt_rs_MV2(sendtmpbuf, recvbuf, count, datatype, op,
146                                       leader_comm);
147             }
148             smpi_free_tmp_buffer(sendtmpbuf);
149         }
150     } else {
151         /* insert the first reduce here */
152         if ((MV2_Allreduce_intra_function == &MPIR_Allreduce_reduce_shmem_MV2) ||
153               (MV2_Allreduce_intra_function == &MPIR_Allreduce_reduce_p2p_MV2) ) {
154         mpi_errno =
155         MV2_Allreduce_intra_function(sendbuf, recvbuf, count, datatype,
156                                      op, comm);
157         }
158         else {
159         mpi_errno =
160         MV2_Allreduce_intra_function(sendbuf, recvbuf, count, datatype,
161                                      op, shmem_comm);
162         }
163     }
164
165     /* Broadcasting the message from leader to the rest */
166     /* Note: shared memory broadcast could improve the performance */
167     mpi_errno = colls::bcast(recvbuf, count, datatype, 0, shmem_comm);
168
169     return (mpi_errno);
170
171 }
172 } // namespace simgrid::smpi