Line data Source code
1 : /*----------------------------------------------------------------------------*/
2 : /* CP2K: A general program to perform molecular dynamics simulations */
3 : /* Copyright 2000-2026 CP2K developers group <https://cp2k.org> */
4 : /* */
5 : /* SPDX-License-Identifier: BSD-3-Clause */
6 : /*----------------------------------------------------------------------------*/
7 : #include "dbm_multiply_comm.h"
8 : #include "../mpiwrap/cp_mpi.h"
9 : #include "../offload/offload_mempool.h"
10 :
11 : #include <assert.h>
12 : #include <stdlib.h>
13 : #include <string.h>
14 :
15 : #if 1
16 : #define DBM_MULTIPLY_COMM_MEMPOOL
17 : #endif
18 :
19 : /*******************************************************************************
20 : * \brief Private routine for computing greatest common divisor of two numbers.
21 : * \author Ole Schuett
22 : ******************************************************************************/
23 445458 : static int gcd(const int a, const int b) {
24 445458 : if (a == 0) {
25 : return b;
26 : }
27 232254 : return gcd(b % a, a); // Euclid's algorithm.
28 : }
29 :
30 : /*******************************************************************************
31 : * \brief Private routine for computing least common multiple of two numbers.
32 : * \author Ole Schuett
33 : ******************************************************************************/
34 213204 : static int lcm(const int a, const int b) { return (a * b) / gcd(a, b); }
35 :
36 : /*******************************************************************************
37 : * \brief Private routine for computing the sum of the given integers.
38 : * \author Ole Schuett
39 : ******************************************************************************/
40 891052 : static inline int isum(const int n, const int input[n]) {
41 891052 : int output = 0;
42 1896812 : for (int i = 0; i < n; i++) {
43 1005760 : output += input[i];
44 : }
45 891052 : return output;
46 : }
47 :
48 : /*******************************************************************************
49 : * \brief Private routine for computing the cumulative sums of given numbers.
50 : * \author Ole Schuett
51 : ******************************************************************************/
52 2227630 : static inline void icumsum(const int n, const int input[n], int output[n]) {
53 2227630 : output[0] = 0;
54 2457046 : for (int i = 1; i < n; i++) {
55 229416 : output[i] = output[i - 1] + input[i - 1];
56 : }
57 2227630 : }
58 :
59 : /*******************************************************************************
60 : * \brief Private struct used for planing during pack_matrix.
61 : * \author Ole Schuett
62 : ******************************************************************************/
63 : typedef struct {
64 : const dbm_block_t *blk; // source block
65 : int rank; // target mpi rank
66 : int row_size;
67 : int col_size;
68 : } plan_t;
69 :
70 : /*******************************************************************************
71 : * \brief Private routine for calculating tick indices in pack plans.
72 : * \author Maximilian Graml
73 : ******************************************************************************/
74 : static inline unsigned long long calculate_tick_index(int sum_index,
75 : int nticks) {
76 : // 1021 is used as a random prime to scramble the index
77 : return ((unsigned long long)sum_index * 1021ULL) % (unsigned long long)nticks;
78 : }
79 :
80 : /*******************************************************************************
81 : * \brief Private routine for planing packs.
82 : * \author Ole Schuett
83 : ******************************************************************************/
84 426408 : static void create_pack_plans(const bool trans_matrix, const bool trans_dist,
85 : const dbm_matrix_t *matrix,
86 : const cp_mpi_comm_t comm,
87 : const dbm_dist_1d_t *dist_indices,
88 : const dbm_dist_1d_t *dist_ticks, const int nticks,
89 : const int npacks, plan_t *plans_per_pack[npacks],
90 : int nblks_per_pack[npacks],
91 : int ndata_per_pack[npacks]) {
92 :
93 426408 : memset(nblks_per_pack, 0, npacks * sizeof(int));
94 426408 : memset(ndata_per_pack, 0, npacks * sizeof(int));
95 :
96 426408 : #pragma omp parallel
97 : {
98 : // 1st pass: Compute number of blocks that will be send in each pack.
99 : int nblks_mythread[npacks];
100 : memset(nblks_mythread, 0, npacks * sizeof(int));
101 : #pragma omp for schedule(static)
102 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
103 : dbm_shard_t *shard = &matrix->shards[ishard];
104 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
105 : const dbm_block_t *blk = &shard->blocks[iblock];
106 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
107 : unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
108 : const int ipack = itick64 / dist_ticks->nranks;
109 : nblks_mythread[ipack]++;
110 : }
111 : }
112 :
113 : // Sum nblocks across threads and allocate arrays for plans.
114 : #pragma omp critical
115 : for (int ipack = 0; ipack < npacks; ipack++) {
116 : nblks_per_pack[ipack] += nblks_mythread[ipack];
117 : nblks_mythread[ipack] = nblks_per_pack[ipack];
118 : }
119 : #pragma omp barrier
120 : #pragma omp for
121 : for (int ipack = 0; ipack < npacks; ipack++) {
122 : const int nblks = nblks_per_pack[ipack];
123 : plans_per_pack[ipack] = malloc(nblks * sizeof(plan_t));
124 : assert(plans_per_pack[ipack] != NULL || nblks == 0);
125 : }
126 :
127 : // 2nd pass: Plan where to send each block.
128 : int ndata_mythread[npacks];
129 : memset(ndata_mythread, 0, npacks * sizeof(int));
130 : #pragma omp for schedule(static) // Need static to match previous loop.
131 : for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
132 : dbm_shard_t *shard = &matrix->shards[ishard];
133 : for (int iblock = 0; iblock < shard->nblocks; iblock++) {
134 : const dbm_block_t *blk = &shard->blocks[iblock];
135 : const int free_index = (trans_matrix) ? blk->col : blk->row;
136 : const int sum_index = (trans_matrix) ? blk->row : blk->col;
137 : unsigned long long itick64 = calculate_tick_index(sum_index, nticks);
138 : const int ipack = itick64 / dist_ticks->nranks;
139 : // Compute rank to which this block should be sent.
140 : const int coord_free_idx = dist_indices->index2coord[free_index];
141 : const int coord_sum_idx = itick64 % dist_ticks->nranks;
142 : const int coords[2] = {(trans_dist) ? coord_sum_idx : coord_free_idx,
143 : (trans_dist) ? coord_free_idx : coord_sum_idx};
144 : const int rank = cp_mpi_cart_rank(comm, coords);
145 : const int row_size = matrix->row_sizes[blk->row];
146 : const int col_size = matrix->col_sizes[blk->col];
147 : ndata_mythread[ipack] += row_size * col_size;
148 : // Create plan.
149 : const int iplan = --nblks_mythread[ipack];
150 : plans_per_pack[ipack][iplan].blk = blk;
151 : plans_per_pack[ipack][iplan].rank = rank;
152 : plans_per_pack[ipack][iplan].row_size = row_size;
153 : plans_per_pack[ipack][iplan].col_size = col_size;
154 : }
155 : }
156 : #pragma omp critical
157 : for (int ipack = 0; ipack < npacks; ipack++) {
158 : ndata_per_pack[ipack] += ndata_mythread[ipack];
159 : }
160 : } // end of omp parallel region
161 426408 : }
162 :
163 : /*******************************************************************************
164 : * \brief Private routine for filling send buffers.
165 : * \author Ole Schuett
166 : ******************************************************************************/
167 445526 : static void fill_send_buffers(
168 : const dbm_matrix_t *matrix, const bool trans_matrix, const int nblks_send,
169 : const int ndata_send, plan_t plans[nblks_send], const int nranks,
170 : int blks_send_count[nranks], int data_send_count[nranks],
171 : int blks_send_displ[nranks], int data_send_displ[nranks],
172 : dbm_pack_block_t blks_send[nblks_send], double data_send[ndata_send]) {
173 :
174 445526 : memset(blks_send_count, 0, nranks * sizeof(int));
175 445526 : memset(data_send_count, 0, nranks * sizeof(int));
176 :
177 445526 : #pragma omp parallel
178 : {
179 : // 3th pass: Compute per rank nblks and ndata.
180 : int nblks_mythread[nranks], ndata_mythread[nranks];
181 : memset(nblks_mythread, 0, nranks * sizeof(int));
182 : memset(ndata_mythread, 0, nranks * sizeof(int));
183 : #pragma omp for schedule(static)
184 : for (int iblock = 0; iblock < nblks_send; iblock++) {
185 : const plan_t *plan = &plans[iblock];
186 : nblks_mythread[plan->rank] += 1;
187 : ndata_mythread[plan->rank] += plan->row_size * plan->col_size;
188 : }
189 :
190 : // Sum nblks and ndata across threads.
191 : #pragma omp critical
192 : for (int irank = 0; irank < nranks; irank++) {
193 : blks_send_count[irank] += nblks_mythread[irank];
194 : data_send_count[irank] += ndata_mythread[irank];
195 : nblks_mythread[irank] = blks_send_count[irank];
196 : ndata_mythread[irank] = data_send_count[irank];
197 : }
198 : #pragma omp barrier
199 :
200 : // Compute send displacements.
201 : #pragma omp master
202 : {
203 : icumsum(nranks, blks_send_count, blks_send_displ);
204 : icumsum(nranks, data_send_count, data_send_displ);
205 : const int m = nranks - 1;
206 : assert(nblks_send == blks_send_displ[m] + blks_send_count[m]);
207 : assert(ndata_send == data_send_displ[m] + data_send_count[m]);
208 : }
209 : #pragma omp barrier
210 :
211 : // 4th pass: Fill blks_send and data_send arrays.
212 : #pragma omp for schedule(static) // Need static to match previous loop.
213 : for (int iblock = 0; iblock < nblks_send; iblock++) {
214 : const plan_t *plan = &plans[iblock];
215 : const dbm_block_t *blk = plan->blk;
216 : const int ishard = dbm_get_shard_index(matrix, blk->row, blk->col);
217 : const dbm_shard_t *shard = &matrix->shards[ishard];
218 : const double *blk_data = &shard->data[blk->offset];
219 : const int row_size = plan->row_size, col_size = plan->col_size;
220 : const int plan_size = row_size * col_size;
221 : const int irank = plan->rank;
222 :
223 : // The blk_send_data is ordered by rank, thread, and block.
224 : // data_send_displ[irank]: Start of data for irank within blk_send_data.
225 : // ndata_mythread[irank]: Current threads offset within data for irank.
226 : nblks_mythread[irank] -= 1;
227 : ndata_mythread[irank] -= plan_size;
228 : const int offset = data_send_displ[irank] + ndata_mythread[irank];
229 : const int jblock = blks_send_displ[irank] + nblks_mythread[irank];
230 :
231 : double norm = 0.0; // Compute norm as double...
232 : if (trans_matrix) {
233 : // Transpose block to allow for outer-product style multiplication.
234 : for (int i = 0; i < row_size; i++) {
235 : for (int j = 0; j < col_size; j++) {
236 : const double element = blk_data[j * row_size + i];
237 : data_send[offset + i * col_size + j] = element;
238 : norm += element * element;
239 : }
240 : }
241 : blks_send[jblock].free_index = plan->blk->col;
242 : blks_send[jblock].sum_index = plan->blk->row;
243 : } else {
244 : for (int i = 0; i < plan_size; i++) {
245 : const double element = blk_data[i];
246 : data_send[offset + i] = element;
247 : norm += element * element;
248 : }
249 : blks_send[jblock].free_index = plan->blk->row;
250 : blks_send[jblock].sum_index = plan->blk->col;
251 : }
252 : blks_send[jblock].norm = (float)norm; // ...store norm as float.
253 :
254 : // After the block exchange data_recv_displ will be added to the offsets.
255 : blks_send[jblock].offset = offset - data_send_displ[irank];
256 : }
257 : } // end of omp parallel region
258 445526 : }
259 :
260 : /*******************************************************************************
261 : * \brief Private comperator passed to qsort to compare two blocks by sum_index.
262 : * \author Ole Schuett
263 : ******************************************************************************/
264 70734217 : static int compare_pack_blocks_by_sum_index(const void *a, const void *b) {
265 70734217 : const dbm_pack_block_t *blk_a = (const dbm_pack_block_t *)a;
266 70734217 : const dbm_pack_block_t *blk_b = (const dbm_pack_block_t *)b;
267 70734217 : return blk_a->sum_index - blk_b->sum_index;
268 : }
269 :
270 : /*******************************************************************************
271 : * \brief Private routine for post-processing received blocks.
272 : * \author Ole Schuett
273 : ******************************************************************************/
274 445526 : static void postprocess_received_blocks(
275 : const int nranks, const int nshards, const int nblocks_recv,
276 : const int blks_recv_count[nranks], const int blks_recv_displ[nranks],
277 : const int data_recv_displ[nranks],
278 445526 : dbm_pack_block_t blks_recv[nblocks_recv]) {
279 :
280 445526 : int nblocks_per_shard[nshards], shard_start[nshards];
281 445526 : memset(nblocks_per_shard, 0, nshards * sizeof(int));
282 445526 : dbm_pack_block_t *blocks_tmp =
283 445526 : malloc(nblocks_recv * sizeof(dbm_pack_block_t));
284 445526 : assert(blocks_tmp != NULL || nblocks_recv == 0);
285 :
286 445526 : #pragma omp parallel
287 : {
288 : // Add data_recv_displ to recveived block offsets.
289 : for (int irank = 0; irank < nranks; irank++) {
290 : #pragma omp for
291 : for (int i = 0; i < blks_recv_count[irank]; i++) {
292 : blks_recv[blks_recv_displ[irank] + i].offset += data_recv_displ[irank];
293 : }
294 : }
295 :
296 : // First use counting sort to group blocks by their free_index shard.
297 : int nblocks_mythread[nshards];
298 : memset(nblocks_mythread, 0, nshards * sizeof(int));
299 : #pragma omp for schedule(static)
300 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
301 : blocks_tmp[iblock] = blks_recv[iblock];
302 : const int ishard = blks_recv[iblock].free_index % nshards;
303 : nblocks_mythread[ishard]++;
304 : }
305 : #pragma omp critical
306 : for (int ishard = 0; ishard < nshards; ishard++) {
307 : nblocks_per_shard[ishard] += nblocks_mythread[ishard];
308 : nblocks_mythread[ishard] = nblocks_per_shard[ishard];
309 : }
310 : #pragma omp barrier
311 : #pragma omp master
312 : icumsum(nshards, nblocks_per_shard, shard_start);
313 : #pragma omp barrier
314 : #pragma omp for schedule(static) // Need static to match previous loop.
315 : for (int iblock = 0; iblock < nblocks_recv; iblock++) {
316 : const int ishard = blocks_tmp[iblock].free_index % nshards;
317 : const int jblock = --nblocks_mythread[ishard] + shard_start[ishard];
318 : blks_recv[jblock] = blocks_tmp[iblock];
319 : }
320 :
321 : // Then sort blocks within each shard by their sum_index.
322 : #pragma omp for
323 : for (int ishard = 0; ishard < nshards; ishard++) {
324 : if (nblocks_per_shard[ishard] > 1) {
325 : qsort(&blks_recv[shard_start[ishard]], nblocks_per_shard[ishard],
326 : sizeof(dbm_pack_block_t), &compare_pack_blocks_by_sum_index);
327 : }
328 : }
329 : } // end of omp parallel region
330 :
331 445526 : free(blocks_tmp);
332 445526 : }
333 :
334 : /*******************************************************************************
335 : * \brief Private routine for redistributing a matrix along selected dimensions.
336 : * \author Ole Schuett
337 : ******************************************************************************/
338 426408 : static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
339 : const bool trans_dist,
340 : const dbm_matrix_t *matrix,
341 : const dbm_distribution_t *dist,
342 426408 : const int nticks) {
343 :
344 426408 : assert(cp_mpi_comms_are_similar(matrix->dist->comm, dist->comm));
345 :
346 : // The row/col indicies are distributed along one cart dimension and the
347 : // ticks are distributed along the other cart dimension.
348 426408 : const dbm_dist_1d_t *dist_indices = (trans_dist) ? &dist->cols : &dist->rows;
349 426408 : const dbm_dist_1d_t *dist_ticks = (trans_dist) ? &dist->rows : &dist->cols;
350 :
351 : // Allocate packed matrix.
352 426408 : const int nsend_packs = nticks / dist_ticks->nranks;
353 426408 : assert(nsend_packs * dist_ticks->nranks == nticks);
354 426408 : dbm_packed_matrix_t packed;
355 426408 : packed.dist_indices = dist_indices;
356 426408 : packed.dist_ticks = dist_ticks;
357 426408 : packed.nsend_packs = nsend_packs;
358 426408 : packed.send_packs = malloc(nsend_packs * sizeof(dbm_pack_t));
359 426408 : assert(packed.send_packs != NULL || nsend_packs == 0);
360 :
361 : // Plan all packs.
362 426408 : plan_t *plans_per_pack[nsend_packs];
363 426408 : int nblks_send_per_pack[nsend_packs], ndata_send_per_pack[nsend_packs];
364 426408 : create_pack_plans(trans_matrix, trans_dist, matrix, dist->comm, dist_indices,
365 : dist_ticks, nticks, nsend_packs, plans_per_pack,
366 : nblks_send_per_pack, ndata_send_per_pack);
367 :
368 : // Allocate send buffers for maximum number of blocks/data over all packs.
369 426408 : int nblks_send_max = 0, ndata_send_max = 0;
370 871934 : for (int ipack = 0; ipack < nsend_packs; ++ipack) {
371 445526 : nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
372 445526 : ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
373 : }
374 426408 : dbm_pack_block_t *blks_send =
375 426408 : cp_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
376 426408 : double *data_send = cp_mpi_alloc_mem(ndata_send_max * sizeof(double));
377 :
378 : // Cannot parallelize over packs (there might be too few of them).
379 871934 : for (int ipack = 0; ipack < nsend_packs; ipack++) {
380 : // Fill send buffers according to plans.
381 445526 : const int nranks = dist->nranks;
382 445526 : int blks_send_count[nranks], data_send_count[nranks];
383 445526 : int blks_send_displ[nranks], data_send_displ[nranks];
384 445526 : fill_send_buffers(matrix, trans_matrix, nblks_send_per_pack[ipack],
385 : ndata_send_per_pack[ipack], plans_per_pack[ipack], nranks,
386 : blks_send_count, data_send_count, blks_send_displ,
387 : data_send_displ, blks_send, data_send);
388 445526 : free(plans_per_pack[ipack]);
389 :
390 : // 1st communication: Exchange block counts.
391 445526 : int blks_recv_count[nranks], blks_recv_displ[nranks];
392 445526 : cp_mpi_alltoall_int(blks_send_count, 1, blks_recv_count, 1, dist->comm);
393 445526 : icumsum(nranks, blks_recv_count, blks_recv_displ);
394 445526 : const int nblocks_recv = isum(nranks, blks_recv_count);
395 :
396 : // 2nd communication: Exchange blocks.
397 445526 : dbm_pack_block_t *blks_recv =
398 445526 : cp_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
399 445526 : int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
400 445526 : int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
401 948406 : for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
402 502880 : blks_send_count_byte[i] = blks_send_count[i] * sizeof(dbm_pack_block_t);
403 502880 : blks_send_displ_byte[i] = blks_send_displ[i] * sizeof(dbm_pack_block_t);
404 502880 : blks_recv_count_byte[i] = blks_recv_count[i] * sizeof(dbm_pack_block_t);
405 502880 : blks_recv_displ_byte[i] = blks_recv_displ[i] * sizeof(dbm_pack_block_t);
406 : }
407 445526 : cp_mpi_alltoallv_byte(blks_send, blks_send_count_byte, blks_send_displ_byte,
408 : blks_recv, blks_recv_count_byte, blks_recv_displ_byte,
409 445526 : dist->comm);
410 :
411 : // 3rd communication: Exchange data counts.
412 : // TODO: could be computed from blks_recv.
413 445526 : int data_recv_count[nranks], data_recv_displ[nranks];
414 445526 : cp_mpi_alltoall_int(data_send_count, 1, data_recv_count, 1, dist->comm);
415 445526 : icumsum(nranks, data_recv_count, data_recv_displ);
416 445526 : const int ndata_recv = isum(nranks, data_recv_count);
417 :
418 : // 4th communication: Exchange data.
419 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
420 445526 : double *data_recv =
421 445526 : offload_mempool_host_malloc(ndata_recv * sizeof(double));
422 : #else
423 : double *data_recv = cp_mpi_alloc_mem(ndata_recv * sizeof(double));
424 : #endif
425 445526 : cp_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
426 : data_recv, data_recv_count, data_recv_displ,
427 445526 : dist->comm);
428 :
429 : // Post-process received blocks and assemble them into a pack.
430 445526 : postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
431 : blks_recv_count, blks_recv_displ,
432 : data_recv_displ, blks_recv);
433 445526 : packed.send_packs[ipack].nblocks = nblocks_recv;
434 445526 : packed.send_packs[ipack].data_size = ndata_recv;
435 445526 : packed.send_packs[ipack].blocks = blks_recv;
436 445526 : packed.send_packs[ipack].data = data_recv;
437 : }
438 :
439 : // Deallocate send buffers.
440 426408 : cp_mpi_free_mem(blks_send);
441 426408 : cp_mpi_free_mem(data_send);
442 :
443 : // Allocate pack_recv.
444 426408 : int max_nblocks = 0, max_data_size = 0;
445 871934 : for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
446 445526 : max_nblocks = imax(max_nblocks, packed.send_packs[ipack].nblocks);
447 445526 : max_data_size = imax(max_data_size, packed.send_packs[ipack].data_size);
448 : }
449 426408 : cp_mpi_max_int(&max_nblocks, 1, packed.dist_ticks->comm);
450 426408 : cp_mpi_max_int(&max_data_size, 1, packed.dist_ticks->comm);
451 426408 : packed.max_nblocks = max_nblocks;
452 426408 : packed.max_data_size = max_data_size;
453 852816 : packed.recv_pack.blocks =
454 426408 : cp_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
455 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
456 852816 : packed.recv_pack.data =
457 426408 : offload_mempool_host_malloc(packed.max_data_size * sizeof(double));
458 : #else
459 : packed.recv_pack.data =
460 : cp_mpi_alloc_mem(packed.max_data_size * sizeof(double));
461 : #endif
462 :
463 426408 : return packed; // Ownership of packed transfers to caller.
464 : }
465 :
466 : /*******************************************************************************
467 : * \brief Private routine for sending and receiving the pack for the given tick.
468 : * \author Ole Schuett
469 : ******************************************************************************/
470 464644 : static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
471 : dbm_packed_matrix_t *packed) {
472 464644 : const int nranks = packed->dist_ticks->nranks;
473 464644 : const int my_rank = packed->dist_ticks->my_rank;
474 :
475 : // Compute send rank and pack.
476 464644 : const int itick_of_rank0 = (itick + nticks - my_rank) % nticks;
477 464644 : const int send_rank = (my_rank + nticks - itick_of_rank0) % nranks;
478 464644 : const int send_itick = (itick_of_rank0 + send_rank) % nticks;
479 464644 : const int send_ipack = send_itick / nranks;
480 464644 : assert(send_itick % nranks == my_rank);
481 :
482 : // Compute receive rank and pack.
483 464644 : const int recv_rank = itick % nranks;
484 464644 : const int recv_ipack = itick / nranks;
485 :
486 464644 : dbm_pack_t *send_pack = &packed->send_packs[send_ipack];
487 464644 : if (send_rank == my_rank) {
488 445526 : assert(send_rank == recv_rank && send_ipack == recv_ipack);
489 : return send_pack; // Local pack, no mpi needed.
490 : } else {
491 : // Exchange blocks.
492 38236 : const int nblocks_in_bytes = cp_mpi_sendrecv_byte(
493 19118 : /*sendbuf=*/send_pack->blocks,
494 19118 : /*sendcound=*/send_pack->nblocks * sizeof(dbm_pack_block_t),
495 : /*dest=*/send_rank,
496 : /*sendtag=*/send_ipack,
497 19118 : /*recvbuf=*/packed->recv_pack.blocks,
498 19118 : /*recvcount=*/packed->max_nblocks * sizeof(dbm_pack_block_t),
499 : /*source=*/recv_rank,
500 : /*recvtag=*/recv_ipack,
501 19118 : /*comm=*/packed->dist_ticks->comm);
502 :
503 19118 : assert(nblocks_in_bytes % sizeof(dbm_pack_block_t) == 0);
504 19118 : packed->recv_pack.nblocks = nblocks_in_bytes / sizeof(dbm_pack_block_t);
505 :
506 : // Exchange data.
507 38236 : packed->recv_pack.data_size = cp_mpi_sendrecv_double(
508 19118 : /*sendbuf=*/send_pack->data,
509 : /*sendcound=*/send_pack->data_size,
510 : /*dest=*/send_rank,
511 : /*sendtag=*/send_ipack,
512 : /*recvbuf=*/packed->recv_pack.data,
513 : /*recvcount=*/packed->max_data_size,
514 : /*source=*/recv_rank,
515 : /*recvtag=*/recv_ipack,
516 19118 : /*comm=*/packed->dist_ticks->comm);
517 :
518 19118 : return &packed->recv_pack;
519 : }
520 : }
521 :
522 : /*******************************************************************************
523 : * \brief Private routine for releasing a packed matrix.
524 : * \author Ole Schuett
525 : ******************************************************************************/
526 426408 : static void free_packed_matrix(dbm_packed_matrix_t *packed) {
527 426408 : cp_mpi_free_mem(packed->recv_pack.blocks);
528 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
529 426408 : offload_mempool_host_free(packed->recv_pack.data);
530 : #else
531 : cp_mpi_free_mem(packed->recv_pack.data);
532 : #endif
533 871934 : for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
534 445526 : cp_mpi_free_mem(packed->send_packs[ipack].blocks);
535 : #if defined(DBM_MULTIPLY_COMM_MEMPOOL)
536 445526 : offload_mempool_host_free(packed->send_packs[ipack].data);
537 : #else
538 : cp_mpi_free_mem(packed->send_packs[ipack].data);
539 : #endif
540 : }
541 426408 : free(packed->send_packs);
542 426408 : }
543 :
544 : /*******************************************************************************
545 : * \brief Internal routine for creating a communication iterator.
546 : * \author Ole Schuett
547 : ******************************************************************************/
548 213204 : dbm_comm_iterator_t *dbm_comm_iterator_start(const bool transa,
549 : const bool transb,
550 : const dbm_matrix_t *matrix_a,
551 : const dbm_matrix_t *matrix_b,
552 : const dbm_matrix_t *matrix_c) {
553 :
554 213204 : dbm_comm_iterator_t *iter = malloc(sizeof(dbm_comm_iterator_t));
555 213204 : assert(iter != NULL);
556 213204 : iter->dist = matrix_c->dist;
557 :
558 : // During each communication tick we'll fetch a pack_a and pack_b.
559 : // Since the cart might be non-squared, the number of communication ticks is
560 : // chosen as the least common multiple of the cart's dimensions.
561 213204 : iter->nticks = lcm(iter->dist->rows.nranks, iter->dist->cols.nranks);
562 213204 : iter->itick = 0;
563 :
564 : // 1.arg=source dimension, 2.arg=target dimension, false=rows, true=columns.
565 213204 : iter->packed_a =
566 213204 : pack_matrix(transa, false, matrix_a, iter->dist, iter->nticks);
567 213204 : iter->packed_b =
568 213204 : pack_matrix(!transb, true, matrix_b, iter->dist, iter->nticks);
569 :
570 213204 : return iter;
571 : }
572 :
573 : /*******************************************************************************
574 : * \brief Internal routine for retriving next pair of packs from given iterator.
575 : * \author Ole Schuett
576 : ******************************************************************************/
577 445526 : bool dbm_comm_iterator_next(dbm_comm_iterator_t *iter, dbm_pack_t **pack_a,
578 : dbm_pack_t **pack_b) {
579 445526 : if (iter->itick >= iter->nticks) {
580 : return false; // end of iterator reached
581 : }
582 :
583 : // Start each rank at a different tick to spread the load on the sources.
584 232322 : const int shift = iter->dist->rows.my_rank + iter->dist->cols.my_rank;
585 232322 : const int shifted_itick = (iter->itick + shift) % iter->nticks;
586 232322 : *pack_a = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_a);
587 232322 : *pack_b = sendrecv_pack(shifted_itick, iter->nticks, &iter->packed_b);
588 :
589 232322 : iter->itick++;
590 232322 : return true;
591 : }
592 :
593 : /*******************************************************************************
594 : * \brief Internal routine for releasing the given communication iterator.
595 : * \author Ole Schuett
596 : ******************************************************************************/
597 213204 : void dbm_comm_iterator_stop(dbm_comm_iterator_t *iter) {
598 213204 : free_packed_matrix(&iter->packed_a);
599 213204 : free_packed_matrix(&iter->packed_b);
600 213204 : free(iter);
601 213204 : }
602 :
603 : // EOF
|