Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Handle duplicated datatypes within predefined MPI_Op.
[simgrid.git] / src / smpi / mpi / smpi_op.cpp
index 5ec47aa..afc3264 100644 (file)
@@ -45,8 +45,13 @@ XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_op, smpi, "Logging specific to SMPI (op)");
   }                                        \
 }
 
+#define APPLY_BEGIN_OP_LOOP()                                                                                          \
+  MPI_Datatype datatype_base = *datatype;                                                                              \
+  while (datatype_base->duplicated_datatype() != MPI_DATATYPE_NULL)                                                    \
+    datatype_base = datatype_base->duplicated_datatype();
+
 #define APPLY_OP_LOOP(dtype, type, op)                                                                                 \
-  if (*datatype == (dtype)) {                                                                                          \
+  if (datatype_base == (dtype)) {                                                                                      \
     APPLY_FUNC(a, b, length, type, op)                                                                                 \
   } else
 
@@ -121,6 +126,7 @@ APPLY_OP_LOOP(MPI_COMPLEX32, double_double,op)
 
 static void max_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(MAX_OP)
   APPLY_FLOAT_OP_LOOP(MAX_OP)
   APPLY_END_OP_LOOP(MAX_OP)
@@ -128,6 +134,7 @@ static void max_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void min_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(MIN_OP)
   APPLY_FLOAT_OP_LOOP(MIN_OP)
   APPLY_END_OP_LOOP(MIN_OP)
@@ -135,6 +142,7 @@ static void min_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void sum_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(SUM_OP)
   APPLY_FLOAT_OP_LOOP(SUM_OP)
   APPLY_COMPLEX_OP_LOOP(SUM_OP)
@@ -144,6 +152,7 @@ static void sum_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void prod_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(PROD_OP)
   APPLY_FLOAT_OP_LOOP(PROD_OP)
   APPLY_COMPLEX_OP_LOOP(PROD_OP)
@@ -153,6 +162,7 @@ static void prod_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void land_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(LAND_OP)
   APPLY_FLOAT_OP_LOOP(LAND_OP)
   APPLY_BOOL_OP_LOOP(LAND_OP)
@@ -161,6 +171,7 @@ static void land_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void lor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(LOR_OP)
   APPLY_FLOAT_OP_LOOP(LOR_OP)
   APPLY_BOOL_OP_LOOP(LOR_OP)
@@ -169,6 +180,7 @@ static void lor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void lxor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(LXOR_OP)
   APPLY_FLOAT_OP_LOOP(LXOR_OP)
   APPLY_BOOL_OP_LOOP(LXOR_OP)
@@ -177,6 +189,7 @@ static void lxor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void band_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(BAND_OP)
   APPLY_BOOL_OP_LOOP(BAND_OP)
   APPLY_BYTE_OP_LOOP(BAND_OP)
@@ -185,6 +198,7 @@ static void band_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void bor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(BOR_OP)
   APPLY_BOOL_OP_LOOP(BOR_OP)
   APPLY_BYTE_OP_LOOP(BOR_OP)
@@ -193,6 +207,7 @@ static void bor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void bxor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(BXOR_OP)
   APPLY_BOOL_OP_LOOP(BXOR_OP)
   APPLY_BYTE_OP_LOOP(BXOR_OP)
@@ -201,12 +216,14 @@ static void bxor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void minloc_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_PAIR_OP_LOOP(MINLOC_OP)
   APPLY_END_OP_LOOP(MINLOC_OP)
 }
 
 static void maxloc_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_PAIR_OP_LOOP(MAXLOC_OP)
   APPLY_END_OP_LOOP(MAXLOC_OP)
 }
@@ -256,10 +273,10 @@ void Op::apply(const void* invec, void* inoutvec, const int* len, MPI_Datatype d
   }
 
   if (not smpi_process()->replaying() && *len > 0) {
+    XBT_DEBUG("Applying operation of length %d from %p and from/to %p", *len, invec, inoutvec);
     if (not is_fortran_op_)
       this->func_(const_cast<void*>(invec), inoutvec, const_cast<int*>(len), &datatype);
     else{
-      XBT_DEBUG("Applying operation of length %d from %p and from/to %p", *len, invec, inoutvec);
       int tmp = datatype->c2f();
       /* Unfortunately, the C and Fortran version of the MPI standard do not agree on the type here,
          thus the reinterpret_cast. */