Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Fix indentation.
[simgrid.git] / src / smpi / colls / reduce / reduce-rab.cpp
1 /* extracted from mpig_myreduce.c with
2    :3,$s/MPL/MPI/g     and  :%s/\\\\$/ \\/   */
3
4 /* Copyright: Rolf Rabenseifner, 1997
5  *            Computing Center University of Stuttgart
6  *            rabenseifner@rus.uni-stuttgart.de
7  *
8  * The usage of this software is free,
9  * but this header must not be removed.
10  */
11
12 #include "../colls_private.hpp"
13 #include <cstdio>
14 #include <cstdlib>
15
16 #define REDUCE_NEW_ALWAYS 1
17
18 #ifdef CRAY
19 #      define SCR_LNG_OPTIM(bytelng)  128 + ((bytelng+127)/256) * 256;
20                                  /* =  16 + multiple of 32 doubles*/
21
22 #define REDUCE_LIMITS /*  values are lower limits for count arg.  */                                                   \
23                       /*  routine =    reduce                allreduce             */                                  \
24                       /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */                                  \
25   static int Lsh[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
26   static int Lin[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
27   static int Llg[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
28   static int Lfp[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
29   static int Ldb[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};                                               \
30   static int Lby[2][4] = {{896, 1728, 576, 736}, {448, 1280, 512, 512}};
31 #endif
32
33 #ifdef REDUCE_NEW_ALWAYS
34 # undef  REDUCE_LIMITS
35 #define REDUCE_LIMITS /*  values are lower limits for count arg.  */                                                   \
36                       /*  routine =    reduce                allreduce             */                                  \
37                       /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */                                  \
38   static int Lsh[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
39   static int Lin[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
40   static int Llg[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
41   static int Lfp[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
42   static int Ldb[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};                                                                 \
43   static int Lby[2][4] = {{1, 1, 1, 1}, {1, 1, 1, 1}};
44 #endif
45
46 /* Fast reduce and allreduce algorithm for longer buffers and predefined
47    operations.
48
49    This algorithm is explained with the example of 13 nodes.
50    The nodes are numbered   0, 1, 2, ... 12.
51    The sendbuf content is   a, b, c, ...  m.
52    The buffer array is notated with ABCDEFGH, this means that
53    e.g. 'C' is the third 1/8 of the buffer.
54
55    The algorithm computes
56
57    { [(a+b)+(c+d)] + [(e+f)+(g+h)] }  +  { [(i+j)+k] + [l+m] }
58
59    This is equivalent to the mostly used binary tree algorithm
60    (e.g. used in mpich).
61
62    size := number of nodes in the communicator.
63    2**n := the power of 2 that is next smaller or equal to the size.
64    r    := size - 2**n
65
66    Exa.: size=13 ==> n=3, r=5  (i.e. size == 13 == 2**n+r ==  2**3 + 5)
67
68    The algorithm needs for the execution of one colls::reduce
69
70    - for r==0
71      exec_time = n*(L1+L2)     + buf_lng * (1-1/2**n) * (T1 + T2 + O/d)
72
73    - for r>0
74      exec_time = (n+1)*(L1+L2) + buf_lng               * T1
75                                + buf_lng*(1+1/2-1/2**n)*     (T2 + O/d)
76
77    with L1 = latency of a message transfer e.g. by send+recv
78         L2 = latency of a message exchange e.g. by sendrecv
79         T1 = additional time to transfer 1 byte (= 1/bandwidth)
80         T2 = additional time to exchange 1 byte in both directions.
81         O  = time for one operation
82         d  = size of the datatype
83
84         On a MPP system it is expected that T1==T2.
85
86    In Comparison with the binary tree algorithm that needs:
87
88    - for r==0
89      exec_time_bin_tree =  n * L1   +  buf_lng * n * (T1 + O/d)
90
91    - for r>0
92      exec_time_bin_tree = (n+1)*L1  +  buf_lng*(n+1)*(T1 + O/d)
93
94    the new algorithm is faster if (assuming T1=T2)
95
96     for n>2:
97      for r==0:  buf_lng  >  L2 *   n   / [ (n-2) * T1 + (n-1) * O/d ]
98      for r>0:   buf_lng  >  L2 * (n+1) / [ (n-1.5)*T1 + (n-0.5)*O/d ]
99
100     for size = 5, 6, and 7:
101                 buf_lng  >  L2 / [0.25 * T1  +  0.58 * O/d ]
102
103     for size = 4:
104                 buf_lng  >  L2 / [0.25 * T1  +  0.62 * O/d ]
105     for size = 2 and 3:
106                 buf_lng  >  L2 / [  0.5 * O/d ]
107
108    and for O/d >> T1 we can summarize:
109
110     The new algorithm is faster for about   buf_lng > 2 * L2 * d / O
111
112    Example L1 = L2 = 50 us,
113            bandwidth=300MB/s, i.e. T1 = T2 = 3.3 ns/byte
114            and 10 MFLOP/s,    i.e. O  = 100 ns/FLOP
115            for double,        i.e. d  =   8 byte/FLOP
116
117            ==> the new algorithm is faster for about buf_lng>8000 bytes
118                i.e count > 1000 doubles !
119 Step 1)
120
121    compute n and r
122
123 Step 2)
124
125    if  myrank < 2*r
126
127      split the buffer into ABCD and EFGH
128
129      even myrank: send    buffer EFGH to   myrank+1 (buf_lng/2 exchange)
130                   receive buffer ABCD from myrank+1
131                   compute op for ABCD                 (count/2 ops)
132                   receive result EFGH               (buf_lng/2 transfer)
133      odd  myrank: send    buffer ABCD to   myrank+1
134                   receive buffer EFGH from myrank+1
135                   compute op for EFGH
136                   send    result EFGH
137
138    Result: node:     0    2    4    6    8  10  11  12
139            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
140
141 Step 3)
142
143    The following algorithm uses only the nodes with
144
145    (myrank is even  &&  myrank < 2*r) || (myrank >= 2*r)
146
147 Step 4)
148
149    define NEWRANK(old) := (old < 2*r ? old/2 : old-r)
150    define OLDRANK(new) := (new < r   ? new*2 : new+r)
151
152    Result:
153        old ranks:    0    2    4    6    8  10  11  12
154        new ranks:    0    1    2    3    4   5   6   7
155            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
156
157 Step 5.1)
158
159    Split the buffer (ABCDEFGH) in the middle,
160    the lower half (ABCD) is computed on even (new) ranks,
161    the opper half (EFGH) is computed on odd  (new) ranks.
162
163    exchange: ABCD from 1 to 0, from 3 to 2, from 5 to 4 and from 7 to 6
164              EFGH from 0 to 1, from 2 to 3, from 4 to 5 and from 6 to 7
165                                                (i.e. buf_lng/2 exchange)
166    compute op in each node on its half:        (i.e.   count/2 ops)
167
168    Result: node 0: (a+b)+(c+d)   for  ABCD
169            node 1: (a+b)+(c+d)   for  EFGH
170            node 2: (e+f)+(g+h)   for  ABCD
171            node 3: (e+f)+(g+h)   for  EFGH
172            node 4: (i+j)+ k      for  ABCD
173            node 5: (i+j)+ k      for  EFGH
174            node 6:    l + m      for  ABCD
175            node 7:    l + m      for  EFGH
176
177 Step 5.2)
178
179    Same with double distance and oncemore the half of the buffer.
180                                                (i.e. buf_lng/4 exchange)
181                                                (i.e.   count/4 ops)
182
183    Result: node 0: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  AB
184            node 1: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  EF
185            node 2: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  CD
186            node 3: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  GH
187            node 4: [(i+j)+ k   ] + [   l + m   ]  for  AB
188            node 5: [(i+j)+ k   ] + [   l + m   ]  for  EF
189            node 6: [(i+j)+ k   ] + [   l + m   ]  for  CD
190            node 7: [(i+j)+ k   ] + [   l + m   ]  for  GH
191
192 ...
193 Step 5.n)
194
195    Same with double distance and oncemore the half of the buffer.
196                                             (i.e. buf_lng/2**n exchange)
197                                             (i.e.   count/2**n ops)
198
199    Result:
200      0: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for A
201      1: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for E
202      2: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for C
203      3: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for G
204      4: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for B
205      5: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for F
206      6: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for D
207      7: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for H
208
209
210 For colls::allreduce:
211 ------------------
212
213 Step 6.1)
214
215    Exchange on the last distance (2*n) the last result
216                                             (i.e. buf_lng/2**n exchange)
217
218 Step 6.2)
219
220    Same with half the distance and double of the result
221                                         (i.e. buf_lng/2**(n-1) exchange)
222
223 ...
224 Step 6.n)
225
226    Same with distance 1 and double of the result (== half of the
227    original buffer)                            (i.e. buf_lng/2 exchange)
228
229    Result after 6.1     6.2         6.n
230
231     on node 0:   AB    ABCD    ABCDEFGH
232     on node 1:   EF    EFGH    ABCDEFGH
233     on node 2:   CD    ABCD    ABCDEFGH
234     on node 3:   GH    EFGH    ABCDEFGH
235     on node 4:   AB    ABCD    ABCDEFGH
236     on node 5:   EF    EFGH    ABCDEFGH
237     on node 6:   CD    ABCD    ABCDEFGH
238     on node 7:   GH    EFGH    ABCDEFGH
239
240 Step 7)
241
242    If r > 0
243      transfer the result from the even nodes with old rank < 2*r
244                            to the odd  nodes with old rank < 2*r
245                                                  (i.e. buf_lng transfer)
246    Result:
247      { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] }
248      for ABCDEFGH
249      on all nodes 0..12
250
251
252 For colls::reduce:
253 ---------------
254
255 Step 6.0)
256
257    If root node not in the list of the nodes with newranks
258    (see steps 3+4) then
259
260      send last result from node 0 to the root    (buf_lng/2**n transfer)
261      and replace the role of node 0 by the root node.
262
263 Step 6.1)
264
265    Send on the last distance (2**(n-1)) the last result
266    from node with bit '2**(n-1)' in the 'new rank' unequal to that of
267    root's new rank  to the node with same '2**(n-1)' bit.
268                                             (i.e. buf_lng/2**n transfer)
269
270 Step 6.2)
271
272    Same with half the distance and double of the result
273    and bit '2**(n-2)'
274                                         (i.e. buf_lng/2**(n-1) transfer)
275
276 ...
277 Step 6.n)
278
279    Same with distance 1 and double of the result (== half of the
280    original buffer) and bit '2**0'             (i.e. buf_lng/2 transfer)
281
282    Example: roots old rank: 10
283             roots new rank:  5
284
285    Results:          6.1               6.2                 6.n
286                 action result     action result       action result
287
288     on node 0:  send A
289     on node 1:  send E
290     on node 2:  send C
291     on node 3:  send G
292     on node 4:  recv A => AB     recv CD => ABCD   send ABCD
293     on node 5:  recv E => EF     recv GH => EFGH   recv ABCD => ABCDEFGH
294     on node 6:  recv C => CD     send CD
295     on node 7:  recv G => GH     send GH
296
297 Benchmark results on CRAY T3E
298 -----------------------------
299
300    uname -a: (sn6307 hwwt3e 1.6.1.51 unicosmk CRAY T3E)
301    MPI:      /opt/ctl/mpt/1.1.0.3
302    datatype: MPI_DOUBLE
303    Ldb[][] = {{ 896,1728, 576, 736},{ 448,1280, 512, 512}}
304    env: export MPI_BUFFER_MAX=4099
305    compiled with: cc -c -O3 -h restrict=f
306
307    old = binary tree protocol of the vendor
308    new = the new protocol and its implementation
309    mixed = 'new' is used if count > limit(datatype, communicator size)
310
311    REDUCE:
312                                           communicator size
313    measurement count prot. unit    2    3    4    6    8   12   16   24
314    --------------------------------------------------------------------
315    latency         1 mixed  us  20.7 35.1 35.6 49.1 49.2 61.8 62.4 74.8
316                        old  us  19.0 32.9 33.8 47.1 47.2 61.7 62.1 73.2
317    --------------------------------------------------------------------
318    bandwidth    128k mixed MB/s 75.4 34.9 49.1 27.3 41.1 24.6 38.0 23.7
319    (=buf_lng/time)     old MB/s 28.8 16.2 16.3 11.6 11.6  8.8  8.8  7.2
320    ration = mixed/old            2.6  2.1  3.0  2.4  3.5  2.8  4.3  3.3
321    --------------------------------------------------------------------
322    limit                doubles  896 1536  576  736  576  736  576  736
323    bandwidth   limit mixed MB/s 35.9 20.5 18.6 12.3 13.5  9.2 10.0  8.6
324                        old MB/s 35.9 20.3 17.8 13.0 12.5  9.6  9.2  7.8
325    ratio = mixed/old            1.00 1.01 1.04 0.95 1.08 0.96 1.09 1.10
326
327                                           communicator size
328    measurement count prot. unit   32   48   64   96  128  192  256
329    ---------------------------------------------------------------
330    latency         1 mixed  us  77.8 88.5 90.6  102                 1)
331                        old  us  78.6 87.2 90.1 99.7  108  119  120
332    ---------------------------------------------------------------
333    bandwidth    128k mixed MB/s 35.1 23.3 34.1 22.8 34.4 22.4 33.9
334    (=buf_lng/time)     old MB/s  6.0  6.0  6.0  5.2  5.2  4.6  4.6
335    ration = mixed/old            5.8  3.9  5.7  4.4  6.6  4.8  7.4  5)
336    ---------------------------------------------------------------
337    limit                doubles  576  736  576  736  576  736  576  2)
338    bandwidth   limit mixed MB/s  9.7  7.5  8.4  6.5  6.9  5.5  5.1  3)
339                        old MB/s  7.7  6.4  6.4  5.7  5.5  4.9  4.7  3)
340    ratio = mixed/old            1.26 1.17 1.31 1.14 1.26 1.12 1.08  4)
341
342    ALLREDUCE:
343                                           communicator size
344    measurement count prot. unit    2    3    4    6    8   12   16   24
345    --------------------------------------------------------------------
346    latency         1 mixed  us  28.2 51.0 44.5 74.4 59.9  102 74.2  133
347                        old  us  26.9 48.3 42.4 69.8 57.6 96.0 75.7  126
348    --------------------------------------------------------------------
349    bandwidth    128k mixed MB/s 74.0 29.4 42.4 23.3 35.0 20.9 32.8 19.7
350    (=buf_lng/time)     old MB/s 20.9 14.4 13.2  9.7  9.6  7.3  7.4  5.8
351    ration = mixed/old            3.5  2.0  3.2  2.4  3.6  2.9  4.4  3.4
352    --------------------------------------------------------------------
353    limit                doubles  448 1280  512  512  512  512  512  512
354    bandwidth   limit mixed MB/s 26.4 15.1 16.2  8.2 12.4  7.2 10.8  5.7
355                        old MB/s 26.1 14.9 15.0  9.1 10.5  6.7  7.3  5.3
356    ratio = mixed/old            1.01 1.01 1.08 0.90 1.18 1.07 1.48 1.08
357
358                                           communicator size
359    measurement count prot. unit   32   48   64   96  128  192  256
360    ---------------------------------------------------------------
361    latency         1 mixed  us  90.3  162  109  190
362                        old  us  92.7  152  104  179  122  225  135
363    ---------------------------------------------------------------
364    bandwidth    128k mixed MB/s 31.1 19.7 30.1 19.2 29.8 18.7 29.0
365    (=buf_lng/time)     old MB/s  5.9  4.8  5.0  4.1  4.4  3.4  3.8
366    ration = mixed/old            5.3  4.1  6.0  4.7  6.8  5.5  7.7
367    ---------------------------------------------------------------
368    limit                doubles  512  512  512  512  512  512  512
369    bandwidth   limit mixed MB/s  6.6  5.6  5.7  3.5  4.4  3.2  3.8
370                        old MB/s  6.3  4.2  5.4  3.6  4.4  3.1
371    ratio = mixed/old            1.05 1.33 1.06 0.97 1.00 1.03
372
373    Footnotes:
374    1) This line shows that the overhead to decide which protocol
375       should be used can be ignored.
376    2) This line shows the limit for the count argument.
377       If count < limit then the vendor protocol is used,
378       otherwise the new protocol is used (see variable Ldb).
379    3) These lines show the bandwidth (= buffer length / execution time)
380       for both protocols.
381    4) This line shows that the limit is chosen well if the ratio is
382       between 0.95 (losing 5% for buffer length near and >=limit)
383       and 1.10 (not gaining 10% for buffer length near and <limit).
384    5) This line shows that the new protocol is 2..7 times faster
385       for long counts.
386
387 */
388
389 #ifdef REDUCE_LIMITS
390
391 #ifdef USE_Irecv
392 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
393   {                                                                                                                    \
394     MPI_Request req;                                                                                                   \
395     req = Request::irecv(rb, rc, rd, source, rt, comm);                                                                \
396     Request::send(sb, sc, sd, dest, st, comm);                                                                         \
397     Request::wait(&req, stat);                                                                                         \
398   }
399 #else
400 #ifdef USE_Isend
401 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
402   {                                                                                                                    \
403     MPI_Request req;                                                                                                   \
404     req = mpi_mpi_isend(sb, sc, sd, dest, st, comm);                                                                   \
405     Request::recv(rb, rc, rd, source, rt, comm, stat);                                                                 \
406     Request::wait(&req, stat);                                                                                         \
407   }
408 #else
409 #define MPI_I_Sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)                                       \
410   Request::sendrecv(sb, sc, sd, dest, st, rb, rc, rd, source, rt, comm, stat)
411 #endif
412 #endif
413
414 enum MPIM_Datatype {
415   MPIM_SHORT,
416   MPIM_INT,
417   MPIM_LONG,
418   MPIM_UNSIGNED_SHORT,
419   MPIM_UNSIGNED,
420   MPIM_UNSIGNED_LONG,
421   MPIM_UNSIGNED_LONG_LONG,
422   MPIM_FLOAT,
423   MPIM_DOUBLE,
424   MPIM_BYTE
425 };
426
427 enum MPIM_Op {
428   MPIM_MAX,
429   MPIM_MIN,
430   MPIM_SUM,
431   MPIM_PROD,
432   MPIM_LAND,
433   MPIM_BAND,
434   MPIM_LOR,
435   MPIM_BOR,
436   MPIM_LXOR,
437   MPIM_BXOR
438 };
439 #define MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_TYPE, TYPE)                                                                  \
440   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
441   {                                                                                                                    \
442     int i;                                                                                                             \
443     switch (op) {                                                                                                      \
444       case MPIM_MAX:                                                                                                   \
445         for (i = 0; i < cnt; i++)                                                                                      \
446           rslt[i] = (b1[i] > b2[i] ? b1[i] : b2[i]);                                                                   \
447         break;                                                                                                         \
448       case MPIM_MIN:                                                                                                   \
449         for (i = 0; i < cnt; i++)                                                                                      \
450           rslt[i] = (b1[i] < b2[i] ? b1[i] : b2[i]);                                                                   \
451         break;                                                                                                         \
452       case MPIM_SUM:                                                                                                   \
453         for (i = 0; i < cnt; i++)                                                                                      \
454           rslt[i] = b1[i] + b2[i];                                                                                     \
455         break;                                                                                                         \
456       case MPIM_PROD:                                                                                                  \
457         for (i = 0; i < cnt; i++)                                                                                      \
458           rslt[i] = b1[i] * b2[i];                                                                                     \
459         break;                                                                                                         \
460       case MPIM_LAND:                                                                                                  \
461         for (i = 0; i < cnt; i++)                                                                                      \
462           rslt[i] = b1[i] && b2[i];                                                                                    \
463         break;                                                                                                         \
464       case MPIM_LOR:                                                                                                   \
465         for (i = 0; i < cnt; i++)                                                                                      \
466           rslt[i] = b1[i] || b2[i];                                                                                    \
467         break;                                                                                                         \
468       case MPIM_LXOR:                                                                                                  \
469         for (i = 0; i < cnt; i++)                                                                                      \
470           rslt[i] = b1[i] != b2[i];                                                                                    \
471         break;                                                                                                         \
472       case MPIM_BAND:                                                                                                  \
473         for (i = 0; i < cnt; i++)                                                                                      \
474           rslt[i] = b1[i] & b2[i];                                                                                     \
475         break;                                                                                                         \
476       case MPIM_BOR:                                                                                                   \
477         for (i = 0; i < cnt; i++)                                                                                      \
478           rslt[i] = b1[i] | b2[i];                                                                                     \
479         break;                                                                                                         \
480       case MPIM_BXOR:                                                                                                  \
481         for (i = 0; i < cnt; i++)                                                                                      \
482           rslt[i] = b1[i] ^ b2[i];                                                                                     \
483         break;                                                                                                         \
484       default:                                                                                                         \
485         break;                                                                                                         \
486     }                                                                                                                  \
487   }
488
489 #define MPI_I_DO_OP_FP(MPI_I_do_op_TYPE, TYPE)                                                                         \
490   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
491   {                                                                                                                    \
492     int i;                                                                                                             \
493     switch (op) {                                                                                                      \
494       case MPIM_MAX:                                                                                                   \
495         for (i = 0; i < cnt; i++)                                                                                      \
496           rslt[i] = (b1[i] > b2[i] ? b1[i] : b2[i]);                                                                   \
497         break;                                                                                                         \
498       case MPIM_MIN:                                                                                                   \
499         for (i = 0; i < cnt; i++)                                                                                      \
500           rslt[i] = (b1[i] < b2[i] ? b1[i] : b2[i]);                                                                   \
501         break;                                                                                                         \
502       case MPIM_SUM:                                                                                                   \
503         for (i = 0; i < cnt; i++)                                                                                      \
504           rslt[i] = b1[i] + b2[i];                                                                                     \
505         break;                                                                                                         \
506       case MPIM_PROD:                                                                                                  \
507         for (i = 0; i < cnt; i++)                                                                                      \
508           rslt[i] = b1[i] * b2[i];                                                                                     \
509         break;                                                                                                         \
510       default:                                                                                                         \
511         break;                                                                                                         \
512     }                                                                                                                  \
513   }
514
515 #define MPI_I_DO_OP_BYTE(MPI_I_do_op_TYPE, TYPE)                                                                       \
516   static void MPI_I_do_op_TYPE(TYPE* b1, TYPE* b2, TYPE* rslt, int cnt, MPIM_Op op)                                    \
517   {                                                                                                                    \
518     int i;                                                                                                             \
519     switch (op) {                                                                                                      \
520       case MPIM_BAND:                                                                                                  \
521         for (i = 0; i < cnt; i++)                                                                                      \
522           rslt[i] = b1[i] & b2[i];                                                                                     \
523         break;                                                                                                         \
524       case MPIM_BOR:                                                                                                   \
525         for (i = 0; i < cnt; i++)                                                                                      \
526           rslt[i] = b1[i] | b2[i];                                                                                     \
527         break;                                                                                                         \
528       case MPIM_BXOR:                                                                                                  \
529         for (i = 0; i < cnt; i++)                                                                                      \
530           rslt[i] = b1[i] ^ b2[i];                                                                                     \
531         break;                                                                                                         \
532       default:                                                                                                         \
533         break;                                                                                                         \
534     }                                                                                                                  \
535   }
536
537 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_short, short)
538 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_int, int)
539 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_long, long)
540 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ushort, unsigned short)
541 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_uint, unsigned int)
542 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ulong, unsigned long)
543 MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_ulonglong, unsigned long long)
544 MPI_I_DO_OP_FP(MPI_I_do_op_float, float)
545 MPI_I_DO_OP_FP(MPI_I_do_op_double, double)
546 MPI_I_DO_OP_BYTE(MPI_I_do_op_byte, char)
547
548 #define MPI_I_DO_OP_CALL(MPI_I_do_op_TYPE, TYPE)                                                                       \
549   MPI_I_do_op_TYPE((TYPE*)b1, (TYPE*)b2, (TYPE*)rslt, cnt, op);                                                        \
550   break;
551
552 static void MPI_I_do_op(void* b1, void* b2, void* rslt, int cnt, MPIM_Datatype datatype, MPIM_Op op)
553 {
554   switch (datatype) {
555     case MPIM_SHORT:
556       MPI_I_DO_OP_CALL(MPI_I_do_op_short, short)
557     case MPIM_INT:
558       MPI_I_DO_OP_CALL(MPI_I_do_op_int, int)
559     case MPIM_LONG:
560       MPI_I_DO_OP_CALL(MPI_I_do_op_long, long)
561     case MPIM_UNSIGNED_SHORT:
562       MPI_I_DO_OP_CALL(MPI_I_do_op_ushort, unsigned short)
563     case MPIM_UNSIGNED:
564       MPI_I_DO_OP_CALL(MPI_I_do_op_uint, unsigned int)
565     case MPIM_UNSIGNED_LONG:
566       MPI_I_DO_OP_CALL(MPI_I_do_op_ulong, unsigned long)
567     case MPIM_UNSIGNED_LONG_LONG:
568       MPI_I_DO_OP_CALL(MPI_I_do_op_ulonglong, unsigned long long)
569     case MPIM_FLOAT:
570       MPI_I_DO_OP_CALL(MPI_I_do_op_float, float)
571     case MPIM_DOUBLE:
572       MPI_I_DO_OP_CALL(MPI_I_do_op_double, double)
573     case MPIM_BYTE:
574       MPI_I_DO_OP_CALL(MPI_I_do_op_byte, char)
575   }
576 }
577
578 REDUCE_LIMITS
579 namespace simgrid {
580 namespace smpi {
581 static int MPI_I_anyReduce(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype mpi_datatype, MPI_Op mpi_op,
582                            int root, MPI_Comm comm, bool is_all)
583 {
584   char *scr1buf, *scr2buf, *scr3buf, *xxx, *sendbuf, *recvbuf;
585   int myrank, size, x_base, x_size, computed, idx;
586   int x_start, x_count = 0, r, n, mynewrank, newroot, partner;
587   int start_even[20], start_odd[20], count_even[20], count_odd[20];
588   MPI_Aint typelng;
589   MPI_Status status;
590   size_t scrlng;
591   int new_prot;
592   MPIM_Datatype datatype = MPIM_INT; MPIM_Op op = MPIM_MAX;
593
594   if     (mpi_datatype==MPI_SHORT         ) datatype=MPIM_SHORT;
595   else if(mpi_datatype==MPI_INT           ) datatype=MPIM_INT;
596   else if(mpi_datatype==MPI_LONG          ) datatype=MPIM_LONG;
597   else if(mpi_datatype==MPI_UNSIGNED_SHORT) datatype=MPIM_UNSIGNED_SHORT;
598   else if(mpi_datatype==MPI_UNSIGNED      ) datatype=MPIM_UNSIGNED;
599   else if(mpi_datatype==MPI_UNSIGNED_LONG ) datatype=MPIM_UNSIGNED_LONG;
600   else if(mpi_datatype==MPI_UNSIGNED_LONG_LONG ) datatype=MPIM_UNSIGNED_LONG_LONG;
601   else if(mpi_datatype==MPI_FLOAT         ) datatype=MPIM_FLOAT;
602   else if(mpi_datatype==MPI_DOUBLE        ) datatype=MPIM_DOUBLE;
603   else if(mpi_datatype==MPI_BYTE          ) datatype=MPIM_BYTE;
604   else
605     throw std::invalid_argument("reduce rab algorithm can't be used with this datatype!");
606
607   if     (mpi_op==MPI_MAX     ) op=MPIM_MAX;
608   else if(mpi_op==MPI_MIN     ) op=MPIM_MIN;
609   else if(mpi_op==MPI_SUM     ) op=MPIM_SUM;
610   else if(mpi_op==MPI_PROD    ) op=MPIM_PROD;
611   else if(mpi_op==MPI_LAND    ) op=MPIM_LAND;
612   else if(mpi_op==MPI_BAND    ) op=MPIM_BAND;
613   else if(mpi_op==MPI_LOR     ) op=MPIM_LOR;
614   else if(mpi_op==MPI_BOR     ) op=MPIM_BOR;
615   else if(mpi_op==MPI_LXOR    ) op=MPIM_LXOR;
616   else if(mpi_op==MPI_BXOR    ) op=MPIM_BXOR;
617
618   new_prot = 0;
619   MPI_Comm_size(comm, &size);
620   if (size > 1) /*otherwise no balancing_protocol*/
621   { int ss;
622     if      (size==2) ss=0;
623     else if (size==3) ss=1;
624     else { int s = size; while (!(s & 1)) s = s >> 1;
625            if (s==1) /* size == power of 2 */ ss = 2;
626            else      /* size != power of 2 */ ss = 3; }
627     switch(op) {
628      case MPIM_MAX:   case MPIM_MIN: case MPIM_SUM:  case MPIM_PROD:
629      case MPIM_LAND:  case MPIM_LOR: case MPIM_LXOR:
630      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
631       switch(datatype) {
632         case MPIM_SHORT:  case MPIM_UNSIGNED_SHORT:
633          new_prot = count >= Lsh[is_all][ss]; break;
634         case MPIM_INT:    case MPIM_UNSIGNED:
635          new_prot = count >= Lin[is_all][ss]; break;
636         case MPIM_LONG:   case MPIM_UNSIGNED_LONG: case MPIM_UNSIGNED_LONG_LONG:
637          new_prot = count >= Llg[is_all][ss]; break;
638      default:
639         break;
640     } default:
641         break;}
642     switch(op) {
643      case MPIM_MAX:  case MPIM_MIN: case MPIM_SUM: case MPIM_PROD:
644       switch(datatype) {
645         case MPIM_FLOAT:
646          new_prot = count >= Lfp[is_all][ss]; break;
647         case MPIM_DOUBLE:
648          new_prot = count >= Ldb[is_all][ss]; break;
649               default:
650         break;
651     } default:
652         break;}
653     switch(op) {
654      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
655       switch(datatype) {
656         case MPIM_BYTE:
657          new_prot = count >= Lby[is_all][ss]; break;
658               default:
659         break;
660     } default:
661         break;}
662 #   ifdef DEBUG
663     { char *ss_str[]={"two","three","power of 2","no power of 2"};
664       printf("MPI_(All)Reduce: is_all=%1d, size=%1d=%s, new_prot=%1d\n",
665              is_all, size, ss_str[ss], new_prot); fflush(stdout);
666     }
667 #   endif
668   }
669
670   if (new_prot)
671   {
672     sendbuf = (char*) Sendbuf;
673     recvbuf = (char*) Recvbuf;
674     MPI_Comm_rank(comm, &myrank);
675     MPI_Type_extent(mpi_datatype, &typelng);
676     scrlng  = typelng * count;
677 #ifdef NO_CACHE_OPTIMIZATION
678     scr1buf = new char[scrlng];
679     scr2buf = new char[scrlng];
680     scr3buf = new char[scrlng];
681 #else
682 #  ifdef SCR_LNG_OPTIM
683     scrlng = SCR_LNG_OPTIM(scrlng);
684 #  endif
685     scr2buf = new char[3 * scrlng]; /* To test cache problems.     */
686     scr1buf = scr2buf + 1*scrlng; /* scr1buf and scr3buf must not*/
687     scr3buf = scr2buf + 2*scrlng; /* be used for malloc because  */
688                                   /* they are interchanged below.*/
689 #endif
690     computed = 0;
691     if (is_all) root = myrank; /* for correct recvbuf handling */
692
693   /*...step 1 */
694
695 #   ifdef DEBUG
696      printf("[%2d] step 1 begin\n",myrank); fflush(stdout);
697 #   endif
698     n = 0; x_size = 1;
699     while (2*x_size <= size) { n++; x_size = x_size * 2; }
700     /* x_size == 2**n */
701     r = size - x_size;
702
703   /*...step 2 */
704
705 #   ifdef DEBUG
706     printf("[%2d] step 2 begin n=%d, r=%d\n",myrank,n,r);fflush(stdout);
707 #   endif
708     if (myrank < 2*r)
709     {
710       if ((myrank % 2) == 0 /*even*/)
711       {
712         MPI_I_Sendrecv(sendbuf + (count/2)*typelng,
713                        count - count/2, mpi_datatype, myrank+1, 1220,
714                        scr2buf, count/2,mpi_datatype, myrank+1, 1221,
715                        comm, &status);
716         MPI_I_do_op(sendbuf, scr2buf, scr1buf,
717                     count/2, datatype, op);
718         Request::recv(scr1buf + (count/2)*typelng, count - count/2,
719                  mpi_datatype, myrank+1, 1223, comm, &status);
720         computed = 1;
721 #       ifdef DEBUG
722         { int i; printf("[%2d] after step 2: val=",
723                         myrank);
724           for (i=0; i<count; i++)
725            printf(" %5.0lf",((double*)scr1buf)[i] );
726           printf("\n"); fflush(stdout);
727         }
728 #       endif
729       }
730       else /*odd*/
731       {
732         MPI_I_Sendrecv(sendbuf, count/2,mpi_datatype, myrank-1, 1221,
733                        scr2buf + (count/2)*typelng,
734                        count - count/2, mpi_datatype, myrank-1, 1220,
735                        comm, &status);
736         MPI_I_do_op(scr2buf + (count/2)*typelng,
737                     sendbuf + (count/2)*typelng,
738                     scr1buf + (count/2)*typelng,
739                     count - count/2, datatype, op);
740         Request::send(scr1buf + (count/2)*typelng, count - count/2,
741                  mpi_datatype, myrank-1, 1223, comm);
742       }
743     }
744
745   /*...step 3+4 */
746
747 #   ifdef DEBUG
748      printf("[%2d] step 3+4 begin\n",myrank); fflush(stdout);
749 #   endif
750     if ((myrank >= 2*r) || ((myrank%2 == 0)  &&  (myrank < 2*r)))
751          mynewrank = (myrank < 2*r ? myrank/2 : myrank-r);
752     else mynewrank = -1;
753
754     if (mynewrank >= 0)
755     { /* begin -- only for nodes with new rank */
756
757 #     define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
758
759   /*...step 5 */
760
761       x_start = 0;
762       x_count = count;
763       for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
764       {
765         start_even[idx] = x_start;
766         count_even[idx] = x_count / 2;
767         start_odd [idx] = x_start + count_even[idx];
768         count_odd [idx] = x_count - count_even[idx];
769         if (((mynewrank/x_base) % 2) == 0 /*even*/)
770         {
771 #         ifdef DEBUG
772             printf("[%2d](%2d) step 5.%d begin even c=%1d\n",
773                    myrank,mynewrank,idx+1,computed); fflush(stdout);
774 #         endif
775           x_start = start_even[idx];
776           x_count = count_even[idx];
777           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
778                          + start_odd[idx]*typelng, count_odd[idx],
779                          mpi_datatype, OLDRANK(mynewrank+x_base), 1231,
780                          scr2buf + x_start*typelng, x_count,
781                          mpi_datatype, OLDRANK(mynewrank+x_base), 1232,
782                          comm, &status);
783           MPI_I_do_op((computed?scr1buf:sendbuf) + x_start*typelng,
784                       scr2buf                    + x_start*typelng,
785                       ((root==myrank) && (idx==(n-1))
786                         ? recvbuf + x_start*typelng
787                         : scr3buf + x_start*typelng),
788                       x_count, datatype, op);
789         }
790         else /*odd*/
791         {
792 #         ifdef DEBUG
793             printf("[%2d](%2d) step 5.%d begin  odd c=%1d\n",
794                    myrank,mynewrank,idx+1,computed); fflush(stdout);
795 #         endif
796           x_start = start_odd[idx];
797           x_count = count_odd[idx];
798           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
799                          +start_even[idx]*typelng, count_even[idx],
800                          mpi_datatype, OLDRANK(mynewrank-x_base), 1232,
801                          scr2buf + x_start*typelng, x_count,
802                          mpi_datatype, OLDRANK(mynewrank-x_base), 1231,
803                          comm, &status);
804           MPI_I_do_op(scr2buf                    + x_start*typelng,
805                       (computed?scr1buf:sendbuf) + x_start*typelng,
806                       ((root==myrank) && (idx==(n-1))
807                         ? recvbuf + x_start*typelng
808                         : scr3buf + x_start*typelng),
809                       x_count, datatype, op);
810         }
811         xxx = scr3buf; scr3buf = scr1buf; scr1buf = xxx;
812         computed = 1;
813 #       ifdef DEBUG
814         { int i; printf("[%2d](%2d) after step 5.%d   end: start=%2d  count=%2d  val=",
815                         myrank,mynewrank,idx+1,x_start,x_count);
816           for (i=0; i<x_count; i++)
817            printf(" %5.0lf",((double*)((root==myrank)&&(idx==(n-1))
818                              ? recvbuf + x_start*typelng
819                              : scr1buf + x_start*typelng))[i] );
820           printf("\n"); fflush(stdout);
821         }
822 #       endif
823       } /*for*/
824
825 #     undef OLDRANK
826
827
828     } /* end -- only for nodes with new rank */
829
830     if (is_all)
831     {
832       /*...steps 6.1 to 6.n */
833
834       if (mynewrank >= 0)
835       { /* begin -- only for nodes with new rank */
836
837 #       define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
838
839         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
840         {
841 #         ifdef DEBUG
842             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
843 #         endif
844           if (((mynewrank/x_base) % 2) == 0 /*even*/)
845           {
846             MPI_I_Sendrecv(recvbuf + start_even[idx]*typelng,
847                                      count_even[idx],
848                            mpi_datatype, OLDRANK(mynewrank+x_base),1241,
849                            recvbuf + start_odd[idx]*typelng,
850                                      count_odd[idx],
851                            mpi_datatype, OLDRANK(mynewrank+x_base),1242,
852                            comm, &status);
853 #           ifdef DEBUG
854               x_start = start_odd[idx];
855               x_count = count_odd[idx];
856 #           endif
857           }
858           else /*odd*/
859           {
860             MPI_I_Sendrecv(recvbuf + start_odd[idx]*typelng,
861                                      count_odd[idx],
862                            mpi_datatype, OLDRANK(mynewrank-x_base),1242,
863                            recvbuf + start_even[idx]*typelng,
864                                      count_even[idx],
865                            mpi_datatype, OLDRANK(mynewrank-x_base),1241,
866                            comm, &status);
867 #           ifdef DEBUG
868               x_start = start_even[idx];
869               x_count = count_even[idx];
870 #           endif
871           }
872 #         ifdef DEBUG
873           { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
874                           myrank,mynewrank,n-idx,x_start,x_count);
875             for (i=0; i<x_count; i++)
876              printf(" %5.0lf",((double*)(root==myrank
877                                ? recvbuf + x_start*typelng
878                                : scr1buf + x_start*typelng))[i] );
879             printf("\n"); fflush(stdout);
880           }
881 #         endif
882         } /*for*/
883
884 #       undef OLDRANK
885
886       } /* end -- only for nodes with new rank */
887
888       /*...step 7 */
889
890       if (myrank < 2*r)
891       {
892 #       ifdef DEBUG
893           printf("[%2d] step 7 begin\n",myrank); fflush(stdout);
894 #       endif
895         if (myrank%2 == 0 /*even*/)
896           Request::send(recvbuf, count, mpi_datatype, myrank+1, 1253, comm);
897         else /*odd*/
898           Request::recv(recvbuf, count, mpi_datatype, myrank-1, 1253, comm, &status);
899       }
900
901     }
902     else /* not is_all, i.e. Reduce */
903     {
904
905     /*...step 6.0 */
906
907       if ((root < 2*r) && (root%2 == 1))
908       {
909 #       ifdef DEBUG
910           printf("[%2d] step 6.0 begin\n",myrank); fflush(stdout);
911 #       endif
912         if (myrank == 0) /* then mynewrank==0, x_start==0
913                                  x_count == count/x_size  */
914         {
915           Request::send(scr1buf,x_count,mpi_datatype,root,1241,comm);
916           mynewrank = -1;
917         }
918
919         if (myrank == root)
920         {
921           mynewrank = 0;
922           x_start = 0;
923           x_count = count;
924           for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
925           {
926             start_even[idx] = x_start;
927             count_even[idx] = x_count / 2;
928             start_odd [idx] = x_start + count_even[idx];
929             count_odd [idx] = x_count - count_even[idx];
930             /* code for always even in each bit of mynewrank: */
931             x_start = start_even[idx];
932             x_count = count_even[idx];
933           }
934           Request::recv(recvbuf,x_count,mpi_datatype,0,1241,comm,&status);
935         }
936         newroot = 0;
937       }
938       else
939       {
940         newroot = (root < 2*r ? root/2 : root-r);
941       }
942
943     /*...steps 6.1 to 6.n */
944
945       if (mynewrank >= 0)
946       { /* begin -- only for nodes with new rank */
947
948 #define OLDRANK(new) ((new) == newroot ? root : ((new) < r ? (new) * 2 : (new) + r))
949
950         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
951         {
952 #         ifdef DEBUG
953             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
954 #         endif
955           if ((mynewrank & x_base) != (newroot & x_base))
956           {
957             if (((mynewrank/x_base) % 2) == 0 /*even*/)
958             { x_start = start_even[idx]; x_count = count_even[idx];
959               partner = mynewrank+x_base; }
960             else
961             { x_start = start_odd[idx]; x_count = count_odd[idx];
962               partner = mynewrank-x_base; }
963             Request::send(scr1buf + x_start*typelng, x_count, mpi_datatype,
964                      OLDRANK(partner), 1244, comm);
965           }
966           else /*odd*/
967           {
968             if (((mynewrank/x_base) % 2) == 0 /*even*/)
969             { x_start = start_odd[idx]; x_count = count_odd[idx];
970               partner = mynewrank+x_base; }
971             else
972             { x_start = start_even[idx]; x_count = count_even[idx];
973               partner = mynewrank-x_base; }
974             Request::recv((myrank==root ? recvbuf : scr1buf)
975                      + x_start*typelng, x_count, mpi_datatype,
976                      OLDRANK(partner), 1244, comm, &status);
977 #           ifdef DEBUG
978             { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
979                             myrank,mynewrank,n-idx,x_start,x_count);
980               for (i=0; i<x_count; i++)
981                printf(" %5.0lf",((double*)(root==myrank
982                                  ? recvbuf + x_start*typelng
983                                  : scr1buf + x_start*typelng))[i] );
984               printf("\n"); fflush(stdout);
985             }
986 #           endif
987           }
988         } /*for*/
989
990 #       undef OLDRANK
991
992       } /* end -- only for nodes with new rank */
993     }
994
995 #   ifdef NO_CACHE_TESTING
996     delete[] scr1buf;
997     delete[] scr2buf;
998     delete[] scr3buf;
999 #   else
1000     delete[] scr2buf;             /* scr1buf and scr3buf are part of scr2buf */
1001 #   endif
1002     return(MPI_SUCCESS);
1003   } /* new_prot */
1004   /*otherwise:*/
1005   if (is_all)
1006     return (colls::allreduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, comm));
1007   else
1008     return (colls::reduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, root, comm));
1009 }
1010 #endif /*REDUCE_LIMITS*/
1011
1012 int reduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root,
1013                 MPI_Comm comm)
1014 {
1015   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, root, comm, false);
1016 }
1017
1018 int allreduce__rab(const void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
1019 {
1020   return MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, -1, comm, true);
1021 }
1022 }
1023 }