Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add collectives for allgather, allreduce, bcast and reduce
[simgrid.git] / src / smpi / smpi_mpi.c
index bb04f50..518e265 100644 (file)
@@ -14,17 +14,35 @@ XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_mpi, smpi,
 
 int MPI_Init(int *argc, char ***argv)
 {
+  int allgather_id = find_coll_description(mpi_coll_allgather_description,
+                                           sg_cfg_get_string("smpi/allgather"));
+  mpi_coll_allgather_fun = (int (*)(void *, int, MPI_Datatype,
+                                   void*, int, MPI_Datatype, MPI_Comm))
+                          mpi_coll_allgather_description[allgather_id].coll;
+
+  int allreduce_id = find_coll_description(mpi_coll_allreduce_description,
+                                           sg_cfg_get_string("smpi/allreduce"));
+  mpi_coll_allreduce_fun = (int (*)(void *sbuf, void *rbuf, int rcount, \
+                                    MPI_Datatype dtype, MPI_Op op, MPI_Comm comm))
+                          mpi_coll_allreduce_description[allreduce_id].coll;
+
   int alltoall_id = find_coll_description(mpi_coll_alltoall_description,
                                           sg_cfg_get_string("smpi/alltoall"));
   mpi_coll_alltoall_fun = (int (*)(void *, int, MPI_Datatype,
                                   void*, int, MPI_Datatype, MPI_Comm))
                          mpi_coll_alltoall_description[alltoall_id].coll;
 
-  int allgather_id = find_coll_description(mpi_coll_allgather_description,
-                                           sg_cfg_get_string("smpi/allgather"));
-  mpi_coll_allgather_fun = (int (*)(void *, int, MPI_Datatype,
-                                   void*, int, MPI_Datatype, MPI_Comm))
-                          mpi_coll_allgather_description[allgather_id].coll;
+  int bcast_id = find_coll_description(mpi_coll_bcast_description,
+                                          sg_cfg_get_string("smpi/bcast"));
+  mpi_coll_bcast_fun = (int (*)(void *buf, int count, MPI_Datatype datatype, \
+                               int root, MPI_Comm com))
+                      mpi_coll_bcast_description[bcast_id].coll;
+
+  int reduce_id = find_coll_description(mpi_coll_reduce_description,
+                                          sg_cfg_get_string("smpi/reduce"));
+  mpi_coll_reduce_fun = (int (*)(void *buf, void *rbuf, int count, MPI_Datatype datatype, \
+                                 MPI_Op op, int root, MPI_Comm comm))
+                       mpi_coll_reduce_description[reduce_id].coll;
 
   return PMPI_Init(argc, argv);
 }