Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'mc' into mc++
[simgrid.git] / teshsuite / smpi / mpich3-test / coll / allgatherv4.c
1 /* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
2 /*
3  *
4  *  (C) 2003 by Argonne National Laboratory.
5  *      See COPYRIGHT in top-level directory.
6  */
7
8 #include "mpi.h"
9 #include "mpitest.h"
10 #include "smpi_cocci.h"
11 #include <stdio.h>
12 #include <stdlib.h>
13 #ifdef HAVE_SYS_TIME_H
14 #include <sys/time.h>
15 #endif
16 #include <time.h>
17 #include <math.h>
18 #include <assert.h>
19
20 /* FIXME: What is this test supposed to accomplish? */
21
22 #define START_BUF (1)
23 #define LARGE_BUF (256 * 1024)
24
25 /* FIXME: MAX_BUF is too large */
26 #define MAX_BUF   (32 * 1024 * 1024)
27 #define LOOPS 10
28
29 SMPI_VARINIT_GLOBAL(sbuf, char*);
30 SMPI_VARINIT_GLOBAL(rbuf, char*);
31 SMPI_VARINIT_GLOBAL(recvcounts, int*);
32 SMPI_VARINIT_GLOBAL(displs, int*);
33 SMPI_VARINIT_GLOBAL_AND_SET(errs, int, 0);
34
35 /* #define dprintf printf */
36 #define dprintf(...)
37
38 typedef enum {
39     REGULAR,
40     BCAST,
41     SPIKE,
42     HALF_FULL,
43     LINEAR_DECREASE,
44     BELL_CURVE
45 } test_t;
46
47 void comm_tests(MPI_Comm comm);
48 double run_test(long long msg_size, MPI_Comm comm, test_t test_type, double * max_time);
49
50 int main(int argc, char ** argv)
51 {
52     int comm_size, comm_rank;
53     MPI_Comm comm;
54
55     MTest_Init(&argc, &argv);
56     MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
57     MPI_Comm_rank(MPI_COMM_WORLD, &comm_rank);
58
59     if (LARGE_BUF * comm_size > MAX_BUF)
60         goto fn_exit;
61
62     SMPI_VARGET_GLOBAL(sbuf) = (void *) calloc(MAX_BUF, 1);
63     SMPI_VARGET_GLOBAL(rbuf) = (void *) calloc(MAX_BUF, 1);
64
65     srand(time(NULL));
66
67     SMPI_VARGET_GLOBAL(recvcounts) = (void *) malloc(comm_size * sizeof(int));
68     SMPI_VARGET_GLOBAL(displs) = (void *) malloc(comm_size * sizeof(int));
69     if (!SMPI_VARGET_GLOBAL(recvcounts) || !SMPI_VARGET_GLOBAL(displs) || !SMPI_VARGET_GLOBAL(sbuf) || !SMPI_VARGET_GLOBAL(rbuf)) {
70         fprintf(stderr, "Unable to allocate memory:\n");
71         if (!SMPI_VARGET_GLOBAL(sbuf)) fprintf(stderr,"\tsbuf of %d bytes\n", MAX_BUF );
72         if (!SMPI_VARGET_GLOBAL(rbuf)) fprintf(stderr,"\trbuf of %d bytes\n", MAX_BUF );
73         if (!SMPI_VARGET_GLOBAL(recvcounts)) fprintf(stderr,"\trecvcounts of %zd bytes\n", comm_size * sizeof(int) );
74         if (!SMPI_VARGET_GLOBAL(displs)) fprintf(stderr,"\tdispls of %zd bytes\n", comm_size * sizeof(int) );
75         fflush(stderr);
76         MPI_Abort(MPI_COMM_WORLD, -1);
77         exit(-1);
78     }
79
80     if (!comm_rank) {
81         dprintf("Message Range: (%d, %d); System size: %d\n", START_BUF, LARGE_BUF, comm_size);
82         fflush(stdout);
83     }
84
85
86     /* COMM_WORLD tests */
87     if (!comm_rank) {
88         dprintf("\n\n==========================================================\n");
89         dprintf("                         MPI_COMM_WORLD\n");
90         dprintf("==========================================================\n");
91     }
92     comm_tests(MPI_COMM_WORLD);
93
94     /* non-COMM_WORLD tests */
95     if (!comm_rank) {
96         dprintf("\n\n==========================================================\n");
97         dprintf("                         non-COMM_WORLD\n");
98         dprintf("==========================================================\n");
99     }
100     MPI_Comm_split(MPI_COMM_WORLD, (comm_rank == comm_size - 1) ? 0 : 1, 0, &comm);
101     if (comm_rank < comm_size - 1)
102         comm_tests(comm);
103     MPI_Comm_free(&comm);
104
105     /* Randomized communicator tests */
106     if (!comm_rank) {
107         dprintf("\n\n==========================================================\n");
108         dprintf("                         Randomized Communicator\n");
109         dprintf("==========================================================\n");
110     }
111     MPI_Comm_split(MPI_COMM_WORLD, 0, rand(), &comm);
112     comm_tests(comm);
113     MPI_Comm_free(&comm);
114
115     //free(SMPI_VARGET_GLOBAL(sbuf));
116     //free(SMPI_VARGET_GLOBAL(rbuf));
117     free(SMPI_VARGET_GLOBAL(recvcounts));
118     free(SMPI_VARGET_GLOBAL(displs));
119
120 fn_exit:
121     MTest_Finalize(SMPI_VARGET_GLOBAL(errs));
122     MPI_Finalize();
123
124     return 0;
125 }
126
127 void comm_tests(MPI_Comm comm)
128 {
129     int comm_size, comm_rank;
130     double rtime = rtime;       /* stop warning about unused variable */
131     double max_time;
132     long long msg_size;
133
134     MPI_Comm_size(comm, &comm_size);
135     MPI_Comm_rank(comm, &comm_rank);
136
137     for (msg_size = START_BUF; msg_size <= LARGE_BUF; msg_size *= 2) {
138         if (!comm_rank) {
139             dprintf("\n====> MSG_SIZE: %d\n", (int) msg_size);
140             fflush(stdout);
141         }
142
143         rtime = run_test(msg_size, comm, REGULAR, &max_time);
144         if (!comm_rank) {
145             dprintf("REGULAR:\tAVG: %.3f\tMAX: %.3f\n", rtime, max_time);
146             fflush(stdout);
147         }
148
149         rtime = run_test(msg_size, comm, BCAST, &max_time);
150         if (!comm_rank) {
151             dprintf("BCAST:\tAVG: %.3f\tMAX: %.3f\n", rtime, max_time);
152             fflush(stdout);
153         }
154
155         rtime = run_test(msg_size, comm, SPIKE, &max_time);
156         if (!comm_rank) {
157             dprintf("SPIKE:\tAVG: %.3f\tMAX: %.3f\n", rtime, max_time);
158             fflush(stdout);
159         }
160
161         rtime = run_test(msg_size, comm, HALF_FULL, &max_time);
162         if (!comm_rank) {
163             dprintf("HALF_FULL:\tAVG: %.3f\tMAX: %.3f\n", rtime, max_time);
164             fflush(stdout);
165         }
166
167         rtime = run_test(msg_size, comm, LINEAR_DECREASE, &max_time);
168         if (!comm_rank) {
169             dprintf("LINEAR_DECREASE:\tAVG: %.3f\tMAX: %.3f\n", rtime, max_time);
170             fflush(stdout);
171         }
172
173         rtime = run_test(msg_size, comm, BELL_CURVE, &max_time);
174         if (!comm_rank) {
175             dprintf("BELL_CURVE:\tAVG: %.3f\tMAX: %.3f\n", rtime, max_time);
176             fflush(stdout);
177         }
178     }
179 }
180
181 double run_test(long long msg_size, MPI_Comm comm, test_t test_type, 
182                 double * max_time)
183 {
184     int i, j;
185     int comm_size, comm_rank;
186     double start, end;
187     double total_time, avg_time;
188     MPI_Aint tmp;
189
190     MPI_Comm_size(comm, &comm_size);
191     MPI_Comm_rank(comm, &comm_rank);
192
193     SMPI_VARGET_GLOBAL(displs)[0] = 0;
194     for (i = 0; i < comm_size; i++) {
195         if (test_type == REGULAR)
196             SMPI_VARGET_GLOBAL(recvcounts)[i] = msg_size;
197         else if (test_type == BCAST)
198             SMPI_VARGET_GLOBAL(recvcounts)[i] = (!i) ? msg_size : 0;
199         else if (test_type == SPIKE)
200             SMPI_VARGET_GLOBAL(recvcounts)[i] = (!i) ? (msg_size / 2) : (msg_size / (2 * (comm_size - 1)));
201         else if (test_type == HALF_FULL)
202             SMPI_VARGET_GLOBAL(recvcounts)[i] = (i < (comm_size / 2)) ? (2 * msg_size) : 0;
203         else if (test_type == LINEAR_DECREASE) {
204             tmp = 2 * msg_size * (comm_size - 1 - i) / (comm_size - 1);
205             if (tmp != (int)tmp) {
206                 fprintf( stderr, "Integer overflow in variable tmp\n" );
207                 MPI_Abort( MPI_COMM_WORLD, 1 );
208                 exit(1);
209             }
210             SMPI_VARGET_GLOBAL(recvcounts)[i] = (int) tmp;
211
212             /* If the maximum message size is too large, don't run */
213             if (tmp > MAX_BUF) return 0;
214         }
215         else if (test_type == BELL_CURVE) {
216             for (j = 0; j < i; j++) {
217                 if (i - 1 + j >= comm_size) continue;
218                 tmp = msg_size * comm_size / (log(comm_size) * i);
219                 SMPI_VARGET_GLOBAL(recvcounts)[i - 1 + j] = (int) tmp;
220                 SMPI_VARGET_GLOBAL(displs)[i - 1 + j] = 0;
221
222                 /* If the maximum message size is too large, don't run */
223                 if (tmp > MAX_BUF) return 0;
224             }
225         }
226
227         if (i < comm_size - 1)
228             SMPI_VARGET_GLOBAL(displs)[i+1] = SMPI_VARGET_GLOBAL(displs)[i] + SMPI_VARGET_GLOBAL(recvcounts)[i];
229     }
230
231     /* Test that:
232        1: sbuf is large enough
233        2: rbuf is large enough
234        3: There were no failures (e.g., tmp nowhere > rbuf size 
235     */
236     MPI_Barrier(comm);
237     start = MPI_Wtime();
238     for (i = 0; i < LOOPS; i++) {
239         MPI_Allgatherv(SMPI_VARGET_GLOBAL(sbuf), SMPI_VARGET_GLOBAL(recvcounts)[comm_rank], MPI_CHAR,
240                        SMPI_VARGET_GLOBAL(rbuf), SMPI_VARGET_GLOBAL(recvcounts), SMPI_VARGET_GLOBAL(displs), MPI_CHAR, comm);
241     }
242     end = MPI_Wtime();
243     MPI_Barrier(comm);
244
245     /* Convert to microseconds (why?) */
246     total_time = 1.0e6 * (end - start);
247     MPI_Reduce(&total_time, &avg_time, 1, MPI_DOUBLE, MPI_SUM, 0, comm);
248     MPI_Reduce(&total_time, max_time, 1, MPI_DOUBLE, MPI_MAX, 0, comm);
249
250     return (avg_time / (LOOPS * comm_size));
251 }