1 #include "colls_private.h"
6 int bcast_SMP_linear_segment_byte = 8192;
8 int smpi_coll_tuned_bcast_SMP_linear(void *buf, int count,
9 MPI_Datatype datatype, int root,
12 int tag = COLL_TAG_BCAST;
15 MPI_Request *request_array;
16 MPI_Status *status_array;
20 extent = smpi_datatype_get_extent(datatype);
22 rank = smpi_comm_rank(comm);
23 size = smpi_comm_size(comm);
26 THROWF(arg_error,0, "bcast SMP linear can't be used with non multiple of NUM_CORE=%d number of processes ! ",NUM_CORE);
28 int segment = bcast_SMP_linear_segment_byte / extent;
29 int pipe_length = count / segment;
30 int remainder = count % segment;
31 int increment = segment * extent;
34 /* leader of each SMP do inter-communication
35 and act as a root for intra-communication */
36 int to_inter = (rank + NUM_CORE) % size;
37 int to_intra = (rank + 1) % size;
38 int from_inter = (rank - NUM_CORE + size) % size;
39 int from_intra = (rank + size - 1) % size;
41 // call native when MPI communication size is too small
42 if (size <= NUM_CORE) {
43 XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");
44 smpi_mpi_bcast(buf, count, datatype, root, comm);
47 // if root is not zero send to rank zero first
50 smpi_mpi_send(buf, count, datatype, 0, tag, comm);
52 smpi_mpi_recv(buf, count, datatype, root, tag, comm, &status);
54 // when a message is smaller than a block size => no pipeline
55 if (count <= segment) {
58 smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
59 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
61 // case last ROOT of each SMP
62 else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
63 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
64 smpi_mpi_wait(&request, &status);
65 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
67 // case intermediate ROOT of each SMP
68 else if (rank % NUM_CORE == 0) {
69 request = smpi_mpi_irecv(buf, count, datatype, from_inter, tag, comm);
70 smpi_mpi_wait(&request, &status);
71 smpi_mpi_send(buf, count, datatype, to_inter, tag, comm);
72 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
74 // case last non-ROOT of each SMP
75 else if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
76 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
77 smpi_mpi_wait(&request, &status);
79 // case intermediate non-ROOT of each SMP
81 request = smpi_mpi_irecv(buf, count, datatype, from_intra, tag, comm);
82 smpi_mpi_wait(&request, &status);
83 smpi_mpi_send(buf, count, datatype, to_intra, tag, comm);
90 (MPI_Request *) xbt_malloc((size + pipe_length) * sizeof(MPI_Request));
92 (MPI_Status *) xbt_malloc((size + pipe_length) * sizeof(MPI_Status));
94 // case ROOT of each SMP
95 if (rank % NUM_CORE == 0) {
98 for (i = 0; i < pipe_length; i++) {
99 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
101 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
105 // case last ROOT of each SMP
106 else if (rank == (((size - 1) / NUM_CORE) * NUM_CORE)) {
107 for (i = 0; i < pipe_length; i++) {
108 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
109 from_inter, (tag + i), comm);
111 for (i = 0; i < pipe_length; i++) {
112 smpi_mpi_wait(&request_array[i], &status);
113 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
117 // case intermediate ROOT of each SMP
119 for (i = 0; i < pipe_length; i++) {
120 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
121 from_inter, (tag + i), comm);
123 for (i = 0; i < pipe_length; i++) {
124 smpi_mpi_wait(&request_array[i], &status);
125 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_inter,
127 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
131 } else { // case last non-ROOT of each SMP
132 if (((rank + 1) % NUM_CORE == 0) || (rank == (size - 1))) {
133 for (i = 0; i < pipe_length; i++) {
134 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
135 from_intra, (tag + i), comm);
137 for (i = 0; i < pipe_length; i++) {
138 smpi_mpi_wait(&request_array[i], &status);
141 // case intermediate non-ROOT of each SMP
143 for (i = 0; i < pipe_length; i++) {
144 request_array[i] = smpi_mpi_irecv((char *) buf + (i * increment), segment, datatype,
145 from_intra, (tag + i), comm);
147 for (i = 0; i < pipe_length; i++) {
148 smpi_mpi_wait(&request_array[i], &status);
149 smpi_mpi_send((char *) buf + (i * increment), segment, datatype, to_intra,
158 // when count is not divisible by block size, use default BCAST for the remainder
159 if ((remainder != 0) && (count > segment)) {
160 XBT_WARN("MPI_bcast_SMP_linear use default MPI_bcast.");
161 smpi_mpi_bcast((char *) buf + (pipe_length * increment), remainder, datatype,