Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
add scatter algos from ompi
[simgrid.git] / src / smpi / colls / allreduce-rab-reduce-scatter.c
1 #include "colls_private.h"
2 #ifndef REDUCE_STUFF
3 #define REDUCE_STUFF
4 /*****************************************************************************
5
6 Copyright (c) 2006, Ahmad Faraj & Xin Yuan,
7 All rights reserved.
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are met:
11
12   * Redistributions of source code must retain the above copyright notice,
13     this list of conditions and the following disclaimer.
14
15   * Redistributions in binary form must reproduce the above copyright notice,
16     this list of conditions and the following disclaimer in the documentation
17     and/or other materials provided with the distribution.
18
19   * Neither the name of the Florida State University nor the names of its
20     contributors may be used to endorse or promote products derived from this
21     software without specific prior written permission.
22
23 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
27 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34   *************************************************************************
35   *     Any results obtained from executing this software require the     *
36   *     acknowledgment and citation of the software and its owners.       *
37   *     The full citation is given below:                                 *
38   *                                                                       *
39   *     A. Faraj and X. Yuan. "Automatic Generation and Tuning of MPI     *
40   *     Collective Communication Routines." The 19th ACM International    *
41   *     Conference on Supercomputing (ICS), Cambridge, Massachusetts,     *
42   *     June 20-22, 2005.                                                 *
43   *************************************************************************
44
45 *****************************************************************************/
46
47 extern MPI_User_function *MPIR_Op_table[];
48
49
50 /* -*- Mode: C; c-basic-offset:4 ; -*- */
51 /*  $Id: mpich-stuff.h,v 1.1 2005/08/22 19:50:21 faraj Exp $
52  *
53  *  (C) 2001 by Argonne National Laboratory.
54  *      See COPYRIGHT in top-level directory.
55  */
56 #ifndef _MPICH_STUFF_H
57 #define _MPICH_STUFF_H
58
59 /*TOpaqOverview.tex
60   MPI Opaque Objects:
61
62   MPI Opaque objects such as 'MPI_Comm' or 'MPI_Datatype' are specified by 
63   integers (in the MPICH2 implementation); the MPI standard calls these
64   handles.  
65   Out of range values are invalid; the value 0 is reserved.
66   For most (with the possible exception of 
67   'MPI_Request' for performance reasons) MPI Opaque objects, the integer
68   encodes both the kind of object (allowing runtime tests to detect a datatype
69   passed where a communicator is expected) and important properties of the 
70   object.  Even the 'MPI_xxx_NULL' values should be encoded so that 
71   different null handles can be distinguished.  The details of the encoding
72   of the handles is covered in more detail in the MPICH2 Design Document.
73   For the most part, the ADI uses pointers to the underlying structures
74   rather than the handles themselves.  However, each structure contains an 
75   'handle' field that is the corresponding integer handle for the MPI object.
76
77   MPID objects (objects used within the implementation of MPI) are not opaque.
78
79   T*/
80
81 /* Known MPI object types.  These are used for both the error handlers 
82    and for the handles.  This is a 4 bit value.  0 is reserved for so 
83    that all-zero handles can be flagged as an error. */
84 /*E
85   MPID_Object_kind - Object kind (communicator, window, or file)
86
87   Notes:
88   This enum is used by keyvals and errhandlers to indicate the type of
89   object for which MPI opaque types the data is valid.  These are defined
90   as bits to allow future expansion to the case where an object is value for
91   multiple types (for example, we may want a universal error handler for 
92   errors return).  This is also used to indicate the type of MPI object a 
93   MPI handle represents.  It is an enum because only this applies only the
94   the MPI objects.
95
96   Module:
97   Attribute-DS
98   E*/
99 typedef enum MPID_Object_kind {
100   MPID_COMM = 0x1,
101   MPID_GROUP = 0x2,
102   MPID_DATATYPE = 0x3,
103   MPID_FILE = 0x4,
104   MPID_ERRHANDLER = 0x5,
105   MPID_OP = 0x6,
106   MPID_INFO = 0x7,
107   MPID_WIN = 0x8,
108   MPID_KEYVAL = 0x9,
109   MPID_ATTR = 0xa,
110   MPID_REQUEST = 0xb
111 } MPID_Object_kind;
112 /* The above objects should correspond to MPI objects only. */
113 #define HANDLE_MPI_KIND_SHIFT 26
114 #define HANDLE_GET_MPI_KIND(a) ( ((a)&0x3c000000) >> HANDLE_MPI_KIND_SHIFT )
115
116 /* Handle types.  These are really 2 bits */
117 #define HANDLE_KIND_INVALID  0x0
118 #define HANDLE_KIND_BUILTIN  0x1
119 #define HANDLE_KIND_DIRECT   0x2
120 #define HANDLE_KIND_INDIRECT 0x3
121 /* Mask assumes that ints are at least 4 bytes */
122 #define HANDLE_KIND_MASK 0xc0000000
123 #define HANDLE_KIND_SHIFT 30
124 #define HANDLE_GET_KIND(a) (((a)&HANDLE_KIND_MASK)>>HANDLE_KIND_SHIFT)
125 #define HANDLE_SET_KIND(a,kind) ((a)|((kind)<<HANDLE_KIND_SHIFT))
126
127 /* For indirect, the remainder of the handle has a block and index */
128 #define HANDLE_INDIRECT_SHIFT 16
129 #define HANDLE_BLOCK(a) (((a)& 0x03FF0000) >> HANDLE_INDIRECT_SHIFT)
130 #define HANDLE_BLOCK_INDEX(a) ((a) & 0x0000FFFF)
131
132 /* Handle block is between 1 and 1024 *elements* */
133 #define HANDLE_BLOCK_SIZE 256
134 /* Index size is bewtween 1 and 65536 *elements* */
135 #define HANDLE_BLOCK_INDEX_SIZE 1024
136
137 /* For direct, the remainder of the handle is the index into a predefined 
138    block */
139 #define HANDLE_MASK 0x03FFFFFF
140 #define HANDLE_INDEX(a) ((a)& HANDLE_MASK)
141
142 /* ALL objects have the handle as the first value. */
143 /* Inactive (unused and stored on the appropriate avail list) objects 
144    have MPIU_Handle_common as the head */
145 typedef struct MPIU_Handle_common {
146   int handle;
147   volatile int ref_count;       /* This field is used to indicate that the
148                                    object is not in use (see, e.g., 
149                                    MPID_Comm_valid_ptr) */
150   void *next;                   /* Free handles use this field to point to the next
151                                    free object */
152 } MPIU_Handle_common;
153
154 /* All *active* (in use) objects have the handle as the first value; objects
155    with referene counts have the reference count as the second value.
156    See MPIU_Object_add_ref and MPIU_Object_release_ref. */
157 typedef struct MPIU_Handle_head {
158   int handle;
159   volatile int ref_count;
160 } MPIU_Handle_head;
161
162 /* This type contains all of the data, except for the direct array,
163    used by the object allocators. */
164 typedef struct MPIU_Object_alloc_t {
165   MPIU_Handle_common *avail;    /* Next available object */
166   int initialized;              /* */
167   void *(*indirect)[];          /* Pointer to indirect object blocks */
168   int indirect_size;            /* Number of allocated indirect blocks */
169   MPID_Object_kind kind;        /* Kind of object this is for */
170   int size;                     /* Size of an individual object */
171   void *direct;                 /* Pointer to direct block, used 
172                                    for allocation */
173   int direct_size;              /* Size of direct block */
174 } MPIU_Object_alloc_t;
175 extern void *MPIU_Handle_obj_alloc(MPIU_Object_alloc_t *);
176 extern void MPIU_Handle_obj_alloc_start(MPIU_Object_alloc_t *);
177 extern void MPIU_Handle_obj_alloc_complete(MPIU_Object_alloc_t *, int init);
178 extern void MPIU_Handle_obj_free(MPIU_Object_alloc_t *, void *);
179 void *MPIU_Handle_get_ptr_indirect(int, MPIU_Object_alloc_t *);
180 extern void *MPIU_Handle_direct_init(void *direct, int direct_size,
181                                      int obj_size, int handle_type);
182 #endif
183 #define MPID_Getb_ptr(kind,a,bmsk,ptr)                                  \
184 {                                                                       \
185    switch (HANDLE_GET_KIND(a)) {                                        \
186       case HANDLE_KIND_BUILTIN:                                         \
187           ptr=MPID_##kind##_builtin+((a)&(bmsk));                       \
188           break;                                                        \
189       case HANDLE_KIND_DIRECT:                                          \
190           ptr=MPID_##kind##_direct+HANDLE_INDEX(a);                     \
191           break;                                                        \
192       case HANDLE_KIND_INDIRECT:                                        \
193           ptr=((MPID_##kind*)                                           \
194                MPIU_Handle_get_ptr_indirect(a,&MPID_##kind##_mem));     \
195           break;                                                        \
196       case HANDLE_KIND_INVALID:                                         \
197       default:                                                          \
198           ptr=0;                                                        \
199           break;                                                        \
200     }                                                                   \
201 }
202
203
204
205 #define MPID_Op_get_ptr(a,ptr)         MPID_Getb_ptr(Op,a,0x000000ff,ptr)
206 typedef enum MPID_Lang_t { MPID_LANG_C
207 #ifdef HAVE_FORTRAN_BINDING
208       , MPID_LANG_FORTRAN, MPID_LANG_FORTRAN90
209 #endif
210 #ifdef HAVE_CXX_BINDING
211       , MPID_LANG_CXX
212 #endif
213 } MPID_Lang_t;
214 /* Reduction and accumulate operations */
215 /*E
216   MPID_Op_kind - Enumerates types of MPI_Op types
217
218   Notes:
219   These are needed for implementing 'MPI_Accumulate', since only predefined
220   operations are allowed for that operation.  
221
222   A gap in the enum values was made allow additional predefined operations
223   to be inserted.  This might include future additions to MPI or experimental
224   extensions (such as a Read-Modify-Write operation).
225
226   Module:
227   Collective-DS
228   E*/
229 typedef enum MPID_Op_kind { MPID_OP_MAX = 1, MPID_OP_MIN = 2,
230   MPID_OP_SUM = 3, MPID_OP_PROD = 4,
231   MPID_OP_LAND = 5, MPID_OP_BAND = 6, MPID_OP_LOR = 7, MPID_OP_BOR = 8,
232   MPID_OP_LXOR = 9, MPID_OP_BXOR = 10, MPID_OP_MAXLOC = 11,
233   MPID_OP_MINLOC = 12, MPID_OP_REPLACE = 13,
234   MPID_OP_USER_NONCOMMUTE = 32, MPID_OP_USER = 33
235 } MPID_Op_kind;
236
237 /*S
238   MPID_User_function - Definition of a user function for MPI_Op types.
239
240   Notes:
241   This includes a 'const' to make clear which is the 'in' argument and 
242   which the 'inout' argument, and to indicate that the 'count' and 'datatype'
243   arguments are unchanged (they are addresses in an attempt to allow 
244   interoperation with Fortran).  It includes 'restrict' to emphasize that 
245   no overlapping operations are allowed.
246
247   We need to include a Fortran version, since those arguments will
248   have type 'MPI_Fint *' instead.  We also need to add a test to the
249   test suite for this case; in fact, we need tests for each of the handle
250   types to ensure that the transfered handle works correctly.
251
252   This is part of the collective module because user-defined operations
253   are valid only for the collective computation routines and not for 
254   RMA accumulate.
255
256   Yes, the 'restrict' is in the correct location.  C compilers that 
257   support 'restrict' should be able to generate code that is as good as a
258   Fortran compiler would for these functions.
259
260   We should note on the manual pages for user-defined operations that
261   'restrict' should be used when available, and that a cast may be 
262   required when passing such a function to 'MPI_Op_create'.
263
264   Question:
265   Should each of these function types have an associated typedef?
266
267   Should there be a C++ function here?
268
269   Module:
270   Collective-DS
271   S*/
272 typedef union MPID_User_function {
273   void (*c_function) (const void *, void *, const int *, const MPI_Datatype *);
274   void (*f77_function) (const void *, void *,
275                         const MPI_Fint *, const MPI_Fint *);
276 } MPID_User_function;
277 /* FIXME: Should there be "restrict" in the definitions above, e.g., 
278    (*c_function)( const void restrict * , void restrict *, ... )? */
279
280 /*S
281   MPID_Op - MPI_Op structure
282
283   Notes:
284   All of the predefined functions are commutative.  Only user functions may 
285   be noncummutative, so there are two separate op types for commutative and
286   non-commutative user-defined operations.
287
288   Operations do not require reference counts because there are no nonblocking
289   operations that accept user-defined operations.  Thus, there is no way that
290   a valid program can free an 'MPI_Op' while it is in use.
291
292   Module:
293   Collective-DS
294   S*/
295 typedef struct MPID_Op {
296   int handle;                   /* value of MPI_Op for this structure */
297   volatile int ref_count;
298   MPID_Op_kind kind;
299   MPID_Lang_t language;
300   MPID_User_function function;
301 } MPID_Op;
302 #define MPID_OP_N_BUILTIN 14
303 extern MPID_Op MPID_Op_builtin[MPID_OP_N_BUILTIN];
304 extern MPID_Op MPID_Op_direct[];
305 extern MPIU_Object_alloc_t MPID_Op_mem;
306
307 /*****************************************************************************
308
309 * Function: get_op_func
310
311 * return: Pointer to MPI_User_function
312
313 * inputs:
314    op: operator (max, min, etc)
315
316    * Descrp: Function returns the function associated with current operator
317    * op.
318
319    * Auther: AHMAD FARAJ
320
321 ****************************************************************************/
322 MPI_User_function *get_op_func(MPI_Op op)
323 {
324
325   if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN)
326     return MPIR_Op_table[op % 16 - 1];
327   return NULL;
328 }
329
330 #endif
331
332
333 int smpi_coll_tuned_allreduce_rab_reduce_scatter(void *sbuff, void *rbuff,
334                                                  int count, MPI_Datatype dtype,
335                                                  MPI_Op op, MPI_Comm comm)
336 {
337   int nprocs, rank, type_size, tag = 543;
338   int mask, dst, pof2, newrank, rem, newdst, i,
339       send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
340   MPI_Aint lb, extent;
341   MPI_Status status;
342   void *tmp_buf = NULL;
343   MPI_User_function *func = get_op_func(op);
344   nprocs = smpi_comm_size(comm);
345   rank = smpi_comm_rank(comm);
346
347   extent = smpi_datatype_get_extent(dtype);
348   tmp_buf = (void *) xbt_malloc(count * extent);
349
350   MPIR_Localcopy(sbuff, count, dtype, rbuff, count, dtype);
351
352   type_size = smpi_datatype_size(dtype);
353
354   // find nearest power-of-two less than or equal to comm_size
355   pof2 = 1;
356   while (pof2 <= nprocs)
357     pof2 <<= 1;
358   pof2 >>= 1;
359
360   rem = nprocs - pof2;
361
362   // In the non-power-of-two case, all even-numbered
363   // processes of rank < 2*rem send their data to
364   // (rank+1). These even-numbered processes no longer
365   // participate in the algorithm until the very end. The
366   // remaining processes form a nice power-of-two. 
367
368   if (rank < 2 * rem) {
369     // even       
370     if (rank % 2 == 0) {
371
372       MPIC_Send(rbuff, count, dtype, rank + 1, tag, comm);
373
374       // temporarily set the rank to -1 so that this
375       // process does not pariticipate in recursive
376       // doubling
377       newrank = -1;
378     } else                      // odd
379     {
380       MPIC_Recv(tmp_buf, count, dtype, rank - 1, tag, comm, &status);
381       // do the reduction on received data. since the
382       // ordering is right, it doesn't matter whether
383       // the operation is commutative or not.
384       (*func) (tmp_buf, rbuff, &count, &dtype);
385
386       // change the rank 
387       newrank = rank / 2;
388     }
389   }
390
391   else                          // rank >= 2 * rem 
392     newrank = rank - rem;
393
394   // If op is user-defined or count is less than pof2, use
395   // recursive doubling algorithm. Otherwise do a reduce-scatter
396   // followed by allgather. (If op is user-defined,
397   // derived datatypes are allowed and the user could pass basic
398   // datatypes on one process and derived on another as long as
399   // the type maps are the same. Breaking up derived
400   // datatypes to do the reduce-scatter is tricky, therefore
401   // using recursive doubling in that case.) 
402
403   if (newrank != -1) {
404     // do a reduce-scatter followed by allgather. for the
405     // reduce-scatter, calculate the count that each process receives
406     // and the displacement within the buffer 
407
408     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
409     disps = (int *) xbt_malloc(pof2 * sizeof(int));
410
411     for (i = 0; i < (pof2 - 1); i++)
412       cnts[i] = count / pof2;
413     cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
414
415     disps[0] = 0;
416     for (i = 1; i < pof2; i++)
417       disps[i] = disps[i - 1] + cnts[i - 1];
418
419     mask = 0x1;
420     send_idx = recv_idx = 0;
421     last_idx = pof2;
422     while (mask < pof2) {
423       newdst = newrank ^ mask;
424       // find real rank of dest 
425       dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
426
427       send_cnt = recv_cnt = 0;
428       if (newrank < newdst) {
429         send_idx = recv_idx + pof2 / (mask * 2);
430         for (i = send_idx; i < last_idx; i++)
431           send_cnt += cnts[i];
432         for (i = recv_idx; i < send_idx; i++)
433           recv_cnt += cnts[i];
434       } else {
435         recv_idx = send_idx + pof2 / (mask * 2);
436         for (i = send_idx; i < recv_idx; i++)
437           send_cnt += cnts[i];
438         for (i = recv_idx; i < last_idx; i++)
439           recv_cnt += cnts[i];
440       }
441
442       // Send data from recvbuf. Recv into tmp_buf 
443       MPIC_Sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
444                     dtype, dst, tag,
445                     (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt,
446                     dtype, dst, tag, comm, &status);
447
448       // tmp_buf contains data received in this step.
449       // recvbuf contains data accumulated so far 
450
451       // This algorithm is used only for predefined ops
452       // and predefined ops are always commutative.
453       (*func) ((char *) tmp_buf + disps[recv_idx] * extent,
454                (char *) rbuff + disps[recv_idx] * extent, &recv_cnt, &dtype);
455
456       // update send_idx for next iteration 
457       send_idx = recv_idx;
458       mask <<= 1;
459
460       // update last_idx, but not in last iteration because the value
461       // is needed in the allgather step below. 
462       if (mask < pof2)
463         last_idx = recv_idx + pof2 / mask;
464     }
465
466     // now do the allgather 
467
468     mask >>= 1;
469     while (mask > 0) {
470       newdst = newrank ^ mask;
471       // find real rank of dest
472       dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
473
474       send_cnt = recv_cnt = 0;
475       if (newrank < newdst) {
476         // update last_idx except on first iteration 
477         if (mask != pof2 / 2)
478           last_idx = last_idx + pof2 / (mask * 2);
479
480         recv_idx = send_idx + pof2 / (mask * 2);
481         for (i = send_idx; i < recv_idx; i++)
482           send_cnt += cnts[i];
483         for (i = recv_idx; i < last_idx; i++)
484           recv_cnt += cnts[i];
485       } else {
486         recv_idx = send_idx - pof2 / (mask * 2);
487         for (i = send_idx; i < last_idx; i++)
488           send_cnt += cnts[i];
489         for (i = recv_idx; i < send_idx; i++)
490           recv_cnt += cnts[i];
491       }
492
493       MPIC_Sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
494                     dtype, dst, tag,
495                     (char *) rbuff + disps[recv_idx] * extent, recv_cnt,
496                     dtype, dst, tag, comm, &status);
497
498       if (newrank > newdst)
499         send_idx = recv_idx;
500
501       mask >>= 1;
502     }
503
504     free(cnts);
505     free(disps);
506
507   }
508   // In the non-power-of-two case, all odd-numbered processes of
509   // rank < 2 * rem send the result to (rank-1), the ranks who didn't
510   // participate above.
511
512   if (rank < 2 * rem) {
513     if (rank % 2)               // odd 
514       MPIC_Send(rbuff, count, dtype, rank - 1, tag, comm);
515     else                        // even 
516       MPIC_Recv(rbuff, count, dtype, rank + 1, tag, comm, &status);
517   }
518
519   free(tmp_buf);
520   return MPI_SUCCESS;
521 }