Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of git+ssh://scm.gforge.inria.fr//gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.cpp
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2010, 2013-2014. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.h"
10
11 #include "smpi_mvapich2_selector_stampede.h"
12
13
14
15 int smpi_coll_tuned_alltoall_mvapich2( void *sendbuf, int sendcount, 
16     MPI_Datatype sendtype,
17     void* recvbuf, int recvcount,
18     MPI_Datatype recvtype,
19     MPI_Comm comm)
20 {
21
22   if(mv2_alltoall_table_ppn_conf==NULL)
23     init_mv2_alltoall_tables_stampede();
24
25   int sendtype_size, recvtype_size, comm_size;
26   char * tmp_buf = NULL;
27   int mpi_errno=MPI_SUCCESS;
28   int range = 0;
29   int range_threshold = 0;
30   int conf_index = 0;
31   comm_size =  comm->size();
32
33   sendtype_size=smpi_datatype_size(sendtype);
34   recvtype_size=smpi_datatype_size(recvtype);
35   long nbytes = sendtype_size * sendcount;
36
37   /* check if safe to use partial subscription mode */
38
39   /* Search for the corresponding system size inside the tuning table */
40   while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
41       (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
42       range++;
43   }
44   /* Search for corresponding inter-leader function */
45   while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
46       && (nbytes >
47   mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
48   && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
49       range_threshold++;
50   }
51   MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
52                                                                                       .MV2_pt_Alltoall_function;
53
54   if(sendbuf != MPI_IN_PLACE) {
55       mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
56           recvbuf, recvcount, recvtype,
57           comm);
58   } else {
59       range_threshold = 0;
60       if(nbytes <
61           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
62           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
63       ) {
64           tmp_buf = (char *)smpi_get_tmp_sendbuffer( comm_size * recvcount * recvtype_size );
65           mpi_errno = smpi_datatype_copy((char *)recvbuf,
66               comm_size*recvcount, recvtype,
67               (char *)tmp_buf,
68               comm_size*recvcount, recvtype);
69
70           mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype,
71               recvbuf, recvcount, recvtype,
72               comm );
73           smpi_free_tmp_buffer(tmp_buf);
74       } else {
75           mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
76               recvbuf, recvcount, recvtype,
77               comm );
78       }
79   }
80
81
82   return (mpi_errno);
83 }
84
85 int smpi_coll_tuned_allgather_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
86     void *recvbuf, int recvcount, MPI_Datatype recvtype,
87     MPI_Comm comm)
88 {
89
90   int mpi_errno = MPI_SUCCESS;
91   long nbytes = 0, comm_size, recvtype_size;
92   int range = 0;
93   int partial_sub_ok = 0;
94   int conf_index = 0;
95   int range_threshold = 0;
96   int is_two_level = 0;
97   int local_size = -1;
98   MPI_Comm shmem_comm;
99   //MPI_Comm *shmem_commptr=NULL;
100   /* Get the size of the communicator */
101   comm_size = comm->size();
102   recvtype_size=smpi_datatype_size(recvtype);
103   nbytes = recvtype_size * recvcount;
104
105   if(mv2_allgather_table_ppn_conf==NULL)
106     init_mv2_allgather_tables_stampede();
107     
108   if(comm->get_leaders_comm()==MPI_COMM_NULL){
109     comm->init_smp();
110   }
111
112   int i;
113   if (comm->is_uniform()){
114     shmem_comm = comm->get_intra_comm();
115     local_size = shmem_comm->size();
116     i = 0;
117     if (mv2_allgather_table_ppn_conf[0] == -1) {
118       // Indicating user defined tuning
119       conf_index = 0;
120       goto conf_check_end;
121     }
122     do {
123       if (local_size == mv2_allgather_table_ppn_conf[i]) {
124         conf_index = i;
125         partial_sub_ok = 1;
126         break;
127       }
128       i++;
129     } while(i < mv2_allgather_num_ppn_conf);
130   }
131   conf_check_end:
132   if (partial_sub_ok != 1) {
133     conf_index = 0;
134   }
135   
136   /* Search for the corresponding system size inside the tuning table */
137   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
138       (comm_size >
139   mv2_allgather_thresholds_table[conf_index][range].numproc)) {
140       range++;
141   }
142   /* Search for corresponding inter-leader function */
143   while ((range_threshold <
144       (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
145       && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
146       && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
147           -1)) {
148       range_threshold++;
149   }
150
151   /* Set inter-leader pt */
152   MV2_Allgather_function =
153       mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
154       MV2_pt_Allgather_function;
155
156   is_two_level =  mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
157
158   /* intracommunicator */
159   if(is_two_level ==1){
160     if(partial_sub_ok ==1){
161       if (comm->is_blocked()){
162       mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
163                             recvbuf, recvcount, recvtype,
164                             comm);
165       }else{
166       mpi_errno = smpi_coll_tuned_allgather_mpich(sendbuf, sendcount, sendtype,
167                             recvbuf, recvcount, recvtype,
168                             comm);
169       }
170     } else {
171       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
172           recvbuf, recvcount, recvtype,
173           comm);
174     }
175   } else if(MV2_Allgather_function == &MPIR_Allgather_Bruck_MV2
176       || MV2_Allgather_function == &MPIR_Allgather_RD_MV2
177       || MV2_Allgather_function == &MPIR_Allgather_Ring_MV2) {
178       mpi_errno = MV2_Allgather_function(sendbuf, sendcount, sendtype,
179           recvbuf, recvcount, recvtype,
180           comm);
181   }else{
182       return MPI_ERR_OTHER;
183   }
184
185   return mpi_errno;
186 }
187
188
189 int smpi_coll_tuned_gather_mvapich2(void *sendbuf,
190     int sendcnt,
191     MPI_Datatype sendtype,
192     void *recvbuf,
193     int recvcnt,
194     MPI_Datatype recvtype,
195     int root, MPI_Comm  comm)
196 {
197   if(mv2_gather_thresholds_table==NULL)
198     init_mv2_gather_tables_stampede();
199
200   int mpi_errno = MPI_SUCCESS;
201   int range = 0;
202   int range_threshold = 0;
203   int range_intra_threshold = 0;
204   long nbytes = 0;
205   int comm_size = 0;
206   int recvtype_size, sendtype_size;
207   int rank = -1;
208   comm_size = comm->size();
209   rank = comm->rank();
210
211   if (rank == root) {
212       recvtype_size=smpi_datatype_size(recvtype);
213       nbytes = recvcnt * recvtype_size;
214   } else {
215       sendtype_size=smpi_datatype_size(sendtype);
216       nbytes = sendcnt * sendtype_size;
217   }
218
219   /* Search for the corresponding system size inside the tuning table */
220   while ((range < (mv2_size_gather_tuning_table - 1)) &&
221       (comm_size > mv2_gather_thresholds_table[range].numproc)) {
222       range++;
223   }
224   /* Search for corresponding inter-leader function */
225   while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
226       && (nbytes >
227   mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
228   && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
229       -1)) {
230       range_threshold++;
231   }
232
233   /* Search for corresponding intra node function */
234   while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
235       && (nbytes >
236   mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
237   && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
238       -1)) {
239       range_intra_threshold++;
240   }
241   
242     if (comm->is_blocked() ) {
243         // Set intra-node function pt for gather_two_level 
244         MV2_Gather_intra_node_function = 
245                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
246                               MV2_pt_Gather_function;
247         //Set inter-leader pt 
248         MV2_Gather_inter_leader_function =
249                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
250                               MV2_pt_Gather_function;
251         // We call Gather function 
252         mpi_errno =
253             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
254                                              recvtype, root, comm);
255
256     } else {
257   // Indeed, direct (non SMP-aware)gather is MPICH one
258   mpi_errno = smpi_coll_tuned_gather_mpich(sendbuf, sendcnt, sendtype,
259       recvbuf, recvcnt, recvtype,
260       root, comm);
261   }
262
263   return mpi_errno;
264 }
265
266
267 int smpi_coll_tuned_allgatherv_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
268     void *recvbuf, int *recvcounts, int *displs,
269     MPI_Datatype recvtype, MPI_Comm  comm )
270 {
271   int mpi_errno = MPI_SUCCESS;
272   int range = 0, comm_size, total_count, recvtype_size, i;
273   int range_threshold = 0;
274   long nbytes = 0;
275
276   if(mv2_allgatherv_thresholds_table==NULL)
277     init_mv2_allgatherv_tables_stampede();
278
279   comm_size = comm->size();
280   total_count = 0;
281   for (i = 0; i < comm_size; i++)
282     total_count += recvcounts[i];
283
284   recvtype_size=smpi_datatype_size(recvtype);
285   nbytes = total_count * recvtype_size;
286
287   /* Search for the corresponding system size inside the tuning table */
288   while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
289       (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
290       range++;
291   }
292   /* Search for corresponding inter-leader function */
293   while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
294       && (nbytes >
295   comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
296   && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
297       -1)) {
298       range_threshold++;
299   }
300   /* Set inter-leader pt */
301   MV2_Allgatherv_function =
302       mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
303       MV2_pt_Allgatherv_function;
304
305   if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
306     {
307       if(!(comm_size & (comm_size - 1)))
308         {
309           mpi_errno =
310               MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount,
311                   sendtype, recvbuf,
312                   recvcounts, displs,
313                   recvtype, comm);
314         } else {
315             mpi_errno =
316                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
317                     sendtype, recvbuf,
318                     recvcounts, displs,
319                     recvtype, comm);
320         }
321     } else {
322         mpi_errno =
323             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
324                 recvbuf, recvcounts, displs,
325                 recvtype, comm);
326     }
327
328   return mpi_errno;
329 }
330
331
332
333 int smpi_coll_tuned_allreduce_mvapich2(void *sendbuf,
334     void *recvbuf,
335     int count,
336     MPI_Datatype datatype,
337     MPI_Op op, MPI_Comm comm)
338 {
339
340   int mpi_errno = MPI_SUCCESS;
341   //int rank = 0,
342   int comm_size = 0;
343
344   comm_size = comm->size();
345   //rank = comm->rank();
346
347   if (count == 0) {
348       return MPI_SUCCESS;
349   }
350
351   if (mv2_allreduce_thresholds_table == NULL)
352     init_mv2_allreduce_tables_stampede();
353
354   /* check if multiple threads are calling this collective function */
355
356   MPI_Aint sendtype_size = 0;
357   long nbytes = 0;
358   int range = 0, range_threshold = 0, range_threshold_intra = 0;
359   int is_two_level = 0;
360   int is_commutative = 0;
361   MPI_Aint true_lb, true_extent;
362
363   sendtype_size=smpi_datatype_size(datatype);
364   nbytes = count * sendtype_size;
365
366   smpi_datatype_extent(datatype, &true_lb, &true_extent);
367   //MPI_Op *op_ptr;
368   //is_commutative = op->is_commutative();
369
370   {
371     /* Search for the corresponding system size inside the tuning table */
372     while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
373         (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
374         range++;
375     }
376     /* Search for corresponding inter-leader function */
377     /* skip mcast poiters if mcast is not available */
378     if(mv2_allreduce_thresholds_table[range].mcast_enabled != 1){
379         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
380             && ((mv2_allreduce_thresholds_table[range].
381                 inter_leader[range_threshold].MV2_pt_Allreduce_function
382                 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
383                 (mv2_allreduce_thresholds_table[range].
384                     inter_leader[range_threshold].MV2_pt_Allreduce_function
385                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
386             )) {
387             range_threshold++;
388         }
389     }
390     while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
391         && (nbytes >
392     mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
393     && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
394         range_threshold++;
395     }
396     if(mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold] == 1){
397         is_two_level = 1;
398     }
399     /* Search for corresponding intra-node function */
400     while ((range_threshold_intra <
401         (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
402         && (nbytes >
403     mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
404     && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
405         -1)) {
406         range_threshold_intra++;
407     }
408
409     MV2_Allreduce_function = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
410                                                                                 .MV2_pt_Allreduce_function;
411
412     MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
413                                                                                     .MV2_pt_Allreduce_function;
414
415     /* check if mcast is ready, otherwise replace mcast with other algorithm */
416     if((MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
417         (MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
418         {
419           MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
420         }
421         if(is_two_level != 1) {
422             MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
423         }
424     }
425
426     if(is_two_level == 1){
427         // check if shm is ready, if not use other algorithm first
428         if (is_commutative) {
429           if(comm->get_leaders_comm()==MPI_COMM_NULL){
430             comm->init_smp();
431           }
432           mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
433                                                      datatype, op, comm);
434                 } else {
435         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
436             datatype, op, comm);
437         }
438     } else {
439         mpi_errno = MV2_Allreduce_function(sendbuf, recvbuf, count,
440             datatype, op, comm);
441     }
442   }
443
444   //comm->ch.intra_node_done=0;
445
446   return (mpi_errno);
447
448
449 }
450
451
452 int smpi_coll_tuned_alltoallv_mvapich2(void *sbuf, int *scounts, int *sdisps,
453     MPI_Datatype sdtype,
454     void *rbuf, int *rcounts, int *rdisps,
455     MPI_Datatype rdtype,
456     MPI_Comm  comm
457 )
458 {
459
460   if (sbuf == MPI_IN_PLACE) {
461       return smpi_coll_tuned_alltoallv_ompi_basic_linear(sbuf, scounts, sdisps, sdtype,
462           rbuf, rcounts, rdisps,rdtype,
463           comm);
464   } else     /* For starters, just keep the original algorithm. */
465   return smpi_coll_tuned_alltoallv_ring(sbuf, scounts, sdisps, sdtype,
466       rbuf, rcounts, rdisps,rdtype,
467       comm);
468 }
469
470
471 int smpi_coll_tuned_barrier_mvapich2(MPI_Comm  comm)
472 {   
473   return smpi_coll_tuned_barrier_mvapich2_pair(comm);
474 }
475
476
477
478
479 int smpi_coll_tuned_bcast_mvapich2(void *buffer,
480     int count,
481     MPI_Datatype datatype,
482     int root, MPI_Comm comm)
483 {
484     int mpi_errno = MPI_SUCCESS;
485     int comm_size/*, rank*/;
486     int two_level_bcast = 1;
487     long nbytes = 0; 
488     int range = 0;
489     int range_threshold = 0;
490     int range_threshold_intra = 0;
491     int is_homogeneous, is_contig;
492     MPI_Aint type_size;
493     //, position;
494     void *tmp_buf = NULL;
495     MPI_Comm shmem_comm;
496     //MPID_Datatype *dtp;
497
498     if (count == 0)
499         return MPI_SUCCESS;
500     if(comm->get_leaders_comm()==MPI_COMM_NULL){
501       comm->init_smp();
502     }
503     if(!mv2_bcast_thresholds_table)
504       init_mv2_bcast_tables_stampede();
505     comm_size = comm->size();
506     //rank = comm->rank();
507
508     is_contig=1;
509 /*    if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
510 /*        is_contig = 1;*/
511 /*    else {*/
512 /*        MPID_Datatype_get_ptr(datatype, dtp);*/
513 /*        is_contig = dtp->is_contig;*/
514 /*    }*/
515
516     is_homogeneous = 1;
517
518     /* MPI_Type_size() might not give the accurate size of the packed
519      * datatype for heterogeneous systems (because of padding, encoding,
520      * etc). On the other hand, MPI_Pack_size() can become very
521      * expensive, depending on the implementation, especially for
522      * heterogeneous systems. We want to use MPI_Type_size() wherever
523      * possible, and MPI_Pack_size() in other places.
524      */
525     //if (is_homogeneous) {
526         type_size=smpi_datatype_size(datatype);
527
528    /* } else {
529         MPIR_Pack_size_impl(1, datatype, &type_size);
530     }*/
531     nbytes =  (count) * (type_size);
532
533     /* Search for the corresponding system size inside the tuning table */
534     while ((range < (mv2_size_bcast_tuning_table - 1)) &&
535            (comm_size > mv2_bcast_thresholds_table[range].numproc)) {
536         range++;
537     }
538     /* Search for corresponding inter-leader function */
539     while ((range_threshold < (mv2_bcast_thresholds_table[range].size_inter_table - 1))
540            && (nbytes >
541                mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max)
542            && (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
543         range_threshold++;
544     }
545
546     /* Search for corresponding intra-node function */
547     while ((range_threshold_intra <
548             (mv2_bcast_thresholds_table[range].size_intra_table - 1))
549            && (nbytes >
550                mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max)
551            && (mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max !=
552                -1)) {
553         range_threshold_intra++;
554     }
555
556     MV2_Bcast_function =
557         mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
558         MV2_pt_Bcast_function;
559
560     MV2_Bcast_intra_node_function =
561         mv2_bcast_thresholds_table[range].
562         intra_node[range_threshold_intra].MV2_pt_Bcast_function;
563
564 /*    if (mv2_user_bcast_intra == NULL && */
565 /*            MV2_Bcast_intra_node_function == &MPIR_Knomial_Bcast_intra_node_MV2) {*/
566 /*            MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;*/
567 /*    }*/
568
569     if (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
570         zcpy_pipelined_knomial_factor != -1) {
571         zcpy_knomial_factor = 
572             mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
573             zcpy_pipelined_knomial_factor;
574     }
575
576     if (mv2_pipelined_zcpy_knomial_factor != -1) {
577         zcpy_knomial_factor = mv2_pipelined_zcpy_knomial_factor;
578     }
579
580     if(MV2_Bcast_intra_node_function == NULL) {
581         /* if tuning table do not have any intra selection, set func pointer to
582         ** default one for mcast intra node */
583         MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;
584     }
585
586     /* Set value of pipeline segment size */
587     bcast_segment_size = mv2_bcast_thresholds_table[range].bcast_segment_size;
588     
589     /* Set value of inter node knomial factor */
590     mv2_inter_node_knomial_factor = mv2_bcast_thresholds_table[range].inter_node_knomial_factor;
591
592     /* Set value of intra node knomial factor */
593     mv2_intra_node_knomial_factor = mv2_bcast_thresholds_table[range].intra_node_knomial_factor;
594
595     /* Check if we will use a two level algorithm or not */
596     two_level_bcast =
597 #if defined(_MCST_SUPPORT_)
598         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold] 
599         || comm->ch.is_mcast_ok;
600 #else
601         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold];
602 #endif
603      if (two_level_bcast == 1) {
604         if (!is_contig || !is_homogeneous) {
605             tmp_buf=(void *)smpi_get_tmp_sendbuffer(nbytes);
606
607 /*            position = 0;*/
608 /*            if (rank == root) {*/
609 /*                mpi_errno =*/
610 /*                    MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
611 /*                if (mpi_errno)*/
612 /*                    MPIU_ERR_POP(mpi_errno);*/
613 /*            }*/
614         }
615 #ifdef CHANNEL_MRAIL_GEN2
616         if ((mv2_enable_zcpy_bcast == 1) &&
617               (&MPIR_Pipelined_Bcast_Zcpy_MV2 == MV2_Bcast_function)) {  
618             if (!is_contig || !is_homogeneous) {
619                 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(tmp_buf, nbytes, MPI_BYTE,
620                                                  root, comm);
621             } else { 
622                 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(buffer, count, datatype,
623                                                  root, comm);
624             } 
625         } else 
626 #endif /* defined(CHANNEL_MRAIL_GEN2) */
627         { 
628             shmem_comm = comm->get_intra_comm();
629             if (!is_contig || !is_homogeneous) {
630                 mpi_errno =
631                     MPIR_Bcast_tune_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE,
632                                                           root, comm);
633             } else {
634                 mpi_errno =
635                     MPIR_Bcast_tune_inter_node_helper_MV2(buffer, count, datatype, root,
636                                                           comm);
637             }
638
639             /* We are now done with the inter-node phase */
640
641
642                     root = INTRA_NODE_ROOT;
643    
644
645                 if (!is_contig || !is_homogeneous) {
646                     mpi_errno = MV2_Bcast_intra_node_function(tmp_buf, nbytes,
647                                                               MPI_BYTE, root, shmem_comm);
648                 } else {
649                     mpi_errno = MV2_Bcast_intra_node_function(buffer, count,
650                                                               datatype, root, shmem_comm);
651
652                 }
653         } 
654 /*        if (!is_contig || !is_homogeneous) {*/
655 /*            if (rank != root) {*/
656 /*                position = 0;*/
657 /*                mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer,*/
658 /*                                             count, datatype);*/
659 /*            }*/
660 /*        }*/
661     } else {
662         /* We use Knomial for intra node */
663         MV2_Bcast_intra_node_function = &MPIR_Knomial_Bcast_intra_node_MV2;
664 /*        if (mv2_enable_shmem_bcast == 0) {*/
665             /* Fall back to non-tuned version */
666 /*            MPIR_Bcast_intra_MV2(buffer, count, datatype, root, comm);*/
667 /*        } else {*/
668             mpi_errno = MV2_Bcast_function(buffer, count, datatype, root,
669                                            comm);
670
671 /*        }*/
672     }
673
674
675     return mpi_errno;
676
677 }
678
679
680
681 int smpi_coll_tuned_reduce_mvapich2( void *sendbuf,
682     void *recvbuf,
683     int count,
684     MPI_Datatype datatype,
685     MPI_Op op, int root, MPI_Comm comm)
686 {
687   if(mv2_reduce_thresholds_table == NULL)
688     init_mv2_reduce_tables_stampede();
689
690   int mpi_errno = MPI_SUCCESS;
691   int range = 0;
692   int range_threshold = 0;
693   int range_intra_threshold = 0;
694   int is_commutative, pof2;
695   int comm_size = 0;
696   long nbytes = 0;
697   int sendtype_size;
698   int is_two_level = 0;
699
700   comm_size = comm->size();
701   sendtype_size=smpi_datatype_size(datatype);
702   nbytes = count * sendtype_size;
703
704   if (count == 0)
705     return MPI_SUCCESS;
706
707   is_commutative = (op==MPI_OP_NULL || op->is_commutative());
708
709   /* find nearest power-of-two less than or equal to comm_size */
710   for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
711   pof2 >>=1;
712
713
714   /* Search for the corresponding system size inside the tuning table */
715   while ((range < (mv2_size_reduce_tuning_table - 1)) &&
716       (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
717       range++;
718   }
719   /* Search for corresponding inter-leader function */
720   while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
721       && (nbytes >
722   mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
723   && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
724       -1)) {
725       range_threshold++;
726   }
727
728   /* Search for corresponding intra node function */
729   while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
730       && (nbytes >
731   mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
732   && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
733       -1)) {
734       range_intra_threshold++;
735   }
736
737   /* Set intra-node function pt for reduce_two_level */
738   MV2_Reduce_intra_function =
739       mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
740       MV2_pt_Reduce_function;
741   /* Set inter-leader pt */
742   MV2_Reduce_function =
743       mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
744       MV2_pt_Reduce_function;
745
746   if(mv2_reduce_intra_knomial_factor<0)
747     {
748       mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
749     }
750   if(mv2_reduce_inter_knomial_factor<0)
751     {
752       mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
753     }
754   if(mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold] == 1){
755       is_two_level = 1;
756   }
757   /* We call Reduce function */
758   if(is_two_level == 1)
759     {
760        if (is_commutative == 1) {
761          if(comm->get_leaders_comm()==MPI_COMM_NULL){
762            comm->init_smp();
763          }
764          mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count, 
765                                            datatype, op, root, comm);
766         } else {
767       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
768           datatype, op, root, comm);
769       }
770     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
771         if(is_commutative ==1)
772           {
773             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
774                 datatype, op, root, comm);
775           } else {
776               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
777                   datatype, op, root, comm);
778           }
779     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
780         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
781           {
782             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
783                 datatype, op, root, comm);
784           } else {
785               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
786                   datatype, op, root, comm);
787           }
788     } else {
789         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
790             datatype, op, root, comm);
791     }
792
793
794   return mpi_errno;
795
796 }
797
798
799 int smpi_coll_tuned_reduce_scatter_mvapich2(void *sendbuf, void *recvbuf, int *recvcnts,
800     MPI_Datatype datatype, MPI_Op op,
801     MPI_Comm comm)
802 {
803   int mpi_errno = MPI_SUCCESS;
804   int i = 0, comm_size = comm->size(), total_count = 0, type_size =
805       0, nbytes = 0;
806   int range = 0;
807   int range_threshold = 0;
808   int is_commutative = 0;
809   int *disps = static_cast<int*>(xbt_malloc(comm_size * sizeof (int)));
810
811   if(mv2_red_scat_thresholds_table==NULL)
812     init_mv2_reduce_scatter_tables_stampede();
813
814   is_commutative=(op==MPI_OP_NULL || op->is_commutative());
815   for (i = 0; i < comm_size; i++) {
816       disps[i] = total_count;
817       total_count += recvcnts[i];
818   }
819
820   type_size=smpi_datatype_size(datatype);
821   nbytes = total_count * type_size;
822
823   if (is_commutative) {
824
825       /* Search for the corresponding system size inside the tuning table */
826       while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
827           (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
828           range++;
829       }
830       /* Search for corresponding inter-leader function */
831       while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
832           && (nbytes >
833       mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
834       && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
835           -1)) {
836           range_threshold++;
837       }
838
839       /* Set inter-leader pt */
840       MV2_Red_scat_function =
841           mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
842           MV2_pt_Red_scat_function;
843
844       mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
845           recvcnts, datatype,
846           op, comm);
847   } else {
848       int is_block_regular = 1;
849       for (i = 0; i < (comm_size - 1); ++i) {
850           if (recvcnts[i] != recvcnts[i+1]) {
851               is_block_regular = 0;
852               break;
853           }
854       }
855       int pof2 = 1;
856       while (pof2 < comm_size) pof2 <<= 1;
857       if (pof2 == comm_size && is_block_regular) {
858           /* noncommutative, pof2 size, and block regular */
859           mpi_errno = MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
860               recvcnts, datatype,
861               op, comm);
862       }
863       mpi_errno =  smpi_coll_tuned_reduce_scatter_mpich_rdb(sendbuf, recvbuf,
864           recvcnts, datatype,
865           op, comm);
866   }
867   xbt_free(disps);
868   return mpi_errno;
869
870 }
871
872
873
874 int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
875     int sendcnt,
876     MPI_Datatype sendtype,
877     void *recvbuf,
878     int recvcnt,
879     MPI_Datatype recvtype,
880     int root, MPI_Comm comm)
881 {
882   int range = 0, range_threshold = 0, range_threshold_intra = 0;
883   int mpi_errno = MPI_SUCCESS;
884   //   int mpi_errno_ret = MPI_SUCCESS;
885   int rank, nbytes, comm_size;
886   int recvtype_size, sendtype_size;
887   int partial_sub_ok = 0;
888   int conf_index = 0;
889     int local_size = -1;
890     int i;
891      MPI_Comm shmem_comm;
892   //    MPID_Comm *shmem_commptr=NULL;
893   if(mv2_scatter_thresholds_table==NULL)
894     init_mv2_scatter_tables_stampede();
895
896   if(comm->get_leaders_comm()==MPI_COMM_NULL){
897     comm->init_smp();
898   }
899   
900   comm_size = comm->size();
901
902   rank = comm->rank();
903
904   if (rank == root) {
905       sendtype_size=smpi_datatype_size(sendtype);
906       nbytes = sendcnt * sendtype_size;
907   } else {
908       recvtype_size=smpi_datatype_size(recvtype);
909       nbytes = recvcnt * recvtype_size;
910   }
911   
912     // check if safe to use partial subscription mode 
913     if (comm->is_uniform()) {
914
915         shmem_comm = comm->get_intra_comm();
916         local_size = shmem_comm->size();
917         i = 0;
918         if (mv2_scatter_table_ppn_conf[0] == -1) {
919             // Indicating user defined tuning 
920             conf_index = 0;
921         }else{
922             do {
923                 if (local_size == mv2_scatter_table_ppn_conf[i]) {
924                     conf_index = i;
925                     partial_sub_ok = 1;
926                     break;
927                 }
928                 i++;
929             } while(i < mv2_scatter_num_ppn_conf);
930         }
931     }
932    
933   if (partial_sub_ok != 1) {
934       conf_index = 0;
935   }
936
937   /* Search for the corresponding system size inside the tuning table */
938   while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
939       (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
940       range++;
941   }
942   /* Search for corresponding inter-leader function */
943   while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
944       && (nbytes >
945   mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
946   && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
947       range_threshold++;
948   }
949
950   /* Search for corresponding intra-node function */
951   while ((range_threshold_intra <
952       (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
953       && (nbytes >
954   mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
955   && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
956       -1)) {
957       range_threshold_intra++;
958   }
959
960   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
961                                                                                       .MV2_pt_Scatter_function;
962
963   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
964 #if defined(_MCST_SUPPORT_)
965       if(comm->ch.is_mcast_ok == 1
966           && mv2_use_mcast_scatter == 1
967           && comm->ch.shmem_coll_ok == 1) {
968           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
969       } else
970 #endif /*#if defined(_MCST_SUPPORT_) */
971         {
972           if(mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].
973               MV2_pt_Scatter_function != NULL) {
974               MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1]
975                                                                                                   .MV2_pt_Scatter_function;
976           } else {
977               /* Fallback! */
978               MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
979           }
980         } 
981   }
982
983   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
984       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
985        if( comm->is_blocked()) {
986              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
987                                 .MV2_pt_Scatter_function;
988
989              mpi_errno =
990                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
991                                         recvbuf, recvcnt, recvtype, root,
992                                         comm);
993          } else {
994       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
995           recvbuf, recvcnt, recvtype, root,
996           comm);
997
998       }
999   } else {
1000       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
1001           recvbuf, recvcnt, recvtype, root,
1002           comm);
1003   }
1004   return (mpi_errno);
1005 }
1006
1007 void smpi_coll_cleanup_mvapich2(void){
1008 int i=0;
1009 if(mv2_alltoall_thresholds_table)
1010   xbt_free(mv2_alltoall_thresholds_table[i]);
1011 xbt_free(mv2_alltoall_thresholds_table);
1012 xbt_free(mv2_size_alltoall_tuning_table);
1013 xbt_free(mv2_alltoall_table_ppn_conf);
1014
1015 xbt_free(mv2_gather_thresholds_table);
1016 if(mv2_allgather_thresholds_table)
1017   xbt_free(mv2_allgather_thresholds_table[0]);
1018 xbt_free(mv2_size_allgather_tuning_table);
1019 xbt_free(mv2_allgather_table_ppn_conf);
1020 xbt_free(mv2_allgather_thresholds_table);
1021
1022 xbt_free(mv2_allgatherv_thresholds_table);
1023 xbt_free(mv2_reduce_thresholds_table);
1024 xbt_free(mv2_red_scat_thresholds_table);
1025 xbt_free(mv2_allreduce_thresholds_table);
1026 xbt_free(mv2_bcast_thresholds_table);
1027 if(mv2_scatter_thresholds_table)
1028   xbt_free(mv2_scatter_thresholds_table[0]);
1029 xbt_free(mv2_scatter_thresholds_table);
1030 xbt_free(mv2_size_scatter_tuning_table);
1031 xbt_free(mv2_scatter_table_ppn_conf);
1032 }