Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
7a9eb2a96efe67d4bcc5b766469183be2c0ca6f8
[simgrid.git] / src / smpi / colls / allreduce-rab-reduce-scatter.c
1 #include "colls.h"
2 #ifndef REDUCE_STUFF
3 #define REDUCE_STUFF
4 /*****************************************************************************
5
6 Copyright (c) 2006, Ahmad Faraj & Xin Yuan,
7 All rights reserved.
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are met:
11
12   * Redistributions of source code must retain the above copyright notice,
13     this list of conditions and the following disclaimer.
14
15   * Redistributions in binary form must reproduce the above copyright notice,
16     this list of conditions and the following disclaimer in the documentation
17     and/or other materials provided with the distribution.
18
19   * Neither the name of the Florida State University nor the names of its
20     contributors may be used to endorse or promote products derived from this
21     software without specific prior written permission.
22
23 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
27 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
30 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34   *************************************************************************
35   *     Any results obtained from executing this software require the     *
36   *     acknowledgment and citation of the software and its owners.       *
37   *     The full citation is given below:                                 *
38   *                                                                       *
39   *     A. Faraj and X. Yuan. "Automatic Generation and Tuning of MPI     *
40   *     Collective Communication Routines." The 19th ACM International    *
41   *     Conference on Supercomputing (ICS), Cambridge, Massachusetts,     *
42   *     June 20-22, 2005.                                                 *
43   *************************************************************************
44
45 *****************************************************************************/
46
47 extern MPI_User_function *MPIR_Op_table[];
48
49
50 /* -*- Mode: C; c-basic-offset:4 ; -*- */
51 /*  $Id: mpich-stuff.h,v 1.1 2005/08/22 19:50:21 faraj Exp $
52  *
53  *  (C) 2001 by Argonne National Laboratory.
54  *      See COPYRIGHT in top-level directory.
55  */
56 #ifndef _MPICH_STUFF_H
57 #define _MPICH_STUFF_H
58
59 /*TOpaqOverview.tex
60   MPI Opaque Objects:
61
62   MPI Opaque objects such as 'MPI_Comm' or 'MPI_Datatype' are specified by 
63   integers (in the MPICH2 implementation); the MPI standard calls these
64   handles.  
65   Out of range values are invalid; the value 0 is reserved.
66   For most (with the possible exception of 
67   'MPI_Request' for performance reasons) MPI Opaque objects, the integer
68   encodes both the kind of object (allowing runtime tests to detect a datatype
69   passed where a communicator is expected) and important properties of the 
70   object.  Even the 'MPI_xxx_NULL' values should be encoded so that 
71   different null handles can be distinguished.  The details of the encoding
72   of the handles is covered in more detail in the MPICH2 Design Document.
73   For the most part, the ADI uses pointers to the underlying structures
74   rather than the handles themselves.  However, each structure contains an 
75   'handle' field that is the corresponding integer handle for the MPI object.
76
77   MPID objects (objects used within the implementation of MPI) are not opaque.
78
79   T*/
80
81 /* Known MPI object types.  These are used for both the error handlers 
82    and for the handles.  This is a 4 bit value.  0 is reserved for so 
83    that all-zero handles can be flagged as an error. */
84 /*E
85   MPID_Object_kind - Object kind (communicator, window, or file)
86
87   Notes:
88   This enum is used by keyvals and errhandlers to indicate the type of
89   object for which MPI opaque types the data is valid.  These are defined
90   as bits to allow future expansion to the case where an object is value for
91   multiple types (for example, we may want a universal error handler for 
92   errors return).  This is also used to indicate the type of MPI object a 
93   MPI handle represents.  It is an enum because only this applies only the
94   the MPI objects.
95
96   Module:
97   Attribute-DS
98   E*/
99 typedef enum MPID_Object_kind {
100   MPID_COMM = 0x1,
101   MPID_GROUP = 0x2,
102   MPID_DATATYPE = 0x3,
103   MPID_FILE = 0x4,
104   MPID_ERRHANDLER = 0x5,
105   MPID_OP = 0x6,
106   MPID_INFO = 0x7,
107   MPID_WIN = 0x8,
108   MPID_KEYVAL = 0x9,
109   MPID_ATTR = 0xa,
110   MPID_REQUEST = 0xb
111 } MPID_Object_kind;
112 /* The above objects should correspond to MPI objects only. */
113 #define HANDLE_MPI_KIND_SHIFT 26
114 #define HANDLE_GET_MPI_KIND(a) ( ((a)&0x3c000000) >> HANDLE_MPI_KIND_SHIFT )
115
116 /* Handle types.  These are really 2 bits */
117 #define HANDLE_KIND_INVALID  0x0
118 #define HANDLE_KIND_BUILTIN  0x1
119 #define HANDLE_KIND_DIRECT   0x2
120 #define HANDLE_KIND_INDIRECT 0x3
121 /* Mask assumes that ints are at least 4 bytes */
122 #define HANDLE_KIND_MASK 0xc0000000
123 #define HANDLE_KIND_SHIFT 30
124 #define HANDLE_GET_KIND(a) (((a)&HANDLE_KIND_MASK)>>HANDLE_KIND_SHIFT)
125 #define HANDLE_SET_KIND(a,kind) ((a)|((kind)<<HANDLE_KIND_SHIFT))
126
127 /* For indirect, the remainder of the handle has a block and index */
128 #define HANDLE_INDIRECT_SHIFT 16
129 #define HANDLE_BLOCK(a) (((a)& 0x03FF0000) >> HANDLE_INDIRECT_SHIFT)
130 #define HANDLE_BLOCK_INDEX(a) ((a) & 0x0000FFFF)
131
132 /* Handle block is between 1 and 1024 *elements* */
133 #define HANDLE_BLOCK_SIZE 256
134 /* Index size is bewtween 1 and 65536 *elements* */
135 #define HANDLE_BLOCK_INDEX_SIZE 1024
136
137 /* For direct, the remainder of the handle is the index into a predefined 
138    block */
139 #define HANDLE_MASK 0x03FFFFFF
140 #define HANDLE_INDEX(a) ((a)& HANDLE_MASK)
141
142 /* ALL objects have the handle as the first value. */
143 /* Inactive (unused and stored on the appropriate avail list) objects 
144    have MPIU_Handle_common as the head */
145 typedef struct MPIU_Handle_common {
146   int handle;
147   volatile int ref_count;       /* This field is used to indicate that the
148                                    object is not in use (see, e.g., 
149                                    MPID_Comm_valid_ptr) */
150   void *next;                   /* Free handles use this field to point to the next
151                                    free object */
152 } MPIU_Handle_common;
153
154 /* All *active* (in use) objects have the handle as the first value; objects
155    with referene counts have the reference count as the second value.
156    See MPIU_Object_add_ref and MPIU_Object_release_ref. */
157 typedef struct MPIU_Handle_head {
158   int handle;
159   volatile int ref_count;
160 } MPIU_Handle_head;
161
162 /* This type contains all of the data, except for the direct array,
163    used by the object allocators. */
164 typedef struct MPIU_Object_alloc_t {
165   MPIU_Handle_common *avail;    /* Next available object */
166   int initialized;              /* */
167   void *(*indirect)[];          /* Pointer to indirect object blocks */
168   int indirect_size;            /* Number of allocated indirect blocks */
169   MPID_Object_kind kind;        /* Kind of object this is for */
170   int size;                     /* Size of an individual object */
171   void *direct;                 /* Pointer to direct block, used 
172                                    for allocation */
173   int direct_size;              /* Size of direct block */
174 } MPIU_Object_alloc_t;
175 extern void *MPIU_Handle_obj_alloc(MPIU_Object_alloc_t *);
176 extern void MPIU_Handle_obj_alloc_start(MPIU_Object_alloc_t *);
177 extern void MPIU_Handle_obj_alloc_complete(MPIU_Object_alloc_t *, int init);
178 extern void MPIU_Handle_obj_free(MPIU_Object_alloc_t *, void *);
179 void *MPIU_Handle_get_ptr_indirect(int, MPIU_Object_alloc_t *);
180 extern void *MPIU_Handle_direct_init(void *direct, int direct_size,
181                                      int obj_size, int handle_type);
182 #endif
183 #define MPID_Getb_ptr(kind,a,bmsk,ptr)                                  \
184 {                                                                       \
185    switch (HANDLE_GET_KIND(a)) {                                        \
186       case HANDLE_KIND_BUILTIN:                                         \
187           ptr=MPID_##kind##_builtin+((a)&(bmsk));                       \
188           break;                                                        \
189       case HANDLE_KIND_DIRECT:                                          \
190           ptr=MPID_##kind##_direct+HANDLE_INDEX(a);                     \
191           break;                                                        \
192       case HANDLE_KIND_INDIRECT:                                        \
193           ptr=((MPID_##kind*)                                           \
194                MPIU_Handle_get_ptr_indirect(a,&MPID_##kind##_mem));     \
195           break;                                                        \
196       case HANDLE_KIND_INVALID:                                         \
197       default:                                                          \
198           ptr=0;                                                        \
199           break;                                                        \
200     }                                                                   \
201 }
202
203
204
205 #define MPID_Op_get_ptr(a,ptr)         MPID_Getb_ptr(Op,a,0x000000ff,ptr)
206 typedef enum MPID_Lang_t { MPID_LANG_C
207 #ifdef HAVE_FORTRAN_BINDING
208       , MPID_LANG_FORTRAN, MPID_LANG_FORTRAN90
209 #endif
210 #ifdef HAVE_CXX_BINDING
211       , MPID_LANG_CXX
212 #endif
213 } MPID_Lang_t;
214 /* Reduction and accumulate operations */
215 /*E
216   MPID_Op_kind - Enumerates types of MPI_Op types
217
218   Notes:
219   These are needed for implementing 'MPI_Accumulate', since only predefined
220   operations are allowed for that operation.  
221
222   A gap in the enum values was made allow additional predefined operations
223   to be inserted.  This might include future additions to MPI or experimental
224   extensions (such as a Read-Modify-Write operation).
225
226   Module:
227   Collective-DS
228   E*/
229 typedef enum MPID_Op_kind { MPID_OP_MAX = 1, MPID_OP_MIN = 2,
230   MPID_OP_SUM = 3, MPID_OP_PROD = 4,
231   MPID_OP_LAND = 5, MPID_OP_BAND = 6, MPID_OP_LOR = 7, MPID_OP_BOR = 8,
232   MPID_OP_LXOR = 9, MPID_OP_BXOR = 10, MPID_OP_MAXLOC = 11,
233   MPID_OP_MINLOC = 12, MPID_OP_REPLACE = 13,
234   MPID_OP_USER_NONCOMMUTE = 32, MPID_OP_USER = 33
235 } MPID_Op_kind;
236
237 /*S
238   MPID_User_function - Definition of a user function for MPI_Op types.
239
240   Notes:
241   This includes a 'const' to make clear which is the 'in' argument and 
242   which the 'inout' argument, and to indicate that the 'count' and 'datatype'
243   arguments are unchanged (they are addresses in an attempt to allow 
244   interoperation with Fortran).  It includes 'restrict' to emphasize that 
245   no overlapping operations are allowed.
246
247   We need to include a Fortran version, since those arguments will
248   have type 'MPI_Fint *' instead.  We also need to add a test to the
249   test suite for this case; in fact, we need tests for each of the handle
250   types to ensure that the transfered handle works correctly.
251
252   This is part of the collective module because user-defined operations
253   are valid only for the collective computation routines and not for 
254   RMA accumulate.
255
256   Yes, the 'restrict' is in the correct location.  C compilers that 
257   support 'restrict' should be able to generate code that is as good as a
258   Fortran compiler would for these functions.
259
260   We should note on the manual pages for user-defined operations that
261   'restrict' should be used when available, and that a cast may be 
262   required when passing such a function to 'MPI_Op_create'.
263
264   Question:
265   Should each of these function types have an associated typedef?
266
267   Should there be a C++ function here?
268
269   Module:
270   Collective-DS
271   S*/
272 typedef union MPID_User_function {
273   void (*c_function) (const void *, void *, const int *, const MPI_Datatype *);
274   void (*f77_function) (const void *, void *,
275                         const MPI_Fint *, const MPI_Fint *);
276 } MPID_User_function;
277 /* FIXME: Should there be "restrict" in the definitions above, e.g., 
278    (*c_function)( const void restrict * , void restrict *, ... )? */
279
280 /*S
281   MPID_Op - MPI_Op structure
282
283   Notes:
284   All of the predefined functions are commutative.  Only user functions may 
285   be noncummutative, so there are two separate op types for commutative and
286   non-commutative user-defined operations.
287
288   Operations do not require reference counts because there are no nonblocking
289   operations that accept user-defined operations.  Thus, there is no way that
290   a valid program can free an 'MPI_Op' while it is in use.
291
292   Module:
293   Collective-DS
294   S*/
295 typedef struct MPID_Op {
296   int handle;                   /* value of MPI_Op for this structure */
297   volatile int ref_count;
298   MPID_Op_kind kind;
299   MPID_Lang_t language;
300   MPID_User_function function;
301 } MPID_Op;
302 #define MPID_OP_N_BUILTIN 14
303 extern MPID_Op MPID_Op_builtin[MPID_OP_N_BUILTIN];
304 extern MPID_Op MPID_Op_direct[];
305 extern MPIU_Object_alloc_t MPID_Op_mem;
306
307 /*****************************************************************************
308
309 * Function: get_op_func
310
311 * return: Pointer to MPI_User_function
312
313 * inputs:
314    op: operator (max, min, etc)
315
316    * Descrp: Function returns the function associated with current operator
317    * op.
318
319    * Auther: AHMAD FARAJ
320
321 ****************************************************************************/
322 MPI_User_function *get_op_func(MPI_Op op)
323 {
324
325   if (HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN)
326     return MPIR_Op_table[op % 16 - 1];
327   return NULL;
328 }
329
330 #endif
331
332
333 int smpi_coll_tuned_allreduce_rab_reduce_scatter(void *sbuff, void *rbuff,
334                                                  int count, MPI_Datatype dtype,
335                                                  MPI_Op op, MPI_Comm comm)
336 {
337   int nprocs, rank, type_size, tag = 543;
338   int mask, dst, pof2, newrank, rem, newdst, i,
339       send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
340   MPI_Aint lb, extent;
341   MPI_Status status;
342   void *tmp_buf = NULL;
343   MPI_User_function *func = get_op_func(op);
344   MPI_Comm_size(comm, &nprocs);
345   MPI_Comm_rank(comm, &rank);
346
347   MPI_Type_extent(dtype, &extent);
348   tmp_buf = (void *) malloc(count * extent);
349   if (!tmp_buf) {
350     printf("Could not allocate memory for tmp_buf\n");
351     return 1;
352   }
353
354   MPIR_Localcopy(sbuff, count, dtype, rbuff, count, dtype);
355
356   MPI_Type_size(dtype, &type_size);
357
358   // find nearest power-of-two less than or equal to comm_size
359   pof2 = 1;
360   while (pof2 <= nprocs)
361     pof2 <<= 1;
362   pof2 >>= 1;
363
364   rem = nprocs - pof2;
365
366   // In the non-power-of-two case, all even-numbered
367   // processes of rank < 2*rem send their data to
368   // (rank+1). These even-numbered processes no longer
369   // participate in the algorithm until the very end. The
370   // remaining processes form a nice power-of-two. 
371
372   if (rank < 2 * rem) {
373     // even       
374     if (rank % 2 == 0) {
375
376       MPIC_Send(rbuff, count, dtype, rank + 1, tag, comm);
377
378       // temporarily set the rank to -1 so that this
379       // process does not pariticipate in recursive
380       // doubling
381       newrank = -1;
382     } else                      // odd
383     {
384       MPIC_Recv(tmp_buf, count, dtype, rank - 1, tag, comm, &status);
385       // do the reduction on received data. since the
386       // ordering is right, it doesn't matter whether
387       // the operation is commutative or not.
388       (*func) (tmp_buf, rbuff, &count, &dtype);
389
390       // change the rank 
391       newrank = rank / 2;
392     }
393   }
394
395   else                          // rank >= 2 * rem 
396     newrank = rank - rem;
397
398   // If op is user-defined or count is less than pof2, use
399   // recursive doubling algorithm. Otherwise do a reduce-scatter
400   // followed by allgather. (If op is user-defined,
401   // derived datatypes are allowed and the user could pass basic
402   // datatypes on one process and derived on another as long as
403   // the type maps are the same. Breaking up derived
404   // datatypes to do the reduce-scatter is tricky, therefore
405   // using recursive doubling in that case.) 
406
407   if (newrank != -1) {
408     // do a reduce-scatter followed by allgather. for the
409     // reduce-scatter, calculate the count that each process receives
410     // and the displacement within the buffer 
411
412     cnts = (int *) malloc(pof2 * sizeof(int));
413     disps = (int *) malloc(pof2 * sizeof(int));
414
415     for (i = 0; i < (pof2 - 1); i++)
416       cnts[i] = count / pof2;
417     cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
418
419     disps[0] = 0;
420     for (i = 1; i < pof2; i++)
421       disps[i] = disps[i - 1] + cnts[i - 1];
422
423     mask = 0x1;
424     send_idx = recv_idx = 0;
425     last_idx = pof2;
426     while (mask < pof2) {
427       newdst = newrank ^ mask;
428       // find real rank of dest 
429       dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
430
431       send_cnt = recv_cnt = 0;
432       if (newrank < newdst) {
433         send_idx = recv_idx + pof2 / (mask * 2);
434         for (i = send_idx; i < last_idx; i++)
435           send_cnt += cnts[i];
436         for (i = recv_idx; i < send_idx; i++)
437           recv_cnt += cnts[i];
438       } else {
439         recv_idx = send_idx + pof2 / (mask * 2);
440         for (i = send_idx; i < recv_idx; i++)
441           send_cnt += cnts[i];
442         for (i = recv_idx; i < last_idx; i++)
443           recv_cnt += cnts[i];
444       }
445
446       // Send data from recvbuf. Recv into tmp_buf 
447       MPIC_Sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
448                     dtype, dst, tag,
449                     (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt,
450                     dtype, dst, tag, comm, &status);
451
452       // tmp_buf contains data received in this step.
453       // recvbuf contains data accumulated so far 
454
455       // This algorithm is used only for predefined ops
456       // and predefined ops are always commutative.
457       (*func) ((char *) tmp_buf + disps[recv_idx] * extent,
458                (char *) rbuff + disps[recv_idx] * extent, &recv_cnt, &dtype);
459
460       // update send_idx for next iteration 
461       send_idx = recv_idx;
462       mask <<= 1;
463
464       // update last_idx, but not in last iteration because the value
465       // is needed in the allgather step below. 
466       if (mask < pof2)
467         last_idx = recv_idx + pof2 / mask;
468     }
469
470     // now do the allgather 
471
472     mask >>= 1;
473     while (mask > 0) {
474       newdst = newrank ^ mask;
475       // find real rank of dest
476       dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
477
478       send_cnt = recv_cnt = 0;
479       if (newrank < newdst) {
480         // update last_idx except on first iteration 
481         if (mask != pof2 / 2)
482           last_idx = last_idx + pof2 / (mask * 2);
483
484         recv_idx = send_idx + pof2 / (mask * 2);
485         for (i = send_idx; i < recv_idx; i++)
486           send_cnt += cnts[i];
487         for (i = recv_idx; i < last_idx; i++)
488           recv_cnt += cnts[i];
489       } else {
490         recv_idx = send_idx - pof2 / (mask * 2);
491         for (i = send_idx; i < last_idx; i++)
492           send_cnt += cnts[i];
493         for (i = recv_idx; i < send_idx; i++)
494           recv_cnt += cnts[i];
495       }
496
497       MPIC_Sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
498                     dtype, dst, tag,
499                     (char *) rbuff + disps[recv_idx] * extent, recv_cnt,
500                     dtype, dst, tag, comm, &status);
501
502       if (newrank > newdst)
503         send_idx = recv_idx;
504
505       mask >>= 1;
506     }
507
508     free(cnts);
509     free(disps);
510
511   }
512   // In the non-power-of-two case, all odd-numbered processes of
513   // rank < 2 * rem send the result to (rank-1), the ranks who didn't
514   // participate above.
515
516   if (rank < 2 * rem) {
517     if (rank % 2)               // odd 
518       MPIC_Send(rbuff, count, dtype, rank - 1, tag, comm);
519     else                        // even 
520       MPIC_Recv(rbuff, count, dtype, rank + 1, tag, comm, &status);
521   }
522
523   free(tmp_buf);
524   return 0;
525 }