Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
ee7e3d2a66240973585d176aa7af4159fc771e31
[simgrid.git] / src / smpi / colls / bcast-SMP-linear.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8 #ifndef NUM_CORE
9 #define NUM_CORE 8
10 #endif
11
12 int bcast_SMP_linear_segment_byte = 8192;
13
14 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
15                                      MPI_Datatype datatype, int root,
16                                      MPI_Comm comm)
17 {
18   int tag = COLL_TAG_BCAST;
19   MPI_Status status;
20   MPI_Request request;
21   MPI_Request *request_array;
22   MPI_Status *status_array;
23   int rank, size;
24   int i;
25   MPI_Aint extent;
26   extent = smpi_datatype_get_extent(datatype);
27
28   rank = smpi_comm_rank(comm);
29   size = smpi_comm_size(comm);
30   int num_core = simcall_host_get_core(SIMIX_host_self());
31   // do we use the default one or the number of cores in the platform ?
32   // if the number of cores is one, the platform may be simulated with 1 node = 1 core
33   if (num_core == 1) num_core = NUM_CORE;
34
35   int segment = bcast_SMP_linear_segment_byte / extent;
36   segment =  segment == 0 ? 1 :segment; 
37   int pipe_length = count / segment;
38   int remainder = count % segment;
39   int increment = segment * extent;
40
41
42   /* leader of each SMP do inter-communication
43      and act as a root for intra-communication */
44   int to_inter = (rank + num_core) % size;
45   int to_intra = (rank + 1) % size;
46   int from_inter = (rank - num_core + size) % size;
47   int from_intra = (rank + size - 1) % size;
48
49   // call native when MPI communication size is too small
50   if (size <= num_core) {
51     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");              
52     smpi_mpi_bcast(buf, count, datatype, root, comm);
53     return MPI_SUCCESS;            
54   }
55   // if root is not zero send to rank zero first
56   if (root != 0) {
57     if (rank == root)
58       smpi_mpi_send(buf, count, datatype, 0, tag, comm);
59     else if (rank == 0)
60       smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
61   }
62   // when a message is smaller than a block size => no pipeline 
63   if (count <= segment) {
64     // case ROOT
65     if (rank == 0) {
66       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
67       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
68     }
69     // case last ROOT of each SMP
70     else if (rank == (((size - 1) / num_core) * num_core)) {
71       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
72       smpi_mpi_wait(&request, &status);
73       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
74     }
75     // case intermediate ROOT of each SMP
76     else if (rank % num_core == 0) {
77       request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
78       smpi_mpi_wait(&request, &status);
79       smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
80       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
81     }
82     // case last non-ROOT of each SMP
83     else if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
84       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
85       smpi_mpi_wait(&request, &status);
86     }
87     // case intermediate non-ROOT of each SMP
88     else {
89       request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
90       smpi_mpi_wait(&request, &status);
91       smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
92     }
93     return MPI_SUCCESS;
94   }
95   // pipeline bcast
96   else {
97     request_array =
98         (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
99     status_array =
100         (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
101
102     // case ROOT of each SMP
103     if (rank % num_core == 0) {
104       // case real root
105       if (rank == 0) {
106         for (i = 0; i < pipe_length; i++) {
107           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
108                    (tag + i), comm);
109           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
110                    (tag + i), comm);
111         }
112       }
113       // case last ROOT of each SMP
114       else if (rank == (((size - 1) / num_core) * num_core)) {
115         for (i = 0; i < pipe_length; i++) {
116           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
117                     from_inter, (tag + i), comm);
118         }
119         for (i = 0; i < pipe_length; i++) {
120           smpi_mpi_wait(&request_array[i], &status);
121           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
122                    (tag + i), comm);
123         }
124       }
125       // case intermediate ROOT of each SMP
126       else {
127         for (i = 0; i < pipe_length; i++) {
128           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
129                     from_inter, (tag + i), comm);
130         }
131         for (i = 0; i < pipe_length; i++) {
132           smpi_mpi_wait(&request_array[i], &status);
133           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
134                    (tag + i), comm);
135           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
136                    (tag + i), comm);
137         }
138       }
139     } else {                    // case last non-ROOT of each SMP
140       if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
141         for (i = 0; i < pipe_length; i++) {
142           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
143                     from_intra, (tag + i), comm);
144         }
145         for (i = 0; i < pipe_length; i++) {
146           smpi_mpi_wait(&request_array[i], &status);
147         }
148       }
149       // case intermediate non-ROOT of each SMP
150       else {
151         for (i = 0; i < pipe_length; i++) {
152           request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
153                     from_intra, (tag + i), comm);
154         }
155         for (i = 0; i < pipe_length; i++) {
156           smpi_mpi_wait(&request_array[i], &status);
157           smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
158                    (tag + i), comm);
159         }
160       }
161     }
162     free(request_array);
163     free(status_array);
164   }
165
166   // when count is not divisible by block size, use default BCAST for the remainder
167   if ((remainder != 0) && (count > segment)) {
168     XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");                     
169     smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,
170               root, comm);
171   }
172
173   return MPI_SUCCESS;
174 }