Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
b399292eeb3ad2db46aa7f81786da96faabd7b9f
[simgrid.git] / src / smpi / colls / reduce / reduce-rab.cpp
1 /* extracted from mpig_myreduce.c with
2    :3,$s/MPL/MPI/g     and  :%s/\\\\$/ \\/   */
3
4 /* Copyright: Rolf Rabenseifner, 1997
5  *            Computing Center University of Stuttgart
6  *            rabenseifner@rus.uni-stuttgart.de
7  *
8  * The usage of this software is free,
9  * but this header must not be removed.
10  */
11
12 #include "../colls_private.hpp"
13 #include <cstdio>
14 #include <cstdlib>
15
16 #define REDUCE_NEW_ALWAYS 1
17
18 #ifdef CRAY
19 #      define SCR_LNG_OPTIM(bytelng)  128 + ((bytelng+127)/256) * 256;
20                                  /* =  16 + multiple of 32 doubles*/
21
22 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
23        /*  routine =    reduce                allreduce             */ \
24        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
25  static int Lsh[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
26  static int Lin[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
27  static int Llg[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
28  static int Lfp[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
29  static int Ldb[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
30  static int Lby[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};
31 #endif
32
33 #ifdef REDUCE_NEW_ALWAYS
34 # undef  REDUCE_LIMITS
35 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
36        /*  routine =    reduce                allreduce             */ \
37        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
38  static int Lsh[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
39  static int Lin[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
40  static int Llg[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
41  static int Lfp[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
42  static int Ldb[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
43  static int Lby[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};
44 #endif
45
46 /* Fast reduce and allreduce algorithm for longer buffers and predefined
47    operations.
48
49    This algorithm is explaned with the example of 13 nodes.
50    The nodes are numbered   0, 1, 2, ... 12.
51    The sendbuf content is   a, b, c, ...  m.
52    The buffer array is notated with ABCDEFGH, this means that
53    e.g. 'C' is the third 1/8 of the buffer.
54
55    The algorithm computes
56
57    { [(a+b)+(c+d)] + [(e+f)+(g+h)] }  +  { [(i+j)+k] + [l+m] }
58
59    This is equivalent to the mostly used binary tree algorithm
60    (e.g. used in mpich).
61
62    size := number of nodes in the communicator.
63    2**n := the power of 2 that is next smaller or equal to the size.
64    r    := size - 2**n
65
66    Exa.: size=13 ==> n=3, r=5  (i.e. size == 13 == 2**n+r ==  2**3 + 5)
67
68    The algoritm needs for the execution of one Colls::reduce
69
70    - for r==0
71      exec_time = n*(L1+L2)     + buf_lng * (1-1/2**n) * (T1 + T2 + O/d)
72
73    - for r>0
74      exec_time = (n+1)*(L1+L2) + buf_lng               * T1
75                                + buf_lng*(1+1/2-1/2**n)*     (T2 + O/d)
76
77    with L1 = latency of a message transfer e.g. by send+recv
78         L2 = latency of a message exchange e.g. by sendrecv
79         T1 = additional time to transfer 1 byte (= 1/bandwidth)
80         T2 = additional time to exchange 1 byte in both directions.
81         O  = time for one operation
82         d  = size of the datatype
83
84         On a MPP system it is expected that T1==T2.
85
86    In Comparison with the binary tree algorithm that needs:
87
88    - for r==0
89      exec_time_bin_tree =  n * L1   +  buf_lng * n * (T1 + O/d)
90
91    - for r>0
92      exec_time_bin_tree = (n+1)*L1  +  buf_lng*(n+1)*(T1 + O/d)
93
94    the new algorithm is faster if (assuming T1=T2)
95
96     for n>2:
97      for r==0:  buf_lng  >  L2 *   n   / [ (n-2) * T1 + (n-1) * O/d ]
98      for r>0:   buf_lng  >  L2 * (n+1) / [ (n-1.5)*T1 + (n-0.5)*O/d ]
99
100     for size = 5, 6, and 7:
101                 buf_lng  >  L2 / [0.25 * T1  +  0.58 * O/d ]
102
103     for size = 4:
104                 buf_lng  >  L2 / [0.25 * T1  +  0.62 * O/d ]
105     for size = 2 and 3:
106                 buf_lng  >  L2 / [  0.5 * O/d ]
107
108    and for O/d >> T1 we can summarize:
109
110     The new algorithm is faster for about   buf_lng > 2 * L2 * d / O
111
112    Example L1 = L2 = 50 us,
113            bandwidth=300MB/s, i.e. T1 = T2 = 3.3 ns/byte
114            and 10 MFLOP/s,    i.e. O  = 100 ns/FLOP
115            for double,        i.e. d  =   8 byte/FLOP
116
117            ==> the new algorithm is faster for about buf_lng>8000 bytes
118                i.e count > 1000 doubles !
119 Step 1)
120
121    compute n and r
122
123 Step 2)
124
125    if  myrank < 2*r
126
127      split the buffer into ABCD and EFGH
128
129      even myrank: send    buffer EFGH to   myrank+1 (buf_lng/2 exchange)
130                   receive buffer ABCD from myrank+1
131                   compute op for ABCD                 (count/2 ops)
132                   receive result EFGH               (buf_lng/2 transfer)
133      odd  myrank: send    buffer ABCD to   myrank+1
134                   receive buffer EFGH from myrank+1
135                   compute op for EFGH
136                   send    result EFGH
137
138    Result: node:     0    2    4    6    8  10  11  12
139            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
140
141 Step 3)
142
143    The following algorithm uses only the nodes with
144
145    (myrank is even  &&  myrank < 2*r) || (myrank >= 2*r)
146
147 Step 4)
148
149    define NEWRANK(old) := (old < 2*r ? old/2 : old-r)
150    define OLDRANK(new) := (new < r   ? new*2 : new+r)
151
152    Result:
153        old ranks:    0    2    4    6    8  10  11  12
154        new ranks:    0    1    2    3    4   5   6   7
155            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
156
157 Step 5.1)
158
159    Split the buffer (ABCDEFGH) in the middle,
160    the lower half (ABCD) is computed on even (new) ranks,
161    the opper half (EFGH) is computed on odd  (new) ranks.
162
163    exchange: ABCD from 1 to 0, from 3 to 2, from 5 to 4 and from 7 to 6
164              EFGH from 0 to 1, from 2 to 3, from 4 to 5 and from 6 to 7
165                                                (i.e. buf_lng/2 exchange)
166    compute op in each node on its half:        (i.e.   count/2 ops)
167
168    Result: node 0: (a+b)+(c+d)   for  ABCD
169            node 1: (a+b)+(c+d)   for  EFGH
170            node 2: (e+f)+(g+h)   for  ABCD
171            node 3: (e+f)+(g+h)   for  EFGH
172            node 4: (i+j)+ k      for  ABCD
173            node 5: (i+j)+ k      for  EFGH
174            node 6:    l + m      for  ABCD
175            node 7:    l + m      for  EFGH
176
177 Step 5.2)
178
179    Same with double distance and oncemore the half of the buffer.
180                                                (i.e. buf_lng/4 exchange)
181                                                (i.e.   count/4 ops)
182
183    Result: node 0: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  AB
184            node 1: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  EF
185            node 2: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  CD
186            node 3: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  GH
187            node 4: [(i+j)+ k   ] + [   l + m   ]  for  AB
188            node 5: [(i+j)+ k   ] + [   l + m   ]  for  EF
189            node 6: [(i+j)+ k   ] + [   l + m   ]  for  CD
190            node 7: [(i+j)+ k   ] + [   l + m   ]  for  GH
191
192 ...
193 Step 5.n)
194
195    Same with double distance and oncemore the half of the buffer.
196                                             (i.e. buf_lng/2**n exchange)
197                                             (i.e.   count/2**n ops)
198
199    Result:
200      0: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for A
201      1: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for E
202      2: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for C
203      3: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for G
204      4: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for B
205      5: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for F
206      6: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for D
207      7: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for H
208
209
210 For Colls::allreduce:
211 ------------------
212
213 Step 6.1)
214
215    Exchange on the last distance (2*n) the last result
216                                             (i.e. buf_lng/2**n exchange)
217
218 Step 6.2)
219
220    Same with half the distance and double of the result
221                                         (i.e. buf_lng/2**(n-1) exchange)
222
223 ...
224 Step 6.n)
225
226    Same with distance 1 and double of the result (== half of the
227    original buffer)                            (i.e. buf_lng/2 exchange)
228
229    Result after 6.1     6.2         6.n
230
231     on node 0:   AB    ABCD    ABCDEFGH
232     on node 1:   EF    EFGH    ABCDEFGH
233     on node 2:   CD    ABCD    ABCDEFGH
234     on node 3:   GH    EFGH    ABCDEFGH
235     on node 4:   AB    ABCD    ABCDEFGH
236     on node 5:   EF    EFGH    ABCDEFGH
237     on node 6:   CD    ABCD    ABCDEFGH
238     on node 7:   GH    EFGH    ABCDEFGH
239
240 Step 7)
241
242    If r > 0
243      transfer the result from the even nodes with old rank < 2*r
244                            to the odd  nodes with old rank < 2*r
245                                                  (i.e. buf_lng transfer)
246    Result:
247      { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] }
248      for ABCDEFGH
249      on all nodes 0..12
250
251
252 For Colls::reduce:
253 ---------------
254
255 Step 6.0)
256
257    If root node not in the list of the nodes with newranks
258    (see steps 3+4) then
259
260      send last result from node 0 to the root    (buf_lng/2**n transfer)
261      and replace the role of node 0 by the root node.
262
263 Step 6.1)
264
265    Send on the last distance (2**(n-1)) the last result
266    from node with bit '2**(n-1)' in the 'new rank' unequal to that of
267    root's new rank  to the node with same '2**(n-1)' bit.
268                                             (i.e. buf_lng/2**n transfer)
269
270 Step 6.2)
271
272    Same with half the distance and double of the result
273    and bit '2**(n-2)'
274                                         (i.e. buf_lng/2**(n-1) transfer)
275
276 ...
277 Step 6.n)
278
279    Same with distance 1 and double of the result (== half of the
280    original buffer) and bit '2**0'             (i.e. buf_lng/2 transfer)
281
282    Example: roots old rank: 10
283             roots new rank:  5
284
285    Results:          6.1               6.2                 6.n
286                 action result     action result       action result
287
288     on node 0:  send A
289     on node 1:  send E
290     on node 2:  send C
291     on node 3:  send G
292     on node 4:  recv A => AB     recv CD => ABCD   send ABCD
293     on node 5:  recv E => EF     recv GH => EFGH   recv ABCD => ABCDEFGH
294     on node 6:  recv C => CD     send CD
295     on node 7:  recv G => GH     send GH
296
297 Benchmark results on CRAY T3E
298 -----------------------------
299
300    uname -a: (sn6307 hwwt3e 1.6.1.51 unicosmk CRAY T3E)
301    MPI:      /opt/ctl/mpt/1.1.0.3
302    datatype: MPI_DOUBLE
303    Ldb[][] = {{ 896,1728, 576, 736},{ 448,1280, 512, 512}}
304    env: export MPI_BUFFER_MAX=4099
305    compiled with: cc -c -O3 -h restrict=f
306
307    old = binary tree protocol of the vendor
308    new = the new protocol and its implementation
309    mixed = 'new' is used if count > limit(datatype, communicator size)
310
311    REDUCE:
312                                           communicator size
313    measurement count prot. unit    2    3    4    6    8   12   16   24
314    --------------------------------------------------------------------
315    latency         1 mixed  us  20.7 35.1 35.6 49.1 49.2 61.8 62.4 74.8
316                        old  us  19.0 32.9 33.8 47.1 47.2 61.7 62.1 73.2
317    --------------------------------------------------------------------
318    bandwidth    128k mixed MB/s 75.4 34.9 49.1 27.3 41.1 24.6 38.0 23.7
319    (=buf_lng/time)     old MB/s 28.8 16.2 16.3 11.6 11.6  8.8  8.8  7.2
320    ration = mixed/old            2.6  2.1  3.0  2.4  3.5  2.8  4.3  3.3
321    --------------------------------------------------------------------
322    limit                doubles  896 1536  576  736  576  736  576  736
323    bandwidth   limit mixed MB/s 35.9 20.5 18.6 12.3 13.5  9.2 10.0  8.6
324                        old MB/s 35.9 20.3 17.8 13.0 12.5  9.6  9.2  7.8
325    ratio = mixed/old            1.00 1.01 1.04 0.95 1.08 0.96 1.09 1.10
326
327                                           communicator size
328    measurement count prot. unit   32   48   64   96  128  192  256
329    ---------------------------------------------------------------
330    latency         1 mixed  us  77.8 88.5 90.6  102                 1)
331                        old  us  78.6 87.2 90.1 99.7  108  119  120
332    ---------------------------------------------------------------
333    bandwidth    128k mixed MB/s 35.1 23.3 34.1 22.8 34.4 22.4 33.9
334    (=buf_lng/time)     old MB/s  6.0  6.0  6.0  5.2  5.2  4.6  4.6
335    ration = mixed/old            5.8  3.9  5.7  4.4  6.6  4.8  7.4  5)
336    ---------------------------------------------------------------
337    limit                doubles  576  736  576  736  576  736  576  2)
338    bandwidth   limit mixed MB/s  9.7  7.5  8.4  6.5  6.9  5.5  5.1  3)
339                        old MB/s  7.7  6.4  6.4  5.7  5.5  4.9  4.7  3)
340    ratio = mixed/old            1.26 1.17 1.31 1.14 1.26 1.12 1.08  4)
341
342    ALLREDUCE:
343                                           communicator size
344    measurement count prot. unit    2    3    4    6    8   12   16   24
345    --------------------------------------------------------------------
346    latency         1 mixed  us  28.2 51.0 44.5 74.4 59.9  102 74.2  133
347                        old  us  26.9 48.3 42.4 69.8 57.6 96.0 75.7  126
348    --------------------------------------------------------------------
349    bandwidth    128k mixed MB/s 74.0 29.4 42.4 23.3 35.0 20.9 32.8 19.7
350    (=buf_lng/time)     old MB/s 20.9 14.4 13.2  9.7  9.6  7.3  7.4  5.8
351    ration = mixed/old            3.5  2.0  3.2  2.4  3.6  2.9  4.4  3.4
352    --------------------------------------------------------------------
353    limit                doubles  448 1280  512  512  512  512  512  512
354    bandwidth   limit mixed MB/s 26.4 15.1 16.2  8.2 12.4  7.2 10.8  5.7
355                        old MB/s 26.1 14.9 15.0  9.1 10.5  6.7  7.3  5.3
356    ratio = mixed/old            1.01 1.01 1.08 0.90 1.18 1.07 1.48 1.08
357
358                                           communicator size
359    measurement count prot. unit   32   48   64   96  128  192  256
360    ---------------------------------------------------------------
361    latency         1 mixed  us  90.3  162  109  190
362                        old  us  92.7  152  104  179  122  225  135
363    ---------------------------------------------------------------
364    bandwidth    128k mixed MB/s 31.1 19.7 30.1 19.2 29.8 18.7 29.0
365    (=buf_lng/time)     old MB/s  5.9  4.8  5.0  4.1  4.4  3.4  3.8
366    ration = mixed/old            5.3  4.1  6.0  4.7  6.8  5.5  7.7
367    ---------------------------------------------------------------
368    limit                doubles  512  512  512  512  512  512  512
369    bandwidth   limit mixed MB/s  6.6  5.6  5.7  3.5  4.4  3.2  3.8
370                        old MB/s  6.3  4.2  5.4  3.6  4.4  3.1
371    ratio = mixed/old            1.05 1.33 1.06 0.97 1.00 1.03
372
373    Footnotes:
374    1) This line shows that the overhead to decide which protocol
375       should be used can be ignored.
376    2) This line shows the limit for the count argument.
377       If count < limit then the vendor protocol is used,
378       otherwise the new protocol is used (see variable Ldb).
379    3) These lines show the bandwidth (=bufer length / execution time)
380       for both protocols.
381    4) This line shows that the limit is choosen well if the ratio is
382       between 0.95 (loosing 5% for buffer length near and >=limit)
383       and 1.10 (not gaining 10% for buffer length near and <limit).
384    5) This line shows that the new protocol is 2..7 times faster
385       for long counts.
386
387 */
388
389 #ifdef REDUCE_LIMITS
390
391 #ifdef USE_Irecv
392 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
393            { MPI_Request req;                                          \
394              req=Request::irecv(rb,rc,rd,source,rt,comm);                  \
395              Request::send(sb,sc,sd,dest,st,comm);                          \
396              Request::wait(&req,stat);                                      \
397            }
398 #else
399 #ifdef USE_Isend
400 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
401            { MPI_Request req;                                          \
402              req=mpi_mpi_isend(sb,sc,sd,dest,st,comm);                    \
403              Request::recv(rb,rc,rd,source,rt,comm,stat);                   \
404              Request::wait(&req,stat);                                      \
405            }
406 #else
407 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
408            Request::sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat)
409 #endif
410 #endif
411
412 enum MPIM_Datatype {
413   MPIM_SHORT,
414   MPIM_INT,
415   MPIM_LONG,
416   MPIM_UNSIGNED_SHORT,
417   MPIM_UNSIGNED,
418   MPIM_UNSIGNED_LONG,
419   MPIM_UNSIGNED_LONG_LONG,
420   MPIM_FLOAT,
421   MPIM_DOUBLE,
422   MPIM_BYTE
423 };
424
425 enum MPIM_Op {
426   MPIM_MAX,
427   MPIM_MIN,
428   MPIM_SUM,
429   MPIM_PROD,
430   MPIM_LAND,
431   MPIM_BAND,
432   MPIM_LOR,
433   MPIM_BOR,
434   MPIM_LXOR,
435   MPIM_BXOR
436 };
437 #define MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_TYPE,TYPE)                   \
438 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
439 { int i;                                                               \
440   switch (op) {                                                        \
441    case MPIM_MAX :                                                     \
442       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
443    case MPIM_MIN :                                                     \
444       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
445    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
446    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
447    case MPIM_LAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] && b2[i]; break;  \
448    case MPIM_LOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] || b2[i]; break;  \
449    case MPIM_LXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] != b2[i]; break;  \
450    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
451    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
452    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
453    default: break;                                                     \
454   }                                                                    \
455 }
456
457 #define MPI_I_DO_OP_FP(MPI_I_do_op_TYPE,TYPE)                          \
458 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op) \
459 { int i;                                                               \
460   switch (op) {                                                        \
461    case MPIM_MAX :                                                     \
462       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
463    case MPIM_MIN :                                                     \
464       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
465    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
466    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
467    default: break;                                                     \
468   }                                                                    \
469 }
470
471 #define MPI_I_DO_OP_BYTE(MPI_I_do_op_TYPE,TYPE)                        \
472 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
473 { int i;                                                               \
474   switch (op) {                                                        \
475    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
476    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
477    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
478    default: break;                                                     \
479   }                                                                    \
480 }
481
482 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_short,  short)
483 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_int,    int)
484 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_long,   long)
485 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ushort, unsigned short)
486 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_uint,   unsigned int)
487 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulong,  unsigned long)
488 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulonglong,  unsigned long long)
489 MPI_I_DO_OP_FP(        MPI_I_do_op_float,  float)
490 MPI_I_DO_OP_FP(        MPI_I_do_op_double, double)
491 MPI_I_DO_OP_BYTE(      MPI_I_do_op_byte,   char)
492
493 #define MPI_I_DO_OP_CALL(MPI_I_do_op_TYPE,TYPE)                        \
494    MPI_I_do_op_TYPE ((TYPE*)b1, (TYPE*)b2, (TYPE*)rslt, cnt, op); break;
495
496 static void MPI_I_do_op(void* b1, void* b2, void* rslt, int cnt,
497                  MPIM_Datatype datatype, MPIM_Op op)
498 {
499  switch (datatype) {
500   case MPIM_SHORT : MPI_I_DO_OP_CALL(MPI_I_do_op_short,  short)
501   case MPIM_INT   : MPI_I_DO_OP_CALL(MPI_I_do_op_int,    int)
502   case MPIM_LONG  : MPI_I_DO_OP_CALL(MPI_I_do_op_long,   long)
503   case MPIM_UNSIGNED_SHORT:
504                    MPI_I_DO_OP_CALL(MPI_I_do_op_ushort, unsigned short)
505   case MPIM_UNSIGNED:
506                    MPI_I_DO_OP_CALL(MPI_I_do_op_uint,   unsigned int)
507   case MPIM_UNSIGNED_LONG:
508                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulong,  unsigned long)
509   case MPIM_UNSIGNED_LONG_LONG:
510                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulonglong,  unsigned long long)
511   case MPIM_FLOAT : MPI_I_DO_OP_CALL(MPI_I_do_op_float,  float)
512   case MPIM_DOUBLE: MPI_I_DO_OP_CALL(MPI_I_do_op_double, double)
513   case MPIM_BYTE  : MPI_I_DO_OP_CALL(MPI_I_do_op_byte,   char)
514  }
515 }
516
517 REDUCE_LIMITS
518 namespace simgrid{
519 namespace smpi{
520 static int MPI_I_anyReduce(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype mpi_datatype, MPI_Op mpi_op, int root, MPI_Comm comm, int is_all)
521 {
522   char *scr1buf, *scr2buf, *scr3buf, *xxx, *sendbuf, *recvbuf;
523   int myrank, size, x_base, x_size, computed, idx;
524   int x_start, x_count = 0, r, n, mynewrank, newroot, partner;
525   int start_even[20], start_odd[20], count_even[20], count_odd[20];
526   MPI_Aint typelng;
527   MPI_Status status;
528   size_t scrlng;
529   int new_prot;
530   MPIM_Datatype datatype = MPIM_INT; MPIM_Op op = MPIM_MAX;
531
532   if     (mpi_datatype==MPI_SHORT         ) datatype=MPIM_SHORT;
533   else if(mpi_datatype==MPI_INT           ) datatype=MPIM_INT;
534   else if(mpi_datatype==MPI_LONG          ) datatype=MPIM_LONG;
535   else if(mpi_datatype==MPI_UNSIGNED_SHORT) datatype=MPIM_UNSIGNED_SHORT;
536   else if(mpi_datatype==MPI_UNSIGNED      ) datatype=MPIM_UNSIGNED;
537   else if(mpi_datatype==MPI_UNSIGNED_LONG ) datatype=MPIM_UNSIGNED_LONG;
538   else if(mpi_datatype==MPI_UNSIGNED_LONG_LONG ) datatype=MPIM_UNSIGNED_LONG_LONG;
539   else if(mpi_datatype==MPI_FLOAT         ) datatype=MPIM_FLOAT;
540   else if(mpi_datatype==MPI_DOUBLE        ) datatype=MPIM_DOUBLE;
541   else if(mpi_datatype==MPI_BYTE          ) datatype=MPIM_BYTE;
542   else
543    THROWF(arg_error,0, "reduce rab algorithm can't be used with this datatype ! ");
544
545   if     (mpi_op==MPI_MAX     ) op=MPIM_MAX;
546   else if(mpi_op==MPI_MIN     ) op=MPIM_MIN;
547   else if(mpi_op==MPI_SUM     ) op=MPIM_SUM;
548   else if(mpi_op==MPI_PROD    ) op=MPIM_PROD;
549   else if(mpi_op==MPI_LAND    ) op=MPIM_LAND;
550   else if(mpi_op==MPI_BAND    ) op=MPIM_BAND;
551   else if(mpi_op==MPI_LOR     ) op=MPIM_LOR;
552   else if(mpi_op==MPI_BOR     ) op=MPIM_BOR;
553   else if(mpi_op==MPI_LXOR    ) op=MPIM_LXOR;
554   else if(mpi_op==MPI_BXOR    ) op=MPIM_BXOR;
555
556   new_prot = 0;
557   MPI_Comm_size(comm, &size);
558   if (size > 1) /*otherwise no balancing_protocol*/
559   { int ss;
560     if      (size==2) ss=0;
561     else if (size==3) ss=1;
562     else { int s = size; while (!(s & 1)) s = s >> 1;
563            if (s==1) /* size == power of 2 */ ss = 2;
564            else      /* size != power of 2 */ ss = 3; }
565     switch(op) {
566      case MPIM_MAX:   case MPIM_MIN: case MPIM_SUM:  case MPIM_PROD:
567      case MPIM_LAND:  case MPIM_LOR: case MPIM_LXOR:
568      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
569       switch(datatype) {
570         case MPIM_SHORT:  case MPIM_UNSIGNED_SHORT:
571          new_prot = count >= Lsh[is_all][ss]; break;
572         case MPIM_INT:    case MPIM_UNSIGNED:
573          new_prot = count >= Lin[is_all][ss]; break;
574         case MPIM_LONG:   case MPIM_UNSIGNED_LONG: case MPIM_UNSIGNED_LONG_LONG:
575          new_prot = count >= Llg[is_all][ss]; break;
576      default:
577         break;
578     } default:
579         break;}
580     switch(op) {
581      case MPIM_MAX:  case MPIM_MIN: case MPIM_SUM: case MPIM_PROD:
582       switch(datatype) {
583         case MPIM_FLOAT:
584          new_prot = count >= Lfp[is_all][ss]; break;
585         case MPIM_DOUBLE:
586          new_prot = count >= Ldb[is_all][ss]; break;
587               default:
588         break;
589     } default:
590         break;}
591     switch(op) {
592      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
593       switch(datatype) {
594         case MPIM_BYTE:
595          new_prot = count >= Lby[is_all][ss]; break;
596               default:
597         break;
598     } default:
599         break;}
600 #   ifdef DEBUG
601     { char *ss_str[]={"two","three","power of 2","no power of 2"};
602       printf("MPI_(All)Reduce: is_all=%1d, size=%1d=%s, new_prot=%1d\n",
603              is_all, size, ss_str[ss], new_prot); fflush(stdout);
604     }
605 #   endif
606   }
607
608   if (new_prot)
609   {
610     sendbuf = (char*) Sendbuf;
611     recvbuf = (char*) Recvbuf;
612     MPI_Comm_rank(comm, &myrank);
613     MPI_Type_extent(mpi_datatype, &typelng);
614     scrlng  = typelng * count;
615 #ifdef NO_CACHE_OPTIMIZATION
616     scr1buf = static_cast<char*>(xbt_malloc(scrlng));
617     scr2buf = static_cast<char*>(xbt_malloc(scrlng));
618     scr3buf = static_cast<char*>(xbt_malloc(scrlng));
619 #else
620 #  ifdef SCR_LNG_OPTIM
621     scrlng = SCR_LNG_OPTIM(scrlng);
622 #  endif
623     scr2buf = static_cast<char*>(xbt_malloc(3*scrlng));   /* To test cache problems.     */
624     scr1buf = scr2buf + 1*scrlng; /* scr1buf and scr3buf must not*/
625     scr3buf = scr2buf + 2*scrlng; /* be used for malloc because  */
626                                   /* they are interchanged below.*/
627 #endif
628     computed = 0;
629     if (is_all) root = myrank; /* for correct recvbuf handling */
630
631   /*...step 1 */
632
633 #   ifdef DEBUG
634      printf("[%2d] step 1 begin\n",myrank); fflush(stdout);
635 #   endif
636     n = 0; x_size = 1;
637     while (2*x_size <= size) { n++; x_size = x_size * 2; }
638     /* x_sixe == 2**n */
639     r = size - x_size;
640
641   /*...step 2 */
642
643 #   ifdef DEBUG
644     printf("[%2d] step 2 begin n=%d, r=%d\n",myrank,n,r);fflush(stdout);
645 #   endif
646     if (myrank < 2*r)
647     {
648       if ((myrank % 2) == 0 /*even*/)
649       {
650         MPI_I_Sendrecv(sendbuf + (count/2)*typelng,
651                        count - count/2, mpi_datatype, myrank+1, 1220,
652                        scr2buf, count/2,mpi_datatype, myrank+1, 1221,
653                        comm, &status);
654         MPI_I_do_op(sendbuf, scr2buf, scr1buf,
655                     count/2, datatype, op);
656         Request::recv(scr1buf + (count/2)*typelng, count - count/2,
657                  mpi_datatype, myrank+1, 1223, comm, &status);
658         computed = 1;
659 #       ifdef DEBUG
660         { int i; printf("[%2d] after step 2: val=",
661                         myrank);
662           for (i=0; i<count; i++)
663            printf(" %5.0lf",((double*)scr1buf)[i] );
664           printf("\n"); fflush(stdout);
665         }
666 #       endif
667       }
668       else /*odd*/
669       {
670         MPI_I_Sendrecv(sendbuf, count/2,mpi_datatype, myrank-1, 1221,
671                        scr2buf + (count/2)*typelng,
672                        count - count/2, mpi_datatype, myrank-1, 1220,
673                        comm, &status);
674         MPI_I_do_op(scr2buf + (count/2)*typelng,
675                     sendbuf + (count/2)*typelng,
676                     scr1buf + (count/2)*typelng,
677                     count - count/2, datatype, op);
678         Request::send(scr1buf + (count/2)*typelng, count - count/2,
679                  mpi_datatype, myrank-1, 1223, comm);
680       }
681     }
682
683   /*...step 3+4 */
684
685 #   ifdef DEBUG
686      printf("[%2d] step 3+4 begin\n",myrank); fflush(stdout);
687 #   endif
688     if ((myrank >= 2*r) || ((myrank%2 == 0)  &&  (myrank < 2*r)))
689          mynewrank = (myrank < 2*r ? myrank/2 : myrank-r);
690     else mynewrank = -1;
691
692     if (mynewrank >= 0)
693     { /* begin -- only for nodes with new rank */
694
695 #     define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
696
697   /*...step 5 */
698
699       x_start = 0;
700       x_count = count;
701       for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
702       {
703         start_even[idx] = x_start;
704         count_even[idx] = x_count / 2;
705         start_odd [idx] = x_start + count_even[idx];
706         count_odd [idx] = x_count - count_even[idx];
707         if (((mynewrank/x_base) % 2) == 0 /*even*/)
708         {
709 #         ifdef DEBUG
710             printf("[%2d](%2d) step 5.%d begin even c=%1d\n",
711                    myrank,mynewrank,idx+1,computed); fflush(stdout);
712 #         endif
713           x_start = start_even[idx];
714           x_count = count_even[idx];
715           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
716                          + start_odd[idx]*typelng, count_odd[idx],
717                          mpi_datatype, OLDRANK(mynewrank+x_base), 1231,
718                          scr2buf + x_start*typelng, x_count,
719                          mpi_datatype, OLDRANK(mynewrank+x_base), 1232,
720                          comm, &status);
721           MPI_I_do_op((computed?scr1buf:sendbuf) + x_start*typelng,
722                       scr2buf                    + x_start*typelng,
723                       ((root==myrank) && (idx==(n-1))
724                         ? recvbuf + x_start*typelng
725                         : scr3buf + x_start*typelng),
726                       x_count, datatype, op);
727         }
728         else /*odd*/
729         {
730 #         ifdef DEBUG
731             printf("[%2d](%2d) step 5.%d begin  odd c=%1d\n",
732                    myrank,mynewrank,idx+1,computed); fflush(stdout);
733 #         endif
734           x_start = start_odd[idx];
735           x_count = count_odd[idx];
736           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
737                          +start_even[idx]*typelng, count_even[idx],
738                          mpi_datatype, OLDRANK(mynewrank-x_base), 1232,
739                          scr2buf + x_start*typelng, x_count,
740                          mpi_datatype, OLDRANK(mynewrank-x_base), 1231,
741                          comm, &status);
742           MPI_I_do_op(scr2buf                    + x_start*typelng,
743                       (computed?scr1buf:sendbuf) + x_start*typelng,
744                       ((root==myrank) && (idx==(n-1))
745                         ? recvbuf + x_start*typelng
746                         : scr3buf + x_start*typelng),
747                       x_count, datatype, op);
748         }
749         xxx = scr3buf; scr3buf = scr1buf; scr1buf = xxx;
750         computed = 1;
751 #       ifdef DEBUG
752         { int i; printf("[%2d](%2d) after step 5.%d   end: start=%2d  count=%2d  val=",
753                         myrank,mynewrank,idx+1,x_start,x_count);
754           for (i=0; i<x_count; i++)
755            printf(" %5.0lf",((double*)((root==myrank)&&(idx==(n-1))
756                              ? recvbuf + x_start*typelng
757                              : scr1buf + x_start*typelng))[i] );
758           printf("\n"); fflush(stdout);
759         }
760 #       endif
761       } /*for*/
762
763 #     undef OLDRANK
764
765
766     } /* end -- only for nodes with new rank */
767
768     if (is_all)
769     {
770       /*...steps 6.1 to 6.n */
771
772       if (mynewrank >= 0)
773       { /* begin -- only for nodes with new rank */
774
775 #       define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
776
777         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
778         {
779 #         ifdef DEBUG
780             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
781 #         endif
782           if (((mynewrank/x_base) % 2) == 0 /*even*/)
783           {
784             MPI_I_Sendrecv(recvbuf + start_even[idx]*typelng,
785                                      count_even[idx],
786                            mpi_datatype, OLDRANK(mynewrank+x_base),1241,
787                            recvbuf + start_odd[idx]*typelng,
788                                      count_odd[idx],
789                            mpi_datatype, OLDRANK(mynewrank+x_base),1242,
790                            comm, &status);
791 #           ifdef DEBUG
792               x_start = start_odd[idx];
793               x_count = count_odd[idx];
794 #           endif
795           }
796           else /*odd*/
797           {
798             MPI_I_Sendrecv(recvbuf + start_odd[idx]*typelng,
799                                      count_odd[idx],
800                            mpi_datatype, OLDRANK(mynewrank-x_base),1242,
801                            recvbuf + start_even[idx]*typelng,
802                                      count_even[idx],
803                            mpi_datatype, OLDRANK(mynewrank-x_base),1241,
804                            comm, &status);
805 #           ifdef DEBUG
806               x_start = start_even[idx];
807               x_count = count_even[idx];
808 #           endif
809           }
810 #         ifdef DEBUG
811           { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
812                           myrank,mynewrank,n-idx,x_start,x_count);
813             for (i=0; i<x_count; i++)
814              printf(" %5.0lf",((double*)(root==myrank
815                                ? recvbuf + x_start*typelng
816                                : scr1buf + x_start*typelng))[i] );
817             printf("\n"); fflush(stdout);
818           }
819 #         endif
820         } /*for*/
821
822 #       undef OLDRANK
823
824       } /* end -- only for nodes with new rank */
825
826       /*...step 7 */
827
828       if (myrank < 2*r)
829       {
830 #       ifdef DEBUG
831           printf("[%2d] step 7 begin\n",myrank); fflush(stdout);
832 #       endif
833         if (myrank%2 == 0 /*even*/)
834           Request::send(recvbuf, count, mpi_datatype, myrank+1, 1253, comm);
835         else /*odd*/
836           Request::recv(recvbuf, count, mpi_datatype, myrank-1, 1253, comm, &status);
837       }
838
839     }
840     else /* not is_all, i.e. Reduce */
841     {
842
843     /*...step 6.0 */
844
845       if ((root < 2*r) && (root%2 == 1))
846       {
847 #       ifdef DEBUG
848           printf("[%2d] step 6.0 begin\n",myrank); fflush(stdout);
849 #       endif
850         if (myrank == 0) /* then mynewrank==0, x_start==0
851                                  x_count == count/x_size  */
852         {
853           Request::send(scr1buf,x_count,mpi_datatype,root,1241,comm);
854           mynewrank = -1;
855         }
856
857         if (myrank == root)
858         {
859           mynewrank = 0;
860           x_start = 0;
861           x_count = count;
862           for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
863           {
864             start_even[idx] = x_start;
865             count_even[idx] = x_count / 2;
866             start_odd [idx] = x_start + count_even[idx];
867             count_odd [idx] = x_count - count_even[idx];
868             /* code for always even in each bit of mynewrank: */
869             x_start = start_even[idx];
870             x_count = count_even[idx];
871           }
872           Request::recv(recvbuf,x_count,mpi_datatype,0,1241,comm,&status);
873         }
874         newroot = 0;
875       }
876       else
877       {
878         newroot = (root < 2*r ? root/2 : root-r);
879       }
880
881     /*...steps 6.1 to 6.n */
882
883       if (mynewrank >= 0)
884       { /* begin -- only for nodes with new rank */
885
886 #       define OLDRANK(new) ((new)==newroot ? root                     \
887                              : ((new)<r ? (new)*2 : (new)+r) )
888
889         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
890         {
891 #         ifdef DEBUG
892             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
893 #         endif
894           if ((mynewrank & x_base) != (newroot & x_base))
895           {
896             if (((mynewrank/x_base) % 2) == 0 /*even*/)
897             { x_start = start_even[idx]; x_count = count_even[idx];
898               partner = mynewrank+x_base; }
899             else
900             { x_start = start_odd[idx]; x_count = count_odd[idx];
901               partner = mynewrank-x_base; }
902             Request::send(scr1buf + x_start*typelng, x_count, mpi_datatype,
903                      OLDRANK(partner), 1244, comm);
904           }
905           else /*odd*/
906           {
907             if (((mynewrank/x_base) % 2) == 0 /*even*/)
908             { x_start = start_odd[idx]; x_count = count_odd[idx];
909               partner = mynewrank+x_base; }
910             else
911             { x_start = start_even[idx]; x_count = count_even[idx];
912               partner = mynewrank-x_base; }
913             Request::recv((myrank==root ? recvbuf : scr1buf)
914                      + x_start*typelng, x_count, mpi_datatype,
915                      OLDRANK(partner), 1244, comm, &status);
916 #           ifdef DEBUG
917             { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
918                             myrank,mynewrank,n-idx,x_start,x_count);
919               for (i=0; i<x_count; i++)
920                printf(" %5.0lf",((double*)(root==myrank
921                                  ? recvbuf + x_start*typelng
922                                  : scr1buf + x_start*typelng))[i] );
923               printf("\n"); fflush(stdout);
924             }
925 #           endif
926           }
927         } /*for*/
928
929 #       undef OLDRANK
930
931       } /* end -- only for nodes with new rank */
932     }
933
934 #   ifdef NO_CACHE_TESTING
935      xbt_free(scr1buf); xbt_free(scr2buf); xbt_free(scr3buf);
936 #   else
937      xbt_free(scr2buf); /* scr1buf and scr3buf are part of scr2buf */
938 #   endif
939     return(MPI_SUCCESS);
940   } /* new_prot */
941   /*otherwise:*/
942   if (is_all)
943    return( Colls::allreduce(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, comm) );
944   else
945    return( Colls::reduce(Sendbuf,Recvbuf, count,mpi_datatype,mpi_op, root, comm) );
946 }
947 #endif /*REDUCE_LIMITS*/
948
949
950 int Coll_reduce_rab::reduce(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
951 {
952   return( MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, root, comm, 0) );
953 }
954
955 int Coll_allreduce_rab::allreduce(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
956 {
957   return( MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op,   -1, comm, 1) );
958 }
959 }
960 }