Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of git+ssh://scm.gforge.inria.fr//gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / colls / bcast-arrival-pattern-aware.c
index f4a482c..09fbbcd 100644 (file)
@@ -1,3 +1,9 @@
+/* Copyright (c) 2013-2014. The SimGrid Team.
+ * All rights reserved.                                                     */
+
+/* This program is free software; you can redistribute it and/or modify it
+ * under the terms of the license (GNU LGPL) which comes with this package. */
+
 #include "colls_private.h"
 
 static int bcast_NTSL_segment_size_in_byte = 8192;
@@ -10,7 +16,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
                                                 MPI_Datatype datatype, int root,
                                                 MPI_Comm comm)
 {
-  int tag = 50;
+  int tag = -COLL_TAG_BCAST;
   MPI_Status status;
   MPI_Request request;
   MPI_Request *send_request_array;
@@ -27,7 +33,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
   int header_index;
   int flag_array[MAX_NODE];
   int already_sent[MAX_NODE];
-
+  int to_clean[MAX_NODE];
   int header_buf[HEADER_SIZE];
   char temp_buf[MAX_NODE];
 
@@ -39,13 +45,13 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
 
 
 
-  rank = smpi_comm_rank(MPI_COMM_WORLD);
-  size = smpi_comm_size(MPI_COMM_WORLD);
+  rank = smpi_comm_rank(comm);
+  size = smpi_comm_size(comm);
 
 
   /* segment is segment size in number of elements (not bytes) */
   int segment = bcast_NTSL_segment_size_in_byte / extent;
-
+  segment =  segment == 0 ? 1 :segment; 
   /* pipeline length */
   int pipe_length = count / segment;
 
@@ -70,6 +76,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
   /* value == 0 means root has not send data (or header) to the node yet */
   for (i = 0; i < MAX_NODE; i++) {
     already_sent[i] = 0;
+    to_clean[i] = 0;
   }
 
   /* when a message is smaller than a block size => no pipeline */
@@ -79,7 +86,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
 
       while (sent_count < (size - 1)) {
         for (i = 1; i < size; i++) {
-          smpi_mpi_iprobe(i, MPI_ANY_TAG, MPI_COMM_WORLD, &flag_array[i],
+          smpi_mpi_iprobe(i, MPI_ANY_TAG, comm, &flag_array[i],
                      MPI_STATUSES_IGNORE);
         }
 
@@ -89,7 +96,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
 
           /* message arrive */
           if ((flag_array[i] == 1) && (already_sent[i] == 0)) {
-            smpi_mpi_recv(temp_buf, 1, MPI_CHAR, i, tag, MPI_COMM_WORLD, &status);
+            smpi_mpi_recv(temp_buf, 1, MPI_CHAR, i, tag, comm, &status);
             header_buf[header_index] = i;
             header_index++;
             sent_count++;
@@ -171,7 +178,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
         //iteration++;
         //start = MPI_Wtime();
         for (i = 1; i < size; i++) {
-          smpi_mpi_iprobe(i, MPI_ANY_TAG, MPI_COMM_WORLD, &flag_array[i],
+          smpi_mpi_iprobe(i, MPI_ANY_TAG, comm, &flag_array[i],
                      &temp_status_array[i]);
         }
         //total = MPI_Wtime() - start;
@@ -184,7 +191,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
         for (i = 1; i < size; i++) {
           /* message arrive */
           if ((flag_array[i] == 1) && (already_sent[i] == 0)) {
-            smpi_mpi_recv(&temp_buf[i], 1, MPI_CHAR, i, tag, MPI_COMM_WORLD,
+            smpi_mpi_recv(&temp_buf[i], 1, MPI_CHAR, i, tag, comm,
                      &status);
             header_buf[header_index] = i;
             header_index++;
@@ -274,6 +281,7 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
 
 
               already_sent[i] = 1;
+              to_clean[i]=1;
               sent_count++;
               break;
             }
@@ -282,6 +290,9 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
 
       }                         /* while loop */
 
+      for(i=0; i<size; i++)
+        if(to_clean[i]!=0)smpi_mpi_recv(&temp_buf[i], 1, MPI_CHAR, i, tag, comm,
+                     &status);
       //total = MPI_Wtime() - start2;
       //total *= 1000;
       //printf("Node zero iter = %d time = %.2f\n",iteration,total);
@@ -331,8 +342,10 @@ int smpi_coll_tuned_bcast_arrival_pattern_aware(void *buf, int count,
                     header_buf[myordering + 1], tag, comm);
         }
         smpi_mpi_waitall((pipe_length), send_request_array, send_status_array);
-      }
-
+      }else{
+          smpi_mpi_waitall(pipe_length, recv_request_array, recv_status_array);
+          }
+    
     }
 
     free(send_request_array);