Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Introduce smpi::Process
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.cpp
1 /* selector for collective algorithms based on mvapich decision logic */
2
3 /* Copyright (c) 2009-2010, 2013-2014. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include "colls_private.h"
10
11 #include "smpi_mvapich2_selector_stampede.h"
12
13
14 namespace simgrid{
15 namespace smpi{
16
17
18 int Coll_alltoall_mvapich2::alltoall( void *sendbuf, int sendcount, 
19     MPI_Datatype sendtype,
20     void* recvbuf, int recvcount,
21     MPI_Datatype recvtype,
22     MPI_Comm comm)
23 {
24
25   if(mv2_alltoall_table_ppn_conf==NULL)
26     init_mv2_alltoall_tables_stampede();
27
28   int sendtype_size, recvtype_size, comm_size;
29   char * tmp_buf = NULL;
30   int mpi_errno=MPI_SUCCESS;
31   int range = 0;
32   int range_threshold = 0;
33   int conf_index = 0;
34   comm_size =  comm->size();
35
36   sendtype_size=sendtype->size();
37   recvtype_size=recvtype->size();
38   long nbytes = sendtype_size * sendcount;
39
40   /* check if safe to use partial subscription mode */
41
42   /* Search for the corresponding system size inside the tuning table */
43   while ((range < (mv2_size_alltoall_tuning_table[conf_index] - 1)) &&
44       (comm_size > mv2_alltoall_thresholds_table[conf_index][range].numproc)) {
45       range++;
46   }
47   /* Search for corresponding inter-leader function */
48   while ((range_threshold < (mv2_alltoall_thresholds_table[conf_index][range].size_table - 1))
49       && (nbytes >
50   mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max)
51   && (mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold].max != -1)) {
52       range_threshold++;
53   }
54   MV2_Alltoall_function = mv2_alltoall_thresholds_table[conf_index][range].algo_table[range_threshold]
55                                                                                       .MV2_pt_Alltoall_function;
56
57   if(sendbuf != MPI_IN_PLACE) {
58       mpi_errno = MV2_Alltoall_function(sendbuf, sendcount, sendtype,
59           recvbuf, recvcount, recvtype,
60           comm);
61   } else {
62       range_threshold = 0;
63       if(nbytes <
64           mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].min
65           ||nbytes > mv2_alltoall_thresholds_table[conf_index][range].in_place_algo_table[range_threshold].max
66       ) {
67           tmp_buf = (char *)smpi_get_tmp_sendbuffer( comm_size * recvcount * recvtype_size );
68           mpi_errno = Datatype::copy((char *)recvbuf,
69               comm_size*recvcount, recvtype,
70               (char *)tmp_buf,
71               comm_size*recvcount, recvtype);
72
73           mpi_errno = MV2_Alltoall_function(tmp_buf, recvcount, recvtype,
74               recvbuf, recvcount, recvtype,
75               comm );
76           smpi_free_tmp_buffer(tmp_buf);
77       } else {
78           mpi_errno = MPIR_Alltoall_inplace_MV2(sendbuf, sendcount, sendtype,
79               recvbuf, recvcount, recvtype,
80               comm );
81       }
82   }
83
84
85   return (mpi_errno);
86 }
87
88 int Coll_allgather_mvapich2::allgather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
89     void *recvbuf, int recvcount, MPI_Datatype recvtype,
90     MPI_Comm comm)
91 {
92
93   int mpi_errno = MPI_SUCCESS;
94   long nbytes = 0, comm_size, recvtype_size;
95   int range = 0;
96   int partial_sub_ok = 0;
97   int conf_index = 0;
98   int range_threshold = 0;
99   int is_two_level = 0;
100   int local_size = -1;
101   MPI_Comm shmem_comm;
102   //MPI_Comm *shmem_commptr=NULL;
103   /* Get the size of the communicator */
104   comm_size = comm->size();
105   recvtype_size=recvtype->size();
106   nbytes = recvtype_size * recvcount;
107
108   if(mv2_allgather_table_ppn_conf==NULL)
109     init_mv2_allgather_tables_stampede();
110     
111   if(comm->get_leaders_comm()==MPI_COMM_NULL){
112     comm->init_smp();
113   }
114
115   int i;
116   if (comm->is_uniform()){
117     shmem_comm = comm->get_intra_comm();
118     local_size = shmem_comm->size();
119     i = 0;
120     if (mv2_allgather_table_ppn_conf[0] == -1) {
121       // Indicating user defined tuning
122       conf_index = 0;
123       goto conf_check_end;
124     }
125     do {
126       if (local_size == mv2_allgather_table_ppn_conf[i]) {
127         conf_index = i;
128         partial_sub_ok = 1;
129         break;
130       }
131       i++;
132     } while(i < mv2_allgather_num_ppn_conf);
133   }
134   conf_check_end:
135   if (partial_sub_ok != 1) {
136     conf_index = 0;
137   }
138   
139   /* Search for the corresponding system size inside the tuning table */
140   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
141       (comm_size >
142   mv2_allgather_thresholds_table[conf_index][range].numproc)) {
143       range++;
144   }
145   /* Search for corresponding inter-leader function */
146   while ((range_threshold <
147       (mv2_allgather_thresholds_table[conf_index][range].size_inter_table - 1))
148       && (nbytes > mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
149       && (mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].max !=
150           -1)) {
151       range_threshold++;
152   }
153
154   /* Set inter-leader pt */
155   MV2_Allgatherction =
156       mv2_allgather_thresholds_table[conf_index][range].inter_leader[range_threshold].
157       MV2_pt_Allgatherction;
158
159   is_two_level =  mv2_allgather_thresholds_table[conf_index][range].two_level[range_threshold];
160
161   /* intracommunicator */
162   if(is_two_level ==1){
163     if(partial_sub_ok ==1){
164       if (comm->is_blocked()){
165       mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
166                             recvbuf, recvcount, recvtype,
167                             comm);
168       }else{
169       mpi_errno = Coll_allgather_mpich::allgather(sendbuf, sendcount, sendtype,
170                             recvbuf, recvcount, recvtype,
171                             comm);
172       }
173     } else {
174       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
175           recvbuf, recvcount, recvtype,
176           comm);
177     }
178   } else if(MV2_Allgatherction == &MPIR_Allgather_Bruck_MV2
179       || MV2_Allgatherction == &MPIR_Allgather_RD_MV2
180       || MV2_Allgatherction == &MPIR_Allgather_Ring_MV2) {
181       mpi_errno = MV2_Allgatherction(sendbuf, sendcount, sendtype,
182           recvbuf, recvcount, recvtype,
183           comm);
184   }else{
185       return MPI_ERR_OTHER;
186   }
187
188   return mpi_errno;
189 }
190
191 int Coll_gather_mvapich2::gather(void *sendbuf,
192     int sendcnt,
193     MPI_Datatype sendtype,
194     void *recvbuf,
195     int recvcnt,
196     MPI_Datatype recvtype,
197     int root, MPI_Comm  comm)
198 {
199   if(mv2_gather_thresholds_table==NULL)
200     init_mv2_gather_tables_stampede();
201
202   int mpi_errno = MPI_SUCCESS;
203   int range = 0;
204   int range_threshold = 0;
205   int range_intra_threshold = 0;
206   long nbytes = 0;
207   int comm_size = 0;
208   int recvtype_size, sendtype_size;
209   int rank = -1;
210   comm_size = comm->size();
211   rank = comm->rank();
212
213   if (rank == root) {
214       recvtype_size=recvtype->size();
215       nbytes = recvcnt * recvtype_size;
216   } else {
217       sendtype_size=sendtype->size();
218       nbytes = sendcnt * sendtype_size;
219   }
220
221   /* Search for the corresponding system size inside the tuning table */
222   while ((range < (mv2_size_gather_tuning_table - 1)) &&
223       (comm_size > mv2_gather_thresholds_table[range].numproc)) {
224       range++;
225   }
226   /* Search for corresponding inter-leader function */
227   while ((range_threshold < (mv2_gather_thresholds_table[range].size_inter_table - 1))
228       && (nbytes >
229   mv2_gather_thresholds_table[range].inter_leader[range_threshold].max)
230   && (mv2_gather_thresholds_table[range].inter_leader[range_threshold].max !=
231       -1)) {
232       range_threshold++;
233   }
234
235   /* Search for corresponding intra node function */
236   while ((range_intra_threshold < (mv2_gather_thresholds_table[range].size_intra_table - 1))
237       && (nbytes >
238   mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max)
239   && (mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].max !=
240       -1)) {
241       range_intra_threshold++;
242   }
243   
244     if (comm->is_blocked() ) {
245         // Set intra-node function pt for gather_two_level 
246         MV2_Gather_intra_node_function = 
247                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
248                               MV2_pt_Gather_function;
249         //Set inter-leader pt 
250         MV2_Gather_inter_leader_function =
251                               mv2_gather_thresholds_table[range].inter_leader[range_threshold].
252                               MV2_pt_Gather_function;
253         // We call Gather function 
254         mpi_errno =
255             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
256                                              recvtype, root, comm);
257
258     } else {
259   // Indeed, direct (non SMP-aware)gather is MPICH one
260   mpi_errno = Coll_gather_mpich::gather(sendbuf, sendcnt, sendtype,
261       recvbuf, recvcnt, recvtype,
262       root, comm);
263   }
264
265   return mpi_errno;
266 }
267
268 int Coll_allgatherv_mvapich2::allgatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
269     void *recvbuf, int *recvcounts, int *displs,
270     MPI_Datatype recvtype, MPI_Comm  comm )
271 {
272   int mpi_errno = MPI_SUCCESS;
273   int range = 0, comm_size, total_count, recvtype_size, i;
274   int range_threshold = 0;
275   long nbytes = 0;
276
277   if(mv2_allgatherv_thresholds_table==NULL)
278     init_mv2_allgatherv_tables_stampede();
279
280   comm_size = comm->size();
281   total_count = 0;
282   for (i = 0; i < comm_size; i++)
283     total_count += recvcounts[i];
284
285   recvtype_size=recvtype->size();
286   nbytes = total_count * recvtype_size;
287
288   /* Search for the corresponding system size inside the tuning table */
289   while ((range < (mv2_size_allgatherv_tuning_table - 1)) &&
290       (comm_size > mv2_allgatherv_thresholds_table[range].numproc)) {
291       range++;
292   }
293   /* Search for corresponding inter-leader function */
294   while ((range_threshold < (mv2_allgatherv_thresholds_table[range].size_inter_table - 1))
295       && (nbytes >
296   comm_size * mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max)
297   && (mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].max !=
298       -1)) {
299       range_threshold++;
300   }
301   /* Set inter-leader pt */
302   MV2_Allgatherv_function =
303       mv2_allgatherv_thresholds_table[range].inter_leader[range_threshold].
304       MV2_pt_Allgatherv_function;
305
306   if (MV2_Allgatherv_function == &MPIR_Allgatherv_Rec_Doubling_MV2)
307     {
308       if(!(comm_size & (comm_size - 1)))
309         {
310           mpi_errno =
311               MPIR_Allgatherv_Rec_Doubling_MV2(sendbuf, sendcount,
312                   sendtype, recvbuf,
313                   recvcounts, displs,
314                   recvtype, comm);
315         } else {
316             mpi_errno =
317                 MPIR_Allgatherv_Bruck_MV2(sendbuf, sendcount,
318                     sendtype, recvbuf,
319                     recvcounts, displs,
320                     recvtype, comm);
321         }
322     } else {
323         mpi_errno =
324             MV2_Allgatherv_function(sendbuf, sendcount, sendtype,
325                 recvbuf, recvcounts, displs,
326                 recvtype, comm);
327     }
328
329   return mpi_errno;
330 }
331
332
333
334 int Coll_allreduce_mvapich2::allreduce(void *sendbuf,
335     void *recvbuf,
336     int count,
337     MPI_Datatype datatype,
338     MPI_Op op, MPI_Comm comm)
339 {
340
341   int mpi_errno = MPI_SUCCESS;
342   //int rank = 0,
343   int comm_size = 0;
344
345   comm_size = comm->size();
346   //rank = comm->rank();
347
348   if (count == 0) {
349       return MPI_SUCCESS;
350   }
351
352   if (mv2_allreduce_thresholds_table == NULL)
353     init_mv2_allreduce_tables_stampede();
354
355   /* check if multiple threads are calling this collective function */
356
357   MPI_Aint sendtype_size = 0;
358   long nbytes = 0;
359   int range = 0, range_threshold = 0, range_threshold_intra = 0;
360   int is_two_level = 0;
361   int is_commutative = 0;
362   MPI_Aint true_lb, true_extent;
363
364   sendtype_size=datatype->size();
365   nbytes = count * sendtype_size;
366
367   datatype->extent(&true_lb, &true_extent);
368   //MPI_Op *op_ptr;
369   //is_commutative = op->is_commutative();
370
371   {
372     /* Search for the corresponding system size inside the tuning table */
373     while ((range < (mv2_size_allreduce_tuning_table - 1)) &&
374         (comm_size > mv2_allreduce_thresholds_table[range].numproc)) {
375         range++;
376     }
377     /* Search for corresponding inter-leader function */
378     /* skip mcast poiters if mcast is not available */
379     if(mv2_allreduce_thresholds_table[range].mcast_enabled != 1){
380         while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
381             && ((mv2_allreduce_thresholds_table[range].
382                 inter_leader[range_threshold].MV2_pt_Allreducection
383                 == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2) ||
384                 (mv2_allreduce_thresholds_table[range].
385                     inter_leader[range_threshold].MV2_pt_Allreducection
386                     == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)
387             )) {
388             range_threshold++;
389         }
390     }
391     while ((range_threshold < (mv2_allreduce_thresholds_table[range].size_inter_table - 1))
392         && (nbytes >
393     mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max)
394     && (mv2_allreduce_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
395         range_threshold++;
396     }
397     if(mv2_allreduce_thresholds_table[range].is_two_level_allreduce[range_threshold] == 1){
398         is_two_level = 1;
399     }
400     /* Search for corresponding intra-node function */
401     while ((range_threshold_intra <
402         (mv2_allreduce_thresholds_table[range].size_intra_table - 1))
403         && (nbytes >
404     mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max)
405     && (mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra].max !=
406         -1)) {
407         range_threshold_intra++;
408     }
409
410     MV2_Allreducection = mv2_allreduce_thresholds_table[range].inter_leader[range_threshold]
411                                                                                 .MV2_pt_Allreducection;
412
413     MV2_Allreduce_intra_function = mv2_allreduce_thresholds_table[range].intra_node[range_threshold_intra]
414                                                                                     .MV2_pt_Allreducection;
415
416     /* check if mcast is ready, otherwise replace mcast with other algorithm */
417     if((MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_redscat_gather_MV2)||
418         (MV2_Allreducection == &MPIR_Allreduce_mcst_reduce_two_level_helper_MV2)){
419         {
420           MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
421         }
422         if(is_two_level != 1) {
423             MV2_Allreducection = &MPIR_Allreduce_pt2pt_rd_MV2;
424         }
425     }
426
427     if(is_two_level == 1){
428         // check if shm is ready, if not use other algorithm first
429         if (is_commutative) {
430           if(comm->get_leaders_comm()==MPI_COMM_NULL){
431             comm->init_smp();
432           }
433           mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
434                                                      datatype, op, comm);
435                 } else {
436         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
437             datatype, op, comm);
438         }
439     } else {
440         mpi_errno = MV2_Allreducection(sendbuf, recvbuf, count,
441             datatype, op, comm);
442     }
443   }
444
445   //comm->ch.intra_node_done=0;
446
447   return (mpi_errno);
448
449
450 }
451
452
453 int Coll_alltoallv_mvapich2::alltoallv(void *sbuf, int *scounts, int *sdisps,
454     MPI_Datatype sdtype,
455     void *rbuf, int *rcounts, int *rdisps,
456     MPI_Datatype rdtype,
457     MPI_Comm  comm
458 )
459 {
460
461   if (sbuf == MPI_IN_PLACE) {
462       return Coll_alltoallv_ompi_basic_linear::alltoallv(sbuf, scounts, sdisps, sdtype,
463           rbuf, rcounts, rdisps,rdtype,
464           comm);
465   } else     /* For starters, just keep the original algorithm. */
466   return Coll_alltoallv_ring::alltoallv(sbuf, scounts, sdisps, sdtype,
467       rbuf, rcounts, rdisps,rdtype,
468       comm);
469 }
470
471
472 int Coll_barrier_mvapich2::barrier(MPI_Comm  comm)
473 {   
474   return Coll_barrier_mvapich2_pair::barrier(comm);
475 }
476
477
478
479
480 int Coll_bcast_mvapich2::bcast(void *buffer,
481     int count,
482     MPI_Datatype datatype,
483     int root, MPI_Comm comm)
484 {
485     int mpi_errno = MPI_SUCCESS;
486     int comm_size/*, rank*/;
487     int two_level_bcast = 1;
488     long nbytes = 0; 
489     int range = 0;
490     int range_threshold = 0;
491     int range_threshold_intra = 0;
492     int is_homogeneous, is_contig;
493     MPI_Aint type_size;
494     //, position;
495     void *tmp_buf = NULL;
496     MPI_Comm shmem_comm;
497     //MPID_Datatype *dtp;
498
499     if (count == 0)
500         return MPI_SUCCESS;
501     if(comm->get_leaders_comm()==MPI_COMM_NULL){
502       comm->init_smp();
503     }
504     if(!mv2_bcast_thresholds_table)
505       init_mv2_bcast_tables_stampede();
506     comm_size = comm->size();
507     //rank = comm->rank();
508
509     is_contig=1;
510 /*    if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)*/
511 /*        is_contig = 1;*/
512 /*    else {*/
513 /*        MPID_Datatype_get_ptr(datatype, dtp);*/
514 /*        is_contig = dtp->is_contig;*/
515 /*    }*/
516
517     is_homogeneous = 1;
518
519     /* MPI_Type_size() might not give the accurate size of the packed
520      * datatype for heterogeneous systems (because of padding, encoding,
521      * etc). On the other hand, MPI_Pack_size() can become very
522      * expensive, depending on the implementation, especially for
523      * heterogeneous systems. We want to use MPI_Type_size() wherever
524      * possible, and MPI_Pack_size() in other places.
525      */
526     //if (is_homogeneous) {
527         type_size=datatype->size();
528
529    /* } else {
530         MPIR_Pack_size_impl(1, datatype, &type_size);
531     }*/
532     nbytes =  (count) * (type_size);
533
534     /* Search for the corresponding system size inside the tuning table */
535     while ((range < (mv2_size_bcast_tuning_table - 1)) &&
536            (comm_size > mv2_bcast_thresholds_table[range].numproc)) {
537         range++;
538     }
539     /* Search for corresponding inter-leader function */
540     while ((range_threshold < (mv2_bcast_thresholds_table[range].size_inter_table - 1))
541            && (nbytes >
542                mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max)
543            && (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].max != -1)) {
544         range_threshold++;
545     }
546
547     /* Search for corresponding intra-node function */
548     while ((range_threshold_intra <
549             (mv2_bcast_thresholds_table[range].size_intra_table - 1))
550            && (nbytes >
551                mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max)
552            && (mv2_bcast_thresholds_table[range].intra_node[range_threshold_intra].max !=
553                -1)) {
554         range_threshold_intra++;
555     }
556
557     MV2_Bcast_function =
558         mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
559         MV2_pt_Bcast_function;
560
561     MV2_Bcast_intra_node_function =
562         mv2_bcast_thresholds_table[range].
563         intra_node[range_threshold_intra].MV2_pt_Bcast_function;
564
565 /*    if (mv2_user_bcast_intra == NULL && */
566 /*            MV2_Bcast_intra_node_function == &MPIR_Knomial_Bcast_intra_node_MV2) {*/
567 /*            MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;*/
568 /*    }*/
569
570     if (mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
571         zcpy_pipelined_knomial_factor != -1) {
572         zcpy_knomial_factor = 
573             mv2_bcast_thresholds_table[range].inter_leader[range_threshold].
574             zcpy_pipelined_knomial_factor;
575     }
576
577     if (mv2_pipelined_zcpy_knomial_factor != -1) {
578         zcpy_knomial_factor = mv2_pipelined_zcpy_knomial_factor;
579     }
580
581     if(MV2_Bcast_intra_node_function == NULL) {
582         /* if tuning table do not have any intra selection, set func pointer to
583         ** default one for mcast intra node */
584         MV2_Bcast_intra_node_function = &MPIR_Shmem_Bcast_MV2;
585     }
586
587     /* Set value of pipeline segment size */
588     bcast_segment_size = mv2_bcast_thresholds_table[range].bcast_segment_size;
589     
590     /* Set value of inter node knomial factor */
591     mv2_inter_node_knomial_factor = mv2_bcast_thresholds_table[range].inter_node_knomial_factor;
592
593     /* Set value of intra node knomial factor */
594     mv2_intra_node_knomial_factor = mv2_bcast_thresholds_table[range].intra_node_knomial_factor;
595
596     /* Check if we will use a two level algorithm or not */
597     two_level_bcast =
598 #if defined(_MCST_SUPPORT_)
599         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold] 
600         || comm->ch.is_mcast_ok;
601 #else
602         mv2_bcast_thresholds_table[range].is_two_level_bcast[range_threshold];
603 #endif
604      if (two_level_bcast == 1) {
605         if (!is_contig || !is_homogeneous) {
606             tmp_buf=(void *)smpi_get_tmp_sendbuffer(nbytes);
607
608 /*            position = 0;*/
609 /*            if (rank == root) {*/
610 /*                mpi_errno =*/
611 /*                    MPIR_Pack_impl(buffer, count, datatype, tmp_buf, nbytes, &position);*/
612 /*                if (mpi_errno)*/
613 /*                    MPIU_ERR_POP(mpi_errno);*/
614 /*            }*/
615         }
616 #ifdef CHANNEL_MRAIL_GEN2
617         if ((mv2_enable_zcpy_bcast == 1) &&
618               (&MPIR_Pipelined_Bcast_Zcpy_MV2 == MV2_Bcast_function)) {  
619             if (!is_contig || !is_homogeneous) {
620                 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(tmp_buf, nbytes, MPI_BYTE,
621                                                  root, comm);
622             } else { 
623                 mpi_errno = MPIR_Pipelined_Bcast_Zcpy_MV2(buffer, count, datatype,
624                                                  root, comm);
625             } 
626         } else 
627 #endif /* defined(CHANNEL_MRAIL_GEN2) */
628         { 
629             shmem_comm = comm->get_intra_comm();
630             if (!is_contig || !is_homogeneous) {
631                 mpi_errno =
632                     MPIR_Bcast_tune_inter_node_helper_MV2(tmp_buf, nbytes, MPI_BYTE,
633                                                           root, comm);
634             } else {
635                 mpi_errno =
636                     MPIR_Bcast_tune_inter_node_helper_MV2(buffer, count, datatype, root,
637                                                           comm);
638             }
639
640             /* We are now done with the inter-node phase */
641
642
643                     root = INTRA_NODE_ROOT;
644    
645
646                 if (!is_contig || !is_homogeneous) {
647                     mpi_errno = MV2_Bcast_intra_node_function(tmp_buf, nbytes,
648                                                               MPI_BYTE, root, shmem_comm);
649                 } else {
650                     mpi_errno = MV2_Bcast_intra_node_function(buffer, count,
651                                                               datatype, root, shmem_comm);
652
653                 }
654         } 
655 /*        if (!is_contig || !is_homogeneous) {*/
656 /*            if (rank != root) {*/
657 /*                position = 0;*/
658 /*                mpi_errno = MPIR_Unpack_impl(tmp_buf, nbytes, &position, buffer,*/
659 /*                                             count, datatype);*/
660 /*            }*/
661 /*        }*/
662     } else {
663         /* We use Knomial for intra node */
664         MV2_Bcast_intra_node_function = &MPIR_Knomial_Bcast_intra_node_MV2;
665 /*        if (mv2_enable_shmem_bcast == 0) {*/
666             /* Fall back to non-tuned version */
667 /*            MPIR_Bcast_intra_MV2(buffer, count, datatype, root, comm);*/
668 /*        } else {*/
669             mpi_errno = MV2_Bcast_function(buffer, count, datatype, root,
670                                            comm);
671
672 /*        }*/
673     }
674
675
676     return mpi_errno;
677
678 }
679
680
681
682 int Coll_reduce_mvapich2::reduce( void *sendbuf,
683     void *recvbuf,
684     int count,
685     MPI_Datatype datatype,
686     MPI_Op op, int root, MPI_Comm comm)
687 {
688   if(mv2_reduce_thresholds_table == NULL)
689     init_mv2_reduce_tables_stampede();
690
691   int mpi_errno = MPI_SUCCESS;
692   int range = 0;
693   int range_threshold = 0;
694   int range_intra_threshold = 0;
695   int is_commutative, pof2;
696   int comm_size = 0;
697   long nbytes = 0;
698   int sendtype_size;
699   int is_two_level = 0;
700
701   comm_size = comm->size();
702   sendtype_size=datatype->size();
703   nbytes = count * sendtype_size;
704
705   if (count == 0)
706     return MPI_SUCCESS;
707
708   is_commutative = (op==MPI_OP_NULL || op->is_commutative());
709
710   /* find nearest power-of-two less than or equal to comm_size */
711   for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
712   pof2 >>=1;
713
714
715   /* Search for the corresponding system size inside the tuning table */
716   while ((range < (mv2_size_reduce_tuning_table - 1)) &&
717       (comm_size > mv2_reduce_thresholds_table[range].numproc)) {
718       range++;
719   }
720   /* Search for corresponding inter-leader function */
721   while ((range_threshold < (mv2_reduce_thresholds_table[range].size_inter_table - 1))
722       && (nbytes >
723   mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max)
724   && (mv2_reduce_thresholds_table[range].inter_leader[range_threshold].max !=
725       -1)) {
726       range_threshold++;
727   }
728
729   /* Search for corresponding intra node function */
730   while ((range_intra_threshold < (mv2_reduce_thresholds_table[range].size_intra_table - 1))
731       && (nbytes >
732   mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max)
733   && (mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].max !=
734       -1)) {
735       range_intra_threshold++;
736   }
737
738   /* Set intra-node function pt for reduce_two_level */
739   MV2_Reduce_intra_function =
740       mv2_reduce_thresholds_table[range].intra_node[range_intra_threshold].
741       MV2_pt_Reduce_function;
742   /* Set inter-leader pt */
743   MV2_Reduce_function =
744       mv2_reduce_thresholds_table[range].inter_leader[range_threshold].
745       MV2_pt_Reduce_function;
746
747   if(mv2_reduce_intra_knomial_factor<0)
748     {
749       mv2_reduce_intra_knomial_factor = mv2_reduce_thresholds_table[range].intra_k_degree;
750     }
751   if(mv2_reduce_inter_knomial_factor<0)
752     {
753       mv2_reduce_inter_knomial_factor = mv2_reduce_thresholds_table[range].inter_k_degree;
754     }
755   if(mv2_reduce_thresholds_table[range].is_two_level_reduce[range_threshold] == 1){
756       is_two_level = 1;
757   }
758   /* We call Reduce function */
759   if(is_two_level == 1)
760     {
761        if (is_commutative == 1) {
762          if(comm->get_leaders_comm()==MPI_COMM_NULL){
763            comm->init_smp();
764          }
765          mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count, 
766                                            datatype, op, root, comm);
767         } else {
768       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
769           datatype, op, root, comm);
770       }
771     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
772         if(is_commutative ==1)
773           {
774             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
775                 datatype, op, root, comm);
776           } else {
777               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
778                   datatype, op, root, comm);
779           }
780     } else if(MV2_Reduce_function == &MPIR_Reduce_redscat_gather_MV2){
781         if (/*(HANDLE_GET_KIND(op) == HANDLE_KIND_BUILTIN) &&*/ (count >= pof2))
782           {
783             mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
784                 datatype, op, root, comm);
785           } else {
786               mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
787                   datatype, op, root, comm);
788           }
789     } else {
790         mpi_errno = MV2_Reduce_function(sendbuf, recvbuf, count, 
791             datatype, op, root, comm);
792     }
793
794
795   return mpi_errno;
796
797 }
798
799
800 int Coll_reduce_scatter_mvapich2::reduce_scatter(void *sendbuf, void *recvbuf, int *recvcnts,
801     MPI_Datatype datatype, MPI_Op op,
802     MPI_Comm comm)
803 {
804   int mpi_errno = MPI_SUCCESS;
805   int i = 0, comm_size = comm->size(), total_count = 0, type_size =
806       0, nbytes = 0;
807   int range = 0;
808   int range_threshold = 0;
809   int is_commutative = 0;
810   int *disps = static_cast<int*>(xbt_malloc(comm_size * sizeof (int)));
811
812   if(mv2_red_scat_thresholds_table==NULL)
813     init_mv2_reduce_scatter_tables_stampede();
814
815   is_commutative=(op==MPI_OP_NULL || op->is_commutative());
816   for (i = 0; i < comm_size; i++) {
817       disps[i] = total_count;
818       total_count += recvcnts[i];
819   }
820
821   type_size=datatype->size();
822   nbytes = total_count * type_size;
823
824   if (is_commutative) {
825
826       /* Search for the corresponding system size inside the tuning table */
827       while ((range < (mv2_size_red_scat_tuning_table - 1)) &&
828           (comm_size > mv2_red_scat_thresholds_table[range].numproc)) {
829           range++;
830       }
831       /* Search for corresponding inter-leader function */
832       while ((range_threshold < (mv2_red_scat_thresholds_table[range].size_inter_table - 1))
833           && (nbytes >
834       mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max)
835       && (mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].max !=
836           -1)) {
837           range_threshold++;
838       }
839
840       /* Set inter-leader pt */
841       MV2_Red_scat_function =
842           mv2_red_scat_thresholds_table[range].inter_leader[range_threshold].
843           MV2_pt_Red_scat_function;
844
845       mpi_errno = MV2_Red_scat_function(sendbuf, recvbuf,
846           recvcnts, datatype,
847           op, comm);
848   } else {
849       int is_block_regular = 1;
850       for (i = 0; i < (comm_size - 1); ++i) {
851           if (recvcnts[i] != recvcnts[i+1]) {
852               is_block_regular = 0;
853               break;
854           }
855       }
856       int pof2 = 1;
857       while (pof2 < comm_size) pof2 <<= 1;
858       if (pof2 == comm_size && is_block_regular) {
859           /* noncommutative, pof2 size, and block regular */
860           mpi_errno = MPIR_Reduce_scatter_non_comm_MV2(sendbuf, recvbuf,
861               recvcnts, datatype,
862               op, comm);
863       }
864       mpi_errno =  Coll_reduce_scatter_mpich_rdb::reduce_scatter(sendbuf, recvbuf,
865           recvcnts, datatype,
866           op, comm);
867   }
868   xbt_free(disps);
869   return mpi_errno;
870
871 }
872
873
874
875 int Coll_scatter_mvapich2::scatter(void *sendbuf,
876     int sendcnt,
877     MPI_Datatype sendtype,
878     void *recvbuf,
879     int recvcnt,
880     MPI_Datatype recvtype,
881     int root, MPI_Comm comm)
882 {
883   int range = 0, range_threshold = 0, range_threshold_intra = 0;
884   int mpi_errno = MPI_SUCCESS;
885   //   int mpi_errno_ret = MPI_SUCCESS;
886   int rank, nbytes, comm_size;
887   int recvtype_size, sendtype_size;
888   int partial_sub_ok = 0;
889   int conf_index = 0;
890     int local_size = -1;
891     int i;
892      MPI_Comm shmem_comm;
893   //    MPID_Comm *shmem_commptr=NULL;
894   if(mv2_scatter_thresholds_table==NULL)
895     init_mv2_scatter_tables_stampede();
896
897   if(comm->get_leaders_comm()==MPI_COMM_NULL){
898     comm->init_smp();
899   }
900   
901   comm_size = comm->size();
902
903   rank = comm->rank();
904
905   if (rank == root) {
906       sendtype_size=sendtype->size();
907       nbytes = sendcnt * sendtype_size;
908   } else {
909       recvtype_size=recvtype->size();
910       nbytes = recvcnt * recvtype_size;
911   }
912   
913     // check if safe to use partial subscription mode 
914     if (comm->is_uniform()) {
915
916         shmem_comm = comm->get_intra_comm();
917         local_size = shmem_comm->size();
918         i = 0;
919         if (mv2_scatter_table_ppn_conf[0] == -1) {
920             // Indicating user defined tuning 
921             conf_index = 0;
922         }else{
923             do {
924                 if (local_size == mv2_scatter_table_ppn_conf[i]) {
925                     conf_index = i;
926                     partial_sub_ok = 1;
927                     break;
928                 }
929                 i++;
930             } while(i < mv2_scatter_num_ppn_conf);
931         }
932     }
933    
934   if (partial_sub_ok != 1) {
935       conf_index = 0;
936   }
937
938   /* Search for the corresponding system size inside the tuning table */
939   while ((range < (mv2_size_scatter_tuning_table[conf_index] - 1)) &&
940       (comm_size > mv2_scatter_thresholds_table[conf_index][range].numproc)) {
941       range++;
942   }
943   /* Search for corresponding inter-leader function */
944   while ((range_threshold < (mv2_scatter_thresholds_table[conf_index][range].size_inter_table - 1))
945       && (nbytes >
946   mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max)
947   && (mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold].max != -1)) {
948       range_threshold++;
949   }
950
951   /* Search for corresponding intra-node function */
952   while ((range_threshold_intra <
953       (mv2_scatter_thresholds_table[conf_index][range].size_intra_table - 1))
954       && (nbytes >
955   mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max)
956   && (mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra].max !=
957       -1)) {
958       range_threshold_intra++;
959   }
960
961   MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold]
962                                                                                       .MV2_pt_Scatter_function;
963
964   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
965 #if defined(_MCST_SUPPORT_)
966       if(comm->ch.is_mcast_ok == 1
967           && mv2_use_mcast_scatter == 1
968           && comm->ch.shmem_coll_ok == 1) {
969           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
970       } else
971 #endif /*#if defined(_MCST_SUPPORT_) */
972         {
973           if(mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1].
974               MV2_pt_Scatter_function != NULL) {
975               MV2_Scatter_function = mv2_scatter_thresholds_table[conf_index][range].inter_leader[range_threshold + 1]
976                                                                                                   .MV2_pt_Scatter_function;
977           } else {
978               /* Fallback! */
979               MV2_Scatter_function = &MPIR_Scatter_MV2_Binomial;
980           }
981         } 
982   }
983
984   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
985       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
986        if( comm->is_blocked()) {
987              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
988                                 .MV2_pt_Scatter_function;
989
990              mpi_errno =
991                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
992                                         recvbuf, recvcnt, recvtype, root,
993                                         comm);
994          } else {
995       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
996           recvbuf, recvcnt, recvtype, root,
997           comm);
998
999       }
1000   } else {
1001       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
1002           recvbuf, recvcnt, recvtype, root,
1003           comm);
1004   }
1005   return (mpi_errno);
1006 }
1007
1008 }
1009 }
1010 void smpi_coll_cleanup_mvapich2(void){
1011 int i=0;
1012 if(mv2_alltoall_thresholds_table)
1013   xbt_free(mv2_alltoall_thresholds_table[i]);
1014 xbt_free(mv2_alltoall_thresholds_table);
1015 xbt_free(mv2_size_alltoall_tuning_table);
1016 xbt_free(mv2_alltoall_table_ppn_conf);
1017
1018 xbt_free(mv2_gather_thresholds_table);
1019 if(mv2_allgather_thresholds_table)
1020   xbt_free(mv2_allgather_thresholds_table[0]);
1021 xbt_free(mv2_size_allgather_tuning_table);
1022 xbt_free(mv2_allgather_table_ppn_conf);
1023 xbt_free(mv2_allgather_thresholds_table);
1024
1025 xbt_free(mv2_allgatherv_thresholds_table);
1026 xbt_free(mv2_reduce_thresholds_table);
1027 xbt_free(mv2_red_scat_thresholds_table);
1028 xbt_free(mv2_allreduce_thresholds_table);
1029 xbt_free(mv2_bcast_thresholds_table);
1030 if(mv2_scatter_thresholds_table)
1031   xbt_free(mv2_scatter_thresholds_table[0]);
1032 xbt_free(mv2_scatter_thresholds_table);
1033 xbt_free(mv2_size_scatter_tuning_table);
1034 xbt_free(mv2_scatter_table_ppn_conf);
1035 }