Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add new entry in Release_Notes.
[simgrid.git] / src / smpi / colls / reduce / reduce-rab.cpp
1 /* Copyright (c) 2013-2023. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 /* extracted from mpig_myreduce.c with
7    :3,$s/MPL/MPI/g     and  :%s/\\\\$/ \\/   */
8
9 /* Copyright: Rolf Rabenseifner, 1997
10  *            Computing Center University of Stuttgart
11  *            rabenseifner@rus.uni-stuttgart.de
12  *
13  * The usage of this software is free,
14  * but this header must not be removed.
15  */
16
17 #include "../colls_private.hpp"
18 #include <cstdio>
19 #include <cstdlib>
20
21 #define REDUCE_NEW_ALWAYS 1
22
23 #ifdef CRAY
24 #      define SCR_LNG_OPTIM(bytelng)  128 + ((bytelng+127)/256) * 256;
25                                  /* =  16 + multiple of 32 doubles*/
26
27 #define REDUCE_LIMITS /*  values are lower limits for count arg.  */                                                   \
28                       /*  routine =    reduce                allreduce             */                                  \
29                       /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */                                  \
30   static int Lsh[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
31   static int Lin[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
32   static int Llg[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
33   static int Lfp[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
34   static int Ldb[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
35   static int Lby[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};
36 #endif
37
38 #ifdef REDUCE_NEW_ALWAYS
39 # undef  REDUCE_LIMITS
40 #define REDUCE_LIMITS /*  values are lower limits for count arg.  */                                                   \
41                       /*  routine =    reduce                allreduce             */                                  \
42                       /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */                                  \
43   static int Lsh[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
44   static int Lin[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
45   static int Llg[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
46   static int Lfp[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
47   static int Ldb[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
48   static int Lby[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};
49 #endif
50
51 /* Fast reduce and allreduce algorithm for longer buffers and predefined
52    operations.
53
54    This algorithm is explained with the example of 13 nodes.
55    The nodes are numbered   0, 1, 2, ... 12.
56    The sendbuf content is   a, b, c, ...  m.
57    The buffer array is notated with ABCDEFGH, this means that
58    e.g. 'C' is the third 1/8 of the buffer.
59
60    The algorithm computes
61
62    { [(a+b)+(c+d)] + [(e+f)+(g+h)] }  +  { [(i+j)+k] + [l+m] }
63
64    This is equivalent to the mostly used binary tree algorithm
65    (e.g. used in mpich).
66
67    size := number of nodes in the communicator.
68    2**n := the power of 2 that is next smaller or equal to the size.
69    r    := size - 2**n
70
71    Exa.: size=13 ==> n=3, r=5  (i.e. size == 13 == 2**n+r ==  2**3 + 5)
72
73    The algorithm needs for the execution of one colls::reduce
74
75    - for r==0
76      exec_time = n*(L1+L2)     + buf_lng * (1-1/2**n) * (T1 + T2 + O/d)
77
78    - for r>0
79      exec_time = (n+1)*(L1+L2) + buf_lng               * T1
80                                + buf_lng*(1+1/2-1/2**n)*     (T2 + O/d)
81
82    with L1 = latency of a message transfer e.g. by send+recv
83         L2 = latency of a message exchange e.g. by sendrecv
84         T1 = additional time to transfer 1 byte (= 1/bandwidth)
85         T2 = additional time to exchange 1 byte in both directions.
86         O  = time for one operation
87         d  = size of the datatype
88
89         On a MPP system it is expected that T1==T2.
90
91    In Comparison with the binary tree algorithm that needs:
92
93    - for r==0
94      exec_time_bin_tree =  n * L1   +  buf_lng * n * (T1 + O/d)
95
96    - for r>0
97      exec_time_bin_tree = (n+1)*L1  +  buf_lng*(n+1)*(T1 + O/d)
98
99    the new algorithm is faster if (assuming T1=T2)
100
101     for n>2:
102      for r==0:  buf_lng  >  L2 *   n   / [ (n-2) * T1 + (n-1) * O/d ]
103      for r>0:   buf_lng  >  L2 * (n+1) / [ (n-1.5)*T1 + (n-0.5)*O/d ]
104
105     for size = 5, 6, and 7:
106                 buf_lng  >  L2 / [0.25 * T1  +  0.58 * O/d ]
107
108     for size = 4:
109                 buf_lng  >  L2 / [0.25 * T1  +  0.62 * O/d ]
110     for size = 2 and 3:
111                 buf_lng  >  L2 / [  0.5 * O/d ]
112
113    and for O/d >> T1 we can summarize:
114
115     The new algorithm is faster for about   buf_lng > 2 * L2 * d / O
116
117    Example L1 = L2 = 50 us,
118            bandwidth=300MB/s, i.e. T1 = T2 = 3.3 ns/byte
119            and 10 MFLOP/s,    i.e. O  = 100 ns/FLOP
120            for double,        i.e. d  =   8 byte/FLOP
121
122            ==> the new algorithm is faster for about buf_lng>8000 bytes
123                i.e count > 1000 doubles !
124 Step 1)
125
126    compute n and r
127
128 Step 2)
129
130    if  myrank < 2*r
131
132      split the buffer into ABCD and EFGH
133
134      even myrank: send    buffer EFGH to   myrank+1 (buf_lng/2 exchange)
135                   receive buffer ABCD from myrank+1
136                   compute op for ABCD                 (count/2 ops)
137                   receive result EFGH               (buf_lng/2 transfer)
138      odd  myrank: send    buffer ABCD to   myrank+1
139                   receive buffer EFGH from myrank+1
140                   compute op for EFGH
141                   send    result EFGH
142
143    Result: node:     0    2    4    6    8  10  11  12
144            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
145
146 Step 3)
147
148    The following algorithm uses only the nodes with
149
150    (myrank is even  &&  myrank < 2*r) || (myrank >= 2*r)
151
152 Step 4)
153
154    define NEWRANK(old) := (old < 2*r ? old/2 : old-r)
155    define OLDRANK(new) := (new < r   ? new*2 : new+r)
156
157    Result:
158        old ranks:    0    2    4    6    8  10  11  12
159        new ranks:    0    1    2    3    4   5   6   7
160            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
161
162 Step 5.1)
163
164    Split the buffer (ABCDEFGH) in the middle,
165    the lower half (ABCD) is computed on even (new) ranks,
166    the opper half (EFGH) is computed on odd  (new) ranks.
167
168    exchange: ABCD from 1 to 0, from 3 to 2, from 5 to 4 and from 7 to 6
169              EFGH from 0 to 1, from 2 to 3, from 4 to 5 and from 6 to 7
170                                                (i.e. buf_lng/2 exchange)
171    compute op in each node on its half:        (i.e.   count/2 ops)
172
173    Result: node 0: (a+b)+(c+d)   for  ABCD
174            node 1: (a+b)+(c+d)   for  EFGH
175            node 2: (e+f)+(g+h)   for  ABCD
176            node 3: (e+f)+(g+h)   for  EFGH
177            node 4: (i+j)+ k      for  ABCD
178            node 5: (i+j)+ k      for  EFGH
179            node 6:    l + m      for  ABCD
180            node 7:    l + m      for  EFGH
181
182 Step 5.2)
183
184    Same with double distance and oncemore the half of the buffer.
185                                                (i.e. buf_lng/4 exchange)
186                                                (i.e.   count/4 ops)
187
188    Result: node 0: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  AB
189            node 1: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  EF
190            node 2: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  CD
191            node 3: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  GH
192            node 4: [(i+j)+ k   ] + [   l + m   ]  for  AB
193            node 5: [(i+j)+ k   ] + [   l + m   ]  for  EF
194            node 6: [(i+j)+ k   ] + [   l + m   ]  for  CD
195            node 7: [(i+j)+ k   ] + [   l + m   ]  for  GH
196
197 ...
198 Step 5.n)
199
200    Same with double distance and oncemore the half of the buffer.
201                                             (i.e. buf_lng/2**n exchange)
202                                             (i.e.   count/2**n ops)
203
204    Result:
205      0: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for A
206      1: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for E
207      2: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for C
208      3: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for G
209      4: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for B
210      5: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for F
211      6: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for D
212      7: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for H
213
214
215 For colls::allreduce:
216 ------------------
217
218 Step 6.1)
219
220    Exchange on the last distance (2*n) the last result
221                                             (i.e. buf_lng/2**n exchange)
222
223 Step 6.2)
224
225    Same with half the distance and double of the result
226                                         (i.e. buf_lng/2**(n-1) exchange)
227
228 ...
229 Step 6.n)
230
231    Same with distance 1 and double of the result (== half of the
232    original buffer)                            (i.e. buf_lng/2 exchange)
233
234    Result after 6.1     6.2         6.n
235
236     on node 0:   AB    ABCD    ABCDEFGH
237     on node 1:   EF    EFGH    ABCDEFGH
238     on node 2:   CD    ABCD    ABCDEFGH
239     on node 3:   GH    EFGH    ABCDEFGH
240     on node 4:   AB    ABCD    ABCDEFGH
241     on node 5:   EF    EFGH    ABCDEFGH
242     on node 6:   CD    ABCD    ABCDEFGH
243     on node 7:   GH    EFGH    ABCDEFGH
244
245 Step 7)
246
247    If r > 0
248      transfer the result from the even nodes with old rank < 2*r
249                            to the odd  nodes with old rank < 2*r
250                                                  (i.e. buf_lng transfer)
251    Result:
252      { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] }
253      for ABCDEFGH
254      on all nodes 0..12
255
256
257 For colls::reduce:
258 ---------------
259
260 Step 6.0)
261
262    If root node not in the list of the nodes with newranks
263    (see steps 3+4) then
264
265      send last result from node 0 to the root    (buf_lng/2**n transfer)
266      and replace the role of node 0 by the root node.
267
268 Step 6.1)
269
270    Send on the last distance (2**(n-1)) the last result
271    from node with bit '2**(n-1)' in the 'new rank' unequal to that of
272    root's new rank  to the node with same '2**(n-1)' bit.
273                                             (i.e. buf_lng/2**n transfer)
274
275 Step 6.2)
276
277    Same with half the distance and double of the result
278    and bit '2**(n-2)'
279                                         (i.e. buf_lng/2**(n-1) transfer)
280
281 ...
282 Step 6.n)
283
284    Same with distance 1 and double of the result (== half of the
285    original buffer) and bit '2**0'             (i.e. buf_lng/2 transfer)
286
287    Example: roots old rank: 10
288             roots new rank:  5
289
290    Results:          6.1               6.2                 6.n
291                 action result     action result       action result
292
293     on node 0:  send A
294     on node 1:  send E
295     on node 2:  send C
296     on node 3:  send G
297     on node 4:  recv A => AB     recv CD => ABCD   send ABCD
298     on node 5:  recv E => EF     recv GH => EFGH   recv ABCD => ABCDEFGH
299     on node 6:  recv C => CD     send CD
300     on node 7:  recv G => GH     send GH
301
302 Benchmark results on CRAY T3E
303 -----------------------------
304
305    uname -a: (sn6307 hwwt3e 1.6.1.51 unicosmk CRAY T3E)
306    MPI:      /opt/ctl/mpt/1.1.0.3
307    datatype: MPI_DOUBLE
308    Ldb[][] = {{ 896,1728, 576, 736},{ 448,1280, 512, 512}}
309    env: export MPI_BUFFER_MAX=4099
310    compiled with: cc -c -O3 -h restrict=f
311
312    old = binary tree protocol of the vendor
313    new = the new protocol and its implementation
314    mixed = 'new' is used if count > limit(datatype, communicator size)
315
316    REDUCE:
317                                           communicator size
318    measurement count prot. unit    2    3    4    6    8   12   16   24
319    --------------------------------------------------------------------
320    latency         1 mixed  us  20.7 35.1 35.6 49.1 49.2 61.8 62.4 74.8
321                        old  us  19.0 32.9 33.8 47.1 47.2 61.7 62.1 73.2
322    --------------------------------------------------------------------
323    bandwidth    128k mixed MB/s 75.4 34.9 49.1 27.3 41.1 24.6 38.0 23.7
324    (=buf_lng/time)     old MB/s 28.8 16.2 16.3 11.6 11.6  8.8  8.8  7.2
325    ration = mixed/old            2.6  2.1  3.0  2.4  3.5  2.8  4.3  3.3
326    --------------------------------------------------------------------
327    limit                doubles  896 1536  576  736  576  736  576  736
328    bandwidth   limit mixed MB/s 35.9 20.5 18.6 12.3 13.5  9.2 10.0  8.6
329                        old MB/s 35.9 20.3 17.8 13.0 12.5  9.6  9.2  7.8
330    ratio = mixed/old            1.00 1.01 1.04 0.95 1.08 0.96 1.09 1.10
331
332                                           communicator size
333    measurement count prot. unit   32   48   64   96  128  192  256
334    ---------------------------------------------------------------
335    latency         1 mixed  us  77.8 88.5 90.6  102                 1)
336                        old  us  78.6 87.2 90.1 99.7  108  119  120
337    ---------------------------------------------------------------
338    bandwidth    128k mixed MB/s 35.1 23.3 34.1 22.8 34.4 22.4 33.9
339    (=buf_lng/time)     old MB/s  6.0  6.0  6.0  5.2  5.2  4.6  4.6
340    ration = mixed/old            5.8  3.9  5.7  4.4  6.6  4.8  7.4  5)
341    ---------------------------------------------------------------
342    limit                doubles  576  736  576  736  576  736  576  2)
343    bandwidth   limit mixed MB/s  9.7  7.5  8.4  6.5  6.9  5.5  5.1  3)
344                        old MB/s  7.7  6.4  6.4  5.7  5.5  4.9  4.7  3)
345    ratio = mixed/old            1.26 1.17 1.31 1.14 1.26 1.12 1.08  4)
346
347    ALLREDUCE:
348                                           communicator size
349    measurement count prot. unit    2    3    4    6    8   12   16   24
350    --------------------------------------------------------------------
351    latency         1 mixed  us  28.2 51.0 44.5 74.4 59.9  102 74.2  133
352                        old  us  26.9 48.3 42.4 69.8 57.6 96.0 75.7  126
353    --------------------------------------------------------------------
354    bandwidth    128k mixed MB/s 74.0 29.4 42.4 23.3 35.0 20.9 32.8 19.7
355    (=buf_lng/time)     old MB/s 20.9 14.4 13.2  9.7  9.6  7.3  7.4  5.8
356    ration = mixed/old            3.5  2.0  3.2  2.4  3.6  2.9  4.4  3.4
357    --------------------------------------------------------------------
358    limit                doubles  448 1280  512  512  512  512  512  512
359    bandwidth   limit mixed MB/s 26.4 15.1 16.2  8.2 12.4  7.2 10.8  5.7
360                        old MB/s 26.1 14.9 15.0  9.1 10.5  6.7  7.3  5.3
361    ratio = mixed/old            1.01 1.01 1.08 0.90 1.18 1.07 1.48 1.08
362
363                                           communicator size
364    measurement count prot. unit   32   48   64   96  128  192  256
365    ---------------------------------------------------------------
366    latency         1 mixed  us  90.3  162  109  190
367                        old  us  92.7  152  104  179  122  225  135
368    ---------------------------------------------------------------
369    bandwidth    128k mixed MB/s 31.1 19.7 30.1 19.2 29.8 18.7 29.0
370    (=buf_lng/time)     old MB/s  5.9  4.8  5.0  4.1  4.4  3.4  3.8
371    ration = mixed/old            5.3  4.1  6.0  4.7  6.8  5.5  7.7
372    ---------------------------------------------------------------
373    limit                doubles  512  512  512  512  512  512  512
374    bandwidth   limit mixed MB/s  6.6  5.6  5.7  3.5  4.4  3.2  3.8
375                        old MB/s  6.3  4.2  5.4  3.6  4.4  3.1
376    ratio = mixed/old            1.05 1.33 1.06 0.97 1.00 1.03
377
378    Footnotes:
379    1) This line shows that the overhead to decide which protocol
380       should be used can be ignored.
381    2) This line shows the limit for the count argument.
382       If count < limit then the vendor protocol is used,
383       otherwise the new protocol is used (see variable Ldb).
384    3) These lines show the bandwidth (= buffer length / execution time)
385       for both protocols.
386    4) This line shows that the limit is chosen well if the ratio is
387       between 0.95 (losing 5% for buffer length near and >=limit)
388       and 1.10 (not gaining 10% for buffer length near and <limit).
389    5) This line shows that the new protocol is 2..7 times faster
390       for long counts.
391
392 */
393
394 #ifdef REDUCE_LIMITS
395
396 #ifdef USE_Irecv
397 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
398   {                                                                                                                    \
399     MPI_Request req;                                                                                                   \
400     req = Request::irecv(rb, rc, rd, source, rt, comm);                                                                \
401     Request::send(sb, sc, sd, dest, st, comm);                                                                         \
402     Request::wait(&req, stat);                                                                                         \
403   }
404 #else
405 #ifdef USE_Isend
406 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
407   {                                                                                                                    \
408     MPI_Request req;                                                                                                   \
409     req = mpi_mpi_isend(sb, sc, sd, dest, st, comm);                                                                   \
410     Request::recv(rb, rc, rd, source, rt, comm, stat);                                                                 \
411     Request::wait(&req, stat);                                                                                         \
412   }
413 #else
414 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
415   Request::sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)
416 #endif
417 #endif
418
419 enum MPIM_Datatype {
420   MPIM_SHORT,
421   MPIM_INT,
422   MPIM_LONG,
423   MPIM_UNSIGNED_SHORT,
424   MPIM_UNSIGNED,
425   MPIM_UNSIGNED_LONG,
426   MPIM_UNSIGNED_LONG_LONG,
427   MPIM_FLOAT,
428   MPIM_DOUBLE,
429   MPIM_BYTE
430 };
431
432 enum MPIM_Op {
433   MPIM_MAX,
434   MPIM_MIN,
435   MPIM_SUM,
436   MPIM_PROD,
437   MPIM_LAND,
438   MPIM_BAND,
439   MPIM_LOR,
440   MPIM_BOR,
441   MPIM_LXOR,
442   MPIM_BXOR
443 };
444 #define MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_TYPE, TYPE)                                                                  \
445   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
446   {                                                                                                                    \
447     int i;                                                                                                             \
448     switch (op) {                                                                                                      \
449       case MPIM_MAX:                                                                                                   \
450         for (i = 0; i < cnt; i++)                                                                                      \
451           rslt[i] = (b1[i] > b2[i] ? b1[i] : b2[i]);                                                                   \
452         break;                                                                                                         \
453       case MPIM_MIN:                                                                                                   \
454         for (i = 0; i < cnt; i++)                                                                                      \
455           rslt[i] = (b1[i] < b2[i] ? b1[i] : b2[i]);                                                                   \
456         break;                                                                                                         \
457       case MPIM_SUM:                                                                                                   \
458         for (i = 0; i < cnt; i++)                                                                                      \
459           rslt[i] = b1[i] + b2[i];                                                                                     \
460         break;                                                                                                         \
461       case MPIM_PROD:                                                                                                  \
462         for (i = 0; i < cnt; i++)                                                                                      \
463           rslt[i] = b1[i] * b2[i];                                                                                     \
464         break;                                                                                                         \
465       case MPIM_LAND:                                                                                                  \
466         for (i = 0; i < cnt; i++)                                                                                      \
467           rslt[i] = b1[i] && b2[i];                                                                                    \
468         break;                                                                                                         \
469       case MPIM_LOR:                                                                                                   \
470         for (i = 0; i < cnt; i++)                                                                                      \
471           rslt[i] = b1[i] || b2[i];                                                                                    \
472         break;                                                                                                         \
473       case MPIM_LXOR:                                                                                                  \
474         for (i = 0; i < cnt; i++)                                                                                      \
475           rslt[i] = b1[i] != b2[i];                                                                                    \
476         break;                                                                                                         \
477       case MPIM_BAND:                                                                                                  \
478         for (i = 0; i < cnt; i++)                                                                                      \
479           rslt[i] = b1[i] & b2[i];                                                                                     \
480         break;                                                                                                         \
481       case MPIM_BOR:                                                                                                   \
482         for (i = 0; i < cnt; i++)                                                                                      \
483           rslt[i] = b1[i] | b2[i];                                                                                     \
484         break;                                                                                                         \
485       case MPIM_BXOR:                                                                                                  \
486         for (i = 0; i < cnt; i++)                                                                                      \
487           rslt[i] = b1[i] ^ b2[i];                                                                                     \
488         break;                                                                                                         \
489       default:                                                                                                         \
490         break;                                                                                                         \
491     }                                                                                                                  \
492   }
493
494 #define MPI_I_DO_OP_FP(MPI_I_do_op_TYPE, TYPE)                                                                         \
495   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
496   {                                                                                                                    \
497     int i;                                                                                                             \
498     switch (op) {                                                                                                      \
499       case MPIM_MAX:                                                                                                   \
500         for (i = 0; i < cnt; i++)                                                                                      \
501           rslt[i] = (b1[i] > b2[i] ? b1[i] : b2[i]);                                                                   \
502         break;                                                                                                         \
503       case MPIM_MIN:                                                                                                   \
504         for (i = 0; i < cnt; i++)                                                                                      \
505           rslt[i] = (b1[i] < b2[i] ? b1[i] : b2[i]);                                                                   \
506         break;                                                                                                         \
507       case MPIM_SUM:                                                                                                   \
508         for (i = 0; i < cnt; i++)                                                                                      \
509           rslt[i] = b1[i] + b2[i];                                                                                     \
510         break;                                                                                                         \
511       case MPIM_PROD:                                                                                                  \
512         for (i = 0; i < cnt; i++)                                                                                      \
513           rslt[i] = b1[i] * b2[i];                                                                                     \
514         break;                                                                                                         \
515       default:                                                                                                         \
516         break;                                                                                                         \
517     }                                                                                                                  \
518   }
519
520 #define MPI_I_DO_OP_BYTE(MPI_I_do_op_TYPE, TYPE)                                                                       \
521   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
522   {                                                                                                                    \
523     int i;                                                                                                             \
524     switch (op) {                                                                                                      \
525       case MPIM_BAND:                                                                                                  \
526         for (i = 0; i < cnt; i++)                                                                                      \
527           rslt[i] = b1[i] & b2[i];                                                                                     \
528         break;                                                                                                         \
529       case MPIM_BOR:                                                                                                   \
530         for (i = 0; i < cnt; i++)                                                                                      \
531           rslt[i] = b1[i] | b2[i];                                                                                     \
532         break;                                                                                                         \
533       case MPIM_BXOR:                                                                                                  \
534         for (i = 0; i < cnt; i++)                                                                                      \
535           rslt[i] = b1[i] ^ b2[i];                                                                                     \
536         break;                                                                                                         \
537       default:                                                                                                         \
538         break;                                                                                                         \
539     }                                                                                                                  \
540   }
541
542 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_short, short)
543 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_int, int)
544 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_long, long)
545 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ushort, unsigned short)
546 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_uint, unsigned int)
547 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ulong, unsigned long)
548 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ulonglong, unsigned long long)
549 MPI_I_DO_OP_FP(MPI_I_do_op_float, float)
550 MPI_I_DO_OP_FP(MPI_I_do_op_double, double)
551 MPI_I_DO_OP_BYTE(MPI_I_do_op_byte, char)
552
553 #define MPI_I_DO_OP_CALL(MPI_I_do_op_TYPE, TYPE)                                                                       \
554   MPI_I_do_op_TYPE((TYPE*)b1, (TYPE*)b2, (TYPE*)rslt, cnt, op);                                                        \
555   break;
556
557 static void MPI_I_do_op(void* b1, void* b2, void* rslt, int cnt, MPIM_Datatype datatype, MPIM_Op op)
558 {
559   switch (datatype) {
560     case MPIM_SHORT:
561       MPI_I_DO_OP_CALL(MPI_I_do_op_short, short)
562     case MPIM_INT:
563       MPI_I_DO_OP_CALL(MPI_I_do_op_int, int)
564     case MPIM_LONG:
565       MPI_I_DO_OP_CALL(MPI_I_do_op_long, long)
566     case MPIM_UNSIGNED_SHORT:
567       MPI_I_DO_OP_CALL(MPI_I_do_op_ushort, unsigned short)
568     case MPIM_UNSIGNED:
569       MPI_I_DO_OP_CALL(MPI_I_do_op_uint, unsigned int)
570     case MPIM_UNSIGNED_LONG:
571       MPI_I_DO_OP_CALL(MPI_I_do_op_ulong, unsigned long)
572     case MPIM_UNSIGNED_LONG_LONG:
573       MPI_I_DO_OP_CALL(MPI_I_do_op_ulonglong, unsigned long long)
574     case MPIM_FLOAT:
575       MPI_I_DO_OP_CALL(MPI_I_do_op_float, float)
576     case MPIM_DOUBLE:
577       MPI_I_DO_OP_CALL(MPI_I_do_op_double, double)
578     case MPIM_BYTE:
579       MPI_I_DO_OP_CALL(MPI_I_do_op_byte, char)
580   }
581 }
582
583 REDUCE_LIMITS
584 namespace simgrid::smpi {
585 static int MPI_I_anyReduce(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype mpi_datatype, MPI_Op mpi_op,
586                            int root, MPI_Comm comm, bool is_all)
587 {
588   char *scr1buf, *scr2buf, *scr3buf, *xxx, *sendbuf, *recvbuf;
589   int myrank, size, x_base, x_size, computed, idx;
590   int x_start, x_count = 0, r, n, mynewrank, newroot, partner;
591   int start_even[20], start_odd[20], count_even[20], count_odd[20];
592   MPI_Aint typelng;
593   MPI_Status status;
594   size_t scrlng;
595   int new_prot;
596   MPIM_Datatype datatype = MPIM_INT; MPIM_Op op = MPIM_MAX;
597
598   if     (mpi_datatype==MPI_SHORT         ) datatype=MPIM_SHORT;
599   else if(mpi_datatype==MPI_INT           ) datatype=MPIM_INT;
600   else if(mpi_datatype==MPI_LONG          ) datatype=MPIM_LONG;
601   else if(mpi_datatype==MPI_UNSIGNED_SHORT) datatype=MPIM_UNSIGNED_SHORT;
602   else if(mpi_datatype==MPI_UNSIGNED      ) datatype=MPIM_UNSIGNED;
603   else if(mpi_datatype==MPI_UNSIGNED_LONG ) datatype=MPIM_UNSIGNED_LONG;
604   else if(mpi_datatype==MPI_UNSIGNED_LONG_LONG ) datatype=MPIM_UNSIGNED_LONG_LONG;
605   else if(mpi_datatype==MPI_FLOAT         ) datatype=MPIM_FLOAT;
606   else if(mpi_datatype==MPI_DOUBLE        ) datatype=MPIM_DOUBLE;
607   else if(mpi_datatype==MPI_BYTE          ) datatype=MPIM_BYTE;
608   else
609     throw std::invalid_argument("reduce rab algorithm can't be used with this datatype!");
610
611   if     (mpi_op==MPI_MAX     ) op=MPIM_MAX;
612   else if(mpi_op==MPI_MIN     ) op=MPIM_MIN;
613   else if(mpi_op==MPI_SUM     ) op=MPIM_SUM;
614   else if(mpi_op==MPI_PROD    ) op=MPIM_PROD;
615   else if(mpi_op==MPI_LAND    ) op=MPIM_LAND;
616   else if(mpi_op==MPI_BAND    ) op=MPIM_BAND;
617   else if(mpi_op==MPI_LOR     ) op=MPIM_LOR;
618   else if(mpi_op==MPI_BOR     ) op=MPIM_BOR;
619   else if(mpi_op==MPI_LXOR    ) op=MPIM_LXOR;
620   else if(mpi_op==MPI_BXOR    ) op=MPIM_BXOR;
621
622   new_prot = 0;
623   MPI_Comm_size(comm, &size);
624   if (size > 1) /*otherwise no balancing_protocol*/
625   { int ss;
626     if      (size==2) ss=0;
627     else if (size==3) ss=1;
628     else { int s = size; while (!(s & 1)) s = s >> 1;
629            if (s==1) /* size == power of 2 */ ss = 2;
630            else      /* size != power of 2 */ ss = 3; }
631     switch(op) {
632      case MPIM_MAX:   case MPIM_MIN: case MPIM_SUM:  case MPIM_PROD:
633      case MPIM_LAND:  case MPIM_LOR: case MPIM_LXOR:
634      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
635       switch(datatype) {
636         case MPIM_SHORT:  case MPIM_UNSIGNED_SHORT:
637          new_prot = count >= Lsh[is_all][ss]; break;
638         case MPIM_INT:    case MPIM_UNSIGNED:
639          new_prot = count >= Lin[is_all][ss]; break;
640         case MPIM_LONG:   case MPIM_UNSIGNED_LONG: case MPIM_UNSIGNED_LONG_LONG:
641          new_prot = count >= Llg[is_all][ss]; break;
642      default:
643         break;
644     } default:
645         break;}
646     switch(op) {
647      case MPIM_MAX:  case MPIM_MIN: case MPIM_SUM: case MPIM_PROD:
648       switch(datatype) {
649         case MPIM_FLOAT:
650          new_prot = count >= Lfp[is_all][ss]; break;
651         case MPIM_DOUBLE:
652          new_prot = count >= Ldb[is_all][ss]; break;
653               default:
654         break;
655     } default:
656         break;}
657     switch(op) {
658      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
659       switch(datatype) {
660         case MPIM_BYTE:
661          new_prot = count >= Lby[is_all][ss]; break;
662               default:
663         break;
664     } default:
665         break;}
666 #   ifdef DEBUG
667     { char *ss_str[]={"two","three","power of 2","no power of 2"};
668       printf("MPI_(All)Reduce: is_all=%1d, size=%1d=%s, new_prot=%1d\n",
669              is_all, size, ss_str[ss], new_prot); fflush(stdout);
670     }
671 #   endif
672   }
673
674   if (new_prot)
675   {
676     sendbuf = (char*) Sendbuf;
677     recvbuf = (char*) Recvbuf;
678     MPI_Comm_rank(comm, &myrank);
679     MPI_Type_extent(mpi_datatype, &typelng);
680     scrlng  = typelng * count;
681 #ifdef NO_CACHE_OPTIMIZATION
682     scr1buf = new char[scrlng];
683     scr2buf = new char[scrlng];
684     scr3buf = new char[scrlng];
685 #else
686 #  ifdef SCR_LNG_OPTIM
687     scrlng = SCR_LNG_OPTIM(scrlng);
688 #  endif
689     scr2buf = new char[3 * scrlng]; /* To test cache problems.     */
690     scr1buf = scr2buf + 1*scrlng; /* scr1buf and scr3buf must not*/
691     scr3buf = scr2buf + 2*scrlng; /* be used for malloc because  */
692                                   /* they are interchanged below.*/
693 #endif
694     computed = 0;
695     if (is_all) root = myrank; /* for correct recvbuf handling */
696
697   /*...step 1 */
698
699 #   ifdef DEBUG
700      printf("[%2d] step 1 begin\n",myrank); fflush(stdout);
701 #   endif
702     n = 0; x_size = 1;
703     while (2*x_size <= size) { n++; x_size = x_size * 2; }
704     /* x_size == 2**n */
705     r = size - x_size;
706
707   /*...step 2 */
708
709 #   ifdef DEBUG
710     printf("[%2d] step 2 begin n=%d, r=%d\n",myrank,n,r);fflush(stdout);
711 #   endif
712     if (myrank < 2*r)
713     {
714       if ((myrank % 2) == 0 /*even*/)
715       {
716         MPI_I_Sendrecv(sendbuf + (count/2)*typelng,
717                        count - count/2, mpi_datatype, myrank+1, 1220,
718                        scr2buf, count/2,mpi_datatype, myrank+1, 1221,
719                        comm, &status);
720         MPI_I_do_op(sendbuf, scr2buf, scr1buf,
721                     count/2, datatype, op);
722         Request::recv(scr1buf + (count/2)*typelng, count - count/2,
723                  mpi_datatype, myrank+1, 1223, comm, &status);
724         computed = 1;
725 #       ifdef DEBUG
726         { int i; printf("[%2d] after step 2: val=",
727                         myrank);
728           for (i=0; i<count; i++)
729            printf(" %5.0lf",((double*)scr1buf)[i] );
730           printf("\n"); fflush(stdout);
731         }
732 #       endif
733       }
734       else /*odd*/
735       {
736         MPI_I_Sendrecv(sendbuf, count/2,mpi_datatype, myrank-1, 1221,
737                        scr2buf + (count/2)*typelng,
738                        count - count/2, mpi_datatype, myrank-1, 1220,
739                        comm, &status);
740         MPI_I_do_op(scr2buf + (count/2)*typelng,
741                     sendbuf + (count/2)*typelng,
742                     scr1buf + (count/2)*typelng,
743                     count - count/2, datatype, op);
744         Request::send(scr1buf + (count/2)*typelng, count - count/2,
745                  mpi_datatype, myrank-1, 1223, comm);
746       }
747     }
748
749   /*...step 3+4 */
750
751 #   ifdef DEBUG
752      printf("[%2d] step 3+4 begin\n",myrank); fflush(stdout);
753 #   endif
754     if ((myrank >= 2*r) || ((myrank%2 == 0)  &&  (myrank < 2*r)))
755          mynewrank = (myrank < 2*r ? myrank/2 : myrank-r);
756     else mynewrank = -1;
757
758     if (mynewrank >= 0)
759     { /* begin -- only for nodes with new rank */
760
761 #     define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
762
763   /*...step 5 */
764
765       x_start = 0;
766       x_count = count;
767       for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
768       {
769         start_even[idx] = x_start;
770         count_even[idx] = x_count / 2;
771         start_odd [idx] = x_start + count_even[idx];
772         count_odd [idx] = x_count - count_even[idx];
773         if (((mynewrank/x_base) % 2) == 0 /*even*/)
774         {
775 #         ifdef DEBUG
776             printf("[%2d](%2d) step 5.%d begin even c=%1d\n",
777                    myrank,mynewrank,idx+1,computed); fflush(stdout);
778 #         endif
779           x_start = start_even[idx];
780           x_count = count_even[idx];
781           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
782                          + start_odd[idx]*typelng, count_odd[idx],
783                          mpi_datatype, OLDRANK(mynewrank+x_base), 1231,
784                          scr2buf + x_start*typelng, x_count,
785                          mpi_datatype, OLDRANK(mynewrank+x_base), 1232,
786                          comm, &status);
787           MPI_I_do_op((computed?scr1buf:sendbuf) + x_start*typelng,
788                       scr2buf                    + x_start*typelng,
789                       ((root==myrank) && (idx==(n-1))
790                         ? recvbuf + x_start*typelng
791                         : scr3buf + x_start*typelng),
792                       x_count, datatype, op);
793         }
794         else /*odd*/
795         {
796 #         ifdef DEBUG
797             printf("[%2d](%2d) step 5.%d begin  odd c=%1d\n",
798                    myrank,mynewrank,idx+1,computed); fflush(stdout);
799 #         endif
800           x_start = start_odd[idx];
801           x_count = count_odd[idx];
802           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
803                          +start_even[idx]*typelng, count_even[idx],
804                          mpi_datatype, OLDRANK(mynewrank-x_base), 1232,
805                          scr2buf + x_start*typelng, x_count,
806                          mpi_datatype, OLDRANK(mynewrank-x_base), 1231,
807                          comm, &status);
808           MPI_I_do_op(scr2buf                    + x_start*typelng,
809                       (computed?scr1buf:sendbuf) + x_start*typelng,
810                       ((root==myrank) && (idx==(n-1))
811                         ? recvbuf + x_start*typelng
812                         : scr3buf + x_start*typelng),
813                       x_count, datatype, op);
814         }
815         xxx = scr3buf; scr3buf = scr1buf; scr1buf = xxx;
816         computed = 1;
817 #       ifdef DEBUG
818         { int i; printf("[%2d](%2d) after step 5.%d   end: start=%2d  count=%2d  val=",
819                         myrank,mynewrank,idx+1,x_start,x_count);
820           for (i=0; i<x_count; i++)
821            printf(" %5.0lf",((double*)((root==myrank)&&(idx==(n-1))
822                              ? recvbuf + x_start*typelng
823                              : scr1buf + x_start*typelng))[i] );
824           printf("\n"); fflush(stdout);
825         }
826 #       endif
827       } /*for*/
828
829 #     undef OLDRANK
830
831
832     } /* end -- only for nodes with new rank */
833
834     if (is_all)
835     {
836       /*...steps 6.1 to 6.n */
837
838       if (mynewrank >= 0)
839       { /* begin -- only for nodes with new rank */
840
841 #       define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
842
843         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
844         {
845 #         ifdef DEBUG
846             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
847 #         endif
848           if (((mynewrank/x_base) % 2) == 0 /*even*/)
849           {
850             MPI_I_Sendrecv(recvbuf + start_even[idx]*typelng,
851                                      count_even[idx],
852                            mpi_datatype, OLDRANK(mynewrank+x_base),1241,
853                            recvbuf + start_odd[idx]*typelng,
854                                      count_odd[idx],
855                            mpi_datatype, OLDRANK(mynewrank+x_base),1242,
856                            comm, &status);
857 #           ifdef DEBUG
858               x_start = start_odd[idx];
859               x_count = count_odd[idx];
860 #           endif
861           }
862           else /*odd*/
863           {
864             MPI_I_Sendrecv(recvbuf + start_odd[idx]*typelng,
865                                      count_odd[idx],
866                            mpi_datatype, OLDRANK(mynewrank-x_base),1242,
867                            recvbuf + start_even[idx]*typelng,
868                                      count_even[idx],
869                            mpi_datatype, OLDRANK(mynewrank-x_base),1241,
870                            comm, &status);
871 #           ifdef DEBUG
872               x_start = start_even[idx];
873               x_count = count_even[idx];
874 #           endif
875           }
876 #         ifdef DEBUG
877           { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
878                           myrank,mynewrank,n-idx,x_start,x_count);
879             for (i=0; i<x_count; i++)
880              printf(" %5.0lf",((double*)(root==myrank
881                                ? recvbuf + x_start*typelng
882                                : scr1buf + x_start*typelng))[i] );
883             printf("\n"); fflush(stdout);
884           }
885 #         endif
886         } /*for*/
887
888 #       undef OLDRANK
889
890       } /* end -- only for nodes with new rank */
891
892       /*...step 7 */
893
894       if (myrank < 2*r)
895       {
896 #       ifdef DEBUG
897           printf("[%2d] step 7 begin\n",myrank); fflush(stdout);
898 #       endif
899         if (myrank%2 == 0 /*even*/)
900           Request::send(recvbuf, count, mpi_datatype, myrank+1, 1253, comm);
901         else /*odd*/
902           Request::recv(recvbuf, count, mpi_datatype, myrank-1, 1253, comm, &status);
903       }
904
905     }
906     else /* not is_all, i.e. Reduce */
907     {
908
909     /*...step 6.0 */
910
911       if ((root < 2*r) && (root%2 == 1))
912       {
913 #       ifdef DEBUG
914           printf("[%2d] step 6.0 begin\n",myrank); fflush(stdout);
915 #       endif
916         if (myrank == 0) /* then mynewrank==0, x_start==0
917                                  x_count == count/x_size  */
918         {
919           Request::send(scr1buf,x_count,mpi_datatype,root,1241,comm);
920           mynewrank = -1;
921         }
922
923         if (myrank == root)
924         {
925           mynewrank = 0;
926           x_start = 0;
927           x_count = count;
928           for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
929           {
930             start_even[idx] = x_start;
931             count_even[idx] = x_count / 2;
932             start_odd [idx] = x_start + count_even[idx];
933             count_odd [idx] = x_count - count_even[idx];
934             /* code for always even in each bit of mynewrank: */
935             x_start = start_even[idx];
936             x_count = count_even[idx];
937           }
938           Request::recv(recvbuf,x_count,mpi_datatype,0,1241,comm,&status);
939         }
940         newroot = 0;
941       }
942       else
943       {
944         newroot = (root < 2*r ? root/2 : root-r);
945       }
946
947     /*...steps 6.1 to 6.n */
948
949       if (mynewrank >= 0)
950       { /* begin -- only for nodes with new rank */
951
952 #define OLDRANK(new) ((new) == newroot ? root : ((new) < r ? (new) * 2 : (new) + r))
953
954         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
955         {
956 #         ifdef DEBUG
957             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
958 #         endif
959           if ((mynewrank & x_base) != (newroot & x_base))
960           {
961             if (((mynewrank/x_base) % 2) == 0 /*even*/)
962             { x_start = start_even[idx]; x_count = count_even[idx];
963               partner = mynewrank+x_base; }
964             else
965             { x_start = start_odd[idx]; x_count = count_odd[idx];
966               partner = mynewrank-x_base; }
967             Request::send(scr1buf + x_start*typelng, x_count, mpi_datatype,
968                      OLDRANK(partner), 1244, comm);
969           }
970           else /*odd*/
971           {
972             if (((mynewrank/x_base) % 2) == 0 /*even*/)
973             { x_start = start_odd[idx]; x_count = count_odd[idx];
974               partner = mynewrank+x_base; }
975             else
976             { x_start = start_even[idx]; x_count = count_even[idx];
977               partner = mynewrank-x_base; }
978             Request::recv((myrank==root ? recvbuf : scr1buf)
979                      + x_start*typelng, x_count, mpi_datatype,
980                      OLDRANK(partner), 1244, comm, &status);
981 #           ifdef DEBUG
982             { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
983                             myrank,mynewrank,n-idx,x_start,x_count);
984               for (i=0; i<x_count; i++)
985                printf(" %5.0lf",((double*)(root==myrank
986                                  ? recvbuf + x_start*typelng
987                                  : scr1buf + x_start*typelng))[i] );
988               printf("\n"); fflush(stdout);
989             }
990 #           endif
991           }
992         } /*for*/
993
994 #       undef OLDRANK
995
996       } /* end -- only for nodes with new rank */
997     }
998
999 #   ifdef NO_CACHE_TESTING
1000     delete[] scr1buf;
1001     delete[] scr2buf;
1002     delete[] scr3buf;
1003 #   else
1004     delete[] scr2buf;             /* scr1buf and scr3buf are part of scr2buf */
1005 #   endif
1006     return(MPI_SUCCESS);
1007   } /* new_prot */
1008   /*otherwise:*/
1009   if (is_all)
1010     return (colls::allreduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, comm));
1011   else
1012     return (colls::reduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, root, comm));
1013 }
1014 #endif /*REDUCE_LIMITS*/
1015
1016 int reduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root,
1017                 MPI_Comm comm)
1018 {
1019   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, root, comm, false);
1020 }
1021
1022 int allreduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
1023 {
1024   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, -1, comm, true);
1025 }
1026 } // namespace simgrid::smpi