Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MPI_Comm -> C++
[simgrid.git] / src / smpi / colls / alltoall-2dmesh.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8 #include <math.h>
9
10 /*****************************************************************************
11
12  * Function: alltoall_2dmesh_shoot
13
14  * Return: int
15
16  * Inputs:
17     send_buff: send input buffer
18     send_count: number of elements to send
19     send_type: data type of elements being sent
20     recv_buff: receive output buffer
21     recv_count: number of elements to received
22     recv_type: data type of elements being received
23     comm: communicator
24
25  * Descrp: Function realizes the alltoall operation using the 2dmesh
26            algorithm. It actually performs allgather operation in x dimension
27            then in the y dimension. Each node then extracts the needed data.
28            The communication in each dimension follows "simple."
29  
30  * Auther: Ahmad Faraj
31
32 ****************************************************************************/
33 static int alltoall_check_is_2dmesh(int num, int *i, int *j)
34 {
35   int x, max = num / 2;
36   x = sqrt(num);
37
38   while (x <= max) {
39     if ((num % x) == 0) {
40       *i = x;
41       *j = num / x;
42
43       if (*i > *j) {
44         x = *i;
45         *i = *j;
46         *j = x;
47       }
48
49       return 1;
50     }
51     x++;
52   }
53   return 0;
54 }
55
56 int smpi_coll_tuned_alltoall_2dmesh(void *send_buff, int send_count,
57                                     MPI_Datatype send_type,
58                                     void *recv_buff, int recv_count,
59                                     MPI_Datatype recv_type, MPI_Comm comm)
60 {
61   MPI_Status *statuses, s;
62   MPI_Request *reqs, *req_ptr;;
63   MPI_Aint extent;
64
65   char *tmp_buff1, *tmp_buff2;
66   int i, j, src, dst, rank, num_procs, count, num_reqs;
67   int X, Y, send_offset, recv_offset;
68   int my_row_base, my_col_base, src_row_base, block_size;
69   int tag = COLL_TAG_ALLTOALL;
70
71   rank = comm->rank();
72   num_procs = comm->size();
73   extent = smpi_datatype_get_extent(send_type);
74
75   if (!alltoall_check_is_2dmesh(num_procs, &X, &Y))
76     return MPI_ERR_OTHER;
77
78   my_row_base = (rank / Y) * Y;
79   my_col_base = rank % Y;
80
81   block_size = extent * send_count;
82
83   tmp_buff1 = (char *) smpi_get_tmp_sendbuffer(block_size * num_procs * Y);
84   tmp_buff2 = (char *) smpi_get_tmp_recvbuffer(block_size * Y);
85
86   num_reqs = X;
87   if (Y > X)
88     num_reqs = Y;
89
90   statuses = (MPI_Status *) xbt_malloc(num_reqs * sizeof(MPI_Status));
91   reqs = (MPI_Request *) xbt_malloc(num_reqs * sizeof(MPI_Request));
92
93   req_ptr = reqs;
94
95   count = send_count * num_procs;
96
97   for (i = 0; i < Y; i++) {
98     src = i + my_row_base;
99     if (src == rank)
100       continue;
101
102     recv_offset = (src % Y) * block_size * num_procs;
103     *(req_ptr++) = smpi_mpi_irecv(tmp_buff1 + recv_offset, count, recv_type, src, tag, comm);
104   }
105
106   for (i = 0; i < Y; i++) {
107     dst = i + my_row_base;
108     if (dst == rank)
109       continue;
110     smpi_mpi_send(send_buff, count, send_type, dst, tag, comm);
111   }
112
113   smpi_mpi_waitall(Y - 1, reqs, statuses);
114   req_ptr = reqs;
115
116   for (i = 0; i < Y; i++) {
117     send_offset = (rank * block_size) + (i * block_size * num_procs);
118     recv_offset = (my_row_base * block_size) + (i * block_size);
119
120     if (i + my_row_base == rank)
121       smpi_mpi_sendrecv((char *) send_buff + recv_offset, send_count, send_type,
122                    rank, tag,
123                    (char *) recv_buff + recv_offset, recv_count, recv_type,
124                    rank, tag, comm, &s);
125
126     else
127       smpi_mpi_sendrecv(tmp_buff1 + send_offset, send_count, send_type,
128                    rank, tag,
129                    (char *) recv_buff + recv_offset, recv_count, recv_type,
130                    rank, tag, comm, &s);
131   }
132
133
134   for (i = 0; i < X; i++) {
135     src = (i * Y + my_col_base);
136     if (src == rank)
137       continue;
138     src_row_base = (src / Y) * Y;
139
140     *(req_ptr++) = smpi_mpi_irecv((char *) recv_buff + src_row_base * block_size, recv_count * Y,
141               recv_type, src, tag, comm);
142   }
143
144   for (i = 0; i < X; i++) {
145     dst = (i * Y + my_col_base);
146     if (dst == rank)
147       continue;
148
149     recv_offset = 0;
150     for (j = 0; j < Y; j++) {
151       send_offset = (dst + j * num_procs) * block_size;
152
153       if (j + my_row_base == rank)
154         smpi_mpi_sendrecv((char *) send_buff + dst * block_size, send_count,
155                      send_type, rank, tag, tmp_buff2 + recv_offset, recv_count,
156                      recv_type, rank, tag, comm, &s);
157       else
158         smpi_mpi_sendrecv(tmp_buff1 + send_offset, send_count, send_type,
159                      rank, tag,
160                      tmp_buff2 + recv_offset, recv_count, recv_type,
161                      rank, tag, comm, &s);
162
163       recv_offset += block_size;
164     }
165
166     smpi_mpi_send(tmp_buff2, send_count * Y, send_type, dst, tag, comm);
167   }
168   smpi_mpi_waitall(X - 1, reqs, statuses);
169   free(reqs);
170   free(statuses);
171   smpi_free_tmp_buffer(tmp_buff1);
172   smpi_free_tmp_buffer(tmp_buff2);
173   return MPI_SUCCESS;
174 }