Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
2371d10e99a2e79b84d7e26d04ad119086530c50
[simgrid.git] / src / smpi / colls / bcast-ompi-pipeline.c
1  #include "colls_private.h"
2
3
4 #define MAXTREEFANOUT 32
5
6 #define COLL_TUNED_COMPUTED_SEGCOUNT(SEGSIZE, TYPELNG, SEGCOUNT)        \
7     if( ((SEGSIZE) >= (TYPELNG)) &&                                     \
8         ((SEGSIZE) < ((TYPELNG) * (SEGCOUNT))) ) {                      \
9         size_t residual;                                                \
10         (SEGCOUNT) = (int)((SEGSIZE) / (TYPELNG));                      \
11         residual = (SEGSIZE) - (SEGCOUNT) * (TYPELNG);                  \
12         if( residual > ((TYPELNG) >> 1) )                               \
13             (SEGCOUNT)++;                                               \
14     }                                                                   \
15
16  typedef struct ompi_coll_tree_t {
17         int32_t tree_root;
18         int32_t tree_fanout;
19         int32_t tree_bmtree;
20         int32_t tree_prev;
21         int32_t tree_next[MAXTREEFANOUT];
22         int32_t tree_nextsize;
23     } ompi_coll_tree_t;
24
25     ompi_coll_tree_t*
26     ompi_coll_tuned_topo_build_chain( int fanout,
27                                      MPI_Comm com,
28                                      int root );
29
30 ompi_coll_tree_t*
31 ompi_coll_tuned_topo_build_chain( int fanout,
32                                   MPI_Comm comm,
33                                   int root )
34 {
35     int rank, size;
36     int srank; /* shifted rank */
37     int i,maxchainlen;
38     int mark,head,len;
39     ompi_coll_tree_t *chain;
40
41     XBT_DEBUG("coll:tuned:topo:build_chain fo %d rt %d", fanout, root);
42
43     /* 
44      * Get size and rank of the process in this communicator 
45      */
46     size = smpi_comm_size(comm);
47     rank = smpi_comm_rank(comm);
48
49     if( fanout < 1 ) {
50         XBT_DEBUG("coll:tuned:topo:build_chain WARNING invalid fanout of ZERO, forcing to 1 (pipeline)!");
51         fanout = 1;
52     }
53     if (fanout>MAXTREEFANOUT) {
54         XBT_DEBUG("coll:tuned:topo:build_chain WARNING invalid fanout %d bigger than max %d, forcing to max!", fanout, MAXTREEFANOUT);
55         fanout = MAXTREEFANOUT;
56     }
57
58     /*
59      * Allocate space for topology arrays if needed 
60      */
61     chain = (ompi_coll_tree_t*)malloc( sizeof(ompi_coll_tree_t) );
62     if (!chain) {
63         XBT_DEBUG("coll:tuned:topo:build_chain PANIC out of memory");
64         fflush(stdout);
65         return NULL;
66     }
67     chain->tree_root     = MPI_UNDEFINED;
68     chain->tree_nextsize = -1;
69     for(i=0;i<fanout;i++) chain->tree_next[i] = -1;
70
71     /* 
72      * Set root & numchain
73      */
74     chain->tree_root = root;
75     if( (size - 1) < fanout ) { 
76         chain->tree_nextsize = size-1;
77         fanout = size-1;
78     } else {
79         chain->tree_nextsize = fanout;
80     }
81     
82     /*
83      * Shift ranks
84      */
85     srank = rank - root;
86     if (srank < 0) srank += size;
87
88     /*
89      * Special case - fanout == 1
90      */
91     if( fanout == 1 ) {
92         if( srank == 0 ) chain->tree_prev = -1;
93         else chain->tree_prev = (srank-1+root)%size;
94
95         if( (srank + 1) >= size) {
96             chain->tree_next[0] = -1;
97             chain->tree_nextsize = 0;
98         } else {
99             chain->tree_next[0] = (srank+1+root)%size;
100             chain->tree_nextsize = 1;
101         }
102         return chain;
103     }
104
105     /* Let's handle the case where there is just one node in the communicator */
106     if( size == 1 ) {
107         chain->tree_next[0] = -1;
108         chain->tree_nextsize = 0;
109         chain->tree_prev = -1;
110         return chain;
111     }
112     /*
113      * Calculate maximum chain length
114      */
115     maxchainlen = (size-1) / fanout;
116     if( (size-1) % fanout != 0 ) {
117         maxchainlen++;
118         mark = (size-1)%fanout;
119     } else {
120         mark = fanout+1;
121     }
122
123     /*
124      * Find your own place in the list of shifted ranks
125      */
126     if( srank != 0 ) {
127         int column;
128         if( srank-1 < (mark * maxchainlen) ) {
129             column = (srank-1)/maxchainlen;
130             head = 1+column*maxchainlen;
131             len = maxchainlen;
132         } else {
133             column = mark + (srank-1-mark*maxchainlen)/(maxchainlen-1);
134             head = mark*maxchainlen+1+(column-mark)*(maxchainlen-1);
135             len = maxchainlen-1;
136         }
137
138         if( srank == head ) {
139             chain->tree_prev = 0; /*root*/
140         } else {
141             chain->tree_prev = srank-1; /* rank -1 */
142         }
143         if( srank == (head + len - 1) ) {
144             chain->tree_next[0] = -1;
145             chain->tree_nextsize = 0;
146         } else {
147             if( (srank + 1) < size ) {
148                 chain->tree_next[0] = srank+1;
149                 chain->tree_nextsize = 1;
150             } else {
151                 chain->tree_next[0] = -1;
152                 chain->tree_nextsize = 0;    
153             }
154         }
155     }
156     
157     /*
158      * Unshift values 
159      */
160     if( rank == root ) {
161         chain->tree_prev = -1;
162         chain->tree_next[0] = (root+1)%size;
163         for( i = 1; i < fanout; i++ ) {
164             chain->tree_next[i] = chain->tree_next[i-1] + maxchainlen;
165             if( i > mark ) {
166                 chain->tree_next[i]--;
167             }
168             chain->tree_next[i] %= size;
169         }
170         chain->tree_nextsize = fanout;
171     } else {
172         chain->tree_prev = (chain->tree_prev+root)%size;
173         if( chain->tree_next[0] != -1 ) {
174             chain->tree_next[0] = (chain->tree_next[0]+root)%size;
175         }
176     }
177
178     return chain;
179 }
180
181 smpi_coll_tuned_bcast_ompi_pipeline( void* buffer,
182                                       int original_count, 
183                                       MPI_Datatype datatype, 
184                                       int root,
185                                       MPI_Comm comm)
186 {
187     int count_by_segment = original_count;
188     size_t type_size;
189     int segsize;
190     //mca_coll_tuned_module_t *tuned_module = (mca_coll_tuned_module_t*) module;
191     //mca_coll_tuned_comm_t *data = tuned_module->tuned_data;
192     
193 //    return ompi_coll_tuned_bcast_intra_generic( buffer, count, datatype, root, comm, module,
194 //                                                count_by_segment, data->cached_pipeline );
195     ompi_coll_tree_t * tree = ompi_coll_tuned_topo_build_chain( 1, comm, root );
196     int err = 0, line, i;
197     int rank, size;
198     int segindex;
199     int num_segments; /* Number of segments */
200     int sendcount;    /* number of elements sent in this segment */ 
201     size_t realsegsize;
202     char *tmpbuf;
203     ptrdiff_t extent;
204     MPI_Request recv_reqs[2] = {MPI_REQUEST_NULL, MPI_REQUEST_NULL};
205     MPI_Request *send_reqs = NULL;
206     int req_index;
207     
208     /**
209      * Determine number of elements sent per operation.
210      */
211     type_size = smpi_datatype_size(datatype);
212
213     size = smpi_comm_size(comm);
214     rank = smpi_comm_rank(comm);
215     xbt_assert( size > 1 );
216
217
218     const double a_p16  = 3.2118e-6; /* [1 / byte] */
219     const double b_p16  = 8.7936;   
220     const double a_p64  = 2.3679e-6; /* [1 / byte] */
221     const double b_p64  = 1.1787;     
222     const double a_p128 = 1.6134e-6; /* [1 / byte] */
223     const double b_p128 = 2.1102;
224     size_t message_size;
225
226     /* else we need data size for decision function */
227     message_size = type_size * (unsigned long)original_count;   /* needed for decision */
228
229     if (size < (a_p128 * message_size + b_p128)) {
230             //Pipeline with 128KB segments 
231             segsize = 1024  << 7;
232     }else if (size < (a_p64 * message_size + b_p64)) {
233             // Pipeline with 64KB segments 
234             segsize = 1024 << 6;
235     }else if (size < (a_p16 * message_size + b_p16)) {
236             //Pipeline with 16KB segments 
237             segsize = 1024 << 4;
238     }
239
240     COLL_TUNED_COMPUTED_SEGCOUNT( segsize, type_size, count_by_segment );
241
242     XBT_DEBUG("coll:tuned:bcast_intra_pipeline rank %d ss %5d type_size %lu count_by_segment %d",
243                  smpi_comm_rank(comm), segsize, (unsigned long)type_size, count_by_segment);
244
245
246
247     extent = smpi_datatype_get_extent (datatype);
248     num_segments = (original_count + count_by_segment - 1) / count_by_segment;
249     realsegsize = count_by_segment * extent;
250     
251     /* Set the buffer pointers */
252     tmpbuf = (char *) buffer;
253
254     if( tree->tree_nextsize != 0 ) {
255         send_reqs = xbt_new(MPI_Request, tree->tree_nextsize  );
256     }
257
258     /* Root code */
259     if( rank == root ) {
260         /* 
261            For each segment:
262            - send segment to all children.
263              The last segment may have less elements than other segments.
264         */
265         sendcount = count_by_segment;
266         for( segindex = 0; segindex < num_segments; segindex++ ) {
267             if( segindex == (num_segments - 1) ) {
268                 sendcount = original_count - segindex * count_by_segment;
269             }
270             for( i = 0; i < tree->tree_nextsize; i++ ) { 
271                 send_reqs[i] = smpi_mpi_isend(tmpbuf, sendcount, datatype,
272                                          tree->tree_next[i], 
273                                          777, comm);
274            } 
275
276             /* complete the sends before starting the next sends */
277             smpi_mpi_waitall( tree->tree_nextsize, send_reqs, 
278                                          MPI_STATUSES_IGNORE );
279
280             /* update tmp buffer */
281             tmpbuf += realsegsize;
282
283         }
284     } 
285     
286     /* Intermediate nodes code */
287     else if( tree->tree_nextsize > 0 ) { 
288         /* 
289            Create the pipeline. 
290            1) Post the first receive
291            2) For segments 1 .. num_segments
292               - post new receive
293               - wait on the previous receive to complete
294               - send this data to children
295            3) Wait on the last segment
296            4) Compute number of elements in last segment.
297            5) Send the last segment to children
298          */
299         req_index = 0;
300         recv_reqs[req_index]=smpi_mpi_irecv(tmpbuf, count_by_segment, datatype,
301                            tree->tree_prev, 777,
302                            comm);
303         
304         for( segindex = 1; segindex < num_segments; segindex++ ) {
305             
306             req_index = req_index ^ 0x1;
307             
308             /* post new irecv */
309             recv_reqs[req_index]= smpi_mpi_irecv( tmpbuf + realsegsize, count_by_segment,
310                                 datatype, tree->tree_prev, 
311                                 777, 
312                                 comm);
313             
314             /* wait for and forward the previous segment to children */
315             smpi_mpi_wait( &recv_reqs[req_index ^ 0x1], 
316                                      MPI_STATUSES_IGNORE );
317             
318             for( i = 0; i < tree->tree_nextsize; i++ ) { 
319                 send_reqs[i]=smpi_mpi_isend(tmpbuf, count_by_segment, datatype,
320                                          tree->tree_next[i], 
321                                          777, comm );
322             } 
323             
324             /* complete the sends before starting the next iteration */
325             smpi_mpi_waitall( tree->tree_nextsize, send_reqs, 
326                                          MPI_STATUSES_IGNORE );
327             
328             /* Update the receive buffer */
329             tmpbuf += realsegsize;
330         }
331
332         /* Process the last segment */
333         smpi_mpi_wait( &recv_reqs[req_index], MPI_STATUSES_IGNORE );
334         sendcount = original_count - (num_segments - 1) * count_by_segment;
335         for( i = 0; i < tree->tree_nextsize; i++ ) {
336             send_reqs[i] = smpi_mpi_isend(tmpbuf, sendcount, datatype,
337                                      tree->tree_next[i], 
338                                      777, comm);
339         }
340         
341         smpi_mpi_waitall( tree->tree_nextsize, send_reqs, 
342                                      MPI_STATUSES_IGNORE );
343     }
344   
345     /* Leaf nodes */
346     else {
347         /* 
348            Receive all segments from parent in a loop:
349            1) post irecv for the first segment
350            2) for segments 1 .. num_segments
351               - post irecv for the next segment
352               - wait on the previous segment to arrive
353            3) wait for the last segment
354         */
355         req_index = 0;
356         recv_reqs[req_index] = smpi_mpi_irecv(tmpbuf, count_by_segment, datatype,
357                                  tree->tree_prev, 777,
358                                  comm);
359
360         for( segindex = 1; segindex < num_segments; segindex++ ) {
361             req_index = req_index ^ 0x1;
362             tmpbuf += realsegsize;
363             /* post receive for the next segment */
364             recv_reqs[req_index] = smpi_mpi_irecv(tmpbuf, count_by_segment, datatype, 
365                                      tree->tree_prev, 777, 
366                                      comm);
367             /* wait on the previous segment */
368             smpi_mpi_wait( &recv_reqs[req_index ^ 0x1], 
369                                      MPI_STATUS_IGNORE );
370         }
371
372         smpi_mpi_wait( &recv_reqs[req_index], MPI_STATUS_IGNORE );
373     }
374
375     if( NULL != send_reqs ) free(send_reqs);
376
377     return (MPI_SUCCESS);
378 }