From b01173b82d64fa12bef1afd2a3f9b3b1259296c5 Mon Sep 17 00:00:00 2001 From: Augustin Degomme Date: Sat, 6 Jun 2020 22:52:23 +0200 Subject: [PATCH] Fix MPI_Type_dup for derived datatypes. --- src/smpi/bindings/smpi_pmpi_type.cpp | 2 +- src/smpi/include/smpi_datatype.hpp | 1 + src/smpi/include/smpi_datatype_derived.hpp | 8 +++++ src/smpi/mpi/smpi_datatype.cpp | 5 ++++ src/smpi/mpi/smpi_datatype_derived.cpp | 29 +++++++++++++++++++ .../smpi/mpich3-test/datatype/sendrecvt4.c | 10 +++++-- 6 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/smpi/bindings/smpi_pmpi_type.cpp b/src/smpi/bindings/smpi_pmpi_type.cpp index da2f4cb04a..74cd22d96c 100644 --- a/src/smpi/bindings/smpi_pmpi_type.cpp +++ b/src/smpi/bindings/smpi_pmpi_type.cpp @@ -77,7 +77,7 @@ int PMPI_Type_ub(MPI_Datatype datatype, MPI_Aint * disp) int PMPI_Type_dup(MPI_Datatype datatype, MPI_Datatype *newtype){ int retval = MPI_SUCCESS; CHECK_MPI_NULL(1, MPI_DATATYPE_NULL, MPI_ERR_TYPE, datatype) - *newtype = new simgrid::smpi::Datatype(datatype, &retval); + *newtype = datatype->clone(); //error when duplicating, free the new datatype if(retval!=MPI_SUCCESS){ simgrid::smpi::Datatype::unref(*newtype); diff --git a/src/smpi/include/smpi_datatype.hpp b/src/smpi/include/smpi_datatype.hpp index 8cf0cbac97..5e387bab81 100644 --- a/src/smpi/include/smpi_datatype.hpp +++ b/src/smpi/include/smpi_datatype.hpp @@ -142,6 +142,7 @@ public: void set_name(const char* name); static int copy(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount, MPI_Datatype recvtype); + virtual MPI_Datatype clone(); virtual void serialize(const void* noncontiguous, void* contiguous, int count); virtual void unserialize(const void* contiguous, void* noncontiguous, int count, MPI_Op op); static int keyval_create(MPI_Type_copy_attr_function* copy_fn, MPI_Type_delete_attr_function* delete_fn, int* keyval, diff --git a/src/smpi/include/smpi_datatype_derived.hpp b/src/smpi/include/smpi_datatype_derived.hpp index fb5922b048..48ca0a8f40 100644 --- a/src/smpi/include/smpi_datatype_derived.hpp +++ b/src/smpi/include/smpi_datatype_derived.hpp @@ -21,11 +21,13 @@ public: Type_Contiguous(const Type_Contiguous&) = delete; Type_Contiguous& operator=(const Type_Contiguous&) = delete; ~Type_Contiguous(); + Type_Contiguous* clone(); void serialize(const void* noncontiguous, void* contiguous, int count) override; void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override; }; class Type_Hvector: public Datatype{ +public: int block_count_; int block_length_; MPI_Aint block_stride_; @@ -37,6 +39,7 @@ public: Type_Hvector(const Type_Hvector&) = delete; Type_Hvector& operator=(const Type_Hvector&) = delete; ~Type_Hvector(); + Type_Hvector* clone(); void serialize(const void* noncontiguous, void* contiguous, int count) override; void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override; }; @@ -45,9 +48,11 @@ class Type_Vector : public Type_Hvector { public: Type_Vector(int size, MPI_Aint lb, MPI_Aint ub, int flags, int count, int blocklen, int stride, MPI_Datatype old_type); + Type_Vector* clone(); }; class Type_Hindexed: public Datatype{ +public: int block_count_; int* block_lengths_; MPI_Aint* block_indices_; @@ -60,6 +65,7 @@ public: MPI_Datatype old_type, MPI_Aint factor); Type_Hindexed(const Type_Hindexed&) = delete; Type_Hindexed& operator=(const Type_Hindexed&) = delete; + Type_Hindexed* clone(); ~Type_Hindexed(); void serialize(const void* noncontiguous, void* contiguous, int count) override; void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override; @@ -69,6 +75,7 @@ class Type_Indexed : public Type_Hindexed { public: Type_Indexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int block_count, const int* block_lengths, const int* block_indices, MPI_Datatype old_type); + Type_Indexed* clone(); }; class Type_Struct: public Datatype{ @@ -82,6 +89,7 @@ public: const MPI_Aint* block_indices, const MPI_Datatype* old_types); Type_Struct(const Type_Struct&) = delete; Type_Struct& operator=(const Type_Struct&) = delete; + Type_Struct* clone(); ~Type_Struct(); void serialize(const void* noncontiguous, void* contiguous, int count) override; void unserialize(const void* contiguous_vector, void* noncontiguous_vector, int count, MPI_Op op) override; diff --git a/src/smpi/mpi/smpi_datatype.cpp b/src/smpi/mpi/smpi_datatype.cpp index 09e81eb3a7..55474585bf 100644 --- a/src/smpi/mpi/smpi_datatype.cpp +++ b/src/smpi/mpi/smpi_datatype.cpp @@ -182,6 +182,11 @@ Datatype::~Datatype() xbt_free(name_); } +MPI_Datatype Datatype::clone(){ + int ret = MPI_SUCCESS; + return new Datatype(this, &ret); +} + void Datatype::ref() { refcount_++; diff --git a/src/smpi/mpi/smpi_datatype_derived.cpp b/src/smpi/mpi/smpi_datatype_derived.cpp index aa12ddcfb5..bf6392b15e 100644 --- a/src/smpi/mpi/smpi_datatype_derived.cpp +++ b/src/smpi/mpi/smpi_datatype_derived.cpp @@ -53,6 +53,11 @@ Type_Contiguous::~Type_Contiguous() Datatype::unref(old_type_); } +Type_Contiguous* Type_Contiguous::clone() +{ + return new Type_Contiguous(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->old_type_); +} + void Type_Contiguous::serialize(const void* noncontiguous_buf, void* contiguous_buf, int count) { char* contiguous_buf_char = static_cast(contiguous_buf); @@ -78,6 +83,11 @@ Type_Hvector::~Type_Hvector(){ Datatype::unref(old_type_); } +Type_Hvector* Type_Hvector::clone() +{ + return new Type_Hvector(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_length_, this->block_stride_, this->old_type_); +} + void Type_Hvector::serialize(const void* noncontiguous_buf, void *contiguous_buf, int count){ char* contiguous_buf_char = static_cast(contiguous_buf); @@ -125,6 +135,11 @@ Type_Vector::Type_Vector(int size, MPI_Aint lb, MPI_Aint ub, int flags, int coun contents_ = new Datatype_contents(MPI_COMBINER_VECTOR, 3, ints, 0, nullptr, 1, &old_type); } +Type_Vector* Type_Vector::clone() +{ + return new Type_Vector(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_length_, this->block_stride_, this->old_type_); +} + Type_Hindexed::Type_Hindexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int count, const int* block_lengths, const MPI_Aint* block_indices, MPI_Datatype old_type) : Datatype(size, lb, ub, flags) @@ -161,6 +176,11 @@ Type_Hindexed::Type_Hindexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int } } +Type_Hindexed* Type_Hindexed::clone() +{ + return new Type_Hindexed(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_lengths_, this->block_indices_, this->old_type_); +} + Type_Hindexed::~Type_Hindexed() { Datatype::unref(old_type_); @@ -230,6 +250,11 @@ Type_Indexed::Type_Indexed(int size, MPI_Aint lb, MPI_Aint ub, int flags, int co delete[] ints; } +Type_Indexed* Type_Indexed::clone() +{ + return new Type_Indexed(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_lengths_, (int*)(this->block_indices_), this->old_type_); +} + Type_Struct::Type_Struct(int size, MPI_Aint lb, MPI_Aint ub, int flags, int count, const int* block_lengths, const MPI_Aint* block_indices, const MPI_Datatype* old_types) : Datatype(size, lb, ub, flags) @@ -263,6 +288,10 @@ Type_Struct::~Type_Struct(){ } } +Type_Struct* Type_Struct::clone() +{ + return new Type_Struct(this->size(), this->lb(), this->ub(), this->flags(), this->block_count_, this->block_lengths_, this->block_indices_, this->old_types_); +} void Type_Struct::serialize(const void* noncontiguous_buf, void *contiguous_buf, int count){ diff --git a/teshsuite/smpi/mpich3-test/datatype/sendrecvt4.c b/teshsuite/smpi/mpich3-test/datatype/sendrecvt4.c index bd3338dc1f..9b9562addc 100644 --- a/teshsuite/smpi/mpich3-test/datatype/sendrecvt4.c +++ b/teshsuite/smpi/mpich3-test/datatype/sendrecvt4.c @@ -75,8 +75,11 @@ int main(int argc, char **argv) continue; } partner = np - 1; - MPI_Send(MPI_BOTTOM, counts[j], offsettype, partner, tag, comm); + MPI_Datatype dup; + MPI_Type_dup(offsettype, &dup); + MPI_Send(MPI_BOTTOM, counts[j], dup, partner, tag, comm); MPI_Type_free(&offsettype); + MPI_Type_free(&dup); } else if (rank == np - 1) { partner = 0; @@ -101,7 +104,9 @@ int main(int argc, char **argv) MPI_Type_free(&offsettype); continue; } - MPI_Recv(MPI_BOTTOM, counts[j], offsettype, partner, tag, comm, &status); + MPI_Datatype dup; + MPI_Type_dup(offsettype, &dup); + MPI_Recv(MPI_BOTTOM, counts[j], dup, partner, tag, comm, &status); /* Test for correctness */ MPI_Get_count(&status, types[j], &count); if (count != counts[j]) { @@ -136,6 +141,7 @@ int main(int argc, char **argv) err++; } MPI_Type_free(&offsettype); + MPI_Type_free(&dup); } } MTestFreeComm(&comm); -- 2.20.1