Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines for 2022.
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.cpp
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2022. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.hpp"
10
11 #include "smpi_mvapich2_selector_stampede.hpp"
12
13 namespace simgrid {
14 namespace smpi {
15
16
17 int alltoall__mvapich2( const void *sendbuf, int sendcount,
18                         MPI_Datatype sendtype,
19                         void* recvbuf, int recvcount,
20                         MPI_Datatype recvtype,
21                         MPI_Comm comm)
22 {
23
24   if (mv2_alltoall_table_ppn_conf == nullptr)
25     init_mv2_alltoall_tables_stampede();
26
27   int sendtype_size, recvtype_size, comm_size;
28   int mpi_errno=MPI_SUCCESS;
29   int range = 0;
30   int range_threshold = 0;
31   int conf_index = 0;
32   comm_size =  comm->size();
33
34   sendtype_size=sendtype->size();
35   recvtype_size=recvtype->size();
36   long nbytes = sendtype_size * sendcount;
37
38   /* check if safe to use partial subscription mode */
39
40   /* Search for the corresponding system size inside the tuning table */
41   while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
42       (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
43       range++;
44   }
45   /* Search for corresponding inter-leader function */
46   while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
47       && (nbytes >
48   mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
49   && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
50       range_threshold++;
51   }
52   MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
53                                                                                       .MV2_pt_Alltoall_function;
54
55   if(sendbuf != MPI_IN_PLACE) {
56       mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
57           recvbuf, recvcount, recvtype,
58           comm);
59   } else {
60       range_threshold = 0;
61       if(nbytes <
62           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
63           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
64       ) {
65         unsigned char* tmp_buf = smpi_get_tmp_sendbuffer(comm_size * recvcount * recvtype_size);
66         Datatype::copy(recvbuf, comm_size * recvcount, recvtype, tmp_buf, comm_size * recvcount, recvtype);
67
68         mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype, recvbuf, recvcount, recvtype, comm);
69         smpi_free_tmp_buffer(tmp_buf);
70       } else {
71           mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
72               recvbuf, recvcount, recvtype,
73               comm );
74       }
75   }
76
77
78   return (mpi_errno);
79 }
80
81 int allgather__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
82     void *recvbuf, int recvcount, MPI_Datatype recvtype,
83     MPI_Comm comm)
84 {
85
86   int mpi_errno = MPI_SUCCESS;
87   long nbytes = 0, comm_size, recvtype_size;
88   int range = 0;
89   bool partial_sub_ok = false;
90   int conf_index = 0;
91   int range_threshold = 0;
92   MPI_Comm shmem_comm;
93   //MPI_Comm *shmem_commptr=NULL;
94   /* Get the size of the communicator */
95   comm_size = comm->size();
96   recvtype_size=recvtype->size();
97   nbytes = recvtype_size * recvcount;
98
99   if (mv2_allgather_table_ppn_conf == nullptr)
100     init_mv2_allgather_tables_stampede();
101
102   if(comm->get_leaders_comm()==MPI_COMM_NULL){
103     comm->init_smp();
104   }
105
106   if (comm->is_uniform()){
107     shmem_comm = comm->get_intra_comm();
108     int local_size = shmem_comm->size();
109     int i          = 0;
110     if (mv2_allgather_table_ppn_conf[0] == -1) {
111       // Indicating user defined tuning
112       conf_index = 0;
113       goto conf_check_end;
114     }
115     do {
116       if (local_size == mv2_allgather_table_ppn_conf[i]) {
117         conf_index = i;
118         partial_sub_ok = true;
119         break;
120       }
121       i++;
122     } while(i < mv2_allgather_num_ppn_conf);
123   }
124   conf_check_end:
125   if (not partial_sub_ok) {
126     conf_index = 0;
127   }
128
129   /* Search for the corresponding system size inside the tuning table */
130   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
131       (comm_size >
132   mv2_allgather_thresholds_table[conf_index][range].numproc)) {
133       range++;
134   }
135   /* Search for corresponding inter-leader function */
136   while ((range_threshold <
137       (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
138       && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
139       && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
140           -1)) {
141       range_threshold++;
142   }
143
144   /* Set inter-leader pt */
145   MV2_Allgatherction =
146       mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
147       MV2_pt_Allgatherction;
148
149   bool is_two_level = mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
150
151   /* intracommunicator */
152   if (is_two_level) {
153     if (partial_sub_ok) {
154       if (comm->is_blocked()){
155       mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
156                             recvbuf, recvcount, recvtype,
157                             comm);
158       }else{
159       mpi_errno = allgather__mpich(sendbuf, sendcount, sendtype,
160                             recvbuf, recvcount, recvtype,
161                             comm);
162       }
163     } else {
164       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
165           recvbuf, recvcount, recvtype,
166           comm);
167     }
168   } else if(MV2_Allgatherction == &MPIR_Allgather_Bruck_MV2
169       || MV2_Allgatherction == &MPIR_Allgather_RD_MV2
170       || MV2_Allgatherction == &MPIR_Allgather_Ring_MV2) {
171       mpi_errno = MV2_Allgatherction(sendbuf, sendcount, sendtype,
172           recvbuf, recvcount, recvtype,
173           comm);
174   }else{
175       return MPI_ERR_OTHER;
176   }
177
178   return mpi_errno;
179 }
180
181 int gather__mvapich2(const void *sendbuf,
182     int sendcnt,
183     MPI_Datatype sendtype,
184     void *recvbuf,
185     int recvcnt,
186     MPI_Datatype recvtype,
187     int root, MPI_Comm  comm)
188 {
189   if (mv2_gather_thresholds_table == nullptr)
190     init_mv2_gather_tables_stampede();
191
192   int mpi_errno = MPI_SUCCESS;
193   int range = 0;
194   int range_threshold = 0;
195   int range_intra_threshold = 0;
196   long nbytes = 0;
197   int comm_size = comm->size();
198   int rank      = comm->rank();
199
200   if (rank == root) {
201     int recvtype_size = recvtype->size();
202     nbytes            = recvcnt * recvtype_size;
203   } else {
204     int sendtype_size = sendtype->size();
205     nbytes            = sendcnt * sendtype_size;
206   }
207
208   /* Search for the corresponding system size inside the tuning table */
209   while ((range < (mv2_size_gather_tuning_table - 1)) &&
210       (comm_size > mv2_gather_thresholds_table[range].numproc)) {
211       range++;
212   }
213   /* Search for corresponding inter-leader function */
214   while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
215       && (nbytes >
216   mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
217   && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
218       -1)) {
219       range_threshold++;
220   }
221
222   /* Search for corresponding intra node function */
223   while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
224       && (nbytes >
225   mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
226   && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
227       -1)) {
228       range_intra_threshold++;
229   }
230
231     if (comm->is_blocked() ) {
232         // Set intra-node function pt for gather_two_level
233         MV2_Gather_intra_node_function =
234                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
235                               MV2_pt_Gather_function;
236         //Set inter-leader pt
237         MV2_Gather_inter_leader_function =
238                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
239                               MV2_pt_Gather_function;
240         // We call Gather function
241         mpi_errno =
242             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
243                                              recvtype, root, comm);
244
245     } else {
246   // Indeed, direct (non SMP-aware)gather is MPICH one
247   mpi_errno = gather__mpich(sendbuf, sendcnt, sendtype,
248       recvbuf, recvcnt, recvtype,
249       root, comm);
250   }
251
252   return mpi_errno;
253 }
254
255 int allgatherv__mvapich2(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
256     void *recvbuf, const int *recvcounts, const int *displs,
257     MPI_Datatype recvtype, MPI_Comm  comm )
258 {
259   int mpi_errno = MPI_SUCCESS;
260   int range = 0, comm_size, total_count, recvtype_size, i;
261   int range_threshold = 0;
262   long nbytes = 0;
263
264   if (mv2_allgatherv_thresholds_table == nullptr)
265     init_mv2_allgatherv_tables_stampede();
266
267   comm_size = comm->size();
268   total_count = 0;
269   for (i = 0; i < comm_size; i++)
270     total_count += recvcounts[i];
271
272   recvtype_size=recvtype->size();
273   nbytes = total_count * recvtype_size;
274
275   /* Search for the corresponding system size inside the tuning table */
276   while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
277       (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
278       range++;
279   }
280   /* Search for corresponding inter-leader function */
281   while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
282       && (nbytes >
283   comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
284   && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
285       -1)) {
286       range_threshold++;
287   }
288   /* Set inter-leader pt */
289   MV2_Allgatherv_function =
290       mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
291       MV2_pt_Allgatherv_function;
292
293   if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
294     {
295     if (not(comm_size & (comm_size - 1))) {
296       mpi_errno =
297           MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
298         } else {
299             mpi_errno =
300                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
301                     sendtype, recvbuf,
302                     recvcounts, displs,
303                     recvtype, comm);
304         }
305     } else {
306         mpi_errno =
307             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
308                 recvbuf, recvcounts, displs,
309                 recvtype, comm);
310     }
311
312   return mpi_errno;
313 }
314
315
316
317 int allreduce__mvapich2(const void *sendbuf,
318     void *recvbuf,
319     int count,
320     MPI_Datatype datatype,
321     MPI_Op op, MPI_Comm comm)
322 {
323
324   int mpi_errno = MPI_SUCCESS;
325   //int rank = 0,
326   int comm_size = 0;
327
328   comm_size = comm->size();
329   //rank = comm->rank();
330
331   if (count == 0) {
332       return MPI_SUCCESS;
333   }
334
335   if (mv2_allreduce_thresholds_table == nullptr)
336     init_mv2_allreduce_tables_stampede();
337
338   /* check if multiple threads are calling this collective function */
339
340   MPI_Aint sendtype_size = 0;
341   long nbytes = 0;
342   MPI_Aint true_lb, true_extent;
343
344   sendtype_size=datatype->size();
345   nbytes = count * sendtype_size;
346
347   datatype->extent(&true_lb, &true_extent);
348   bool is_commutative = op->is_commutative();
349
350   {
351     int range = 0, range_threshold = 0, range_threshold_intra = 0;
352     bool is_two_level = false;
353
354     /* Search for the corresponding system size inside the tuning table */
355     while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
356         (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
357         range++;
358     }
359     /* Search for corresponding inter-leader function */
360     /* skip mcast pointers if mcast is not available */
361     if (not mv2_allreduce_thresholds_table[range].mcast_enabled) {
362         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
363             && ((mv2_allreduce_thresholds_table[range].
364                 inter_leader[range_threshold].MV2_pt_Allreducection
365                 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
366                 (mv2_allreduce_thresholds_table[range].
367                     inter_leader[range_threshold].MV2_pt_Allreducection
368                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
369             )) {
370             range_threshold++;
371         }
372     }
373     while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
374         && (nbytes >
375     mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
376     && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
377         range_threshold++;
378     }
379     if (mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold]) {
380       is_two_level = true;
381     }
382     /* Search for corresponding intra-node function */
383     while ((range_threshold_intra <
384         (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
385         && (nbytes >
386     mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
387     && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
388         -1)) {
389         range_threshold_intra++;
390     }
391
392     MV2_Allreducection = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
393                                                                                 .MV2_pt_Allreducection;
394
395     MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
396                                                                                     .MV2_pt_Allreducection;
397
398     /* check if mcast is ready, otherwise replace mcast with other algorithm */
399     if((MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
400         (MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
401         {
402           MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
403         }
404         if (not is_two_level) {
405             MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
406         }
407     }
408
409     if (is_two_level) {
410       // check if shm is ready, if not use other algorithm first
411       if (is_commutative) {
412           if(comm->get_leaders_comm()==MPI_COMM_NULL){
413             comm->init_smp();
414           }
415           mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
416                                                      datatype, op, comm);
417       } else {
418         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
419             datatype, op, comm);
420       }
421     } else {
422         mpi_errno = MV2_Allreducection(sendbuf, recvbuf, count,
423             datatype, op, comm);
424     }
425   }
426
427   //comm->ch.intra_node_done=0;
428
429   return (mpi_errno);
430
431
432 }
433
434
435 int alltoallv__mvapich2(const void *sbuf, const int *scounts, const int *sdisps,
436     MPI_Datatype sdtype,
437     void *rbuf, const int *rcounts, const int *rdisps,
438     MPI_Datatype rdtype,
439     MPI_Comm  comm
440 )
441 {
442
443   if (sbuf == MPI_IN_PLACE) {
444       return alltoallv__ompi_basic_linear(sbuf, scounts, sdisps, sdtype,
445                                           rbuf, rcounts, rdisps, rdtype,
446                                           comm);
447   } else     /* For starters, just keep the original algorithm. */
448   return alltoallv__ring(sbuf, scounts, sdisps, sdtype,
449                          rbuf, rcounts, rdisps, rdtype,
450                          comm);
451 }
452
453
454 int barrier__mvapich2(MPI_Comm  comm)
455 {
456   return barrier__mvapich2_pair(comm);
457 }
458
459
460
461
462 int bcast__mvapich2(void *buffer,
463                     int count,
464                     MPI_Datatype datatype,
465                     int root, MPI_Comm comm)
466 {
467     int mpi_errno = MPI_SUCCESS;
468     int comm_size/*, rank*/;
469     bool two_level_bcast      = true;
470     long nbytes = 0;
471     int range = 0;
472     int range_threshold = 0;
473     int range_threshold_intra = 0;
474     MPI_Aint type_size;
475     //, position;
476     // unsigned char *tmp_buf = NULL;
477     MPI_Comm shmem_comm;
478     //MPID_Datatype *dtp;
479
480     if (count == 0)
481         return MPI_SUCCESS;
482     if(comm->get_leaders_comm()==MPI_COMM_NULL){
483       comm->init_smp();
484     }
485     if (not mv2_bcast_thresholds_table)
486       init_mv2_bcast_tables_stampede();
487     comm_size = comm->size();
488     //rank = comm->rank();
489
490     // bool is_contig = true;
491 /*    if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
492 /*        is_contig = true;*/
493 /*    else {*/
494 /*        MPID_Datatype_get_ptr(datatype, dtp);*/
495 /*        is_contig = dtp->is_contig;*/
496 /*    }*/
497
498     // bool is_homogeneous = true;
499
500     /* MPI_Type_size() might not give the accurate size of the packed
501      * datatype for heterogeneous systems (because of padding, encoding,
502      * etc). On the other hand, MPI_Pack_size() can become very
503      * expensive, depending on the implementation, especially for
504      * heterogeneous systems. We want to use MPI_Type_size() wherever
505      * possible, and MPI_Pack_size() in other places.
506      */
507     //if (is_homogeneous) {
508         type_size=datatype->size();
509
510    /* } else {
511         MPIR_Pack_size_impl(1, datatype, &type_size);
512     }*/
513     nbytes =  (count) * (type_size);
514
515     /* Search for the corresponding system size inside the tuning table */
516     while ((range < (mv2_size_bcast_tuning_table - 1)) &&
517            (comm_size > mv2_bcast_thresholds_table[range].numproc)) {
518         range++;
519     }
520     /* Search for corresponding inter-leader function */
521     while ((range_threshold < (mv2_bcast_thresholds_table[range].size_inter_table - 1))
522            && (nbytes >
523                mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max)
524            && (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
525         range_threshold++;
526     }
527
528     /* Search for corresponding intra-node function */
529     while ((range_threshold_intra <
530             (mv2_bcast_thresholds_table[range].size_intra_table - 1))
531            && (nbytes >
532                mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max)
533            && (mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max !=
534                -1)) {
535         range_threshold_intra++;
536     }
537
538     MV2_Bcast_function =
539         mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
540         MV2_pt_Bcast_function;
541
542     MV2_Bcast_intra_node_function =
543         mv2_bcast_thresholds_table[range].
544         intra_node[range_threshold_intra].MV2_pt_Bcast_function;
545
546 /*    if (mv2_user_bcast_intra == NULL && */
547 /*            MV2_Bcast_intra_node_function == &MPIR_Knomial_Bcast_intra_node_MV2) {*/
548 /*            MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;*/
549 /*    }*/
550
551     if (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
552         zcpy_pipelined_knomial_factor != -1) {
553         zcpy_knomial_factor =
554             mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
555             zcpy_pipelined_knomial_factor;
556     }
557
558     if (mv2_pipelined_zcpy_knomial_factor != -1) {
559         zcpy_knomial_factor = mv2_pipelined_zcpy_knomial_factor;
560     }
561
562     if (MV2_Bcast_intra_node_function == nullptr) {
563       /* if tuning table do not have any intra selection, set func pointer to
564       ** default one for mcast intra node */
565       MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;
566     }
567
568     /* Set value of pipeline segment size */
569     bcast_segment_size = mv2_bcast_thresholds_table[range].bcast_segment_size;
570
571     /* Set value of inter node knomial factor */
572     mv2_inter_node_knomial_factor = mv2_bcast_thresholds_table[range].inter_node_knomial_factor;
573
574     /* Set value of intra node knomial factor */
575     mv2_intra_node_knomial_factor = mv2_bcast_thresholds_table[range].intra_node_knomial_factor;
576
577     /* Check if we will use a two level algorithm or not */
578     two_level_bcast =
579 #if defined(_MCST_SUPPORT_)
580         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold]
581         || comm->ch.is_mcast_ok;
582 #else
583         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold];
584 #endif
585     if (two_level_bcast) {
586        // if (not is_contig || not is_homogeneous) {
587 //   tmp_buf = smpi_get_tmp_sendbuffer(nbytes);
588
589 /*            position = 0;*/
590 /*            if (rank == root) {*/
591 /*                mpi_errno =*/
592 /*                    MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
593 /*                if (mpi_errno)*/
594 /*                    MPIU_ERR_POP(mpi_errno);*/
595 /*            }*/
596 // }
597 #ifdef CHANNEL_MRAIL_GEN2
598         if ((mv2_enable_zcpy_bcast == 1) &&
599               (&MPIR_Pipelined_Bcast_Zcpy_MV2 == MV2_Bcast_function)) {
600           // if (not is_contig || not is_homogeneous) {
601           //   mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
602           // } else {
603                 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(buffer, count, datatype,
604                                                  root, comm);
605           // }
606         } else
607 #endif /* defined(CHANNEL_MRAIL_GEN2) */
608         {
609             shmem_comm = comm->get_intra_comm();
610             // if (not is_contig || not is_homogeneous) {
611             //   MPIR_Bcast_tune_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
612             // } else {
613               MPIR_Bcast_tune_inter_node_helper_MV2(buffer, count, datatype, root, comm);
614             // }
615
616             /* We are now done with the inter-node phase */
617
618
619                     root = INTRA_NODE_ROOT;
620
621                     // if (not is_contig || not is_homogeneous) {
622                     //       mpi_errno = MV2_Bcast_intra_node_function(tmp_buf, nbytes, MPI_BYTE, root, shmem_comm);
623                     // } else {
624                     mpi_errno = MV2_Bcast_intra_node_function(buffer, count,
625                                                               datatype, root, shmem_comm);
626
627                     // }
628         }
629         /*        if (not is_contig || not is_homogeneous) {*/
630         /*            if (rank != root) {*/
631         /*                position = 0;*/
632         /*                mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer,*/
633         /*                                             count, datatype);*/
634         /*            }*/
635         /*        }*/
636     } else {
637         /* We use Knomial for intra node */
638         MV2_Bcast_intra_node_function = &MPIR_Knomial_Bcast_intra_node_MV2;
639 /*        if (mv2_enable_shmem_bcast == 0) {*/
640             /* Fall back to non-tuned version */
641 /*            MPIR_Bcast_intra_MV2(buffer, count, datatype, root, comm);*/
642 /*        } else {*/
643             mpi_errno = MV2_Bcast_function(buffer, count, datatype, root,
644                                            comm);
645
646 /*        }*/
647     }
648
649
650     return mpi_errno;
651
652 }
653
654
655
656 int reduce__mvapich2(const void *sendbuf,
657     void *recvbuf,
658     int count,
659     MPI_Datatype datatype,
660     MPI_Op op, int root, MPI_Comm comm)
661 {
662   if (mv2_reduce_thresholds_table == nullptr)
663     init_mv2_reduce_tables_stampede();
664
665   int mpi_errno = MPI_SUCCESS;
666   int range = 0;
667   int range_threshold = 0;
668   int range_intra_threshold = 0;
669   int pof2;
670   int comm_size = 0;
671   long nbytes = 0;
672   int sendtype_size;
673   bool is_two_level = false;
674
675   comm_size = comm->size();
676   sendtype_size=datatype->size();
677   nbytes = count * sendtype_size;
678
679   if (count == 0)
680     return MPI_SUCCESS;
681
682   bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
683
684   /* find nearest power-of-two less than or equal to comm_size */
685   for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
686   pof2 >>=1;
687
688
689   /* Search for the corresponding system size inside the tuning table */
690   while ((range < (mv2_size_reduce_tuning_table - 1)) &&
691       (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
692       range++;
693   }
694   /* Search for corresponding inter-leader function */
695   while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
696       && (nbytes >
697   mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
698   && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
699       -1)) {
700       range_threshold++;
701   }
702
703   /* Search for corresponding intra node function */
704   while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
705       && (nbytes >
706   mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
707   && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
708       -1)) {
709       range_intra_threshold++;
710   }
711
712   /* Set intra-node function pt for reduce_two_level */
713   MV2_Reduce_intra_function =
714       mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
715       MV2_pt_Reduce_function;
716   /* Set inter-leader pt */
717   MV2_Reduce_function =
718       mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
719       MV2_pt_Reduce_function;
720
721   if(mv2_reduce_intra_knomial_factor<0)
722     {
723       mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
724     }
725   if(mv2_reduce_inter_knomial_factor<0)
726     {
727       mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
728     }
729   if (mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold]) {
730     is_two_level = true;
731   }
732   /* We call Reduce function */
733   if (is_two_level) {
734     if (is_commutative) {
735          if(comm->get_leaders_comm()==MPI_COMM_NULL){
736            comm->init_smp();
737          }
738          mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count,
739                                            datatype, op, root, comm);
740     } else {
741       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
742           datatype, op, root, comm);
743     }
744     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
745         if (is_commutative)
746           {
747             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
748                 datatype, op, root, comm);
749           } else {
750               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
751                   datatype, op, root, comm);
752           }
753     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
754         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
755           {
756             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
757                 datatype, op, root, comm);
758           } else {
759               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
760                   datatype, op, root, comm);
761           }
762     } else {
763         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
764             datatype, op, root, comm);
765     }
766
767
768   return mpi_errno;
769
770 }
771
772
773 int reduce_scatter__mvapich2(const void *sendbuf, void *recvbuf, const int *recvcnts,
774     MPI_Datatype datatype, MPI_Op op,
775     MPI_Comm comm)
776 {
777   int mpi_errno = MPI_SUCCESS;
778   int i = 0, comm_size = comm->size(), total_count = 0, type_size =
779       0, nbytes = 0;
780   int* disps          = new int[comm_size];
781
782   if (mv2_red_scat_thresholds_table == nullptr)
783     init_mv2_reduce_scatter_tables_stampede();
784
785   bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
786   for (i = 0; i < comm_size; i++) {
787       disps[i] = total_count;
788       total_count += recvcnts[i];
789   }
790
791   type_size=datatype->size();
792   nbytes = total_count * type_size;
793
794   if (is_commutative) {
795     int range           = 0;
796     int range_threshold = 0;
797
798       /* Search for the corresponding system size inside the tuning table */
799       while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
800           (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
801           range++;
802       }
803       /* Search for corresponding inter-leader function */
804       while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
805           && (nbytes >
806       mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
807       && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
808           -1)) {
809           range_threshold++;
810       }
811
812       /* Set inter-leader pt */
813       MV2_Red_scat_function =
814           mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
815           MV2_pt_Red_scat_function;
816
817       mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
818           recvcnts, datatype,
819           op, comm);
820   } else {
821       bool is_block_regular = true;
822       for (i = 0; i < (comm_size - 1); ++i) {
823           if (recvcnts[i] != recvcnts[i+1]) {
824               is_block_regular = false;
825               break;
826           }
827       }
828       int pof2 = 1;
829       while (pof2 < comm_size) pof2 <<= 1;
830       if (pof2 == comm_size && is_block_regular) {
831           /* noncommutative, pof2 size, and block regular */
832           MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
833               recvcnts, datatype,
834               op, comm);
835       }
836       mpi_errno =  reduce_scatter__mpich_rdb(sendbuf, recvbuf,
837                                              recvcnts, datatype,
838                                              op, comm);
839   }
840   delete[] disps;
841   return mpi_errno;
842
843 }
844
845
846
847 int scatter__mvapich2(const void *sendbuf,
848     int sendcnt,
849     MPI_Datatype sendtype,
850     void *recvbuf,
851     int recvcnt,
852     MPI_Datatype recvtype,
853     int root, MPI_Comm comm)
854 {
855   int range = 0, range_threshold = 0, range_threshold_intra = 0;
856   int mpi_errno = MPI_SUCCESS;
857   //   int mpi_errno_ret = MPI_SUCCESS;
858   int rank, nbytes, comm_size;
859   bool partial_sub_ok = false;
860   int conf_index = 0;
861      MPI_Comm shmem_comm;
862   //    MPID_Comm *shmem_commptr=NULL;
863      if (mv2_scatter_thresholds_table == nullptr)
864        init_mv2_scatter_tables_stampede();
865
866      if (comm->get_leaders_comm() == MPI_COMM_NULL) {
867        comm->init_smp();
868      }
869
870   comm_size = comm->size();
871
872   rank = comm->rank();
873
874   if (rank == root) {
875     int sendtype_size = sendtype->size();
876     nbytes            = sendcnt * sendtype_size;
877   } else {
878     int recvtype_size = recvtype->size();
879     nbytes            = recvcnt * recvtype_size;
880   }
881
882     // check if safe to use partial subscription mode
883     if (comm->is_uniform()) {
884
885         shmem_comm = comm->get_intra_comm();
886         if (mv2_scatter_table_ppn_conf[0] == -1) {
887             // Indicating user defined tuning
888             conf_index = 0;
889         }else{
890           int local_size = shmem_comm->size();
891           int i          = 0;
892             do {
893                 if (local_size == mv2_scatter_table_ppn_conf[i]) {
894                     conf_index = i;
895                     partial_sub_ok = true;
896                     break;
897                 }
898                 i++;
899             } while(i < mv2_scatter_num_ppn_conf);
900         }
901     }
902
903   if (not partial_sub_ok) {
904       conf_index = 0;
905   }
906
907   /* Search for the corresponding system size inside the tuning table */
908   while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
909       (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
910       range++;
911   }
912   /* Search for corresponding inter-leader function */
913   while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
914       && (nbytes >
915   mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
916   && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
917       range_threshold++;
918   }
919
920   /* Search for corresponding intra-node function */
921   while ((range_threshold_intra <
922       (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
923       && (nbytes >
924   mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
925   && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
926       -1)) {
927       range_threshold_intra++;
928   }
929
930   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
931                                                                                       .MV2_pt_Scatter_function;
932
933   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
934 #if defined(_MCST_SUPPORT_)
935       if(comm->ch.is_mcast_ok == 1
936           && mv2_use_mcast_scatter == 1
937           && comm->ch.shmem_coll_ok == 1) {
938           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
939       } else
940 #endif /*#if defined(_MCST_SUPPORT_) */
941         {
942         if (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function !=
943             nullptr) {
944           MV2_Scatter_function =
945               mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].MV2_pt_Scatter_function;
946         } else {
947           /* Fallback! */
948           MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
949         }
950         }
951   }
952
953   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
954       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
955        if( comm->is_blocked()) {
956              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
957                                 .MV2_pt_Scatter_function;
958
959              mpi_errno =
960                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
961                                         recvbuf, recvcnt, recvtype, root,
962                                         comm);
963          } else {
964       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
965           recvbuf, recvcnt, recvtype, root,
966           comm);
967
968       }
969   } else {
970       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
971           recvbuf, recvcnt, recvtype, root,
972           comm);
973   }
974   return (mpi_errno);
975 }
976
977 }
978 }
979
980 void smpi_coll_cleanup_mvapich2()
981 {
982   if (mv2_alltoall_thresholds_table)
983     delete[] mv2_alltoall_thresholds_table[0];
984   delete[] mv2_alltoall_thresholds_table;
985   delete[] mv2_size_alltoall_tuning_table;
986   delete[] mv2_alltoall_table_ppn_conf;
987
988   delete[] mv2_gather_thresholds_table;
989   if (mv2_allgather_thresholds_table)
990     delete[] mv2_allgather_thresholds_table[0];
991   delete[] mv2_size_allgather_tuning_table;
992   delete[] mv2_allgather_table_ppn_conf;
993   delete[] mv2_allgather_thresholds_table;
994
995   delete[] mv2_allgatherv_thresholds_table;
996   delete[] mv2_reduce_thresholds_table;
997   delete[] mv2_red_scat_thresholds_table;
998   delete[] mv2_allreduce_thresholds_table;
999   delete[] mv2_bcast_thresholds_table;
1000   if (mv2_scatter_thresholds_table)
1001     delete[] mv2_scatter_thresholds_table[0];
1002   delete[] mv2_scatter_thresholds_table;
1003   delete[] mv2_size_scatter_tuning_table;
1004   delete[] mv2_scatter_table_ppn_conf;
1005 }