Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Concatenate nested namespaces (sonar).
[simgrid.git] / src / smpi / colls / smpi_openmpi_selector.cpp
1 /* selector for collective algorithms based on openmpi's default coll_tuned_decision_fixed selector
2  * Updated 02/2022                                                          */
3
4 /* Copyright (c) 2009-2022. The SimGrid Team.
5  * All rights reserved.                                                     */
6
7 /* This program is free software; you can redistribute it and/or modify it
8  * under the terms of the license (GNU LGPL) which comes with this package. */
9
10 #include "colls_private.hpp"
11
12 #include <memory>
13
14 namespace simgrid::smpi {
15
16 int allreduce__ompi(const void *sbuf, void *rbuf, int count,
17                     MPI_Datatype dtype, MPI_Op op, MPI_Comm comm)
18 {
19     size_t total_dsize = dtype->size() * (ptrdiff_t)count;
20     int communicator_size = comm->size();
21     int alg = 1;
22     int(*funcs[]) (const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm)={
23         &allreduce__redbcast,
24         &allreduce__redbcast,
25         &allreduce__rdb,
26         &allreduce__lr,
27         &allreduce__ompi_ring_segmented,
28         &allreduce__rab_rdb
29     };
30
31     /** Algorithms:
32      *  {1, "basic_linear"},
33      *  {2, "nonoverlapping"},
34      *  {3, "recursive_doubling"},
35      *  {4, "ring"},
36      *  {5, "segmented_ring"},
37      *  {6, "rabenseifner"
38      *
39      * Currently, ring, segmented ring, and rabenseifner do not support
40      * non-commutative operations.
41      */
42     if ((op != MPI_OP_NULL) && not op->is_commutative()) {
43         if (communicator_size < 4) {
44             if (total_dsize < 131072) {
45                 alg = 3;
46             } else {
47                 alg = 1;
48             }
49         } else if (communicator_size < 8) {
50             alg = 3;
51         } else if (communicator_size < 16) {
52             if (total_dsize < 1048576) {
53                 alg = 3;
54             } else {
55                 alg = 2;
56             }
57         } else if (communicator_size < 128) {
58             alg = 3;
59         } else if (communicator_size < 256) {
60             if (total_dsize < 131072) {
61                 alg = 2;
62             } else if (total_dsize < 524288) {
63                 alg = 3;
64             } else {
65                 alg = 2;
66             }
67         } else if (communicator_size < 512) {
68             if (total_dsize < 4096) {
69                 alg = 2;
70             } else if (total_dsize < 524288) {
71                 alg = 3;
72             } else {
73                 alg = 2;
74             }
75         } else {
76             if (total_dsize < 2048) {
77                 alg = 2;
78             } else {
79                 alg = 3;
80             }
81         }
82     } else {
83         if (communicator_size < 4) {
84             if (total_dsize < 8) {
85                 alg = 4;
86             } else if (total_dsize < 4096) {
87                 alg = 3;
88             } else if (total_dsize < 8192) {
89                 alg = 4;
90             } else if (total_dsize < 16384) {
91                 alg = 3;
92             } else if (total_dsize < 65536) {
93                 alg = 4;
94             } else if (total_dsize < 262144) {
95                 alg = 5;
96             } else {
97                 alg = 6;
98             }
99         } else if (communicator_size < 8) {
100             if (total_dsize < 16) {
101                 alg = 4;
102             } else if (total_dsize < 8192) {
103                 alg = 3;
104             } else {
105                 alg = 6;
106             }
107         } else if (communicator_size < 16) {
108             if (total_dsize < 8192) {
109                 alg = 3;
110             } else {
111                 alg = 6;
112             }
113         } else if (communicator_size < 32) {
114             if (total_dsize < 64) {
115                 alg = 5;
116             } else if (total_dsize < 4096) {
117                 alg = 3;
118             } else {
119                 alg = 6;
120             }
121         } else if (communicator_size < 64) {
122             if (total_dsize < 128) {
123                 alg = 5;
124             } else {
125                 alg = 6;
126             }
127         } else if (communicator_size < 128) {
128             if (total_dsize < 262144) {
129                 alg = 3;
130             } else {
131                 alg = 6;
132             }
133         } else if (communicator_size < 256) {
134             if (total_dsize < 131072) {
135                 alg = 2;
136             } else if (total_dsize < 262144) {
137                 alg = 3;
138             } else {
139                 alg = 6;
140             }
141         } else if (communicator_size < 512) {
142             if (total_dsize < 4096) {
143                 alg = 2;
144             } else {
145                 alg = 6;
146             }
147         } else if (communicator_size < 2048) {
148             if (total_dsize < 2048) {
149                 alg = 2;
150             } else if (total_dsize < 16384) {
151                 alg = 3;
152             } else {
153                 alg = 6;
154             }
155         } else if (communicator_size < 4096) {
156             if (total_dsize < 2048) {
157                 alg = 2;
158             } else if (total_dsize < 4096) {
159                 alg = 5;
160             } else if (total_dsize < 16384) {
161                 alg = 3;
162             } else {
163                 alg = 6;
164             }
165         } else {
166             if (total_dsize < 2048) {
167                 alg = 2;
168             } else if (total_dsize < 16384) {
169                 alg = 5;
170             } else if (total_dsize < 32768) {
171                 alg = 3;
172             } else {
173                 alg = 6;
174             }
175         }
176     }
177     return funcs[alg-1](sbuf, rbuf, count, dtype, op, comm);
178 }
179
180
181
182 int alltoall__ompi(const void *sbuf, int scount,
183                    MPI_Datatype sdtype,
184                    void* rbuf, int rcount,
185                    MPI_Datatype rdtype,
186                    MPI_Comm comm)
187 {
188     int alg = 1;
189     size_t dsize, total_dsize;
190     int communicator_size = comm->size();
191
192     if (MPI_IN_PLACE != sbuf) {
193         dsize = sdtype->size();
194     } else {
195         dsize = rdtype->size();
196     }
197     total_dsize = dsize * (ptrdiff_t)scount;
198     int (*funcs[])(const void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm) = {
199         &alltoall__basic_linear,
200         &alltoall__pair,
201         &alltoall__bruck,
202         &alltoall__basic_linear,
203         &alltoall__basic_linear
204     };
205     /** Algorithms:
206      *  {1, "linear"},
207      *  {2, "pairwise"},
208      *  {3, "modified_bruck"},
209      *  {4, "linear_sync"},
210      *  {5, "two_proc"},
211      */
212     if (communicator_size == 2) {
213         if (total_dsize < 2) {
214             alg = 2;
215         } else if (total_dsize < 4) {
216             alg = 5;
217         } else if (total_dsize < 16) {
218             alg = 2;
219         } else if (total_dsize < 64) {
220             alg = 5;
221         } else if (total_dsize < 256) {
222             alg = 2;
223         } else if (total_dsize < 4096) {
224             alg = 5;
225         } else if (total_dsize < 32768) {
226             alg = 2;
227         } else if (total_dsize < 262144) {
228             alg = 4;
229         } else if (total_dsize < 1048576) {
230             alg = 5;
231         } else {
232             alg = 2;
233         }
234     } else if (communicator_size < 8) {
235         if (total_dsize < 8192) {
236             alg = 4;
237         } else if (total_dsize < 16384) {
238             alg = 1;
239         } else if (total_dsize < 65536) {
240             alg = 4;
241         } else if (total_dsize < 524288) {
242             alg = 1;
243         } else if (total_dsize < 1048576) {
244             alg = 2;
245         } else {
246             alg = 1;
247         }
248     } else if (communicator_size < 16) {
249         if (total_dsize < 262144) {
250             alg = 4;
251         } else {
252             alg = 1;
253         }
254     } else if (communicator_size < 32) {
255         if (total_dsize < 4) {
256             alg = 4;
257         } else if (total_dsize < 512) {
258             alg = 3;
259         } else if (total_dsize < 8192) {
260             alg = 4;
261         } else if (total_dsize < 32768) {
262             alg = 1;
263         } else if (total_dsize < 262144) {
264             alg = 4;
265         } else if (total_dsize < 524288) {
266             alg = 1;
267         } else {
268             alg = 4;
269         }
270     } else if (communicator_size < 64) {
271         if (total_dsize < 512) {
272             alg = 3;
273         } else if (total_dsize < 524288) {
274             alg = 1;
275         } else {
276             alg = 4;
277         }
278     } else if (communicator_size < 128) {
279         if (total_dsize < 1024) {
280             alg = 3;
281         } else if (total_dsize < 2048) {
282             alg = 1;
283         } else if (total_dsize < 4096) {
284             alg = 4;
285         } else if (total_dsize < 262144) {
286             alg = 1;
287         } else {
288             alg = 2;
289         }
290     } else if (communicator_size < 256) {
291         if (total_dsize < 1024) {
292             alg = 3;
293         } else if (total_dsize < 2048) {
294             alg = 4;
295         } else if (total_dsize < 262144) {
296             alg = 1;
297         } else {
298             alg = 2;
299         }
300     } else if (communicator_size < 512) {
301         if (total_dsize < 1024) {
302             alg = 3;
303         } else if (total_dsize < 8192) {
304             alg = 4;
305         } else if (total_dsize < 32768) {
306             alg = 1;
307         } else {
308             alg = 2;
309         }
310     } else if (communicator_size < 1024) {
311         if (total_dsize < 512) {
312             alg = 3;
313         } else if (total_dsize < 8192) {
314             alg = 4;
315         } else if (total_dsize < 16384) {
316             alg = 1;
317         } else if (total_dsize < 131072) {
318             alg = 4;
319         } else if (total_dsize < 262144) {
320             alg = 1;
321         } else {
322             alg = 2;
323         }
324     } else if (communicator_size < 2048) {
325         if (total_dsize < 512) {
326             alg = 3;
327         } else if (total_dsize < 1024) {
328             alg = 4;
329         } else if (total_dsize < 2048) {
330             alg = 1;
331         } else if (total_dsize < 16384) {
332             alg = 4;
333         } else if (total_dsize < 262144) {
334             alg = 1;
335         } else {
336             alg = 4;
337         }
338     } else if (communicator_size < 4096) {
339         if (total_dsize < 1024) {
340             alg = 3;
341         } else if (total_dsize < 4096) {
342             alg = 4;
343         } else if (total_dsize < 8192) {
344             alg = 1;
345         } else if (total_dsize < 131072) {
346             alg = 4;
347         } else {
348             alg = 1;
349         }
350     } else {
351         if (total_dsize < 2048) {
352             alg = 3;
353         } else if (total_dsize < 8192) {
354             alg = 4;
355         } else if (total_dsize < 16384) {
356             alg = 1;
357         } else if (total_dsize < 32768) {
358             alg = 4;
359         } else if (total_dsize < 65536) {
360             alg = 1;
361         } else {
362             alg = 4;
363         }
364     }
365
366     return funcs[alg-1](sbuf, scount, sdtype,
367                           rbuf, rcount, rdtype, comm);
368 }
369
370 int alltoallv__ompi(const void *sbuf, const int *scounts, const int *sdisps,
371                     MPI_Datatype sdtype,
372                     void *rbuf, const int *rcounts, const int *rdisps,
373                     MPI_Datatype rdtype,
374                     MPI_Comm  comm
375                     )
376 {
377     int communicator_size = comm->size();
378     int alg = 1;
379     int (*funcs[])(const void *, const int*, const int*, MPI_Datatype, void*, const int*, const int*, MPI_Datatype, MPI_Comm) = {
380         &alltoallv__ompi_basic_linear,
381         &alltoallv__pair
382     };
383    /** Algorithms:
384      *  {1, "basic_linear"},
385      *  {2, "pairwise"},
386      *
387      * We can only optimize based on com size
388      */
389     if (communicator_size < 4) {
390         alg = 2;
391     } else if (communicator_size < 64) {
392         alg = 1;
393     } else if (communicator_size < 128) {
394         alg = 2;
395     } else if (communicator_size < 256) {
396         alg = 1;
397     } else if (communicator_size < 1024) {
398         alg = 2;
399     } else {
400         alg = 1;
401     }
402     return funcs[alg-1](sbuf, scounts, sdisps, sdtype,
403                            rbuf, rcounts, rdisps,rdtype,
404                            comm);
405 }
406
407 int barrier__ompi(MPI_Comm  comm)
408 {
409     int communicator_size = comm->size();
410     int alg = 1;
411     int (*funcs[])(MPI_Comm) = {
412         &barrier__ompi_basic_linear,
413         &barrier__ompi_basic_linear,
414         &barrier__ompi_recursivedoubling,
415         &barrier__ompi_bruck,
416         &barrier__ompi_two_procs,
417         &barrier__ompi_tree
418     };
419     /** Algorithms:
420      *  {1, "linear"},
421      *  {2, "double_ring"},
422      *  {3, "recursive_doubling"},
423      *  {4, "bruck"},
424      *  {5, "two_proc"},
425      *  {6, "tree"},
426      *
427      * We can only optimize based on com size
428      */
429     if (communicator_size < 4) {
430         alg = 3;
431     } else if (communicator_size < 8) {
432         alg = 1;
433     } else if (communicator_size < 64) {
434         alg = 3;
435     } else if (communicator_size < 256) {
436         alg = 4;
437     } else if (communicator_size < 512) {
438         alg = 6;
439     } else if (communicator_size < 1024) {
440         alg = 4;
441     } else if (communicator_size < 4096) {
442         alg = 6;
443     } else {
444         alg = 4;
445     }
446
447     return funcs[alg-1](comm);
448 }
449
450 int bcast__ompi(void *buff, int count, MPI_Datatype datatype, int root, MPI_Comm  comm)
451 {
452     int alg = 1;
453     size_t total_dsize, dsize;
454
455     int communicator_size = comm->size();
456
457     dsize = datatype->size();
458     total_dsize = dsize * (unsigned long)count;
459     int (*funcs[])(void*, int, MPI_Datatype, int, MPI_Comm) = {
460         &bcast__NTSL,
461         &bcast__ompi_pipeline,
462         &bcast__ompi_pipeline,
463         &bcast__ompi_split_bintree,
464         &bcast__NTSB,
465         &bcast__binomial_tree,
466         &bcast__mvapich2_knomial_intra_node,
467         &bcast__scatter_rdb_allgather,
468         &bcast__scatter_LR_allgather,
469     };
470     /** Algorithms:
471      *  {1, "basic_linear"},
472      *  {2, "chain"},
473      *  {3, "pipeline"},
474      *  {4, "split_binary_tree"},
475      *  {5, "binary_tree"},
476      *  {6, "binomial"},
477      *  {7, "knomial"},
478      *  {8, "scatter_allgather"},
479      *  {9, "scatter_allgather_ring"},
480      */
481     if (communicator_size < 4) {
482         if (total_dsize < 32) {
483             alg = 3;
484         } else if (total_dsize < 256) {
485             alg = 5;
486         } else if (total_dsize < 512) {
487             alg = 3;
488         } else if (total_dsize < 1024) {
489             alg = 7;
490         } else if (total_dsize < 32768) {
491             alg = 1;
492         } else if (total_dsize < 131072) {
493             alg = 5;
494         } else if (total_dsize < 262144) {
495             alg = 2;
496         } else if (total_dsize < 524288) {
497             alg = 1;
498         } else if (total_dsize < 1048576) {
499             alg = 6;
500         } else {
501             alg = 5;
502         }
503     } else if (communicator_size < 8) {
504         if (total_dsize < 64) {
505             alg = 5;
506         } else if (total_dsize < 128) {
507             alg = 6;
508         } else if (total_dsize < 2048) {
509             alg = 5;
510         } else if (total_dsize < 8192) {
511             alg = 6;
512         } else if (total_dsize < 1048576) {
513             alg = 1;
514         } else {
515             alg = 2;
516         }
517     } else if (communicator_size < 16) {
518         if (total_dsize < 8) {
519             alg = 7;
520         } else if (total_dsize < 64) {
521             alg = 5;
522         } else if (total_dsize < 4096) {
523             alg = 7;
524         } else if (total_dsize < 16384) {
525             alg = 5;
526         } else if (total_dsize < 32768) {
527             alg = 6;
528         } else {
529             alg = 1;
530         }
531     } else if (communicator_size < 32) {
532         if (total_dsize < 4096) {
533             alg = 7;
534         } else if (total_dsize < 1048576) {
535             alg = 6;
536         } else {
537             alg = 8;
538         }
539     } else if (communicator_size < 64) {
540         if (total_dsize < 2048) {
541             alg = 6;
542         } else {
543             alg = 7;
544         }
545     } else if (communicator_size < 128) {
546         alg = 7;
547     } else if (communicator_size < 256) {
548         if (total_dsize < 2) {
549             alg = 6;
550         } else if (total_dsize < 16384) {
551             alg = 5;
552         } else if (total_dsize < 32768) {
553             alg = 1;
554         } else if (total_dsize < 65536) {
555             alg = 5;
556         } else {
557             alg = 7;
558         }
559     } else if (communicator_size < 1024) {
560         if (total_dsize < 16384) {
561             alg = 7;
562         } else if (total_dsize < 32768) {
563             alg = 4;
564         } else {
565             alg = 7;
566         }
567     } else if (communicator_size < 2048) {
568         if (total_dsize < 524288) {
569             alg = 7;
570         } else {
571             alg = 8;
572         }
573     } else if (communicator_size < 4096) {
574         if (total_dsize < 262144) {
575             alg = 7;
576         } else {
577             alg = 8;
578         }
579     } else {
580         if (total_dsize < 8192) {
581             alg = 7;
582         } else if (total_dsize < 16384) {
583             alg = 5;
584         } else if (total_dsize < 262144) {
585             alg = 7;
586         } else {
587             alg = 8;
588         }
589     }
590     return funcs[alg-1](buff, count, datatype, root, comm);
591 }
592
593 int reduce__ompi(const void *sendbuf, void *recvbuf,
594                  int count, MPI_Datatype  datatype,
595                  MPI_Op   op, int root,
596                  MPI_Comm   comm)
597 {
598     size_t total_dsize, dsize;
599     int alg = 1;
600     int communicator_size = comm->size();
601
602     dsize=datatype->size();
603     total_dsize = dsize * count;
604     int (*funcs[])(const void*, void*, int, MPI_Datatype, MPI_Op, int, MPI_Comm) = {
605         &reduce__ompi_basic_linear,
606         &reduce__ompi_chain,
607         &reduce__ompi_pipeline,
608         &reduce__ompi_binary,
609         &reduce__ompi_binomial,
610         &reduce__ompi_in_order_binary,
611         //&reduce__rab our rab can't be used with all datatypes
612         &reduce__ompi_basic_linear
613     };
614     /** Algorithms:
615      *  {1, "linear"},
616      *  {2, "chain"},
617      *  {3, "pipeline"},
618      *  {4, "binary"},
619      *  {5, "binomial"},
620      *  {6, "in-order_binary"},
621      *  {7, "rabenseifner"},
622      *
623      * Currently, only linear and in-order binary tree algorithms are
624      * capable of non commutative ops.
625      */
626      if ((op != MPI_OP_NULL) && not op->is_commutative()) {
627         if (communicator_size < 4) {
628             if (total_dsize < 8) {
629                 alg = 6;
630             } else {
631                 alg = 1;
632             }
633         } else if (communicator_size < 8) {
634             alg = 1;
635         } else if (communicator_size < 16) {
636             if (total_dsize < 1024) {
637                 alg = 6;
638             } else if (total_dsize < 8192) {
639                 alg = 1;
640             } else if (total_dsize < 16384) {
641                 alg = 6;
642             } else if (total_dsize < 262144) {
643                 alg = 1;
644             } else {
645                 alg = 6;
646             }
647         } else if (communicator_size < 128) {
648             alg = 6;
649         } else if (communicator_size < 256) {
650             if (total_dsize < 512) {
651                 alg = 6;
652             } else if (total_dsize < 1024) {
653                 alg = 1;
654             } else {
655                 alg = 6;
656             }
657         } else {
658             alg = 6;
659         }
660     } else {
661         if (communicator_size < 4) {
662             if (total_dsize < 8) {
663                 alg = 7;
664             } else if (total_dsize < 16) {
665                 alg = 4;
666             } else if (total_dsize < 32) {
667                 alg = 3;
668             } else if (total_dsize < 262144) {
669                 alg = 1;
670             } else if (total_dsize < 524288) {
671                 alg = 3;
672             } else if (total_dsize < 1048576) {
673                 alg = 2;
674             } else {
675                 alg = 3;
676             }
677         } else if (communicator_size < 8) {
678             if (total_dsize < 4096) {
679                 alg = 4;
680             } else if (total_dsize < 65536) {
681                 alg = 2;
682             } else if (total_dsize < 262144) {
683                 alg = 5;
684             } else if (total_dsize < 524288) {
685                 alg = 1;
686             } else if (total_dsize < 1048576) {
687                 alg = 5;
688             } else {
689                 alg = 1;
690             }
691         } else if (communicator_size < 16) {
692             if (total_dsize < 8192) {
693                 alg = 4;
694             } else {
695                 alg = 5;
696             }
697         } else if (communicator_size < 32) {
698             if (total_dsize < 4096) {
699                 alg = 4;
700             } else {
701                 alg = 5;
702             }
703         } else if (communicator_size < 256) {
704             alg = 5;
705         } else if (communicator_size < 512) {
706             if (total_dsize < 8192) {
707                 alg = 5;
708             } else if (total_dsize < 16384) {
709                 alg = 6;
710             } else {
711                 alg = 5;
712             }
713         } else if (communicator_size < 2048) {
714             alg = 5;
715         } else if (communicator_size < 4096) {
716             if (total_dsize < 512) {
717                 alg = 5;
718             } else if (total_dsize < 1024) {
719                 alg = 6;
720             } else if (total_dsize < 8192) {
721                 alg = 5;
722             } else if (total_dsize < 16384) {
723                 alg = 6;
724             } else {
725                 alg = 5;
726             }
727         } else {
728             if (total_dsize < 16) {
729                 alg = 5;
730             } else if (total_dsize < 32) {
731                 alg = 6;
732             } else if (total_dsize < 1024) {
733                 alg = 5;
734             } else if (total_dsize < 2048) {
735                 alg = 6;
736             } else if (total_dsize < 8192) {
737                 alg = 5;
738             } else if (total_dsize < 16384) {
739                 alg = 6;
740             } else {
741                 alg = 5;
742             }
743         }
744     }
745
746     return funcs[alg-1] (sendbuf, recvbuf, count, datatype, op, root, comm);
747 }
748
749 int reduce_scatter__ompi(const void *sbuf, void *rbuf,
750                          const int *rcounts,
751                          MPI_Datatype dtype,
752                          MPI_Op  op,
753                          MPI_Comm  comm
754                          )
755 {
756     size_t total_dsize, dsize;
757     int communicator_size = comm->size();
758     int alg = 1;
759     int zerocounts = 0;
760     dsize=dtype->size();
761     total_dsize = 0;
762     for (int i = 0; i < communicator_size; i++) {
763         total_dsize += rcounts[i];
764        // if (0 == rcounts[i]) {
765         //    zerocounts = 1;
766         //}
767     }
768     total_dsize *= dsize;
769     int (*funcs[])(const void*, void*, const int*, MPI_Datatype, MPI_Op, MPI_Comm) = {
770         &reduce_scatter__default,
771         &reduce_scatter__ompi_basic_recursivehalving,
772         &reduce_scatter__ompi_ring,
773         &reduce_scatter__ompi_butterfly,
774     };
775     /** Algorithms:
776      *  {1, "non-overlapping"},
777      *  {2, "recursive_halving"},
778      *  {3, "ring"},
779      *  {4, "butterfly"},
780      *
781      * Non commutative algorithm capability needs re-investigation.
782      * Defaulting to non overlapping for non commutative ops.
783      */
784     if (((op != MPI_OP_NULL) && not op->is_commutative()) || (zerocounts)) {
785         alg = 1;
786     } else {
787         if (communicator_size < 4) {
788             if (total_dsize < 65536) {
789                 alg = 3;
790             } else if (total_dsize < 131072) {
791                 alg = 4;
792             } else {
793                 alg = 3;
794             }
795         } else if (communicator_size < 8) {
796             if (total_dsize < 8) {
797                 alg = 1;
798             } else if (total_dsize < 262144) {
799                 alg = 2;
800             } else {
801                 alg = 3;
802             }
803         } else if (communicator_size < 32) {
804             if (total_dsize < 262144) {
805                 alg = 2;
806             } else {
807                 alg = 3;
808             }
809         } else if (communicator_size < 64) {
810             if (total_dsize < 64) {
811                 alg = 1;
812             } else if (total_dsize < 2048) {
813                 alg = 2;
814             } else if (total_dsize < 524288) {
815                 alg = 4;
816             } else {
817                 alg = 3;
818             }
819         } else if (communicator_size < 128) {
820             if (total_dsize < 256) {
821                 alg = 1;
822             } else if (total_dsize < 512) {
823                 alg = 2;
824             } else if (total_dsize < 2048) {
825                 alg = 4;
826             } else if (total_dsize < 4096) {
827                 alg = 2;
828             } else {
829                 alg = 4;
830             }
831         } else if (communicator_size < 256) {
832             if (total_dsize < 256) {
833                 alg = 1;
834             } else if (total_dsize < 512) {
835                 alg = 2;
836             } else {
837                 alg = 4;
838             }
839         } else if (communicator_size < 512) {
840             if (total_dsize < 256) {
841                 alg = 1;
842             } else if (total_dsize < 1024) {
843                 alg = 2;
844             } else {
845                 alg = 4;
846             }
847         } else if (communicator_size < 1024) {
848             if (total_dsize < 512) {
849                 alg = 1;
850             } else if (total_dsize < 2048) {
851                 alg = 2;
852             } else if (total_dsize < 8192) {
853                 alg = 4;
854             } else if (total_dsize < 16384) {
855                 alg = 2;
856             } else {
857                 alg = 4;
858             }
859         } else if (communicator_size < 2048) {
860             if (total_dsize < 512) {
861                 alg = 1;
862             } else if (total_dsize < 4096) {
863                 alg = 2;
864             } else if (total_dsize < 16384) {
865                 alg = 4;
866             } else if (total_dsize < 32768) {
867                 alg = 2;
868             } else {
869                 alg = 4;
870             }
871         } else if (communicator_size < 4096) {
872             if (total_dsize < 512) {
873                 alg = 1;
874             } else if (total_dsize < 4096) {
875                 alg = 2;
876             } else {
877                 alg = 4;
878             }
879         } else {
880             if (total_dsize < 1024) {
881                 alg = 1;
882             } else if (total_dsize < 8192) {
883                 alg = 2;
884             } else {
885                 alg = 4;
886             }
887         }
888     }
889
890     return funcs[alg-1] (sbuf, rbuf, rcounts, dtype, op, comm);
891 }
892
893 int allgather__ompi(const void *sbuf, int scount,
894                     MPI_Datatype sdtype,
895                     void* rbuf, int rcount,
896                     MPI_Datatype rdtype,
897                     MPI_Comm  comm
898                     )
899 {
900     int communicator_size;
901     size_t dsize, total_dsize;
902     int alg = 1;
903     communicator_size = comm->size();
904     if (MPI_IN_PLACE != sbuf) {
905         dsize = sdtype->size();
906     } else {
907         dsize = rdtype->size();
908     }
909     total_dsize = dsize * (ptrdiff_t)scount;
910     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm) = {
911         &allgather__NTSLR_NB,
912         &allgather__bruck,
913         &allgather__rdb,
914         &allgather__ring,
915         &allgather__ompi_neighborexchange,
916         &allgather__pair
917     };
918     /** Algorithms:
919      *  {1, "linear"},
920      *  {2, "bruck"},
921      *  {3, "recursive_doubling"},
922      *  {4, "ring"},
923      *  {5, "neighbor"},
924      *  {6, "two_proc"}
925      */
926     if (communicator_size == 2) {
927         alg = 6;
928     } else if (communicator_size < 32) {
929         alg = 3;
930     } else if (communicator_size < 64) {
931         if (total_dsize < 1024) {
932             alg = 3;
933         } else if (total_dsize < 65536) {
934             alg = 5;
935         } else {
936             alg = 4;
937         }
938     } else if (communicator_size < 128) {
939         if (total_dsize < 512) {
940             alg = 3;
941         } else if (total_dsize < 65536) {
942             alg = 5;
943         } else {
944             alg = 4;
945         }
946     } else if (communicator_size < 256) {
947         if (total_dsize < 512) {
948             alg = 3;
949         } else if (total_dsize < 131072) {
950             alg = 5;
951         } else if (total_dsize < 524288) {
952             alg = 4;
953         } else if (total_dsize < 1048576) {
954             alg = 5;
955         } else {
956             alg = 4;
957         }
958     } else if (communicator_size < 512) {
959         if (total_dsize < 32) {
960             alg = 3;
961         } else if (total_dsize < 128) {
962             alg = 2;
963         } else if (total_dsize < 1024) {
964             alg = 3;
965         } else if (total_dsize < 131072) {
966             alg = 5;
967         } else if (total_dsize < 524288) {
968             alg = 4;
969         } else if (total_dsize < 1048576) {
970             alg = 5;
971         } else {
972             alg = 4;
973         }
974     } else if (communicator_size < 1024) {
975         if (total_dsize < 64) {
976             alg = 3;
977         } else if (total_dsize < 256) {
978             alg = 2;
979         } else if (total_dsize < 2048) {
980             alg = 3;
981         } else {
982             alg = 5;
983         }
984     } else if (communicator_size < 2048) {
985         if (total_dsize < 4) {
986             alg = 3;
987         } else if (total_dsize < 8) {
988             alg = 2;
989         } else if (total_dsize < 16) {
990             alg = 3;
991         } else if (total_dsize < 32) {
992             alg = 2;
993         } else if (total_dsize < 256) {
994             alg = 3;
995         } else if (total_dsize < 512) {
996             alg = 2;
997         } else if (total_dsize < 4096) {
998             alg = 3;
999         } else {
1000             alg = 5;
1001         }
1002     } else if (communicator_size < 4096) {
1003         if (total_dsize < 32) {
1004             alg = 2;
1005         } else if (total_dsize < 128) {
1006             alg = 3;
1007         } else if (total_dsize < 512) {
1008             alg = 2;
1009         } else if (total_dsize < 4096) {
1010             alg = 3;
1011         } else {
1012             alg = 5;
1013         }
1014     } else {
1015         if (total_dsize < 2) {
1016             alg = 3;
1017         } else if (total_dsize < 8) {
1018             alg = 2;
1019         } else if (total_dsize < 16) {
1020             alg = 3;
1021         } else if (total_dsize < 512) {
1022             alg = 2;
1023         } else if (total_dsize < 4096) {
1024             alg = 3;
1025         } else {
1026             alg = 5;
1027         }
1028     }
1029
1030     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
1031
1032 }
1033
1034 int allgatherv__ompi(const void *sbuf, int scount,
1035                      MPI_Datatype sdtype,
1036                      void* rbuf, const int *rcounts,
1037                      const int *rdispls,
1038                      MPI_Datatype rdtype,
1039                      MPI_Comm  comm
1040                      )
1041 {
1042     int i;
1043     int communicator_size;
1044     size_t dsize, total_dsize;
1045     int alg = 1;
1046     communicator_size = comm->size();
1047     if (MPI_IN_PLACE != sbuf) {
1048         dsize = sdtype->size();
1049     } else {
1050         dsize = rdtype->size();
1051     }
1052
1053     total_dsize = 0;
1054     for (i = 0; i < communicator_size; i++) {
1055         total_dsize += dsize * rcounts[i];
1056     }
1057
1058     /* use the per-rank data size as basis, similar to allgather */
1059     size_t per_rank_dsize = total_dsize / communicator_size;
1060
1061     int (*funcs[])(const void*, int, MPI_Datatype, void*, const int*, const int*, MPI_Datatype, MPI_Comm) = {
1062         &allgatherv__GB,
1063         &allgatherv__ompi_bruck,
1064         &allgatherv__mpich_ring,
1065         &allgatherv__ompi_neighborexchange,
1066         &allgatherv__pair
1067     };
1068     /** Algorithms:
1069      *  {1, "default"},
1070      *  {2, "bruck"},
1071      *  {3, "ring"},
1072      *  {4, "neighbor"},
1073      *  {5, "two_proc"},
1074      */
1075     if (communicator_size == 2) {
1076         if (per_rank_dsize < 2048) {
1077             alg = 3;
1078         } else if (per_rank_dsize < 4096) {
1079             alg = 5;
1080         } else if (per_rank_dsize < 8192) {
1081             alg = 3;
1082         } else {
1083             alg = 5;
1084         }
1085     } else if (communicator_size < 8) {
1086         if (per_rank_dsize < 256) {
1087             alg = 1;
1088         } else if (per_rank_dsize < 4096) {
1089             alg = 4;
1090         } else if (per_rank_dsize < 8192) {
1091             alg = 3;
1092         } else if (per_rank_dsize < 16384) {
1093             alg = 4;
1094         } else if (per_rank_dsize < 262144) {
1095             alg = 2;
1096         } else {
1097             alg = 4;
1098         }
1099     } else if (communicator_size < 16) {
1100         if (per_rank_dsize < 1024) {
1101             alg = 1;
1102         } else {
1103             alg = 2;
1104         }
1105     } else if (communicator_size < 32) {
1106         if (per_rank_dsize < 128) {
1107             alg = 1;
1108         } else if (per_rank_dsize < 262144) {
1109             alg = 2;
1110         } else {
1111             alg = 3;
1112         }
1113     } else if (communicator_size < 64) {
1114         if (per_rank_dsize < 256) {
1115             alg = 1;
1116         } else if (per_rank_dsize < 8192) {
1117             alg = 2;
1118         } else {
1119             alg = 3;
1120         }
1121     } else if (communicator_size < 128) {
1122         if (per_rank_dsize < 256) {
1123             alg = 1;
1124         } else if (per_rank_dsize < 4096) {
1125             alg = 2;
1126         } else {
1127             alg = 3;
1128         }
1129     } else if (communicator_size < 256) {
1130         if (per_rank_dsize < 1024) {
1131             alg = 2;
1132         } else if (per_rank_dsize < 65536) {
1133             alg = 4;
1134         } else {
1135             alg = 3;
1136         }
1137     } else if (communicator_size < 512) {
1138         if (per_rank_dsize < 1024) {
1139             alg = 2;
1140         } else {
1141             alg = 3;
1142         }
1143     } else if (communicator_size < 1024) {
1144         if (per_rank_dsize < 512) {
1145             alg = 2;
1146         } else if (per_rank_dsize < 1024) {
1147             alg = 1;
1148         } else if (per_rank_dsize < 4096) {
1149             alg = 2;
1150         } else if (per_rank_dsize < 1048576) {
1151             alg = 4;
1152         } else {
1153             alg = 3;
1154         }
1155     } else {
1156         if (per_rank_dsize < 4096) {
1157             alg = 2;
1158         } else {
1159             alg = 4;
1160         }
1161     }
1162
1163     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcounts, rdispls, rdtype, comm);
1164 }
1165
1166 int gather__ompi(const void *sbuf, int scount,
1167                  MPI_Datatype sdtype,
1168                  void* rbuf, int rcount,
1169                  MPI_Datatype rdtype,
1170                  int root,
1171                  MPI_Comm  comm
1172                  )
1173 {
1174     int communicator_size, rank;
1175     size_t dsize, total_dsize;
1176     int alg = 1;
1177     communicator_size = comm->size();
1178     rank = comm->rank();
1179
1180     if (rank == root) {
1181         dsize = rdtype->size();
1182         total_dsize = dsize * rcount;
1183     } else {
1184         dsize = sdtype->size();
1185         total_dsize = dsize * scount;
1186     }
1187     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, int, MPI_Comm) = {
1188         &gather__ompi_basic_linear,
1189         &gather__ompi_binomial,
1190         &gather__ompi_linear_sync
1191     };
1192     /** Algorithms:
1193      *  {1, "basic_linear"},
1194      *  {2, "binomial"},
1195      *  {3, "linear_sync"},
1196      *
1197      * We do not make any rank specific checks since the params
1198      * should be uniform across ranks.
1199      */
1200     if (communicator_size < 4) {
1201         if (total_dsize < 2) {
1202             alg = 3;
1203         } else if (total_dsize < 4) {
1204             alg = 1;
1205         } else if (total_dsize < 32768) {
1206             alg = 2;
1207         } else if (total_dsize < 65536) {
1208             alg = 1;
1209         } else if (total_dsize < 131072) {
1210             alg = 2;
1211         } else {
1212             alg = 3;
1213         }
1214     } else if (communicator_size < 8) {
1215         if (total_dsize < 1024) {
1216             alg = 2;
1217         } else if (total_dsize < 8192) {
1218             alg = 1;
1219         } else if (total_dsize < 32768) {
1220             alg = 2;
1221         } else if (total_dsize < 262144) {
1222             alg = 1;
1223         } else {
1224             alg = 3;
1225         }
1226     } else if (communicator_size < 256) {
1227         alg = 2;
1228     } else if (communicator_size < 512) {
1229         if (total_dsize < 2048) {
1230             alg = 2;
1231         } else if (total_dsize < 8192) {
1232             alg = 1;
1233         } else {
1234             alg = 2;
1235         }
1236     } else {
1237         alg = 2;
1238     }
1239
1240     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
1241 }
1242
1243
1244 int scatter__ompi(const void *sbuf, int scount,
1245                   MPI_Datatype sdtype,
1246                   void* rbuf, int rcount,
1247                   MPI_Datatype rdtype,
1248                   int root, MPI_Comm  comm
1249                   )
1250 {
1251     int communicator_size, rank;
1252     size_t dsize, total_dsize;
1253     int alg = 1;
1254
1255     communicator_size = comm->size();
1256     rank = comm->rank();
1257     if (root == rank) {
1258         dsize=sdtype->size();
1259         total_dsize = dsize * scount;
1260     } else {
1261         dsize=rdtype->size();
1262         total_dsize = dsize * rcount;
1263     }
1264     int (*funcs[])(const void*, int, MPI_Datatype, void*, int, MPI_Datatype, int, MPI_Comm) = {
1265         &scatter__ompi_basic_linear,
1266         &scatter__ompi_binomial,
1267         &scatter__ompi_linear_nb
1268     };
1269     /** Algorithms:
1270      *  {1, "basic_linear"},
1271      *  {2, "binomial"},
1272      *  {3, "linear_nb"},
1273      *
1274      * We do not make any rank specific checks since the params
1275      * should be uniform across ranks.
1276      */
1277     if (communicator_size < 4) {
1278         if (total_dsize < 2) {
1279             alg = 3;
1280         } else if (total_dsize < 131072) {
1281             alg = 1;
1282         } else if (total_dsize < 262144) {
1283             alg = 3;
1284         } else {
1285             alg = 1;
1286         }
1287     } else if (communicator_size < 8) {
1288         if (total_dsize < 2048) {
1289             alg = 2;
1290         } else if (total_dsize < 4096) {
1291             alg = 1;
1292         } else if (total_dsize < 8192) {
1293             alg = 2;
1294         } else if (total_dsize < 32768) {
1295             alg = 1;
1296         } else if (total_dsize < 1048576) {
1297             alg = 3;
1298         } else {
1299             alg = 1;
1300         }
1301     } else if (communicator_size < 16) {
1302         if (total_dsize < 16384) {
1303             alg = 2;
1304         } else if (total_dsize < 1048576) {
1305             alg = 3;
1306         } else {
1307             alg = 1;
1308         }
1309     } else if (communicator_size < 32) {
1310         if (total_dsize < 16384) {
1311             alg = 2;
1312         } else if (total_dsize < 32768) {
1313             alg = 1;
1314         } else {
1315             alg = 3;
1316         }
1317     } else if (communicator_size < 64) {
1318         if (total_dsize < 512) {
1319             alg = 2;
1320         } else if (total_dsize < 8192) {
1321             alg = 3;
1322         } else if (total_dsize < 16384) {
1323             alg = 2;
1324         } else {
1325             alg = 3;
1326         }
1327     } else {
1328         if (total_dsize < 512) {
1329             alg = 2;
1330         } else {
1331             alg = 3;
1332         }
1333     }
1334
1335     return funcs[alg-1](sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
1336 }
1337
1338 } // namespace simgrid::smpi