From a4fe82686e41ee595eaf2a59434091695ec28b59 Mon Sep 17 00:00:00 2001 From: Augustin Degomme Date: Sun, 4 Apr 2021 01:05:01 +0200 Subject: [PATCH] group inputs validation --- src/smpi/bindings/smpi_pmpi_group.cpp | 75 +++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/src/smpi/bindings/smpi_pmpi_group.cpp b/src/smpi/bindings/smpi_pmpi_group.cpp index 9908277bbb..ea9898fd74 100644 --- a/src/smpi/bindings/smpi_pmpi_group.cpp +++ b/src/smpi/bindings/smpi_pmpi_group.cpp @@ -17,6 +17,7 @@ XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi); int PMPI_Group_free(MPI_Group * group) { CHECK_NULL(1, MPI_ERR_ARG, group) + CHECK_MPI_NULL(1, MPI_GROUP_NULL, MPI_ERR_GROUP, *group) if(*group != MPI_COMM_WORLD->group() && *group != MPI_GROUP_EMPTY) simgrid::smpi::Group::unref(*group); *group = MPI_GROUP_NULL; @@ -42,8 +43,13 @@ int PMPI_Group_rank(MPI_Group group, int *rank) int PMPI_Group_translate_ranks(MPI_Group group1, int n, const int *ranks1, MPI_Group group2, int *ranks2) { CHECK_GROUP(1, group1) + CHECK_NEGATIVE(2, MPI_ERR_ARG, n) + CHECK_NULL(3, MPI_ERR_ARG, ranks1) + CHECK_NULL(5, MPI_ERR_ARG, ranks2) CHECK_GROUP(4, group2) for (int i = 0; i < n; i++) { + if (ranks1[i] != MPI_PROC_NULL && (ranks1[i] < 0 || ranks1[i] >= group1->size())) + return MPI_ERR_RANK; if(ranks1[i]==MPI_PROC_NULL){ ranks2[i]=MPI_PROC_NULL; }else{ @@ -76,6 +82,10 @@ int PMPI_Group_intersection(MPI_Group group1, MPI_Group group2, MPI_Group * newg CHECK_GROUP(1, group1) CHECK_GROUP(2, group2) CHECK_NULL(3, MPI_ERR_ARG, newgroup) + if(group1 == MPI_GROUP_EMPTY || group2 == MPI_GROUP_EMPTY){ + *newgroup = MPI_GROUP_EMPTY; + return MPI_SUCCESS; + } return group1->intersection(group2,newgroup); } @@ -90,15 +100,43 @@ int PMPI_Group_difference(MPI_Group group1, MPI_Group group2, MPI_Group * newgro int PMPI_Group_incl(MPI_Group group, int n, const int *ranks, MPI_Group * newgroup) { CHECK_GROUP(1, group) + CHECK_NEGATIVE(2, MPI_ERR_ARG, n) + CHECK_NULL(3, MPI_ERR_ARG, ranks) CHECK_NULL(4, MPI_ERR_ARG, newgroup) - return group->incl(n, ranks, newgroup); + for(int i = 0; i < n; i++){ + if (ranks[i] < 0 || ranks[i] >= group->size()) + return MPI_ERR_RANK; + for(int j = i+1; j < n; j++){ + if(ranks[i] == ranks[j]) + return MPI_ERR_RANK; + } + } + if (n > group->size()){ + XBT_WARN("MPI_Group_excl, param 2 > group size"); + return MPI_ERR_ARG; + } else { + return group->incl(n, ranks, newgroup); + } } int PMPI_Group_excl(MPI_Group group, int n, const int *ranks, MPI_Group * newgroup) { CHECK_GROUP(1, group) + CHECK_NEGATIVE(2, MPI_ERR_ARG, n) + CHECK_NULL(3, MPI_ERR_ARG, ranks) CHECK_NULL(4, MPI_ERR_ARG, newgroup) - if (n == 0) { + for(int i = 0; i < n; i++){ + if (ranks[i] < 0 || ranks[i] >= group->size()) + return MPI_ERR_RANK; + for(int j = i+1; j < n; j++){ + if(ranks[i] == ranks[j]) + return MPI_ERR_RANK; + } + } + if (n > group->size()){ + XBT_WARN("MPI_Group_excl, param 2 > group size"); + return MPI_ERR_ARG; + } else if (n == 0) { *newgroup = group; if (group != MPI_COMM_WORLD->group() && group != MPI_COMM_SELF->group() && group != MPI_GROUP_EMPTY) group->ref(); @@ -114,8 +152,25 @@ int PMPI_Group_excl(MPI_Group group, int n, const int *ranks, MPI_Group * newgro int PMPI_Group_range_incl(MPI_Group group, int n, int ranges[][3], MPI_Group * newgroup) { CHECK_GROUP(1, group) + CHECK_NEGATIVE(2, MPI_ERR_ARG, n) + CHECK_NULL(3, MPI_ERR_ARG, ranges) CHECK_NULL(4, MPI_ERR_ARG, newgroup) - if (n == 0) { + for(int i = 0; i < n; i++){ + if (ranges[i][0] < 0 || ranges[i][0] >= group->size() || + ranges[i][1] < 0 || ranges[i][1] >= group->size()){ + return MPI_ERR_RANK; + } + if ((ranges[i][0] < ranges[i][1] && ranges[i][2] < 0) || + (ranges[i][0] > ranges[i][1] && ranges[i][2] > 0)){ + return MPI_ERR_ARG; + } + if (ranges[i][2] == 0) + return MPI_ERR_ARG; + } + if (n > group->size()){ + XBT_WARN("MPI_Group_range_incl, param 2 > group size"); + return MPI_ERR_ARG; + } else if (n == 0) { *newgroup = MPI_GROUP_EMPTY; return MPI_SUCCESS; } else { @@ -126,7 +181,21 @@ int PMPI_Group_range_incl(MPI_Group group, int n, int ranges[][3], MPI_Group * n int PMPI_Group_range_excl(MPI_Group group, int n, int ranges[][3], MPI_Group * newgroup) { CHECK_GROUP(1, group) + CHECK_NEGATIVE(2, MPI_ERR_ARG, n) + CHECK_NULL(3, MPI_ERR_ARG, ranges) CHECK_NULL(4, MPI_ERR_ARG, newgroup) + for(int i = 0; i < n; i++){ + if (ranges[i][0] < 0 || ranges[i][0] >= group->size() || + ranges[i][1] < 0 || ranges[i][1] >= group->size()){ + return MPI_ERR_RANK; + } + if ((ranges[i][0] < ranges[i][1] && ranges[i][2] < 0) || + (ranges[i][0] > ranges[i][1] && ranges[i][2] > 0)){ + return MPI_ERR_ARG; + } + if (ranges[i][2] == 0) + return MPI_ERR_ARG; + } if (n == 0) { *newgroup = group; if (group != MPI_COMM_WORLD->group() && group != MPI_COMM_SELF->group() && -- 2.20.1