Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
simplify
[simgrid.git] / src / smpi / colls / reduce / reduce-rab.cpp
1 /* Copyright (c) 2013-2021. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 /* extracted from mpig_myreduce.c with
7    :3,$s/MPL/MPI/g     and  :%s/\\\\$/ \\/   */
8
9 /* Copyright: Rolf Rabenseifner, 1997
10  *            Computing Center University of Stuttgart
11  *            rabenseifner@rus.uni-stuttgart.de
12  *
13  * The usage of this software is free,
14  * but this header must not be removed.
15  */
16
17 #include "../colls_private.hpp"
18 #include <cstdio>
19 #include <cstdlib>
20
21 #define REDUCE_NEW_ALWAYS 1
22
23 #ifdef CRAY
24 #      define SCR_LNG_OPTIM(bytelng)  128 + ((bytelng+127)/256) * 256;
25                                  /* =  16 + multiple of 32 doubles*/
26
27 #define REDUCE_LIMITS /*  values are lower limits for count arg.  */                                                   \
28                       /*  routine =    reduce                allreduce             */                                  \
29                       /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */                                  \
30   static int Lsh[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
31   static int Lin[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
32   static int Llg[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
33   static int Lfp[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
34   static int Ldb[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
35   static int Lby[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};
36 #endif
37
38 #ifdef REDUCE_NEW_ALWAYS
39 # undef  REDUCE_LIMITS
40 #define REDUCE_LIMITS /*  values are lower limits for count arg.  */                                                   \
41                       /*  routine =    reduce                allreduce             */                                  \
42                       /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */                                  \
43   static int Lsh[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
44   static int Lin[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
45   static int Llg[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
46   static int Lfp[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
47   static int Ldb[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
48   static int Lby[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};
49 #endif
50
51 /* Fast reduce and allreduce algorithm for longer buffers and predefined
52    operations.
53
54    This algorithm is explained with the example of 13 nodes.
55    The nodes are numbered   0, 1, 2, ... 12.
56    The sendbuf content is   a, b, c, ...  m.
57    The buffer array is notated with ABCDEFGH, this means that
58    e.g. 'C' is the third 1/8 of the buffer.
59
60    The algorithm computes
61
62    { [(a+b)+(c+d)] + [(e+f)+(g+h)] }  +  { [(i+j)+k] + [l+m] }
63
64    This is equivalent to the mostly used binary tree algorithm
65    (e.g. used in mpich).
66
67    size := number of nodes in the communicator.
68    2**n := the power of 2 that is next smaller or equal to the size.
69    r    := size - 2**n
70
71    Exa.: size=13 ==> n=3, r=5  (i.e. size == 13 == 2**n+r ==  2**3 + 5)
72
73    The algorithm needs for the execution of one colls::reduce
74
75    - for r==0
76      exec_time = n*(L1+L2)     + buf_lng * (1-1/2**n) * (T1 + T2 + O/d)
77
78    - for r>0
79      exec_time = (n+1)*(L1+L2) + buf_lng               * T1
80                                + buf_lng*(1+1/2-1/2**n)*     (T2 + O/d)
81
82    with L1 = latency of a message transfer e.g. by send+recv
83         L2 = latency of a message exchange e.g. by sendrecv
84         T1 = additional time to transfer 1 byte (= 1/bandwidth)
85         T2 = additional time to exchange 1 byte in both directions.
86         O  = time for one operation
87         d  = size of the datatype
88
89         On a MPP system it is expected that T1==T2.
90
91    In Comparison with the binary tree algorithm that needs:
92
93    - for r==0
94      exec_time_bin_tree =  n * L1   +  buf_lng * n * (T1 + O/d)
95
96    - for r>0
97      exec_time_bin_tree = (n+1)*L1  +  buf_lng*(n+1)*(T1 + O/d)
98
99    the new algorithm is faster if (assuming T1=T2)
100
101     for n>2:
102      for r==0:  buf_lng  >  L2 *   n   / [ (n-2) * T1 + (n-1) * O/d ]
103      for r>0:   buf_lng  >  L2 * (n+1) / [ (n-1.5)*T1 + (n-0.5)*O/d ]
104
105     for size = 5, 6, and 7:
106                 buf_lng  >  L2 / [0.25 * T1  +  0.58 * O/d ]
107
108     for size = 4:
109                 buf_lng  >  L2 / [0.25 * T1  +  0.62 * O/d ]
110     for size = 2 and 3:
111                 buf_lng  >  L2 / [  0.5 * O/d ]
112
113    and for O/d >> T1 we can summarize:
114
115     The new algorithm is faster for about   buf_lng > 2 * L2 * d / O
116
117    Example L1 = L2 = 50 us,
118            bandwidth=300MB/s, i.e. T1 = T2 = 3.3 ns/byte
119            and 10 MFLOP/s,    i.e. O  = 100 ns/FLOP
120            for double,        i.e. d  =   8 byte/FLOP
121
122            ==> the new algorithm is faster for about buf_lng>8000 bytes
123                i.e count > 1000 doubles !
124 Step 1)
125
126    compute n and r
127
128 Step 2)
129
130    if  myrank < 2*r
131
132      split the buffer into ABCD and EFGH
133
134      even myrank: send    buffer EFGH to   myrank+1 (buf_lng/2 exchange)
135                   receive buffer ABCD from myrank+1
136                   compute op for ABCD                 (count/2 ops)
137                   receive result EFGH               (buf_lng/2 transfer)
138      odd  myrank: send    buffer ABCD to   myrank+1
139                   receive buffer EFGH from myrank+1
140                   compute op for EFGH
141                   send    result EFGH
142
143    Result: node:     0    2    4    6    8  10  11  12
144            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
145
146 Step 3)
147
148    The following algorithm uses only the nodes with
149
150    (myrank is even  &&  myrank < 2*r) || (myrank >= 2*r)
151
152 Step 4)
153
154    define NEWRANK(old) := (old < 2*r ? old/2 : old-r)
155    define OLDRANK(new) := (new < r   ? new*2 : new+r)
156
157    Result:
158        old ranks:    0    2    4    6    8  10  11  12
159        new ranks:    0    1    2    3    4   5   6   7
160            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
161
162 Step 5.1)
163
164    Split the buffer (ABCDEFGH) in the middle,
165    the lower half (ABCD) is computed on even (new) ranks,
166    the opper half (EFGH) is computed on odd  (new) ranks.
167
168    exchange: ABCD from 1 to 0, from 3 to 2, from 5 to 4 and from 7 to 6
169              EFGH from 0 to 1, from 2 to 3, from 4 to 5 and from 6 to 7
170                                                (i.e. buf_lng/2 exchange)
171    compute op in each node on its half:        (i.e.   count/2 ops)
172
173    Result: node 0: (a+b)+(c+d)   for  ABCD
174            node 1: (a+b)+(c+d)   for  EFGH
175            node 2: (e+f)+(g+h)   for  ABCD
176            node 3: (e+f)+(g+h)   for  EFGH
177            node 4: (i+j)+ k      for  ABCD
178            node 5: (i+j)+ k      for  EFGH
179            node 6:    l + m      for  ABCD
180            node 7:    l + m      for  EFGH
181
182 Step 5.2)
183
184    Same with double distance and oncemore the half of the buffer.
185                                                (i.e. buf_lng/4 exchange)
186                                                (i.e.   count/4 ops)
187
188    Result: node 0: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  AB
189            node 1: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  EF
190            node 2: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  CD
191            node 3: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  GH
192            node 4: [(i+j)+ k   ] + [   l + m   ]  for  AB
193            node 5: [(i+j)+ k   ] + [   l + m   ]  for  EF
194            node 6: [(i+j)+ k   ] + [   l + m   ]  for  CD
195            node 7: [(i+j)+ k   ] + [   l + m   ]  for  GH
196
197 ...
198 Step 5.n)
199
200    Same with double distance and oncemore the half of the buffer.
201                                             (i.e. buf_lng/2**n exchange)
202                                             (i.e.   count/2**n ops)
203
204    Result:
205      0: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for A
206      1: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for E
207      2: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for C
208      3: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for G
209      4: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for B
210      5: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for F
211      6: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for D
212      7: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for H
213
214
215 For colls::allreduce:
216 ------------------
217
218 Step 6.1)
219
220    Exchange on the last distance (2*n) the last result
221                                             (i.e. buf_lng/2**n exchange)
222
223 Step 6.2)
224
225    Same with half the distance and double of the result
226                                         (i.e. buf_lng/2**(n-1) exchange)
227
228 ...
229 Step 6.n)
230
231    Same with distance 1 and double of the result (== half of the
232    original buffer)                            (i.e. buf_lng/2 exchange)
233
234    Result after 6.1     6.2         6.n
235
236     on node 0:   AB    ABCD    ABCDEFGH
237     on node 1:   EF    EFGH    ABCDEFGH
238     on node 2:   CD    ABCD    ABCDEFGH
239     on node 3:   GH    EFGH    ABCDEFGH
240     on node 4:   AB    ABCD    ABCDEFGH
241     on node 5:   EF    EFGH    ABCDEFGH
242     on node 6:   CD    ABCD    ABCDEFGH
243     on node 7:   GH    EFGH    ABCDEFGH
244
245 Step 7)
246
247    If r > 0
248      transfer the result from the even nodes with old rank < 2*r
249                            to the odd  nodes with old rank < 2*r
250                                                  (i.e. buf_lng transfer)
251    Result:
252      { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] }
253      for ABCDEFGH
254      on all nodes 0..12
255
256
257 For colls::reduce:
258 ---------------
259
260 Step 6.0)
261
262    If root node not in the list of the nodes with newranks
263    (see steps 3+4) then
264
265      send last result from node 0 to the root    (buf_lng/2**n transfer)
266      and replace the role of node 0 by the root node.
267
268 Step 6.1)
269
270    Send on the last distance (2**(n-1)) the last result
271    from node with bit '2**(n-1)' in the 'new rank' unequal to that of
272    root's new rank  to the node with same '2**(n-1)' bit.
273                                             (i.e. buf_lng/2**n transfer)
274
275 Step 6.2)
276
277    Same with half the distance and double of the result
278    and bit '2**(n-2)'
279                                         (i.e. buf_lng/2**(n-1) transfer)
280
281 ...
282 Step 6.n)
283
284    Same with distance 1 and double of the result (== half of the
285    original buffer) and bit '2**0'             (i.e. buf_lng/2 transfer)
286
287    Example: roots old rank: 10
288             roots new rank:  5
289
290    Results:          6.1               6.2                 6.n
291                 action result     action result       action result
292
293     on node 0:  send A
294     on node 1:  send E
295     on node 2:  send C
296     on node 3:  send G
297     on node 4:  recv A => AB     recv CD => ABCD   send ABCD
298     on node 5:  recv E => EF     recv GH => EFGH   recv ABCD => ABCDEFGH
299     on node 6:  recv C => CD     send CD
300     on node 7:  recv G => GH     send GH
301
302 Benchmark results on CRAY T3E
303 -----------------------------
304
305    uname -a: (sn6307 hwwt3e 1.6.1.51 unicosmk CRAY T3E)
306    MPI:      /opt/ctl/mpt/1.1.0.3
307    datatype: MPI_DOUBLE
308    Ldb[][] = {{ 896,1728, 576, 736},{ 448,1280, 512, 512}}
309    env: export MPI_BUFFER_MAX=4099
310    compiled with: cc -c -O3 -h restrict=f
311
312    old = binary tree protocol of the vendor
313    new = the new protocol and its implementation
314    mixed = 'new' is used if count > limit(datatype, communicator size)
315
316    REDUCE:
317                                           communicator size
318    measurement count prot. unit    2    3    4    6    8   12   16   24
319    --------------------------------------------------------------------
320    latency         1 mixed  us  20.7 35.1 35.6 49.1 49.2 61.8 62.4 74.8
321                        old  us  19.0 32.9 33.8 47.1 47.2 61.7 62.1 73.2
322    --------------------------------------------------------------------
323    bandwidth    128k mixed MB/s 75.4 34.9 49.1 27.3 41.1 24.6 38.0 23.7
324    (=buf_lng/time)     old MB/s 28.8 16.2 16.3 11.6 11.6  8.8  8.8  7.2
325    ration = mixed/old            2.6  2.1  3.0  2.4  3.5  2.8  4.3  3.3
326    --------------------------------------------------------------------
327    limit                doubles  896 1536  576  736  576  736  576  736
328    bandwidth   limit mixed MB/s 35.9 20.5 18.6 12.3 13.5  9.2 10.0  8.6
329                        old MB/s 35.9 20.3 17.8 13.0 12.5  9.6  9.2  7.8
330    ratio = mixed/old            1.00 1.01 1.04 0.95 1.08 0.96 1.09 1.10
331
332                                           communicator size
333    measurement count prot. unit   32   48   64   96  128  192  256
334    ---------------------------------------------------------------
335    latency         1 mixed  us  77.8 88.5 90.6  102                 1)
336                        old  us  78.6 87.2 90.1 99.7  108  119  120
337    ---------------------------------------------------------------
338    bandwidth    128k mixed MB/s 35.1 23.3 34.1 22.8 34.4 22.4 33.9
339    (=buf_lng/time)     old MB/s  6.0  6.0  6.0  5.2  5.2  4.6  4.6
340    ration = mixed/old            5.8  3.9  5.7  4.4  6.6  4.8  7.4  5)
341    ---------------------------------------------------------------
342    limit                doubles  576  736  576  736  576  736  576  2)
343    bandwidth   limit mixed MB/s  9.7  7.5  8.4  6.5  6.9  5.5  5.1  3)
344                        old MB/s  7.7  6.4  6.4  5.7  5.5  4.9  4.7  3)
345    ratio = mixed/old            1.26 1.17 1.31 1.14 1.26 1.12 1.08  4)
346
347    ALLREDUCE:
348                                           communicator size
349    measurement count prot. unit    2    3    4    6    8   12   16   24
350    --------------------------------------------------------------------
351    latency         1 mixed  us  28.2 51.0 44.5 74.4 59.9  102 74.2  133
352                        old  us  26.9 48.3 42.4 69.8 57.6 96.0 75.7  126
353    --------------------------------------------------------------------
354    bandwidth    128k mixed MB/s 74.0 29.4 42.4 23.3 35.0 20.9 32.8 19.7
355    (=buf_lng/time)     old MB/s 20.9 14.4 13.2  9.7  9.6  7.3  7.4  5.8
356    ration = mixed/old            3.5  2.0  3.2  2.4  3.6  2.9  4.4  3.4
357    --------------------------------------------------------------------
358    limit                doubles  448 1280  512  512  512  512  512  512
359    bandwidth   limit mixed MB/s 26.4 15.1 16.2  8.2 12.4  7.2 10.8  5.7
360                        old MB/s 26.1 14.9 15.0  9.1 10.5  6.7  7.3  5.3
361    ratio = mixed/old            1.01 1.01 1.08 0.90 1.18 1.07 1.48 1.08
362
363                                           communicator size
364    measurement count prot. unit   32   48   64   96  128  192  256
365    ---------------------------------------------------------------
366    latency         1 mixed  us  90.3  162  109  190
367                        old  us  92.7  152  104  179  122  225  135
368    ---------------------------------------------------------------
369    bandwidth    128k mixed MB/s 31.1 19.7 30.1 19.2 29.8 18.7 29.0
370    (=buf_lng/time)     old MB/s  5.9  4.8  5.0  4.1  4.4  3.4  3.8
371    ration = mixed/old            5.3  4.1  6.0  4.7  6.8  5.5  7.7
372    ---------------------------------------------------------------
373    limit                doubles  512  512  512  512  512  512  512
374    bandwidth   limit mixed MB/s  6.6  5.6  5.7  3.5  4.4  3.2  3.8
375                        old MB/s  6.3  4.2  5.4  3.6  4.4  3.1
376    ratio = mixed/old            1.05 1.33 1.06 0.97 1.00 1.03
377
378    Footnotes:
379    1) This line shows that the overhead to decide which protocol
380       should be used can be ignored.
381    2) This line shows the limit for the count argument.
382       If count < limit then the vendor protocol is used,
383       otherwise the new protocol is used (see variable Ldb).
384    3) These lines show the bandwidth (= buffer length / execution time)
385       for both protocols.
386    4) This line shows that the limit is chosen well if the ratio is
387       between 0.95 (losing 5% for buffer length near and >=limit)
388       and 1.10 (not gaining 10% for buffer length near and <limit).
389    5) This line shows that the new protocol is 2..7 times faster
390       for long counts.
391
392 */
393
394 #ifdef REDUCE_LIMITS
395
396 #ifdef USE_Irecv
397 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
398   {                                                                                                                    \
399     MPI_Request req;                                                                                                   \
400     req = Request::irecv(rb, rc, rd, source, rt, comm);                                                                \
401     Request::send(sb, sc, sd, dest, st, comm);                                                                         \
402     Request::wait(&req, stat);                                                                                         \
403   }
404 #else
405 #ifdef USE_Isend
406 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
407   {                                                                                                                    \
408     MPI_Request req;                                                                                                   \
409     req = mpi_mpi_isend(sb, sc, sd, dest, st, comm);                                                                   \
410     Request::recv(rb, rc, rd, source, rt, comm, stat);                                                                 \
411     Request::wait(&req, stat);                                                                                         \
412   }
413 #else
414 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
415   Request::sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)
416 #endif
417 #endif
418
419 enum MPIM_Datatype {
420   MPIM_SHORT,
421   MPIM_INT,
422   MPIM_LONG,
423   MPIM_UNSIGNED_SHORT,
424   MPIM_UNSIGNED,
425   MPIM_UNSIGNED_LONG,
426   MPIM_UNSIGNED_LONG_LONG,
427   MPIM_FLOAT,
428   MPIM_DOUBLE,
429   MPIM_BYTE
430 };
431
432 enum MPIM_Op {
433   MPIM_MAX,
434   MPIM_MIN,
435   MPIM_SUM,
436   MPIM_PROD,
437   MPIM_LAND,
438   MPIM_BAND,
439   MPIM_LOR,
440   MPIM_BOR,
441   MPIM_LXOR,
442   MPIM_BXOR
443 };
444 #define MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_TYPE, TYPE)                                                                  \
445   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
446   {                                                                                                                    \
447     int i;                                                                                                             \
448     switch (op) {                                                                                                      \
449       case MPIM_MAX:                                                                                                   \
450         for (i = 0; i < cnt; i++)                                                                                      \
451           rslt[i] = (b1[i] > b2[i] ? b1[i] : b2[i]);                                                                   \
452         break;                                                                                                         \
453       case MPIM_MIN:                                                                                                   \
454         for (i = 0; i < cnt; i++)                                                                                      \
455           rslt[i] = (b1[i] < b2[i] ? b1[i] : b2[i]);                                                                   \
456         break;                                                                                                         \
457       case MPIM_SUM:                                                                                                   \
458         for (i = 0; i < cnt; i++)                                                                                      \
459           rslt[i] = b1[i] + b2[i];                                                                                     \
460         break;                                                                                                         \
461       case MPIM_PROD:                                                                                                  \
462         for (i = 0; i < cnt; i++)                                                                                      \
463           rslt[i] = b1[i] * b2[i];                                                                                     \
464         break;                                                                                                         \
465       case MPIM_LAND:                                                                                                  \
466         for (i = 0; i < cnt; i++)                                                                                      \
467           rslt[i] = b1[i] && b2[i];                                                                                    \
468         break;                                                                                                         \
469       case MPIM_LOR:                                                                                                   \
470         for (i = 0; i < cnt; i++)                                                                                      \
471           rslt[i] = b1[i] || b2[i];                                                                                    \
472         break;                                                                                                         \
473       case MPIM_LXOR:                                                                                                  \
474         for (i = 0; i < cnt; i++)                                                                                      \
475           rslt[i] = b1[i] != b2[i];                                                                                    \
476         break;                                                                                                         \
477       case MPIM_BAND:                                                                                                  \
478         for (i = 0; i < cnt; i++)                                                                                      \
479           rslt[i] = b1[i] & b2[i];                                                                                     \
480         break;                                                                                                         \
481       case MPIM_BOR:                                                                                                   \
482         for (i = 0; i < cnt; i++)                                                                                      \
483           rslt[i] = b1[i] | b2[i];                                                                                     \
484         break;                                                                                                         \
485       case MPIM_BXOR:                                                                                                  \
486         for (i = 0; i < cnt; i++)                                                                                      \
487           rslt[i] = b1[i] ^ b2[i];                                                                                     \
488         break;                                                                                                         \
489       default:                                                                                                         \
490         break;                                                                                                         \
491     }                                                                                                                  \
492   }
493
494 #define MPI_I_DO_OP_FP(MPI_I_do_op_TYPE, TYPE)                                                                         \
495   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
496   {                                                                                                                    \
497     int i;                                                                                                             \
498     switch (op) {                                                                                                      \
499       case MPIM_MAX:                                                                                                   \
500         for (i = 0; i < cnt; i++)                                                                                      \
501           rslt[i] = (b1[i] > b2[i] ? b1[i] : b2[i]);                                                                   \
502         break;                                                                                                         \
503       case MPIM_MIN:                                                                                                   \
504         for (i = 0; i < cnt; i++)                                                                                      \
505           rslt[i] = (b1[i] < b2[i] ? b1[i] : b2[i]);                                                                   \
506         break;                                                                                                         \
507       case MPIM_SUM:                                                                                                   \
508         for (i = 0; i < cnt; i++)                                                                                      \
509           rslt[i] = b1[i] + b2[i];                                                                                     \
510         break;                                                                                                         \
511       case MPIM_PROD:                                                                                                  \
512         for (i = 0; i < cnt; i++)                                                                                      \
513           rslt[i] = b1[i] * b2[i];                                                                                     \
514         break;                                                                                                         \
515       default:                                                                                                         \
516         break;                                                                                                         \
517     }                                                                                                                  \
518   }
519
520 #define MPI_I_DO_OP_BYTE(MPI_I_do_op_TYPE, TYPE)                                                                       \
521   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
522   {                                                                                                                    \
523     int i;                                                                                                             \
524     switch (op) {                                                                                                      \
525       case MPIM_BAND:                                                                                                  \
526         for (i = 0; i < cnt; i++)                                                                                      \
527           rslt[i] = b1[i] & b2[i];                                                                                     \
528         break;                                                                                                         \
529       case MPIM_BOR:                                                                                                   \
530         for (i = 0; i < cnt; i++)                                                                                      \
531           rslt[i] = b1[i] | b2[i];                                                                                     \
532         break;                                                                                                         \
533       case MPIM_BXOR:                                                                                                  \
534         for (i = 0; i < cnt; i++)                                                                                      \
535           rslt[i] = b1[i] ^ b2[i];                                                                                     \
536         break;                                                                                                         \
537       default:                                                                                                         \
538         break;                                                                                                         \
539     }                                                                                                                  \
540   }
541
542 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_short, short)
543 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_int, int)
544 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_long, long)
545 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ushort, unsigned short)
546 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_uint, unsigned int)
547 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ulong, unsigned long)
548 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ulonglong, unsigned long long)
549 MPI_I_DO_OP_FP(MPI_I_do_op_float, float)
550 MPI_I_DO_OP_FP(MPI_I_do_op_double, double)
551 MPI_I_DO_OP_BYTE(MPI_I_do_op_byte, char)
552
553 #define MPI_I_DO_OP_CALL(MPI_I_do_op_TYPE, TYPE)                                                                       \
554   MPI_I_do_op_TYPE((TYPE*)b1, (TYPE*)b2, (TYPE*)rslt, cnt, op);                                                        \
555   break;
556
557 static void MPI_I_do_op(void* b1, void* b2, void* rslt, int cnt, MPIM_Datatype datatype, MPIM_Op op)
558 {
559   switch (datatype) {
560     case MPIM_SHORT:
561       MPI_I_DO_OP_CALL(MPI_I_do_op_short, short)
562     case MPIM_INT:
563       MPI_I_DO_OP_CALL(MPI_I_do_op_int, int)
564     case MPIM_LONG:
565       MPI_I_DO_OP_CALL(MPI_I_do_op_long, long)
566     case MPIM_UNSIGNED_SHORT:
567       MPI_I_DO_OP_CALL(MPI_I_do_op_ushort, unsigned short)
568     case MPIM_UNSIGNED:
569       MPI_I_DO_OP_CALL(MPI_I_do_op_uint, unsigned int)
570     case MPIM_UNSIGNED_LONG:
571       MPI_I_DO_OP_CALL(MPI_I_do_op_ulong, unsigned long)
572     case MPIM_UNSIGNED_LONG_LONG:
573       MPI_I_DO_OP_CALL(MPI_I_do_op_ulonglong, unsigned long long)
574     case MPIM_FLOAT:
575       MPI_I_DO_OP_CALL(MPI_I_do_op_float, float)
576     case MPIM_DOUBLE:
577       MPI_I_DO_OP_CALL(MPI_I_do_op_double, double)
578     case MPIM_BYTE:
579       MPI_I_DO_OP_CALL(MPI_I_do_op_byte, char)
580   }
581 }
582
583 REDUCE_LIMITS
584 namespace simgrid {
585 namespace smpi {
586 static int MPI_I_anyReduce(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype mpi_datatype, MPI_Op mpi_op,
587                            int root, MPI_Comm comm, bool is_all)
588 {
589   char *scr1buf, *scr2buf, *scr3buf, *xxx, *sendbuf, *recvbuf;
590   int myrank, size, x_base, x_size, computed, idx;
591   int x_start, x_count = 0, r, n, mynewrank, newroot, partner;
592   int start_even[20], start_odd[20], count_even[20], count_odd[20];
593   MPI_Aint typelng;
594   MPI_Status status;
595   size_t scrlng;
596   int new_prot;
597   MPIM_Datatype datatype = MPIM_INT; MPIM_Op op = MPIM_MAX;
598
599   if     (mpi_datatype==MPI_SHORT         ) datatype=MPIM_SHORT;
600   else if(mpi_datatype==MPI_INT           ) datatype=MPIM_INT;
601   else if(mpi_datatype==MPI_LONG          ) datatype=MPIM_LONG;
602   else if(mpi_datatype==MPI_UNSIGNED_SHORT) datatype=MPIM_UNSIGNED_SHORT;
603   else if(mpi_datatype==MPI_UNSIGNED      ) datatype=MPIM_UNSIGNED;
604   else if(mpi_datatype==MPI_UNSIGNED_LONG ) datatype=MPIM_UNSIGNED_LONG;
605   else if(mpi_datatype==MPI_UNSIGNED_LONG_LONG ) datatype=MPIM_UNSIGNED_LONG_LONG;
606   else if(mpi_datatype==MPI_FLOAT         ) datatype=MPIM_FLOAT;
607   else if(mpi_datatype==MPI_DOUBLE        ) datatype=MPIM_DOUBLE;
608   else if(mpi_datatype==MPI_BYTE          ) datatype=MPIM_BYTE;
609   else
610     throw std::invalid_argument("reduce rab algorithm can't be used with this datatype!");
611
612   if     (mpi_op==MPI_MAX     ) op=MPIM_MAX;
613   else if(mpi_op==MPI_MIN     ) op=MPIM_MIN;
614   else if(mpi_op==MPI_SUM     ) op=MPIM_SUM;
615   else if(mpi_op==MPI_PROD    ) op=MPIM_PROD;
616   else if(mpi_op==MPI_LAND    ) op=MPIM_LAND;
617   else if(mpi_op==MPI_BAND    ) op=MPIM_BAND;
618   else if(mpi_op==MPI_LOR     ) op=MPIM_LOR;
619   else if(mpi_op==MPI_BOR     ) op=MPIM_BOR;
620   else if(mpi_op==MPI_LXOR    ) op=MPIM_LXOR;
621   else if(mpi_op==MPI_BXOR    ) op=MPIM_BXOR;
622
623   new_prot = 0;
624   MPI_Comm_size(comm, &size);
625   if (size > 1) /*otherwise no balancing_protocol*/
626   { int ss;
627     if      (size==2) ss=0;
628     else if (size==3) ss=1;
629     else { int s = size; while (!(s & 1)) s = s >> 1;
630            if (s==1) /* size == power of 2 */ ss = 2;
631            else      /* size != power of 2 */ ss = 3; }
632     switch(op) {
633      case MPIM_MAX:   case MPIM_MIN: case MPIM_SUM:  case MPIM_PROD:
634      case MPIM_LAND:  case MPIM_LOR: case MPIM_LXOR:
635      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
636       switch(datatype) {
637         case MPIM_SHORT:  case MPIM_UNSIGNED_SHORT:
638          new_prot = count >= Lsh[is_all][ss]; break;
639         case MPIM_INT:    case MPIM_UNSIGNED:
640          new_prot = count >= Lin[is_all][ss]; break;
641         case MPIM_LONG:   case MPIM_UNSIGNED_LONG: case MPIM_UNSIGNED_LONG_LONG:
642          new_prot = count >= Llg[is_all][ss]; break;
643      default:
644         break;
645     } default:
646         break;}
647     switch(op) {
648      case MPIM_MAX:  case MPIM_MIN: case MPIM_SUM: case MPIM_PROD:
649       switch(datatype) {
650         case MPIM_FLOAT:
651          new_prot = count >= Lfp[is_all][ss]; break;
652         case MPIM_DOUBLE:
653          new_prot = count >= Ldb[is_all][ss]; break;
654               default:
655         break;
656     } default:
657         break;}
658     switch(op) {
659      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
660       switch(datatype) {
661         case MPIM_BYTE:
662          new_prot = count >= Lby[is_all][ss]; break;
663               default:
664         break;
665     } default:
666         break;}
667 #   ifdef DEBUG
668     { char *ss_str[]={"two","three","power of 2","no power of 2"};
669       printf("MPI_(All)Reduce: is_all=%1d, size=%1d=%s, new_prot=%1d\n",
670              is_all, size, ss_str[ss], new_prot); fflush(stdout);
671     }
672 #   endif
673   }
674
675   if (new_prot)
676   {
677     sendbuf = (char*) Sendbuf;
678     recvbuf = (char*) Recvbuf;
679     MPI_Comm_rank(comm, &myrank);
680     MPI_Type_extent(mpi_datatype, &typelng);
681     scrlng  = typelng * count;
682 #ifdef NO_CACHE_OPTIMIZATION
683     scr1buf = new char[scrlng];
684     scr2buf = new char[scrlng];
685     scr3buf = new char[scrlng];
686 #else
687 #  ifdef SCR_LNG_OPTIM
688     scrlng = SCR_LNG_OPTIM(scrlng);
689 #  endif
690     scr2buf = new char[3 * scrlng]; /* To test cache problems.     */
691     scr1buf = scr2buf + 1*scrlng; /* scr1buf and scr3buf must not*/
692     scr3buf = scr2buf + 2*scrlng; /* be used for malloc because  */
693                                   /* they are interchanged below.*/
694 #endif
695     computed = 0;
696     if (is_all) root = myrank; /* for correct recvbuf handling */
697
698   /*...step 1 */
699
700 #   ifdef DEBUG
701      printf("[%2d] step 1 begin\n",myrank); fflush(stdout);
702 #   endif
703     n = 0; x_size = 1;
704     while (2*x_size <= size) { n++; x_size = x_size * 2; }
705     /* x_size == 2**n */
706     r = size - x_size;
707
708   /*...step 2 */
709
710 #   ifdef DEBUG
711     printf("[%2d] step 2 begin n=%d, r=%d\n",myrank,n,r);fflush(stdout);
712 #   endif
713     if (myrank < 2*r)
714     {
715       if ((myrank % 2) == 0 /*even*/)
716       {
717         MPI_I_Sendrecv(sendbuf + (count/2)*typelng,
718                        count - count/2, mpi_datatype, myrank+1, 1220,
719                        scr2buf, count/2,mpi_datatype, myrank+1, 1221,
720                        comm, &status);
721         MPI_I_do_op(sendbuf, scr2buf, scr1buf,
722                     count/2, datatype, op);
723         Request::recv(scr1buf + (count/2)*typelng, count - count/2,
724                  mpi_datatype, myrank+1, 1223, comm, &status);
725         computed = 1;
726 #       ifdef DEBUG
727         { int i; printf("[%2d] after step 2: val=",
728                         myrank);
729           for (i=0; i<count; i++)
730            printf(" %5.0lf",((double*)scr1buf)[i] );
731           printf("\n"); fflush(stdout);
732         }
733 #       endif
734       }
735       else /*odd*/
736       {
737         MPI_I_Sendrecv(sendbuf, count/2,mpi_datatype, myrank-1, 1221,
738                        scr2buf + (count/2)*typelng,
739                        count - count/2, mpi_datatype, myrank-1, 1220,
740                        comm, &status);
741         MPI_I_do_op(scr2buf + (count/2)*typelng,
742                     sendbuf + (count/2)*typelng,
743                     scr1buf + (count/2)*typelng,
744                     count - count/2, datatype, op);
745         Request::send(scr1buf + (count/2)*typelng, count - count/2,
746                  mpi_datatype, myrank-1, 1223, comm);
747       }
748     }
749
750   /*...step 3+4 */
751
752 #   ifdef DEBUG
753      printf("[%2d] step 3+4 begin\n",myrank); fflush(stdout);
754 #   endif
755     if ((myrank >= 2*r) || ((myrank%2 == 0)  &&  (myrank < 2*r)))
756          mynewrank = (myrank < 2*r ? myrank/2 : myrank-r);
757     else mynewrank = -1;
758
759     if (mynewrank >= 0)
760     { /* begin -- only for nodes with new rank */
761
762 #     define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
763
764   /*...step 5 */
765
766       x_start = 0;
767       x_count = count;
768       for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
769       {
770         start_even[idx] = x_start;
771         count_even[idx] = x_count / 2;
772         start_odd [idx] = x_start + count_even[idx];
773         count_odd [idx] = x_count - count_even[idx];
774         if (((mynewrank/x_base) % 2) == 0 /*even*/)
775         {
776 #         ifdef DEBUG
777             printf("[%2d](%2d) step 5.%d begin even c=%1d\n",
778                    myrank,mynewrank,idx+1,computed); fflush(stdout);
779 #         endif
780           x_start = start_even[idx];
781           x_count = count_even[idx];
782           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
783                          + start_odd[idx]*typelng, count_odd[idx],
784                          mpi_datatype, OLDRANK(mynewrank+x_base), 1231,
785                          scr2buf + x_start*typelng, x_count,
786                          mpi_datatype, OLDRANK(mynewrank+x_base), 1232,
787                          comm, &status);
788           MPI_I_do_op((computed?scr1buf:sendbuf) + x_start*typelng,
789                       scr2buf                    + x_start*typelng,
790                       ((root==myrank) && (idx==(n-1))
791                         ? recvbuf + x_start*typelng
792                         : scr3buf + x_start*typelng),
793                       x_count, datatype, op);
794         }
795         else /*odd*/
796         {
797 #         ifdef DEBUG
798             printf("[%2d](%2d) step 5.%d begin  odd c=%1d\n",
799                    myrank,mynewrank,idx+1,computed); fflush(stdout);
800 #         endif
801           x_start = start_odd[idx];
802           x_count = count_odd[idx];
803           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
804                          +start_even[idx]*typelng, count_even[idx],
805                          mpi_datatype, OLDRANK(mynewrank-x_base), 1232,
806                          scr2buf + x_start*typelng, x_count,
807                          mpi_datatype, OLDRANK(mynewrank-x_base), 1231,
808                          comm, &status);
809           MPI_I_do_op(scr2buf                    + x_start*typelng,
810                       (computed?scr1buf:sendbuf) + x_start*typelng,
811                       ((root==myrank) && (idx==(n-1))
812                         ? recvbuf + x_start*typelng
813                         : scr3buf + x_start*typelng),
814                       x_count, datatype, op);
815         }
816         xxx = scr3buf; scr3buf = scr1buf; scr1buf = xxx;
817         computed = 1;
818 #       ifdef DEBUG
819         { int i; printf("[%2d](%2d) after step 5.%d   end: start=%2d  count=%2d  val=",
820                         myrank,mynewrank,idx+1,x_start,x_count);
821           for (i=0; i<x_count; i++)
822            printf(" %5.0lf",((double*)((root==myrank)&&(idx==(n-1))
823                              ? recvbuf + x_start*typelng
824                              : scr1buf + x_start*typelng))[i] );
825           printf("\n"); fflush(stdout);
826         }
827 #       endif
828       } /*for*/
829
830 #     undef OLDRANK
831
832
833     } /* end -- only for nodes with new rank */
834
835     if (is_all)
836     {
837       /*...steps 6.1 to 6.n */
838
839       if (mynewrank >= 0)
840       { /* begin -- only for nodes with new rank */
841
842 #       define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
843
844         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
845         {
846 #         ifdef DEBUG
847             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
848 #         endif
849           if (((mynewrank/x_base) % 2) == 0 /*even*/)
850           {
851             MPI_I_Sendrecv(recvbuf + start_even[idx]*typelng,
852                                      count_even[idx],
853                            mpi_datatype, OLDRANK(mynewrank+x_base),1241,
854                            recvbuf + start_odd[idx]*typelng,
855                                      count_odd[idx],
856                            mpi_datatype, OLDRANK(mynewrank+x_base),1242,
857                            comm, &status);
858 #           ifdef DEBUG
859               x_start = start_odd[idx];
860               x_count = count_odd[idx];
861 #           endif
862           }
863           else /*odd*/
864           {
865             MPI_I_Sendrecv(recvbuf + start_odd[idx]*typelng,
866                                      count_odd[idx],
867                            mpi_datatype, OLDRANK(mynewrank-x_base),1242,
868                            recvbuf + start_even[idx]*typelng,
869                                      count_even[idx],
870                            mpi_datatype, OLDRANK(mynewrank-x_base),1241,
871                            comm, &status);
872 #           ifdef DEBUG
873               x_start = start_even[idx];
874               x_count = count_even[idx];
875 #           endif
876           }
877 #         ifdef DEBUG
878           { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
879                           myrank,mynewrank,n-idx,x_start,x_count);
880             for (i=0; i<x_count; i++)
881              printf(" %5.0lf",((double*)(root==myrank
882                                ? recvbuf + x_start*typelng
883                                : scr1buf + x_start*typelng))[i] );
884             printf("\n"); fflush(stdout);
885           }
886 #         endif
887         } /*for*/
888
889 #       undef OLDRANK
890
891       } /* end -- only for nodes with new rank */
892
893       /*...step 7 */
894
895       if (myrank < 2*r)
896       {
897 #       ifdef DEBUG
898           printf("[%2d] step 7 begin\n",myrank); fflush(stdout);
899 #       endif
900         if (myrank%2 == 0 /*even*/)
901           Request::send(recvbuf, count, mpi_datatype, myrank+1, 1253, comm);
902         else /*odd*/
903           Request::recv(recvbuf, count, mpi_datatype, myrank-1, 1253, comm, &status);
904       }
905
906     }
907     else /* not is_all, i.e. Reduce */
908     {
909
910     /*...step 6.0 */
911
912       if ((root < 2*r) && (root%2 == 1))
913       {
914 #       ifdef DEBUG
915           printf("[%2d] step 6.0 begin\n",myrank); fflush(stdout);
916 #       endif
917         if (myrank == 0) /* then mynewrank==0, x_start==0
918                                  x_count == count/x_size  */
919         {
920           Request::send(scr1buf,x_count,mpi_datatype,root,1241,comm);
921           mynewrank = -1;
922         }
923
924         if (myrank == root)
925         {
926           mynewrank = 0;
927           x_start = 0;
928           x_count = count;
929           for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
930           {
931             start_even[idx] = x_start;
932             count_even[idx] = x_count / 2;
933             start_odd [idx] = x_start + count_even[idx];
934             count_odd [idx] = x_count - count_even[idx];
935             /* code for always even in each bit of mynewrank: */
936             x_start = start_even[idx];
937             x_count = count_even[idx];
938           }
939           Request::recv(recvbuf,x_count,mpi_datatype,0,1241,comm,&status);
940         }
941         newroot = 0;
942       }
943       else
944       {
945         newroot = (root < 2*r ? root/2 : root-r);
946       }
947
948     /*...steps 6.1 to 6.n */
949
950       if (mynewrank >= 0)
951       { /* begin -- only for nodes with new rank */
952
953 #define OLDRANK(new) ((new) == newroot ? root : ((new) < r ? (new) * 2 : (new) + r))
954
955         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
956         {
957 #         ifdef DEBUG
958             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
959 #         endif
960           if ((mynewrank & x_base) != (newroot & x_base))
961           {
962             if (((mynewrank/x_base) % 2) == 0 /*even*/)
963             { x_start = start_even[idx]; x_count = count_even[idx];
964               partner = mynewrank+x_base; }
965             else
966             { x_start = start_odd[idx]; x_count = count_odd[idx];
967               partner = mynewrank-x_base; }
968             Request::send(scr1buf + x_start*typelng, x_count, mpi_datatype,
969                      OLDRANK(partner), 1244, comm);
970           }
971           else /*odd*/
972           {
973             if (((mynewrank/x_base) % 2) == 0 /*even*/)
974             { x_start = start_odd[idx]; x_count = count_odd[idx];
975               partner = mynewrank+x_base; }
976             else
977             { x_start = start_even[idx]; x_count = count_even[idx];
978               partner = mynewrank-x_base; }
979             Request::recv((myrank==root ? recvbuf : scr1buf)
980                      + x_start*typelng, x_count, mpi_datatype,
981                      OLDRANK(partner), 1244, comm, &status);
982 #           ifdef DEBUG
983             { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
984                             myrank,mynewrank,n-idx,x_start,x_count);
985               for (i=0; i<x_count; i++)
986                printf(" %5.0lf",((double*)(root==myrank
987                                  ? recvbuf + x_start*typelng
988                                  : scr1buf + x_start*typelng))[i] );
989               printf("\n"); fflush(stdout);
990             }
991 #           endif
992           }
993         } /*for*/
994
995 #       undef OLDRANK
996
997       } /* end -- only for nodes with new rank */
998     }
999
1000 #   ifdef NO_CACHE_TESTING
1001     delete[] scr1buf;
1002     delete[] scr2buf;
1003     delete[] scr3buf;
1004 #   else
1005     delete[] scr2buf;             /* scr1buf and scr3buf are part of scr2buf */
1006 #   endif
1007     return(MPI_SUCCESS);
1008   } /* new_prot */
1009   /*otherwise:*/
1010   if (is_all)
1011     return (colls::allreduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, comm));
1012   else
1013     return (colls::reduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, root, comm));
1014 }
1015 #endif /*REDUCE_LIMITS*/
1016
1017 int reduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root,
1018                 MPI_Comm comm)
1019 {
1020   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, root, comm, false);
1021 }
1022
1023 int allreduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
1024 {
1025   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, -1, comm, true);
1026 }
1027 }
1028 }