Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'mc'
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.c
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2010, 2013-2014. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is xbt_free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.h"
10
11 #include "smpi_mvapich2_selector_stampede.h"
12
13
14
15 int smpi_coll_tuned_alltoall_mvapich2( void *sendbuf, int sendcount, 
16     MPI_Datatype sendtype,
17     void* recvbuf, int recvcount,
18     MPI_Datatype recvtype,
19     MPI_Comm comm)
20 {
21
22   if(mv2_alltoall_table_ppn_conf==NULL)
23     init_mv2_alltoall_tables_stampede();
24
25   int sendtype_size, recvtype_size, nbytes, comm_size;
26   char * tmp_buf = NULL;
27   int mpi_errno=MPI_SUCCESS;
28   int range = 0;
29   int range_threshold = 0;
30   int conf_index = 0;
31   comm_size =  smpi_comm_size(comm);
32
33   sendtype_size=smpi_datatype_size(sendtype);
34   recvtype_size=smpi_datatype_size(recvtype);
35   nbytes = sendtype_size * sendcount;
36
37   /* check if safe to use partial subscription mode */
38
39   /* Search for the corresponding system size inside the tuning table */
40   while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
41       (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
42       range++;
43   }
44   /* Search for corresponding inter-leader function */
45   while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
46       && (nbytes >
47   mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
48   && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
49       range_threshold++;
50   }
51   MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
52                                                                                       .MV2_pt_Alltoall_function;
53
54   if(sendbuf != MPI_IN_PLACE) {
55       mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
56           recvbuf, recvcount, recvtype,
57           comm);
58   } else {
59       range_threshold = 0;
60       if(nbytes <
61           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
62           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
63       ) {
64           tmp_buf = (char *)xbt_malloc( comm_size * recvcount * recvtype_size );
65           mpi_errno = smpi_datatype_copy((char *)recvbuf,
66               comm_size*recvcount, recvtype,
67               (char *)tmp_buf,
68               comm_size*recvcount, recvtype);
69
70           mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype,
71               recvbuf, recvcount, recvtype,
72               comm );
73           xbt_free(tmp_buf);
74       } else {
75           mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
76               recvbuf, recvcount, recvtype,
77               comm );
78       }
79   }
80
81
82   return (mpi_errno);
83 }
84
85
86
87 int smpi_coll_tuned_allgather_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
88     void *recvbuf, int recvcount, MPI_Datatype recvtype,
89     MPI_Comm comm)
90 {
91
92   int mpi_errno = MPI_SUCCESS;
93   int nbytes = 0, comm_size, recvtype_size;
94   int range = 0;
95   //int partial_sub_ok = 0;
96   int conf_index = 0;
97   int range_threshold = 0;
98   int is_two_level = 0;
99   //int local_size = -1;
100   //MPI_Comm shmem_comm;
101   //MPI_Comm *shmem_commptr=NULL;
102   /* Get the size of the communicator */
103   comm_size = smpi_comm_size(comm);
104   recvtype_size=smpi_datatype_size(recvtype);
105   nbytes = recvtype_size * recvcount;
106
107   if(mv2_allgather_table_ppn_conf==NULL)
108     init_mv2_allgather_tables_stampede();
109
110   //int i;
111   /* check if safe to use partial subscription mode */
112   /*  if (comm->ch.shmem_coll_ok == 1 && comm->ch.is_uniform) {
113
114         shmem_comm = comm->ch.shmem_comm;
115         MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
116         local_size = shmem_commptr->local_size;
117         i = 0;
118         if (mv2_allgather_table_ppn_conf[0] == -1) {
119             // Indicating user defined tuning
120             conf_index = 0;
121             goto conf_check_end;
122         }
123         do {
124             if (local_size == mv2_allgather_table_ppn_conf[i]) {
125                 conf_index = i;
126                 partial_sub_ok = 1;
127                 break;
128             }
129             i++;
130         } while(i < mv2_allgather_num_ppn_conf);
131     }
132
133   conf_check_end:
134     if (partial_sub_ok != 1) {
135         conf_index = 0;
136     }*/
137   /* Search for the corresponding system size inside the tuning table */
138   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
139       (comm_size >
140   mv2_allgather_thresholds_table[conf_index][range].numproc)) {
141       range++;
142   }
143   /* Search for corresponding inter-leader function */
144   while ((range_threshold <
145       (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
146       && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
147       && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
148           -1)) {
149       range_threshold++;
150   }
151
152   /* Set inter-leader pt */
153   MV2_Allgather_function =
154       mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
155       MV2_pt_Allgather_function;
156
157   is_two_level =  mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
158
159   /* intracommunicator */
160   if(is_two_level ==1){
161
162       /*       if(comm->ch.shmem_coll_ok == 1){
163             MPIR_T_PVAR_COUNTER_INC(MV2, mv2_num_shmem_coll_calls, 1);
164            if (1 == comm->ch.is_blocked) {
165                 mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
166                                                     recvbuf, recvcount, recvtype,
167                                                     comm, errflag);
168            }
169            else {
170                mpi_errno = MPIR_Allgather_intra(sendbuf, sendcount, sendtype,
171                                                 recvbuf, recvcount, recvtype,
172                                                 comm, errflag);
173            }
174         } else {*/
175       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
176           recvbuf, recvcount, recvtype,
177           comm);
178       //     }
179   } else if(MV2_Allgather_function == &MPIR_Allgather_Bruck_MV2
180       || MV2_Allgather_function == &MPIR_Allgather_RD_MV2
181       || MV2_Allgather_function == &MPIR_Allgather_Ring_MV2) {
182       mpi_errno = MV2_Allgather_function(sendbuf, sendcount, sendtype,
183           recvbuf, recvcount, recvtype,
184           comm);
185   }else{
186       return MPI_ERR_OTHER;
187   }
188
189   return mpi_errno;
190 }
191
192
193 int smpi_coll_tuned_gather_mvapich2(void *sendbuf,
194     int sendcnt,
195     MPI_Datatype sendtype,
196     void *recvbuf,
197     int recvcnt,
198     MPI_Datatype recvtype,
199     int root, MPI_Comm  comm)
200 {
201   if(mv2_gather_thresholds_table==NULL)
202     init_mv2_gather_tables_stampede();
203
204   int mpi_errno = MPI_SUCCESS;
205   int range = 0;
206   int range_threshold = 0;
207   int range_intra_threshold = 0;
208   int nbytes = 0;
209   int comm_size = 0;
210   int recvtype_size, sendtype_size;
211   int rank = -1;
212   comm_size = smpi_comm_size(comm);
213   rank = smpi_comm_rank(comm);
214
215   if (rank == root) {
216       recvtype_size=smpi_datatype_size(recvtype);
217       nbytes = recvcnt * recvtype_size;
218   } else {
219       sendtype_size=smpi_datatype_size(sendtype);
220       nbytes = sendcnt * sendtype_size;
221   }
222
223   /* Search for the corresponding system size inside the tuning table */
224   while ((range < (mv2_size_gather_tuning_table - 1)) &&
225       (comm_size > mv2_gather_thresholds_table[range].numproc)) {
226       range++;
227   }
228   /* Search for corresponding inter-leader function */
229   while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
230       && (nbytes >
231   mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
232   && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
233       -1)) {
234       range_threshold++;
235   }
236
237   /* Search for corresponding intra node function */
238   while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
239       && (nbytes >
240   mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
241   && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
242       -1)) {
243       range_intra_threshold++;
244   }
245   /*
246     if (comm->ch.is_global_block == 1 && mv2_use_direct_gather == 1 &&
247             mv2_use_two_level_gather == 1 && comm->ch.shmem_coll_ok == 1) {
248         // Set intra-node function pt for gather_two_level 
249         MV2_Gather_intra_node_function = 
250                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
251                               MV2_pt_Gather_function;
252         //Set inter-leader pt 
253         MV2_Gather_inter_leader_function =
254                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
255                               MV2_pt_Gather_function;
256         // We call Gather function 
257         mpi_errno =
258             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
259                                              recvtype, root, comm);
260
261     } else {*/
262   // Indded, direct (non SMP-aware)gather is MPICH one
263   mpi_errno = smpi_coll_tuned_gather_mpich(sendbuf, sendcnt, sendtype,
264       recvbuf, recvcnt, recvtype,
265       root, comm);
266   //}
267
268   return mpi_errno;
269 }
270
271
272 int smpi_coll_tuned_allgatherv_mvapich2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
273     void *recvbuf, int *recvcounts, int *displs,
274     MPI_Datatype recvtype, MPI_Comm  comm )
275 {
276   int mpi_errno = MPI_SUCCESS;
277   int range = 0, comm_size, total_count, recvtype_size, i;
278   int range_threshold = 0;
279   int nbytes = 0;
280
281   if(mv2_allgatherv_thresholds_table==NULL)
282     init_mv2_allgatherv_tables_stampede();
283
284   comm_size = smpi_comm_size(comm);
285   total_count = 0;
286   for (i = 0; i < comm_size; i++)
287     total_count += recvcounts[i];
288
289   recvtype_size=smpi_datatype_size(recvtype);
290   nbytes = total_count * recvtype_size;
291
292   /* Search for the corresponding system size inside the tuning table */
293   while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
294       (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
295       range++;
296   }
297   /* Search for corresponding inter-leader function */
298   while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
299       && (nbytes >
300   comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
301   && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
302       -1)) {
303       range_threshold++;
304   }
305   /* Set inter-leader pt */
306   MV2_Allgatherv_function =
307       mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
308       MV2_pt_Allgatherv_function;
309
310   if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
311     {
312       if(!(comm_size & (comm_size - 1)))
313         {
314           mpi_errno =
315               MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount,
316                   sendtype, recvbuf,
317                   recvcounts, displs,
318                   recvtype, comm);
319         } else {
320             mpi_errno =
321                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
322                     sendtype, recvbuf,
323                     recvcounts, displs,
324                     recvtype, comm);
325         }
326     } else {
327         mpi_errno =
328             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
329                 recvbuf, recvcounts, displs,
330                 recvtype, comm);
331     }
332
333   return mpi_errno;
334 }
335
336
337
338 int smpi_coll_tuned_allreduce_mvapich2(void *sendbuf,
339     void *recvbuf,
340     int count,
341     MPI_Datatype datatype,
342     MPI_Op op, MPI_Comm comm)
343 {
344
345   int mpi_errno = MPI_SUCCESS;
346   //int rank = 0,
347   int comm_size = 0;
348
349   comm_size = smpi_comm_size(comm);
350   //rank = smpi_comm_rank(comm);
351
352   if (count == 0) {
353       return MPI_SUCCESS;
354   }
355
356   if (mv2_allreduce_thresholds_table == NULL)
357     init_mv2_allreduce_tables_stampede();
358
359   /* check if multiple threads are calling this collective function */
360
361   MPI_Aint sendtype_size = 0;
362   int nbytes = 0;
363   int range = 0, range_threshold = 0, range_threshold_intra = 0;
364   int is_two_level = 0;
365   //int is_commutative = 0;
366   MPI_Aint true_lb, true_extent;
367
368   sendtype_size=smpi_datatype_size(datatype);
369   nbytes = count * sendtype_size;
370
371   smpi_datatype_extent(datatype, &true_lb, &true_extent);
372   //MPI_Op *op_ptr;
373   //is_commutative = smpi_op_is_commute(op);
374
375   {
376     /* Search for the corresponding system size inside the tuning table */
377     while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
378         (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
379         range++;
380     }
381     /* Search for corresponding inter-leader function */
382     /* skip mcast poiters if mcast is not available */
383     if(mv2_allreduce_thresholds_table[range].mcast_enabled != 1){
384         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
385             && ((mv2_allreduce_thresholds_table[range].
386                 inter_leader[range_threshold].MV2_pt_Allreduce_function
387                 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
388                 (mv2_allreduce_thresholds_table[range].
389                     inter_leader[range_threshold].MV2_pt_Allreduce_function
390                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
391             )) {
392             range_threshold++;
393         }
394     }
395     while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
396         && (nbytes >
397     mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
398     && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
399         range_threshold++;
400     }
401     if(mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold] == 1){
402         is_two_level = 1;
403     }
404     /* Search for corresponding intra-node function */
405     while ((range_threshold_intra <
406         (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
407         && (nbytes >
408     mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
409     && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
410         -1)) {
411         range_threshold_intra++;
412     }
413
414     MV2_Allreduce_function = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
415                                                                                 .MV2_pt_Allreduce_function;
416
417     MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
418                                                                                     .MV2_pt_Allreduce_function;
419
420     /* check if mcast is ready, otherwise replace mcast with other algorithm */
421     if((MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
422         (MV2_Allreduce_function == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
423         {
424           MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
425         }
426         if(is_two_level != 1) {
427             MV2_Allreduce_function = &MPIR_Allreduce_pt2pt_rd_MV2;
428         }
429     }
430
431     if(is_two_level == 1){
432         // check if shm is ready, if not use other algorithm first
433         /*if ((comm->ch.shmem_coll_ok == 1)
434                     && (mv2_enable_shmem_allreduce)
435                     && (is_commutative)
436                     && (mv2_enable_shmem_collectives)) {
437                     mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
438                                                      datatype, op, comm);
439                 } else {*/
440         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
441             datatype, op, comm);
442         // }
443     } else {
444         mpi_errno = MV2_Allreduce_function(sendbuf, recvbuf, count,
445             datatype, op, comm);
446     }
447   }
448
449   //comm->ch.intra_node_done=0;
450
451   return (mpi_errno);
452
453
454 }
455
456
457 int smpi_coll_tuned_alltoallv_mvapich2(void *sbuf, int *scounts, int *sdisps,
458     MPI_Datatype sdtype,
459     void *rbuf, int *rcounts, int *rdisps,
460     MPI_Datatype rdtype,
461     MPI_Comm  comm
462 )
463 {
464
465   if (sbuf == MPI_IN_PLACE) {
466       return smpi_coll_tuned_alltoallv_ompi_basic_linear(sbuf, scounts, sdisps, sdtype,
467           rbuf, rcounts, rdisps,rdtype,
468           comm);
469   } else     /* For starters, just keep the original algorithm. */
470   return smpi_coll_tuned_alltoallv_ring(sbuf, scounts, sdisps, sdtype,
471       rbuf, rcounts, rdisps,rdtype,
472       comm);
473 }
474
475
476 int smpi_coll_tuned_barrier_mvapich2(MPI_Comm  comm)
477 {   
478   return smpi_coll_tuned_barrier_mvapich2_pair(comm);
479 }
480
481
482
483
484 int smpi_coll_tuned_bcast_mvapich2(void *buffer,
485     int count,
486     MPI_Datatype datatype,
487     int root, MPI_Comm comm)
488 {
489
490   //TODO : Bcast really needs intra/inter phases in mvapich. Default to mpich if not available
491   return smpi_coll_tuned_bcast_mpich(buffer, count, datatype, root, comm);
492
493 }
494
495
496
497 int smpi_coll_tuned_reduce_mvapich2( void *sendbuf,
498     void *recvbuf,
499     int count,
500     MPI_Datatype datatype,
501     MPI_Op op, int root, MPI_Comm comm)
502 {
503   if(mv2_reduce_thresholds_table == NULL)
504     init_mv2_reduce_tables_stampede();
505
506   int mpi_errno = MPI_SUCCESS;
507   int range = 0;
508   int range_threshold = 0;
509   int range_intra_threshold = 0;
510   int is_commutative, pof2;
511   int comm_size = 0;
512   int nbytes = 0;
513   int sendtype_size;
514   int is_two_level = 0;
515
516   comm_size = smpi_comm_size(comm);
517   sendtype_size=smpi_datatype_size(datatype);
518   nbytes = count * sendtype_size;
519
520   if (count == 0)
521     return MPI_SUCCESS;
522
523   is_commutative = smpi_op_is_commute(op);
524
525   /* find nearest power-of-two less than or equal to comm_size */
526   for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
527   pof2 >>=1;
528
529
530   /* Search for the corresponding system size inside the tuning table */
531   while ((range < (mv2_size_reduce_tuning_table - 1)) &&
532       (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
533       range++;
534   }
535   /* Search for corresponding inter-leader function */
536   while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
537       && (nbytes >
538   mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
539   && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
540       -1)) {
541       range_threshold++;
542   }
543
544   /* Search for corresponding intra node function */
545   while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
546       && (nbytes >
547   mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
548   && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
549       -1)) {
550       range_intra_threshold++;
551   }
552
553   /* Set intra-node function pt for reduce_two_level */
554   MV2_Reduce_intra_function =
555       mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
556       MV2_pt_Reduce_function;
557   /* Set inter-leader pt */
558   MV2_Reduce_function =
559       mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
560       MV2_pt_Reduce_function;
561
562   if(mv2_reduce_intra_knomial_factor<0)
563     {
564       mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
565     }
566   if(mv2_reduce_inter_knomial_factor<0)
567     {
568       mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
569     }
570   if(mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold] == 1){
571       is_two_level = 1;
572   }
573   /* We call Reduce function */
574   if(is_two_level == 1)
575     {
576       /* if (comm->ch.shmem_coll_ok == 1
577             && is_commutative == 1) {
578             mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count, 
579                                            datatype, op, root, comm, errflag);
580         } else {*/
581       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
582           datatype, op, root, comm);
583       //}
584     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
585         if(is_commutative ==1)
586           {
587             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
588                 datatype, op, root, comm);
589           } else {
590               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
591                   datatype, op, root, comm);
592           }
593     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
594         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
595           {
596             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
597                 datatype, op, root, comm);
598           } else {
599               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
600                   datatype, op, root, comm);
601           }
602     } else {
603         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
604             datatype, op, root, comm);
605     }
606
607
608   return mpi_errno;
609
610 }
611
612
613 int smpi_coll_tuned_reduce_scatter_mvapich2(void *sendbuf, void *recvbuf, int *recvcnts,
614     MPI_Datatype datatype, MPI_Op op,
615     MPI_Comm comm)
616 {
617   int mpi_errno = MPI_SUCCESS;
618   int i = 0, comm_size = smpi_comm_size(comm), total_count = 0, type_size =
619       0, nbytes = 0;
620   int range = 0;
621   int range_threshold = 0;
622   int is_commutative = 0;
623   int *disps = xbt_malloc(comm_size * sizeof (int));
624
625   if(mv2_red_scat_thresholds_table==NULL)
626     init_mv2_reduce_scatter_tables_stampede();
627
628   is_commutative=smpi_op_is_commute(op);
629   for (i = 0; i < comm_size; i++) {
630       disps[i] = total_count;
631       total_count += recvcnts[i];
632   }
633
634   type_size=smpi_datatype_size(datatype);
635   nbytes = total_count * type_size;
636
637   if (is_commutative) {
638
639       /* Search for the corresponding system size inside the tuning table */
640       while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
641           (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
642           range++;
643       }
644       /* Search for corresponding inter-leader function */
645       while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
646           && (nbytes >
647       mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
648       && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
649           -1)) {
650           range_threshold++;
651       }
652
653       /* Set inter-leader pt */
654       MV2_Red_scat_function =
655           mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
656           MV2_pt_Red_scat_function;
657
658       mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
659           recvcnts, datatype,
660           op, comm);
661   } else {
662       int is_block_regular = 1;
663       for (i = 0; i < (comm_size - 1); ++i) {
664           if (recvcnts[i] != recvcnts[i+1]) {
665               is_block_regular = 0;
666               break;
667           }
668       }
669       int pof2 = 1;
670       while (pof2 < comm_size) pof2 <<= 1;
671       if (pof2 == comm_size && is_block_regular) {
672           /* noncommutative, pof2 size, and block regular */
673           mpi_errno = MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
674               recvcnts, datatype,
675               op, comm);
676       }
677       mpi_errno =  smpi_coll_tuned_reduce_scatter_mpich_rdb(sendbuf, recvbuf,
678           recvcnts, datatype,
679           op, comm);
680   }
681
682   return mpi_errno;
683
684 }
685
686
687
688 int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
689     int sendcnt,
690     MPI_Datatype sendtype,
691     void *recvbuf,
692     int recvcnt,
693     MPI_Datatype recvtype,
694     int root, MPI_Comm comm_ptr)
695 {
696   int range = 0, range_threshold = 0, range_threshold_intra = 0;
697   int mpi_errno = MPI_SUCCESS;
698   //   int mpi_errno_ret = MPI_SUCCESS;
699   int rank, nbytes, comm_size;
700   int recvtype_size, sendtype_size;
701   int partial_sub_ok = 0;
702   int conf_index = 0;
703   //  int local_size = -1;
704   //  int i;
705   //   MPI_Comm shmem_comm;
706   //    MPID_Comm *shmem_commptr=NULL;
707   if(mv2_scatter_thresholds_table==NULL)
708     init_mv2_scatter_tables_stampede();
709
710   comm_size = smpi_comm_size(comm_ptr);
711
712   rank = smpi_comm_rank(comm_ptr);
713
714   if (rank == root) {
715       sendtype_size=smpi_datatype_size(sendtype);
716       nbytes = sendcnt * sendtype_size;
717   } else {
718       recvtype_size=smpi_datatype_size(recvtype);
719       nbytes = recvcnt * recvtype_size;
720   }
721   /*
722     // check if safe to use partial subscription mode 
723     if (comm_ptr->ch.shmem_coll_ok == 1 && comm_ptr->ch.is_uniform) {
724
725         shmem_comm = comm_ptr->ch.shmem_comm;
726         MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
727         local_size = shmem_commptr->local_size;
728         i = 0;
729         if (mv2_scatter_table_ppn_conf[0] == -1) {
730             // Indicating user defined tuning 
731             conf_index = 0;
732             goto conf_check_end;
733         }
734         do {
735             if (local_size == mv2_scatter_table_ppn_conf[i]) {
736                 conf_index = i;
737                 partial_sub_ok = 1;
738                 break;
739             }
740             i++;
741         } while(i < mv2_scatter_num_ppn_conf);
742     }
743    */
744   if (partial_sub_ok != 1) {
745       conf_index = 0;
746   }
747
748   /* Search for the corresponding system size inside the tuning table */
749   while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
750       (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
751       range++;
752   }
753   /* Search for corresponding inter-leader function */
754   while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
755       && (nbytes >
756   mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
757   && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
758       range_threshold++;
759   }
760
761   /* Search for corresponding intra-node function */
762   while ((range_threshold_intra <
763       (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
764       && (nbytes >
765   mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
766   && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
767       -1)) {
768       range_threshold_intra++;
769   }
770
771   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
772                                                                                       .MV2_pt_Scatter_function;
773
774   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
775 #if defined(_MCST_SUPPORT_)
776       if(comm_ptr->ch.is_mcast_ok == 1
777           && mv2_use_mcast_scatter == 1
778           && comm_ptr->ch.shmem_coll_ok == 1) {
779           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
780       } else
781 #endif /*#if defined(_MCST_SUPPORT_) */
782         {
783           if(mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].
784               MV2_pt_Scatter_function != NULL) {
785               MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1]
786                                                                                                   .MV2_pt_Scatter_function;
787           } else {
788               /* Fallback! */
789               MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
790           }
791         } 
792   }
793
794   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
795       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
796       /* if( comm_ptr->ch.shmem_coll_ok == 1 &&
797              comm_ptr->ch.is_global_block == 1 ) {
798              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
799                                 .MV2_pt_Scatter_function;
800
801              mpi_errno =
802                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
803                                         recvbuf, recvcnt, recvtype, root,
804                                         comm_ptr);
805          } else {*/
806       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
807           recvbuf, recvcnt, recvtype, root,
808           comm_ptr);
809
810       //}
811   } else {
812       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
813           recvbuf, recvcnt, recvtype, root,
814           comm_ptr);
815   }
816   return (mpi_errno);
817 }
818