Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'smpi_cpp'
[simgrid.git] / src / smpi / colls / reduce-rab.cpp
1 /* extracted from mpig_myreduce.c with
2    :3,$s/MPL/MPI/g     and  :%s/\\\\$/ \\/   */
3
4 /* Copyright: Rolf Rabenseifner, 1997
5  *            Computing Center University of Stuttgart
6  *            rabenseifner@rus.uni-stuttgart.de
7  *
8  * The usage of this software is free,
9  * but this header must not be removed.
10  */
11
12 #include "colls_private.h"
13 #include <stdio.h>
14 #include <stdlib.h>
15
16 #define REDUCE_NEW_ALWAYS 1
17
18 #ifdef CRAY
19 #      define SCR_LNG_OPTIM(bytelng)  128 + ((bytelng+127)/256) * 256;
20                                  /* =  16 + multiple of 32 doubles*/
21
22 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
23        /*  routine =    reduce                allreduce             */ \
24        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
25  static int Lsh[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
26  static int Lin[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
27  static int Llg[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
28  static int Lfp[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
29  static int Ldb[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};   \
30  static int Lby[2][4]={{ 896,1728, 576, 736},{ 448,1280, 512, 512}};
31 #endif
32
33 #ifdef REDUCE_NEW_ALWAYS
34 # undef  REDUCE_LIMITS
35 # define REDUCE_LIMITS  /*  values are lower limits for count arg.  */ \
36        /*  routine =    reduce                allreduce             */ \
37        /*  size    =       2,   3,2**n,other     2,   3,2**n,other  */ \
38  static int Lsh[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
39  static int Lin[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
40  static int Llg[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
41  static int Lfp[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
42  static int Ldb[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};   \
43  static int Lby[2][4]={{   1,   1,   1,   1},{   1,   1,   1,   1}};
44 #endif
45
46 /* Fast reduce and allreduce algorithm for longer buffers and predefined
47    operations.
48
49    This algorithm is explaned with the example of 13 nodes.
50    The nodes are numbered   0, 1, 2, ... 12.
51    The sendbuf content is   a, b, c, ...  m.
52    The buffer array is notated with ABCDEFGH, this means that
53    e.g. 'C' is the third 1/8 of the buffer.
54
55    The algorithm computes
56
57    { [(a+b)+(c+d)] + [(e+f)+(g+h)] }  +  { [(i+j)+k] + [l+m] }
58
59    This is equivalent to the mostly used binary tree algorithm
60    (e.g. used in mpich).
61
62    size := number of nodes in the communicator.
63    2**n := the power of 2 that is next smaller or equal to the size.
64    r    := size - 2**n
65
66    Exa.: size=13 ==> n=3, r=5  (i.e. size == 13 == 2**n+r ==  2**3 + 5)
67
68    The algoritm needs for the execution of one mpi_coll_reduce_fun
69
70    - for r==0
71      exec_time = n*(L1+L2)     + buf_lng * (1-1/2**n) * (T1 + T2 + O/d)
72
73    - for r>0
74      exec_time = (n+1)*(L1+L2) + buf_lng               * T1
75                                + buf_lng*(1+1/2-1/2**n)*     (T2 + O/d)
76
77    with L1 = latency of a message transfer e.g. by send+recv
78         L2 = latency of a message exchange e.g. by sendrecv
79         T1 = additional time to transfer 1 byte (= 1/bandwidth)
80         T2 = additional time to exchange 1 byte in both directions.
81         O  = time for one operation
82         d  = size of the datatype
83
84         On a MPP system it is expected that T1==T2.
85
86    In Comparison with the binary tree algorithm that needs:
87
88    - for r==0
89      exec_time_bin_tree =  n * L1   +  buf_lng * n * (T1 + O/d)
90
91    - for r>0
92      exec_time_bin_tree = (n+1)*L1  +  buf_lng*(n+1)*(T1 + O/d)
93
94    the new algorithm is faster if (assuming T1=T2)
95
96     for n>2:
97      for r==0:  buf_lng  >  L2 *   n   / [ (n-2) * T1 + (n-1) * O/d ]
98      for r>0:   buf_lng  >  L2 * (n+1) / [ (n-1.5)*T1 + (n-0.5)*O/d ]
99
100     for size = 5, 6, and 7:
101                 buf_lng  >  L2 / [0.25 * T1  +  0.58 * O/d ]
102
103     for size = 4:
104                 buf_lng  >  L2 / [0.25 * T1  +  0.62 * O/d ]
105     for size = 2 and 3:
106                 buf_lng  >  L2 / [  0.5 * O/d ]
107
108    and for O/d >> T1 we can summarize:
109
110     The new algorithm is faster for about   buf_lng > 2 * L2 * d / O
111
112    Example L1 = L2 = 50 us,
113            bandwidth=300MB/s, i.e. T1 = T2 = 3.3 ns/byte
114            and 10 MFLOP/s,    i.e. O  = 100 ns/FLOP
115            for double,        i.e. d  =   8 byte/FLOP
116
117            ==> the new algorithm is faster for about buf_lng>8000 bytes
118                i.e count > 1000 doubles !
119 Step 1)
120
121    compute n and r
122
123 Step 2)
124
125    if  myrank < 2*r
126
127      split the buffer into ABCD and EFGH
128
129      even myrank: send    buffer EFGH to   myrank+1 (buf_lng/2 exchange)
130                   receive buffer ABCD from myrank+1
131                   compute op for ABCD                 (count/2 ops)
132                   receive result EFGH               (buf_lng/2 transfer)
133      odd  myrank: send    buffer ABCD to   myrank+1
134                   receive buffer EFGH from myrank+1
135                   compute op for EFGH
136                   send    result EFGH
137
138    Result: node:     0    2    4    6    8  10  11  12
139            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
140
141 Step 3)
142
143    The following algorithm uses only the nodes with
144
145    (myrank is even  &&  myrank < 2*r) || (myrank >= 2*r)
146
147 Step 4)
148
149    define NEWRANK(old) := (old < 2*r ? old/2 : old-r)
150    define OLDRANK(new) := (new < r   ? new*2 : new+r)
151
152    Result:
153        old ranks:    0    2    4    6    8  10  11  12
154        new ranks:    0    1    2    3    4   5   6   7
155            value:  a+b  c+d  e+f  g+h  i+j   k   l   m
156
157 Step 5.1)
158
159    Split the buffer (ABCDEFGH) in the middle,
160    the lower half (ABCD) is computed on even (new) ranks,
161    the opper half (EFGH) is computed on odd  (new) ranks.
162
163    exchange: ABCD from 1 to 0, from 3 to 2, from 5 to 4 and from 7 to 6
164              EFGH from 0 to 1, from 2 to 3, from 4 to 5 and from 6 to 7
165                                                (i.e. buf_lng/2 exchange)
166    compute op in each node on its half:        (i.e.   count/2 ops)
167
168    Result: node 0: (a+b)+(c+d)   for  ABCD
169            node 1: (a+b)+(c+d)   for  EFGH
170            node 2: (e+f)+(g+h)   for  ABCD
171            node 3: (e+f)+(g+h)   for  EFGH
172            node 4: (i+j)+ k      for  ABCD
173            node 5: (i+j)+ k      for  EFGH
174            node 6:    l + m      for  ABCD
175            node 7:    l + m      for  EFGH
176
177 Step 5.2)
178
179    Same with double distance and oncemore the half of the buffer.
180                                                (i.e. buf_lng/4 exchange)
181                                                (i.e.   count/4 ops)
182
183    Result: node 0: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  AB
184            node 1: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  EF
185            node 2: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  CD
186            node 3: [(a+b)+(c+d)] + [(e+f)+(g+h)]  for  GH
187            node 4: [(i+j)+ k   ] + [   l + m   ]  for  AB
188            node 5: [(i+j)+ k   ] + [   l + m   ]  for  EF
189            node 6: [(i+j)+ k   ] + [   l + m   ]  for  CD
190            node 7: [(i+j)+ k   ] + [   l + m   ]  for  GH
191
192 ...
193 Step 5.n)
194
195    Same with double distance and oncemore the half of the buffer.
196                                             (i.e. buf_lng/2**n exchange)
197                                             (i.e.   count/2**n ops)
198
199    Result:
200      0: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for A
201      1: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for E
202      2: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for C
203      3: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for G
204      4: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for B
205      5: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for F
206      6: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for D
207      7: { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] } for H
208
209
210 For mpi_coll_allreduce_fun:
211 ------------------
212
213 Step 6.1)
214
215    Exchange on the last distance (2*n) the last result
216                                             (i.e. buf_lng/2**n exchange)
217
218 Step 6.2)
219
220    Same with half the distance and double of the result
221                                         (i.e. buf_lng/2**(n-1) exchange)
222
223 ...
224 Step 6.n)
225
226    Same with distance 1 and double of the result (== half of the
227    original buffer)                            (i.e. buf_lng/2 exchange)
228
229    Result after 6.1     6.2         6.n
230
231     on node 0:   AB    ABCD    ABCDEFGH
232     on node 1:   EF    EFGH    ABCDEFGH
233     on node 2:   CD    ABCD    ABCDEFGH
234     on node 3:   GH    EFGH    ABCDEFGH
235     on node 4:   AB    ABCD    ABCDEFGH
236     on node 5:   EF    EFGH    ABCDEFGH
237     on node 6:   CD    ABCD    ABCDEFGH
238     on node 7:   GH    EFGH    ABCDEFGH
239
240 Step 7)
241
242    If r > 0
243      transfer the result from the even nodes with old rank < 2*r
244                            to the odd  nodes with old rank < 2*r
245                                                  (i.e. buf_lng transfer)
246    Result:
247      { [(a+b)+(c+d)] + [(e+f)+(g+h)] } + { [(i+j)+k] + [l+m] }
248      for ABCDEFGH
249      on all nodes 0..12
250
251
252 For mpi_coll_reduce_fun:
253 ---------------
254
255 Step 6.0)
256
257    If root node not in the list of the nodes with newranks
258    (see steps 3+4) then
259
260      send last result from node 0 to the root    (buf_lng/2**n transfer)
261      and replace the role of node 0 by the root node.
262
263 Step 6.1)
264
265    Send on the last distance (2**(n-1)) the last result
266    from node with bit '2**(n-1)' in the 'new rank' unequal to that of
267    root's new rank  to the node with same '2**(n-1)' bit.
268                                             (i.e. buf_lng/2**n transfer)
269
270 Step 6.2)
271
272    Same with half the distance and double of the result
273    and bit '2**(n-2)'
274                                         (i.e. buf_lng/2**(n-1) transfer)
275
276 ...
277 Step 6.n)
278
279    Same with distance 1 and double of the result (== half of the
280    original buffer) and bit '2**0'             (i.e. buf_lng/2 transfer)
281
282    Example: roots old rank: 10
283             roots new rank:  5
284
285    Results:          6.1               6.2                 6.n
286                 action result     action result       action result
287
288     on node 0:  send A
289     on node 1:  send E
290     on node 2:  send C
291     on node 3:  send G
292     on node 4:  recv A => AB     recv CD => ABCD   send ABCD
293     on node 5:  recv E => EF     recv GH => EFGH   recv ABCD => ABCDEFGH
294     on node 6:  recv C => CD     send CD
295     on node 7:  recv G => GH     send GH
296
297 Benchmark results on CRAY T3E
298 -----------------------------
299
300    uname -a: (sn6307 hwwt3e 1.6.1.51 unicosmk CRAY T3E)
301    MPI:      /opt/ctl/mpt/1.1.0.3
302    datatype: MPI_DOUBLE
303    Ldb[][] = {{ 896,1728, 576, 736},{ 448,1280, 512, 512}}
304    env: export MPI_BUFFER_MAX=4099
305    compiled with: cc -c -O3 -h restrict=f
306
307    old = binary tree protocol of the vendor
308    new = the new protocol and its implementation
309    mixed = 'new' is used if count > limit(datatype, communicator size)
310
311    REDUCE:
312                                           communicator size
313    measurement count prot. unit    2    3    4    6    8   12   16   24
314    --------------------------------------------------------------------
315    latency         1 mixed  us  20.7 35.1 35.6 49.1 49.2 61.8 62.4 74.8
316                        old  us  19.0 32.9 33.8 47.1 47.2 61.7 62.1 73.2
317    --------------------------------------------------------------------
318    bandwidth    128k mixed MB/s 75.4 34.9 49.1 27.3 41.1 24.6 38.0 23.7
319    (=buf_lng/time)     old MB/s 28.8 16.2 16.3 11.6 11.6  8.8  8.8  7.2
320    ration = mixed/old            2.6  2.1  3.0  2.4  3.5  2.8  4.3  3.3
321    --------------------------------------------------------------------
322    limit                doubles  896 1536  576  736  576  736  576  736
323    bandwidth   limit mixed MB/s 35.9 20.5 18.6 12.3 13.5  9.2 10.0  8.6
324                        old MB/s 35.9 20.3 17.8 13.0 12.5  9.6  9.2  7.8
325    ratio = mixed/old            1.00 1.01 1.04 0.95 1.08 0.96 1.09 1.10
326
327                                           communicator size
328    measurement count prot. unit   32   48   64   96  128  192  256
329    ---------------------------------------------------------------
330    latency         1 mixed  us  77.8 88.5 90.6  102                 1)
331                        old  us  78.6 87.2 90.1 99.7  108  119  120
332    ---------------------------------------------------------------
333    bandwidth    128k mixed MB/s 35.1 23.3 34.1 22.8 34.4 22.4 33.9
334    (=buf_lng/time)     old MB/s  6.0  6.0  6.0  5.2  5.2  4.6  4.6
335    ration = mixed/old            5.8  3.9  5.7  4.4  6.6  4.8  7.4  5)
336    ---------------------------------------------------------------
337    limit                doubles  576  736  576  736  576  736  576  2)
338    bandwidth   limit mixed MB/s  9.7  7.5  8.4  6.5  6.9  5.5  5.1  3)
339                        old MB/s  7.7  6.4  6.4  5.7  5.5  4.9  4.7  3)
340    ratio = mixed/old            1.26 1.17 1.31 1.14 1.26 1.12 1.08  4)
341
342    ALLREDUCE:
343                                           communicator size
344    measurement count prot. unit    2    3    4    6    8   12   16   24
345    --------------------------------------------------------------------
346    latency         1 mixed  us  28.2 51.0 44.5 74.4 59.9  102 74.2  133
347                        old  us  26.9 48.3 42.4 69.8 57.6 96.0 75.7  126
348    --------------------------------------------------------------------
349    bandwidth    128k mixed MB/s 74.0 29.4 42.4 23.3 35.0 20.9 32.8 19.7
350    (=buf_lng/time)     old MB/s 20.9 14.4 13.2  9.7  9.6  7.3  7.4  5.8
351    ration = mixed/old            3.5  2.0  3.2  2.4  3.6  2.9  4.4  3.4
352    --------------------------------------------------------------------
353    limit                doubles  448 1280  512  512  512  512  512  512
354    bandwidth   limit mixed MB/s 26.4 15.1 16.2  8.2 12.4  7.2 10.8  5.7
355                        old MB/s 26.1 14.9 15.0  9.1 10.5  6.7  7.3  5.3
356    ratio = mixed/old            1.01 1.01 1.08 0.90 1.18 1.07 1.48 1.08
357
358                                           communicator size
359    measurement count prot. unit   32   48   64   96  128  192  256
360    ---------------------------------------------------------------
361    latency         1 mixed  us  90.3  162  109  190
362                        old  us  92.7  152  104  179  122  225  135
363    ---------------------------------------------------------------
364    bandwidth    128k mixed MB/s 31.1 19.7 30.1 19.2 29.8 18.7 29.0
365    (=buf_lng/time)     old MB/s  5.9  4.8  5.0  4.1  4.4  3.4  3.8
366    ration = mixed/old            5.3  4.1  6.0  4.7  6.8  5.5  7.7
367    ---------------------------------------------------------------
368    limit                doubles  512  512  512  512  512  512  512
369    bandwidth   limit mixed MB/s  6.6  5.6  5.7  3.5  4.4  3.2  3.8
370                        old MB/s  6.3  4.2  5.4  3.6  4.4  3.1
371    ratio = mixed/old            1.05 1.33 1.06 0.97 1.00 1.03
372
373    Footnotes:
374    1) This line shows that the overhead to decide which protocol
375       should be used can be ignored.
376    2) This line shows the limit for the count argument.
377       If count < limit then the vendor protocol is used,
378       otherwise the new protocol is used (see variable Ldb).
379    3) These lines show the bandwidth (=bufer length / execution time)
380       for both protocols.
381    4) This line shows that the limit is choosen well if the ratio is
382       between 0.95 (loosing 5% for buffer length near and >=limit)
383       and 1.10 (not gaining 10% for buffer length near and <limit).
384    5) This line shows that the new protocol is 2..7 times faster
385       for long counts.
386
387 */
388
389 #ifdef REDUCE_LIMITS
390
391 #ifdef USE_Irecv
392 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
393            { MPI_Request req;                                          \
394              req=smpi_mpi_irecv(rb,rc,rd,source,rt,comm);                  \
395              smpi_mpi_send(sb,sc,sd,dest,st,comm);                          \
396              smpi_mpi_wait(&req,stat);                                      \
397            }
398 #else
399 #ifdef USE_Isend
400 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
401            { MPI_Request req;                                          \
402              req=mpi_mpi_isend(sb,sc,sd,dest,st,comm);                    \
403              smpi_mpi_recv(rb,rc,rd,source,rt,comm,stat);                   \
404              smpi_mpi_wait(&req,stat);                                      \
405            }
406 #else
407 #define  MPI_I_Sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat) \
408            smpi_mpi_sendrecv(sb,sc,sd,dest,st,rb,rc,rd,source,rt,comm,stat)
409 #endif
410 #endif
411
412 typedef enum {MPIM_SHORT, MPIM_INT, MPIM_LONG, MPIM_UNSIGNED_SHORT,
413               MPIM_UNSIGNED, MPIM_UNSIGNED_LONG, MPIM_UNSIGNED_LONG_LONG, MPIM_FLOAT,
414               MPIM_DOUBLE, MPIM_BYTE} MPIM_Datatype;
415
416 typedef enum {MPIM_MAX, MPIM_MIN, MPIM_SUM, MPIM_PROD,
417               MPIM_LAND, MPIM_BAND, MPIM_LOR, MPIM_BOR,
418               MPIM_LXOR, MPIM_BXOR} MPIM_Op;
419 #define MPI_I_DO_OP_C_INTEGER(MPI_I_do_op_TYPE,TYPE)                   \
420 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
421 { int i;                                                               \
422   switch (op) {                                                        \
423    case MPIM_MAX :                                                     \
424       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
425    case MPIM_MIN :                                                     \
426       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
427    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
428    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
429    case MPIM_LAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] && b2[i]; break;  \
430    case MPIM_LOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] || b2[i]; break;  \
431    case MPIM_LXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] != b2[i]; break;  \
432    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
433    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
434    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
435    default: break;                                                     \
436   }                                                                    \
437 }
438
439 #define MPI_I_DO_OP_FP(MPI_I_do_op_TYPE,TYPE)                          \
440 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op) \
441 { int i;                                                               \
442   switch (op) {                                                        \
443    case MPIM_MAX :                                                     \
444       for(i=0;i<cnt;i++) rslt[i] = (b1[i]>b2[i]?b1[i]:b2[i]); break;   \
445    case MPIM_MIN :                                                     \
446       for(i=0;i<cnt;i++) rslt[i] = (b1[i]<b2[i]?b1[i]:b2[i]); break;   \
447    case MPIM_SUM :for(i=0;i<cnt;i++) rslt[i] = b1[i] +  b2[i]; break;  \
448    case MPIM_PROD:for(i=0;i<cnt;i++) rslt[i] = b1[i] *  b2[i]; break;  \
449    default: break;                                                     \
450   }                                                                    \
451 }
452
453 #define MPI_I_DO_OP_BYTE(MPI_I_do_op_TYPE,TYPE)                        \
454 static void MPI_I_do_op_TYPE(TYPE* b1,TYPE* b2,TYPE* rslt, int cnt,MPIM_Op op)\
455 { int i;                                                               \
456   switch (op) {                                                        \
457    case MPIM_BAND:for(i=0;i<cnt;i++) rslt[i] = b1[i] &  b2[i]; break;  \
458    case MPIM_BOR :for(i=0;i<cnt;i++) rslt[i] = b1[i] |  b2[i]; break;  \
459    case MPIM_BXOR:for(i=0;i<cnt;i++) rslt[i] = b1[i] ^  b2[i]; break;  \
460    default: break;                                                     \
461   }                                                                    \
462 }
463
464 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_short,  short)
465 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_int,    int)
466 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_long,   long)
467 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ushort, unsigned short)
468 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_uint,   unsigned int)
469 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulong,  unsigned long)
470 MPI_I_DO_OP_C_INTEGER( MPI_I_do_op_ulonglong,  unsigned long long)
471 MPI_I_DO_OP_FP(        MPI_I_do_op_float,  float)
472 MPI_I_DO_OP_FP(        MPI_I_do_op_double, double)
473 MPI_I_DO_OP_BYTE(      MPI_I_do_op_byte,   char)
474
475 #define MPI_I_DO_OP_CALL(MPI_I_do_op_TYPE,TYPE)                        \
476    MPI_I_do_op_TYPE ((TYPE*)b1, (TYPE*)b2, (TYPE*)rslt, cnt, op); break;
477
478 static void MPI_I_do_op(void* b1, void* b2, void* rslt, int cnt,
479                  MPIM_Datatype datatype, MPIM_Op op)
480 {
481  switch (datatype) {
482   case MPIM_SHORT : MPI_I_DO_OP_CALL(MPI_I_do_op_short,  short)
483   case MPIM_INT   : MPI_I_DO_OP_CALL(MPI_I_do_op_int,    int)
484   case MPIM_LONG  : MPI_I_DO_OP_CALL(MPI_I_do_op_long,   long)
485   case MPIM_UNSIGNED_SHORT:
486                    MPI_I_DO_OP_CALL(MPI_I_do_op_ushort, unsigned short)
487   case MPIM_UNSIGNED:
488                    MPI_I_DO_OP_CALL(MPI_I_do_op_uint,   unsigned int)
489   case MPIM_UNSIGNED_LONG:
490                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulong,  unsigned long)
491   case MPIM_UNSIGNED_LONG_LONG:
492                    MPI_I_DO_OP_CALL(MPI_I_do_op_ulonglong,  unsigned long long)
493   case MPIM_FLOAT : MPI_I_DO_OP_CALL(MPI_I_do_op_float,  float)
494   case MPIM_DOUBLE: MPI_I_DO_OP_CALL(MPI_I_do_op_double, double)
495   case MPIM_BYTE  : MPI_I_DO_OP_CALL(MPI_I_do_op_byte,   char)
496  }
497 }
498
499 REDUCE_LIMITS
500
501 static int MPI_I_anyReduce(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype mpi_datatype, MPI_Op mpi_op, int root, MPI_Comm comm, int is_all)
502 {
503   char *scr1buf, *scr2buf, *scr3buf, *xxx, *sendbuf, *recvbuf;
504   int myrank, size, x_base, x_size, computed, idx;
505   int x_start, x_count = 0, r, n, mynewrank, newroot, partner;
506   int start_even[20], start_odd[20], count_even[20], count_odd[20];
507   MPI_Aint typelng;
508   MPI_Status status;
509   size_t scrlng;
510   int new_prot;
511   MPIM_Datatype datatype = MPIM_INT; MPIM_Op op = MPIM_MAX;
512
513   if     (mpi_datatype==MPI_SHORT         ) datatype=MPIM_SHORT;
514   else if(mpi_datatype==MPI_INT           ) datatype=MPIM_INT;
515   else if(mpi_datatype==MPI_LONG          ) datatype=MPIM_LONG;
516   else if(mpi_datatype==MPI_UNSIGNED_SHORT) datatype=MPIM_UNSIGNED_SHORT;
517   else if(mpi_datatype==MPI_UNSIGNED      ) datatype=MPIM_UNSIGNED;
518   else if(mpi_datatype==MPI_UNSIGNED_LONG ) datatype=MPIM_UNSIGNED_LONG;
519   else if(mpi_datatype==MPI_UNSIGNED_LONG_LONG ) datatype=MPIM_UNSIGNED_LONG_LONG;
520   else if(mpi_datatype==MPI_FLOAT         ) datatype=MPIM_FLOAT;
521   else if(mpi_datatype==MPI_DOUBLE        ) datatype=MPIM_DOUBLE;
522   else if(mpi_datatype==MPI_BYTE          ) datatype=MPIM_BYTE;
523   else
524    THROWF(arg_error,0, "reduce rab algorithm can't be used with this datatype ! ");
525
526   if     (mpi_op==MPI_MAX     ) op=MPIM_MAX;
527   else if(mpi_op==MPI_MIN     ) op=MPIM_MIN;
528   else if(mpi_op==MPI_SUM     ) op=MPIM_SUM;
529   else if(mpi_op==MPI_PROD    ) op=MPIM_PROD;
530   else if(mpi_op==MPI_LAND    ) op=MPIM_LAND;
531   else if(mpi_op==MPI_BAND    ) op=MPIM_BAND;
532   else if(mpi_op==MPI_LOR     ) op=MPIM_LOR;
533   else if(mpi_op==MPI_BOR     ) op=MPIM_BOR;
534   else if(mpi_op==MPI_LXOR    ) op=MPIM_LXOR;
535   else if(mpi_op==MPI_BXOR    ) op=MPIM_BXOR;
536
537   new_prot = 0;
538   MPI_Comm_size(comm, &size);
539   if (size > 1) /*otherwise no balancing_protocol*/
540   { int ss;
541     if      (size==2) ss=0;
542     else if (size==3) ss=1;
543     else { int s = size; while (!(s & 1)) s = s >> 1;
544            if (s==1) /* size == power of 2 */ ss = 2;
545            else      /* size != power of 2 */ ss = 3; }
546     switch(op) {
547      case MPIM_MAX:   case MPIM_MIN: case MPIM_SUM:  case MPIM_PROD:
548      case MPIM_LAND:  case MPIM_LOR: case MPIM_LXOR:
549      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
550       switch(datatype) {
551         case MPIM_SHORT:  case MPIM_UNSIGNED_SHORT:
552          new_prot = count >= Lsh[is_all][ss]; break;
553         case MPIM_INT:    case MPIM_UNSIGNED:
554          new_prot = count >= Lin[is_all][ss]; break;
555         case MPIM_LONG:   case MPIM_UNSIGNED_LONG: case MPIM_UNSIGNED_LONG_LONG:
556          new_prot = count >= Llg[is_all][ss]; break;
557      default:
558         break;
559     } default:
560         break;}
561     switch(op) {
562      case MPIM_MAX:  case MPIM_MIN: case MPIM_SUM: case MPIM_PROD:
563       switch(datatype) {
564         case MPIM_FLOAT:
565          new_prot = count >= Lfp[is_all][ss]; break;
566         case MPIM_DOUBLE:
567          new_prot = count >= Ldb[is_all][ss]; break;
568               default:
569         break;
570     } default:
571         break;}
572     switch(op) {
573      case MPIM_BAND:  case MPIM_BOR: case MPIM_BXOR:
574       switch(datatype) {
575         case MPIM_BYTE:
576          new_prot = count >= Lby[is_all][ss]; break;
577               default:
578         break;
579     } default:
580         break;}
581 #   ifdef DEBUG
582     { char *ss_str[]={"two","three","power of 2","no power of 2"};
583       printf("MPI_(All)Reduce: is_all=%1d, size=%1d=%s, new_prot=%1d\n",
584              is_all, size, ss_str[ss], new_prot); fflush(stdout);
585     }
586 #   endif
587   }
588
589   if (new_prot)
590   {
591     sendbuf = (char*) Sendbuf;
592     recvbuf = (char*) Recvbuf;
593     MPI_Comm_rank(comm, &myrank);
594     MPI_Type_extent(mpi_datatype, &typelng);
595     scrlng  = typelng * count;
596 #ifdef NO_CACHE_OPTIMIZATION
597     scr1buf = static_cast<char*>(xbt_malloc(scrlng));
598     scr2buf = static_cast<char*>(xbt_malloc(scrlng));
599     scr3buf = static_cast<char*>(xbt_malloc(scrlng));
600 #else
601 #  ifdef SCR_LNG_OPTIM
602     scrlng = SCR_LNG_OPTIM(scrlng);
603 #  endif
604     scr2buf = static_cast<char*>(xbt_malloc(3*scrlng));   /* To test cache problems.     */
605     scr1buf = scr2buf + 1*scrlng; /* scr1buf and scr3buf must not*/
606     scr3buf = scr2buf + 2*scrlng; /* be used for malloc because  */
607                                   /* they are interchanged below.*/
608 #endif
609     computed = 0;
610     if (is_all) root = myrank; /* for correct recvbuf handling */
611
612   /*...step 1 */
613
614 #   ifdef DEBUG
615      printf("[%2d] step 1 begin\n",myrank); fflush(stdout);
616 #   endif
617     n = 0; x_size = 1;
618     while (2*x_size <= size) { n++; x_size = x_size * 2; }
619     /* x_sixe == 2**n */
620     r = size - x_size;
621
622   /*...step 2 */
623
624 #   ifdef DEBUG
625     printf("[%2d] step 2 begin n=%d, r=%d\n",myrank,n,r);fflush(stdout);
626 #   endif
627     if (myrank < 2*r)
628     {
629       if ((myrank % 2) == 0 /*even*/)
630       {
631         MPI_I_Sendrecv(sendbuf + (count/2)*typelng,
632                        count - count/2, mpi_datatype, myrank+1, 1220,
633                        scr2buf, count/2,mpi_datatype, myrank+1, 1221,
634                        comm, &status);
635         MPI_I_do_op(sendbuf, scr2buf, scr1buf,
636                     count/2, datatype, op);
637         smpi_mpi_recv(scr1buf + (count/2)*typelng, count - count/2,
638                  mpi_datatype, myrank+1, 1223, comm, &status);
639         computed = 1;
640 #       ifdef DEBUG
641         { int i; printf("[%2d] after step 2: val=",
642                         myrank);
643           for (i=0; i<count; i++)
644            printf(" %5.0lf",((double*)scr1buf)[i] );
645           printf("\n"); fflush(stdout);
646         }
647 #       endif
648       }
649       else /*odd*/
650       {
651         MPI_I_Sendrecv(sendbuf, count/2,mpi_datatype, myrank-1, 1221,
652                        scr2buf + (count/2)*typelng,
653                        count - count/2, mpi_datatype, myrank-1, 1220,
654                        comm, &status);
655         MPI_I_do_op(scr2buf + (count/2)*typelng,
656                     sendbuf + (count/2)*typelng,
657                     scr1buf + (count/2)*typelng,
658                     count - count/2, datatype, op);
659         smpi_mpi_send(scr1buf + (count/2)*typelng, count - count/2,
660                  mpi_datatype, myrank-1, 1223, comm);
661       }
662     }
663
664   /*...step 3+4 */
665
666 #   ifdef DEBUG
667      printf("[%2d] step 3+4 begin\n",myrank); fflush(stdout);
668 #   endif
669     if ((myrank >= 2*r) || ((myrank%2 == 0)  &&  (myrank < 2*r)))
670          mynewrank = (myrank < 2*r ? myrank/2 : myrank-r);
671     else mynewrank = -1;
672
673     if (mynewrank >= 0)
674     { /* begin -- only for nodes with new rank */
675
676 #     define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
677
678   /*...step 5 */
679
680       x_start = 0;
681       x_count = count;
682       for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
683       {
684         start_even[idx] = x_start;
685         count_even[idx] = x_count / 2;
686         start_odd [idx] = x_start + count_even[idx];
687         count_odd [idx] = x_count - count_even[idx];
688         if (((mynewrank/x_base) % 2) == 0 /*even*/)
689         {
690 #         ifdef DEBUG
691             printf("[%2d](%2d) step 5.%d begin even c=%1d\n",
692                    myrank,mynewrank,idx+1,computed); fflush(stdout);
693 #         endif
694           x_start = start_even[idx];
695           x_count = count_even[idx];
696           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
697                          + start_odd[idx]*typelng, count_odd[idx],
698                          mpi_datatype, OLDRANK(mynewrank+x_base), 1231,
699                          scr2buf + x_start*typelng, x_count,
700                          mpi_datatype, OLDRANK(mynewrank+x_base), 1232,
701                          comm, &status);
702           MPI_I_do_op((computed?scr1buf:sendbuf) + x_start*typelng,
703                       scr2buf                    + x_start*typelng,
704                       ((root==myrank) && (idx==(n-1))
705                         ? recvbuf + x_start*typelng
706                         : scr3buf + x_start*typelng),
707                       x_count, datatype, op);
708         }
709         else /*odd*/
710         {
711 #         ifdef DEBUG
712             printf("[%2d](%2d) step 5.%d begin  odd c=%1d\n",
713                    myrank,mynewrank,idx+1,computed); fflush(stdout);
714 #         endif
715           x_start = start_odd[idx];
716           x_count = count_odd[idx];
717           MPI_I_Sendrecv((computed ? scr1buf : sendbuf)
718                          +start_even[idx]*typelng, count_even[idx],
719                          mpi_datatype, OLDRANK(mynewrank-x_base), 1232,
720                          scr2buf + x_start*typelng, x_count,
721                          mpi_datatype, OLDRANK(mynewrank-x_base), 1231,
722                          comm, &status);
723           MPI_I_do_op(scr2buf                    + x_start*typelng,
724                       (computed?scr1buf:sendbuf) + x_start*typelng,
725                       ((root==myrank) && (idx==(n-1))
726                         ? recvbuf + x_start*typelng
727                         : scr3buf + x_start*typelng),
728                       x_count, datatype, op);
729         }
730         xxx = scr3buf; scr3buf = scr1buf; scr1buf = xxx;
731         computed = 1;
732 #       ifdef DEBUG
733         { int i; printf("[%2d](%2d) after step 5.%d   end: start=%2d  count=%2d  val=",
734                         myrank,mynewrank,idx+1,x_start,x_count);
735           for (i=0; i<x_count; i++)
736            printf(" %5.0lf",((double*)((root==myrank)&&(idx==(n-1))
737                              ? recvbuf + x_start*typelng
738                              : scr1buf + x_start*typelng))[i] );
739           printf("\n"); fflush(stdout);
740         }
741 #       endif
742       } /*for*/
743
744 #     undef OLDRANK
745
746
747     } /* end -- only for nodes with new rank */
748
749     if (is_all)
750     {
751       /*...steps 6.1 to 6.n */
752
753       if (mynewrank >= 0)
754       { /* begin -- only for nodes with new rank */
755
756 #       define OLDRANK(new)   ((new) < r ? (new)*2 : (new)+r)
757
758         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
759         {
760 #         ifdef DEBUG
761             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
762 #         endif
763           if (((mynewrank/x_base) % 2) == 0 /*even*/)
764           {
765             MPI_I_Sendrecv(recvbuf + start_even[idx]*typelng,
766                                      count_even[idx],
767                            mpi_datatype, OLDRANK(mynewrank+x_base),1241,
768                            recvbuf + start_odd[idx]*typelng,
769                                      count_odd[idx],
770                            mpi_datatype, OLDRANK(mynewrank+x_base),1242,
771                            comm, &status);
772 #           ifdef DEBUG
773               x_start = start_odd[idx];
774               x_count = count_odd[idx];
775 #           endif
776           }
777           else /*odd*/
778           {
779             MPI_I_Sendrecv(recvbuf + start_odd[idx]*typelng,
780                                      count_odd[idx],
781                            mpi_datatype, OLDRANK(mynewrank-x_base),1242,
782                            recvbuf + start_even[idx]*typelng,
783                                      count_even[idx],
784                            mpi_datatype, OLDRANK(mynewrank-x_base),1241,
785                            comm, &status);
786 #           ifdef DEBUG
787               x_start = start_even[idx];
788               x_count = count_even[idx];
789 #           endif
790           }
791 #         ifdef DEBUG
792           { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
793                           myrank,mynewrank,n-idx,x_start,x_count);
794             for (i=0; i<x_count; i++)
795              printf(" %5.0lf",((double*)(root==myrank
796                                ? recvbuf + x_start*typelng
797                                : scr1buf + x_start*typelng))[i] );
798             printf("\n"); fflush(stdout);
799           }
800 #         endif
801         } /*for*/
802
803 #       undef OLDRANK
804
805       } /* end -- only for nodes with new rank */
806
807       /*...step 7 */
808
809       if (myrank < 2*r)
810       {
811 #       ifdef DEBUG
812           printf("[%2d] step 7 begin\n",myrank); fflush(stdout);
813 #       endif
814         if (myrank%2 == 0 /*even*/)
815           smpi_mpi_send(recvbuf, count, mpi_datatype, myrank+1, 1253, comm);
816         else /*odd*/
817           smpi_mpi_recv(recvbuf, count, mpi_datatype, myrank-1, 1253, comm, &status);
818       }
819
820     }
821     else /* not is_all, i.e. Reduce */
822     {
823
824     /*...step 6.0 */
825
826       if ((root < 2*r) && (root%2 == 1))
827       {
828 #       ifdef DEBUG
829           printf("[%2d] step 6.0 begin\n",myrank); fflush(stdout);
830 #       endif
831         if (myrank == 0) /* then mynewrank==0, x_start==0
832                                  x_count == count/x_size  */
833         {
834           smpi_mpi_send(scr1buf,x_count,mpi_datatype,root,1241,comm);
835           mynewrank = -1;
836         }
837
838         if (myrank == root)
839         {
840           mynewrank = 0;
841           x_start = 0;
842           x_count = count;
843           for (idx=0, x_base=1; idx<n; idx++, x_base=x_base*2)
844           {
845             start_even[idx] = x_start;
846             count_even[idx] = x_count / 2;
847             start_odd [idx] = x_start + count_even[idx];
848             count_odd [idx] = x_count - count_even[idx];
849             /* code for always even in each bit of mynewrank: */
850             x_start = start_even[idx];
851             x_count = count_even[idx];
852           }
853           smpi_mpi_recv(recvbuf,x_count,mpi_datatype,0,1241,comm,&status);
854         }
855         newroot = 0;
856       }
857       else
858       {
859         newroot = (root < 2*r ? root/2 : root-r);
860       }
861
862     /*...steps 6.1 to 6.n */
863
864       if (mynewrank >= 0)
865       { /* begin -- only for nodes with new rank */
866
867 #       define OLDRANK(new) ((new)==newroot ? root                     \
868                              : ((new)<r ? (new)*2 : (new)+r) )
869
870         for(idx=n-1, x_base=x_size/2; idx>=0; idx--, x_base=x_base/2)
871         {
872 #         ifdef DEBUG
873             printf("[%2d](%2d) step 6.%d begin\n",myrank,mynewrank,n-idx); fflush(stdout);
874 #         endif
875           if ((mynewrank & x_base) != (newroot & x_base))
876           {
877             if (((mynewrank/x_base) % 2) == 0 /*even*/)
878             { x_start = start_even[idx]; x_count = count_even[idx];
879               partner = mynewrank+x_base; }
880             else
881             { x_start = start_odd[idx]; x_count = count_odd[idx];
882               partner = mynewrank-x_base; }
883             smpi_mpi_send(scr1buf + x_start*typelng, x_count, mpi_datatype,
884                      OLDRANK(partner), 1244, comm);
885           }
886           else /*odd*/
887           {
888             if (((mynewrank/x_base) % 2) == 0 /*even*/)
889             { x_start = start_odd[idx]; x_count = count_odd[idx];
890               partner = mynewrank+x_base; }
891             else
892             { x_start = start_even[idx]; x_count = count_even[idx];
893               partner = mynewrank-x_base; }
894             smpi_mpi_recv((myrank==root ? recvbuf : scr1buf)
895                      + x_start*typelng, x_count, mpi_datatype,
896                      OLDRANK(partner), 1244, comm, &status);
897 #           ifdef DEBUG
898             { int i; printf("[%2d](%2d) after step 6.%d   end: start=%2d  count=%2d  val=",
899                             myrank,mynewrank,n-idx,x_start,x_count);
900               for (i=0; i<x_count; i++)
901                printf(" %5.0lf",((double*)(root==myrank
902                                  ? recvbuf + x_start*typelng
903                                  : scr1buf + x_start*typelng))[i] );
904               printf("\n"); fflush(stdout);
905             }
906 #           endif
907           }
908         } /*for*/
909
910 #       undef OLDRANK
911
912       } /* end -- only for nodes with new rank */
913     }
914
915 #   ifdef NO_CACHE_TESTING
916      xbt_free(scr1buf); xbt_free(scr2buf); xbt_free(scr3buf);
917 #   else
918      xbt_free(scr2buf); /* scr1buf and scr3buf are part of scr2buf */
919 #   endif
920     return(MPI_SUCCESS);
921   } /* new_prot */
922   /*otherwise:*/
923   if (is_all)
924    return( mpi_coll_allreduce_fun(Sendbuf, Recvbuf, count, mpi_datatype, mpi_op, comm) );
925   else
926    return( mpi_coll_reduce_fun(Sendbuf,Recvbuf, count,mpi_datatype,mpi_op, root, comm) );
927 }
928 #endif /*REDUCE_LIMITS*/
929
930
931 int smpi_coll_tuned_reduce_rab(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
932 {
933   return( MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op, root, comm, 0) );
934 }
935
936 int smpi_coll_tuned_allreduce_rab(void* Sendbuf, void* Recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
937 {
938   return( MPI_I_anyReduce(Sendbuf, Recvbuf, count, datatype, op,   -1, comm, 1) );
939 }