Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
memleaks --
[simgrid.git] / src / smpi / colls / alltoall-2dmesh.c
1 #include "colls_private.h"
2 #include <math.h>
3
4 /*****************************************************************************
5
6  * Function: alltoall_2dmesh_shoot
7
8  * Return: int
9
10  * Inputs:
11     send_buff: send input buffer
12     send_count: number of elements to send
13     send_type: data type of elements being sent
14     recv_buff: receive output buffer
15     recv_count: number of elements to received
16     recv_type: data type of elements being received
17     comm: communicator
18
19  * Descrp: Function realizes the alltoall operation using the 2dmesh
20            algorithm. It actually performs allgather operation in x dimension
21            then in the y dimension. Each node then extracts the needed data.
22            The communication in each dimension follows "simple."
23  
24  * Auther: Ahmad Faraj
25
26 ****************************************************************************/
27 static int alltoall_check_is_2dmesh(int num, int *i, int *j)
28 {
29   int x, max = num / 2;
30   x = sqrt(num);
31
32   while (x <= max) {
33     if ((num % x) == 0) {
34       *i = x;
35       *j = num / x;
36
37       if (*i > *j) {
38         x = *i;
39         *i = *j;
40         *j = x;
41       }
42
43       return 1;
44     }
45     x++;
46   }
47   return 0;
48 }
49
50 int smpi_coll_tuned_alltoall_2dmesh(void *send_buff, int send_count,
51                                     MPI_Datatype send_type,
52                                     void *recv_buff, int recv_count,
53                                     MPI_Datatype recv_type, MPI_Comm comm)
54 {
55   MPI_Status *statuses, s;
56   MPI_Request *reqs, *req_ptr;;
57   MPI_Aint extent;
58
59   char *tmp_buff1, *tmp_buff2;
60   int i, j, src, dst, rank, num_procs, count, num_reqs;
61   int X, Y, send_offset, recv_offset;
62   int my_row_base, my_col_base, src_row_base, block_size;
63   int tag = COLL_TAG_ALLTOALL;
64
65   rank = smpi_comm_rank(comm);
66   num_procs = smpi_comm_size(comm);
67   extent = smpi_datatype_get_extent(send_type);
68
69   if (!alltoall_check_is_2dmesh(num_procs, &X, &Y))
70     return MPI_ERR_OTHER;
71
72   my_row_base = (rank / Y) * Y;
73   my_col_base = rank % Y;
74
75   block_size = extent * send_count;
76
77   tmp_buff1 = (char *) xbt_malloc(block_size * num_procs * Y);
78   tmp_buff2 = (char *) xbt_malloc(block_size * Y);
79
80   num_reqs = X;
81   if (Y > X)
82     num_reqs = Y;
83
84   statuses = (MPI_Status *) xbt_malloc(num_reqs * sizeof(MPI_Status));
85   reqs = (MPI_Request *) xbt_malloc(num_reqs * sizeof(MPI_Request));
86
87   req_ptr = reqs;
88
89   send_offset = recv_offset = (rank % Y) * block_size * num_procs;
90
91   count = send_count * num_procs;
92
93   for (i = 0; i < Y; i++) {
94     src = i + my_row_base;
95     if (src == rank)
96       continue;
97
98     recv_offset = (src % Y) * block_size * num_procs;
99     *(req_ptr++) = smpi_mpi_irecv(tmp_buff1 + recv_offset, count, recv_type, src, tag, comm);
100   }
101
102   for (i = 0; i < Y; i++) {
103     dst = i + my_row_base;
104     if (dst == rank)
105       continue;
106     smpi_mpi_send(send_buff, count, send_type, dst, tag, comm);
107   }
108
109   smpi_mpi_waitall(Y - 1, reqs, statuses);
110   req_ptr = reqs;
111
112   for (i = 0; i < Y; i++) {
113     send_offset = (rank * block_size) + (i * block_size * num_procs);
114     recv_offset = (my_row_base * block_size) + (i * block_size);
115
116     if (i + my_row_base == rank)
117       smpi_mpi_sendrecv((char *) send_buff + recv_offset, send_count, send_type,
118                    rank, tag,
119                    (char *) recv_buff + recv_offset, recv_count, recv_type,
120                    rank, tag, comm, &s);
121
122     else
123       smpi_mpi_sendrecv(tmp_buff1 + send_offset, send_count, send_type,
124                    rank, tag,
125                    (char *) recv_buff + recv_offset, recv_count, recv_type,
126                    rank, tag, comm, &s);
127   }
128
129
130   for (i = 0; i < X; i++) {
131     src = (i * Y + my_col_base);
132     if (src == rank)
133       continue;
134     src_row_base = (src / Y) * Y;
135
136     *(req_ptr++) = smpi_mpi_irecv((char *) recv_buff + src_row_base * block_size, recv_count * Y,
137               recv_type, src, tag, comm);
138   }
139
140   for (i = 0; i < X; i++) {
141     dst = (i * Y + my_col_base);
142     if (dst == rank)
143       continue;
144
145     recv_offset = 0;
146     for (j = 0; j < Y; j++) {
147       send_offset = (dst + j * num_procs) * block_size;
148
149       if (j + my_row_base == rank)
150         smpi_mpi_sendrecv((char *) send_buff + dst * block_size, send_count,
151                      send_type, rank, tag, tmp_buff2 + recv_offset, recv_count,
152                      recv_type, rank, tag, comm, &s);
153       else
154         smpi_mpi_sendrecv(tmp_buff1 + send_offset, send_count, send_type,
155                      rank, tag,
156                      tmp_buff2 + recv_offset, recv_count, recv_type,
157                      rank, tag, comm, &s);
158
159       recv_offset += block_size;
160     }
161
162     smpi_mpi_send(tmp_buff2, send_count * Y, send_type, dst, tag, comm);
163   }
164   smpi_mpi_waitall(X - 1, reqs, statuses);
165   free(reqs);
166   free(statuses);
167   free(tmp_buff1);
168   free(tmp_buff2);
169   return MPI_SUCCESS;
170 }