Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines for 2022.
[simgrid.git] / src / smpi / colls / alltoall / alltoall-2dmesh.cpp
1 /* Copyright (c) 2013-2022. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8 #include <cmath>
9
10 /*****************************************************************************
11
12  * Function: alltoall_2dmesh_shoot
13
14  * Return: int
15
16  * Inputs:
17     send_buff: send input buffer
18     send_count: number of elements to send
19     send_type: data type of elements being sent
20     recv_buff: receive output buffer
21     recv_count: number of elements to received
22     recv_type: data type of elements being received
23     comm: communicator
24
25  * Descrp: Function realizes the alltoall operation using the 2dmesh
26            algorithm. It actually performs allgather operation in x dimension
27            then in the y dimension. Each node then extracts the needed data.
28            The communication in each dimension follows "simple."
29
30  * Author: Ahmad Faraj
31
32 ****************************************************************************/
33 static int alltoall_check_is_2dmesh(int num, int *i, int *j)
34 {
35   int x, max = num / 2;
36   x = sqrt(double(num));
37
38   while (x <= max) {
39     if ((num % x) == 0) {
40       *i = x;
41       *j = num / x;
42
43       if (*i > *j) {
44         x = *i;
45         *i = *j;
46         *j = x;
47       }
48
49       return 1;
50     }
51     x++;
52   }
53   return 0;
54 }
55 namespace simgrid{
56 namespace smpi{
57
58 int alltoall__2dmesh(const void *send_buff, int send_count,
59                      MPI_Datatype send_type,
60                      void *recv_buff, int recv_count,
61                      MPI_Datatype recv_type, MPI_Comm comm)
62 {
63   MPI_Status s;
64   MPI_Aint extent;
65
66   int i, j, src, dst, rank, num_procs, count, num_reqs;
67   int X, Y, send_offset, recv_offset;
68   int my_row_base, my_col_base, src_row_base, block_size;
69   int tag = COLL_TAG_ALLTOALL;
70
71   rank = comm->rank();
72   num_procs = comm->size();
73   extent = send_type->get_extent();
74
75   if (not alltoall_check_is_2dmesh(num_procs, &X, &Y))
76     return MPI_ERR_OTHER;
77
78   my_row_base = (rank / Y) * Y;
79   my_col_base = rank % Y;
80
81   block_size = extent * send_count;
82
83   unsigned char* tmp_buff1 = smpi_get_tmp_sendbuffer(block_size * num_procs * Y);
84   unsigned char* tmp_buff2 = smpi_get_tmp_recvbuffer(block_size * Y);
85
86   num_reqs = X;
87   if (Y > X)
88     num_reqs = Y;
89
90   auto* statuses       = new MPI_Status[num_reqs];
91   auto* reqs           = new MPI_Request[num_reqs];
92   MPI_Request* req_ptr = reqs;
93
94   count = send_count * num_procs;
95
96   for (i = 0; i < Y; i++) {
97     src = i + my_row_base;
98     if (src == rank)
99       continue;
100
101     recv_offset = (src % Y) * block_size * num_procs;
102     *(req_ptr++) = Request::irecv(tmp_buff1 + recv_offset, count, recv_type, src, tag, comm);
103   }
104
105   for (i = 0; i < Y; i++) {
106     dst = i + my_row_base;
107     if (dst == rank)
108       continue;
109     Request::send(send_buff, count, send_type, dst, tag, comm);
110   }
111
112   Request::waitall(Y - 1, reqs, statuses);
113   req_ptr = reqs;
114
115   for (i = 0; i < Y; i++) {
116     send_offset = (rank * block_size) + (i * block_size * num_procs);
117     recv_offset = (my_row_base * block_size) + (i * block_size);
118
119     if (i + my_row_base == rank)
120       Request::sendrecv((char *) send_buff + recv_offset, send_count, send_type,
121                    rank, tag,
122                    (char *) recv_buff + recv_offset, recv_count, recv_type,
123                    rank, tag, comm, &s);
124
125     else
126       Request::sendrecv(tmp_buff1 + send_offset, send_count, send_type,
127                    rank, tag,
128                    (char *) recv_buff + recv_offset, recv_count, recv_type,
129                    rank, tag, comm, &s);
130   }
131
132
133   for (i = 0; i < X; i++) {
134     src = (i * Y + my_col_base);
135     if (src == rank)
136       continue;
137     src_row_base = (src / Y) * Y;
138
139     *(req_ptr++) = Request::irecv((char *) recv_buff + src_row_base * block_size, recv_count * Y,
140               recv_type, src, tag, comm);
141   }
142
143   for (i = 0; i < X; i++) {
144     dst = (i * Y + my_col_base);
145     if (dst == rank)
146       continue;
147
148     recv_offset = 0;
149     for (j = 0; j < Y; j++) {
150       send_offset = (dst + j * num_procs) * block_size;
151
152       if (j + my_row_base == rank)
153         Request::sendrecv((char *) send_buff + dst * block_size, send_count,
154                      send_type, rank, tag, tmp_buff2 + recv_offset, recv_count,
155                      recv_type, rank, tag, comm, &s);
156       else
157         Request::sendrecv(tmp_buff1 + send_offset, send_count, send_type,
158                      rank, tag,
159                      tmp_buff2 + recv_offset, recv_count, recv_type,
160                      rank, tag, comm, &s);
161
162       recv_offset += block_size;
163     }
164
165     Request::send(tmp_buff2, send_count * Y, send_type, dst, tag, comm);
166   }
167   Request::waitall(X - 1, reqs, statuses);
168   delete[] reqs;
169   delete[] statuses;
170   smpi_free_tmp_buffer(tmp_buff1);
171   smpi_free_tmp_buffer(tmp_buff2);
172   return MPI_SUCCESS;
173 }
174 }
175 }