Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
some preliminary additions to implement more collectives
[simgrid.git] / src / smpi / smpi_base.c
1 #include "private.h"
2 #include "xbt/time.h"
3
4 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_base, smpi,
5                                 "Logging specific to SMPI (base)");
6 XBT_LOG_EXTERNAL_CATEGORY(smpi_base);
7 XBT_LOG_EXTERNAL_CATEGORY(smpi_bench);
8 XBT_LOG_EXTERNAL_CATEGORY(smpi_kernel);
9 XBT_LOG_EXTERNAL_CATEGORY(smpi_mpi);
10 XBT_LOG_EXTERNAL_CATEGORY(smpi_receiver);
11 XBT_LOG_EXTERNAL_CATEGORY(smpi_sender);
12 XBT_LOG_EXTERNAL_CATEGORY(smpi_util);
13
14 smpi_mpi_global_t smpi_mpi_global = NULL;
15
16 /**
17  * Operations of MPI_OP : implemented=land,sum,min,max
18  **/
19 void smpi_mpi_land_func(void *a, void *b, int *length,
20                         MPI_Datatype * datatype);
21
22 void smpi_mpi_land_func(void *a, void *b, int *length,
23                         MPI_Datatype * datatype)
24 {
25   int i;
26   if (*datatype == smpi_mpi_global->mpi_int) {
27     int *x = a, *y = b;
28     for (i = 0; i < *length; i++) {
29       y[i] = x[i] && y[i];
30     }
31   }
32 }
33
34 /**
35  * sum two vectors element-wise
36  *
37  * @param a the first vectors 
38  * @param b the second vectors
39  * @return the second vector is modified and contains the element-wise sums
40  **/
41 void smpi_mpi_sum_func(void *a, void *b, int *length,
42                        MPI_Datatype * datatype);
43
44 void smpi_mpi_sum_func(void *a, void *b, int *length, MPI_Datatype * datatype)
45 {
46           int i;
47           if (*datatype == smpi_mpi_global->mpi_byte) {
48                                 char *x = a, *y = b;
49                                 for (i = 0; i < *length; i++) {
50                                           y[i] = x[i] + y[i];
51                                 }
52           } else {
53           if (*datatype == smpi_mpi_global->mpi_int) {
54                                 int *x = a, *y = b;
55                                 for (i = 0; i < *length; i++) {
56                                           y[i] = x[i] + y[i];
57                                 }
58           } else {
59           if (*datatype == smpi_mpi_global->mpi_float) {
60                                 float *x = a, *y = b;
61                                 for (i = 0; i < *length; i++) {
62                                           y[i] = x[i] + y[i];
63                                 }
64           } else {
65           if (*datatype == smpi_mpi_global->mpi_double) {
66                                 double *x = a, *y = b;
67                                 for (i = 0; i < *length; i++) {
68                                           y[i] = x[i] + y[i];
69                                 }
70           }}}}
71 }
72 /**
73  * compute the min of two vectors element-wise
74  **/
75 void smpi_mpi_min_func(void *a, void *b, int *length, MPI_Datatype * datatype);
76
77 void smpi_mpi_min_func(void *a, void *b, int *length, MPI_Datatype * datatype)
78 {
79           int i;
80           if (*datatype == smpi_mpi_global->mpi_byte) {
81                                 char *x = a, *y = b;
82                                 for (i = 0; i < *length; i++) {
83                                           y[i] = x[i] < y[i] ? x[i] : y[i];
84                                 }
85           } else {
86           if (*datatype == smpi_mpi_global->mpi_int) {
87                                 int *x = a, *y = b;
88                                 for (i = 0; i < *length; i++) {
89                                           y[i] = x[i] < y[i] ? x[i] : y[i];
90                                 }
91           } else {
92           if (*datatype == smpi_mpi_global->mpi_float) {
93                                 float *x = a, *y = b;
94                                 for (i = 0; i < *length; i++) {
95                                           y[i] = x[i] < y[i] ? x[i] : y[i];
96                                 }
97           } else {
98           if (*datatype == smpi_mpi_global->mpi_double) {
99                                 double *x = a, *y = b;
100                                 for (i = 0; i < *length; i++) {
101                                           y[i] = x[i] < y[i] ? x[i] : y[i];
102                                 }
103
104           }}}}
105 }
106 /**
107  * compute the max of two vectors element-wise
108  **/
109 void smpi_mpi_max_func(void *a, void *b, int *length, MPI_Datatype * datatype);
110
111 void smpi_mpi_max_func(void *a, void *b, int *length, MPI_Datatype * datatype)
112 {
113           int i;
114           if (*datatype == smpi_mpi_global->mpi_byte) {
115                                 char *x = a, *y = b;
116                                 for (i = 0; i < *length; i++) {
117                                           y[i] = x[i] > y[i] ? x[i] : y[i];
118                                 }
119           } else {
120           if (*datatype == smpi_mpi_global->mpi_int) {
121                                 int *x = a, *y = b;
122                                 for (i = 0; i > *length; i++) {
123                                           y[i] = x[i] < y[i] ? x[i] : y[i];
124                                 }
125           } else {
126           if (*datatype == smpi_mpi_global->mpi_float) {
127                                 float *x = a, *y = b;
128                                 for (i = 0; i > *length; i++) {
129                                           y[i] = x[i] < y[i] ? x[i] : y[i];
130                                 }
131           } else {
132           if (*datatype == smpi_mpi_global->mpi_double) {
133                                 double *x = a, *y = b;
134                                 for (i = 0; i > *length; i++) {
135                                           y[i] = x[i] < y[i] ? x[i] : y[i];
136                                 }
137
138           }}}}
139 }
140
141
142
143
144 /**
145  * tell the MPI rank of the calling process (from its SIMIX process id)
146  **/
147 int smpi_mpi_comm_rank(smpi_mpi_communicator_t comm)
148 {
149   return comm->index_to_rank_map[smpi_process_index()];
150 }
151
152 void smpi_process_init(int *argc, char***argv)
153 {
154   smpi_process_data_t pdata;
155
156   // initialize some local variables
157
158   pdata = xbt_new(s_smpi_process_data_t, 1);
159   SIMIX_process_set_data(SIMIX_process_self(),pdata);
160
161   /* get rank from command line, and remove it from argv */
162   pdata->index = atoi( (*argv)[1] );
163   DEBUG1("I'm rank %d",pdata->index);
164   if (*argc>2) {
165           memmove((*argv)[1],(*argv)[2], sizeof(char*)* (*argc-2));
166           (*argv)[ (*argc)-1] = NULL;
167   }
168   (*argc)--;
169
170   pdata->mutex = SIMIX_mutex_init();
171   pdata->cond = SIMIX_cond_init();
172   pdata->finalize = 0;
173
174   pdata->pending_recv_request_queue = xbt_fifo_new();
175   pdata->pending_send_request_queue = xbt_fifo_new();
176   pdata->received_message_queue = xbt_fifo_new();
177
178   pdata->main = SIMIX_process_self();
179   pdata->sender = SIMIX_process_create("smpi_sender",
180           smpi_sender, pdata,
181           SIMIX_host_get_name(SIMIX_host_self()), 0, NULL,
182           /*props */ NULL);
183   pdata->receiver = SIMIX_process_create("smpi_receiver",
184           smpi_receiver, pdata,
185           SIMIX_host_get_name(SIMIX_host_self()), 0, NULL,
186           /*props */ NULL);
187
188   smpi_global->main_processes[pdata->index] = SIMIX_process_self();
189   return;
190 }
191
192 void smpi_process_finalize()
193 {
194   smpi_process_data_t pdata =  SIMIX_process_get_data(SIMIX_process_self());
195
196   pdata->finalize = 2; /* Tell sender and receiver to quit */
197   SIMIX_process_resume(pdata->sender);
198   SIMIX_process_resume(pdata->receiver);
199   while (pdata->finalize>0) { /* wait until it's done */
200           SIMIX_cond_wait(pdata->cond,pdata->mutex);
201   }
202
203   SIMIX_mutex_destroy(pdata->mutex);
204   SIMIX_cond_destroy(pdata->cond);
205   xbt_fifo_free(pdata->pending_recv_request_queue);
206   xbt_fifo_free(pdata->pending_send_request_queue);
207   xbt_fifo_free(pdata->received_message_queue);
208 }
209
210 int smpi_mpi_barrier(smpi_mpi_communicator_t comm)
211 {
212
213   SIMIX_mutex_lock(comm->barrier_mutex);
214   ++comm->barrier_count;
215   if (comm->barrier_count > comm->size) {       // only happens on second barrier...
216     comm->barrier_count = 0;
217   } else if (comm->barrier_count == comm->size) {
218     SIMIX_cond_broadcast(comm->barrier_cond);
219   }
220   while (comm->barrier_count < comm->size) {
221     SIMIX_cond_wait(comm->barrier_cond, comm->barrier_mutex);
222   }
223   SIMIX_mutex_unlock(comm->barrier_mutex);
224
225   return MPI_SUCCESS;
226 }
227
228 int smpi_mpi_isend(smpi_mpi_request_t request)
229 {
230         smpi_process_data_t pdata =  SIMIX_process_get_data(SIMIX_process_self());
231   int retval = MPI_SUCCESS;
232
233   if (NULL == request) {
234     retval = MPI_ERR_INTERN;
235   } else {
236     xbt_fifo_push(pdata->pending_send_request_queue, request);
237     SIMIX_process_resume(pdata->sender);
238   }
239
240   return retval;
241 }
242
243 int smpi_mpi_irecv(smpi_mpi_request_t request)
244 {
245   int retval = MPI_SUCCESS;
246   smpi_process_data_t pdata =  SIMIX_process_get_data(SIMIX_process_self());
247
248   if (NULL == request) {
249     retval = MPI_ERR_INTERN;
250   } else {
251     xbt_fifo_push(pdata->pending_recv_request_queue, request);
252
253     if (SIMIX_process_is_suspended(pdata->receiver)) {
254       SIMIX_process_resume(pdata->receiver);
255     }
256   }
257
258   return retval;
259 }
260
261 int smpi_mpi_wait(smpi_mpi_request_t request, smpi_mpi_status_t * status)
262 {
263   int retval = MPI_SUCCESS;
264
265   if (NULL == request) {
266     retval = MPI_ERR_INTERN;
267   } else {
268     SIMIX_mutex_lock(request->mutex);
269     while (!request->completed) {
270       SIMIX_cond_wait(request->cond, request->mutex);
271     }
272     if (NULL != status) {
273       status->MPI_SOURCE = request->src;
274       status->MPI_TAG = request->tag;
275       status->MPI_ERROR = MPI_SUCCESS;
276     }
277     SIMIX_mutex_unlock(request->mutex);
278   }
279
280   return retval;
281 }