Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
hide this from users
[simgrid.git] / src / smpi / smpi_pmpi.cpp
index 4d8d8f0..853c980 100644 (file)
@@ -4,7 +4,6 @@
  * under the terms of the license (GNU LGPL) which comes with this package. */
 
 #include <simgrid/s4u/host.hpp>
-#include <xbt/ex.hpp>
 
 #include "private.h"
 
@@ -152,7 +151,7 @@ int PMPI_Type_free(MPI_Datatype * datatype)
   if (*datatype == MPI_DATATYPE_NULL) {
     return MPI_ERR_ARG;
   } else {
-    (*datatype)->unuse();
+    Datatype::unref(*datatype);
     return MPI_SUCCESS;
   }
 }
@@ -234,12 +233,18 @@ int PMPI_Type_ub(MPI_Datatype datatype, MPI_Aint * disp)
 }
 
 int PMPI_Type_dup(MPI_Datatype datatype, MPI_Datatype *newtype){
+  int retval = MPI_SUCCESS;
   if (datatype == MPI_DATATYPE_NULL) {
-    return MPI_ERR_TYPE;
+    retval=MPI_ERR_TYPE;
   } else {
-    *newtype = new Datatype(datatype);
-    return MPI_SUCCESS;
+    *newtype = new Datatype(datatype, &retval);
+    //error when duplicating, free the new datatype
+    if(retval!=MPI_SUCCESS){
+      Datatype::unref(*newtype);
+      *newtype = MPI_DATATYPE_NULL;
+    }
   }
+  return retval;
 }
 
 int PMPI_Op_create(MPI_User_function * function, int commute, MPI_Op * op)
@@ -270,7 +275,8 @@ int PMPI_Group_free(MPI_Group * group)
   if (group == nullptr) {
     return MPI_ERR_ARG;
   } else {
-    (*group)->destroy();
+    if(*group != MPI_COMM_WORLD->group() && *group != MPI_GROUP_EMPTY)
+      Group::unref(*group);
     *group = MPI_GROUP_NULL;
     return MPI_SUCCESS;
   }
@@ -386,7 +392,7 @@ int PMPI_Group_excl(MPI_Group group, int n, int *ranks, MPI_Group * newgroup)
       *newgroup = group;
       if (group != MPI_COMM_WORLD->group()
                 && group != MPI_COMM_SELF->group() && group != MPI_GROUP_EMPTY)
-      group->use();
+      group->ref();
       return MPI_SUCCESS;
     } else if (n == group->size()) {
       *newgroup = MPI_GROUP_EMPTY;
@@ -424,7 +430,7 @@ int PMPI_Group_range_excl(MPI_Group group, int n, int ranges[][3], MPI_Group * n
       *newgroup = group;
       if (group != MPI_COMM_WORLD->group() && group != MPI_COMM_SELF->group() &&
           group != MPI_GROUP_EMPTY)
-        group->use();
+        group->ref();
       return MPI_SUCCESS;
     } else {
       return group->range_excl(n,ranges,newgroup);
@@ -477,7 +483,7 @@ int PMPI_Comm_group(MPI_Comm comm, MPI_Group * group)
   } else {
     *group = comm->group();
     if (*group != MPI_COMM_WORLD->group() && *group != MPI_GROUP_NULL && *group != MPI_GROUP_EMPTY)
-      (*group)->use();
+      (*group)->ref();
     return MPI_SUCCESS;
   }
 }
@@ -524,7 +530,7 @@ int PMPI_Comm_create(MPI_Comm comm, MPI_Group group, MPI_Comm * newcomm)
     *newcomm= MPI_COMM_NULL;
     return MPI_SUCCESS;
   }else{
-    group->use();
+    group->ref();
     *newcomm = new Comm(group, nullptr);
     return MPI_SUCCESS;
   }
@@ -537,7 +543,7 @@ int PMPI_Comm_free(MPI_Comm * comm)
   } else if (*comm == MPI_COMM_NULL) {
     return MPI_ERR_COMM;
   } else {
-    (*comm)->destroy();
+    Comm::destroy(*comm);
     *comm = MPI_COMM_NULL;
     return MPI_SUCCESS;
   }
@@ -551,7 +557,7 @@ int PMPI_Comm_disconnect(MPI_Comm * comm)
   } else if (*comm == MPI_COMM_NULL) {
     return MPI_ERR_COMM;
   } else {
-    (*comm)->destroy();
+    Comm::destroy(*comm);
     *comm = MPI_COMM_NULL;
     return MPI_SUCCESS;
   }
@@ -705,7 +711,7 @@ int PMPI_Request_free(MPI_Request * request)
   if (*request == MPI_REQUEST_NULL) {
     retval = MPI_ERR_ARG;
   } else {
-    Request::unuse(request);
+    Request::unref(request);
     retval = MPI_SUCCESS;
   }
   smpi_bench_begin();
@@ -1380,7 +1386,7 @@ int PMPI_Bcast(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm c
     extra->send_size = count * dt_size_send;
     TRACE_smpi_collective_in(rank, root_traced, __FUNCTION__, extra);
     if (comm->size() > 1)
-      mpi_coll_bcast_fun(buf, count, datatype, root, comm);
+      Colls::bcast(buf, count, datatype, root, comm);
     retval = MPI_SUCCESS;
 
     TRACE_smpi_collective_out(rank, root_traced, __FUNCTION__);
@@ -1403,7 +1409,7 @@ int PMPI_Barrier(MPI_Comm comm)
     extra->type            = TRACING_BARRIER;
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    mpi_coll_barrier_fun(comm);
+    Colls::barrier(comm);
     retval = MPI_SUCCESS;
 
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
@@ -1455,7 +1461,7 @@ int PMPI_Gather(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbu
 
     TRACE_smpi_collective_in(rank, root_traced, __FUNCTION__, extra);
 
-    mpi_coll_gather_fun(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcount, recvtype, root, comm);
+    Colls::gather(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcount, recvtype, root, comm);
 
     retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, root_traced, __FUNCTION__);
@@ -1515,8 +1521,7 @@ int PMPI_Gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recv
     }
     TRACE_smpi_collective_in(rank, root_traced, __FUNCTION__, extra);
 
-    smpi_mpi_gatherv(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcounts, displs, recvtype, root, comm);
-    retval = MPI_SUCCESS;
+    retval = Colls::gatherv(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcounts, displs, recvtype, root, comm);
     TRACE_smpi_collective_out(rank, root_traced, __FUNCTION__);
   }
 
@@ -1562,7 +1567,7 @@ int PMPI_Allgather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    mpi_coll_allgather_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm);
+    Colls::allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm);
     retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
   }
@@ -1614,7 +1619,7 @@ int PMPI_Allgatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    mpi_coll_allgatherv_fun(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
+    Colls::allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm);
     retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
   }
@@ -1662,7 +1667,7 @@ int PMPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
     extra->recv_size = recvcount * dt_size_recv;
     TRACE_smpi_collective_in(rank, root_traced, __FUNCTION__, extra);
 
-    mpi_coll_scatter_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm);
+    Colls::scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm);
     retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, root_traced, __FUNCTION__);
   }
@@ -1715,9 +1720,8 @@ int PMPI_Scatterv(void *sendbuf, int *sendcounts, int *displs,
     extra->recv_size = recvcount * dt_size_recv;
     TRACE_smpi_collective_in(rank, root_traced, __FUNCTION__, extra);
 
-    smpi_mpi_scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm);
+    retval = Colls::scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm);
 
-    retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, root_traced, __FUNCTION__);
   }
 
@@ -1750,7 +1754,7 @@ int PMPI_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
 
     TRACE_smpi_collective_in(rank, root_traced, __FUNCTION__, extra);
 
-    mpi_coll_reduce_fun(sendbuf, recvbuf, count, datatype, op, root, comm);
+    Colls::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
 
     retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, root_traced, __FUNCTION__);
@@ -1767,7 +1771,7 @@ int PMPI_Reduce_local(void *inbuf, void *inoutbuf, int count, MPI_Datatype datat
   if (!datatype->is_valid() || op == MPI_OP_NULL) {
     retval = MPI_ERR_ARG;
   } else {
-    if(op!=MPI_OP_NULL) op->apply( inbuf, inoutbuf, &count, datatype);
+    op->apply(inbuf, inoutbuf, &count, datatype);
     retval = MPI_SUCCESS;
   }
   smpi_bench_begin();
@@ -1805,7 +1809,7 @@ int PMPI_Allreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatyp
 
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    mpi_coll_allreduce_fun(sendtmpbuf, recvbuf, count, datatype, op, comm);
+    Colls::allreduce(sendtmpbuf, recvbuf, count, datatype, op, comm);
 
     if( sendbuf == MPI_IN_PLACE )
       xbt_free(sendtmpbuf);
@@ -1843,9 +1847,8 @@ int PMPI_Scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MP
 
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    smpi_mpi_scan(sendbuf, recvbuf, count, datatype, op, comm);
+    retval = Colls::scan(sendbuf, recvbuf, count, datatype, op, comm);
 
-    retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
   }
 
@@ -1881,8 +1884,8 @@ int PMPI_Exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
     }
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    smpi_mpi_exscan(sendtmpbuf, recvbuf, count, datatype, op, comm);
-    retval = MPI_SUCCESS;
+    retval = Colls::exscan(sendtmpbuf, recvbuf, count, datatype, op, comm);
+
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
     if (sendbuf == MPI_IN_PLACE)
       xbt_free(sendtmpbuf);
@@ -1932,7 +1935,7 @@ int PMPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datat
 
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    mpi_coll_reduce_scatter_fun(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm);
+    Colls::reduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm);
     retval = MPI_SUCCESS;
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
 
@@ -1985,7 +1988,7 @@ int PMPI_Reduce_scatter_block(void *sendbuf, void *recvbuf, int recvcount,
     int* recvcounts = static_cast<int*>(xbt_malloc(count * sizeof(int)));
     for (int i      = 0; i < count; i++)
       recvcounts[i] = recvcount;
-    mpi_coll_reduce_scatter_fun(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm);
+    Colls::reduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm);
     xbt_free(recvcounts);
     retval = MPI_SUCCESS;
 
@@ -2038,7 +2041,7 @@ int PMPI_Alltoall(void* sendbuf, int sendcount, MPI_Datatype sendtype, void* rec
 
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
 
-    retval = mpi_coll_alltoall_fun(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcount, recvtype, comm);
+    retval = Colls::alltoall(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcount, recvtype, comm);
 
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
 
@@ -2111,7 +2114,7 @@ int PMPI_Alltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
     }
     extra->num_processes = size;
     TRACE_smpi_collective_in(rank, -1, __FUNCTION__, extra);
-    retval = mpi_coll_alltoallv_fun(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptype, recvbuf, recvcounts,
+    retval = Colls::alltoallv(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptype, recvbuf, recvcounts,
                                     recvdisps, recvtype, comm);
     TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
 
@@ -2446,7 +2449,7 @@ int PMPI_Win_free( MPI_Win* win){
   if (win == nullptr || *win == MPI_WIN_NULL) {
     retval = MPI_ERR_WIN;
   }else{
-    delete(*win);
+    delete *win;
     retval=MPI_SUCCESS;
   }
   smpi_bench_begin();
@@ -2482,7 +2485,7 @@ int PMPI_Win_get_group(MPI_Win  win, MPI_Group * group){
     return MPI_ERR_WIN;
   }else {
     win->get_group(group);
-    (*group)->use();
+    (*group)->ref();
     return MPI_SUCCESS;
   }
 }
@@ -2713,59 +2716,59 @@ int PMPI_Type_get_name(MPI_Datatype  datatype, char * name, int* len)
 }
 
 MPI_Datatype PMPI_Type_f2c(MPI_Fint datatype){
-  return smpi_type_f2c(datatype);
+  return static_cast<MPI_Datatype>(F2C::f2c(datatype));
 }
 
 MPI_Fint PMPI_Type_c2f(MPI_Datatype datatype){
-  return smpi_type_c2f( datatype);
+  return datatype->c2f();
 }
 
 MPI_Group PMPI_Group_f2c(MPI_Fint group){
-  return smpi_group_f2c( group);
+  return Group::f2c(group);
 }
 
 MPI_Fint PMPI_Group_c2f(MPI_Group group){
-  return smpi_group_c2f(group);
+  return group->c2f();
 }
 
 MPI_Request PMPI_Request_f2c(MPI_Fint request){
-  return smpi_request_f2c(request);
+  return static_cast<MPI_Request>(Request::f2c(request));
 }
 
 MPI_Fint PMPI_Request_c2f(MPI_Request request) {
-  return smpi_request_c2f(request);
+  return request->c2f();
 }
 
 MPI_Win PMPI_Win_f2c(MPI_Fint win){
-  return smpi_win_f2c(win);
+  return static_cast<MPI_Win>(Win::f2c(win));
 }
 
 MPI_Fint PMPI_Win_c2f(MPI_Win win){
-  return smpi_win_c2f(win);
+  return win->c2f();
 }
 
 MPI_Op PMPI_Op_f2c(MPI_Fint op){
-  return smpi_op_f2c(op);
+  return static_cast<MPI_Op>(Op::f2c(op));
 }
 
 MPI_Fint PMPI_Op_c2f(MPI_Op op){
-  return smpi_op_c2f(op);
+  return op->c2f();
 }
 
 MPI_Comm PMPI_Comm_f2c(MPI_Fint comm){
-  return smpi_comm_f2c(comm);
+  return static_cast<MPI_Comm>(Comm::f2c(comm));
 }
 
 MPI_Fint PMPI_Comm_c2f(MPI_Comm comm){
-  return smpi_comm_c2f(comm);
+  return comm->c2f();
 }
 
 MPI_Info PMPI_Info_f2c(MPI_Fint info){
-  return smpi_info_f2c(info);
+  return static_cast<MPI_Info>(Info::f2c(info));
 }
 
 MPI_Fint PMPI_Info_c2f(MPI_Info info){
-  return smpi_info_c2f(info);
+  return info->c2f();
 }
 
 int PMPI_Keyval_create(MPI_Copy_function* copy_fn, MPI_Delete_function* delete_fn, int* keyval, void* extra_state) {
@@ -2897,28 +2900,21 @@ int PMPI_Type_free_keyval(int* keyval) {
 int PMPI_Info_create( MPI_Info *info){
   if (info == nullptr)
     return MPI_ERR_ARG;
-  *info = xbt_new(s_smpi_mpi_info_t, 1);
-  (*info)->info_dict= xbt_dict_new_homogeneous(xbt_free_f);
-  (*info)->refcount=1;
+  *info = new Info();
   return MPI_SUCCESS;
 }
 
 int PMPI_Info_set( MPI_Info info, char *key, char *value){
   if (info == nullptr || key == nullptr || value == nullptr)
     return MPI_ERR_ARG;
-
-  xbt_dict_set(info->info_dict, key, xbt_strdup(value), nullptr);
+  info->set(key, value);
   return MPI_SUCCESS;
 }
 
 int PMPI_Info_free( MPI_Info *info){
   if (info == nullptr || *info==nullptr)
     return MPI_ERR_ARG;
-  (*info)->refcount--;
-  if((*info)->refcount==0){
-    xbt_dict_free(&((*info)->info_dict));
-    xbt_free(*info);
-  }
+  Info::unref(*info);
   *info=MPI_INFO_NULL;
   return MPI_SUCCESS;
 }
@@ -2929,78 +2925,39 @@ int PMPI_Info_get(MPI_Info info,char *key,int valuelen, char *value, int *flag){
     return MPI_ERR_ARG;
   if (value == nullptr)
     return MPI_ERR_INFO_VALUE;
-  char* tmpvalue=static_cast<char*>(xbt_dict_get_or_null(info->info_dict, key));
-  if(tmpvalue){
-    memset(value, 0, valuelen);
-    memcpy(value,tmpvalue, (strlen(tmpvalue) + 1 < static_cast<size_t>(valuelen)) ? strlen(tmpvalue) + 1 : valuelen);
-    *flag=true;
-  }
-  return MPI_SUCCESS;
+  return info->get(key, valuelen, value, flag);
 }
 
 int PMPI_Info_dup(MPI_Info info, MPI_Info *newinfo){
   if (info == nullptr || newinfo==nullptr)
     return MPI_ERR_ARG;
-  *newinfo = xbt_new(s_smpi_mpi_info_t, 1);
-  (*newinfo)->info_dict= xbt_dict_new_homogeneous(xbt_free_f);
-  (*newinfo)->refcount=1;
-  xbt_dict_cursor_t cursor = nullptr;
-  char* key;
-  void* data;
-  xbt_dict_foreach(info->info_dict,cursor,key,data){
-    xbt_dict_set((*newinfo)->info_dict, key, xbt_strdup(static_cast<char*>(data)), nullptr);
-  }
+  *newinfo = new Info(info);
   return MPI_SUCCESS;
 }
 
 int PMPI_Info_delete(MPI_Info info, char *key){
   if (info == nullptr || key==nullptr)
     return MPI_ERR_ARG;
-  try {
-    xbt_dict_remove(info->info_dict, key);
-  }
-  catch(xbt_ex& e){
-    return MPI_ERR_INFO_NOKEY;
-  }
-  return MPI_SUCCESS;
+  return info->remove(key);
 }
 
 int PMPI_Info_get_nkeys( MPI_Info info, int *nkeys){
   if (info == nullptr || nkeys==nullptr)
     return MPI_ERR_ARG;
-  *nkeys=xbt_dict_size(info->info_dict);
-  return MPI_SUCCESS;
+  return info->get_nkeys(nkeys);
 }
 
 int PMPI_Info_get_nthkey( MPI_Info info, int n, char *key){
   if (info == nullptr || key==nullptr || n<0 || n> MPI_MAX_INFO_KEY)
     return MPI_ERR_ARG;
-
-  xbt_dict_cursor_t cursor = nullptr;
-  char *keyn;
-  void* data;
-  int num=0;
-  xbt_dict_foreach(info->info_dict,cursor,keyn,data){
-    if(num==n){
-      strncpy(key,keyn,strlen(keyn)+1);
-      xbt_dict_cursor_free(&cursor);
-      return MPI_SUCCESS;
-    }
-    num++;
-  }
-  return MPI_ERR_ARG;
+  return info->get_nthkey(n, key);
 }
 
 int PMPI_Info_get_valuelen( MPI_Info info, char *key, int *valuelen, int *flag){
   *flag=false;
   if (info == nullptr || key == nullptr || valuelen==nullptr)
     return MPI_ERR_ARG;
-  char* tmpvalue=(char*)xbt_dict_get_or_null(info->info_dict, key);
-  if(tmpvalue){
-    *valuelen=strlen(tmpvalue);
-    *flag=true;
-  }
-  return MPI_SUCCESS;
+  return info->get_valuelen(key, valuelen, flag);
 }
 
 int PMPI_Unpack(void* inbuf, int incount, int* position, void* outbuf, int outcount, MPI_Datatype type, MPI_Comm comm) {