Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
948957cf8169844976e736720cc594f687ce6a68
[simgrid.git] / src / smpi / colls / alltoall-2dmesh.c
1 #include "colls.h"
2 #include <math.h>
3
4 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_colls, smpi,
5                                 "Logging specific to SMPI collectives");
6
7 /*****************************************************************************
8
9  * Function: alltoall_2dmesh_shoot
10
11  * Return: int
12
13  * Inputs:
14     send_buff: send input buffer
15     send_count: number of elements to send
16     send_type: data type of elements being sent
17     recv_buff: receive output buffer
18     recv_count: number of elements to received
19     recv_type: data type of elements being received
20     comm: communicator
21
22  * Descrp: Function realizes the alltoall operation using the 2dmesh
23            algorithm. It actually performs allgather operation in x dimension
24            then in the y dimension. Each node then extracts the needed data.
25            The communication in each dimension follows "simple."
26  
27  * Auther: Ahmad Faraj
28
29 ****************************************************************************/
30 static int alltoall_check_is_2dmesh(int num, int *i, int *j)
31 {
32   int x, max = num / 2;
33   x = sqrt(num);
34
35   while (x <= max) {
36     if ((num % x) == 0) {
37       *i = x;
38       *j = num / x;
39
40       if (*i > *j) {
41         x = *i;
42         *i = *j;
43         *j = x;
44       }
45
46       return 1;
47     }
48     x++;
49   }
50   return 0;
51 }
52
53 int smpi_coll_tuned_alltoall_2dmesh(void *send_buff, int send_count,
54                                     MPI_Datatype send_type,
55                                     void *recv_buff, int recv_count,
56                                     MPI_Datatype recv_type,
57                                     MPI_Comm comm)
58 {
59   MPI_Status *statuses, s;
60   MPI_Request *reqs, *req_ptr;;
61   MPI_Aint extent;
62
63   char *tmp_buff1, *tmp_buff2;
64   int i, j, src, dst, rank, num_procs, count, num_reqs;
65   int X, Y, send_offset, recv_offset;
66   int my_row_base, my_col_base, src_row_base, block_size;
67   int tag = 1, failure = 0, success = 1;
68
69   MPI_Comm_rank(comm, &rank);
70   MPI_Comm_size(comm, &num_procs);
71   MPI_Type_extent(send_type, &extent);
72
73   if (!alltoall_check_is_2dmesh(num_procs, &X, &Y))
74     return failure;
75
76   my_row_base = (rank / Y) * Y;
77   my_col_base = rank % Y;
78
79   block_size = extent * send_count;
80
81   tmp_buff1 = (char *) malloc(block_size * num_procs * Y);
82   if (!tmp_buff1) {
83     XBT_DEBUG("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
84     MPI_Finalize();
85     exit(failure);
86   }
87
88   tmp_buff2 = (char *) malloc(block_size * Y);
89   if (!tmp_buff2) {
90     XBT_WARN("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
91     MPI_Finalize();
92     exit(failure);
93   }
94
95
96
97   num_reqs = X;
98   if (Y > X)
99     num_reqs = Y;
100
101   statuses = (MPI_Status *) malloc(num_reqs * sizeof(MPI_Status));
102   reqs = (MPI_Request *) malloc(num_reqs * sizeof(MPI_Request));
103   if (!reqs) {
104     XBT_WARN("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
105     MPI_Finalize();
106     exit(failure);
107   }
108
109   req_ptr = reqs;
110
111   send_offset = recv_offset = (rank % Y) * block_size * num_procs;
112
113   count = send_count * num_procs;
114
115   for (i = 0; i < Y; i++) {
116     src = i + my_row_base;
117     if (src == rank)
118       continue;
119
120     recv_offset = (src % Y) * block_size * num_procs;
121     MPI_Irecv(tmp_buff1 + recv_offset, count, recv_type, src, tag, comm,
122               req_ptr++);
123   }
124
125   for (i = 0; i < Y; i++) {
126     dst = i + my_row_base;
127     if (dst == rank)
128       continue;
129     MPI_Send(send_buff, count, send_type, dst, tag, comm);
130   }
131
132   MPI_Waitall(Y - 1, reqs, statuses);
133   req_ptr = reqs;
134
135   for (i = 0; i < Y; i++) {
136     send_offset = (rank * block_size) + (i * block_size * num_procs);
137     recv_offset = (my_row_base * block_size) + (i * block_size);
138
139     if (i + my_row_base == rank)
140       MPI_Sendrecv((char *)send_buff + recv_offset, send_count, send_type,
141                    rank, tag,
142                    (char*)recv_buff + recv_offset, recv_count, recv_type,
143                    rank, tag, comm, &s);
144
145     else
146       MPI_Sendrecv(tmp_buff1 + send_offset, send_count, send_type,
147                    rank, tag,
148                    (char *)recv_buff + recv_offset, recv_count, recv_type,
149                    rank, tag, comm, &s);
150   }
151
152
153   for (i = 0; i < X; i++) {
154     src = (i * Y + my_col_base);
155     if (src == rank)
156       continue;
157     src_row_base = (src / Y) * Y;
158
159     MPI_Irecv((char *)recv_buff + src_row_base * block_size, recv_count * Y,
160               recv_type, src, tag, comm, req_ptr++);
161   }
162
163   for (i = 0; i < X; i++) {
164     dst = (i * Y + my_col_base);
165     if (dst == rank)
166       continue;
167
168     recv_offset = 0;
169     for (j = 0; j < Y; j++) {
170       send_offset = (dst + j * num_procs) * block_size;
171
172       if (j + my_row_base == rank)
173         MPI_Sendrecv((char *)send_buff + dst * block_size, send_count, send_type,
174                      rank, tag,
175                      tmp_buff2 + recv_offset, recv_count, recv_type,
176                      rank, tag, comm, &s);
177       else
178         MPI_Sendrecv(tmp_buff1 + send_offset, send_count, send_type,
179                      rank, tag,
180                      tmp_buff2 + recv_offset, recv_count, recv_type,
181                      rank, tag, comm, &s);
182
183       recv_offset += block_size;
184     }
185
186     MPI_Send(tmp_buff2, send_count * Y, send_type, dst, tag, comm);
187   }
188   MPI_Waitall(X - 1, reqs, statuses);
189   free(reqs);
190   free(statuses);
191   free(tmp_buff1);
192   free(tmp_buff2);
193   return success;
194 }