1 #include "colls_private.h"
6 int bcast_SMP_linear_segment_byte = 8192;
8 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
9 MPI_Datatype datatype, int root,
12 int tag = COLL_TAG_BCAST;
15 MPI_Request *request_array;
16 MPI_Status *status_array;
20 extent = smpi_datatype_get_extent(datatype);
22 rank = smpi_comm_rank(comm);
23 size = smpi_comm_size(comm);
25 int segment = bcast_SMP_linear_segment_byte / extent;
26 int pipe_length = count / segment;
27 int remainder = count % segment;
28 int increment = segment * extent;
31 /* leader of each SMP do inter-communication
32 and act as a root for intra-communication */
33 int to_inter = (rank + NUM_CORE) % size;
34 int to_intra = (rank + 1) % size;
35 int from_inter = (rank - NUM_CORE + size) % size;
36 int from_intra = (rank + size - 1) % size;
38 // call native when MPI communication size is too small
39 if (size <= NUM_CORE) {
40 XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");
41 smpi_mpi_bcast(buf, count, datatype, root, comm);
44 // if root is not zero send to rank zero first
47 smpi_mpi_send(buf, count, datatype, 0, tag, comm);
49 smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
51 // when a message is smaller than a block size => no pipeline
52 if (count <= segment) {
55 smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
56 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
58 // case last ROOT of each SMP
59 else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
60 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
61 smpi_mpi_wait(&request, &status);
62 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
64 // case intermediate ROOT of each SMP
65 else if (rank % NUM_CORE == 0) {
66 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
67 smpi_mpi_wait(&request, &status);
68 smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
69 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
71 // case last non-ROOT of each SMP
72 else if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
73 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
74 smpi_mpi_wait(&request, &status);
76 // case intermediate non-ROOT of each SMP
78 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
79 smpi_mpi_wait(&request, &status);
80 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
87 (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
89 (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
91 // case ROOT of each SMP
92 if (rank % NUM_CORE == 0) {
95 for (i = 0; i < pipe_length; i++) {
96 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
98 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
102 // case last ROOT of each SMP
103 else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
104 for (i = 0; i < pipe_length; i++) {
105 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
106 from_inter, (tag + i), comm);
108 for (i = 0; i < pipe_length; i++) {
109 smpi_mpi_wait(&request_array[i], &status);
110 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
114 // case intermediate ROOT of each SMP
116 for (i = 0; i < pipe_length; i++) {
117 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
118 from_inter, (tag + i), comm);
120 for (i = 0; i < pipe_length; i++) {
121 smpi_mpi_wait(&request_array[i], &status);
122 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
124 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
128 } else { // case last non-ROOT of each SMP
129 if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
130 for (i = 0; i < pipe_length; i++) {
131 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
132 from_intra, (tag + i), comm);
134 for (i = 0; i < pipe_length; i++) {
135 smpi_mpi_wait(&request_array[i], &status);
138 // case intermediate non-ROOT of each SMP
140 for (i = 0; i < pipe_length; i++) {
141 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
142 from_intra, (tag + i), comm);
144 for (i = 0; i < pipe_length; i++) {
145 smpi_mpi_wait(&request_array[i], &status);
146 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
155 // when count is not divisible by block size, use default BCAST for the remainder
156 if ((remainder != 0) && (count > segment)) {
157 XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");
158 smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,