Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
cleanup a bit the code, ensure tests do pass
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.c
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2010, 2013-2014. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is xbt_free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.h"
10
11 #include "smpi_mvapich2_selector_stampede.h"
12
13
14                             
15 int smpi_coll_tuned_alltoall_mvapich2( void *sendbuf, int sendcount, 
16                                              MPI_Datatype sendtype,
17                                              void* recvbuf, int recvcount, 
18                                              MPI_Datatype recvtype, 
19                                              MPI_Comm comm)
20 {
21
22     if(mv2_alltoall_table_ppn_conf==NULL)
23         init_mv2_alltoall_tables_stampede();
24         
25     int sendtype_size, recvtype_size, nbytes, comm_size;
26     char * tmp_buf = NULL;
27     int mpi_errno=MPI_SUCCESS;
28     int range = 0;
29     int range_threshold = 0;
30     int conf_index = 0;
31     comm_size =  smpi_comm_size(comm);
32
33     sendtype_size=smpi_datatype_size(sendtype);
34     recvtype_size=smpi_datatype_size(recvtype);
35     nbytes = sendtype_size * sendcount;
36
37     /* check if safe to use partial subscription mode */
38
39     /* Search for the corresponding system size inside the tuning table */
40     while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
41            (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
42         range++;
43     }    
44     /* Search for corresponding inter-leader function */
45     while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
46            && (nbytes >
47                mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
48            && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
49         range_threshold++;
50     }     
51     MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
52                                 .MV2_pt_Alltoall_function;
53
54     if(sendbuf != MPI_IN_PLACE) {  
55         mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
56                                               recvbuf, recvcount, recvtype,
57                                                comm);
58     } else {
59         range_threshold = 0; 
60         if(nbytes < 
61           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
62           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
63           ) {
64             tmp_buf = (char *)xbt_malloc( comm_size * recvcount * recvtype_size );
65             mpi_errno = smpi_datatype_copy((char *)recvbuf,
66                                        comm_size*recvcount, recvtype,
67                                        (char *)tmp_buf,
68                                        comm_size*recvcount, recvtype);
69
70             mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype,
71                                                recvbuf, recvcount, recvtype,
72                                                 comm );        
73             xbt_free(tmp_buf);
74         } else { 
75             mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
76                                               recvbuf, recvcount, recvtype,
77                                                comm );
78         } 
79     }
80
81     
82     return (mpi_errno);
83 }
84
85
86
87 int smpi_coll_tuned_allgather_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
88                        void *recvbuf, int recvcount, MPI_Datatype recvtype,
89                        MPI_Comm comm)
90 {
91
92     int mpi_errno = MPI_SUCCESS;
93     int nbytes = 0, comm_size, recvtype_size;
94     int range = 0;
95     //int partial_sub_ok = 0;
96     int conf_index = 0;
97     int range_threshold = 0;
98     int is_two_level = 0;
99     //int local_size = -1;
100     //MPI_Comm shmem_comm;
101     //MPI_Comm *shmem_commptr=NULL;
102     /* Get the size of the communicator */
103     comm_size = smpi_comm_size(comm);
104     recvtype_size=smpi_datatype_size(recvtype);
105     nbytes = recvtype_size * recvcount;
106
107     if(mv2_allgather_table_ppn_conf==NULL)
108         init_mv2_allgather_tables_stampede();
109         
110     //int i;
111     /* check if safe to use partial subscription mode */
112   /*  if (comm->ch.shmem_coll_ok == 1 && comm->ch.is_uniform) {
113     
114         shmem_comm = comm->ch.shmem_comm;
115         MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
116         local_size = shmem_commptr->local_size;
117         i = 0;
118         if (mv2_allgather_table_ppn_conf[0] == -1) {
119             // Indicating user defined tuning
120             conf_index = 0;
121             goto conf_check_end;
122         }
123         do {
124             if (local_size == mv2_allgather_table_ppn_conf[i]) {
125                 conf_index = i;
126                 partial_sub_ok = 1;
127                 break;
128             }
129             i++;
130         } while(i < mv2_allgather_num_ppn_conf);
131     }
132
133   conf_check_end:
134     if (partial_sub_ok != 1) {
135         conf_index = 0;
136     }*/
137     /* Search for the corresponding system size inside the tuning table */
138     while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
139            (comm_size >
140             mv2_allgather_thresholds_table[conf_index][range].numproc)) {
141         range++;
142     }
143     /* Search for corresponding inter-leader function */
144     while ((range_threshold <
145          (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
146            && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
147            && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
148                -1)) {
149         range_threshold++;
150     }
151
152     /* Set inter-leader pt */
153     MV2_Allgather_function =
154                           mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
155                           MV2_pt_Allgather_function;
156
157     is_two_level =  mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
158
159     /* intracommunicator */
160     if(is_two_level ==1){
161         
162  /*       if(comm->ch.shmem_coll_ok == 1){
163             MPIR_T_PVAR_COUNTER_INC(MV2, mv2_num_shmem_coll_calls, 1);
164            if (1 == comm->ch.is_blocked) {
165                 mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
166                                                     recvbuf, recvcount, recvtype,
167                                                     comm, errflag);
168            }
169            else {
170                mpi_errno = MPIR_Allgather_intra(sendbuf, sendcount, sendtype,
171                                                 recvbuf, recvcount, recvtype,
172                                                 comm, errflag);
173            }
174         } else {*/
175             mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
176                                                 recvbuf, recvcount, recvtype,
177                                                 comm);
178    //     }
179     } else if(MV2_Allgather_function == &MPIR_Allgather_Bruck_MV2 
180             || MV2_Allgather_function == &MPIR_Allgather_RD_MV2
181             || MV2_Allgather_function == &MPIR_Allgather_Ring_MV2) {
182             mpi_errno = MV2_Allgather_function(sendbuf, sendcount, sendtype,
183                                           recvbuf, recvcount, recvtype,
184                                           comm);
185     }else{
186       return MPI_ERR_OTHER;
187     }
188
189     return mpi_errno;
190 }
191
192
193 int smpi_coll_tuned_gather_mvapich2(void *sendbuf,
194                     int sendcnt,
195                     MPI_Datatype sendtype,
196                     void *recvbuf,
197                     int recvcnt,
198                     MPI_Datatype recvtype,
199                     int root, MPI_Comm  comm)
200 {
201     if(mv2_gather_thresholds_table==NULL)
202         init_mv2_gather_tables_stampede();
203         
204     int mpi_errno = MPI_SUCCESS;
205     int range = 0;
206     int range_threshold = 0;
207     int range_intra_threshold = 0;
208     int nbytes = 0;
209     int comm_size = 0;
210     int recvtype_size, sendtype_size;
211     int rank = -1;
212     comm_size = smpi_comm_size(comm);
213     rank = smpi_comm_rank(comm);
214
215     if (rank == root) {
216         recvtype_size=smpi_datatype_size(recvtype);
217         nbytes = recvcnt * recvtype_size;
218     } else {
219         sendtype_size=smpi_datatype_size(sendtype);
220         nbytes = sendcnt * sendtype_size;
221     }
222     
223     /* Search for the corresponding system size inside the tuning table */
224     while ((range < (mv2_size_gather_tuning_table - 1)) &&
225            (comm_size > mv2_gather_thresholds_table[range].numproc)) {
226         range++;
227     }
228     /* Search for corresponding inter-leader function */
229     while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
230            && (nbytes >
231                mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
232            && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
233                -1)) {
234         range_threshold++;
235     }
236
237     /* Search for corresponding intra node function */
238     while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
239            && (nbytes >
240                mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
241            && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
242                -1)) {
243         range_intra_threshold++;
244     }
245 /*
246     if (comm->ch.is_global_block == 1 && mv2_use_direct_gather == 1 &&
247             mv2_use_two_level_gather == 1 && comm->ch.shmem_coll_ok == 1) {
248         // Set intra-node function pt for gather_two_level 
249         MV2_Gather_intra_node_function = 
250                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
251                               MV2_pt_Gather_function;
252         //Set inter-leader pt 
253         MV2_Gather_inter_leader_function =
254                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
255                               MV2_pt_Gather_function;
256         // We call Gather function 
257         mpi_errno =
258             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
259                                              recvtype, root, comm);
260
261     } else {*/
262     // Indded, direct (non SMP-aware)gather is MPICH one 
263         mpi_errno = smpi_coll_tuned_gather_mpich(sendbuf, sendcnt, sendtype,
264                                       recvbuf, recvcnt, recvtype,
265                                       root, comm);
266     //}
267
268     return mpi_errno;
269 }
270
271
272 int smpi_coll_tuned_allgatherv_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
273                         void *recvbuf, int *recvcounts, int *displs,
274                         MPI_Datatype recvtype, MPI_Comm  comm )
275 {
276     int mpi_errno = MPI_SUCCESS;
277     int range = 0, comm_size, total_count, recvtype_size, i;
278     int range_threshold = 0;
279     int nbytes = 0;
280
281     if(mv2_allgatherv_thresholds_table==NULL)
282         init_mv2_allgatherv_tables_stampede();
283         
284     comm_size = smpi_comm_size(comm);
285     total_count = 0;
286     for (i = 0; i < comm_size; i++)
287         total_count += recvcounts[i];
288
289     recvtype_size=smpi_datatype_size(recvtype);
290     nbytes = total_count * recvtype_size;
291
292     /* Search for the corresponding system size inside the tuning table */
293     while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
294            (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
295         range++;
296     }
297     /* Search for corresponding inter-leader function */
298     while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
299            && (nbytes >
300                comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
301            && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
302                -1)) {
303         range_threshold++;
304     }
305     /* Set inter-leader pt */
306     MV2_Allgatherv_function =
307                           mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
308                           MV2_pt_Allgatherv_function;
309
310     if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
311     {
312         if(!(comm_size & (comm_size - 1)))
313         {
314             mpi_errno =
315                 MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount,
316                                                  sendtype, recvbuf,
317                                                  recvcounts, displs,
318                                                  recvtype, comm);
319         } else {
320             mpi_errno =
321                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
322                                           sendtype, recvbuf,
323                                           recvcounts, displs,
324                                           recvtype, comm);
325         }
326     } else {
327         mpi_errno =
328             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
329                                     recvbuf, recvcounts, displs,
330                                     recvtype, comm);
331     }
332
333     return mpi_errno;
334 }
335
336
337
338 int smpi_coll_tuned_allreduce_mvapich2(void *sendbuf,
339                        void *recvbuf,
340                        int count,
341                        MPI_Datatype datatype,
342                        MPI_Op op, MPI_Comm comm)
343 {
344
345     int mpi_errno = MPI_SUCCESS;
346     //int rank = 0, 
347     int comm_size = 0;
348    
349     comm_size = smpi_comm_size(comm);
350     //rank = smpi_comm_rank(comm);
351
352     if (count == 0) {
353         return MPI_SUCCESS;
354     }
355
356   if (mv2_allreduce_thresholds_table == NULL)
357     init_mv2_allreduce_tables_stampede();
358
359     /* check if multiple threads are calling this collective function */
360
361     MPI_Aint sendtype_size = 0;
362     int nbytes = 0;
363     int range = 0, range_threshold = 0, range_threshold_intra = 0;
364     int is_two_level = 0;
365     //int is_commutative = 0;
366     MPI_Aint true_lb, true_extent;
367
368     sendtype_size=smpi_datatype_size(datatype);
369     nbytes = count * sendtype_size;
370
371     smpi_datatype_extent(datatype, &true_lb, &true_extent);
372     //MPI_Op *op_ptr;
373     //is_commutative = smpi_op_is_commute(op);
374
375     {
376         /* Search for the corresponding system size inside the tuning table */
377         while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
378                (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
379             range++;
380         }
381         /* Search for corresponding inter-leader function */
382         /* skip mcast poiters if mcast is not available */
383         if(mv2_allreduce_thresholds_table[range].mcast_enabled != 1){
384             while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1)) 
385                     && ((mv2_allreduce_thresholds_table[range].
386                     inter_leader[range_threshold].MV2_pt_Allreduce_function 
387                     == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
388                     (mv2_allreduce_thresholds_table[range].
389                     inter_leader[range_threshold].MV2_pt_Allreduce_function
390                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
391                     )) {
392                     range_threshold++;
393             }
394         }
395         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
396                && (nbytes >
397                mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
398                && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
399                range_threshold++;
400         }
401         if(mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold] == 1){
402                is_two_level = 1;    
403         }
404         /* Search for corresponding intra-node function */
405         while ((range_threshold_intra <
406                (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
407                 && (nbytes >
408                 mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
409                 && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
410                 -1)) {
411                 range_threshold_intra++;
412         }
413
414         MV2_Allreduce_function = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
415                                 .MV2_pt_Allreduce_function;
416
417         MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
418                                 .MV2_pt_Allreduce_function;
419
420         /* check if mcast is ready, otherwise replace mcast with other algorithm */
421         if((MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
422           (MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
423             {
424                 MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
425             }
426             if(is_two_level != 1) {
427                 MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
428             }
429         } 
430
431         if(is_two_level == 1){
432                 // check if shm is ready, if not use other algorithm first 
433                 /*if ((comm->ch.shmem_coll_ok == 1)
434                     && (mv2_enable_shmem_allreduce)
435                     && (is_commutative)
436                     && (mv2_enable_shmem_collectives)) {
437                     mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
438                                                      datatype, op, comm);
439                 } else {*/
440                     mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
441                                                      datatype, op, comm);
442                // }
443         } else { 
444             mpi_errno = MV2_Allreduce_function(sendbuf, recvbuf, count,
445                                            datatype, op, comm);
446         }
447     } 
448
449         //comm->ch.intra_node_done=0;
450         
451     return (mpi_errno);
452
453
454 }
455
456
457 int smpi_coll_tuned_alltoallv_mvapich2(void *sbuf, int *scounts, int *sdisps,
458                                               MPI_Datatype sdtype,
459                                               void *rbuf, int *rcounts, int *rdisps,
460                                               MPI_Datatype rdtype,
461                                               MPI_Comm  comm
462                                               )
463 {
464
465 if (sbuf == MPI_IN_PLACE) {
466     return smpi_coll_tuned_alltoallv_ompi_basic_linear(sbuf, scounts, sdisps, sdtype, 
467                                                         rbuf, rcounts, rdisps,rdtype,
468                                                         comm);
469  } else     /* For starters, just keep the original algorithm. */
470     return smpi_coll_tuned_alltoallv_ring(sbuf, scounts, sdisps, sdtype, 
471                                                         rbuf, rcounts, rdisps,rdtype,
472                                                         comm);
473 }
474
475
476 int smpi_coll_tuned_barrier_mvapich2(MPI_Comm  comm)
477 {   
478     return smpi_coll_tuned_barrier_mvapich2_pair(comm);
479 }
480
481
482
483
484 int smpi_coll_tuned_bcast_mvapich2(void *buffer,
485                               int count,
486                               MPI_Datatype datatype,
487                               int root, MPI_Comm comm)
488 {
489
490 //TODO : Bcast really needs intra/inter phases in mvapich. Default to mpich if not available
491   return smpi_coll_tuned_bcast_mpich(buffer, count, datatype, root, comm);
492
493 }
494
495
496
497 int smpi_coll_tuned_reduce_mvapich2( void *sendbuf,
498                     void *recvbuf,
499                     int count,
500                     MPI_Datatype datatype,
501                     MPI_Op op, int root, MPI_Comm comm)
502 {
503    if(mv2_reduce_thresholds_table == NULL)
504      init_mv2_reduce_tables_stampede();
505
506     int mpi_errno = MPI_SUCCESS;
507     int range = 0;
508     int range_threshold = 0;
509     int range_intra_threshold = 0;
510     int is_commutative, pof2;
511     int comm_size = 0;
512     int nbytes = 0;
513     int sendtype_size;
514     int is_two_level = 0;
515
516     comm_size = smpi_comm_size(comm);
517     sendtype_size=smpi_datatype_size(datatype);
518     nbytes = count * sendtype_size;
519
520     if (count == 0)
521         return MPI_SUCCESS;
522
523     is_commutative = smpi_op_is_commute(op);
524
525     /* find nearest power-of-two less than or equal to comm_size */
526     for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
527     pof2 >>=1;
528     
529
530     /* Search for the corresponding system size inside the tuning table */
531     while ((range < (mv2_size_reduce_tuning_table - 1)) &&
532            (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
533         range++;
534     }
535     /* Search for corresponding inter-leader function */
536     while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
537            && (nbytes >
538                mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
539            && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
540                -1)) {
541         range_threshold++;
542     }
543
544     /* Search for corresponding intra node function */
545     while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
546            && (nbytes >
547                mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
548            && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
549                -1)) {
550         range_intra_threshold++;
551     }
552
553     /* Set intra-node function pt for reduce_two_level */
554     MV2_Reduce_intra_function = 
555                           mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
556                           MV2_pt_Reduce_function;
557     /* Set inter-leader pt */
558     MV2_Reduce_function =
559                           mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
560                           MV2_pt_Reduce_function;
561
562     if(mv2_reduce_intra_knomial_factor<0)
563     {
564         mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
565     }
566     if(mv2_reduce_inter_knomial_factor<0)
567     {
568         mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
569     }
570     if(mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold] == 1){
571                is_two_level = 1;
572     }
573     /* We call Reduce function */
574     if(is_two_level == 1)
575     {
576        /* if (comm->ch.shmem_coll_ok == 1
577             && is_commutative == 1) {
578             mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count, 
579                                            datatype, op, root, comm, errflag);
580         } else {*/
581             mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count, 
582                                            datatype, op, root, comm);
583        //}
584     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
585         if(is_commutative ==1)
586         {
587             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
588                                            datatype, op, root, comm);
589         } else {
590             mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count, 
591                                            datatype, op, root, comm);
592         }
593     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
594         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
595         {
596             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
597                                             datatype, op, root, comm);
598         } else {
599             mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count, 
600                                             datatype, op, root, comm);
601         }
602     } else {
603         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
604                                         datatype, op, root, comm);
605     }
606
607
608       return mpi_errno;
609
610 }
611
612
613 int smpi_coll_tuned_reduce_scatter_mvapich2(void *sendbuf, void *recvbuf, int *recvcnts,
614                                                         MPI_Datatype datatype, MPI_Op op,
615                                                         MPI_Comm comm)
616 {
617         int mpi_errno = MPI_SUCCESS;
618         int i = 0, comm_size = smpi_comm_size(comm), total_count = 0, type_size =
619                 0, nbytes = 0;
620     int range = 0;
621     int range_threshold = 0;
622         int is_commutative = 0;
623         int *disps = xbt_malloc(comm_size * sizeof (int));
624
625     if(mv2_red_scat_thresholds_table==NULL)
626       init_mv2_reduce_scatter_tables_stampede();
627       
628     is_commutative=smpi_op_is_commute(op);
629         for (i = 0; i < comm_size; i++) {
630                 disps[i] = total_count;
631                 total_count += recvcnts[i];
632         }
633
634         type_size=smpi_datatype_size(datatype);
635         nbytes = total_count * type_size;
636
637         if (is_commutative) {
638
639         /* Search for the corresponding system size inside the tuning table */
640         while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
641                (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
642             range++;
643         }
644         /* Search for corresponding inter-leader function */
645         while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
646                && (nbytes >
647                    mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
648                && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
649                    -1)) {
650             range_threshold++;
651         }
652     
653         /* Set inter-leader pt */
654         MV2_Red_scat_function =
655                               mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
656                               MV2_pt_Red_scat_function;
657
658                 mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
659                                           recvcnts, datatype,
660                                           op, comm);
661         } else {
662           int is_block_regular = 1;
663         for (i = 0; i < (comm_size - 1); ++i) {
664             if (recvcnts[i] != recvcnts[i+1]) {
665                 is_block_regular = 0;
666                 break;
667             }
668         }
669           int pof2 = 1;
670       while (pof2 < comm_size) pof2 <<= 1;
671         if (pof2 == comm_size && is_block_regular) {
672        /* noncommutative, pof2 size, and block regular */
673           mpi_errno = MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
674                                                       recvcnts, datatype,
675                                                       op, comm);
676         }
677         mpi_errno =  smpi_coll_tuned_reduce_scatter_mpich_rdb(sendbuf, recvbuf,
678                                                              recvcnts, datatype,
679                                                              op, comm);
680         }
681
682     return mpi_errno;
683
684 }
685
686
687
688 int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
689                            int sendcnt,
690                            MPI_Datatype sendtype,
691                            void *recvbuf,
692                            int recvcnt,
693                            MPI_Datatype recvtype,
694                            int root, MPI_Comm comm_ptr)
695 {
696     int range = 0, range_threshold = 0, range_threshold_intra = 0;
697     int mpi_errno = MPI_SUCCESS;
698  //   int mpi_errno_ret = MPI_SUCCESS;
699     int rank, nbytes, comm_size;
700     int recvtype_size, sendtype_size;
701     int partial_sub_ok = 0;
702     int conf_index = 0;
703   //  int local_size = -1;
704   //  int i;
705  //   MPI_Comm shmem_comm;
706 //    MPID_Comm *shmem_commptr=NULL;
707     if(mv2_scatter_thresholds_table==NULL)
708       init_mv2_scatter_tables_stampede();
709
710     comm_size = smpi_comm_size(comm_ptr);
711
712     rank = smpi_comm_rank(comm_ptr);
713
714     if (rank == root) {
715         sendtype_size=smpi_datatype_size(sendtype);
716         nbytes = sendcnt * sendtype_size;
717     } else {
718         recvtype_size=smpi_datatype_size(recvtype);
719         nbytes = recvcnt * recvtype_size;
720     }
721 /*
722     // check if safe to use partial subscription mode 
723     if (comm_ptr->ch.shmem_coll_ok == 1 && comm_ptr->ch.is_uniform) {
724     
725         shmem_comm = comm_ptr->ch.shmem_comm;
726         MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
727         local_size = shmem_commptr->local_size;
728         i = 0;
729         if (mv2_scatter_table_ppn_conf[0] == -1) {
730             // Indicating user defined tuning 
731             conf_index = 0;
732             goto conf_check_end;
733         }
734         do {
735             if (local_size == mv2_scatter_table_ppn_conf[i]) {
736                 conf_index = i;
737                 partial_sub_ok = 1;
738                 break;
739             }
740             i++;
741         } while(i < mv2_scatter_num_ppn_conf);
742     }
743     */
744     if (partial_sub_ok != 1) {
745         conf_index = 0;
746     }
747
748     /* Search for the corresponding system size inside the tuning table */
749     while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
750            (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
751         range++;
752     }
753     /* Search for corresponding inter-leader function */
754     while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
755            && (nbytes >
756            mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
757            && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
758            range_threshold++;
759     }
760
761     /* Search for corresponding intra-node function */
762     while ((range_threshold_intra <
763            (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
764             && (nbytes >
765                 mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
766             && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
767             -1)) {
768             range_threshold_intra++;
769     }
770
771     MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
772                             .MV2_pt_Scatter_function;
773
774     if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) { 
775 #if defined(_MCST_SUPPORT_)
776         if(comm_ptr->ch.is_mcast_ok == 1 
777            && mv2_use_mcast_scatter == 1 
778            && comm_ptr->ch.shmem_coll_ok == 1) {
779             MV2_Scatter_function = &MPIR_Scatter_mcst_MV2; 
780         } else
781 #endif /*#if defined(_MCST_SUPPORT_) */
782         {
783             if(mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].
784                MV2_pt_Scatter_function != NULL) { 
785                   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1]
786                                                                           .MV2_pt_Scatter_function;
787             } else { 
788                   /* Fallback! */ 
789                   MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial; 
790             }  
791         } 
792     } 
793  
794     if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) || 
795         (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) { 
796         /* if( comm_ptr->ch.shmem_coll_ok == 1 && 
797              comm_ptr->ch.is_global_block == 1 ) {
798              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
799                                 .MV2_pt_Scatter_function;
800
801              mpi_errno =
802                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
803                                         recvbuf, recvcnt, recvtype, root,
804                                         comm_ptr);
805          } else {*/
806              mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
807                                         recvbuf, recvcnt, recvtype, root,
808                                         comm_ptr);
809
810          //}
811     } else { 
812          mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
813                                     recvbuf, recvcnt, recvtype, root,
814                                     comm_ptr);
815     } 
816     return (mpi_errno);
817 }
818