Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
support MPI_Op_commutative call, as it was already implemented internally
[simgrid.git] / teshsuite / smpi / mpich3-test / coll / redscat2.c
1 /* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil ; -*- */
2 /*
3  *  (C) 2001 by Argonne National Laboratory.
4  *      See COPYRIGHT in top-level directory.
5  */
6 /*
7  * Test of reduce scatter.
8  *
9  * Checks that non-commutative operations are not commuted and that
10  * all of the operations are performed.
11  *
12  * Can be called with any number of processors.
13  */
14
15 #include "mpi.h"
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include "mpitest.h"
19
20 int err = 0;
21
22 /* left(x,y) ==> x */
23 void left(void *a, void *b, int *count, MPI_Datatype * type);
24 void left(void *a, void *b, int *count, MPI_Datatype * type)
25 {
26     int *in = a;
27     int *inout = b;
28     int i;
29
30     for (i = 0; i < *count; ++i) {
31         if (in[i] > inout[i])
32             ++err;
33         inout[i] = in[i];
34     }
35 }
36
37 /* right(x,y) ==> y */
38 void right(void *a, void *b, int *count, MPI_Datatype * type);
39 void right(void *a, void *b, int *count, MPI_Datatype * type)
40 {
41     int *in = a;
42     int *inout = b;
43     int i;
44
45     for (i = 0; i < *count; ++i) {
46         if (in[i] > inout[i])
47             ++err;
48         inout[i] = inout[i];
49     }
50 }
51
52 /* Just performs a simple sum but can be marked as non-commutative to
53    potentially tigger different logic in the implementation. */
54 void nc_sum(void *a, void *b, int *count, MPI_Datatype * type);
55 void nc_sum(void *a, void *b, int *count, MPI_Datatype * type)
56 {
57     int *in = a;
58     int *inout = b;
59     int i;
60
61     for (i = 0; i < *count; ++i) {
62         inout[i] = in[i] + inout[i];
63     }
64 }
65
66 #define MAX_BLOCK_SIZE 256
67
68 int main(int argc, char **argv)
69 {
70     int *sendbuf, *recvcounts;
71     int block_size;
72     int *recvbuf;
73     int size, rank, i;
74     MPI_Comm comm;
75     MPI_Op left_op, right_op, nc_sum_op;
76
77     MTest_Init(&argc, &argv);
78     comm = MPI_COMM_WORLD;
79
80     MPI_Comm_size(comm, &size);
81     MPI_Comm_rank(comm, &rank);
82
83     MPI_Op_create(&left, 0 /*non-commutative */ , &left_op);
84     MPI_Op_create(&right, 0 /*non-commutative */ , &right_op);
85     MPI_Op_create(&nc_sum, 0 /*non-commutative */ , &nc_sum_op);
86
87     for (block_size = 1; block_size < MAX_BLOCK_SIZE; block_size *= 2) {
88         sendbuf = (int *) malloc(block_size * size * sizeof(int));
89         recvbuf = malloc(block_size * sizeof(int));
90
91         for (i = 0; i < (size * block_size); i++)
92             sendbuf[i] = rank + i;
93         for (i = 0; i < block_size; i++)
94             recvbuf[i] = 0xdeadbeef;
95         recvcounts = (int *) malloc(size * sizeof(int));
96         for (i = 0; i < size; i++)
97             recvcounts[i] = block_size;
98
99         MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, left_op, comm);
100         for (i = 0; i < block_size; ++i)
101             if (recvbuf[i] != (rank * block_size + i))
102                 ++err;
103
104         MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, right_op, comm);
105         for (i = 0; i < block_size; ++i)
106             if (recvbuf[i] != ((size - 1) + (rank * block_size) + i))
107                 ++err;
108
109         MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, nc_sum_op, comm);
110         for (i = 0; i < block_size; ++i) {
111             int x = rank * block_size + i;
112             if (recvbuf[i] != (size * x + (size - 1) * size / 2))
113                 ++err;
114         }
115
116         free(recvbuf);
117         free(sendbuf);
118         free(recvcounts);
119     }
120
121     MPI_Op_free(&left_op);
122     MPI_Op_free(&right_op);
123     MPI_Op_free(&nc_sum_op);
124
125     MTest_Finalize(err);
126     MPI_Finalize();
127
128     return err;
129 }