Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add new tracing options to reduce the size of the traces
[simgrid.git] / src / smpi / colls / alltoall-2dmesh.c
1 #include "colls.h"
2 #include <math.h>
3 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_colls, smpi,
4                                 "Logging specific to SMPI collectives");
5
6 /*****************************************************************************
7
8  * Function: alltoall_2dmesh_shoot
9
10  * Return: int
11
12  * Inputs:
13     send_buff: send input buffer
14     send_count: number of elements to send
15     send_type: data type of elements being sent
16     recv_buff: receive output buffer
17     recv_count: number of elements to received
18     recv_type: data type of elements being received
19     comm: communicator
20
21  * Descrp: Function realizes the alltoall operation using the 2dmesh
22            algorithm. It actually performs allgather operation in x dimension
23            then in the y dimension. Each node then extracts the needed data.
24            The communication in each dimension follows "simple."
25  
26  * Auther: Ahmad Faraj
27
28 ****************************************************************************/
29 int alltoall_check_is_2dmesh(int num, int * i, int * j)
30 {
31   int x, max = num / 2;
32   x = sqrt(num);  
33
34   while (x <= max)
35     {
36       if ((num % x) == 0)
37         {
38           * i = x;
39           * j = num / x;
40        
41           if (* i > * j)
42             {
43               x = * i;
44               * i = * j;
45               * j = x;
46             }
47        
48           return 1;
49         }
50       x++;
51     }
52   return 0;
53 }
54
55 int
56 smpi_coll_tuned_alltoall_2dmesh(void * send_buff, int send_count, MPI_Datatype send_type,
57                       void * recv_buff, int recv_count, MPI_Datatype recv_type,
58                       MPI_Comm comm)
59 {
60   MPI_Status * statuses, s;
61   MPI_Request * reqs, * req_ptr;;
62   MPI_Aint extent;
63
64   char * tmp_buff1, * tmp_buff2;
65   int i, j, src, dst, rank, num_procs, count, num_reqs;
66   int rows, cols, my_row, my_col, X, Y, send_offset, recv_offset;
67   int two_dsize, my_row_base, my_col_base, src_row_base, block_size;
68   int tag = 1, failure = 0, success = 1;
69   
70   MPI_Comm_rank(comm, &rank);
71   MPI_Comm_size(comm, &num_procs);
72   MPI_Type_extent(send_type, &extent);
73
74   if (!alltoall_check_is_2dmesh(num_procs, &X, &Y))
75     return failure;
76
77   two_dsize = X * Y;
78
79   my_row_base = (rank / Y) * Y;
80   my_col_base = rank % Y;
81
82   block_size = extent * send_count;
83   
84   tmp_buff1 =(char *) malloc(block_size * num_procs * Y);
85   if (!tmp_buff1)
86     {
87       XBT_DEBUG("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
88       MPI_Finalize();
89       exit(failure);
90     }  
91
92   tmp_buff2 =(char *) malloc(block_size *  Y);  
93   if (!tmp_buff2)
94     {
95       XBT_WARN("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
96       MPI_Finalize();
97       exit(failure);
98     }
99   
100
101
102   num_reqs = X;
103   if (Y > X) num_reqs = Y;
104
105   statuses = (MPI_Status *) malloc(num_reqs * sizeof(MPI_Status));  
106   reqs = (MPI_Request *) malloc(num_reqs * sizeof(MPI_Request));  
107   if (!reqs)
108     {
109       XBT_WARN("alltoall-2dmesh_shoot.c:88: cannot allocate memory");
110       MPI_Finalize();
111       exit(failure);
112     }
113   
114   req_ptr = reqs;
115
116   send_offset = recv_offset = (rank % Y) * block_size * num_procs;
117
118   count = send_count * num_procs;
119   
120   for (i = 0; i < Y; i++)
121     {
122       src = i + my_row_base;
123       if (src == rank)
124         continue;
125
126       recv_offset = (src % Y) * block_size * num_procs;
127        MPI_Irecv(tmp_buff1 + recv_offset, count, recv_type, src, tag, comm,
128                  req_ptr++);
129     }
130   
131   for (i = 0; i < Y; i++)
132     {
133       dst = i + my_row_base;
134       if (dst == rank)
135         continue;
136        MPI_Send(send_buff, count, send_type, dst, tag, comm);
137     }
138   
139   MPI_Waitall(Y - 1, reqs, statuses);
140   req_ptr = reqs;
141   
142   for (i = 0; i < Y; i++)
143     {
144       send_offset = (rank * block_size) + (i * block_size * num_procs);
145       recv_offset = (my_row_base * block_size) + (i * block_size);
146      
147       if (i + my_row_base == rank) 
148         MPI_Sendrecv (send_buff + recv_offset, send_count, send_type,
149                       rank, tag, recv_buff + recv_offset, recv_count,
150                       recv_type, rank, tag, comm, &s);
151
152      else 
153          MPI_Sendrecv (tmp_buff1 + send_offset, send_count, send_type,
154                        rank, tag, 
155                        recv_buff + recv_offset, recv_count, recv_type,
156                        rank, tag, comm, &s);
157     }
158
159   
160   for (i = 0; i < X; i++)
161   {
162      src = (i * Y + my_col_base);
163      if (src == rank)
164        continue;
165      src_row_base = (src / Y) * Y;
166
167       MPI_Irecv(recv_buff + src_row_base * block_size, recv_count * Y,
168                 recv_type, src, tag, comm, req_ptr++);
169   }
170   
171   for (i = 0; i < X; i++)
172     {
173       dst = (i * Y + my_col_base);
174       if (dst == rank)
175         continue;
176
177       recv_offset = 0;
178       for (j = 0; j < Y; j++)
179         {
180           send_offset = (dst + j * num_procs) * block_size;
181
182           if (j + my_row_base == rank) 
183              MPI_Sendrecv (send_buff + dst * block_size, send_count, send_type,
184                            rank, tag, 
185                            tmp_buff2 + recv_offset, recv_count, recv_type,
186                            rank, tag, comm, &s);
187           else
188              MPI_Sendrecv (tmp_buff1 + send_offset, send_count, send_type,
189                            rank, tag, 
190                            tmp_buff2 + recv_offset, recv_count, recv_type,
191                            rank, tag, comm, &s);
192           
193           recv_offset += block_size;
194         }
195
196        MPI_Send(tmp_buff2, send_count * Y, send_type, dst, tag, comm);
197     }
198   MPI_Waitall(X - 1, reqs, statuses);
199   free(reqs);
200   free(statuses);
201   free(tmp_buff1);
202   free(tmp_buff2);  
203   return success;
204 }