Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Fix mpi bcast flattree-pipeline collective
[simgrid.git] / src / smpi / colls / bcast-flattree-pipeline.c
index 5212032..9033bf5 100644 (file)
@@ -1,4 +1,4 @@
-#include "colls.h"
+#include "colls_private.h"
 
 int flattree_segment_in_byte = 8192;
 
@@ -11,27 +11,29 @@ smpi_coll_tuned_bcast_flattree_pipeline(void *buff, int count,
   int tag = 1;
 
   MPI_Aint extent;
-  MPI_Type_extent(data_type, &extent);
+  extent = smpi_datatype_get_extent(data_type);
 
   int segment = flattree_segment_in_byte / extent;
   int pipe_length = count / segment;
   int increment = segment * extent;
-
-  MPI_Comm_rank(comm, &rank);
-  MPI_Comm_size(comm, &num_procs);
+  if (pipe_length==0) {
+    XBT_WARN("MPI_bcast_flattree_pipeline use default MPI_bcast_flattree.");
+    return smpi_coll_tuned_bcast_flattree(buff, count, data_type, root, comm);
+  }
+  rank = smpi_comm_rank(comm);
+  num_procs = smpi_comm_size(comm);
 
   MPI_Request *request_array;
   MPI_Status *status_array;
 
-  request_array = (MPI_Request *) malloc(pipe_length * sizeof(MPI_Request));
-  status_array = (MPI_Status *) malloc(pipe_length * sizeof(MPI_Status));
+  request_array = (MPI_Request *) xbt_malloc(pipe_length * sizeof(MPI_Request));
+  status_array = (MPI_Status *) xbt_malloc(pipe_length * sizeof(MPI_Status));
 
   if (rank != root) {
     for (i = 0; i < pipe_length; i++) {
-      MPI_Irecv((char *)buff + (i * increment), segment, data_type, root, tag, comm,
-                &request_array[i]);
+      request_array[i] = smpi_mpi_irecv((char *)buff + (i * increment), segment, data_type, root, tag, comm);
     }
-    MPI_Waitall(pipe_length, request_array, status_array);
+    smpi_mpi_waitall(pipe_length, request_array, status_array);
   }
 
   else {
@@ -41,7 +43,7 @@ smpi_coll_tuned_bcast_flattree_pipeline(void *buff, int count,
         continue;
       else {
         for (i = 0; i < pipe_length; i++) {
-          MPI_Send((char *)buff + (i * increment), segment, data_type, j, tag, comm);
+          smpi_mpi_send((char *)buff + (i * increment), segment, data_type, j, tag, comm);
         }
       }
     }