Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Use C++ style includes when available.
[simgrid.git] / src / smpi / colls / reduce / reduce-rab.cpp
1 /* extracted from mpig_myreduce.c with
2    :3,$s/MPL/MPI/g     and  :%s/\\\\$/ \\/   */
3
4 /* Copyright: Rolf Rabenseifner, 1997
5  *            Computing Center University of Stuttgart
6  *            rabenseifner@rus.uni-stuttgart.de
7  *
8  * The usage of this software is free,
9  * but this header must not be removed.
10  */
11
12 #include "../colls_private.h"
13 #include <cstdio>
14 #include <cstdlib>
15
16 #define REDUCE_NEW_ALWAYS 1
17
18 #ifdef CRAY
19 #      define SCR_LNG_OPTIM(bytelng)  128 + ((bytelng+127)/256) * 256;
20                                  /* =  16 + multiple of 32 doubles*/
21
22 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
23        /*  routine =    reduce                allreduce             */ \
24        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
25  static int Lsh[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
26  static int Lin[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
27  static int Llg[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
28  static int Lfp[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
29  static int Ldb[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
30  static int Lby[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};
31 #endif
32
33 #ifdef REDUCE_NEW_ALWAYS
34 # undef  REDUCE_LIMITS
35 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
36        /*  routine =    reduce                allreduce             */ \
37        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
38  static int Lsh[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
39  static int Lin[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
40  static int Llg[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
41  static int Lfp[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
42  static int Ldb[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
43  static int Lby[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};
44 #endif
45
46 /* Fast reduce and allreduce algorithm for longer buffers and predefined
47    operations.
48
49    This algorithm is explaned with the example of 13 nodes.
50    The nodes are numbered   0, 1, 2, ... 12.
51    The sendbuf content is   a, b, c, ...  m.
52    The buffer array is notated with ABCDEFGH, this means that
53    e.g. 'C' is the third 1/8 of the buffer.
54
55    The algorithm computes
56
57    { [(a+b)+(c+d)] + [(e+f)+(g+h)] }  +  { [(i+j)+k] + [l+m] }
58
59    This is equivalent to the mostly used binary tree algorithm
60    (e.g. used in mpich).
61
62    size := number of nodes in the communicator.
63    2**n := the power of 2 that is next smaller or equal to the size.
64    r    := size - 2**n
65
66    Exa.: size=13 ==> n=3, r=5  (i.e. size == 13 == 2**n+r ==  2**3 + 5)
67
68    The algoritm needs for the execution of one Colls::reduce
69
70    - for r==0
71      exec_time = n*(L1+L2)     + buf_lng * (1-1/2**n) * (T1 + T2 + O/d)
72
73    - for r>0
74      exec_time = (n+1)*(L1+L2) + buf_lng               * T1
75                                + buf_lng*(1+1/2-1/2**n)*     (T2 + O/d)
76
77    with L1 = latency of a message transfer e.g. by send+recv
78         L2 = latency of a message exchange e.g. by sendrecv
79         T1 = additional time to transfer 1 byte (= 1/bandwidth)
80         T2 = additional time to exchange 1 byte in both directions.
81         O  = time for one operation
82         d  = size of the datatype
83
84         On a MPP system it is expected that T1==T2.
85
86    In Comparison with the binary tree algorithm that needs:
87
88    - for r==0
89      exec_time_bin_tree =  n * L1   +  buf_lng * n * (T1 + O/d)
90
91    - for r>0
92      exec_time_bin_tree = (n+1)*L1  +  buf_lng*(n+1)*(T1 + O/d)
93
94    the new algorithm is faster if (assuming T1=T2)
95
96     for n>2:
97      for r==0:  buf_lng  >  L2 *   n   / [ (n-2) * T1 + (n-1) * O/d ]
98      for r>0:   buf_lng  >  L2 * (n+1) / [ (n-1.5)*T1 + (n-0.5)*O/d ]
99
100     for size = 5, 6, and 7:
101                 buf_lng  >  L2 / [0.25 * T1  +  0.58 * O/d ]
102
103     for size = 4:
104                 buf_lng  >  L2 / [0.25 * T1  +  0.62 * O/d ]
105     for size = 2 and 3:
106                 buf_lng  >  L2 / [  0.5 * O/d ]
107
108    and for O/d >> T1 we can summarize:
109
110     The new algorithm is faster for about   buf_lng > 2 * L2 * d / O
111
112    Example L1 = L2 = 50 us,
113            bandwidth=300MB/s, i.e. T1 = T2 = 3.3 ns/byte
114            and 10 MFLOP/s,    i.e. O  = 100 ns/FLOP
115            for double,        i.e. d  =   8 byte/FLOP
116
117            ==> the new algorithm is faster for about buf_lng>8000 bytes
118                i.e count > 1000 doubles !
119 Step 1)
120
121    compute n and r
122
123 Step 2)
124
125    if  myrank < 2*r
126
127      split the buffer into ABCD and EFGH
128
129      even myrank: send    buffer EFGH to   myrank+1 (buf_lng/2 exchange)
130                   receive buffer ABCD from myrank+1
131                   compute op for ABCD                 (count/2 ops)
132                   receive result EFGH               (buf_lng/2 transfer)
133      odd  myrank: send    buffer ABCD to   myrank+1
134                   receive buffer EFGH from myrank+1
135                   compute op for EFGH
136                   send    result EFGH
137
138    Result: node:     0    2    4    6    8  10  11  12
139            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
140
141 Step 3)
142
143    The following algorithm uses only the nodes with
144
145    (myrank is even  &&  myrank < 2*r) || (myrank >= 2*r)
146
147 Step 4)
148
149    define NEWRANK(old) := (old < 2*r ? old/2 : old-r)
150    define OLDRANK(new) := (new < r   ? new*2 : new+r)
151
152    Result:
153        old ranks:    0    2    4    6    8  10  11  12
154        new ranks:    0    1    2    3    4   5   6   7
155            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
156
157 Step 5.1)
158
159    Split the buffer (ABCDEFGH) in the middle,
160    the lower half (ABCD) is computed on even (new) ranks,
161    the opper half (EFGH) is computed on odd  (new) ranks.
162
163    exchange: ABCD from 1 to 0, from 3 to 2, from 5 to 4 and from 7 to 6
164              EFGH from 0 to 1, from 2 to 3, from 4 to 5 and from 6 to 7
165                                                (i.e. buf_lng/2 exchange)
166    compute op in each node on its half:        (i.e.   count/2 ops)
167
168    Result: node 0: (a+b)+(c+d)   for  ABCD
169            node 1: (a+b)+(c+d)   for  EFGH
170            node 2: (e+f)+(g+h)   for  ABCD
171            node 3: (e+f)+(g+h)   for  EFGH
172            node 4: (i+j)+ k      for  ABCD
173            node 5: (i+j)+ k      for  EFGH
174            node 6:    l + m      for  ABCD
175            node 7:    l + m      for  EFGH
176
177 Step 5.2)
178
179    Same with double distance and oncemore the half of the buffer.
180                                                (i.e. buf_lng/4 exchange)
181                                                (i.e.   count/4 ops)
182
183    Result: node 0: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  AB
184            node 1: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  EF
185            node 2: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  CD
186            node 3: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  GH
187            node 4: [(i+j)+ k   ] + [   l + m   ]  for  AB
188            node 5: [(i+j)+ k   ] + [   l + m   ]  for  EF
189            node 6: [(i+j)+ k   ] + [   l + m   ]  for  CD
190            node 7: [(i+j)+ k   ] + [   l + m   ]  for  GH
191
192 ...
193 Step 5.n)
194
195    Same with double distance and oncemore the half of the buffer.
196                                             (i.e. buf_lng/2**n exchange)
197                                             (i.e.   count/2**n ops)
198
199    Result:
200      0: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for A
201      1: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for E
202      2: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for C
203      3: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for G
204      4: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for B
205      5: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for F
206      6: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for D
207      7: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for H
208
209
210 For Colls::allreduce:
211 ------------------
212
213 Step 6.1)
214
215    Exchange on the last distance (2*n) the last result
216                                             (i.e. buf_lng/2**n exchange)
217
218 Step 6.2)
219
220    Same with half the distance and double of the result
221                                         (i.e. buf_lng/2**(n-1) exchange)
222
223 ...
224 Step 6.n)
225
226    Same with distance 1 and double of the result (== half of the
227    original buffer)                            (i.e. buf_lng/2 exchange)
228
229    Result after 6.1     6.2         6.n
230
231     on node 0:   AB    ABCD    ABCDEFGH
232     on node 1:   EF    EFGH    ABCDEFGH
233     on node 2:   CD    ABCD    ABCDEFGH
234     on node 3:   GH    EFGH    ABCDEFGH
235     on node 4:   AB    ABCD    ABCDEFGH
236     on node 5:   EF    EFGH    ABCDEFGH
237     on node 6:   CD    ABCD    ABCDEFGH
238     on node 7:   GH    EFGH    ABCDEFGH
239
240 Step 7)
241
242    If r > 0
243      transfer the result from the even nodes with old rank < 2*r
244                            to the odd  nodes with old rank < 2*r
245                                                  (i.e. buf_lng transfer)
246    Result:
247      { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] }
248      for ABCDEFGH
249      on all nodes 0..12
250
251
252 For Colls::reduce:
253 ---------------
254
255 Step 6.0)
256
257    If root node not in the list of the nodes with newranks
258    (see steps 3+4) then
259
260      send last result from node 0 to the root    (buf_lng/2**n transfer)
261      and replace the role of node 0 by the root node.
262
263 Step 6.1)
264
265    Send on the last distance (2**(n-1)) the last result
266    from node with bit '2**(n-1)' in the 'new rank' unequal to that of
267    root's new rank  to the node with same '2**(n-1)' bit.
268                                             (i.e. buf_lng/2**n transfer)
269
270 Step 6.2)
271
272    Same with half the distance and double of the result
273    and bit '2**(n-2)'
274                                         (i.e. buf_lng/2**(n-1) transfer)
275
276 ...
277 Step 6.n)
278
279    Same with distance 1 and double of the result (== half of the
280    original buffer) and bit '2**0'             (i.e. buf_lng/2 transfer)
281
282    Example: roots old rank: 10
283             roots new rank:  5
284
285    Results:          6.1               6.2                 6.n
286                 action result     action result       action result
287
288     on node 0:  send A
289     on node 1:  send E
290     on node 2:  send C
291     on node 3:  send G
292     on node 4:  recv A => AB     recv CD => ABCD   send ABCD
293     on node 5:  recv E => EF     recv GH => EFGH   recv ABCD => ABCDEFGH
294     on node 6:  recv C => CD     send CD
295     on node 7:  recv G => GH     send GH
296
297 Benchmark results on CRAY T3E
298 -----------------------------
299
300    uname -a: (sn6307 hwwt3e 1.6.1.51 unicosmk CRAY T3E)
301    MPI:      /opt/ctl/mpt/1.1.0.3
302    datatype: MPI_DOUBLE
303    Ldb[][] = {{ 896,1728, 576, 736},{ 448,1280, 512, 512}}
304    env: export MPI_BUFFER_MAX=4099
305    compiled with: cc -c -O3 -h restrict=f
306
307    old = binary tree protocol of the vendor
308    new = the new protocol and its implementation
309    mixed = 'new' is used if count > limit(datatype, communicator size)
310
311    REDUCE:
312                                           communicator size
313    measurement count prot. unit    2    3    4    6    8   12   16   24
314    --------------------------------------------------------------------
315    latency         1 mixed  us  20.7 35.1 35.6 49.1 49.2 61.8 62.4 74.8
316                        old  us  19.0 32.9 33.8 47.1 47.2 61.7 62.1 73.2
317    --------------------------------------------------------------------
318    bandwidth    128k mixed MB/s 75.4 34.9 49.1 27.3 41.1 24.6 38.0 23.7
319    (=buf_lng/time)     old MB/s 28.8 16.2 16.3 11.6 11.6  8.8  8.8  7.2
320    ration = mixed/old            2.6  2.1  3.0  2.4  3.5  2.8  4.3  3.3
321    --------------------------------------------------------------------
322    limit                doubles  896 1536  576  736  576  736  576  736
323    bandwidth   limit mixed MB/s 35.9 20.5 18.6 12.3 13.5  9.2 10.0  8.6
324                        old MB/s 35.9 20.3 17.8 13.0 12.5  9.6  9.2  7.8
325    ratio = mixed/old            1.00 1.01 1.04 0.95 1.08 0.96 1.09 1.10
326
327                                           communicator size
328    measurement count prot. unit   32   48   64   96  128  192  256
329    ---------------------------------------------------------------
330    latency         1 mixed  us  77.8 88.5 90.6  102                 1)
331                        old  us  78.6 87.2 90.1 99.7  108  119  120
332    ---------------------------------------------------------------
333    bandwidth    128k mixed MB/s 35.1 23.3 34.1 22.8 34.4 22.4 33.9
334    (=buf_lng/time)     old MB/s  6.0  6.0  6.0  5.2  5.2  4.6  4.6
335    ration = mixed/old            5.8  3.9  5.7  4.4  6.6  4.8  7.4  5)
336    ---------------------------------------------------------------
337    limit                doubles  576  736  576  736  576  736  576  2)
338    bandwidth   limit mixed MB/s  9.7  7.5  8.4  6.5  6.9  5.5  5.1  3)
339                        old MB/s  7.7  6.4  6.4  5.7  5.5  4.9  4.7  3)
340    ratio = mixed/old            1.26 1.17 1.31 1.14 1.26 1.12 1.08  4)
341
342    ALLREDUCE:
343                                           communicator size
344    measurement count prot. unit    2    3    4    6    8   12   16   24
345    --------------------------------------------------------------------
346    latency         1 mixed  us  28.2 51.0 44.5 74.4 59.9  102 74.2  133
347                        old  us  26.9 48.3 42.4 69.8 57.6 96.0 75.7  126
348    --------------------------------------------------------------------
349    bandwidth    128k mixed MB/s 74.0 29.4 42.4 23.3 35.0 20.9 32.8 19.7
350    (=buf_lng/time)     old MB/s 20.9 14.4 13.2  9.7  9.6  7.3  7.4  5.8
351    ration = mixed/old            3.5  2.0  3.2  2.4  3.6  2.9  4.4  3.4
352    --------------------------------------------------------------------
353    limit                doubles  448 1280  512  512  512  512  512  512
354    bandwidth   limit mixed MB/s 26.4 15.1 16.2  8.2 12.4  7.2 10.8  5.7
355                        old MB/s 26.1 14.9 15.0  9.1 10.5  6.7  7.3  5.3
356    ratio = mixed/old            1.01 1.01 1.08 0.90 1.18 1.07 1.48 1.08
357
358                                           communicator size
359    measurement count prot. unit   32   48   64   96  128  192  256
360    ---------------------------------------------------------------
361    latency         1 mixed  us  90.3  162  109  190
362                        old  us  92.7  152  104  179  122  225  135
363    ---------------------------------------------------------------
364    bandwidth    128k mixed MB/s 31.1 19.7 30.1 19.2 29.8 18.7 29.0
365    (=buf_lng/time)     old MB/s  5.9  4.8  5.0  4.1  4.4  3.4  3.8
366    ration = mixed/old            5.3  4.1  6.0  4.7  6.8  5.5  7.7
367    ---------------------------------------------------------------
368    limit                doubles  512  512  512  512  512  512  512
369    bandwidth   limit mixed MB/s  6.6  5.6  5.7  3.5  4.4  3.2  3.8
370                        old MB/s  6.3  4.2  5.4  3.6  4.4  3.1
371    ratio = mixed/old            1.05 1.33 1.06 0.97 1.00 1.03
372
373    Footnotes:
374    1) This line shows that the overhead to decide which protocol
375       should be used can be ignored.
376    2) This line shows the limit for the count argument.
377       If count < limit then the vendor protocol is used,
378       otherwise the new protocol is used (see variable Ldb).
379    3) These lines show the bandwidth (=bufer length / execution time)
380       for both protocols.
381    4) This line shows that the limit is choosen well if the ratio is
382       between 0.95 (loosing 5% for buffer length near and >=limit)
383       and 1.10 (not gaining 10% for buffer length near and <limit).
384    5) This line shows that the new protocol is 2..7 times faster
385       for long counts.
386
387 */
388
389 #ifdef REDUCE_LIMITS
390
391 #ifdef USE_Irecv
392 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
393            { MPI_Request req;                                          \
394              req=Request::irecv(rb,rc,rd,source,rt,comm);                  \
395              Request::send(sb,sc,sd,dest,st,comm);                          \
396              Request::wait(&req,stat);                                      \
397            }
398 #else
399 #ifdef USE_Isend
400 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
401            { MPI_Request req;                                          \
402              req=mpi_mpi_isend(sb,sc,sd,dest,st,comm);                    \
403              Request::recv(rb,rc,rd,source,rt,comm,stat);                   \
404              Request::wait(&req,stat);                                      \
405            }
406 #else
407 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
408            Request::sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat)
409 #endif
410 #endif
411
412 typedef enum {MPIM_SHORT, MPIM_INT, MPIM_LONG, MPIM_UNSIGNED_SHORT,
413               MPIM_UNSIGNED, MPIM_UNSIGNED_LONG, MPIM_UNSIGNED_LONG_LONG, MPIM_FLOAT,
414               MPIM_DOUBLE, MPIM_BYTE} MPIM_Datatype;
415
416 typedef enum {MPIM_MAX, MPIM_MIN, MPIM_SUM, MPIM_PROD,
417               MPIM_LAND, MPIM_BAND, MPIM_LOR, MPIM_BOR,
418               MPIM_LXOR, MPIM_BXOR} MPIM_Op;
419 #define MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_TYPE,TYPE)                   \
420 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
421 { int i;                                                               \
422   switch (op) {                                                        \
423    case MPIM_MAX :                                                     \
424       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
425    case MPIM_MIN :                                                     \
426       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
427    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
428    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
429    case MPIM_LAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] && b2[i]; break;  \
430    case MPIM_LOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] || b2[i]; break;  \
431    case MPIM_LXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] != b2[i]; break;  \
432    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
433    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
434    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
435    default: break;                                                     \
436   }                                                                    \
437 }
438
439 #define MPI_I_DO_OP_FP(MPI_I_do_op_TYPE,TYPE)                          \
440 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op) \
441 { int i;                                                               \
442   switch (op) {                                                        \
443    case MPIM_MAX :                                                     \
444       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
445    case MPIM_MIN :                                                     \
446       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
447    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
448    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
449    default: break;                                                     \
450   }                                                                    \
451 }
452
453 #define MPI_I_DO_OP_BYTE(MPI_I_do_op_TYPE,TYPE)                        \
454 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
455 { int i;                                                               \
456   switch (op) {                                                        \
457    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
458    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
459    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
460    default: break;                                                     \
461   }                                                                    \
462 }
463
464 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_short,  short)
465 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_int,    int)
466 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_long,   long)
467 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ushort, unsigned short)
468 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_uint,   unsigned int)
469 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulong,  unsigned long)
470 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulonglong,  unsigned long long)
471 MPI_I_DO_OP_FP(        MPI_I_do_op_float,  float)
472 MPI_I_DO_OP_FP(        MPI_I_do_op_double, double)
473 MPI_I_DO_OP_BYTE(      MPI_I_do_op_byte,   char)
474
475 #define MPI_I_DO_OP_CALL(MPI_I_do_op_TYPE,TYPE)                        \
476    MPI_I_do_op_TYPE ((TYPE*)b1, (TYPE*)b2, (TYPE*)rslt, cnt, op); break;
477
478 static void MPI_I_do_op(void* b1, void* b2, void* rslt, int cnt,
479                  MPIM_Datatype datatype, MPIM_Op op)
480 {
481  switch (datatype) {
482   case MPIM_SHORT : MPI_I_DO_OP_CALL(MPI_I_do_op_short,  short)
483   case MPIM_INT   : MPI_I_DO_OP_CALL(MPI_I_do_op_int,    int)
484   case MPIM_LONG  : MPI_I_DO_OP_CALL(MPI_I_do_op_long,   long)
485   case MPIM_UNSIGNED_SHORT:
486                    MPI_I_DO_OP_CALL(MPI_I_do_op_ushort, unsigned short)
487   case MPIM_UNSIGNED:
488                    MPI_I_DO_OP_CALL(MPI_I_do_op_uint,   unsigned int)
489   case MPIM_UNSIGNED_LONG:
490                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulong,  unsigned long)
491   case MPIM_UNSIGNED_LONG_LONG:
492                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulonglong,  unsigned long long)
493   case MPIM_FLOAT : MPI_I_DO_OP_CALL(MPI_I_do_op_float,  float)
494   case MPIM_DOUBLE: MPI_I_DO_OP_CALL(MPI_I_do_op_double, double)
495   case MPIM_BYTE  : MPI_I_DO_OP_CALL(MPI_I_do_op_byte,   char)
496  }
497 }
498
499 REDUCE_LIMITS
500 namespace simgrid{
501 namespace smpi{
502 static int MPI_I_anyReduce(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype mpi_datatype, MPI_Op mpi_op, int root, MPI_Comm comm, int is_all)
503 {
504   char *scr1buf, *scr2buf, *scr3buf, *xxx, *sendbuf, *recvbuf;
505   int myrank, size, x_base, x_size, computed, idx;
506   int x_start, x_count = 0, r, n, mynewrank, newroot, partner;
507   int start_even[20], start_odd[20], count_even[20], count_odd[20];
508   MPI_Aint typelng;
509   MPI_Status status;
510   size_t scrlng;
511   int new_prot;
512   MPIM_Datatype datatype = MPIM_INT; MPIM_Op op = MPIM_MAX;
513
514   if     (mpi_datatype==MPI_SHORT         ) datatype=MPIM_SHORT;
515   else if(mpi_datatype==MPI_INT           ) datatype=MPIM_INT;
516   else if(mpi_datatype==MPI_LONG          ) datatype=MPIM_LONG;
517   else if(mpi_datatype==MPI_UNSIGNED_SHORT) datatype=MPIM_UNSIGNED_SHORT;
518   else if(mpi_datatype==MPI_UNSIGNED      ) datatype=MPIM_UNSIGNED;
519   else if(mpi_datatype==MPI_UNSIGNED_LONG ) datatype=MPIM_UNSIGNED_LONG;
520   else if(mpi_datatype==MPI_UNSIGNED_LONG_LONG ) datatype=MPIM_UNSIGNED_LONG_LONG;
521   else if(mpi_datatype==MPI_FLOAT         ) datatype=MPIM_FLOAT;
522   else if(mpi_datatype==MPI_DOUBLE        ) datatype=MPIM_DOUBLE;
523   else if(mpi_datatype==MPI_BYTE          ) datatype=MPIM_BYTE;
524   else
525    THROWF(arg_error,0, "reduce rab algorithm can't be used with this datatype ! ");
526
527   if     (mpi_op==MPI_MAX     ) op=MPIM_MAX;
528   else if(mpi_op==MPI_MIN     ) op=MPIM_MIN;
529   else if(mpi_op==MPI_SUM     ) op=MPIM_SUM;
530   else if(mpi_op==MPI_PROD    ) op=MPIM_PROD;
531   else if(mpi_op==MPI_LAND    ) op=MPIM_LAND;
532   else if(mpi_op==MPI_BAND    ) op=MPIM_BAND;
533   else if(mpi_op==MPI_LOR     ) op=MPIM_LOR;
534   else if(mpi_op==MPI_BOR     ) op=MPIM_BOR;
535   else if(mpi_op==MPI_LXOR    ) op=MPIM_LXOR;
536   else if(mpi_op==MPI_BXOR    ) op=MPIM_BXOR;
537
538   new_prot = 0;
539   MPI_Comm_size(comm, &size);
540   if (size > 1) /*otherwise no balancing_protocol*/
541   { int ss;
542     if      (size==2) ss=0;
543     else if (size==3) ss=1;
544     else { int s = size; while (!(s & 1)) s = s >> 1;
545            if (s==1) /* size == power of 2 */ ss = 2;
546            else      /* size != power of 2 */ ss = 3; }
547     switch(op) {
548      case MPIM_MAX:   case MPIM_MIN: case MPIM_SUM:  case MPIM_PROD:
549      case MPIM_LAND:  case MPIM_LOR: case MPIM_LXOR:
550      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
551       switch(datatype) {
552         case MPIM_SHORT:  case MPIM_UNSIGNED_SHORT:
553          new_prot = count >= Lsh[is_all][ss]; break;
554         case MPIM_INT:    case MPIM_UNSIGNED:
555          new_prot = count >= Lin[is_all][ss]; break;
556         case MPIM_LONG:   case MPIM_UNSIGNED_LONG: case MPIM_UNSIGNED_LONG_LONG:
557          new_prot = count >= Llg[is_all][ss]; break;
558      default:
559         break;
560     } default:
561         break;}
562     switch(op) {
563      case MPIM_MAX:  case MPIM_MIN: case MPIM_SUM: case MPIM_PROD:
564       switch(datatype) {
565         case MPIM_FLOAT:
566          new_prot = count >= Lfp[is_all][ss]; break;
567         case MPIM_DOUBLE:
568          new_prot = count >= Ldb[is_all][ss]; break;
569               default:
570         break;
571     } default:
572         break;}
573     switch(op) {
574      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
575       switch(datatype) {
576         case MPIM_BYTE:
577          new_prot = count >= Lby[is_all][ss]; break;
578               default:
579         break;
580     } default:
581         break;}
582 #   ifdef DEBUG
583     { char *ss_str[]={"two","three","power of 2","no power of 2"};
584       printf("MPI_(All)Reduce: is_all=%1d, size=%1d=%s, new_prot=%1d\n",
585              is_all, size, ss_str[ss], new_prot); fflush(stdout);
586     }
587 #   endif
588   }
589
590   if (new_prot)
591   {
592     sendbuf = (char*) Sendbuf;
593     recvbuf = (char*) Recvbuf;
594     MPI_Comm_rank(comm, &myrank);
595     MPI_Type_extent(mpi_datatype, &typelng);
596     scrlng  = typelng * count;
597 #ifdef NO_CACHE_OPTIMIZATION
598     scr1buf = static_cast<char*>(xbt_malloc(scrlng));
599     scr2buf = static_cast<char*>(xbt_malloc(scrlng));
600     scr3buf = static_cast<char*>(xbt_malloc(scrlng));
601 #else
602 #  ifdef SCR_LNG_OPTIM
603     scrlng = SCR_LNG_OPTIM(scrlng);
604 #  endif
605     scr2buf = static_cast<char*>(xbt_malloc(3*scrlng));   /* To test cache problems.     */
606     scr1buf = scr2buf + 1*scrlng; /* scr1buf and scr3buf must not*/
607     scr3buf = scr2buf + 2*scrlng; /* be used for malloc because  */
608                                   /* they are interchanged below.*/
609 #endif
610     computed = 0;
611     if (is_all) root = myrank; /* for correct recvbuf handling */
612
613   /*...step 1 */
614
615 #   ifdef DEBUG
616      printf("[%2d] step 1 begin\n",myrank); fflush(stdout);
617 #   endif
618     n = 0; x_size = 1;
619     while (2*x_size <= size) { n++; x_size = x_size * 2; }
620     /* x_sixe == 2**n */
621     r = size - x_size;
622
623   /*...step 2 */
624
625 #   ifdef DEBUG
626     printf("[%2d] step 2 begin n=%d, r=%d\n",myrank,n,r);fflush(stdout);
627 #   endif
628     if (myrank < 2*r)
629     {
630       if ((myrank % 2) == 0 /*even*/)
631       {
632         MPI_I_Sendrecv(sendbuf + (count/2)*typelng,
633                        count - count/2, mpi_datatype, myrank+1, 1220,
634                        scr2buf, count/2,mpi_datatype, myrank+1, 1221,
635                        comm, &status);
636         MPI_I_do_op(sendbuf, scr2buf, scr1buf,
637                     count/2, datatype, op);
638         Request::recv(scr1buf + (count/2)*typelng, count - count/2,
639                  mpi_datatype, myrank+1, 1223, comm, &status);
640         computed = 1;
641 #       ifdef DEBUG
642         { int i; printf("[%2d] after step 2: val=",
643                         myrank);
644           for (i=0; i<count; i++)
645            printf(" %5.0lf",((double*)scr1buf)[i] );
646           printf("\n"); fflush(stdout);
647         }
648 #       endif
649       }
650       else /*odd*/
651       {
652         MPI_I_Sendrecv(sendbuf, count/2,mpi_datatype, myrank-1, 1221,
653                        scr2buf + (count/2)*typelng,
654                        count - count/2, mpi_datatype, myrank-1, 1220,
655                        comm, &status);
656         MPI_I_do_op(scr2buf + (count/2)*typelng,
657                     sendbuf + (count/2)*typelng,
658                     scr1buf + (count/2)*typelng,
659                     count - count/2, datatype, op);
660         Request::send(scr1buf + (count/2)*typelng, count - count/2,
661                  mpi_datatype, myrank-1, 1223, comm);
662       }
663     }
664
665   /*...step 3+4 */
666
667 #   ifdef DEBUG
668      printf("[%2d] step 3+4 begin\n",myrank); fflush(stdout);
669 #   endif
670     if ((myrank >= 2*r) || ((myrank%2 == 0)  &&  (myrank < 2*r)))
671          mynewrank = (myrank < 2*r ? myrank/2 : myrank-r);
672     else mynewrank = -1;
673
674     if (mynewrank >= 0)
675     { /* begin -- only for nodes with new rank */
676
677 #     define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
678
679   /*...step 5 */
680
681       x_start = 0;
682       x_count = count;
683       for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
684       {
685         start_even[idx] = x_start;
686         count_even[idx] = x_count / 2;
687         start_odd [idx] = x_start + count_even[idx];
688         count_odd [idx] = x_count - count_even[idx];
689         if (((mynewrank/x_base) % 2) == 0 /*even*/)
690         {
691 #         ifdef DEBUG
692             printf("[%2d](%2d) step 5.%d begin even c=%1d\n",
693                    myrank,mynewrank,idx+1,computed); fflush(stdout);
694 #         endif
695           x_start = start_even[idx];
696           x_count = count_even[idx];
697           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
698                          + start_odd[idx]*typelng, count_odd[idx],
699                          mpi_datatype, OLDRANK(mynewrank+x_base), 1231,
700                          scr2buf + x_start*typelng, x_count,
701                          mpi_datatype, OLDRANK(mynewrank+x_base), 1232,
702                          comm, &status);
703           MPI_I_do_op((computed?scr1buf:sendbuf) + x_start*typelng,
704                       scr2buf                    + x_start*typelng,
705                       ((root==myrank) && (idx==(n-1))
706                         ? recvbuf + x_start*typelng
707                         : scr3buf + x_start*typelng),
708                       x_count, datatype, op);
709         }
710         else /*odd*/
711         {
712 #         ifdef DEBUG
713             printf("[%2d](%2d) step 5.%d begin  odd c=%1d\n",
714                    myrank,mynewrank,idx+1,computed); fflush(stdout);
715 #         endif
716           x_start = start_odd[idx];
717           x_count = count_odd[idx];
718           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
719                          +start_even[idx]*typelng, count_even[idx],
720                          mpi_datatype, OLDRANK(mynewrank-x_base), 1232,
721                          scr2buf + x_start*typelng, x_count,
722                          mpi_datatype, OLDRANK(mynewrank-x_base), 1231,
723                          comm, &status);
724           MPI_I_do_op(scr2buf                    + x_start*typelng,
725                       (computed?scr1buf:sendbuf) + x_start*typelng,
726                       ((root==myrank) && (idx==(n-1))
727                         ? recvbuf + x_start*typelng
728                         : scr3buf + x_start*typelng),
729                       x_count, datatype, op);
730         }
731         xxx = scr3buf; scr3buf = scr1buf; scr1buf = xxx;
732         computed = 1;
733 #       ifdef DEBUG
734         { int i; printf("[%2d](%2d) after step 5.%d   end: start=%2d  count=%2d  val=",
735                         myrank,mynewrank,idx+1,x_start,x_count);
736           for (i=0; i<x_count; i++)
737            printf(" %5.0lf",((double*)((root==myrank)&&(idx==(n-1))
738                              ? recvbuf + x_start*typelng
739                              : scr1buf + x_start*typelng))[i] );
740           printf("\n"); fflush(stdout);
741         }
742 #       endif
743       } /*for*/
744
745 #     undef OLDRANK
746
747
748     } /* end -- only for nodes with new rank */
749
750     if (is_all)
751     {
752       /*...steps 6.1 to 6.n */
753
754       if (mynewrank >= 0)
755       { /* begin -- only for nodes with new rank */
756
757 #       define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
758
759         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
760         {
761 #         ifdef DEBUG
762             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
763 #         endif
764           if (((mynewrank/x_base) % 2) == 0 /*even*/)
765           {
766             MPI_I_Sendrecv(recvbuf + start_even[idx]*typelng,
767                                      count_even[idx],
768                            mpi_datatype, OLDRANK(mynewrank+x_base),1241,
769                            recvbuf + start_odd[idx]*typelng,
770                                      count_odd[idx],
771                            mpi_datatype, OLDRANK(mynewrank+x_base),1242,
772                            comm, &status);
773 #           ifdef DEBUG
774               x_start = start_odd[idx];
775               x_count = count_odd[idx];
776 #           endif
777           }
778           else /*odd*/
779           {
780             MPI_I_Sendrecv(recvbuf + start_odd[idx]*typelng,
781                                      count_odd[idx],
782                            mpi_datatype, OLDRANK(mynewrank-x_base),1242,
783                            recvbuf + start_even[idx]*typelng,
784                                      count_even[idx],
785                            mpi_datatype, OLDRANK(mynewrank-x_base),1241,
786                            comm, &status);
787 #           ifdef DEBUG
788               x_start = start_even[idx];
789               x_count = count_even[idx];
790 #           endif
791           }
792 #         ifdef DEBUG
793           { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
794                           myrank,mynewrank,n-idx,x_start,x_count);
795             for (i=0; i<x_count; i++)
796              printf(" %5.0lf",((double*)(root==myrank
797                                ? recvbuf + x_start*typelng
798                                : scr1buf + x_start*typelng))[i] );
799             printf("\n"); fflush(stdout);
800           }
801 #         endif
802         } /*for*/
803
804 #       undef OLDRANK
805
806       } /* end -- only for nodes with new rank */
807
808       /*...step 7 */
809
810       if (myrank < 2*r)
811       {
812 #       ifdef DEBUG
813           printf("[%2d] step 7 begin\n",myrank); fflush(stdout);
814 #       endif
815         if (myrank%2 == 0 /*even*/)
816           Request::send(recvbuf, count, mpi_datatype, myrank+1, 1253, comm);
817         else /*odd*/
818           Request::recv(recvbuf, count, mpi_datatype, myrank-1, 1253, comm, &status);
819       }
820
821     }
822     else /* not is_all, i.e. Reduce */
823     {
824
825     /*...step 6.0 */
826
827       if ((root < 2*r) && (root%2 == 1))
828       {
829 #       ifdef DEBUG
830           printf("[%2d] step 6.0 begin\n",myrank); fflush(stdout);
831 #       endif
832         if (myrank == 0) /* then mynewrank==0, x_start==0
833                                  x_count == count/x_size  */
834         {
835           Request::send(scr1buf,x_count,mpi_datatype,root,1241,comm);
836           mynewrank = -1;
837         }
838
839         if (myrank == root)
840         {
841           mynewrank = 0;
842           x_start = 0;
843           x_count = count;
844           for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
845           {
846             start_even[idx] = x_start;
847             count_even[idx] = x_count / 2;
848             start_odd [idx] = x_start + count_even[idx];
849             count_odd [idx] = x_count - count_even[idx];
850             /* code for always even in each bit of mynewrank: */
851             x_start = start_even[idx];
852             x_count = count_even[idx];
853           }
854           Request::recv(recvbuf,x_count,mpi_datatype,0,1241,comm,&status);
855         }
856         newroot = 0;
857       }
858       else
859       {
860         newroot = (root < 2*r ? root/2 : root-r);
861       }
862
863     /*...steps 6.1 to 6.n */
864
865       if (mynewrank >= 0)
866       { /* begin -- only for nodes with new rank */
867
868 #       define OLDRANK(new) ((new)==newroot ? root                     \
869                              : ((new)<r ? (new)*2 : (new)+r) )
870
871         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
872         {
873 #         ifdef DEBUG
874             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
875 #         endif
876           if ((mynewrank & x_base) != (newroot & x_base))
877           {
878             if (((mynewrank/x_base) % 2) == 0 /*even*/)
879             { x_start = start_even[idx]; x_count = count_even[idx];
880               partner = mynewrank+x_base; }
881             else
882             { x_start = start_odd[idx]; x_count = count_odd[idx];
883               partner = mynewrank-x_base; }
884             Request::send(scr1buf + x_start*typelng, x_count, mpi_datatype,
885                      OLDRANK(partner), 1244, comm);
886           }
887           else /*odd*/
888           {
889             if (((mynewrank/x_base) % 2) == 0 /*even*/)
890             { x_start = start_odd[idx]; x_count = count_odd[idx];
891               partner = mynewrank+x_base; }
892             else
893             { x_start = start_even[idx]; x_count = count_even[idx];
894               partner = mynewrank-x_base; }
895             Request::recv((myrank==root ? recvbuf : scr1buf)
896                      + x_start*typelng, x_count, mpi_datatype,
897                      OLDRANK(partner), 1244, comm, &status);
898 #           ifdef DEBUG
899             { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
900                             myrank,mynewrank,n-idx,x_start,x_count);
901               for (i=0; i<x_count; i++)
902                printf(" %5.0lf",((double*)(root==myrank
903                                  ? recvbuf + x_start*typelng
904                                  : scr1buf + x_start*typelng))[i] );
905               printf("\n"); fflush(stdout);
906             }
907 #           endif
908           }
909         } /*for*/
910
911 #       undef OLDRANK
912
913       } /* end -- only for nodes with new rank */
914     }
915
916 #   ifdef NO_CACHE_TESTING
917      xbt_free(scr1buf); xbt_free(scr2buf); xbt_free(scr3buf);
918 #   else
919      xbt_free(scr2buf); /* scr1buf and scr3buf are part of scr2buf */
920 #   endif
921     return(MPI_SUCCESS);
922   } /* new_prot */
923   /*otherwise:*/
924   if (is_all)
925    return( Colls::allreduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, comm) );
926   else
927    return( Colls::reduce(Sendbuf,Recvbuf, count,mpi_datatype,mpi_op, root, comm) );
928 }
929 #endif /*REDUCE_LIMITS*/
930
931
932 int Coll_reduce_rab::reduce(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
933 {
934   return( MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, root, comm, 0) );
935 }
936
937 int Coll_allreduce_rab::allreduce(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
938 {
939   return( MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op,   -1, comm, 1) );
940 }
941 }
942 }