Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' into hypervisor
[simgrid.git] / src / smpi / colls / alltoall-2dmesh.c
1 #include "colls.h"
2 #include <math.h>
3
4 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_colls, smpi,
5                                 "Logging specific to SMPI collectives");
6
7 /*****************************************************************************
8
9  * Function: alltoall_2dmesh_shoot
10
11  * Return: int
12
13  * Inputs:
14     send_buff: send input buffer
15     send_count: number of elements to send
16     send_type: data type of elements being sent
17     recv_buff: receive output buffer
18     recv_count: number of elements to received
19     recv_type: data type of elements being received
20     comm: communicator
21
22  * Descrp: Function realizes the alltoall operation using the 2dmesh
23            algorithm. It actually performs allgather operation in x dimension
24            then in the y dimension. Each node then extracts the needed data.
25            The communication in each dimension follows "simple."
26  
27  * Auther: Ahmad Faraj
28
29 ****************************************************************************/
30 static int alltoall_check_is_2dmesh(int num, int *i, int *j)
31 {
32   int x, max = num / 2;
33   x = sqrt(num);
34
35   while (x <= max) {
36     if ((num % x) == 0) {
37       *i = x;
38       *j = num / x;
39
40       if (*i > *j) {
41         x = *i;
42         *i = *j;
43         *j = x;
44       }
45
46       return 1;
47     }
48     x++;
49   }
50   return 0;
51 }
52
53 int smpi_coll_tuned_alltoall_2dmesh(void *send_buff, int send_count,
54                                     MPI_Datatype send_type,
55                                     void *recv_buff, int recv_count,
56                                     MPI_Datatype recv_type, MPI_Comm comm)
57 {
58   MPI_Status *statuses, s;
59   MPI_Request *reqs, *req_ptr;;
60   MPI_Aint extent;
61
62   char *tmp_buff1, *tmp_buff2;
63   int i, j, src, dst, rank, num_procs, count, num_reqs;
64   int X, Y, send_offset, recv_offset;
65   int my_row_base, my_col_base, src_row_base, block_size;
66   int tag = 1, failure = 0, success = 1;
67
68   MPI_Comm_rank(comm, &rank);
69   MPI_Comm_size(comm, &num_procs);
70   MPI_Type_extent(send_type, &extent);
71
72   if (!alltoall_check_is_2dmesh(num_procs, &X, &Y))
73     return failure;
74
75   my_row_base = (rank / Y) * Y;
76   my_col_base = rank % Y;
77
78   block_size = extent * send_count;
79
80   tmp_buff1 = (char *) malloc(block_size * num_procs * Y);
81   if (!tmp_buff1) {
82     XBT_DEBUG("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
83     MPI_Finalize();
84     exit(failure);
85   }
86
87   tmp_buff2 = (char *) malloc(block_size * Y);
88   if (!tmp_buff2) {
89     XBT_WARN("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
90     MPI_Finalize();
91     exit(failure);
92   }
93
94
95
96   num_reqs = X;
97   if (Y > X)
98     num_reqs = Y;
99
100   statuses = (MPI_Status *) malloc(num_reqs * sizeof(MPI_Status));
101   reqs = (MPI_Request *) malloc(num_reqs * sizeof(MPI_Request));
102   if (!reqs) {
103     XBT_WARN("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
104     MPI_Finalize();
105     exit(failure);
106   }
107
108   req_ptr = reqs;
109
110   send_offset = recv_offset = (rank % Y) * block_size * num_procs;
111
112   count = send_count * num_procs;
113
114   for (i = 0; i < Y; i++) {
115     src = i + my_row_base;
116     if (src == rank)
117       continue;
118
119     recv_offset = (src % Y) * block_size * num_procs;
120     MPI_Irecv(tmp_buff1 + recv_offset, count, recv_type, src, tag, comm,
121               req_ptr++);
122   }
123
124   for (i = 0; i < Y; i++) {
125     dst = i + my_row_base;
126     if (dst == rank)
127       continue;
128     MPI_Send(send_buff, count, send_type, dst, tag, comm);
129   }
130
131   MPI_Waitall(Y - 1, reqs, statuses);
132   req_ptr = reqs;
133
134   for (i = 0; i < Y; i++) {
135     send_offset = (rank * block_size) + (i * block_size * num_procs);
136     recv_offset = (my_row_base * block_size) + (i * block_size);
137
138     if (i + my_row_base == rank)
139       MPI_Sendrecv((char *) send_buff + recv_offset, send_count, send_type,
140                    rank, tag,
141                    (char *) recv_buff + recv_offset, recv_count, recv_type,
142                    rank, tag, comm, &s);
143
144     else
145       MPI_Sendrecv(tmp_buff1 + send_offset, send_count, send_type,
146                    rank, tag,
147                    (char *) recv_buff + recv_offset, recv_count, recv_type,
148                    rank, tag, comm, &s);
149   }
150
151
152   for (i = 0; i < X; i++) {
153     src = (i * Y + my_col_base);
154     if (src == rank)
155       continue;
156     src_row_base = (src / Y) * Y;
157
158     MPI_Irecv((char *) recv_buff + src_row_base * block_size, recv_count * Y,
159               recv_type, src, tag, comm, req_ptr++);
160   }
161
162   for (i = 0; i < X; i++) {
163     dst = (i * Y + my_col_base);
164     if (dst == rank)
165       continue;
166
167     recv_offset = 0;
168     for (j = 0; j < Y; j++) {
169       send_offset = (dst + j * num_procs) * block_size;
170
171       if (j + my_row_base == rank)
172         MPI_Sendrecv((char *) send_buff + dst * block_size, send_count,
173                      send_type, rank, tag, tmp_buff2 + recv_offset, recv_count,
174                      recv_type, rank, tag, comm, &s);
175       else
176         MPI_Sendrecv(tmp_buff1 + send_offset, send_count, send_type,
177                      rank, tag,
178                      tmp_buff2 + recv_offset, recv_count, recv_type,
179                      rank, tag, comm, &s);
180
181       recv_offset += block_size;
182     }
183
184     MPI_Send(tmp_buff2, send_count * Y, send_type, dst, tag, comm);
185   }
186   MPI_Waitall(X - 1, reqs, statuses);
187   free(reqs);
188   free(statuses);
189   free(tmp_buff1);
190   free(tmp_buff2);
191   return success;
192 }