Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Make smpi_switch_data_segment check if a switch is needed, and return true when it...
[simgrid.git] / src / smpi / mpi / smpi_op.cpp
index e81b93f..a0501c1 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved.          */
+/* Copyright (c) 2009-2021. The SimGrid Team. All rights reserved.          */
 
 /* This program is free software; you can redistribute it and/or modify it
  * under the terms of the license (GNU LGPL) which comes with this package. */
@@ -13,15 +13,27 @@ XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_op, smpi, "Logging specific to SMPI (op)");
 #define MAX_OP(a, b)  (b) = (a) < (b) ? (b) : (a)
 #define MIN_OP(a, b)  (b) = (a) < (b) ? (a) : (b)
 #define SUM_OP(a, b)  (b) += (a)
+#define SUM_OP_COMPLEX(a, b)                                                                                           \
+  {                                                                                                                    \
+    ((b).value) += ((a).value);                                                                                        \
+    ((b).index) += ((a).index);                                                                                        \
+  }
 #define PROD_OP(a, b) (b) *= (a)
+#define PROD_OP_COMPLEX(a, b)                                                                                          \
+  {                                                                                                                    \
+    ((b).value) *= ((a).value);                                                                                        \
+    ((b).index) *= ((a).index);                                                                                        \
+  }
 #define LAND_OP(a, b) (b) = (a) && (b)
 #define LOR_OP(a, b)  (b) = (a) || (b)
-#define LXOR_OP(a, b) (b) = (not(a) && (b)) || ((a) && not(b))
+#define LXOR_OP(a, b) (b) = bool(a) != bool(b)
 #define BAND_OP(a, b) (b) &= (a)
 #define BOR_OP(a, b)  (b) |= (a)
 #define BXOR_OP(a, b) (b) ^= (a)
-#define MAXLOC_OP(a, b)  (b) = (a.value) < (b.value) ? (b) : ((a.value) == (b.value) ? ((a.index) < (b.index) ? (a) : (b)) : (a))
-#define MINLOC_OP(a, b)  (b) = (a.value) < (b.value) ? (a) : ((a.value) == (b.value) ? ((a.index) < (b.index) ? (a) : (b)) : (b))
+#define MAXLOC_OP(a, b)                                                                                                \
+  (b) = ((a).value) < ((b).value) ? (b) : (((a).value) == ((b).value) ? (((a).index) < ((b).index) ? (a) : (b)) : (a))
+#define MINLOC_OP(a, b)                                                                                                \
+  (b) = ((a).value) < ((b).value) ? (a) : (((a).value) == ((b).value) ? (((a).index) < ((b).index) ? (a) : (b)) : (b))
 
 #define APPLY_FUNC(a, b, length, type, func) \
 {                                          \
@@ -33,11 +45,15 @@ XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_op, smpi, "Logging specific to SMPI (op)");
   }                                        \
 }
 
-#define APPLY_OP_LOOP(dtype, type, op) \
-  if (*datatype == dtype) {\
-    APPLY_FUNC(a, b, length, type, op)\
-  } else \
+#define APPLY_BEGIN_OP_LOOP()                                                                                          \
+  MPI_Datatype datatype_base = *datatype;                                                                              \
+  while (datatype_base->duplicated_datatype() != MPI_DATATYPE_NULL)                                                    \
+    datatype_base = datatype_base->duplicated_datatype();
 
+#define APPLY_OP_LOOP(dtype, type, op)                                                                                 \
+  if (datatype_base == (dtype)) {                                                                                      \
+    APPLY_FUNC(a, b, length, type, op)                                                                                 \
+  } else
 
 #define APPLY_BASIC_OP_LOOP(op)\
 APPLY_OP_LOOP(MPI_CHAR, char,op)\
@@ -52,7 +68,6 @@ APPLY_OP_LOOP(MPI_UNSIGNED, unsigned int,op)\
 APPLY_OP_LOOP(MPI_UNSIGNED_LONG, unsigned long,op)\
 APPLY_OP_LOOP(MPI_UNSIGNED_LONG_LONG, unsigned long long,op)\
 APPLY_OP_LOOP(MPI_WCHAR, wchar_t,op)\
-APPLY_OP_LOOP(MPI_BYTE, int8_t,op)\
 APPLY_OP_LOOP(MPI_INT8_T, int8_t,op)\
 APPLY_OP_LOOP(MPI_INT16_T, int16_t,op)\
 APPLY_OP_LOOP(MPI_INT32_T, int32_t,op)\
@@ -66,11 +81,16 @@ APPLY_OP_LOOP(MPI_OFFSET, MPI_Offset,op)\
 APPLY_OP_LOOP(MPI_INTEGER1, int,op)\
 APPLY_OP_LOOP(MPI_INTEGER2, int16_t,op)\
 APPLY_OP_LOOP(MPI_INTEGER4, int32_t,op)\
-APPLY_OP_LOOP(MPI_INTEGER8, int64_t,op)
+APPLY_OP_LOOP(MPI_INTEGER8, int64_t,op)\
+APPLY_OP_LOOP(MPI_COUNT, long long,op)
+
 
 #define APPLY_BOOL_OP_LOOP(op)\
 APPLY_OP_LOOP(MPI_C_BOOL, bool,op)
 
+#define APPLY_BYTE_OP_LOOP(op)\
+APPLY_OP_LOOP(MPI_BYTE, int8_t,op)
+
 #define APPLY_FLOAT_OP_LOOP(op)\
 APPLY_OP_LOOP(MPI_FLOAT, float,op)\
 APPLY_OP_LOOP(MPI_DOUBLE, double,op)\
@@ -94,15 +114,19 @@ APPLY_OP_LOOP(MPI_2INT, int_int,op)\
 APPLY_OP_LOOP(MPI_2FLOAT, float_float,op)\
 APPLY_OP_LOOP(MPI_2DOUBLE, double_double,op)\
 APPLY_OP_LOOP(MPI_LONG_DOUBLE_INT, long_double_int,op)\
-APPLY_OP_LOOP(MPI_2LONG, long_long,op)
-
-#define APPLY_END_OP_LOOP(op)\
-  {\
-    xbt_die("Failed to apply " #op " to type %s", (*datatype)->name());\
+APPLY_OP_LOOP(MPI_2LONG, long_long,op)\
+APPLY_OP_LOOP(MPI_COMPLEX8, float_float,op)\
+APPLY_OP_LOOP(MPI_COMPLEX16, double_double,op)\
+APPLY_OP_LOOP(MPI_COMPLEX32, double_double,op)
+
+#define APPLY_END_OP_LOOP(op)                                                                                          \
+  {                                                                                                                    \
+    xbt_die("Failed to apply " _XBT_STRINGIFY(op) " to type %s", (*datatype)->name().c_str());                         \
   }
 
 static void max_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(MAX_OP)
   APPLY_FLOAT_OP_LOOP(MAX_OP)
   APPLY_END_OP_LOOP(MAX_OP)
@@ -110,6 +134,7 @@ static void max_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void min_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(MIN_OP)
   APPLY_FLOAT_OP_LOOP(MIN_OP)
   APPLY_END_OP_LOOP(MIN_OP)
@@ -117,70 +142,88 @@ static void min_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 
 static void sum_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(SUM_OP)
   APPLY_FLOAT_OP_LOOP(SUM_OP)
   APPLY_COMPLEX_OP_LOOP(SUM_OP)
+  APPLY_PAIR_OP_LOOP(SUM_OP_COMPLEX)
   APPLY_END_OP_LOOP(SUM_OP)
 }
 
 static void prod_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(PROD_OP)
   APPLY_FLOAT_OP_LOOP(PROD_OP)
   APPLY_COMPLEX_OP_LOOP(PROD_OP)
+  APPLY_PAIR_OP_LOOP(PROD_OP_COMPLEX)
   APPLY_END_OP_LOOP(PROD_OP)
 }
 
 static void land_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(LAND_OP)
+  APPLY_FLOAT_OP_LOOP(LAND_OP)
   APPLY_BOOL_OP_LOOP(LAND_OP)
   APPLY_END_OP_LOOP(LAND_OP)
 }
 
 static void lor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(LOR_OP)
+  APPLY_FLOAT_OP_LOOP(LOR_OP)
   APPLY_BOOL_OP_LOOP(LOR_OP)
   APPLY_END_OP_LOOP(LOR_OP)
 }
 
 static void lxor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(LXOR_OP)
+  APPLY_FLOAT_OP_LOOP(LXOR_OP)
   APPLY_BOOL_OP_LOOP(LXOR_OP)
   APPLY_END_OP_LOOP(LXOR_OP)
 }
 
 static void band_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(BAND_OP)
   APPLY_BOOL_OP_LOOP(BAND_OP)
+  APPLY_BYTE_OP_LOOP(BAND_OP)
   APPLY_END_OP_LOOP(BAND_OP)
 }
 
 static void bor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(BOR_OP)
   APPLY_BOOL_OP_LOOP(BOR_OP)
+  APPLY_BYTE_OP_LOOP(BOR_OP)
   APPLY_END_OP_LOOP(BOR_OP)
 }
 
 static void bxor_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_BASIC_OP_LOOP(BXOR_OP)
   APPLY_BOOL_OP_LOOP(BXOR_OP)
+  APPLY_BYTE_OP_LOOP(BXOR_OP)
   APPLY_END_OP_LOOP(BXOR_OP)
 }
 
 static void minloc_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_PAIR_OP_LOOP(MINLOC_OP)
   APPLY_END_OP_LOOP(MINLOC_OP)
 }
 
 static void maxloc_func(void *a, void *b, int *length, MPI_Datatype * datatype)
 {
+  APPLY_BEGIN_OP_LOOP()
   APPLY_PAIR_OP_LOOP(MAXLOC_OP)
   APPLY_END_OP_LOOP(MAXLOC_OP)
 }
@@ -195,45 +238,46 @@ static void no_func(void*, void*, int*, MPI_Datatype*)
   /* obviously a no-op */
 }
 
-#define CREATE_MPI_OP(name, func)                             \
-  static SMPI_Op mpi_##name (&(func) /* func */, true ); \
-MPI_Op name = &mpi_##name;
-
-CREATE_MPI_OP(MPI_MAX, max_func);
-CREATE_MPI_OP(MPI_MIN, min_func);
-CREATE_MPI_OP(MPI_SUM, sum_func);
-CREATE_MPI_OP(MPI_PROD, prod_func);
-CREATE_MPI_OP(MPI_LAND, land_func);
-CREATE_MPI_OP(MPI_LOR, lor_func);
-CREATE_MPI_OP(MPI_LXOR, lxor_func);
-CREATE_MPI_OP(MPI_BAND, band_func);
-CREATE_MPI_OP(MPI_BOR, bor_func);
-CREATE_MPI_OP(MPI_BXOR, bxor_func);
-CREATE_MPI_OP(MPI_MAXLOC, maxloc_func);
-CREATE_MPI_OP(MPI_MINLOC, minloc_func);
-CREATE_MPI_OP(MPI_REPLACE, replace_func);
-CREATE_MPI_OP(MPI_NO_OP, no_func);
+
+#define CREATE_MPI_OP(name, func, types)                                                                                      \
+  SMPI_Op _XBT_CONCAT(smpi_MPI_, name)(&(func) /* func */, true, true, types);
+
+#define MAX_TYPES DT_FLAG_C_INTEGER|DT_FLAG_F_INTEGER|DT_FLAG_FP|DT_FLAG_MULTILANG
+#define LAND_TYPES DT_FLAG_C_INTEGER|DT_FLAG_FP|DT_FLAG_LOGICAL|DT_FLAG_MULTILANG
+#define BAND_TYPES DT_FLAG_C_INTEGER|DT_FLAG_F_INTEGER|DT_FLAG_BYTE|DT_FLAG_MULTILANG
+
+CREATE_MPI_OP(MAX, max_func, MAX_TYPES)
+CREATE_MPI_OP(MIN, min_func, MAX_TYPES)
+CREATE_MPI_OP(SUM, sum_func, MAX_TYPES|DT_FLAG_COMPLEX)
+CREATE_MPI_OP(PROD, prod_func, MAX_TYPES|DT_FLAG_COMPLEX)
+CREATE_MPI_OP(LAND, land_func, LAND_TYPES)
+CREATE_MPI_OP(LOR, lor_func, LAND_TYPES)
+CREATE_MPI_OP(LXOR, lxor_func, LAND_TYPES)
+CREATE_MPI_OP(BAND, band_func, BAND_TYPES)
+CREATE_MPI_OP(BOR, bor_func, BAND_TYPES)
+CREATE_MPI_OP(BXOR, bxor_func, BAND_TYPES)
+CREATE_MPI_OP(MAXLOC, maxloc_func, DT_FLAG_REDUCTION)
+CREATE_MPI_OP(MINLOC, minloc_func, DT_FLAG_REDUCTION)
+CREATE_MPI_OP(REPLACE, replace_func, 0)
+CREATE_MPI_OP(NO_OP, no_func, 0)
 
 namespace simgrid{
 namespace smpi{
 
-void Op::apply(void *invec, void *inoutvec, int *len, MPI_Datatype datatype)
+void Op::apply(const void* invec, void* inoutvec, const int* len, MPI_Datatype datatype) const
 {
-  if (smpi_privatize_global_variables == SmpiPrivStrategies::MMAP) {
-    // we need to switch as the called function may silently touch global variables
-    XBT_DEBUG("Applying operation, switch to the right data frame ");
-    smpi_switch_data_segment(simgrid::s4u::Actor::self());
-  }
+  // we need to switch as the called function may silently touch global variables
+  smpi_switch_data_segment(simgrid::s4u::Actor::self());
 
   if (not smpi_process()->replaying() && *len > 0) {
+    XBT_DEBUG("Applying operation of length %d from %p and from/to %p", *len, invec, inoutvec);
     if (not is_fortran_op_)
-      this->func_(invec, inoutvec, len, &datatype);
+      this->func_(const_cast<void*>(invec), inoutvec, const_cast<int*>(len), &datatype);
     else{
-      XBT_DEBUG("Applying operation of length %d from %p and from/to %p", *len, invec, inoutvec);
       int tmp = datatype->c2f();
       /* Unfortunately, the C and Fortran version of the MPI standard do not agree on the type here,
          thus the reinterpret_cast. */
-      this->func_(invec, inoutvec, len, reinterpret_cast<MPI_Datatype*>(&tmp) );
+      this->func_(const_cast<void*>(invec), inoutvec, const_cast<int*>(len), reinterpret_cast<MPI_Datatype*>(&tmp));
     }
   }
 }
@@ -242,5 +286,19 @@ Op* Op::f2c(int id){
   return static_cast<Op*>(F2C::f2c(id));
 }
 
+void Op::ref(){
+  refcount_++;
+}
+
+void Op::unref(MPI_Op* op){
+  if((*op)!=MPI_OP_NULL){
+    (*op)->refcount_--;
+    if ((*op)->refcount_ == 0 && not (*op)->is_predefined_){
+      F2C::free_f((*op)->f2c_id());
+      delete(*op);
+    }
+  }
+}
+
 }
 }