Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
ba77e82eefc56d98c07f9589e514a1cef914b6ce
[simgrid.git] / src / smpi / mpi / smpi_datatype_derived.cpp
1 /* smpi_datatype.cpp -- MPI primitives to handle datatypes                  */
2 /* Copyright (c) 2009-2017. The SimGrid Team. All rights reserved.          */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "smpi_datatype_derived.hpp"
8 #include "smpi_op.hpp"
9
10 XBT_LOG_EXTERNAL_CATEGORY(smpi_datatype);
11
12 namespace simgrid{
13 namespace smpi{
14
15
16 Type_Contiguous::Type_Contiguous(int size, MPI_Aint lb, MPI_Aint ub, int flags, int block_count, MPI_Datatype old_type): Datatype(size, lb, ub, flags), block_count_(block_count), old_type_(old_type){
17   old_type_->ref();
18 }
19
20 Type_Contiguous::~Type_Contiguous(){
21   Datatype::unref(old_type_);
22 }
23
24
25 void Type_Contiguous::serialize( void* noncontiguous_buf, void *contiguous_buf,
26                             int count){
27   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
28   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf)+lb();
29   memcpy(contiguous_buf_char, noncontiguous_buf_char, count * block_count_ * old_type_->size());
30 }
31 void Type_Contiguous::unserialize( void* contiguous_buf, void *noncontiguous_buf,
32                               int count, MPI_Op op){
33   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
34   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf)+lb();
35   int n= count*block_count_;
36   if(op!=MPI_OP_NULL)
37     op->apply( contiguous_buf_char, noncontiguous_buf_char, &n, old_type_);
38 }
39
40
41 Type_Vector::Type_Vector(int size,MPI_Aint lb, MPI_Aint ub, int flags, int count, int block_length, int stride, MPI_Datatype old_type): Datatype(size, lb, ub, flags), block_count_(count), block_length_(block_length),block_stride_(stride),  old_type_(old_type){
42   old_type_->ref();
43 }
44
45 Type_Vector::~Type_Vector(){
46   Datatype::unref(old_type_);
47 }
48
49
50 void Type_Vector::serialize( void* noncontiguous_buf, void *contiguous_buf,
51                             int count){
52   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
53   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf);
54
55   for (int i = 0; i < block_count_ * count; i++) {
56     if (not(old_type_->flags() & DT_FLAG_DERIVED))
57       memcpy(contiguous_buf_char, noncontiguous_buf_char, block_length_ * old_type_->size());
58     else
59       old_type_->serialize(noncontiguous_buf_char, contiguous_buf_char, block_length_);
60
61     contiguous_buf_char += block_length_*old_type_->size();
62     if((i+1)%block_count_ ==0)
63       noncontiguous_buf_char += block_length_*old_type_->get_extent();
64     else
65       noncontiguous_buf_char += block_stride_*old_type_->get_extent();
66   }
67 }
68
69 void Type_Vector::unserialize( void* contiguous_buf, void *noncontiguous_buf,
70                               int count, MPI_Op op){
71   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
72   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf);
73
74   for (int i = 0; i < block_count_ * count; i++) {
75     if (not(old_type_->flags() & DT_FLAG_DERIVED)) {
76       if(op != MPI_OP_NULL)
77         op->apply(contiguous_buf_char, noncontiguous_buf_char, &block_length_,
78           old_type_);
79     }else
80       old_type_->unserialize(contiguous_buf_char, noncontiguous_buf_char, block_length_, op);
81
82     contiguous_buf_char += block_length_*old_type_->size();
83     if((i+1)%block_count_ ==0)
84       noncontiguous_buf_char += block_length_*old_type_->get_extent();
85     else
86       noncontiguous_buf_char += block_stride_*old_type_->get_extent();
87   }
88 }
89
90 Type_Hvector::Type_Hvector(int size,MPI_Aint lb, MPI_Aint ub, int flags, int count, int block_length, MPI_Aint stride, MPI_Datatype old_type): Datatype(size, lb, ub, flags), block_count_(count), block_length_(block_length), block_stride_(stride), old_type_(old_type){
91   old_type->ref();
92 }
93 Type_Hvector::~Type_Hvector(){
94   Datatype::unref(old_type_);
95 }
96
97 void Type_Hvector::serialize( void* noncontiguous_buf, void *contiguous_buf,
98                     int count){
99   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
100   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf);
101
102   for (int i = 0; i < block_count_ * count; i++) {
103     if (not(old_type_->flags() & DT_FLAG_DERIVED))
104       memcpy(contiguous_buf_char, noncontiguous_buf_char, block_length_ * old_type_->size());
105     else
106       old_type_->serialize( noncontiguous_buf_char, contiguous_buf_char, block_length_);
107
108     contiguous_buf_char += block_length_*old_type_->size();
109     if((i+1)%block_count_ ==0)
110       noncontiguous_buf_char += block_length_*old_type_->size();
111     else
112       noncontiguous_buf_char += block_stride_;
113   }
114 }
115
116
117 void Type_Hvector::unserialize( void* contiguous_buf, void *noncontiguous_buf,
118                               int count, MPI_Op op){
119   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
120   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf);
121
122   for (int i = 0; i < block_count_ * count; i++) {
123     if (not(old_type_->flags() & DT_FLAG_DERIVED)) {
124       if(op!=MPI_OP_NULL)
125         op->apply( contiguous_buf_char, noncontiguous_buf_char, &block_length_, old_type_);
126     }else
127       old_type_->unserialize( contiguous_buf_char, noncontiguous_buf_char, block_length_, op);
128     contiguous_buf_char += block_length_*old_type_->size();
129     if((i+1)%block_count_ ==0)
130       noncontiguous_buf_char += block_length_*old_type_->size();
131     else
132       noncontiguous_buf_char += block_stride_;
133   }
134 }
135
136 Type_Indexed::Type_Indexed(int size,MPI_Aint lb, MPI_Aint ub, int flags, int count, int* block_lengths, int* block_indices, MPI_Datatype old_type): Datatype(size, lb, ub, flags), block_count_(count), old_type_(old_type){
137   old_type->ref();
138   block_lengths_ = new int[count];
139   block_indices_ = new int[count];
140   for (int i = 0; i < count; i++) {
141     block_lengths_[i]=block_lengths[i];
142     block_indices_[i]=block_indices[i];
143   }
144 }
145
146 Type_Indexed::~Type_Indexed(){
147   Datatype::unref(old_type_);
148   if(refcount()==0){
149     delete[] block_lengths_;
150     delete[] block_indices_;
151   }
152 }
153
154
155 void Type_Indexed::serialize( void* noncontiguous_buf, void *contiguous_buf,
156                         int count){
157   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
158   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf)+block_indices_[0] * old_type_->size();
159   for (int j = 0; j < count; j++) {
160     for (int i = 0; i < block_count_; i++) {
161       if (not(old_type_->flags() & DT_FLAG_DERIVED))
162         memcpy(contiguous_buf_char, noncontiguous_buf_char, block_lengths_[i] * old_type_->size());
163       else
164         old_type_->serialize( noncontiguous_buf_char, contiguous_buf_char, block_lengths_[i]);
165
166       contiguous_buf_char += block_lengths_[i]*old_type_->size();
167       if (i<block_count_-1)
168         noncontiguous_buf_char =
169           static_cast<char*>(noncontiguous_buf) + block_indices_[i+1]*old_type_->get_extent();
170       else
171         noncontiguous_buf_char += block_lengths_[i]*old_type_->get_extent();
172     }
173     noncontiguous_buf=static_cast< void*>(noncontiguous_buf_char);
174   }
175 }
176
177
178 void Type_Indexed::unserialize( void* contiguous_buf, void *noncontiguous_buf,
179                       int count, MPI_Op op){
180   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
181   char* noncontiguous_buf_char =
182     static_cast<char*>(noncontiguous_buf)+block_indices_[0]*old_type_->get_extent();
183   for (int j = 0; j < count; j++) {
184     for (int i = 0; i < block_count_; i++) {
185       if (not(old_type_->flags() & DT_FLAG_DERIVED)) {
186         if(op!=MPI_OP_NULL)
187           op->apply( contiguous_buf_char, noncontiguous_buf_char, &block_lengths_[i],
188                     old_type_);
189       }else
190         old_type_->unserialize( contiguous_buf_char,noncontiguous_buf_char,block_lengths_[i], op);
191
192       contiguous_buf_char += block_lengths_[i]*old_type_->size();
193       if (i<block_count_-1)
194         noncontiguous_buf_char =
195           static_cast<char*>(noncontiguous_buf) + block_indices_[i+1]*old_type_->get_extent();
196       else
197         noncontiguous_buf_char += block_lengths_[i]*old_type_->get_extent();
198     }
199     noncontiguous_buf=static_cast<void*>(noncontiguous_buf_char);
200   }
201 }
202
203 Type_Hindexed::Type_Hindexed(int size,MPI_Aint lb, MPI_Aint ub, int flags, int count, int* block_lengths, MPI_Aint* block_indices, MPI_Datatype old_type)
204 : Datatype(size, lb, ub, flags), block_count_(count), old_type_(old_type)
205 {
206   old_type_->ref();
207   block_lengths_ = new int[count];
208   block_indices_ = new MPI_Aint[count];
209   for (int i = 0; i < count; i++) {
210     block_lengths_[i]=block_lengths[i];
211     block_indices_[i]=block_indices[i];
212   }
213 }
214
215     Type_Hindexed::~Type_Hindexed(){
216   Datatype::unref(old_type_);
217   if(refcount()==0){
218     delete[] block_lengths_;
219     delete[] block_indices_;
220   }
221 }
222
223 void Type_Hindexed::serialize( void* noncontiguous_buf, void *contiguous_buf,
224                 int count){
225   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
226   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf)+ block_indices_[0];
227   for (int j = 0; j < count; j++) {
228     for (int i = 0; i < block_count_; i++) {
229       if (not(old_type_->flags() & DT_FLAG_DERIVED))
230         memcpy(contiguous_buf_char, noncontiguous_buf_char, block_lengths_[i] * old_type_->size());
231       else
232         old_type_->serialize(noncontiguous_buf_char, contiguous_buf_char,block_lengths_[i]);
233
234       contiguous_buf_char += block_lengths_[i]*old_type_->size();
235       if (i<block_count_-1)
236         noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf) + block_indices_[i+1];
237       else
238         noncontiguous_buf_char += block_lengths_[i]*old_type_->get_extent();
239     }
240     noncontiguous_buf=static_cast<void*>(noncontiguous_buf_char);
241   }
242 }
243
244 void Type_Hindexed::unserialize( void* contiguous_buf, void *noncontiguous_buf,
245                           int count, MPI_Op op){
246   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
247   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf)+ block_indices_[0];
248   for (int j = 0; j < count; j++) {
249     for (int i = 0; i < block_count_; i++) {
250       if (not(old_type_->flags() & DT_FLAG_DERIVED)) {
251         if(op!=MPI_OP_NULL)
252           op->apply( contiguous_buf_char, noncontiguous_buf_char, &block_lengths_[i],
253                             old_type_);
254       }else
255         old_type_->unserialize( contiguous_buf_char,noncontiguous_buf_char,block_lengths_[i], op);
256
257       contiguous_buf_char += block_lengths_[i]*old_type_->size();
258       if (i<block_count_-1)
259         noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf) + block_indices_[i+1];
260       else
261         noncontiguous_buf_char += block_lengths_[i]*old_type_->get_extent();
262     }
263     noncontiguous_buf=static_cast<void*>(noncontiguous_buf_char);
264   }
265 }
266
267 Type_Struct::Type_Struct(int size,MPI_Aint lb, MPI_Aint ub, int flags, int count, int* block_lengths, MPI_Aint* block_indices, MPI_Datatype* old_types): Datatype(size, lb, ub, flags), block_count_(count), block_lengths_(block_lengths), block_indices_(block_indices), old_types_(old_types){
268   block_lengths_= new int[count];
269   block_indices_= new MPI_Aint[count];
270   old_types_=  new MPI_Datatype[count];
271   for (int i = 0; i < count; i++) {
272     block_lengths_[i]=block_lengths[i];
273     block_indices_[i]=block_indices[i];
274     old_types_[i]=old_types[i];
275     old_types_[i]->ref();
276   }
277 }
278
279 Type_Struct::~Type_Struct(){
280   for (int i = 0; i < block_count_; i++) {
281     Datatype::unref(old_types_[i]);
282   }
283   if(refcount()==0){
284     delete[] block_lengths_;
285     delete[] block_indices_;
286     delete[] old_types_;
287   }
288 }
289
290
291 void Type_Struct::serialize( void* noncontiguous_buf, void *contiguous_buf,
292                         int count){
293   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
294   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf)+ block_indices_[0];
295   for (int j = 0; j < count; j++) {
296     for (int i = 0; i < block_count_; i++) {
297       if (not(old_types_[i]->flags() & DT_FLAG_DERIVED))
298         memcpy(contiguous_buf_char, noncontiguous_buf_char, block_lengths_[i] * old_types_[i]->size());
299       else
300         old_types_[i]->serialize( noncontiguous_buf_char,contiguous_buf_char,block_lengths_[i]);
301
302
303       contiguous_buf_char += block_lengths_[i]*old_types_[i]->size();
304       if (i<block_count_-1)
305         noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf) + block_indices_[i+1];
306       else //let's hope this is MPI_UB ?
307         noncontiguous_buf_char += block_lengths_[i]*old_types_[i]->get_extent();
308     }
309     noncontiguous_buf=static_cast<void*>(noncontiguous_buf_char);
310   }
311 }
312
313 void Type_Struct::unserialize( void* contiguous_buf, void *noncontiguous_buf,
314                               int count, MPI_Op op){
315   char* contiguous_buf_char = static_cast<char*>(contiguous_buf);
316   char* noncontiguous_buf_char = static_cast<char*>(noncontiguous_buf)+ block_indices_[0];
317   for (int j = 0; j < count; j++) {
318     for (int i = 0; i < block_count_; i++) {
319       if (not(old_types_[i]->flags() & DT_FLAG_DERIVED)) {
320         if(op!=MPI_OP_NULL)
321           op->apply( contiguous_buf_char, noncontiguous_buf_char, &block_lengths_[i], old_types_[i]);
322       }else
323         old_types_[i]->unserialize( contiguous_buf_char, noncontiguous_buf_char,block_lengths_[i], op);
324
325       contiguous_buf_char += block_lengths_[i]*old_types_[i]->size();
326       if (i<block_count_-1)
327         noncontiguous_buf_char =  static_cast<char*>(noncontiguous_buf) + block_indices_[i+1];
328       else
329         noncontiguous_buf_char += block_lengths_[i]*old_types_[i]->get_extent();
330     }
331     noncontiguous_buf=static_cast<void*>(noncontiguous_buf_char);
332   }
333 }
334
335 }
336 }