Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add Allgather SMP collective from MVAPICH2
[simgrid.git] / src / smpi / colls / allgather-mvapich-smp.c
1 #include "colls_private.h"
2
3
4
5 int smpi_coll_tuned_allgather_mvapich2_smp(void *sendbuf,int sendcnt, MPI_Datatype sendtype,
6                             void *recvbuf, int recvcnt,MPI_Datatype recvtype,
7                             MPI_Comm  comm)
8 {
9     int rank, size;
10     int local_rank, local_size;
11     int leader_comm_size = 0; 
12     int mpi_errno = MPI_SUCCESS;
13     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
14     MPI_Comm shmem_comm, leader_comm;
15
16   if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
17     smpi_comm_init_smp(comm);
18   }
19   
20     if(!smpi_comm_is_uniform(comm) || !smpi_comm_is_blocked(comm))
21     THROWF(arg_error,0, "allgather MVAPICH2 smp algorithm can't be used with irregular deployment. Please insure that processes deployed on the same node are contiguous and that each node has the same number of processes");
22   
23     if (recvcnt == 0) {
24         return MPI_SUCCESS;
25     }
26
27     rank = smpi_comm_rank(comm);
28     size = smpi_comm_size(comm);
29
30     /* extract the rank,size information for the intra-node
31      * communicator */
32     recvtype_extent=smpi_datatype_get_extent(recvtype);
33     
34     shmem_comm = smpi_comm_get_intra_comm(comm);
35     local_rank = smpi_comm_rank(shmem_comm);
36     local_size = smpi_comm_size(shmem_comm);
37
38     if (local_rank == 0) {
39         /* Node leader. Extract the rank, size information for the leader
40          * communicator */
41         leader_comm = smpi_comm_get_leaders_comm(comm);
42         if(leader_comm==MPI_COMM_NULL){
43           leader_comm = MPI_COMM_WORLD;
44         }
45         leader_comm_size = smpi_comm_size(leader_comm);
46     }
47
48     /*If there is just one node, after gather itself,
49      * root has all the data and it can do bcast*/
50     if(local_rank == 0) {
51         mpi_errno = mpi_coll_gather_fun(sendbuf, sendcnt,sendtype, 
52                                     (void*)((char*)recvbuf + (rank * recvcnt * recvtype_extent)), 
53                                      recvcnt, recvtype,
54                                      0, shmem_comm);
55     } else {
56         /*Since in allgather all the processes could have 
57          * its own data in place*/
58         if(sendbuf == MPI_IN_PLACE) {
59             mpi_errno = mpi_coll_gather_fun((void*)((char*)recvbuf + (rank * recvcnt * recvtype_extent)), 
60                                          recvcnt , recvtype, 
61                                          recvbuf, recvcnt, recvtype,
62                                          0, shmem_comm);
63         } else {
64             mpi_errno = mpi_coll_gather_fun(sendbuf, sendcnt,sendtype, 
65                                          recvbuf, recvcnt, recvtype,
66                                          0, shmem_comm);
67         }
68     }
69     /* Exchange the data between the node leaders*/
70     if (local_rank == 0 && (leader_comm_size > 1)) {
71         /*When data in each socket is different*/
72         if (smpi_comm_is_uniform(comm) != 1) {
73
74             int *displs = NULL;
75             int *recvcnts = NULL;
76             int *node_sizes = NULL;
77             int i = 0;
78
79             node_sizes = smpi_comm_get_non_uniform_map(comm);
80
81             displs = xbt_malloc(sizeof (int) * leader_comm_size);
82             recvcnts = xbt_malloc(sizeof (int) * leader_comm_size);
83             if (!displs || !recvcnts) {
84                 return MPI_ERR_OTHER;
85             }
86             recvcnts[0] = node_sizes[0] * recvcnt;
87             displs[0] = 0;
88
89             for (i = 1; i < leader_comm_size; i++) {
90                 displs[i] = displs[i - 1] + node_sizes[i - 1] * recvcnt;
91                 recvcnts[i] = node_sizes[i] * recvcnt;
92             }
93
94
95             void* sendbuf=((char*)recvbuf)+smpi_datatype_get_extent(recvtype)*displs[smpi_comm_rank(leader_comm)];
96
97             mpi_errno = mpi_coll_allgatherv_fun(sendbuf,
98                                        (recvcnt*local_size),
99                                        recvtype, 
100                                        recvbuf, recvcnts,
101                                        displs, recvtype,
102                                        leader_comm);
103             xbt_free(displs);
104             xbt_free(recvcnts);
105         } else {
106         void* sendtmpbuf=((char*)recvbuf)+smpi_datatype_get_extent(recvtype)*(recvcnt*local_size)*smpi_comm_rank(leader_comm);
107         
108           
109
110             mpi_errno = smpi_coll_tuned_allgather_mpich(sendtmpbuf, 
111                                                (recvcnt*local_size),
112                                                recvtype,
113                                                recvbuf, (recvcnt*local_size), recvtype,
114                                              leader_comm);
115
116         }
117     }
118
119     /*Bcast the entire data from node leaders to all other cores*/
120     mpi_errno = mpi_coll_bcast_fun (recvbuf, recvcnt * size, recvtype, 0, shmem_comm);
121     return mpi_errno;
122 }