Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
+ MPI_Sendrecv()
[simgrid.git] / src / smpi / smpi_mpi.c
1
2
3 #include "private.h"
4 #include "smpi_coll_private.h"
5
6 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_mpi, smpi,
7                                 "Logging specific to SMPI (mpi)");
8
9 int SMPI_MPI_Init(int *argc, char ***argv)
10 {
11   smpi_process_init(argc, argv);
12   smpi_bench_begin();
13   return MPI_SUCCESS;
14 }
15
16 int SMPI_MPI_Finalize()
17 {
18   smpi_bench_end();
19   smpi_process_finalize();
20   return MPI_SUCCESS;
21 }
22
23 // right now this just exits the current node, should send abort signal to all
24 // hosts in the communicator (TODO)
25 int SMPI_MPI_Abort(MPI_Comm comm, int errorcode)
26 {
27   smpi_exit(errorcode);
28   return 0;
29 }
30
31 int SMPI_MPI_Comm_size(MPI_Comm comm, int *size)
32 {
33   int retval = MPI_SUCCESS;
34
35   smpi_bench_end();
36
37   if (NULL == comm) {
38     retval = MPI_ERR_COMM;
39   } else if (NULL == size) {
40     retval = MPI_ERR_ARG;
41   } else {
42     *size = comm->size;
43   }
44
45   smpi_bench_begin();
46
47   return retval;
48 }
49
50 int SMPI_MPI_Comm_rank(MPI_Comm comm, int *rank)
51 {
52   int retval = MPI_SUCCESS;
53
54   smpi_bench_end();
55
56   if (NULL == comm) {
57     retval = MPI_ERR_COMM;
58   } else if (NULL == rank) {
59     retval = MPI_ERR_ARG;
60   } else {
61     *rank = smpi_mpi_comm_rank(comm);
62   }
63
64   smpi_bench_begin();
65
66   return retval;
67 }
68
69 int SMPI_MPI_Type_size(MPI_Datatype datatype, size_t * size)
70 {
71   int retval = MPI_SUCCESS;
72
73   smpi_bench_end();
74
75   if (NULL == datatype) {
76     retval = MPI_ERR_TYPE;
77   } else if (NULL == size) {
78     retval = MPI_ERR_ARG;
79   } else {
80     *size = datatype->size;
81   }
82
83   smpi_bench_begin();
84
85   return retval;
86 }
87
88 int SMPI_MPI_Barrier(MPI_Comm comm)
89 {
90   int retval = MPI_SUCCESS;
91   int arity=4;
92
93   smpi_bench_end();
94
95   if (NULL == comm) {
96     retval = MPI_ERR_COMM;
97   } else {
98
99     /*
100      * original implemantation:
101      * retval = smpi_mpi_barrier(comm);
102      * this one is unrealistic: it just cond_waits, means no time.
103      */
104      retval = nary_tree_barrier( comm, arity );
105   }
106
107   smpi_bench_begin();
108
109   return retval;
110 }
111
112 int SMPI_MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src,
113                    int tag, MPI_Comm comm, MPI_Request * request)
114 {
115   int retval = MPI_SUCCESS;
116
117   smpi_bench_end();
118
119   retval = smpi_create_request(buf, count, datatype, src, 0, tag, comm,
120                                request);
121   if (NULL != *request && MPI_SUCCESS == retval) {
122     retval = smpi_mpi_irecv(*request);
123   }
124
125   smpi_bench_begin();
126
127   return retval;
128 }
129
130 int SMPI_MPI_Recv(void *buf, int count, MPI_Datatype datatype, int src,
131                   int tag, MPI_Comm comm, MPI_Status * status)
132 {
133   int retval = MPI_SUCCESS;
134   smpi_mpi_request_t request;
135
136   smpi_bench_end();
137
138   retval = smpi_create_request(buf, count, datatype, src, 0, tag, comm,
139                                &request);
140   if (NULL != request && MPI_SUCCESS == retval) {
141     retval = smpi_mpi_irecv(request);
142     if (MPI_SUCCESS == retval) {
143       retval = smpi_mpi_wait(request, status);
144     }
145     xbt_mallocator_release(smpi_global->request_mallocator, request);
146   }
147
148   smpi_bench_begin();
149
150   return retval;
151 }
152
153 int SMPI_MPI_Isend(void *buf, int count, MPI_Datatype datatype, int dst,
154                    int tag, MPI_Comm comm, MPI_Request * request)
155 {
156   int retval = MPI_SUCCESS;
157
158   smpi_bench_end();
159
160   retval = smpi_create_request(buf, count, datatype, 0, dst, tag, comm,
161                                request);
162   if (NULL != *request && MPI_SUCCESS == retval) {
163     retval = smpi_mpi_isend(*request);
164   }
165
166   smpi_bench_begin();
167
168   return retval;
169 }
170
171 int SMPI_MPI_Send(void *buf, int count, MPI_Datatype datatype, int dst,
172                   int tag, MPI_Comm comm)
173 {
174   int retval = MPI_SUCCESS;
175   smpi_mpi_request_t request;
176
177   smpi_bench_end();
178
179   retval = smpi_create_request(buf, count, datatype, 0, dst, tag, comm,
180                                &request);
181   if (NULL != request && MPI_SUCCESS == retval) {
182     retval = smpi_mpi_isend(request);
183     if (MPI_SUCCESS == retval) {
184       smpi_mpi_wait(request, MPI_STATUS_IGNORE);
185     }
186     xbt_mallocator_release(smpi_global->request_mallocator, request);
187   }
188
189   smpi_bench_begin();
190
191   return retval;
192 }
193
194 /**
195  * MPI_Sendrecv
196  **/
197 int SMPI_MPI_Sendrecv(void *sendbuf, int sendcount, MPI_Datatype sendtype, int dest, int sendtag, 
198                     void *recvbuf, int recvcount, MPI_Datatype recvtype, int source, int recvtag,
199                     MPI_Comm comm, MPI_Status *status)
200 {
201 int rank;
202 int retval = MPI_SUCCESS;
203 smpi_mpi_request_t srequest;
204 smpi_mpi_request_t rrequest;
205
206           rank = smpi_mpi_comm_rank(comm);
207
208           /* send */
209           retval = smpi_create_request(sendbuf, sendcount, sendtype, 
210                                 rank,dest,sendtag, 
211                                 comm, &srequest);
212           smpi_mpi_isend(srequest);
213
214
215           /* recv */
216           retval = smpi_create_request(recvbuf, recvcount, recvtype, 
217                                 source,rank,recvtag, 
218                                 comm, &rrequest);
219           smpi_mpi_irecv(rrequest);
220
221           smpi_mpi_wait(srequest, MPI_STATUS_IGNORE);
222           smpi_mpi_wait(rrequest, MPI_STATUS_IGNORE);
223
224           return(retval);
225 }
226
227
228 /**
229  * MPI_Wait and friends
230  **/
231 int SMPI_MPI_Wait(MPI_Request * request, MPI_Status * status)
232 {
233   return smpi_mpi_wait(*request, status);
234 }
235
236 int SMPI_MPI_Waitall(int count, MPI_Request requests[], MPI_Status status[])
237 {
238   return smpi_mpi_waitall(count, requests, status);
239 }
240
241 int SMPI_MPI_Waitany(int count, MPI_Request requests[], int *index,
242                      MPI_Status status[])
243 {
244   return smpi_mpi_waitany(count, requests, index, status);
245 }
246
247 /**
248  * MPI_Bcast
249  **/
250
251 /**
252  * flat bcast 
253  **/
254 int flat_tree_bcast(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm comm);
255 int flat_tree_bcast(void *buf, int count, MPI_Datatype datatype, int root,
256                 MPI_Comm comm)
257 {
258         int rank;
259         int retval = MPI_SUCCESS;
260         smpi_mpi_request_t request;
261
262         rank = smpi_mpi_comm_rank(comm);
263         if (rank == root) {
264                 retval = smpi_create_request(buf, count, datatype, root,
265                                 (root + 1) % comm->size, 0, comm, &request);
266                 request->forward = comm->size - 1;
267                 smpi_mpi_isend(request);
268         } else {
269                 retval = smpi_create_request(buf, count, datatype, MPI_ANY_SOURCE, rank,
270                                 0, comm, &request);
271                 smpi_mpi_irecv(request);
272         }
273
274         smpi_mpi_wait(request, MPI_STATUS_IGNORE);
275         xbt_mallocator_release(smpi_global->request_mallocator, request);
276
277         return(retval);
278
279 }
280
281 /**
282  * Bcast user entry point
283  **/
284 int SMPI_MPI_Bcast(void *buf, int count, MPI_Datatype datatype, int root,
285                    MPI_Comm comm)
286 {
287   int retval = MPI_SUCCESS;
288
289   smpi_bench_end();
290
291   //retval = flat_tree_bcast(buf, count, datatype, root, comm);
292   retval = nary_tree_bcast(buf, count, datatype, root, comm, 2 );
293
294   smpi_bench_begin();
295
296   return retval;
297 }
298
299
300
301 //#ifdef DEBUG_REDUCE
302 /**
303  * debugging helper function
304  **/
305 static void print_buffer_int(void *buf, int len, char *msg, int rank)
306 {
307   int tmp, *v;
308   printf("**[%d] %s: ", rank, msg);
309   for (tmp = 0; tmp < len; tmp++) {
310     v = buf;
311     printf("[%d]", v[tmp]);
312   }
313   printf("\n");
314   free(msg);
315 }
316 static void print_buffer_double(void *buf, int len, char *msg, int rank)
317 {
318   int tmp;
319   double *v;
320   printf("**[%d] %s: ", rank, msg);
321   for (tmp = 0; tmp < len; tmp++) {
322     v = buf;
323     printf("[%lf]", v[tmp]);
324   }
325   printf("\n");
326   free(msg);
327 }
328
329
330 //#endif
331 /**
332  * MPI_Reduce
333  **/
334 int SMPI_MPI_Reduce(void *sendbuf, void *recvbuf, int count,
335                 MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
336 {
337         int retval = MPI_SUCCESS;
338         int rank;
339         int size;
340         int i;
341         int tag = 0;
342         smpi_mpi_request_t *requests;
343         smpi_mpi_request_t request;
344
345         smpi_bench_end();
346
347         rank = smpi_mpi_comm_rank(comm);
348         size = comm->size;
349
350         if (rank != root) {           // if i am not ROOT, simply send my buffer to root
351
352 #ifdef DEBUG_REDUCE
353                 print_buffer_int(sendbuf, count, xbt_strdup("sndbuf"), rank);
354 #endif
355                 retval =
356                         smpi_create_request(sendbuf, count, datatype, rank, root, tag, comm,
357                                         &request);
358                 smpi_mpi_isend(request);
359                 smpi_mpi_wait(request, MPI_STATUS_IGNORE);
360                 xbt_mallocator_release(smpi_global->request_mallocator, request);
361
362         } else {
363                 // i am the ROOT: wait for all buffers by creating one request by sender
364                 int src;
365                 requests = xbt_malloc((size-1) * sizeof(smpi_mpi_request_t));
366
367                 void **tmpbufs = xbt_malloc((size-1) * sizeof(void *));
368                 for (i = 0; i < size-1; i++) {
369                         // we need 1 buffer per request to store intermediate receptions
370                         tmpbufs[i] = xbt_malloc(count * datatype->size);
371                 }  
372                 // root: initiliaze recv buf with my own snd buf
373                 memcpy(recvbuf, sendbuf, count * datatype->size * sizeof(char));  
374
375                 // i can not use: 'request->forward = size-1;' (which would progagate size-1 receive reqs)
376                 // since we should op values as soon as one receiving request matches.
377                 for (i = 0; i < size-1; i++) {
378                         // reminder: for smpi_create_request() the src is always the process sending.
379                         src = i < root ? i : i + 1;
380                         retval = smpi_create_request(tmpbufs[i], count, datatype,
381                                         src, root, tag, comm, &(requests[i]));
382                         if (NULL != requests[i] && MPI_SUCCESS == retval) {
383                                 if (MPI_SUCCESS == retval) {
384                                         smpi_mpi_irecv(requests[i]);
385                                 }
386                         }
387                 }
388                 // now, wait for completion of all irecv's.
389                 for (i = 0; i < size-1; i++) {
390                         int index = MPI_UNDEFINED;
391                         smpi_mpi_waitany( size-1, requests, &index, MPI_STATUS_IGNORE);
392 #ifdef DEBUG_REDUCE
393                         printf ("MPI_Waitany() unblocked: root received (completes req[index=%d])\n",index);
394                         print_buffer_int(tmpbufs[index], count, bprintf("tmpbufs[index=%d] (value received)", index),
395                                         rank);
396 #endif
397
398                         // arg 2 is modified
399                         op->func(tmpbufs[index], recvbuf, &count, &datatype);
400 #ifdef DEBUG_REDUCE
401                         print_buffer_int(recvbuf, count, xbt_strdup("rcvbuf"), rank);
402 #endif
403                         xbt_free(tmpbufs[index]);
404                         /* FIXME: with the following line, it  generates an
405                          * [xbt_ex/CRITICAL] Conditional list not empty 162518800.
406                          */
407                         // xbt_mallocator_release(smpi_global->request_mallocator, requests[index]);
408                 }
409                 xbt_free(requests);
410                 xbt_free(tmpbufs);
411         }
412         smpi_bench_begin();
413         return retval;
414 }
415
416 /**
417  * MPI_Allreduce
418  *
419  * Same as MPI_REDUCE except that the result appears in the receive buffer of all the group members.
420  **/
421 int SMPI_MPI_Allreduce( void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
422                          MPI_Op op, MPI_Comm comm );
423 int SMPI_MPI_Allreduce( void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
424                          MPI_Op op, MPI_Comm comm )
425 {
426   int retval = MPI_SUCCESS;
427   int root=1;  // arbitrary choice
428
429   smpi_bench_end();
430
431   retval = SMPI_MPI_Reduce( sendbuf, recvbuf, count, datatype, op, root, comm);
432   if (MPI_SUCCESS != retval)
433             return(retval);
434
435   retval = SMPI_MPI_Bcast( sendbuf, count, datatype, root, comm);
436   smpi_bench_begin();
437   return( retval );
438 }
439
440
441 /**
442  * MPI_Scatter user entry point
443  **/
444 //int SMPI_MPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype datatype, 
445 //                       void *recvbuf, int recvcount, MPI_Datatype recvtype,int root,
446 //                      MPI_Comm comm);
447 int SMPI_MPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype datatype, 
448                          void *recvbuf, int recvcount, MPI_Datatype recvtype,
449                            int root, MPI_Comm comm)
450 {
451   int retval = MPI_SUCCESS;
452   int i;
453   int cnt=0;  
454   int rank;
455   int tag=0;
456   char *cptr;  // to manipulate the void * buffers
457   smpi_mpi_request_t *requests;
458   smpi_mpi_request_t request;
459   smpi_mpi_status_t status;
460
461
462   smpi_bench_end();
463
464   rank = smpi_mpi_comm_rank(comm);
465
466   requests = xbt_malloc((comm->size-1) * sizeof(smpi_mpi_request_t));
467   if (rank == root) {
468           // i am the root: distribute my sendbuf
469           //print_buffer_int(sendbuf, comm->size, xbt_strdup("rcvbuf"), rank);
470           cptr = sendbuf;
471           for (i=0; i < comm->size; i++) {
472                   if ( i!=root ) { // send to processes ...
473
474                           retval = smpi_create_request((void *)cptr, sendcount, 
475                                           datatype, root, i, tag, comm, &(requests[cnt]));
476                           if (NULL != requests[cnt] && MPI_SUCCESS == retval) {
477                                   if (MPI_SUCCESS == retval) {
478                                           smpi_mpi_isend(requests[cnt]);
479                                   }
480                                   }
481                                   cnt++;
482                         } 
483                         else { // ... except if it's me.
484                                   memcpy(recvbuf, (void *)cptr, recvcount*recvtype->size*sizeof(char));
485                         }
486                   cptr += sendcount*datatype->size;
487             }
488             for(i=0; i<cnt; i++) { // wait for send to complete
489                             /* FIXME: waitall() should be slightly better */
490                             smpi_mpi_wait(requests[i], &status);
491                             xbt_mallocator_release(smpi_global->request_mallocator, requests[i]);
492
493             }
494   } 
495   else {  // i am a non-root process: wait data from the root
496             retval = smpi_create_request(recvbuf,recvcount, 
497                                   recvtype, root, rank, tag, comm, &request);
498             if (NULL != request && MPI_SUCCESS == retval) {
499                         if (MPI_SUCCESS == retval) {
500                                   smpi_mpi_irecv(request);
501                         }
502             }
503             smpi_mpi_wait(request, &status);
504             xbt_mallocator_release(smpi_global->request_mallocator, request);
505   }
506   xbt_free(requests);
507
508   smpi_bench_begin();
509
510   return retval;
511 }
512
513
514 /**
515  * MPI_Alltoall user entry point
516  * 
517  * Uses the logic of OpenMPI (upto 1.2.7 or greater) for the optimizations
518  * ompi/mca/coll/tuned/coll_tuned_module.c
519  **/
520 int SMPI_MPI_Alltoall(void *sendbuf, int sendcount, MPI_Datatype datatype, 
521                          void *recvbuf, int recvcount, MPI_Datatype recvtype,
522                            MPI_Comm comm)
523 {
524   int retval = MPI_SUCCESS;
525   int block_dsize;
526   int rank;
527
528   smpi_bench_end();
529
530   rank = smpi_mpi_comm_rank(comm);
531   block_dsize = datatype->size * sendcount;
532
533   if ((block_dsize < 200) && (comm->size > 12)) {
534             retval = smpi_coll_tuned_alltoall_bruck(sendbuf, sendcount, datatype,
535                                   recvbuf, recvcount, recvtype, comm);
536
537   } else if (block_dsize < 3000) {
538 /* use this one !!          retval = smpi_coll_tuned_alltoall_basic_linear(sendbuf, sendcount, datatype,
539                                   recvbuf, recvcount, recvtype, comm);
540                                   */
541   retval = smpi_coll_tuned_alltoall_pairwise(sendbuf, sendcount, datatype,
542                                   recvbuf, recvcount, recvtype, comm);
543   } else {
544
545   retval = smpi_coll_tuned_alltoall_pairwise(sendbuf, sendcount, datatype,
546                                   recvbuf, recvcount, recvtype, comm);
547   }
548
549   smpi_bench_begin();
550
551   return retval;
552 }
553
554
555
556
557 // used by comm_split to sort ranks based on key values
558 int smpi_compare_rankkeys(const void *a, const void *b);
559 int smpi_compare_rankkeys(const void *a, const void *b)
560 {
561   int *x = (int *) a;
562   int *y = (int *) b;
563
564   if (x[1] < y[1])
565     return -1;
566
567   if (x[1] == y[1]) {
568     if (x[0] < y[0])
569       return -1;
570     if (x[0] == y[0])
571       return 0;
572     return 1;
573   }
574
575   return 1;
576 }
577
578 int SMPI_MPI_Comm_split(MPI_Comm comm, int color, int key,
579                         MPI_Comm * comm_out)
580 {
581   int retval = MPI_SUCCESS;
582
583   int index, rank;
584   smpi_mpi_request_t request;
585   int colorkey[2];
586   smpi_mpi_status_t status;
587
588   smpi_bench_end();
589
590   // FIXME: need to test parameters
591
592   index = smpi_process_index();
593   rank = comm->index_to_rank_map[index];
594
595   // default output
596   comm_out = NULL;
597
598   // root node does most of the real work
599   if (0 == rank) {
600     int colormap[comm->size];
601     int keymap[comm->size];
602     int rankkeymap[comm->size * 2];
603     int i, j;
604     smpi_mpi_communicator_t tempcomm = NULL;
605     int count;
606     int indextmp;
607
608     colormap[0] = color;
609     keymap[0] = key;
610
611     // FIXME: use scatter/gather or similar instead of individual comms
612     for (i = 1; i < comm->size; i++) {
613       retval = smpi_create_request(colorkey, 2, MPI_INT, MPI_ANY_SOURCE,
614                                    rank, MPI_ANY_TAG, comm, &request);
615       smpi_mpi_irecv(request);
616       smpi_mpi_wait(request, &status);
617       colormap[status.MPI_SOURCE] = colorkey[0];
618       keymap[status.MPI_SOURCE] = colorkey[1];
619       xbt_mallocator_release(smpi_global->request_mallocator, request);
620     }
621
622     for (i = 0; i < comm->size; i++) {
623       if (MPI_UNDEFINED == colormap[i]) {
624         continue;
625       }
626       // make a list of nodes with current color and sort by keys
627       count = 0;
628       for (j = i; j < comm->size; j++) {
629         if (colormap[i] == colormap[j]) {
630           colormap[j] = MPI_UNDEFINED;
631           rankkeymap[count * 2] = j;
632           rankkeymap[count * 2 + 1] = keymap[j];
633           count++;
634         }
635       }
636       qsort(rankkeymap, count, sizeof(int) * 2, &smpi_compare_rankkeys);
637
638       // new communicator
639       tempcomm = xbt_new(s_smpi_mpi_communicator_t, 1);
640       tempcomm->barrier_count = 0;
641       tempcomm->size = count;
642       tempcomm->barrier_mutex = SIMIX_mutex_init();
643       tempcomm->barrier_cond = SIMIX_cond_init();
644       tempcomm->rank_to_index_map = xbt_new(int, count);
645       tempcomm->index_to_rank_map = xbt_new(int, smpi_global->process_count);
646       for (j = 0; j < smpi_global->process_count; j++) {
647         tempcomm->index_to_rank_map[j] = -1;
648       }
649       for (j = 0; j < count; j++) {
650         indextmp = comm->rank_to_index_map[rankkeymap[j * 2]];
651         tempcomm->rank_to_index_map[j] = indextmp;
652         tempcomm->index_to_rank_map[indextmp] = j;
653       }
654       for (j = 0; j < count; j++) {
655         if (rankkeymap[j * 2]) {
656           retval = smpi_create_request(&j, 1, MPI_INT, 0,
657                                        rankkeymap[j * 2], 0, comm, &request);
658           request->data = tempcomm;
659           smpi_mpi_isend(request);
660           smpi_mpi_wait(request, &status);
661           xbt_mallocator_release(smpi_global->request_mallocator, request);
662         } else {
663           *comm_out = tempcomm;
664         }
665       }
666     }
667   } else {
668     colorkey[0] = color;
669     colorkey[1] = key;
670     retval = smpi_create_request(colorkey, 2, MPI_INT, rank, 0, 0, comm,
671                                  &request);
672     smpi_mpi_isend(request);
673     smpi_mpi_wait(request, &status);
674     xbt_mallocator_release(smpi_global->request_mallocator, request);
675     if (MPI_UNDEFINED != color) {
676       retval = smpi_create_request(colorkey, 1, MPI_INT, 0, rank, 0, comm,
677                                    &request);
678       smpi_mpi_irecv(request);
679       smpi_mpi_wait(request, &status);
680       *comm_out = request->data;
681     }
682   }
683
684   smpi_bench_begin();
685
686   return retval;
687 }
688
689 double SMPI_MPI_Wtime(void)
690 {
691   return (SIMIX_get_clock());
692 }
693
694