1 /* Copyright (c) 2013-2023. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "../colls_private.hpp"
9 int bcast_SMP_linear_segment_byte = 8192;
10 namespace simgrid::smpi {
11 int bcast__SMP_linear(void *buf, int count,
12 MPI_Datatype datatype, int root,
15 int tag = COLL_TAG_BCAST;
21 extent = datatype->get_extent();
25 if(comm->get_leaders_comm()==MPI_COMM_NULL){
29 if (comm->is_uniform()){
30 num_core = comm->get_intra_comm()->size();
32 //implementation buggy in this case
33 return bcast__mpich(buf, count, datatype, root, comm);
36 int segment = bcast_SMP_linear_segment_byte / extent;
37 segment = segment == 0 ? 1 :segment;
38 int pipe_length = count / segment;
39 int remainder = count % segment;
40 int increment = segment * extent;
43 /* leader of each SMP do inter-communication
44 and act as a root for intra-communication */
45 int to_inter = (rank + num_core) % size;
46 int to_intra = (rank + 1) % size;
47 int from_inter = (rank - num_core + size) % size;
48 int from_intra = (rank + size - 1) % size;
50 // call native when MPI communication size is too small
51 if (size <= num_core) {
52 XBT_INFO("size <= num_core : MPI_bcast_SMP_linear use default MPI_bcast.");
53 bcast__default(buf, count, datatype, root, comm);
56 // if root is not zero send to rank zero first
59 Request::send(buf, count, datatype, 0, tag, comm);
61 Request::recv(buf, count, datatype, root, tag, comm, &status);
63 // when a message is smaller than a block size => no pipeline
64 if (count <= segment) {
67 Request::send(buf, count, datatype, to_inter, tag, comm);
68 Request::send(buf, count, datatype, to_intra, tag, comm);
70 // case last ROOT of each SMP
71 else if (rank == (((size - 1) / num_core) * num_core)) {
72 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
73 Request::wait(&request, &status);
74 Request::send(buf, count, datatype, to_intra, tag, comm);
76 // case intermediate ROOT of each SMP
77 else if (rank % num_core == 0) {
78 request = Request::irecv(buf, count, datatype, from_inter, tag, comm);
79 Request::wait(&request, &status);
80 Request::send(buf, count, datatype, to_inter, tag, comm);
81 Request::send(buf, count, datatype, to_intra, tag, comm);
83 // case last non-ROOT of each SMP
84 else if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
85 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
86 Request::wait(&request, &status);
88 // case intermediate non-ROOT of each SMP
90 request = Request::irecv(buf, count, datatype, from_intra, tag, comm);
91 Request::wait(&request, &status);
92 Request::send(buf, count, datatype, to_intra, tag, comm);
98 auto* request_array = new MPI_Request[size + pipe_length];
99 auto* status_array = new MPI_Status[size + pipe_length];
101 // case ROOT of each SMP
102 if (rank % num_core == 0) {
105 for (i = 0; i < pipe_length; i++) {
106 Request::send((char *) buf + (i * increment), segment, datatype, to_inter,
108 Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
112 // case last ROOT of each SMP
113 else if (rank == (((size - 1) / num_core) * num_core)) {
114 for (i = 0; i < pipe_length; i++) {
115 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
116 from_inter, (tag + i), comm);
118 for (i = 0; i < pipe_length; i++) {
119 Request::wait(&request_array[i], &status);
120 Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
124 // case intermediate ROOT of each SMP
126 for (i = 0; i < pipe_length; i++) {
127 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
128 from_inter, (tag + i), comm);
130 for (i = 0; i < pipe_length; i++) {
131 Request::wait(&request_array[i], &status);
132 Request::send((char *) buf + (i * increment), segment, datatype, to_inter,
134 Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
138 } else { // case last non-ROOT of each SMP
139 if (((rank + 1) % num_core == 0) || (rank == (size - 1))) {
140 for (i = 0; i < pipe_length; i++) {
141 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
142 from_intra, (tag + i), comm);
144 for (i = 0; i < pipe_length; i++) {
145 Request::wait(&request_array[i], &status);
148 // case intermediate non-ROOT of each SMP
150 for (i = 0; i < pipe_length; i++) {
151 request_array[i] = Request::irecv((char *) buf + (i * increment), segment, datatype,
152 from_intra, (tag + i), comm);
154 for (i = 0; i < pipe_length; i++) {
155 Request::wait(&request_array[i], &status);
156 Request::send((char *) buf + (i * increment), segment, datatype, to_intra,
161 delete[] request_array;
162 delete[] status_array;
165 if ((remainder != 0) && (count > segment)) {
166 XBT_INFO("MPI_bcast_SMP_linear: count is not divisible by block size, use default MPI_bcast for remainder.");
167 colls::bcast((char*)buf + (pipe_length * increment), remainder, datatype, root, comm);
173 } // namespace simgrid::smpi