Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Fix [#17799] : have mpi_group_range_incl and mpi_group_range_excl better test some...
[simgrid.git] / src / smpi / smpi_pmpi.c
index 485a8c0..ed24e85 100644 (file)
@@ -478,30 +478,14 @@ int PMPI_Group_difference(MPI_Group group1, MPI_Group group2, MPI_Group * newgro
 
 int PMPI_Group_incl(MPI_Group group, int n, int *ranks, MPI_Group * newgroup)
 {
-  int retval, i, index;
+  int retval;
 
   if (group == MPI_GROUP_NULL) {
     retval = MPI_ERR_GROUP;
   } else if (newgroup == NULL) {
     retval = MPI_ERR_ARG;
   } else {
-    if (n == 0) {
-      *newgroup = MPI_GROUP_EMPTY;
-    } else if (n == smpi_group_size(group)) {
-      *newgroup = group;
-      if(group!= smpi_comm_group(MPI_COMM_WORLD)
-                && group != MPI_GROUP_NULL
-                && group != smpi_comm_group(MPI_COMM_SELF)
-                && group != MPI_GROUP_EMPTY)
-      smpi_group_use(group);
-    } else {
-      *newgroup = smpi_group_new(n);
-      for (i = 0; i < n; i++) {
-        index = smpi_group_index(group, ranks[i]);
-        smpi_group_set_mapping(*newgroup, index, i);
-      }
-    }
-    retval = MPI_SUCCESS;
+    retval = smpi_group_incl(group, n, ranks, newgroup);
   }
   return retval;
 }
@@ -567,10 +551,12 @@ int PMPI_Group_range_incl(MPI_Group group, int n, int ranges[][3],
       size = 0;
       for (i = 0; i < n; i++) {
         for (rank = ranges[i][0];       /* First */
-             rank >= 0; /* Last */
+             rank >= 0 && rank < smpi_group_size(group); /* Last */
               ) {
           size++;
-
+          if(rank == ranges[i][1]){/*already last ?*/
+            break;
+          }
           rank += ranges[i][2]; /* Stride */
          if (ranges[i][0]<ranges[i][1]){
              if(rank > ranges[i][1])
@@ -586,11 +572,14 @@ int PMPI_Group_range_incl(MPI_Group group, int n, int ranges[][3],
       j = 0;
       for (i = 0; i < n; i++) {
         for (rank = ranges[i][0];     /* First */
-             rank >= 0; /* Last */
+             rank >= 0 && rank < smpi_group_size(group); /* Last */
              ) {
           index = smpi_group_index(group, rank);
           smpi_group_set_mapping(*newgroup, index, j);
           j++;
+          if(rank == ranges[i][1]){/*already last ?*/
+            break;
+          }
           rank += ranges[i][2]; /* Stride */
          if (ranges[i][0]<ranges[i][1]){
            if(rank > ranges[i][1])
@@ -628,10 +617,12 @@ int PMPI_Group_range_excl(MPI_Group group, int n, int ranges[][3],
       size = smpi_group_size(group);
       for (i = 0; i < n; i++) {
         for (rank = ranges[i][0];       /* First */
-             rank >= 0; /* Last */
+             rank >= 0 && rank < smpi_group_size(group); /* Last */
               ) {
           size--;
-
+          if(rank == ranges[i][1]){/*already last ?*/
+            break;
+          }
           rank += ranges[i][2]; /* Stride */
          if (ranges[i][0]<ranges[i][1]){
              if(rank > ranges[i][1])
@@ -651,14 +642,17 @@ int PMPI_Group_range_excl(MPI_Group group, int n, int ranges[][3],
         while (newrank < size) {
           add=1;
           for (i = 0; i < n; i++) {
-            for (rank = ranges[i][0];rank >= 0;){
+            for (rank = ranges[i][0];
+                rank >= 0 && rank < smpi_group_size(group);
+                ){
               if(rank==oldrank){
                   add=0;
                   break;
               }
-
+              if(rank == ranges[i][1]){/*already last ?*/
+                break;
+              }
               rank += ranges[i][2]; /* Stride */
-
               if (ranges[i][0]<ranges[i][1]){
                   if(rank > ranges[i][1])
                     break;
@@ -1691,6 +1685,8 @@ int PMPI_Bcast(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm c
 
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
+  } else if (!is_datatype_valid(datatype)) {
+      retval = MPI_ERR_ARG;
   } else {
 #ifdef HAVE_TRACING
   int rank = comm != MPI_COMM_NULL ? smpi_process_index() : -1;
@@ -1956,8 +1952,8 @@ int PMPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (((smpi_comm_rank(comm)==root) && (sendtype == MPI_DATATYPE_NULL))
-             || ((recvbuf !=MPI_IN_PLACE) && (recvtype == MPI_DATATYPE_NULL))) {
+  } else if (((smpi_comm_rank(comm)==root) && (!is_datatype_valid(sendtype)))
+             || ((recvbuf !=MPI_IN_PLACE) && (!is_datatype_valid(recvtype)))){
     retval = MPI_ERR_TYPE;
   } else if ((sendbuf == recvbuf) ||
       ((smpi_comm_rank(comm)==root) && sendcount>0 && (sendbuf == NULL))){