Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Fix most of spelling mistakes in src/
[simgrid.git] / src / smpi / colls / reduce / reduce-rab.cpp
1 /* extracted from mpig_myreduce.c with
2    :3,$s/MPL/MPI/g     and  :%s/\\\\$/ \\/   */
3
4 /* Copyright: Rolf Rabenseifner, 1997
5  *            Computing Center University of Stuttgart
6  *            rabenseifner@rus.uni-stuttgart.de
7  *
8  * The usage of this software is free,
9  * but this header must not be removed.
10  */
11
12 #include "../colls_private.hpp"
13 #include <cstdio>
14 #include <cstdlib>
15
16 #define REDUCE_NEW_ALWAYS 1
17
18 #ifdef CRAY
19 #      define SCR_LNG_OPTIM(bytelng)  128 + ((bytelng+127)/256) * 256;
20                                  /* =  16 + multiple of 32 doubles*/
21
22 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
23        /*  routine =    reduce                allreduce             */ \
24        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
25  static int Lsh[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
26  static int Lin[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
27  static int Llg[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
28  static int Lfp[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
29  static int Ldb[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
30  static int Lby[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};
31 #endif
32
33 #ifdef REDUCE_NEW_ALWAYS
34 # undef  REDUCE_LIMITS
35 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
36        /*  routine =    reduce                allreduce             */ \
37        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
38  static int Lsh[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
39  static int Lin[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
40  static int Llg[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
41  static int Lfp[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
42  static int Ldb[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
43  static int Lby[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};
44 #endif
45
46 /* Fast reduce and allreduce algorithm for longer buffers and predefined
47    operations.
48
49    This algorithm is explained with the example of 13 nodes.
50    The nodes are numbered   0, 1, 2, ... 12.
51    The sendbuf content is   a, b, c, ...  m.
52    The buffer array is notated with ABCDEFGH, this means that
53    e.g. 'C' is the third 1/8 of the buffer.
54
55    The algorithm computes
56
57    { [(a+b)+(c+d)] + [(e+f)+(g+h)] }  +  { [(i+j)+k] + [l+m] }
58
59    This is equivalent to the mostly used binary tree algorithm
60    (e.g. used in mpich).
61
62    size := number of nodes in the communicator.
63    2**n := the power of 2 that is next smaller or equal to the size.
64    r    := size - 2**n
65
66    Exa.: size=13 ==> n=3, r=5  (i.e. size == 13 == 2**n+r ==  2**3 + 5)
67
68    The algorithm needs for the execution of one colls::reduce
69
70    - for r==0
71      exec_time = n*(L1+L2)     + buf_lng * (1-1/2**n) * (T1 + T2 + O/d)
72
73    - for r>0
74      exec_time = (n+1)*(L1+L2) + buf_lng               * T1
75                                + buf_lng*(1+1/2-1/2**n)*     (T2 + O/d)
76
77    with L1 = latency of a message transfer e.g. by send+recv
78         L2 = latency of a message exchange e.g. by sendrecv
79         T1 = additional time to transfer 1 byte (= 1/bandwidth)
80         T2 = additional time to exchange 1 byte in both directions.
81         O  = time for one operation
82         d  = size of the datatype
83
84         On a MPP system it is expected that T1==T2.
85
86    In Comparison with the binary tree algorithm that needs:
87
88    - for r==0
89      exec_time_bin_tree =  n * L1   +  buf_lng * n * (T1 + O/d)
90
91    - for r>0
92      exec_time_bin_tree = (n+1)*L1  +  buf_lng*(n+1)*(T1 + O/d)
93
94    the new algorithm is faster if (assuming T1=T2)
95
96     for n>2:
97      for r==0:  buf_lng  >  L2 *   n   / [ (n-2) * T1 + (n-1) * O/d ]
98      for r>0:   buf_lng  >  L2 * (n+1) / [ (n-1.5)*T1 + (n-0.5)*O/d ]
99
100     for size = 5, 6, and 7:
101                 buf_lng  >  L2 / [0.25 * T1  +  0.58 * O/d ]
102
103     for size = 4:
104                 buf_lng  >  L2 / [0.25 * T1  +  0.62 * O/d ]
105     for size = 2 and 3:
106                 buf_lng  >  L2 / [  0.5 * O/d ]
107
108    and for O/d >> T1 we can summarize:
109
110     The new algorithm is faster for about   buf_lng > 2 * L2 * d / O
111
112    Example L1 = L2 = 50 us,
113            bandwidth=300MB/s, i.e. T1 = T2 = 3.3 ns/byte
114            and 10 MFLOP/s,    i.e. O  = 100 ns/FLOP
115            for double,        i.e. d  =   8 byte/FLOP
116
117            ==> the new algorithm is faster for about buf_lng>8000 bytes
118                i.e count > 1000 doubles !
119 Step 1)
120
121    compute n and r
122
123 Step 2)
124
125    if  myrank < 2*r
126
127      split the buffer into ABCD and EFGH
128
129      even myrank: send    buffer EFGH to   myrank+1 (buf_lng/2 exchange)
130                   receive buffer ABCD from myrank+1
131                   compute op for ABCD                 (count/2 ops)
132                   receive result EFGH               (buf_lng/2 transfer)
133      odd  myrank: send    buffer ABCD to   myrank+1
134                   receive buffer EFGH from myrank+1
135                   compute op for EFGH
136                   send    result EFGH
137
138    Result: node:     0    2    4    6    8  10  11  12
139            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
140
141 Step 3)
142
143    The following algorithm uses only the nodes with
144
145    (myrank is even  &&  myrank < 2*r) || (myrank >= 2*r)
146
147 Step 4)
148
149    define NEWRANK(old) := (old < 2*r ? old/2 : old-r)
150    define OLDRANK(new) := (new < r   ? new*2 : new+r)
151
152    Result:
153        old ranks:    0    2    4    6    8  10  11  12
154        new ranks:    0    1    2    3    4   5   6   7
155            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
156
157 Step 5.1)
158
159    Split the buffer (ABCDEFGH) in the middle,
160    the lower half (ABCD) is computed on even (new) ranks,
161    the opper half (EFGH) is computed on odd  (new) ranks.
162
163    exchange: ABCD from 1 to 0, from 3 to 2, from 5 to 4 and from 7 to 6
164              EFGH from 0 to 1, from 2 to 3, from 4 to 5 and from 6 to 7
165                                                (i.e. buf_lng/2 exchange)
166    compute op in each node on its half:        (i.e.   count/2 ops)
167
168    Result: node 0: (a+b)+(c+d)   for  ABCD
169            node 1: (a+b)+(c+d)   for  EFGH
170            node 2: (e+f)+(g+h)   for  ABCD
171            node 3: (e+f)+(g+h)   for  EFGH
172            node 4: (i+j)+ k      for  ABCD
173            node 5: (i+j)+ k      for  EFGH
174            node 6:    l + m      for  ABCD
175            node 7:    l + m      for  EFGH
176
177 Step 5.2)
178
179    Same with double distance and oncemore the half of the buffer.
180                                                (i.e. buf_lng/4 exchange)
181                                                (i.e.   count/4 ops)
182
183    Result: node 0: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  AB
184            node 1: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  EF
185            node 2: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  CD
186            node 3: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  GH
187            node 4: [(i+j)+ k   ] + [   l + m   ]  for  AB
188            node 5: [(i+j)+ k   ] + [   l + m   ]  for  EF
189            node 6: [(i+j)+ k   ] + [   l + m   ]  for  CD
190            node 7: [(i+j)+ k   ] + [   l + m   ]  for  GH
191
192 ...
193 Step 5.n)
194
195    Same with double distance and oncemore the half of the buffer.
196                                             (i.e. buf_lng/2**n exchange)
197                                             (i.e.   count/2**n ops)
198
199    Result:
200      0: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for A
201      1: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for E
202      2: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for C
203      3: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for G
204      4: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for B
205      5: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for F
206      6: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for D
207      7: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for H
208
209
210 For colls::allreduce:
211 ------------------
212
213 Step 6.1)
214
215    Exchange on the last distance (2*n) the last result
216                                             (i.e. buf_lng/2**n exchange)
217
218 Step 6.2)
219
220    Same with half the distance and double of the result
221                                         (i.e. buf_lng/2**(n-1) exchange)
222
223 ...
224 Step 6.n)
225
226    Same with distance 1 and double of the result (== half of the
227    original buffer)                            (i.e. buf_lng/2 exchange)
228
229    Result after 6.1     6.2         6.n
230
231     on node 0:   AB    ABCD    ABCDEFGH
232     on node 1:   EF    EFGH    ABCDEFGH
233     on node 2:   CD    ABCD    ABCDEFGH
234     on node 3:   GH    EFGH    ABCDEFGH
235     on node 4:   AB    ABCD    ABCDEFGH
236     on node 5:   EF    EFGH    ABCDEFGH
237     on node 6:   CD    ABCD    ABCDEFGH
238     on node 7:   GH    EFGH    ABCDEFGH
239
240 Step 7)
241
242    If r > 0
243      transfer the result from the even nodes with old rank < 2*r
244                            to the odd  nodes with old rank < 2*r
245                                                  (i.e. buf_lng transfer)
246    Result:
247      { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] }
248      for ABCDEFGH
249      on all nodes 0..12
250
251
252 For colls::reduce:
253 ---------------
254
255 Step 6.0)
256
257    If root node not in the list of the nodes with newranks
258    (see steps 3+4) then
259
260      send last result from node 0 to the root    (buf_lng/2**n transfer)
261      and replace the role of node 0 by the root node.
262
263 Step 6.1)
264
265    Send on the last distance (2**(n-1)) the last result
266    from node with bit '2**(n-1)' in the 'new rank' unequal to that of
267    root's new rank  to the node with same '2**(n-1)' bit.
268                                             (i.e. buf_lng/2**n transfer)
269
270 Step 6.2)
271
272    Same with half the distance and double of the result
273    and bit '2**(n-2)'
274                                         (i.e. buf_lng/2**(n-1) transfer)
275
276 ...
277 Step 6.n)
278
279    Same with distance 1 and double of the result (== half of the
280    original buffer) and bit '2**0'             (i.e. buf_lng/2 transfer)
281
282    Example: roots old rank: 10
283             roots new rank:  5
284
285    Results:          6.1               6.2                 6.n
286                 action result     action result       action result
287
288     on node 0:  send A
289     on node 1:  send E
290     on node 2:  send C
291     on node 3:  send G
292     on node 4:  recv A => AB     recv CD => ABCD   send ABCD
293     on node 5:  recv E => EF     recv GH => EFGH   recv ABCD => ABCDEFGH
294     on node 6:  recv C => CD     send CD
295     on node 7:  recv G => GH     send GH
296
297 Benchmark results on CRAY T3E
298 -----------------------------
299
300    uname -a: (sn6307 hwwt3e 1.6.1.51 unicosmk CRAY T3E)
301    MPI:      /opt/ctl/mpt/1.1.0.3
302    datatype: MPI_DOUBLE
303    Ldb[][] = {{ 896,1728, 576, 736},{ 448,1280, 512, 512}}
304    env: export MPI_BUFFER_MAX=4099
305    compiled with: cc -c -O3 -h restrict=f
306
307    old = binary tree protocol of the vendor
308    new = the new protocol and its implementation
309    mixed = 'new' is used if count > limit(datatype, communicator size)
310
311    REDUCE:
312                                           communicator size
313    measurement count prot. unit    2    3    4    6    8   12   16   24
314    --------------------------------------------------------------------
315    latency         1 mixed  us  20.7 35.1 35.6 49.1 49.2 61.8 62.4 74.8
316                        old  us  19.0 32.9 33.8 47.1 47.2 61.7 62.1 73.2
317    --------------------------------------------------------------------
318    bandwidth    128k mixed MB/s 75.4 34.9 49.1 27.3 41.1 24.6 38.0 23.7
319    (=buf_lng/time)     old MB/s 28.8 16.2 16.3 11.6 11.6  8.8  8.8  7.2
320    ration = mixed/old            2.6  2.1  3.0  2.4  3.5  2.8  4.3  3.3
321    --------------------------------------------------------------------
322    limit                doubles  896 1536  576  736  576  736  576  736
323    bandwidth   limit mixed MB/s 35.9 20.5 18.6 12.3 13.5  9.2 10.0  8.6
324                        old MB/s 35.9 20.3 17.8 13.0 12.5  9.6  9.2  7.8
325    ratio = mixed/old            1.00 1.01 1.04 0.95 1.08 0.96 1.09 1.10
326
327                                           communicator size
328    measurement count prot. unit   32   48   64   96  128  192  256
329    ---------------------------------------------------------------
330    latency         1 mixed  us  77.8 88.5 90.6  102                 1)
331                        old  us  78.6 87.2 90.1 99.7  108  119  120
332    ---------------------------------------------------------------
333    bandwidth    128k mixed MB/s 35.1 23.3 34.1 22.8 34.4 22.4 33.9
334    (=buf_lng/time)     old MB/s  6.0  6.0  6.0  5.2  5.2  4.6  4.6
335    ration = mixed/old            5.8  3.9  5.7  4.4  6.6  4.8  7.4  5)
336    ---------------------------------------------------------------
337    limit                doubles  576  736  576  736  576  736  576  2)
338    bandwidth   limit mixed MB/s  9.7  7.5  8.4  6.5  6.9  5.5  5.1  3)
339                        old MB/s  7.7  6.4  6.4  5.7  5.5  4.9  4.7  3)
340    ratio = mixed/old            1.26 1.17 1.31 1.14 1.26 1.12 1.08  4)
341
342    ALLREDUCE:
343                                           communicator size
344    measurement count prot. unit    2    3    4    6    8   12   16   24
345    --------------------------------------------------------------------
346    latency         1 mixed  us  28.2 51.0 44.5 74.4 59.9  102 74.2  133
347                        old  us  26.9 48.3 42.4 69.8 57.6 96.0 75.7  126
348    --------------------------------------------------------------------
349    bandwidth    128k mixed MB/s 74.0 29.4 42.4 23.3 35.0 20.9 32.8 19.7
350    (=buf_lng/time)     old MB/s 20.9 14.4 13.2  9.7  9.6  7.3  7.4  5.8
351    ration = mixed/old            3.5  2.0  3.2  2.4  3.6  2.9  4.4  3.4
352    --------------------------------------------------------------------
353    limit                doubles  448 1280  512  512  512  512  512  512
354    bandwidth   limit mixed MB/s 26.4 15.1 16.2  8.2 12.4  7.2 10.8  5.7
355                        old MB/s 26.1 14.9 15.0  9.1 10.5  6.7  7.3  5.3
356    ratio = mixed/old            1.01 1.01 1.08 0.90 1.18 1.07 1.48 1.08
357
358                                           communicator size
359    measurement count prot. unit   32   48   64   96  128  192  256
360    ---------------------------------------------------------------
361    latency         1 mixed  us  90.3  162  109  190
362                        old  us  92.7  152  104  179  122  225  135
363    ---------------------------------------------------------------
364    bandwidth    128k mixed MB/s 31.1 19.7 30.1 19.2 29.8 18.7 29.0
365    (=buf_lng/time)     old MB/s  5.9  4.8  5.0  4.1  4.4  3.4  3.8
366    ration = mixed/old            5.3  4.1  6.0  4.7  6.8  5.5  7.7
367    ---------------------------------------------------------------
368    limit                doubles  512  512  512  512  512  512  512
369    bandwidth   limit mixed MB/s  6.6  5.6  5.7  3.5  4.4  3.2  3.8
370                        old MB/s  6.3  4.2  5.4  3.6  4.4  3.1
371    ratio = mixed/old            1.05 1.33 1.06 0.97 1.00 1.03
372
373    Footnotes:
374    1) This line shows that the overhead to decide which protocol
375       should be used can be ignored.
376    2) This line shows the limit for the count argument.
377       If count < limit then the vendor protocol is used,
378       otherwise the new protocol is used (see variable Ldb).
379    3) These lines show the bandwidth (= buffer length / execution time)
380       for both protocols.
381    4) This line shows that the limit is chosen well if the ratio is
382       between 0.95 (losing 5% for buffer length near and >=limit)
383       and 1.10 (not gaining 10% for buffer length near and <limit).
384    5) This line shows that the new protocol is 2..7 times faster
385       for long counts.
386
387 */
388
389 #ifdef REDUCE_LIMITS
390
391 #ifdef USE_Irecv
392 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
393            { MPI_Request req;                                          \
394              req=Request::irecv(rb,rc,rd,source,rt,comm);                  \
395              Request::send(sb,sc,sd,dest,st,comm);                          \
396              Request::wait(&req,stat);                                      \
397            }
398 #else
399 #ifdef USE_Isend
400 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
401            { MPI_Request req;                                          \
402              req=mpi_mpi_isend(sb,sc,sd,dest,st,comm);                    \
403              Request::recv(rb,rc,rd,source,rt,comm,stat);                   \
404              Request::wait(&req,stat);                                      \
405            }
406 #else
407 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
408            Request::sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat)
409 #endif
410 #endif
411
412 enum MPIM_Datatype {
413   MPIM_SHORT,
414   MPIM_INT,
415   MPIM_LONG,
416   MPIM_UNSIGNED_SHORT,
417   MPIM_UNSIGNED,
418   MPIM_UNSIGNED_LONG,
419   MPIM_UNSIGNED_LONG_LONG,
420   MPIM_FLOAT,
421   MPIM_DOUBLE,
422   MPIM_BYTE
423 };
424
425 enum MPIM_Op {
426   MPIM_MAX,
427   MPIM_MIN,
428   MPIM_SUM,
429   MPIM_PROD,
430   MPIM_LAND,
431   MPIM_BAND,
432   MPIM_LOR,
433   MPIM_BOR,
434   MPIM_LXOR,
435   MPIM_BXOR
436 };
437 #define MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_TYPE,TYPE)                   \
438 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
439 { int i;                                                               \
440   switch (op) {                                                        \
441    case MPIM_MAX :                                                     \
442       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
443    case MPIM_MIN :                                                     \
444       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
445    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
446    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
447    case MPIM_LAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] && b2[i]; break;  \
448    case MPIM_LOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] || b2[i]; break;  \
449    case MPIM_LXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] != b2[i]; break;  \
450    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
451    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
452    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
453    default: break;                                                     \
454   }                                                                    \
455 }
456
457 #define MPI_I_DO_OP_FP(MPI_I_do_op_TYPE,TYPE)                          \
458 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op) \
459 { int i;                                                               \
460   switch (op) {                                                        \
461    case MPIM_MAX :                                                     \
462       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
463    case MPIM_MIN :                                                     \
464       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
465    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
466    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
467    default: break;                                                     \
468   }                                                                    \
469 }
470
471 #define MPI_I_DO_OP_BYTE(MPI_I_do_op_TYPE,TYPE)                        \
472 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
473 { int i;                                                               \
474   switch (op) {                                                        \
475    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
476    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
477    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
478    default: break;                                                     \
479   }                                                                    \
480 }
481
482 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_short,  short)
483 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_int,    int)
484 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_long,   long)
485 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ushort, unsigned short)
486 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_uint,   unsigned int)
487 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulong,  unsigned long)
488 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulonglong,  unsigned long long)
489 MPI_I_DO_OP_FP(        MPI_I_do_op_float,  float)
490 MPI_I_DO_OP_FP(        MPI_I_do_op_double, double)
491 MPI_I_DO_OP_BYTE(      MPI_I_do_op_byte,   char)
492
493 #define MPI_I_DO_OP_CALL(MPI_I_do_op_TYPE,TYPE)                        \
494    MPI_I_do_op_TYPE ((TYPE*)b1, (TYPE*)b2, (TYPE*)rslt, cnt, op); break;
495
496 static void MPI_I_do_op(void* b1, void* b2, void* rslt, int cnt,
497                  MPIM_Datatype datatype, MPIM_Op op)
498 {
499  switch (datatype) {
500   case MPIM_SHORT : MPI_I_DO_OP_CALL(MPI_I_do_op_short,  short)
501   case MPIM_INT   : MPI_I_DO_OP_CALL(MPI_I_do_op_int,    int)
502   case MPIM_LONG  : MPI_I_DO_OP_CALL(MPI_I_do_op_long,   long)
503   case MPIM_UNSIGNED_SHORT:
504                    MPI_I_DO_OP_CALL(MPI_I_do_op_ushort, unsigned short)
505   case MPIM_UNSIGNED:
506                    MPI_I_DO_OP_CALL(MPI_I_do_op_uint,   unsigned int)
507   case MPIM_UNSIGNED_LONG:
508                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulong,  unsigned long)
509   case MPIM_UNSIGNED_LONG_LONG:
510                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulonglong,  unsigned long long)
511   case MPIM_FLOAT : MPI_I_DO_OP_CALL(MPI_I_do_op_float,  float)
512   case MPIM_DOUBLE: MPI_I_DO_OP_CALL(MPI_I_do_op_double, double)
513   case MPIM_BYTE  : MPI_I_DO_OP_CALL(MPI_I_do_op_byte,   char)
514  }
515 }
516
517 REDUCE_LIMITS
518 namespace simgrid{
519 namespace smpi{
520 static int MPI_I_anyReduce(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype mpi_datatype, MPI_Op mpi_op,
521                            int root, MPI_Comm comm, bool is_all)
522 {
523   char *scr1buf, *scr2buf, *scr3buf, *xxx, *sendbuf, *recvbuf;
524   int myrank, size, x_base, x_size, computed, idx;
525   int x_start, x_count = 0, r, n, mynewrank, newroot, partner;
526   int start_even[20], start_odd[20], count_even[20], count_odd[20];
527   MPI_Aint typelng;
528   MPI_Status status;
529   size_t scrlng;
530   int new_prot;
531   MPIM_Datatype datatype = MPIM_INT; MPIM_Op op = MPIM_MAX;
532
533   if     (mpi_datatype==MPI_SHORT         ) datatype=MPIM_SHORT;
534   else if(mpi_datatype==MPI_INT           ) datatype=MPIM_INT;
535   else if(mpi_datatype==MPI_LONG          ) datatype=MPIM_LONG;
536   else if(mpi_datatype==MPI_UNSIGNED_SHORT) datatype=MPIM_UNSIGNED_SHORT;
537   else if(mpi_datatype==MPI_UNSIGNED      ) datatype=MPIM_UNSIGNED;
538   else if(mpi_datatype==MPI_UNSIGNED_LONG ) datatype=MPIM_UNSIGNED_LONG;
539   else if(mpi_datatype==MPI_UNSIGNED_LONG_LONG ) datatype=MPIM_UNSIGNED_LONG_LONG;
540   else if(mpi_datatype==MPI_FLOAT         ) datatype=MPIM_FLOAT;
541   else if(mpi_datatype==MPI_DOUBLE        ) datatype=MPIM_DOUBLE;
542   else if(mpi_datatype==MPI_BYTE          ) datatype=MPIM_BYTE;
543   else
544     throw std::invalid_argument("reduce rab algorithm can't be used with this datatype!");
545
546   if     (mpi_op==MPI_MAX     ) op=MPIM_MAX;
547   else if(mpi_op==MPI_MIN     ) op=MPIM_MIN;
548   else if(mpi_op==MPI_SUM     ) op=MPIM_SUM;
549   else if(mpi_op==MPI_PROD    ) op=MPIM_PROD;
550   else if(mpi_op==MPI_LAND    ) op=MPIM_LAND;
551   else if(mpi_op==MPI_BAND    ) op=MPIM_BAND;
552   else if(mpi_op==MPI_LOR     ) op=MPIM_LOR;
553   else if(mpi_op==MPI_BOR     ) op=MPIM_BOR;
554   else if(mpi_op==MPI_LXOR    ) op=MPIM_LXOR;
555   else if(mpi_op==MPI_BXOR    ) op=MPIM_BXOR;
556
557   new_prot = 0;
558   MPI_Comm_size(comm, &size);
559   if (size > 1) /*otherwise no balancing_protocol*/
560   { int ss;
561     if      (size==2) ss=0;
562     else if (size==3) ss=1;
563     else { int s = size; while (!(s & 1)) s = s >> 1;
564            if (s==1) /* size == power of 2 */ ss = 2;
565            else      /* size != power of 2 */ ss = 3; }
566     switch(op) {
567      case MPIM_MAX:   case MPIM_MIN: case MPIM_SUM:  case MPIM_PROD:
568      case MPIM_LAND:  case MPIM_LOR: case MPIM_LXOR:
569      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
570       switch(datatype) {
571         case MPIM_SHORT:  case MPIM_UNSIGNED_SHORT:
572          new_prot = count >= Lsh[is_all][ss]; break;
573         case MPIM_INT:    case MPIM_UNSIGNED:
574          new_prot = count >= Lin[is_all][ss]; break;
575         case MPIM_LONG:   case MPIM_UNSIGNED_LONG: case MPIM_UNSIGNED_LONG_LONG:
576          new_prot = count >= Llg[is_all][ss]; break;
577      default:
578         break;
579     } default:
580         break;}
581     switch(op) {
582      case MPIM_MAX:  case MPIM_MIN: case MPIM_SUM: case MPIM_PROD:
583       switch(datatype) {
584         case MPIM_FLOAT:
585          new_prot = count >= Lfp[is_all][ss]; break;
586         case MPIM_DOUBLE:
587          new_prot = count >= Ldb[is_all][ss]; break;
588               default:
589         break;
590     } default:
591         break;}
592     switch(op) {
593      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
594       switch(datatype) {
595         case MPIM_BYTE:
596          new_prot = count >= Lby[is_all][ss]; break;
597               default:
598         break;
599     } default:
600         break;}
601 #   ifdef DEBUG
602     { char *ss_str[]={"two","three","power of 2","no power of 2"};
603       printf("MPI_(All)Reduce: is_all=%1d, size=%1d=%s, new_prot=%1d\n",
604              is_all, size, ss_str[ss], new_prot); fflush(stdout);
605     }
606 #   endif
607   }
608
609   if (new_prot)
610   {
611     sendbuf = (char*) Sendbuf;
612     recvbuf = (char*) Recvbuf;
613     MPI_Comm_rank(comm, &myrank);
614     MPI_Type_extent(mpi_datatype, &typelng);
615     scrlng  = typelng * count;
616 #ifdef NO_CACHE_OPTIMIZATION
617     scr1buf = new char[scrlng];
618     scr2buf = new char[scrlng];
619     scr3buf = new char[scrlng];
620 #else
621 #  ifdef SCR_LNG_OPTIM
622     scrlng = SCR_LNG_OPTIM(scrlng);
623 #  endif
624     scr2buf = new char[3 * scrlng]; /* To test cache problems.     */
625     scr1buf = scr2buf + 1*scrlng; /* scr1buf and scr3buf must not*/
626     scr3buf = scr2buf + 2*scrlng; /* be used for malloc because  */
627                                   /* they are interchanged below.*/
628 #endif
629     computed = 0;
630     if (is_all) root = myrank; /* for correct recvbuf handling */
631
632   /*...step 1 */
633
634 #   ifdef DEBUG
635      printf("[%2d] step 1 begin\n",myrank); fflush(stdout);
636 #   endif
637     n = 0; x_size = 1;
638     while (2*x_size <= size) { n++; x_size = x_size * 2; }
639     /* x_size == 2**n */
640     r = size - x_size;
641
642   /*...step 2 */
643
644 #   ifdef DEBUG
645     printf("[%2d] step 2 begin n=%d, r=%d\n",myrank,n,r);fflush(stdout);
646 #   endif
647     if (myrank < 2*r)
648     {
649       if ((myrank % 2) == 0 /*even*/)
650       {
651         MPI_I_Sendrecv(sendbuf + (count/2)*typelng,
652                        count - count/2, mpi_datatype, myrank+1, 1220,
653                        scr2buf, count/2,mpi_datatype, myrank+1, 1221,
654                        comm, &status);
655         MPI_I_do_op(sendbuf, scr2buf, scr1buf,
656                     count/2, datatype, op);
657         Request::recv(scr1buf + (count/2)*typelng, count - count/2,
658                  mpi_datatype, myrank+1, 1223, comm, &status);
659         computed = 1;
660 #       ifdef DEBUG
661         { int i; printf("[%2d] after step 2: val=",
662                         myrank);
663           for (i=0; i<count; i++)
664            printf(" %5.0lf",((double*)scr1buf)[i] );
665           printf("\n"); fflush(stdout);
666         }
667 #       endif
668       }
669       else /*odd*/
670       {
671         MPI_I_Sendrecv(sendbuf, count/2,mpi_datatype, myrank-1, 1221,
672                        scr2buf + (count/2)*typelng,
673                        count - count/2, mpi_datatype, myrank-1, 1220,
674                        comm, &status);
675         MPI_I_do_op(scr2buf + (count/2)*typelng,
676                     sendbuf + (count/2)*typelng,
677                     scr1buf + (count/2)*typelng,
678                     count - count/2, datatype, op);
679         Request::send(scr1buf + (count/2)*typelng, count - count/2,
680                  mpi_datatype, myrank-1, 1223, comm);
681       }
682     }
683
684   /*...step 3+4 */
685
686 #   ifdef DEBUG
687      printf("[%2d] step 3+4 begin\n",myrank); fflush(stdout);
688 #   endif
689     if ((myrank >= 2*r) || ((myrank%2 == 0)  &&  (myrank < 2*r)))
690          mynewrank = (myrank < 2*r ? myrank/2 : myrank-r);
691     else mynewrank = -1;
692
693     if (mynewrank >= 0)
694     { /* begin -- only for nodes with new rank */
695
696 #     define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
697
698   /*...step 5 */
699
700       x_start = 0;
701       x_count = count;
702       for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
703       {
704         start_even[idx] = x_start;
705         count_even[idx] = x_count / 2;
706         start_odd [idx] = x_start + count_even[idx];
707         count_odd [idx] = x_count - count_even[idx];
708         if (((mynewrank/x_base) % 2) == 0 /*even*/)
709         {
710 #         ifdef DEBUG
711             printf("[%2d](%2d) step 5.%d begin even c=%1d\n",
712                    myrank,mynewrank,idx+1,computed); fflush(stdout);
713 #         endif
714           x_start = start_even[idx];
715           x_count = count_even[idx];
716           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
717                          + start_odd[idx]*typelng, count_odd[idx],
718                          mpi_datatype, OLDRANK(mynewrank+x_base), 1231,
719                          scr2buf + x_start*typelng, x_count,
720                          mpi_datatype, OLDRANK(mynewrank+x_base), 1232,
721                          comm, &status);
722           MPI_I_do_op((computed?scr1buf:sendbuf) + x_start*typelng,
723                       scr2buf                    + x_start*typelng,
724                       ((root==myrank) && (idx==(n-1))
725                         ? recvbuf + x_start*typelng
726                         : scr3buf + x_start*typelng),
727                       x_count, datatype, op);
728         }
729         else /*odd*/
730         {
731 #         ifdef DEBUG
732             printf("[%2d](%2d) step 5.%d begin  odd c=%1d\n",
733                    myrank,mynewrank,idx+1,computed); fflush(stdout);
734 #         endif
735           x_start = start_odd[idx];
736           x_count = count_odd[idx];
737           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
738                          +start_even[idx]*typelng, count_even[idx],
739                          mpi_datatype, OLDRANK(mynewrank-x_base), 1232,
740                          scr2buf + x_start*typelng, x_count,
741                          mpi_datatype, OLDRANK(mynewrank-x_base), 1231,
742                          comm, &status);
743           MPI_I_do_op(scr2buf                    + x_start*typelng,
744                       (computed?scr1buf:sendbuf) + x_start*typelng,
745                       ((root==myrank) && (idx==(n-1))
746                         ? recvbuf + x_start*typelng
747                         : scr3buf + x_start*typelng),
748                       x_count, datatype, op);
749         }
750         xxx = scr3buf; scr3buf = scr1buf; scr1buf = xxx;
751         computed = 1;
752 #       ifdef DEBUG
753         { int i; printf("[%2d](%2d) after step 5.%d   end: start=%2d  count=%2d  val=",
754                         myrank,mynewrank,idx+1,x_start,x_count);
755           for (i=0; i<x_count; i++)
756            printf(" %5.0lf",((double*)((root==myrank)&&(idx==(n-1))
757                              ? recvbuf + x_start*typelng
758                              : scr1buf + x_start*typelng))[i] );
759           printf("\n"); fflush(stdout);
760         }
761 #       endif
762       } /*for*/
763
764 #     undef OLDRANK
765
766
767     } /* end -- only for nodes with new rank */
768
769     if (is_all)
770     {
771       /*...steps 6.1 to 6.n */
772
773       if (mynewrank >= 0)
774       { /* begin -- only for nodes with new rank */
775
776 #       define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
777
778         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
779         {
780 #         ifdef DEBUG
781             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
782 #         endif
783           if (((mynewrank/x_base) % 2) == 0 /*even*/)
784           {
785             MPI_I_Sendrecv(recvbuf + start_even[idx]*typelng,
786                                      count_even[idx],
787                            mpi_datatype, OLDRANK(mynewrank+x_base),1241,
788                            recvbuf + start_odd[idx]*typelng,
789                                      count_odd[idx],
790                            mpi_datatype, OLDRANK(mynewrank+x_base),1242,
791                            comm, &status);
792 #           ifdef DEBUG
793               x_start = start_odd[idx];
794               x_count = count_odd[idx];
795 #           endif
796           }
797           else /*odd*/
798           {
799             MPI_I_Sendrecv(recvbuf + start_odd[idx]*typelng,
800                                      count_odd[idx],
801                            mpi_datatype, OLDRANK(mynewrank-x_base),1242,
802                            recvbuf + start_even[idx]*typelng,
803                                      count_even[idx],
804                            mpi_datatype, OLDRANK(mynewrank-x_base),1241,
805                            comm, &status);
806 #           ifdef DEBUG
807               x_start = start_even[idx];
808               x_count = count_even[idx];
809 #           endif
810           }
811 #         ifdef DEBUG
812           { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
813                           myrank,mynewrank,n-idx,x_start,x_count);
814             for (i=0; i<x_count; i++)
815              printf(" %5.0lf",((double*)(root==myrank
816                                ? recvbuf + x_start*typelng
817                                : scr1buf + x_start*typelng))[i] );
818             printf("\n"); fflush(stdout);
819           }
820 #         endif
821         } /*for*/
822
823 #       undef OLDRANK
824
825       } /* end -- only for nodes with new rank */
826
827       /*...step 7 */
828
829       if (myrank < 2*r)
830       {
831 #       ifdef DEBUG
832           printf("[%2d] step 7 begin\n",myrank); fflush(stdout);
833 #       endif
834         if (myrank%2 == 0 /*even*/)
835           Request::send(recvbuf, count, mpi_datatype, myrank+1, 1253, comm);
836         else /*odd*/
837           Request::recv(recvbuf, count, mpi_datatype, myrank-1, 1253, comm, &status);
838       }
839
840     }
841     else /* not is_all, i.e. Reduce */
842     {
843
844     /*...step 6.0 */
845
846       if ((root < 2*r) && (root%2 == 1))
847       {
848 #       ifdef DEBUG
849           printf("[%2d] step 6.0 begin\n",myrank); fflush(stdout);
850 #       endif
851         if (myrank == 0) /* then mynewrank==0, x_start==0
852                                  x_count == count/x_size  */
853         {
854           Request::send(scr1buf,x_count,mpi_datatype,root,1241,comm);
855           mynewrank = -1;
856         }
857
858         if (myrank == root)
859         {
860           mynewrank = 0;
861           x_start = 0;
862           x_count = count;
863           for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
864           {
865             start_even[idx] = x_start;
866             count_even[idx] = x_count / 2;
867             start_odd [idx] = x_start + count_even[idx];
868             count_odd [idx] = x_count - count_even[idx];
869             /* code for always even in each bit of mynewrank: */
870             x_start = start_even[idx];
871             x_count = count_even[idx];
872           }
873           Request::recv(recvbuf,x_count,mpi_datatype,0,1241,comm,&status);
874         }
875         newroot = 0;
876       }
877       else
878       {
879         newroot = (root < 2*r ? root/2 : root-r);
880       }
881
882     /*...steps 6.1 to 6.n */
883
884       if (mynewrank >= 0)
885       { /* begin -- only for nodes with new rank */
886
887 #       define OLDRANK(new) ((new)==newroot ? root                     \
888                              : ((new)<r ? (new)*2 : (new)+r) )
889
890         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
891         {
892 #         ifdef DEBUG
893             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
894 #         endif
895           if ((mynewrank & x_base) != (newroot & x_base))
896           {
897             if (((mynewrank/x_base) % 2) == 0 /*even*/)
898             { x_start = start_even[idx]; x_count = count_even[idx];
899               partner = mynewrank+x_base; }
900             else
901             { x_start = start_odd[idx]; x_count = count_odd[idx];
902               partner = mynewrank-x_base; }
903             Request::send(scr1buf + x_start*typelng, x_count, mpi_datatype,
904                      OLDRANK(partner), 1244, comm);
905           }
906           else /*odd*/
907           {
908             if (((mynewrank/x_base) % 2) == 0 /*even*/)
909             { x_start = start_odd[idx]; x_count = count_odd[idx];
910               partner = mynewrank+x_base; }
911             else
912             { x_start = start_even[idx]; x_count = count_even[idx];
913               partner = mynewrank-x_base; }
914             Request::recv((myrank==root ? recvbuf : scr1buf)
915                      + x_start*typelng, x_count, mpi_datatype,
916                      OLDRANK(partner), 1244, comm, &status);
917 #           ifdef DEBUG
918             { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
919                             myrank,mynewrank,n-idx,x_start,x_count);
920               for (i=0; i<x_count; i++)
921                printf(" %5.0lf",((double*)(root==myrank
922                                  ? recvbuf + x_start*typelng
923                                  : scr1buf + x_start*typelng))[i] );
924               printf("\n"); fflush(stdout);
925             }
926 #           endif
927           }
928         } /*for*/
929
930 #       undef OLDRANK
931
932       } /* end -- only for nodes with new rank */
933     }
934
935 #   ifdef NO_CACHE_TESTING
936     delete[] scr1buf;
937     delete[] scr2buf;
938     delete[] scr3buf;
939 #   else
940     delete[] scr2buf;             /* scr1buf and scr3buf are part of scr2buf */
941 #   endif
942     return(MPI_SUCCESS);
943   } /* new_prot */
944   /*otherwise:*/
945   if (is_all)
946     return (colls::allreduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, comm));
947   else
948     return (colls::reduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, root, comm));
949 }
950 #endif /*REDUCE_LIMITS*/
951
952
953 int reduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
954 {
955   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, root, comm, false);
956 }
957
958 int allreduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
959 {
960   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, -1, comm, true);
961 }
962 }
963 }