Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Actor: make the refcount observable, and improve debug messages
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.cpp
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2019. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.hpp"
10
11 #include "smpi_mvapich2_selector_stampede.hpp"
12
13 namespace simgrid{
14 namespace smpi{
15
16
17 int Coll_alltoall_mvapich2::alltoall( const void *sendbuf, int sendcount,
18     MPI_Datatype sendtype,
19     void* recvbuf, int recvcount,
20     MPI_Datatype recvtype,
21     MPI_Comm comm)
22 {
23
24   if(mv2_alltoall_table_ppn_conf==NULL)
25     init_mv2_alltoall_tables_stampede();
26
27   int sendtype_size, recvtype_size, comm_size;
28   int mpi_errno=MPI_SUCCESS;
29   int range = 0;
30   int range_threshold = 0;
31   int conf_index = 0;
32   comm_size =  comm->size();
33
34   sendtype_size=sendtype->size();
35   recvtype_size=recvtype->size();
36   long nbytes = sendtype_size * sendcount;
37
38   /* check if safe to use partial subscription mode */
39
40   /* Search for the corresponding system size inside the tuning table */
41   while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
42       (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
43       range++;
44   }
45   /* Search for corresponding inter-leader function */
46   while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
47       && (nbytes >
48   mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
49   && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
50       range_threshold++;
51   }
52   MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
53                                                                                       .MV2_pt_Alltoall_function;
54
55   if(sendbuf != MPI_IN_PLACE) {
56       mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
57           recvbuf, recvcount, recvtype,
58           comm);
59   } else {
60       range_threshold = 0;
61       if(nbytes <
62           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
63           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
64       ) {
65         unsigned char* tmp_buf = smpi_get_tmp_sendbuffer(comm_size * recvcount * recvtype_size);
66         Datatype::copy(recvbuf, comm_size * recvcount, recvtype, tmp_buf, comm_size * recvcount, recvtype);
67
68         mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype, recvbuf, recvcount, recvtype, comm);
69         smpi_free_tmp_buffer(tmp_buf);
70       } else {
71           mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
72               recvbuf, recvcount, recvtype,
73               comm );
74       }
75   }
76
77
78   return (mpi_errno);
79 }
80
81 int Coll_allgather_mvapich2::allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
82     void *recvbuf, int recvcount, MPI_Datatype recvtype,
83     MPI_Comm comm)
84 {
85
86   int mpi_errno = MPI_SUCCESS;
87   long nbytes = 0, comm_size, recvtype_size;
88   int range = 0;
89   int partial_sub_ok = 0;
90   int conf_index = 0;
91   int range_threshold = 0;
92   int is_two_level = 0;
93   MPI_Comm shmem_comm;
94   //MPI_Comm *shmem_commptr=NULL;
95   /* Get the size of the communicator */
96   comm_size = comm->size();
97   recvtype_size=recvtype->size();
98   nbytes = recvtype_size * recvcount;
99
100   if(mv2_allgather_table_ppn_conf==NULL)
101     init_mv2_allgather_tables_stampede();
102
103   if(comm->get_leaders_comm()==MPI_COMM_NULL){
104     comm->init_smp();
105   }
106
107   if (comm->is_uniform()){
108     shmem_comm = comm->get_intra_comm();
109     int local_size = shmem_comm->size();
110     int i          = 0;
111     if (mv2_allgather_table_ppn_conf[0] == -1) {
112       // Indicating user defined tuning
113       conf_index = 0;
114       goto conf_check_end;
115     }
116     do {
117       if (local_size == mv2_allgather_table_ppn_conf[i]) {
118         conf_index = i;
119         partial_sub_ok = 1;
120         break;
121       }
122       i++;
123     } while(i < mv2_allgather_num_ppn_conf);
124   }
125   conf_check_end:
126   if (partial_sub_ok != 1) {
127     conf_index = 0;
128   }
129
130   /* Search for the corresponding system size inside the tuning table */
131   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
132       (comm_size >
133   mv2_allgather_thresholds_table[conf_index][range].numproc)) {
134       range++;
135   }
136   /* Search for corresponding inter-leader function */
137   while ((range_threshold <
138       (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
139       && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
140       && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
141           -1)) {
142       range_threshold++;
143   }
144
145   /* Set inter-leader pt */
146   MV2_Allgatherction =
147       mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
148       MV2_pt_Allgatherction;
149
150   is_two_level =  mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
151
152   /* intracommunicator */
153   if(is_two_level ==1){
154     if(partial_sub_ok ==1){
155       if (comm->is_blocked()){
156       mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
157                             recvbuf, recvcount, recvtype,
158                             comm);
159       }else{
160       mpi_errno = Coll_allgather_mpich::allgather(sendbuf, sendcount, sendtype,
161                             recvbuf, recvcount, recvtype,
162                             comm);
163       }
164     } else {
165       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
166           recvbuf, recvcount, recvtype,
167           comm);
168     }
169   } else if(MV2_Allgatherction == &MPIR_Allgather_Bruck_MV2
170       || MV2_Allgatherction == &MPIR_Allgather_RD_MV2
171       || MV2_Allgatherction == &MPIR_Allgather_Ring_MV2) {
172       mpi_errno = MV2_Allgatherction(sendbuf, sendcount, sendtype,
173           recvbuf, recvcount, recvtype,
174           comm);
175   }else{
176       return MPI_ERR_OTHER;
177   }
178
179   return mpi_errno;
180 }
181
182 int Coll_gather_mvapich2::gather(const void *sendbuf,
183     int sendcnt,
184     MPI_Datatype sendtype,
185     void *recvbuf,
186     int recvcnt,
187     MPI_Datatype recvtype,
188     int root, MPI_Comm  comm)
189 {
190   if(mv2_gather_thresholds_table==NULL)
191     init_mv2_gather_tables_stampede();
192
193   int mpi_errno = MPI_SUCCESS;
194   int range = 0;
195   int range_threshold = 0;
196   int range_intra_threshold = 0;
197   long nbytes = 0;
198   int comm_size = comm->size();
199   int rank      = comm->rank();
200
201   if (rank == root) {
202     int recvtype_size = recvtype->size();
203     nbytes            = recvcnt * recvtype_size;
204   } else {
205     int sendtype_size = sendtype->size();
206     nbytes            = sendcnt * sendtype_size;
207   }
208
209   /* Search for the corresponding system size inside the tuning table */
210   while ((range < (mv2_size_gather_tuning_table - 1)) &&
211       (comm_size > mv2_gather_thresholds_table[range].numproc)) {
212       range++;
213   }
214   /* Search for corresponding inter-leader function */
215   while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
216       && (nbytes >
217   mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
218   && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
219       -1)) {
220       range_threshold++;
221   }
222
223   /* Search for corresponding intra node function */
224   while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
225       && (nbytes >
226   mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
227   && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
228       -1)) {
229       range_intra_threshold++;
230   }
231
232     if (comm->is_blocked() ) {
233         // Set intra-node function pt for gather_two_level
234         MV2_Gather_intra_node_function =
235                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
236                               MV2_pt_Gather_function;
237         //Set inter-leader pt
238         MV2_Gather_inter_leader_function =
239                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
240                               MV2_pt_Gather_function;
241         // We call Gather function
242         mpi_errno =
243             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
244                                              recvtype, root, comm);
245
246     } else {
247   // Indeed, direct (non SMP-aware)gather is MPICH one
248   mpi_errno = Coll_gather_mpich::gather(sendbuf, sendcnt, sendtype,
249       recvbuf, recvcnt, recvtype,
250       root, comm);
251   }
252
253   return mpi_errno;
254 }
255
256 int Coll_allgatherv_mvapich2::allgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
257     void *recvbuf, const int *recvcounts, const int *displs,
258     MPI_Datatype recvtype, MPI_Comm  comm )
259 {
260   int mpi_errno = MPI_SUCCESS;
261   int range = 0, comm_size, total_count, recvtype_size, i;
262   int range_threshold = 0;
263   long nbytes = 0;
264
265   if(mv2_allgatherv_thresholds_table==NULL)
266     init_mv2_allgatherv_tables_stampede();
267
268   comm_size = comm->size();
269   total_count = 0;
270   for (i = 0; i < comm_size; i++)
271     total_count += recvcounts[i];
272
273   recvtype_size=recvtype->size();
274   nbytes = total_count * recvtype_size;
275
276   /* Search for the corresponding system size inside the tuning table */
277   while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
278       (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
279       range++;
280   }
281   /* Search for corresponding inter-leader function */
282   while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
283       && (nbytes >
284   comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
285   && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
286       -1)) {
287       range_threshold++;
288   }
289   /* Set inter-leader pt */
290   MV2_Allgatherv_function =
291       mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
292       MV2_pt_Allgatherv_function;
293
294   if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
295     {
296     if (not(comm_size & (comm_size - 1))) {
297       mpi_errno =
298           MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
299         } else {
300             mpi_errno =
301                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
302                     sendtype, recvbuf,
303                     recvcounts, displs,
304                     recvtype, comm);
305         }
306     } else {
307         mpi_errno =
308             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
309                 recvbuf, recvcounts, displs,
310                 recvtype, comm);
311     }
312
313   return mpi_errno;
314 }
315
316
317
318 int Coll_allreduce_mvapich2::allreduce(const void *sendbuf,
319     void *recvbuf,
320     int count,
321     MPI_Datatype datatype,
322     MPI_Op op, MPI_Comm comm)
323 {
324
325   int mpi_errno = MPI_SUCCESS;
326   //int rank = 0,
327   int comm_size = 0;
328
329   comm_size = comm->size();
330   //rank = comm->rank();
331
332   if (count == 0) {
333       return MPI_SUCCESS;
334   }
335
336   if (mv2_allreduce_thresholds_table == NULL)
337     init_mv2_allreduce_tables_stampede();
338
339   /* check if multiple threads are calling this collective function */
340
341   MPI_Aint sendtype_size = 0;
342   long nbytes = 0;
343   int is_commutative = 0;
344   MPI_Aint true_lb, true_extent;
345
346   sendtype_size=datatype->size();
347   nbytes = count * sendtype_size;
348
349   datatype->extent(&true_lb, &true_extent);
350   is_commutative = op->is_commutative();
351
352   {
353     int range = 0, range_threshold = 0, range_threshold_intra = 0;
354     int is_two_level = 0;
355
356     /* Search for the corresponding system size inside the tuning table */
357     while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
358         (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
359         range++;
360     }
361     /* Search for corresponding inter-leader function */
362     /* skip mcast poiters if mcast is not available */
363     if(mv2_allreduce_thresholds_table[range].mcast_enabled != 1){
364         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
365             && ((mv2_allreduce_thresholds_table[range].
366                 inter_leader[range_threshold].MV2_pt_Allreducection
367                 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
368                 (mv2_allreduce_thresholds_table[range].
369                     inter_leader[range_threshold].MV2_pt_Allreducection
370                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
371             )) {
372             range_threshold++;
373         }
374     }
375     while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
376         && (nbytes >
377     mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
378     && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
379         range_threshold++;
380     }
381     if(mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold] == 1){
382         is_two_level = 1;
383     }
384     /* Search for corresponding intra-node function */
385     while ((range_threshold_intra <
386         (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
387         && (nbytes >
388     mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
389     && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
390         -1)) {
391         range_threshold_intra++;
392     }
393
394     MV2_Allreducection = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
395                                                                                 .MV2_pt_Allreducection;
396
397     MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
398                                                                                     .MV2_pt_Allreducection;
399
400     /* check if mcast is ready, otherwise replace mcast with other algorithm */
401     if((MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
402         (MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
403         {
404           MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
405         }
406         if(is_two_level != 1) {
407             MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
408         }
409     }
410
411     if(is_two_level == 1){
412         // check if shm is ready, if not use other algorithm first
413         if (is_commutative) {
414           if(comm->get_leaders_comm()==MPI_COMM_NULL){
415             comm->init_smp();
416           }
417           mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
418                                                      datatype, op, comm);
419                 } else {
420         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
421             datatype, op, comm);
422         }
423     } else {
424         mpi_errno = MV2_Allreducection(sendbuf, recvbuf, count,
425             datatype, op, comm);
426     }
427   }
428
429   //comm->ch.intra_node_done=0;
430
431   return (mpi_errno);
432
433
434 }
435
436
437 int Coll_alltoallv_mvapich2::alltoallv(const void *sbuf, const int *scounts, const int *sdisps,
438     MPI_Datatype sdtype,
439     void *rbuf, const int *rcounts, const int *rdisps,
440     MPI_Datatype rdtype,
441     MPI_Comm  comm
442 )
443 {
444
445   if (sbuf == MPI_IN_PLACE) {
446       return Coll_alltoallv_ompi_basic_linear::alltoallv(sbuf, scounts, sdisps, sdtype,
447           rbuf, rcounts, rdisps,rdtype,
448           comm);
449   } else     /* For starters, just keep the original algorithm. */
450   return Coll_alltoallv_ring::alltoallv(sbuf, scounts, sdisps, sdtype,
451       rbuf, rcounts, rdisps,rdtype,
452       comm);
453 }
454
455
456 int Coll_barrier_mvapich2::barrier(MPI_Comm  comm)
457 {
458   return Coll_barrier_mvapich2_pair::barrier(comm);
459 }
460
461
462
463
464 int Coll_bcast_mvapich2::bcast(void *buffer,
465     int count,
466     MPI_Datatype datatype,
467     int root, MPI_Comm comm)
468 {
469     int mpi_errno = MPI_SUCCESS;
470     int comm_size/*, rank*/;
471     int two_level_bcast = 1;
472     long nbytes = 0;
473     int range = 0;
474     int range_threshold = 0;
475     int range_threshold_intra = 0;
476     // int is_homogeneous, is_contig;
477     MPI_Aint type_size;
478     //, position;
479     // unsigned char *tmp_buf = NULL;
480     MPI_Comm shmem_comm;
481     //MPID_Datatype *dtp;
482
483     if (count == 0)
484         return MPI_SUCCESS;
485     if(comm->get_leaders_comm()==MPI_COMM_NULL){
486       comm->init_smp();
487     }
488     if (not mv2_bcast_thresholds_table)
489       init_mv2_bcast_tables_stampede();
490     comm_size = comm->size();
491     //rank = comm->rank();
492
493     //is_contig=1;
494 /*    if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
495 /*        is_contig = 1;*/
496 /*    else {*/
497 /*        MPID_Datatype_get_ptr(datatype, dtp);*/
498 /*        is_contig = dtp->is_contig;*/
499 /*    }*/
500
501     // is_homogeneous = 1;
502
503     /* MPI_Type_size() might not give the accurate size of the packed
504      * datatype for heterogeneous systems (because of padding, encoding,
505      * etc). On the other hand, MPI_Pack_size() can become very
506      * expensive, depending on the implementation, especially for
507      * heterogeneous systems. We want to use MPI_Type_size() wherever
508      * possible, and MPI_Pack_size() in other places.
509      */
510     //if (is_homogeneous) {
511         type_size=datatype->size();
512
513    /* } else {
514         MPIR_Pack_size_impl(1, datatype, &type_size);
515     }*/
516     nbytes =  (count) * (type_size);
517
518     /* Search for the corresponding system size inside the tuning table */
519     while ((range < (mv2_size_bcast_tuning_table - 1)) &&
520            (comm_size > mv2_bcast_thresholds_table[range].numproc)) {
521         range++;
522     }
523     /* Search for corresponding inter-leader function */
524     while ((range_threshold < (mv2_bcast_thresholds_table[range].size_inter_table - 1))
525            && (nbytes >
526                mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max)
527            && (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
528         range_threshold++;
529     }
530
531     /* Search for corresponding intra-node function */
532     while ((range_threshold_intra <
533             (mv2_bcast_thresholds_table[range].size_intra_table - 1))
534            && (nbytes >
535                mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max)
536            && (mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max !=
537                -1)) {
538         range_threshold_intra++;
539     }
540
541     MV2_Bcast_function =
542         mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
543         MV2_pt_Bcast_function;
544
545     MV2_Bcast_intra_node_function =
546         mv2_bcast_thresholds_table[range].
547         intra_node[range_threshold_intra].MV2_pt_Bcast_function;
548
549 /*    if (mv2_user_bcast_intra == NULL && */
550 /*            MV2_Bcast_intra_node_function == &MPIR_Knomial_Bcast_intra_node_MV2) {*/
551 /*            MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;*/
552 /*    }*/
553
554     if (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
555         zcpy_pipelined_knomial_factor != -1) {
556         zcpy_knomial_factor =
557             mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
558             zcpy_pipelined_knomial_factor;
559     }
560
561     if (mv2_pipelined_zcpy_knomial_factor != -1) {
562         zcpy_knomial_factor = mv2_pipelined_zcpy_knomial_factor;
563     }
564
565     if(MV2_Bcast_intra_node_function == NULL) {
566         /* if tuning table do not have any intra selection, set func pointer to
567         ** default one for mcast intra node */
568         MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;
569     }
570
571     /* Set value of pipeline segment size */
572     bcast_segment_size = mv2_bcast_thresholds_table[range].bcast_segment_size;
573
574     /* Set value of inter node knomial factor */
575     mv2_inter_node_knomial_factor = mv2_bcast_thresholds_table[range].inter_node_knomial_factor;
576
577     /* Set value of intra node knomial factor */
578     mv2_intra_node_knomial_factor = mv2_bcast_thresholds_table[range].intra_node_knomial_factor;
579
580     /* Check if we will use a two level algorithm or not */
581     two_level_bcast =
582 #if defined(_MCST_SUPPORT_)
583         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold]
584         || comm->ch.is_mcast_ok;
585 #else
586         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold];
587 #endif
588      if (two_level_bcast == 1) {
589        // if (not is_contig || not is_homogeneous) {
590 //   tmp_buf = smpi_get_tmp_sendbuffer(nbytes);
591
592 /*            position = 0;*/
593 /*            if (rank == root) {*/
594 /*                mpi_errno =*/
595 /*                    MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
596 /*                if (mpi_errno)*/
597 /*                    MPIU_ERR_POP(mpi_errno);*/
598 /*            }*/
599 // }
600 #ifdef CHANNEL_MRAIL_GEN2
601         if ((mv2_enable_zcpy_bcast == 1) &&
602               (&MPIR_Pipelined_Bcast_Zcpy_MV2 == MV2_Bcast_function)) {
603           // if (not is_contig || not is_homogeneous) {
604           //   mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
605           // } else {
606                 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(buffer, count, datatype,
607                                                  root, comm);
608           // }
609         } else
610 #endif /* defined(CHANNEL_MRAIL_GEN2) */
611         {
612             shmem_comm = comm->get_intra_comm();
613             // if (not is_contig || not is_homogeneous) {
614             //   MPIR_Bcast_tune_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE, root, comm);
615             // } else {
616               MPIR_Bcast_tune_inter_node_helper_MV2(buffer, count, datatype, root, comm);
617             // }
618
619             /* We are now done with the inter-node phase */
620
621
622                     root = INTRA_NODE_ROOT;
623
624                     // if (not is_contig || not is_homogeneous) {
625                     //       mpi_errno = MV2_Bcast_intra_node_function(tmp_buf, nbytes, MPI_BYTE, root, shmem_comm);
626                     // } else {
627                     mpi_errno = MV2_Bcast_intra_node_function(buffer, count,
628                                                               datatype, root, shmem_comm);
629
630                     // }
631         }
632         /*        if (not is_contig || not is_homogeneous) {*/
633         /*            if (rank != root) {*/
634         /*                position = 0;*/
635         /*                mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer,*/
636         /*                                             count, datatype);*/
637         /*            }*/
638         /*        }*/
639     } else {
640         /* We use Knomial for intra node */
641         MV2_Bcast_intra_node_function = &MPIR_Knomial_Bcast_intra_node_MV2;
642 /*        if (mv2_enable_shmem_bcast == 0) {*/
643             /* Fall back to non-tuned version */
644 /*            MPIR_Bcast_intra_MV2(buffer, count, datatype, root, comm);*/
645 /*        } else {*/
646             mpi_errno = MV2_Bcast_function(buffer, count, datatype, root,
647                                            comm);
648
649 /*        }*/
650     }
651
652
653     return mpi_errno;
654
655 }
656
657
658
659 int Coll_reduce_mvapich2::reduce(const void *sendbuf,
660     void *recvbuf,
661     int count,
662     MPI_Datatype datatype,
663     MPI_Op op, int root, MPI_Comm comm)
664 {
665   if(mv2_reduce_thresholds_table == NULL)
666     init_mv2_reduce_tables_stampede();
667
668   int mpi_errno = MPI_SUCCESS;
669   int range = 0;
670   int range_threshold = 0;
671   int range_intra_threshold = 0;
672   int is_commutative, pof2;
673   int comm_size = 0;
674   long nbytes = 0;
675   int sendtype_size;
676   int is_two_level = 0;
677
678   comm_size = comm->size();
679   sendtype_size=datatype->size();
680   nbytes = count * sendtype_size;
681
682   if (count == 0)
683     return MPI_SUCCESS;
684
685   is_commutative = (op==MPI_OP_NULL || op->is_commutative());
686
687   /* find nearest power-of-two less than or equal to comm_size */
688   for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
689   pof2 >>=1;
690
691
692   /* Search for the corresponding system size inside the tuning table */
693   while ((range < (mv2_size_reduce_tuning_table - 1)) &&
694       (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
695       range++;
696   }
697   /* Search for corresponding inter-leader function */
698   while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
699       && (nbytes >
700   mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
701   && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
702       -1)) {
703       range_threshold++;
704   }
705
706   /* Search for corresponding intra node function */
707   while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
708       && (nbytes >
709   mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
710   && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
711       -1)) {
712       range_intra_threshold++;
713   }
714
715   /* Set intra-node function pt for reduce_two_level */
716   MV2_Reduce_intra_function =
717       mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
718       MV2_pt_Reduce_function;
719   /* Set inter-leader pt */
720   MV2_Reduce_function =
721       mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
722       MV2_pt_Reduce_function;
723
724   if(mv2_reduce_intra_knomial_factor<0)
725     {
726       mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
727     }
728   if(mv2_reduce_inter_knomial_factor<0)
729     {
730       mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
731     }
732   if(mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold] == 1){
733       is_two_level = 1;
734   }
735   /* We call Reduce function */
736   if(is_two_level == 1)
737     {
738        if (is_commutative == 1) {
739          if(comm->get_leaders_comm()==MPI_COMM_NULL){
740            comm->init_smp();
741          }
742          mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count,
743                                            datatype, op, root, comm);
744         } else {
745       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
746           datatype, op, root, comm);
747       }
748     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
749         if(is_commutative ==1)
750           {
751             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
752                 datatype, op, root, comm);
753           } else {
754               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
755                   datatype, op, root, comm);
756           }
757     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
758         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
759           {
760             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
761                 datatype, op, root, comm);
762           } else {
763               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
764                   datatype, op, root, comm);
765           }
766     } else {
767         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count,
768             datatype, op, root, comm);
769     }
770
771
772   return mpi_errno;
773
774 }
775
776
777 int Coll_reduce_scatter_mvapich2::reduce_scatter(const void *sendbuf, void *recvbuf, const int *recvcnts,
778     MPI_Datatype datatype, MPI_Op op,
779     MPI_Comm comm)
780 {
781   int mpi_errno = MPI_SUCCESS;
782   int i = 0, comm_size = comm->size(), total_count = 0, type_size =
783       0, nbytes = 0;
784   int is_commutative = 0;
785   int* disps          = new int[comm_size];
786
787   if(mv2_red_scat_thresholds_table==NULL)
788     init_mv2_reduce_scatter_tables_stampede();
789
790   is_commutative=(op==MPI_OP_NULL || op->is_commutative());
791   for (i = 0; i < comm_size; i++) {
792       disps[i] = total_count;
793       total_count += recvcnts[i];
794   }
795
796   type_size=datatype->size();
797   nbytes = total_count * type_size;
798
799   if (is_commutative) {
800     int range           = 0;
801     int range_threshold = 0;
802
803       /* Search for the corresponding system size inside the tuning table */
804       while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
805           (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
806           range++;
807       }
808       /* Search for corresponding inter-leader function */
809       while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
810           && (nbytes >
811       mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
812       && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
813           -1)) {
814           range_threshold++;
815       }
816
817       /* Set inter-leader pt */
818       MV2_Red_scat_function =
819           mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
820           MV2_pt_Red_scat_function;
821
822       mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
823           recvcnts, datatype,
824           op, comm);
825   } else {
826       int is_block_regular = 1;
827       for (i = 0; i < (comm_size - 1); ++i) {
828           if (recvcnts[i] != recvcnts[i+1]) {
829               is_block_regular = 0;
830               break;
831           }
832       }
833       int pof2 = 1;
834       while (pof2 < comm_size) pof2 <<= 1;
835       if (pof2 == comm_size && is_block_regular) {
836           /* noncommutative, pof2 size, and block regular */
837           MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
838               recvcnts, datatype,
839               op, comm);
840       }
841       mpi_errno =  Coll_reduce_scatter_mpich_rdb::reduce_scatter(sendbuf, recvbuf,
842           recvcnts, datatype,
843           op, comm);
844   }
845   delete[] disps;
846   return mpi_errno;
847
848 }
849
850
851
852 int Coll_scatter_mvapich2::scatter(const void *sendbuf,
853     int sendcnt,
854     MPI_Datatype sendtype,
855     void *recvbuf,
856     int recvcnt,
857     MPI_Datatype recvtype,
858     int root, MPI_Comm comm)
859 {
860   int range = 0, range_threshold = 0, range_threshold_intra = 0;
861   int mpi_errno = MPI_SUCCESS;
862   //   int mpi_errno_ret = MPI_SUCCESS;
863   int rank, nbytes, comm_size;
864   int partial_sub_ok = 0;
865   int conf_index = 0;
866      MPI_Comm shmem_comm;
867   //    MPID_Comm *shmem_commptr=NULL;
868   if(mv2_scatter_thresholds_table==NULL)
869     init_mv2_scatter_tables_stampede();
870
871   if(comm->get_leaders_comm()==MPI_COMM_NULL){
872     comm->init_smp();
873   }
874
875   comm_size = comm->size();
876
877   rank = comm->rank();
878
879   if (rank == root) {
880     int sendtype_size = sendtype->size();
881     nbytes            = sendcnt * sendtype_size;
882   } else {
883     int recvtype_size = recvtype->size();
884     nbytes            = recvcnt * recvtype_size;
885   }
886
887     // check if safe to use partial subscription mode
888     if (comm->is_uniform()) {
889
890         shmem_comm = comm->get_intra_comm();
891         if (mv2_scatter_table_ppn_conf[0] == -1) {
892             // Indicating user defined tuning
893             conf_index = 0;
894         }else{
895           int local_size = shmem_comm->size();
896           int i          = 0;
897             do {
898                 if (local_size == mv2_scatter_table_ppn_conf[i]) {
899                     conf_index = i;
900                     partial_sub_ok = 1;
901                     break;
902                 }
903                 i++;
904             } while(i < mv2_scatter_num_ppn_conf);
905         }
906     }
907
908   if (partial_sub_ok != 1) {
909       conf_index = 0;
910   }
911
912   /* Search for the corresponding system size inside the tuning table */
913   while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
914       (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
915       range++;
916   }
917   /* Search for corresponding inter-leader function */
918   while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
919       && (nbytes >
920   mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
921   && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
922       range_threshold++;
923   }
924
925   /* Search for corresponding intra-node function */
926   while ((range_threshold_intra <
927       (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
928       && (nbytes >
929   mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
930   && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
931       -1)) {
932       range_threshold_intra++;
933   }
934
935   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
936                                                                                       .MV2_pt_Scatter_function;
937
938   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
939 #if defined(_MCST_SUPPORT_)
940       if(comm->ch.is_mcast_ok == 1
941           && mv2_use_mcast_scatter == 1
942           && comm->ch.shmem_coll_ok == 1) {
943           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
944       } else
945 #endif /*#if defined(_MCST_SUPPORT_) */
946         {
947           if(mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].
948               MV2_pt_Scatter_function != NULL) {
949               MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1]
950                                                                                                   .MV2_pt_Scatter_function;
951           } else {
952               /* Fallback! */
953               MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
954           }
955         }
956   }
957
958   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
959       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
960        if( comm->is_blocked()) {
961              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
962                                 .MV2_pt_Scatter_function;
963
964              mpi_errno =
965                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
966                                         recvbuf, recvcnt, recvtype, root,
967                                         comm);
968          } else {
969       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
970           recvbuf, recvcnt, recvtype, root,
971           comm);
972
973       }
974   } else {
975       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
976           recvbuf, recvcnt, recvtype, root,
977           comm);
978   }
979   return (mpi_errno);
980 }
981
982 }
983 }
984
985 void smpi_coll_cleanup_mvapich2()
986 {
987   if (mv2_alltoall_thresholds_table)
988     delete[] mv2_alltoall_thresholds_table[0];
989   delete[] mv2_alltoall_thresholds_table;
990   delete[] mv2_size_alltoall_tuning_table;
991   delete[] mv2_alltoall_table_ppn_conf;
992
993   delete[] mv2_gather_thresholds_table;
994   if (mv2_allgather_thresholds_table)
995     delete[] mv2_allgather_thresholds_table[0];
996   delete[] mv2_size_allgather_tuning_table;
997   delete[] mv2_allgather_table_ppn_conf;
998   delete[] mv2_allgather_thresholds_table;
999
1000   delete[] mv2_allgatherv_thresholds_table;
1001   delete[] mv2_reduce_thresholds_table;
1002   delete[] mv2_red_scat_thresholds_table;
1003   delete[] mv2_allreduce_thresholds_table;
1004   delete[] mv2_bcast_thresholds_table;
1005   if (mv2_scatter_thresholds_table)
1006     delete[] mv2_scatter_thresholds_table[0];
1007   delete[] mv2_scatter_thresholds_table;
1008   delete[] mv2_size_scatter_tuning_table;
1009   delete[] mv2_scatter_table_ppn_conf;
1010 }