Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
group inputs validation
authorAugustin Degomme <adegomme@gmail.com>
Sat, 3 Apr 2021 23:05:01 +0000 (01:05 +0200)
committerAugustin Degomme <adegomme@gmail.com>
Sat, 3 Apr 2021 23:05:01 +0000 (01:05 +0200)
src/smpi/bindings/smpi_pmpi_group.cpp

index 9908277..ea9898f 100644 (file)
@@ -17,6 +17,7 @@ XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi);
 int PMPI_Group_free(MPI_Group * group)
 {
   CHECK_NULL(1, MPI_ERR_ARG, group)
+  CHECK_MPI_NULL(1, MPI_GROUP_NULL, MPI_ERR_GROUP, *group)
   if(*group != MPI_COMM_WORLD->group() && *group != MPI_GROUP_EMPTY)
     simgrid::smpi::Group::unref(*group);
   *group = MPI_GROUP_NULL;
@@ -42,8 +43,13 @@ int PMPI_Group_rank(MPI_Group group, int *rank)
 int PMPI_Group_translate_ranks(MPI_Group group1, int n, const int *ranks1, MPI_Group group2, int *ranks2)
 {
   CHECK_GROUP(1, group1)
+  CHECK_NEGATIVE(2, MPI_ERR_ARG, n)
+  CHECK_NULL(3, MPI_ERR_ARG, ranks1)
+  CHECK_NULL(5, MPI_ERR_ARG, ranks2)
   CHECK_GROUP(4, group2)
   for (int i = 0; i < n; i++) {
+    if (ranks1[i] != MPI_PROC_NULL && (ranks1[i] < 0 || ranks1[i] >= group1->size()))
+      return MPI_ERR_RANK;
     if(ranks1[i]==MPI_PROC_NULL){
       ranks2[i]=MPI_PROC_NULL;
     }else{
@@ -76,6 +82,10 @@ int PMPI_Group_intersection(MPI_Group group1, MPI_Group group2, MPI_Group * newg
   CHECK_GROUP(1, group1)
   CHECK_GROUP(2, group2)
   CHECK_NULL(3, MPI_ERR_ARG, newgroup)
+  if(group1 == MPI_GROUP_EMPTY || group2 == MPI_GROUP_EMPTY){
+    *newgroup = MPI_GROUP_EMPTY;
+    return MPI_SUCCESS;
+  }
   return group1->intersection(group2,newgroup);
 }
 
@@ -90,15 +100,43 @@ int PMPI_Group_difference(MPI_Group group1, MPI_Group group2, MPI_Group * newgro
 int PMPI_Group_incl(MPI_Group group, int n, const int *ranks, MPI_Group * newgroup)
 {
   CHECK_GROUP(1, group)
+  CHECK_NEGATIVE(2, MPI_ERR_ARG, n)
+  CHECK_NULL(3, MPI_ERR_ARG, ranks)
   CHECK_NULL(4, MPI_ERR_ARG, newgroup)
-  return group->incl(n, ranks, newgroup);
+  for(int i = 0; i < n; i++){
+    if (ranks[i] < 0 || ranks[i] >= group->size())
+      return MPI_ERR_RANK;
+    for(int j = i+1; j < n; j++){
+      if(ranks[i] == ranks[j])
+        return MPI_ERR_RANK;
+    }
+  }
+  if (n > group->size()){
+    XBT_WARN("MPI_Group_excl, param 2 > group size");
+    return MPI_ERR_ARG;
+  } else {
+    return group->incl(n, ranks, newgroup);
+  }
 }
 
 int PMPI_Group_excl(MPI_Group group, int n, const int *ranks, MPI_Group * newgroup)
 {
   CHECK_GROUP(1, group)
+  CHECK_NEGATIVE(2, MPI_ERR_ARG, n)
+  CHECK_NULL(3, MPI_ERR_ARG, ranks)
   CHECK_NULL(4, MPI_ERR_ARG, newgroup)
-  if (n == 0) {
+  for(int i = 0; i < n; i++){
+    if (ranks[i] < 0 || ranks[i] >= group->size())
+      return MPI_ERR_RANK;
+    for(int j = i+1; j < n; j++){
+      if(ranks[i] == ranks[j])
+        return MPI_ERR_RANK;
+    }
+  }
+  if (n > group->size()){
+    XBT_WARN("MPI_Group_excl, param 2 > group size");
+    return MPI_ERR_ARG;
+  } else if (n == 0) {
     *newgroup = group;
     if (group != MPI_COMM_WORLD->group() && group != MPI_COMM_SELF->group() && group != MPI_GROUP_EMPTY)
       group->ref();
@@ -114,8 +152,25 @@ int PMPI_Group_excl(MPI_Group group, int n, const int *ranks, MPI_Group * newgro
 int PMPI_Group_range_incl(MPI_Group group, int n, int ranges[][3], MPI_Group * newgroup)
 {
   CHECK_GROUP(1, group)
+  CHECK_NEGATIVE(2, MPI_ERR_ARG, n)
+  CHECK_NULL(3, MPI_ERR_ARG, ranges)
   CHECK_NULL(4, MPI_ERR_ARG, newgroup)
-  if (n == 0) {
+  for(int i = 0; i < n; i++){
+    if (ranges[i][0] < 0 || ranges[i][0] >= group->size() ||
+        ranges[i][1] < 0 || ranges[i][1] >= group->size()){
+      return MPI_ERR_RANK;
+    }
+    if ((ranges[i][0] < ranges[i][1] && ranges[i][2] < 0) ||
+        (ranges[i][0] > ranges[i][1] && ranges[i][2] > 0)){
+      return MPI_ERR_ARG;
+    }
+    if (ranges[i][2] == 0)
+      return MPI_ERR_ARG;
+  }
+  if (n > group->size()){
+    XBT_WARN("MPI_Group_range_incl, param 2 > group size");
+    return MPI_ERR_ARG;
+  } else if (n == 0) {
     *newgroup = MPI_GROUP_EMPTY;
     return MPI_SUCCESS;
   } else {
@@ -126,7 +181,21 @@ int PMPI_Group_range_incl(MPI_Group group, int n, int ranges[][3], MPI_Group * n
 int PMPI_Group_range_excl(MPI_Group group, int n, int ranges[][3], MPI_Group * newgroup)
 {
   CHECK_GROUP(1, group)
+  CHECK_NEGATIVE(2, MPI_ERR_ARG, n)
+  CHECK_NULL(3, MPI_ERR_ARG, ranges)
   CHECK_NULL(4, MPI_ERR_ARG, newgroup)
+  for(int i = 0; i < n; i++){
+    if (ranges[i][0] < 0 || ranges[i][0] >= group->size() ||
+        ranges[i][1] < 0 || ranges[i][1] >= group->size()){
+      return MPI_ERR_RANK;
+    }
+    if ((ranges[i][0] < ranges[i][1] && ranges[i][2] < 0) ||
+        (ranges[i][0] > ranges[i][1] && ranges[i][2] > 0)){
+      return MPI_ERR_ARG;
+    }
+    if (ranges[i][2] == 0)
+      return MPI_ERR_ARG;
+  }
   if (n == 0) {
     *newgroup = group;
     if (group != MPI_COMM_WORLD->group() && group != MPI_COMM_SELF->group() &&