Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
use tuned algo here
[simgrid.git] / src / smpi / colls / allreduce-NTS.c
index eac171c..85c790b 100644 (file)
@@ -1,4 +1,4 @@
-#include "colls.h"
+#include "colls_private.h"
 /* IMPLEMENTED BY PITCH PATARASUK 
    Non-topoloty-specific all-reduce operation designed bandwidth optimally */
 
@@ -14,32 +14,32 @@ int
 smpi_coll_tuned_allreduce_NTS(void *sbuf, void *rbuf, int rcount,
                               MPI_Datatype dtype, MPI_Op op, MPI_Comm comm)
 {
-  int tag = 5000;
+  int tag = COLL_TAG_ALLREDUCE;
   MPI_Status status;
   int rank, i, size, count;
   int send_offset, recv_offset;
   int remainder, remainder_flag, remainder_offset;
 
-  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
-  MPI_Comm_size(MPI_COMM_WORLD, &size);
+  rank = smpi_comm_rank(comm);
+  size = smpi_comm_size(comm);
 
   /* make it compatible with all data type */
   MPI_Aint extent;
-  MPI_Type_extent(dtype, &extent);
+  extent = smpi_datatype_get_extent(dtype);
 
   /* when communication size is smaller than number of process (not support) */
   if (rcount < size) {
-    return MPI_Allreduce(sbuf, rbuf, rcount, dtype, op, comm);
+    return mpi_coll_allreduce_fun(sbuf, rbuf, rcount, dtype, op, comm);
   }
 
   /* when communication size is not divisible by number of process: 
      call the native implementation for the remain chunk at the end of the operation */
-  else if (rcount % size != 0) {
+  if (rcount % size != 0) {
     remainder = rcount % size;
     remainder_flag = 1;
     remainder_offset = (rcount / size) * size * extent;
   } else {
-    remainder_flag = remainder_offset = 0;
+    remainder = remainder_flag = remainder_offset = 0;
   }
 
   /* size of each point-to-point communication is equal to the size of the whole message
@@ -56,7 +56,7 @@ smpi_coll_tuned_allreduce_NTS(void *sbuf, void *rbuf, int rcount,
   // copy partial data
   send_offset = ((rank - 1 + size) % size) * count * extent;
   recv_offset = ((rank - 1 + size) % size) * count * extent;
-  MPI_Sendrecv((char *) sbuf + send_offset, count, dtype, rank, tag - 1,
+  smpi_mpi_sendrecv((char *) sbuf + send_offset, count, dtype, rank, tag - 1,
                (char *) rbuf + recv_offset, count, dtype, rank, tag - 1, comm,
                &status);
 
@@ -64,19 +64,19 @@ smpi_coll_tuned_allreduce_NTS(void *sbuf, void *rbuf, int rcount,
   for (i = 0; i < (size - 1); i++) {
     send_offset = ((rank - 1 - i + size) % size) * count * extent;
     recv_offset = ((rank - 2 - i + size) % size) * count * extent;
-    MPI_Sendrecv((char *) rbuf + send_offset, count, dtype, ((rank + 1) % size),
+    smpi_mpi_sendrecv((char *) rbuf + send_offset, count, dtype, ((rank + 1) % size),
                  tag + i, (char *) rbuf + recv_offset, count, dtype,
                  ((rank + size - 1) % size), tag + i, comm, &status);
 
     // compute result to rbuf+recv_offset
-    star_reduction(op, (char *)sbuf + recv_offset, (char *)rbuf + recv_offset, &count, &dtype);
+    smpi_op_apply(op, (char *)sbuf + recv_offset, (char *)rbuf + recv_offset, &count, &dtype);
   }
 
   // all-gather
   for (i = 0; i < (size - 1); i++) {
     send_offset = ((rank - i + size) % size) * count * extent;
     recv_offset = ((rank - 1 - i + size) % size) * count * extent;
-    MPI_Sendrecv((char *) rbuf + send_offset, count, dtype, ((rank + 1) % size),
+    smpi_mpi_sendrecv((char *) rbuf + send_offset, count, dtype, ((rank + 1) % size),
                  tag + i, (char *) rbuf + recv_offset, count, dtype,
                  ((rank + size - 1) % size), tag + i, comm, &status);
   }
@@ -84,10 +84,12 @@ smpi_coll_tuned_allreduce_NTS(void *sbuf, void *rbuf, int rcount,
   /* when communication size is not divisible by number of process: 
      call the native implementation for the remain chunk at the end of the operation */
   if (remainder_flag) {
-    return MPI_Allreduce((char *) sbuf + remainder_offset,
+    XBT_WARN("MPI_allreduce_NTS use default MPI_allreduce.");
+    smpi_mpi_allreduce((char *) sbuf + remainder_offset,
                          (char *) rbuf + remainder_offset, remainder, dtype, op,
                          comm);
+    return MPI_SUCCESS;    
   }
 
-  return 0;
+  return MPI_SUCCESS;
 }