Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Fix MPI_Type_dup for derived datatypes.
authorAugustin Degomme <adegomme@gmail.com>
Sat, 6 Jun 2020 20:52:23 +0000 (22:52 +0200)
committerAugustin Degomme <adegomme@gmail.com>
Sat, 6 Jun 2020 20:52:46 +0000 (22:52 +0200)
src/smpi/bindings/smpi_pmpi_type.cpp
src/smpi/include/smpi_datatype.hpp
src/smpi/include/smpi_datatype_derived.hpp
src/smpi/mpi/smpi_datatype.cpp
src/smpi/mpi/smpi_datatype_derived.cpp
teshsuite/smpi/mpich3-test/datatype/sendrecvt4.c

index da2f4cb..74cd22d 100644 (file)
@@ -77,7 +77,7 @@ int PMPI_Type_ub(MPI_Datatype datatype, MPI_Aint * disp)
 int PMPI_Type_dup(MPI_Datatype datatype, MPI_Datatype *newtype){
   int retval = MPI_SUCCESS;
   CHECK_MPI_NULL(1, MPI_DATATYPE_NULL, MPI_ERR_TYPE, datatype)
-  *newtype = new simgrid::smpi::Datatype(datatype, &retval);
+  *newtype = datatype->clone();
   //error when duplicating, free the new datatype
   if(retval!=MPI_SUCCESS){
     simgrid::smpi::Datatype::unref(*newtype);
index 8cf0cba..5e387ba 100644 (file)
@@ -142,6 +142,7 @@ public:
   void set_name(const char* name);
   static int copy(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount,
                   MPI_Datatype recvtype);
+  virtual MPI_Datatype clone();
   virtual void serialize(const void* noncontiguous, void* contiguous, int count);
   virtual void unserialize(const void* contiguous, void* noncontiguous, int count, MPI_Op op);
   static int keyval_create(MPI_Type_copy_attr_function* copy_fn, MPI_Type_delete_attr_function* delete_fn, int* keyval,
index fb5922b..48ca0a8 100644 (file)
@@ -21,11 +21,13 @@ public:
   Type_Contiguous(const Type_Contiguous&) = delete;
   Type_Contiguous& operator=(const Type_Contiguous&) = delete;
   ~Type_Contiguous();
+  Type_Contiguous* clone();
   void serialize(const void* noncontiguous, void* contiguous, int count) override;
   void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override;
 };
 
 class Type_Hvector: public Datatype{
+public:
   int block_count_;
   int block_length_;
   MPI_Aint block_stride_;
@@ -37,6 +39,7 @@ public:
   Type_Hvector(const Type_Hvector&) = delete;
   Type_Hvector& operator=(const Type_Hvector&) = delete;
   ~Type_Hvector();
+  Type_Hvector* clone();
   void serialize(const void* noncontiguous, void* contiguous, int count) override;
   void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override;
 };
@@ -45,9 +48,11 @@ class Type_Vector : public Type_Hvector {
 public:
   Type_Vector(int size, MPI_Aint lb, MPI_Aint ub, int flags, int count, int blocklen, int stride,
               MPI_Datatype old_type);
+  Type_Vector* clone();
 };
 
 class Type_Hindexed: public Datatype{
+public:
   int block_count_;
   int* block_lengths_;
   MPI_Aint* block_indices_;
@@ -60,6 +65,7 @@ public:
                 MPI_Datatype old_type, MPI_Aint factor);
   Type_Hindexed(const Type_Hindexed&) = delete;
   Type_Hindexed& operator=(const Type_Hindexed&) = delete;
+  Type_Hindexed* clone();
   ~Type_Hindexed();
   void serialize(const void* noncontiguous, void* contiguous, int count) override;
   void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override;
@@ -69,6 +75,7 @@ class Type_Indexed : public Type_Hindexed {
 public:
   Type_Indexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int block_count, const int* block_lengths, const int* block_indices,
                MPI_Datatype old_type);
+  Type_Indexed* clone();
 };
 
 class Type_Struct: public Datatype{
@@ -82,6 +89,7 @@ public:
               const MPI_Aint* block_indices, const MPI_Datatype* old_types);
   Type_Struct(const Type_Struct&) = delete;
   Type_Struct& operator=(const Type_Struct&) = delete;
+  Type_Struct* clone();
   ~Type_Struct();
   void serialize(const void* noncontiguous, void* contiguous, int count) override;
   void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override;
index 09e81eb..5547458 100644 (file)
@@ -182,6 +182,11 @@ Datatype::~Datatype()
   xbt_free(name_);
 }
 
+MPI_Datatype Datatype::clone(){
+  int ret = MPI_SUCCESS;
+  return new Datatype(this, &ret);
+}
+
 void Datatype::ref()
 {
   refcount_++;
index aa12ddc..bf6392b 100644 (file)
@@ -53,6 +53,11 @@ Type_Contiguous::~Type_Contiguous()
   Datatype::unref(old_type_);
 }
 
+Type_Contiguous* Type_Contiguous::clone()
+{
+  return new Type_Contiguous(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->old_type_);
+}
+
 void Type_Contiguous::serialize(const void* noncontiguous_buf, void* contiguous_buf, int count)
 {
   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
@@ -78,6 +83,11 @@ Type_Hvector::~Type_Hvector(){
   Datatype::unref(old_type_);
 }
 
+Type_Hvector* Type_Hvector::clone()
+{
+  return new Type_Hvector(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_length_, this->block_stride_, this->old_type_);
+}
+
 void Type_Hvector::serialize(const void* noncontiguous_buf, void *contiguous_buf,
                     int count){
   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
@@ -125,6 +135,11 @@ Type_Vector::Type_Vector(int size, MPI_Aint lb, MPI_Aint ub, int flags, int coun
   contents_ = new Datatype_contents(MPI_COMBINER_VECTOR, 3, ints, 0, nullptr, 1, &old_type);
 }
 
+Type_Vector* Type_Vector::clone()
+{
+  return new Type_Vector(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_length_, this->block_stride_, this->old_type_);
+}
+
 Type_Hindexed::Type_Hindexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int count, const int* block_lengths,
                              const MPI_Aint* block_indices, MPI_Datatype old_type)
     : Datatype(size, lb, ub, flags)
@@ -161,6 +176,11 @@ Type_Hindexed::Type_Hindexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int
   }
 }
 
+Type_Hindexed* Type_Hindexed::clone()
+{
+  return new Type_Hindexed(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_lengths_, this->block_indices_, this->old_type_);
+}
+
 Type_Hindexed::~Type_Hindexed()
 {
   Datatype::unref(old_type_);
@@ -230,6 +250,11 @@ Type_Indexed::Type_Indexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int co
   delete[] ints;
 }
 
+Type_Indexed* Type_Indexed::clone()
+{
+  return new Type_Indexed(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_lengths_, (int*)(this->block_indices_), this->old_type_);
+}
+
 Type_Struct::Type_Struct(int size, MPI_Aint lb, MPI_Aint ub, int flags, int count, const int* block_lengths,
                          const MPI_Aint* block_indices, const MPI_Datatype* old_types)
     : Datatype(size, lb, ub, flags)
@@ -263,6 +288,10 @@ Type_Struct::~Type_Struct(){
   }
 }
 
+Type_Struct* Type_Struct::clone()
+{
+  return new Type_Struct(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_lengths_, this->block_indices_, this->old_types_);
+}
 
 void Type_Struct::serialize(const void* noncontiguous_buf, void *contiguous_buf,
                         int count){
index bd3338d..9b9562a 100644 (file)
@@ -75,8 +75,11 @@ int main(int argc, char **argv)
                     continue;
                 }
                 partner = np - 1;
-                MPI_Send(MPI_BOTTOM, counts[j], offsettype, partner, tag, comm);
+               MPI_Datatype dup;
+                MPI_Type_dup(offsettype, &dup);
+               MPI_Send(MPI_BOTTOM, counts[j], dup, partner, tag, comm);
                 MPI_Type_free(&offsettype);
+                MPI_Type_free(&dup);
             }
             else if (rank == np - 1) {
                 partner = 0;
@@ -101,7 +104,9 @@ int main(int argc, char **argv)
                     MPI_Type_free(&offsettype);
                     continue;
                 }
-                MPI_Recv(MPI_BOTTOM, counts[j], offsettype, partner, tag, comm, &status);
+                MPI_Datatype dup;
+               MPI_Type_dup(offsettype, &dup);
+                MPI_Recv(MPI_BOTTOM, counts[j], dup, partner, tag, comm, &status);
                 /* Test for correctness */
                 MPI_Get_count(&status, types[j], &count);
                 if (count != counts[j]) {
@@ -136,6 +141,7 @@ int main(int argc, char **argv)
                     err++;
                 }
                 MPI_Type_free(&offsettype);
+                MPI_Type_free(&dup);
             }
         }
         MTestFreeComm(&comm);