Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Matrix multiplication for tall-and-skinny matrices.
10 : !> This uses the k-split (non-recursive) CARMA algorithm that is communication-optimal
11 : !> as long as the two smaller dimensions have the same size.
12 : !> Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
13 : !> submatrices uses DBM Cannon algorithm. Due to unknown sparsity pattern of result matrix,
14 : !> parameters (group sizes and process grid dimensions) can not be derived from matrix
15 : !> dimensions and need to be set manually.
16 : !> \author Patrick Seewald
17 : ! **************************************************************************************************
18 : MODULE dbt_tas_mm
19 : USE dbm_api, ONLY: &
20 : dbm_add, dbm_clear, dbm_copy, dbm_create, dbm_create_from_template, dbm_distribution_new, &
21 : dbm_distribution_obj, dbm_distribution_release, dbm_get_col_block_sizes, &
22 : dbm_get_distribution, dbm_get_name, dbm_get_nze, dbm_get_row_block_sizes, dbm_multiply, &
23 : dbm_redistribute, dbm_release, dbm_scale, dbm_type, dbm_zero
24 : USE dbt_tas_base, ONLY: &
25 : dbt_tas_clear, dbt_tas_copy, dbt_tas_create, dbt_tas_destroy, dbt_tas_distribution_new, &
26 : dbt_tas_filter, dbt_tas_get_info, dbt_tas_get_nze_total, dbt_tas_info, &
27 : dbt_tas_iterator_blocks_left, dbt_tas_iterator_next_block, dbt_tas_iterator_start, &
28 : dbt_tas_iterator_stop, dbt_tas_nblkcols_total, dbt_tas_nblkrows_total, dbt_tas_put_block, &
29 : dbt_tas_reserve_blocks
30 : USE dbt_tas_global, ONLY: dbt_tas_blk_size_one,&
31 : dbt_tas_default_distvec,&
32 : dbt_tas_dist_arb,&
33 : dbt_tas_dist_arb_default,&
34 : dbt_tas_dist_cyclic,&
35 : dbt_tas_distribution,&
36 : dbt_tas_rowcol_data
37 : USE dbt_tas_io, ONLY: dbt_tas_write_dist,&
38 : dbt_tas_write_matrix_info,&
39 : dbt_tas_write_split_info,&
40 : prep_output_unit
41 : USE dbt_tas_reshape_ops, ONLY: dbt_tas_merge,&
42 : dbt_tas_replicate,&
43 : dbt_tas_reshape
44 : USE dbt_tas_split, ONLY: &
45 : accept_pgrid_dims, colsplit, dbt_tas_create_split, dbt_tas_get_split_info, &
46 : dbt_tas_info_hold, dbt_tas_mp_comm, dbt_tas_release_info, default_nsplit_accept_ratio, &
47 : rowsplit
48 : USE dbt_tas_types, ONLY: dbt_tas_distribution_type,&
49 : dbt_tas_iterator,&
50 : dbt_tas_split_info,&
51 : dbt_tas_type
52 : USE dbt_tas_util, ONLY: array_eq,&
53 : swap
54 : USE kinds, ONLY: default_string_length,&
55 : dp,&
56 : int_8
57 : USE message_passing, ONLY: mp_cart_type
58 : #include "../../base/base_uses.f90"
59 :
60 : IMPLICIT NONE
61 : PRIVATE
62 :
63 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbt_tas_mm'
64 :
65 : PUBLIC :: &
66 : dbt_tas_multiply, &
67 : dbt_tas_batched_mm_init, &
68 : dbt_tas_batched_mm_finalize, &
69 : dbt_tas_set_batched_state, &
70 : dbt_tas_batched_mm_complete
71 :
72 : CONTAINS
73 :
74 : ! **************************************************************************************************
75 : !> \brief tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical
76 : !> to arguments of dbm_multiply (see dbm_mm, dbm_multiply_generic).
77 : !> \param transa ...
78 : !> \param transb ...
79 : !> \param transc ...
80 : !> \param alpha ...
81 : !> \param matrix_a ...
82 : !> \param matrix_b ...
83 : !> \param beta ...
84 : !> \param matrix_c ...
85 : !> \param optimize_dist Whether distribution should be optimized internally. In the current
86 : !> implementation this guarantees optimal parameters only for dense matrices.
87 : !> \param split_opt optionally return split info containing optimal grid and split parameters.
88 : !> This can be used to choose optimal process grids for subsequent matrix
89 : !> multiplications with matrices of similar shape and sparsity.
90 : !> \param filter_eps ...
91 : !> \param flop ...
92 : !> \param move_data_a memory optimization: move data to matrix_c such that matrix_a is empty on return
93 : !> (for internal use only)
94 : !> \param move_data_b memory optimization: move data to matrix_c such that matrix_b is empty on return
95 : !> (for internal use only)
96 : !> \param retain_sparsity ...
97 : !> \param simple_split ...
98 : !> \param unit_nr unit number for logging output
99 : !> \param log_verbose only for testing: verbose output
100 : !> \author Patrick Seewald
101 : ! **************************************************************************************************
102 1020030 : RECURSIVE SUBROUTINE dbt_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
103 : optimize_dist, split_opt, filter_eps, flop, move_data_a, &
104 : move_data_b, retain_sparsity, simple_split, unit_nr, log_verbose)
105 :
106 : LOGICAL, INTENT(IN) :: transa, transb, transc
107 : REAL(dp), INTENT(IN) :: alpha
108 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b
109 : REAL(dp), INTENT(IN) :: beta
110 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_c
111 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
112 : TYPE(dbt_tas_split_info), INTENT(OUT), OPTIONAL :: split_opt
113 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
114 : INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
115 : LOGICAL, INTENT(IN), OPTIONAL :: move_data_a, move_data_b, &
116 : retain_sparsity, simple_split
117 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
118 : LOGICAL, INTENT(IN), OPTIONAL :: log_verbose
119 :
120 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_multiply'
121 :
122 : INTEGER :: batched_repl, handle, handle2, handle3, handle4, max_mm_dim, max_mm_dim_batched, &
123 : nsplit, nsplit_batched, nsplit_opt, numproc, split_a, split_b, split_c, split_rc, &
124 : unit_nr_prv
125 : INTEGER(KIND=int_8) :: nze_a, nze_b, nze_c, nze_c_sum
126 : INTEGER(KIND=int_8), DIMENSION(2) :: dims_a, dims_b, dims_c
127 : INTEGER(KIND=int_8), DIMENSION(3) :: dims
128 : INTEGER, DIMENSION(2) :: pdims, pdims_sub
129 : LOGICAL :: do_batched, move_a, move_b, new_a, new_b, new_c, nodata_3, opt_pgrid, &
130 : simple_split_prv, tr_case, transa_prv, transb_prv, transc_prv
131 : REAL(KIND=dp) :: filter_eps_prv
132 : TYPE(dbm_type) :: matrix_a_mm, matrix_b_mm, matrix_c_mm
133 4117332 : TYPE(dbt_tas_split_info) :: info, info_a, info_b, info_c
134 : TYPE(dbt_tas_type), POINTER :: matrix_a_rep, matrix_a_rs, matrix_b_rep, &
135 : matrix_b_rs, matrix_c_rep, matrix_c_rs
136 242196 : TYPE(mp_cart_type) :: comm_tmp, mp_comm, mp_comm_group, &
137 242196 : mp_comm_mm, mp_comm_opt
138 :
139 242196 : CALL timeset(routineN, handle)
140 242196 : CALL matrix_a%dist%info%mp_comm%sync()
141 242196 : CALL timeset("dbt_tas_total", handle2)
142 :
143 242196 : NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs)
144 :
145 242196 : unit_nr_prv = prep_output_unit(unit_nr)
146 :
147 242196 : IF (PRESENT(simple_split)) THEN
148 72143 : simple_split_prv = simple_split
149 : ELSE
150 170053 : simple_split_prv = .FALSE.
151 :
152 510159 : info_a = dbt_tas_info(matrix_a); info_b = dbt_tas_info(matrix_b); info_c = dbt_tas_info(matrix_c)
153 170053 : IF (info_a%strict_split(1) .OR. info_b%strict_split(1) .OR. info_c%strict_split(1)) simple_split_prv = .TRUE.
154 : END IF
155 :
156 242196 : nodata_3 = .TRUE.
157 242196 : IF (PRESENT(retain_sparsity)) THEN
158 4794 : IF (retain_sparsity) nodata_3 = .FALSE.
159 : END IF
160 :
161 : ! get prestored info for multiplication strategy in case of batched mm
162 242196 : batched_repl = 0
163 242196 : do_batched = .FALSE.
164 242196 : IF (matrix_a%do_batched > 0) THEN
165 50876 : do_batched = .TRUE.
166 50876 : IF (matrix_a%do_batched == 3) THEN
167 : CPASSERT(batched_repl == 0)
168 19211 : batched_repl = 1
169 : CALL dbt_tas_get_split_info( &
170 : dbt_tas_info(matrix_a%mm_storage%store_batched_repl), &
171 19211 : nsplit=nsplit_batched)
172 19211 : CPASSERT(nsplit_batched > 0)
173 : max_mm_dim_batched = 3
174 : END IF
175 : END IF
176 :
177 242196 : IF (matrix_b%do_batched > 0) THEN
178 15358 : do_batched = .TRUE.
179 15358 : IF (matrix_b%do_batched == 3) THEN
180 2816 : CPASSERT(batched_repl == 0)
181 2816 : batched_repl = 2
182 : CALL dbt_tas_get_split_info( &
183 : dbt_tas_info(matrix_b%mm_storage%store_batched_repl), &
184 2816 : nsplit=nsplit_batched)
185 2816 : CPASSERT(nsplit_batched > 0)
186 : max_mm_dim_batched = 1
187 : END IF
188 : END IF
189 :
190 242196 : IF (matrix_c%do_batched > 0) THEN
191 38932 : do_batched = .TRUE.
192 38932 : IF (matrix_c%do_batched == 3) THEN
193 9392 : CPASSERT(batched_repl == 0)
194 9392 : batched_repl = 3
195 : CALL dbt_tas_get_split_info( &
196 : dbt_tas_info(matrix_c%mm_storage%store_batched_repl), &
197 9392 : nsplit=nsplit_batched)
198 9392 : CPASSERT(nsplit_batched > 0)
199 : max_mm_dim_batched = 2
200 : END IF
201 : END IF
202 :
203 242196 : move_a = .FALSE.
204 242196 : move_b = .FALSE.
205 :
206 242196 : IF (PRESENT(move_data_a)) move_a = move_data_a
207 242196 : IF (PRESENT(move_data_b)) move_b = move_data_b
208 :
209 242196 : transa_prv = transa; transb_prv = transb; transc_prv = transc
210 :
211 726588 : dims_a = [dbt_tas_nblkrows_total(matrix_a), dbt_tas_nblkcols_total(matrix_a)]
212 726588 : dims_b = [dbt_tas_nblkrows_total(matrix_b), dbt_tas_nblkcols_total(matrix_b)]
213 726588 : dims_c = [dbt_tas_nblkrows_total(matrix_c), dbt_tas_nblkcols_total(matrix_c)]
214 :
215 242196 : IF (unit_nr_prv > 0) THEN
216 34 : WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
217 : WRITE (unit_nr_prv, "(A)") &
218 : "DBT TAS MATRIX MULTIPLICATION: "// &
219 : TRIM(dbm_get_name(matrix_a%matrix))//" x "// &
220 : TRIM(dbm_get_name(matrix_b%matrix))//" = "// &
221 34 : TRIM(dbm_get_name(matrix_c%matrix))
222 34 : WRITE (unit_nr_prv, "(A)") REPEAT("-", 80)
223 : END IF
224 242196 : IF (do_batched) THEN
225 101430 : IF (unit_nr_prv > 0) THEN
226 : WRITE (unit_nr_prv, "(T2,A)") &
227 0 : "BATCHED PROCESSING OF MATMUL"
228 0 : IF (batched_repl > 0) THEN
229 0 : WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
230 : END IF
231 : END IF
232 : END IF
233 :
234 242196 : IF (transa_prv) THEN
235 76455 : CALL swap(dims_a)
236 : END IF
237 :
238 242196 : IF (transb_prv) THEN
239 126632 : CALL swap(dims_b)
240 : END IF
241 :
242 726588 : dims_c = [dims_a(1), dims_b(2)]
243 :
244 242196 : IF (.NOT. (dims_a(2) == dims_b(1))) THEN
245 0 : CPABORT("inconsistent matrix dimensions")
246 : END IF
247 :
248 968784 : dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
249 :
250 242196 : IF (unit_nr_prv > 0) THEN
251 34 : WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
252 : END IF
253 :
254 242196 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
255 242196 : numproc = mp_comm%num_pe
256 :
257 : ! derive optimal matrix layout and split factor from occupancies
258 242196 : nze_a = dbt_tas_get_nze_total(matrix_a)
259 242196 : nze_b = dbt_tas_get_nze_total(matrix_b)
260 :
261 242196 : IF (.NOT. simple_split_prv) THEN
262 : CALL dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
263 : estimated_nze=nze_c, filter_eps=filter_eps, &
264 72261 : retain_sparsity=retain_sparsity)
265 :
266 289044 : max_mm_dim = MAXLOC(dims, 1)
267 72261 : nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
268 72261 : nsplit_opt = nsplit
269 :
270 72261 : IF (unit_nr_prv > 0) THEN
271 : WRITE (unit_nr_prv, "(T2,A)") &
272 34 : "MM PARAMETERS"
273 34 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
274 68 : (nze_c + numproc - 1)/numproc
275 :
276 34 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
277 : END IF
278 :
279 169935 : ELSEIF (batched_repl > 0) THEN
280 31419 : nsplit = nsplit_batched
281 31419 : nsplit_opt = nsplit
282 31419 : max_mm_dim = max_mm_dim_batched
283 31419 : IF (unit_nr_prv > 0) THEN
284 : WRITE (unit_nr_prv, "(T2,A)") &
285 0 : "MM PARAMETERS"
286 0 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
287 : END IF
288 :
289 : ELSE
290 138516 : nsplit = 0
291 554064 : max_mm_dim = MAXLOC(dims, 1)
292 : END IF
293 :
294 : ! reshape matrices to the optimal layout and split factor
295 242196 : split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
296 70558 : SELECT CASE (max_mm_dim)
297 : CASE (1)
298 :
299 : split_a = rowsplit; split_c = rowsplit
300 : CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
301 : new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
302 : nsplit=nsplit, &
303 : opt_nsplit=batched_repl == 0, &
304 : split_rc_1=split_a, split_rc_2=split_c, &
305 : nodata2=nodata_3, comm_new=comm_tmp, &
306 70558 : move_data_1=move_a, unit_nr=unit_nr_prv)
307 :
308 70558 : info = dbt_tas_info(matrix_a_rs)
309 70558 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
310 :
311 70558 : new_b = .FALSE.
312 70558 : IF (matrix_b%do_batched <= 2) THEN
313 338710 : ALLOCATE (matrix_b_rs)
314 67742 : CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv, move_data=move_b)
315 67742 : transb_prv = .FALSE.
316 67742 : new_b = .TRUE.
317 : END IF
318 :
319 70558 : tr_case = transa_prv
320 :
321 141127 : IF (unit_nr_prv > 0) THEN
322 11 : IF (.NOT. tr_case) THEN
323 11 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
324 : ELSE
325 0 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
326 : END IF
327 : END IF
328 :
329 : CASE (2)
330 :
331 73084 : split_a = colsplit; split_b = rowsplit
332 : CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
333 : optimize_dist=optimize_dist, &
334 : nsplit=nsplit, &
335 : opt_nsplit=batched_repl == 0, &
336 : split_rc_1=split_a, split_rc_2=split_b, &
337 : comm_new=comm_tmp, &
338 73084 : move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
339 :
340 73084 : info = dbt_tas_info(matrix_a_rs)
341 73084 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
342 :
343 73084 : IF (matrix_c%do_batched == 1) THEN
344 28072 : matrix_c%mm_storage%batched_beta = beta
345 45012 : ELSEIF (matrix_c%do_batched > 1) THEN
346 10546 : matrix_c%mm_storage%batched_beta = matrix_c%mm_storage%batched_beta*beta
347 : END IF
348 :
349 73084 : IF (matrix_c%do_batched <= 2) THEN
350 318460 : ALLOCATE (matrix_c_rs)
351 63692 : CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv, nodata=nodata_3)
352 63692 : transc_prv = .FALSE.
353 :
354 : ! just leave sparsity structure for retain sparsity but no values
355 63692 : IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rs%matrix)
356 :
357 63692 : IF (matrix_c%do_batched >= 1) matrix_c%mm_storage%store_batched => matrix_c_rs
358 9392 : ELSEIF (matrix_c%do_batched == 3) THEN
359 9392 : matrix_c_rs => matrix_c%mm_storage%store_batched
360 : END IF
361 :
362 73084 : new_c = matrix_c%do_batched == 0
363 73084 : tr_case = transa_prv
364 :
365 146181 : IF (unit_nr_prv > 0) THEN
366 13 : IF (.NOT. tr_case) THEN
367 2 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
368 : ELSE
369 11 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
370 : END IF
371 : END IF
372 :
373 : CASE (3)
374 :
375 98554 : split_b = colsplit; split_c = colsplit
376 : CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
377 : transc_prv, optimize_dist=optimize_dist, &
378 : nsplit=nsplit, &
379 : opt_nsplit=batched_repl == 0, &
380 : split_rc_1=split_b, split_rc_2=split_c, &
381 : nodata2=nodata_3, comm_new=comm_tmp, &
382 98554 : move_data_1=move_b, unit_nr=unit_nr_prv)
383 98554 : info = dbt_tas_info(matrix_b_rs)
384 98554 : CALL dbt_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
385 :
386 98554 : new_a = .FALSE.
387 98554 : IF (matrix_a%do_batched <= 2) THEN
388 396715 : ALLOCATE (matrix_a_rs)
389 79343 : CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv, move_data=move_a)
390 79343 : transa_prv = .FALSE.
391 79343 : new_a = .TRUE.
392 : END IF
393 :
394 98554 : tr_case = transb_prv
395 :
396 439304 : IF (unit_nr_prv > 0) THEN
397 10 : IF (.NOT. tr_case) THEN
398 0 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
399 : ELSE
400 10 : WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
401 : END IF
402 : END IF
403 :
404 : END SELECT
405 :
406 242196 : CALL dbt_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
407 :
408 242196 : numproc = mp_comm%num_pe
409 726588 : pdims_sub = mp_comm_group%num_pe_cart
410 :
411 242196 : opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.TRUE.)
412 :
413 242196 : IF (PRESENT(filter_eps)) THEN
414 191472 : filter_eps_prv = filter_eps
415 : ELSE
416 50724 : filter_eps_prv = 0.0_dp
417 : END IF
418 :
419 242196 : IF (unit_nr_prv /= 0) THEN
420 52224 : IF (unit_nr_prv > 0) THEN
421 34 : WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
422 : END IF
423 52224 : CALL dbt_tas_write_split_info(info, unit_nr_prv)
424 52224 : IF (ASSOCIATED(matrix_a_rs)) CALL dbt_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
425 52224 : IF (ASSOCIATED(matrix_b_rs)) CALL dbt_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
426 52224 : IF (ASSOCIATED(matrix_c_rs)) CALL dbt_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
427 52224 : IF (unit_nr_prv > 0) THEN
428 34 : IF (opt_pgrid) THEN
429 0 : WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
430 : ELSE
431 34 : WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
432 : END IF
433 : END IF
434 : END IF
435 :
436 242196 : pdims = 0
437 242196 : CALL mp_comm_mm%create(mp_comm_group, 2, pdims)
438 :
439 : ! Convert DBM submatrices to optimized process grids and multiply
440 70558 : SELECT CASE (max_mm_dim)
441 : CASE (1)
442 70558 : IF (matrix_b%do_batched <= 2) THEN
443 338710 : ALLOCATE (matrix_b_rep)
444 67742 : CALL dbt_tas_replicate(matrix_b_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_b_rep, move_data=.TRUE.)
445 67742 : IF (matrix_b%do_batched == 1 .OR. matrix_b%do_batched == 2) THEN
446 8082 : matrix_b%mm_storage%store_batched_repl => matrix_b_rep
447 8082 : CALL dbt_tas_set_batched_state(matrix_b, state=3)
448 : END IF
449 2816 : ELSEIF (matrix_b%do_batched == 3) THEN
450 2816 : matrix_b_rep => matrix_b%mm_storage%store_batched_repl
451 : END IF
452 :
453 70558 : IF (new_b) THEN
454 67742 : CALL dbt_tas_destroy(matrix_b_rs)
455 67742 : DEALLOCATE (matrix_b_rs)
456 : END IF
457 70558 : IF (unit_nr_prv /= 0) THEN
458 438 : CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
459 438 : CALL dbt_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
460 : END IF
461 :
462 70558 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
463 :
464 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
465 70558 : info_a = dbt_tas_info(matrix_a_rs)
466 70558 : CALL dbt_tas_info_hold(info_a)
467 :
468 70558 : IF (new_a) THEN
469 6768 : CALL dbt_tas_destroy(matrix_a_rs)
470 6768 : DEALLOCATE (matrix_a_rs)
471 : END IF
472 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
473 70558 : move_data=matrix_b%do_batched == 0)
474 :
475 70558 : info_b = dbt_tas_info(matrix_b_rep)
476 70558 : CALL dbt_tas_info_hold(info_b)
477 :
478 70558 : IF (matrix_b%do_batched == 0) THEN
479 59660 : CALL dbt_tas_destroy(matrix_b_rep)
480 59660 : DEALLOCATE (matrix_b_rep)
481 : END IF
482 :
483 70558 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
484 :
485 70558 : info_c = dbt_tas_info(matrix_c_rs)
486 70558 : CALL dbt_tas_info_hold(info_c)
487 :
488 70558 : CALL matrix_a%dist%info%mp_comm%sync()
489 70558 : CALL timeset("dbt_tas_dbm", handle4)
490 70558 : IF (.NOT. tr_case) THEN
491 64082 : CALL timeset("dbt_tas_mm_1N", handle3)
492 :
493 : CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
494 : matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
495 64082 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
496 64082 : CALL timestop(handle3)
497 : ELSE
498 6476 : CALL timeset("dbt_tas_mm_1T", handle3)
499 : CALL dbm_multiply(transa=.TRUE., transb=.FALSE., alpha=alpha, &
500 : matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
501 6476 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
502 :
503 6476 : CALL timestop(handle3)
504 : END IF
505 70558 : CALL matrix_a%dist%info%mp_comm%sync()
506 70558 : CALL timestop(handle4)
507 :
508 70558 : CALL dbm_release(matrix_a_mm)
509 70558 : CALL dbm_release(matrix_b_mm)
510 :
511 70558 : nze_c = dbm_get_nze(matrix_c_mm)
512 :
513 70558 : IF (.NOT. new_c) THEN
514 64026 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
515 : ELSE
516 6532 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
517 : END IF
518 :
519 70558 : CALL dbm_release(matrix_c_mm)
520 :
521 70558 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
522 :
523 282670 : IF (unit_nr_prv /= 0) THEN
524 438 : CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
525 : END IF
526 :
527 : CASE (2)
528 73084 : IF (matrix_c%do_batched <= 1) THEN
529 312690 : ALLOCATE (matrix_c_rep)
530 62538 : CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
531 62538 : IF (matrix_c%do_batched == 1) THEN
532 28072 : matrix_c%mm_storage%store_batched_repl => matrix_c_rep
533 28072 : CALL dbt_tas_set_batched_state(matrix_c, state=3)
534 : END IF
535 10546 : ELSEIF (matrix_c%do_batched == 2) THEN
536 5770 : ALLOCATE (matrix_c_rep)
537 1154 : CALL dbt_tas_replicate(matrix_c_rs%matrix, dbt_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
538 : ! just leave sparsity structure for retain sparsity but no values
539 1154 : IF (.NOT. nodata_3) CALL dbm_zero(matrix_c_rep%matrix)
540 1154 : matrix_c%mm_storage%store_batched_repl => matrix_c_rep
541 1154 : CALL dbt_tas_set_batched_state(matrix_c, state=3)
542 9392 : ELSEIF (matrix_c%do_batched == 3) THEN
543 9392 : matrix_c_rep => matrix_c%mm_storage%store_batched_repl
544 : END IF
545 :
546 73084 : IF (unit_nr_prv /= 0) THEN
547 22760 : CALL dbt_tas_write_dist(matrix_a_rs, unit_nr_prv)
548 22760 : CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
549 : END IF
550 :
551 73084 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
552 :
553 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
554 73084 : info_a = dbt_tas_info(matrix_a_rs)
555 73084 : CALL dbt_tas_info_hold(info_a)
556 :
557 73084 : IF (new_a) THEN
558 614 : CALL dbt_tas_destroy(matrix_a_rs)
559 614 : DEALLOCATE (matrix_a_rs)
560 : END IF
561 :
562 73084 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
563 :
564 73084 : info_b = dbt_tas_info(matrix_b_rs)
565 73084 : CALL dbt_tas_info_hold(info_b)
566 :
567 73084 : IF (new_b) THEN
568 950 : CALL dbt_tas_destroy(matrix_b_rs)
569 950 : DEALLOCATE (matrix_b_rs)
570 : END IF
571 :
572 73084 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
573 :
574 73084 : info_c = dbt_tas_info(matrix_c_rep)
575 73084 : CALL dbt_tas_info_hold(info_c)
576 :
577 73084 : CALL matrix_a%dist%info%mp_comm%sync()
578 73084 : CALL timeset("dbt_tas_dbm", handle4)
579 73084 : CALL timeset("dbt_tas_mm_2", handle3)
580 : CALL dbm_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
581 : matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
582 73084 : filter_eps=filter_eps_prv/REAL(nsplit, KIND=dp), retain_sparsity=retain_sparsity, flop=flop)
583 73084 : CALL matrix_a%dist%info%mp_comm%sync()
584 73084 : CALL timestop(handle3)
585 73084 : CALL timestop(handle4)
586 :
587 73084 : CALL dbm_release(matrix_a_mm)
588 73084 : CALL dbm_release(matrix_b_mm)
589 :
590 73084 : nze_c = dbm_get_nze(matrix_c_mm)
591 :
592 73084 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
593 73084 : nze_c_sum = dbt_tas_get_nze_total(matrix_c_rep)
594 :
595 73084 : CALL dbm_release(matrix_c_mm)
596 :
597 73084 : IF (unit_nr_prv /= 0) THEN
598 22760 : CALL dbt_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
599 : END IF
600 :
601 73084 : IF (matrix_c%do_batched == 0) THEN
602 34466 : CALL dbt_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.TRUE.)
603 : ELSE
604 38618 : matrix_c%mm_storage%batched_out = .TRUE. ! postpone merging submatrices to dbt_tas_batched_mm_finalize
605 : END IF
606 :
607 73084 : IF (matrix_c%do_batched == 0) THEN
608 34466 : CALL dbt_tas_destroy(matrix_c_rep)
609 34466 : DEALLOCATE (matrix_c_rep)
610 : END IF
611 :
612 73084 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
613 :
614 : ! set upper limit on memory consumption for replicated matrix and complete batched mm
615 : ! if limit is exceeded
616 367212 : IF (nze_c_sum > default_nsplit_accept_ratio*MAX(nze_a, nze_b)) THEN
617 1792 : CALL dbt_tas_batched_mm_complete(matrix_c)
618 : END IF
619 :
620 : CASE (3)
621 98554 : IF (matrix_a%do_batched <= 2) THEN
622 396715 : ALLOCATE (matrix_a_rep)
623 79343 : CALL dbt_tas_replicate(matrix_a_rs%matrix, dbt_tas_info(matrix_b_rs), matrix_a_rep, move_data=.TRUE.)
624 79343 : IF (matrix_a%do_batched == 1 .OR. matrix_a%do_batched == 2) THEN
625 27705 : matrix_a%mm_storage%store_batched_repl => matrix_a_rep
626 27705 : CALL dbt_tas_set_batched_state(matrix_a, state=3)
627 : END IF
628 19211 : ELSEIF (matrix_a%do_batched == 3) THEN
629 19211 : matrix_a_rep => matrix_a%mm_storage%store_batched_repl
630 : END IF
631 :
632 98554 : IF (new_a) THEN
633 79343 : CALL dbt_tas_destroy(matrix_a_rs)
634 79343 : DEALLOCATE (matrix_a_rs)
635 : END IF
636 98554 : IF (unit_nr_prv /= 0) THEN
637 29026 : CALL dbt_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
638 29026 : CALL dbt_tas_write_dist(matrix_b_rs, unit_nr_prv)
639 : END IF
640 :
641 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
642 98554 : move_data=matrix_a%do_batched == 0)
643 :
644 : ! keep communicators alive even after releasing TAS matrices (communicator management does not work between DBM and TAS)
645 98554 : info_a = dbt_tas_info(matrix_a_rep)
646 98554 : CALL dbt_tas_info_hold(info_a)
647 :
648 98554 : IF (matrix_a%do_batched == 0) THEN
649 51638 : CALL dbt_tas_destroy(matrix_a_rep)
650 51638 : DEALLOCATE (matrix_a_rep)
651 : END IF
652 :
653 98554 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
654 :
655 98554 : info_b = dbt_tas_info(matrix_b_rs)
656 98554 : CALL dbt_tas_info_hold(info_b)
657 :
658 98554 : IF (new_b) THEN
659 16 : CALL dbt_tas_destroy(matrix_b_rs)
660 16 : DEALLOCATE (matrix_b_rs)
661 : END IF
662 98554 : CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
663 :
664 98554 : info_c = dbt_tas_info(matrix_c_rs)
665 98554 : CALL dbt_tas_info_hold(info_c)
666 :
667 98554 : CALL matrix_a%dist%info%mp_comm%sync()
668 98554 : CALL timeset("dbt_tas_dbm", handle4)
669 98554 : IF (.NOT. tr_case) THEN
670 44952 : CALL timeset("dbt_tas_mm_3N", handle3)
671 : CALL dbm_multiply(transa=.FALSE., transb=.FALSE., alpha=alpha, &
672 : matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
673 44952 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
674 44952 : CALL timestop(handle3)
675 : ELSE
676 53602 : CALL timeset("dbt_tas_mm_3T", handle3)
677 : CALL dbm_multiply(transa=.FALSE., transb=.TRUE., alpha=alpha, &
678 : matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
679 53602 : filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
680 53602 : CALL timestop(handle3)
681 : END IF
682 98554 : CALL matrix_a%dist%info%mp_comm%sync()
683 98554 : CALL timestop(handle4)
684 :
685 98554 : CALL dbm_release(matrix_a_mm)
686 98554 : CALL dbm_release(matrix_b_mm)
687 :
688 98554 : nze_c = dbm_get_nze(matrix_c_mm)
689 :
690 98554 : IF (.NOT. new_c) THEN
691 88380 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=beta)
692 : ELSE
693 10174 : CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, local_copy=.NOT. opt_pgrid, alpha=1.0_dp)
694 : END IF
695 :
696 98554 : CALL dbm_release(matrix_c_mm)
697 :
698 98554 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c_rs, filter_eps)
699 :
700 636412 : IF (unit_nr_prv /= 0) THEN
701 29026 : CALL dbt_tas_write_dist(matrix_c_rs, unit_nr_prv)
702 : END IF
703 : END SELECT
704 :
705 242196 : CALL mp_comm_mm%free()
706 :
707 242196 : CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm)
708 :
709 242196 : IF (PRESENT(split_opt)) THEN
710 112845 : SELECT CASE (max_mm_dim)
711 : CASE (1, 3)
712 112845 : CALL mp_comm%sum(nze_c)
713 : CASE (2)
714 57160 : CALL dbt_tas_get_split_info(info_c, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
715 57160 : CALL mp_comm%sum(nze_c)
716 227165 : CALL mp_comm%max(nze_c)
717 :
718 : END SELECT
719 170005 : nsplit_opt = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
720 : ! ideally we should rederive the split factor from the actual sparsity of C, but
721 : ! due to parameter beta, we can not get the sparsity of AxB from DBM if not new_c
722 170005 : mp_comm_opt = dbt_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
723 170005 : CALL dbt_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.TRUE.)
724 170005 : IF (unit_nr_prv > 0) THEN
725 : WRITE (unit_nr_prv, "(T2,A)") &
726 10 : "MM PARAMETERS"
727 10 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Number of matrix elements per CPU of result matrix:", &
728 20 : (nze_c + numproc - 1)/numproc
729 :
730 10 : WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Optimal split factor:", nsplit_opt
731 : END IF
732 :
733 : END IF
734 :
735 242196 : IF (new_c) THEN
736 51172 : CALL dbm_scale(matrix_c%matrix, beta)
737 : CALL dbt_tas_reshape(matrix_c_rs, matrix_c, summation=.TRUE., &
738 : transposed=(transc_prv .NEQV. transc), &
739 51172 : move_data=.TRUE.)
740 51172 : CALL dbt_tas_destroy(matrix_c_rs)
741 51172 : DEALLOCATE (matrix_c_rs)
742 51172 : IF (PRESENT(filter_eps)) CALL dbt_tas_filter(matrix_c, filter_eps)
743 191024 : ELSEIF (matrix_c%do_batched > 0) THEN
744 38924 : IF (matrix_c%mm_storage%batched_out) THEN
745 38618 : matrix_c%mm_storage%batched_trans = (transc_prv .NEQV. transc)
746 : END IF
747 : END IF
748 :
749 242196 : IF (PRESENT(move_data_a)) THEN
750 242148 : IF (move_data_a) CALL dbt_tas_clear(matrix_a)
751 : END IF
752 242196 : IF (PRESENT(move_data_b)) THEN
753 242148 : IF (move_data_b) CALL dbt_tas_clear(matrix_b)
754 : END IF
755 :
756 242196 : IF (PRESENT(flop)) THEN
757 131704 : CALL mp_comm%sum(flop)
758 131704 : flop = (flop + numproc - 1)/numproc
759 : END IF
760 :
761 242196 : IF (PRESENT(optimize_dist)) THEN
762 48 : IF (optimize_dist) CALL comm_tmp%free()
763 : END IF
764 242196 : IF (unit_nr_prv > 0) THEN
765 34 : WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
766 34 : WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
767 34 : WRITE (unit_nr_prv, '(A)') REPEAT("-", 80)
768 : END IF
769 :
770 242196 : CALL dbt_tas_release_info(info_a)
771 242196 : CALL dbt_tas_release_info(info_b)
772 242196 : CALL dbt_tas_release_info(info_c)
773 :
774 242196 : CALL matrix_a%dist%info%mp_comm%sync()
775 242196 : CALL timestop(handle2)
776 242196 : CALL timestop(handle)
777 :
778 484392 : END SUBROUTINE dbt_tas_multiply
779 :
780 : ! **************************************************************************************************
781 : !> \brief ...
782 : !> \param matrix_in ...
783 : !> \param matrix_out ...
784 : !> \param local_copy ...
785 : !> \param alpha ...
786 : !> \author Patrick Seewald
787 : ! **************************************************************************************************
788 242196 : SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, local_copy, alpha)
789 : TYPE(dbm_type), INTENT(IN) :: matrix_in
790 : TYPE(dbm_type), INTENT(INOUT) :: matrix_out
791 : LOGICAL, INTENT(IN), OPTIONAL :: local_copy
792 : REAL(dp), INTENT(IN) :: alpha
793 :
794 : LOGICAL :: local_copy_prv
795 : TYPE(dbm_type) :: matrix_tmp
796 :
797 242196 : IF (PRESENT(local_copy)) THEN
798 242196 : local_copy_prv = local_copy
799 : ELSE
800 : local_copy_prv = .FALSE.
801 : END IF
802 :
803 242196 : IF (alpha /= 1.0_dp) THEN
804 151490 : CALL dbm_scale(matrix_out, alpha)
805 : END IF
806 :
807 242196 : IF (.NOT. local_copy_prv) THEN
808 0 : CALL dbm_create_from_template(matrix_tmp, name="tmp", template=matrix_out)
809 0 : CALL dbm_redistribute(matrix_in, matrix_tmp)
810 0 : CALL dbm_add(matrix_out, matrix_tmp)
811 0 : CALL dbm_release(matrix_tmp)
812 : ELSE
813 242196 : CALL dbm_add(matrix_out, matrix_in)
814 : END IF
815 :
816 242196 : END SUBROUTINE redistribute_and_sum
817 :
818 : ! **************************************************************************************************
819 : !> \brief Make sure that smallest matrix involved in a multiplication is not split and bring it to
820 : !> the same process grid as the other 2 matrices.
821 : !> \param mp_comm communicator that defines Cartesian topology
822 : !> \param matrix_in ...
823 : !> \param matrix_out ...
824 : !> \param transposed Whether matrix_out should be transposed
825 : !> \param nodata Data of matrix_in should not be copied to matrix_out
826 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
827 : !> \author Patrick Seewald
828 : ! **************************************************************************************************
829 1475439 : SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, nodata, move_data)
830 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm
831 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
832 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
833 : LOGICAL, INTENT(IN) :: transposed
834 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
835 :
836 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_small'
837 :
838 : INTEGER :: handle
839 : INTEGER(KIND=int_8), DIMENSION(2) :: dims
840 : INTEGER, DIMENSION(2) :: pdims
841 : LOGICAL :: nodata_prv
842 210777 : TYPE(dbt_tas_dist_arb) :: new_col_dist, new_row_dist
843 1053885 : TYPE(dbt_tas_distribution_type) :: dist
844 :
845 210777 : CALL timeset(routineN, handle)
846 :
847 210777 : IF (PRESENT(nodata)) THEN
848 63692 : nodata_prv = nodata
849 : ELSE
850 : nodata_prv = .FALSE.
851 : END IF
852 :
853 632331 : pdims = mp_comm%num_pe_cart
854 :
855 632331 : dims = [dbt_tas_nblkrows_total(matrix_in), dbt_tas_nblkcols_total(matrix_in)]
856 :
857 210777 : IF (transposed) CALL swap(dims)
858 :
859 210777 : IF (.NOT. transposed) THEN
860 143955 : new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
861 143955 : new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
862 143955 : CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
863 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
864 143955 : matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
865 : ELSE
866 66822 : new_row_dist = dbt_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
867 66822 : new_col_dist = dbt_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
868 66822 : CALL dbt_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
869 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist, &
870 66822 : matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
871 : END IF
872 210777 : IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
873 :
874 210777 : CALL timestop(handle)
875 :
876 210777 : END SUBROUTINE reshape_mm_small
877 :
878 : ! **************************************************************************************************
879 : !> \brief Reshape either matrix1 or matrix2 to make sure that their process grids are compatible
880 : !> with the same split factor.
881 : !> \param matrix1_in ...
882 : !> \param matrix2_in ...
883 : !> \param matrix1_out ...
884 : !> \param matrix2_out ...
885 : !> \param new1 Whether matrix1_out is a new matrix or simply pointing to matrix1_in
886 : !> \param new2 Whether matrix2_out is a new matrix or simply pointing to matrix2_in
887 : !> \param trans1 transpose flag of matrix1_in for multiplication
888 : !> \param trans2 transpose flag of matrix2_in for multiplication
889 : !> \param optimize_dist experimental: optimize matrix splitting and distribution
890 : !> \param nsplit Optimal split factor (set to 0 if split factor should not be changed)
891 : !> \param opt_nsplit ...
892 : !> \param split_rc_1 Whether to split rows or columns for matrix 1
893 : !> \param split_rc_2 Whether to split rows or columns for matrix 2
894 : !> \param nodata1 Don't copy matrix data from matrix1_in to matrix1_out
895 : !> \param nodata2 Don't copy matrix data from matrix2_in to matrix2_out
896 : !> \param move_data_1 memory optimization: move data such that matrix1_in may be empty on return.
897 : !> \param move_data_2 memory optimization: move data such that matrix2_in may be empty on return.
898 : !> \param comm_new returns the new communicator only if optimize_dist
899 : !> \param unit_nr output unit
900 : !> \author Patrick Seewald
901 : ! **************************************************************************************************
902 242196 : SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
903 : optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
904 : move_data_1, move_data_2, comm_new, unit_nr)
905 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix1_in, matrix2_in
906 : TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix1_out, matrix2_out
907 : LOGICAL, INTENT(OUT) :: new1, new2
908 : LOGICAL, INTENT(INOUT) :: trans1, trans2
909 : LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist
910 : INTEGER, INTENT(IN), OPTIONAL :: nsplit
911 : LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
912 : INTEGER, INTENT(INOUT) :: split_rc_1, split_rc_2
913 : LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2
914 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2
915 : TYPE(mp_cart_type), INTENT(OUT), OPTIONAL :: comm_new
916 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
917 :
918 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_compatible'
919 :
920 : INTEGER :: handle, nsplit_prv, ref, split_rc_ref, &
921 : unit_nr_prv
922 : INTEGER(KIND=int_8) :: d1, d2, nze1, nze2
923 : INTEGER(KIND=int_8), DIMENSION(2) :: dims1, dims2, dims_ref
924 : INTEGER, DIMENSION(2) :: pdims
925 : LOGICAL :: nodata1_prv, nodata2_prv, &
926 : optimize_dist_prv, trans1_newdist, &
927 : trans2_newdist
928 : TYPE(dbt_tas_dist_cyclic) :: col_dist_1, col_dist_2, row_dist_1, &
929 : row_dist_2
930 2179764 : TYPE(dbt_tas_distribution_type) :: dist_1, dist_2
931 1210980 : TYPE(dbt_tas_split_info) :: split_info
932 242196 : TYPE(mp_cart_type) :: mp_comm
933 :
934 242196 : CALL timeset(routineN, handle)
935 242196 : new1 = .FALSE.; new2 = .FALSE.
936 :
937 242196 : IF (PRESENT(nodata1)) THEN
938 0 : nodata1_prv = nodata1
939 : ELSE
940 : nodata1_prv = .FALSE.
941 : END IF
942 :
943 242196 : IF (PRESENT(nodata2)) THEN
944 169112 : nodata2_prv = nodata2
945 : ELSE
946 : nodata2_prv = .FALSE.
947 : END IF
948 :
949 242196 : unit_nr_prv = prep_output_unit(unit_nr)
950 :
951 242196 : NULLIFY (matrix1_out, matrix2_out)
952 :
953 242196 : IF (PRESENT(optimize_dist)) THEN
954 48 : optimize_dist_prv = optimize_dist
955 : ELSE
956 : optimize_dist_prv = .FALSE.
957 : END IF
958 :
959 726588 : dims1 = [dbt_tas_nblkrows_total(matrix1_in), dbt_tas_nblkcols_total(matrix1_in)]
960 726588 : dims2 = [dbt_tas_nblkrows_total(matrix2_in), dbt_tas_nblkcols_total(matrix2_in)]
961 242196 : nze1 = dbt_tas_get_nze_total(matrix1_in)
962 242196 : nze2 = dbt_tas_get_nze_total(matrix2_in)
963 :
964 242196 : IF (trans1) split_rc_1 = MOD(split_rc_1, 2) + 1
965 :
966 242196 : IF (trans2) split_rc_2 = MOD(split_rc_2, 2) + 1
967 :
968 242196 : IF (nze1 >= nze2) THEN
969 225251 : ref = 1
970 225251 : split_rc_ref = split_rc_1
971 225251 : dims_ref = dims1
972 : ELSE
973 16945 : ref = 2
974 16945 : split_rc_ref = split_rc_2
975 16945 : dims_ref = dims2
976 : END IF
977 :
978 242196 : IF (PRESENT(nsplit)) THEN
979 242196 : nsplit_prv = nsplit
980 : ELSE
981 0 : nsplit_prv = 0
982 : END IF
983 :
984 242196 : IF (optimize_dist_prv) THEN
985 48 : CPASSERT(PRESENT(comm_new))
986 : END IF
987 :
988 242148 : IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
989 : CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
990 223682 : move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
991 223682 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_out), nsplit=nsplit_prv)
992 : CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
993 223682 : move_data=move_data_2, nodata=nodata2, opt_nsplit=.FALSE.)
994 223682 : IF (unit_nr_prv > 0) THEN
995 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", &
996 10 : TRIM(dbm_get_name(matrix1_in%matrix)), &
997 20 : "and", TRIM(dbm_get_name(matrix2_in%matrix))
998 10 : IF (new1) THEN
999 0 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1000 0 : TRIM(dbm_get_name(matrix1_in%matrix)), ": Yes"
1001 : ELSE
1002 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1003 20 : TRIM(dbm_get_name(matrix1_in%matrix)), ": No"
1004 : END IF
1005 10 : IF (new2) THEN
1006 0 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1007 0 : TRIM(dbm_get_name(matrix2_in%matrix)), ": Yes"
1008 : ELSE
1009 10 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", &
1010 20 : TRIM(dbm_get_name(matrix2_in%matrix)), ": No"
1011 : END IF
1012 : END IF
1013 : ELSE
1014 :
1015 18466 : IF (optimize_dist_prv) THEN
1016 48 : IF (unit_nr_prv > 0) THEN
1017 24 : WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", &
1018 24 : TRIM(dbm_get_name(matrix1_in%matrix)), &
1019 48 : "and", TRIM(dbm_get_name(matrix2_in%matrix))
1020 : END IF
1021 :
1022 48 : trans1_newdist = (split_rc_1 == colsplit)
1023 48 : trans2_newdist = (split_rc_2 == colsplit)
1024 :
1025 48 : IF (trans1_newdist) THEN
1026 24 : CALL swap(dims1)
1027 24 : trans1 = .NOT. trans1
1028 : END IF
1029 :
1030 48 : IF (trans2_newdist) THEN
1031 24 : CALL swap(dims2)
1032 24 : trans2 = .NOT. trans2
1033 : END IF
1034 :
1035 48 : IF (nsplit_prv == 0) THEN
1036 0 : SELECT CASE (split_rc_ref)
1037 : CASE (rowsplit)
1038 0 : d1 = dims_ref(1)
1039 0 : d2 = dims_ref(2)
1040 : CASE (colsplit)
1041 0 : d1 = dims_ref(2)
1042 0 : d2 = dims_ref(1)
1043 : END SELECT
1044 0 : nsplit_prv = INT((d1 - 1)/d2 + 1)
1045 : END IF
1046 :
1047 48 : CPASSERT(nsplit_prv > 0)
1048 :
1049 48 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix1_in), mp_comm=mp_comm)
1050 48 : comm_new = dbt_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
1051 48 : CALL dbt_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
1052 :
1053 144 : pdims = comm_new%num_pe_cart
1054 :
1055 : ! use a very simple cyclic distribution that may not be load balanced if block
1056 : ! sizes are not equal. However we can not use arbitrary distributions
1057 : ! for large dimensions since this would require storing distribution vectors as arrays
1058 : ! which can not be stored for large dimensions.
1059 48 : row_dist_1 = dbt_tas_dist_cyclic(1, pdims(1), dims1(1))
1060 48 : col_dist_1 = dbt_tas_dist_cyclic(1, pdims(2), dims1(2))
1061 :
1062 48 : row_dist_2 = dbt_tas_dist_cyclic(1, pdims(1), dims2(1))
1063 48 : col_dist_2 = dbt_tas_dist_cyclic(1, pdims(2), dims2(2))
1064 :
1065 48 : CALL dbt_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
1066 48 : CALL dbt_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
1067 48 : CALL dbt_tas_release_info(split_info)
1068 :
1069 240 : ALLOCATE (matrix1_out)
1070 48 : IF (.NOT. trans1_newdist) THEN
1071 : CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1072 24 : matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.TRUE.)
1073 :
1074 : ELSE
1075 : CALL dbt_tas_create(matrix1_out, dbm_get_name(matrix1_in%matrix), dist_1, &
1076 24 : matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.TRUE.)
1077 : END IF
1078 :
1079 240 : ALLOCATE (matrix2_out)
1080 48 : IF (.NOT. trans2_newdist) THEN
1081 : CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1082 24 : matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.TRUE.)
1083 : ELSE
1084 : CALL dbt_tas_create(matrix2_out, dbm_get_name(matrix2_in%matrix), dist_2, &
1085 24 : matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.TRUE.)
1086 : END IF
1087 :
1088 48 : IF (.NOT. nodata1_prv) CALL dbt_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
1089 48 : IF (.NOT. nodata2_prv) CALL dbt_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
1090 48 : new1 = .TRUE.
1091 48 : new2 = .TRUE.
1092 :
1093 : ELSE
1094 17584 : SELECT CASE (ref)
1095 : CASE (1)
1096 17584 : IF (unit_nr_prv > 0) THEN
1097 0 : WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1098 0 : TRIM(dbm_get_name(matrix2_in%matrix))
1099 : END IF
1100 :
1101 : CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
1102 17584 : move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
1103 :
1104 87920 : ALLOCATE (matrix2_out)
1105 : CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
1106 17584 : nodata=nodata2, move_data=move_data_2)
1107 17584 : new2 = .TRUE.
1108 : CASE (2)
1109 882 : IF (unit_nr_prv > 0) THEN
1110 0 : WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", &
1111 0 : TRIM(dbm_get_name(matrix1_in%matrix))
1112 : END IF
1113 :
1114 : CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
1115 882 : move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
1116 :
1117 4410 : ALLOCATE (matrix1_out)
1118 : CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
1119 882 : nodata=nodata1, move_data=move_data_1)
1120 37814 : new1 = .TRUE.
1121 : END SELECT
1122 : END IF
1123 : END IF
1124 :
1125 242196 : IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .TRUE.
1126 242196 : IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .TRUE.
1127 :
1128 242196 : CALL timestop(handle)
1129 :
1130 726588 : END SUBROUTINE reshape_mm_compatible
1131 :
1132 : ! **************************************************************************************************
1133 : !> \brief Change split factor without redistribution
1134 : !> \param matrix_in ...
1135 : !> \param matrix_out ...
1136 : !> \param nsplit new split factor, set to 0 to not change split of matrix_in
1137 : !> \param split_rowcol split rows or columns
1138 : !> \param is_new whether matrix_out is new or a pointer to matrix_in
1139 : !> \param opt_nsplit whether nsplit should be optimized for current process grid
1140 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1141 : !> \param nodata Data of matrix_in should not be copied to matrix_out
1142 : !> \author Patrick Seewald
1143 : ! **************************************************************************************************
1144 465830 : SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
1145 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_in
1146 : TYPE(dbt_tas_type), INTENT(OUT), POINTER :: matrix_out
1147 : INTEGER, INTENT(IN) :: nsplit, split_rowcol
1148 : LOGICAL, INTENT(OUT) :: is_new
1149 : LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit
1150 : LOGICAL, INTENT(INOUT), OPTIONAL :: move_data
1151 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
1152 :
1153 : CHARACTER(len=default_string_length) :: name
1154 : INTEGER :: handle, nsplit_new, nsplit_old, &
1155 : nsplit_prv, split_rc
1156 : LOGICAL :: nodata_prv
1157 2329150 : TYPE(dbt_tas_distribution_type) :: dist
1158 2329150 : TYPE(dbt_tas_split_info) :: split_info
1159 465830 : TYPE(mp_cart_type) :: mp_comm
1160 :
1161 1863320 : CLASS(dbt_tas_distribution), ALLOCATABLE :: rdist, cdist
1162 931660 : CLASS(dbt_tas_rowcol_data), ALLOCATABLE :: rbsize, cbsize
1163 : CHARACTER(LEN=*), PARAMETER :: routineN = 'change_split'
1164 :
1165 465830 : NULLIFY (matrix_out)
1166 :
1167 465830 : is_new = .TRUE.
1168 :
1169 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_in), mp_comm=mp_comm, &
1170 465830 : split_rowcol=split_rc, nsplit=nsplit_old)
1171 :
1172 465830 : IF (nsplit == 0) THEN
1173 138516 : IF (split_rowcol == split_rc) THEN
1174 135116 : matrix_out => matrix_in
1175 135116 : is_new = .FALSE.
1176 135116 : RETURN
1177 : ELSE
1178 3400 : nsplit_prv = 1
1179 : END IF
1180 : ELSE
1181 327314 : nsplit_prv = nsplit
1182 : END IF
1183 :
1184 330714 : CALL timeset(routineN, handle)
1185 :
1186 330714 : nodata_prv = .FALSE.
1187 330714 : IF (PRESENT(nodata)) nodata_prv = nodata
1188 :
1189 : CALL dbt_tas_get_info(matrix_in, name=name, &
1190 : row_blk_size=rbsize, col_blk_size=cbsize, &
1191 : proc_row_dist=rdist, proc_col_dist=cdist)
1192 :
1193 330714 : CALL dbt_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
1194 :
1195 330714 : CALL dbt_tas_get_split_info(split_info, nsplit=nsplit_new)
1196 :
1197 330714 : IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
1198 324222 : matrix_out => matrix_in
1199 324222 : is_new = .FALSE.
1200 324222 : CALL dbt_tas_release_info(split_info)
1201 324222 : CALL timestop(handle)
1202 324222 : RETURN
1203 : END IF
1204 :
1205 : CALL dbt_tas_distribution_new(dist, mp_comm, rdist, cdist, &
1206 6492 : split_info=split_info)
1207 :
1208 6492 : CALL dbt_tas_release_info(split_info)
1209 :
1210 32460 : ALLOCATE (matrix_out)
1211 6492 : CALL dbt_tas_create(matrix_out, name, dist, rbsize, cbsize, own_dist=.TRUE.)
1212 :
1213 6492 : IF (.NOT. nodata_prv) CALL dbt_tas_copy(matrix_out, matrix_in)
1214 :
1215 6492 : IF (PRESENT(move_data)) THEN
1216 6492 : IF (.NOT. nodata_prv) THEN
1217 6492 : IF (move_data) CALL dbt_tas_clear(matrix_in)
1218 6492 : move_data = .TRUE.
1219 : END IF
1220 : END IF
1221 :
1222 6492 : CALL timestop(handle)
1223 1656114 : END SUBROUTINE change_split
1224 :
1225 : ! **************************************************************************************************
1226 : !> \brief Check whether matrices have same distribution and same split.
1227 : !> \param mat_a ...
1228 : !> \param mat_b ...
1229 : !> \param split_rc_a ...
1230 : !> \param split_rc_b ...
1231 : !> \param unit_nr ...
1232 : !> \return ...
1233 : !> \author Patrick Seewald
1234 : ! **************************************************************************************************
1235 242148 : FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
1236 : TYPE(dbt_tas_type), INTENT(IN) :: mat_a, mat_b
1237 : INTEGER, INTENT(IN) :: split_rc_a, split_rc_b
1238 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1239 : LOGICAL :: dist_compatible
1240 :
1241 : INTEGER :: numproc, same_local_rowcols, &
1242 : split_check_a, split_check_b, &
1243 : unit_nr_prv
1244 242148 : INTEGER(int_8), ALLOCATABLE, DIMENSION(:) :: local_rowcols_a, local_rowcols_b
1245 : INTEGER, DIMENSION(2) :: pdims_a, pdims_b
1246 2179332 : TYPE(dbt_tas_split_info) :: info_a, info_b
1247 :
1248 242148 : unit_nr_prv = prep_output_unit(unit_nr)
1249 :
1250 242148 : dist_compatible = .FALSE.
1251 :
1252 242148 : info_a = dbt_tas_info(mat_a)
1253 242148 : info_b = dbt_tas_info(mat_b)
1254 242148 : CALL dbt_tas_get_split_info(info_a, split_rowcol=split_check_a)
1255 242148 : CALL dbt_tas_get_split_info(info_b, split_rowcol=split_check_b)
1256 242148 : IF (split_check_b /= split_rc_b .OR. split_check_a /= split_rc_a .OR. split_rc_a /= split_rc_b) THEN
1257 18386 : IF (unit_nr_prv > 0) THEN
1258 0 : WRITE (unit_nr_prv, *) "matrix layout a not compatible", split_check_a, split_rc_a
1259 0 : WRITE (unit_nr_prv, *) "matrix layout b not compatible", split_check_b, split_rc_b
1260 : END IF
1261 18442 : RETURN
1262 : END IF
1263 :
1264 : ! check if communicators are equivalent
1265 : ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
1266 : ! It's sufficient to check dimensions of global grid, subgrids will be determined later on (change_split)
1267 223762 : numproc = info_b%mp_comm%num_pe
1268 671286 : pdims_a = info_a%mp_comm%num_pe_cart
1269 671286 : pdims_b = info_b%mp_comm%num_pe_cart
1270 223762 : IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
1271 56 : IF (unit_nr_prv > 0) THEN
1272 0 : WRITE (unit_nr_prv, *) "mp dims not compatible:", pdims_a, "|", pdims_b
1273 : END IF
1274 56 : RETURN
1275 : END IF
1276 :
1277 : ! check that distribution is the same by comparing local rows / columns for each matrix
1278 154530 : SELECT CASE (split_rc_a)
1279 : CASE (rowsplit)
1280 154530 : CALL dbt_tas_get_info(mat_a, local_rows=local_rowcols_a)
1281 154530 : CALL dbt_tas_get_info(mat_b, local_rows=local_rowcols_b)
1282 : CASE (colsplit)
1283 69176 : CALL dbt_tas_get_info(mat_a, local_cols=local_rowcols_a)
1284 292882 : CALL dbt_tas_get_info(mat_b, local_cols=local_rowcols_b)
1285 : END SELECT
1286 :
1287 223706 : same_local_rowcols = MERGE(1, 0, array_eq(local_rowcols_a, local_rowcols_b))
1288 :
1289 223706 : CALL info_a%mp_comm%sum(same_local_rowcols)
1290 :
1291 223706 : IF (same_local_rowcols == numproc) THEN
1292 : dist_compatible = .TRUE.
1293 : ELSE
1294 24 : IF (unit_nr_prv > 0) THEN
1295 0 : WRITE (unit_nr_prv, *) "local rowcols not compatible"
1296 0 : WRITE (unit_nr_prv, *) "local rowcols A", local_rowcols_a
1297 0 : WRITE (unit_nr_prv, *) "local rowcols B", local_rowcols_b
1298 : END IF
1299 : END IF
1300 :
1301 484296 : END FUNCTION dist_compatible
1302 :
1303 : ! **************************************************************************************************
1304 : !> \brief Reshape matrix_in s.t. it has same process grid, distribution and split as template
1305 : !> \param template ...
1306 : !> \param matrix_in ...
1307 : !> \param matrix_out ...
1308 : !> \param trans ...
1309 : !> \param split_rc ...
1310 : !> \param nodata ...
1311 : !> \param move_data ...
1312 : !> \author Patrick Seewald
1313 : ! **************************************************************************************************
1314 129262 : SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
1315 : TYPE(dbt_tas_type), INTENT(IN) :: template
1316 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1317 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1318 : LOGICAL, INTENT(INOUT) :: trans
1319 : INTEGER, INTENT(IN) :: split_rc
1320 : LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data
1321 :
1322 18466 : CLASS(dbt_tas_distribution), ALLOCATABLE :: row_dist, col_dist
1323 :
1324 110796 : TYPE(dbt_tas_distribution_type) :: dist_new
1325 203126 : TYPE(dbt_tas_split_info) :: info_template, info_matrix
1326 : INTEGER :: dim_split_template, dim_split_matrix, &
1327 : handle
1328 : INTEGER, DIMENSION(2) :: pdims
1329 : LOGICAL :: nodata_prv, transposed
1330 18466 : TYPE(mp_cart_type) :: mp_comm
1331 : CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_template'
1332 :
1333 18466 : CALL timeset(routineN, handle)
1334 :
1335 18466 : IF (PRESENT(nodata)) THEN
1336 16674 : nodata_prv = nodata
1337 : ELSE
1338 : nodata_prv = .FALSE.
1339 : END IF
1340 :
1341 18466 : info_template = dbt_tas_info(template)
1342 18466 : info_matrix = dbt_tas_info(matrix_in)
1343 :
1344 18466 : dim_split_template = info_template%split_rowcol
1345 18466 : dim_split_matrix = split_rc
1346 :
1347 18466 : transposed = dim_split_template /= dim_split_matrix
1348 18466 : IF (transposed) trans = .NOT. trans
1349 :
1350 55398 : pdims = info_template%mp_comm%num_pe_cart
1351 :
1352 11008 : SELECT CASE (dim_split_template)
1353 : CASE (1)
1354 11008 : IF (.NOT. transposed) THEN
1355 44 : ALLOCATE (row_dist, source=template%dist%row_dist)
1356 44 : ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
1357 : ELSE
1358 10964 : ALLOCATE (row_dist, source=template%dist%row_dist)
1359 10964 : ALLOCATE (col_dist, source=dbt_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
1360 : END IF
1361 : CASE (2)
1362 18466 : IF (.NOT. transposed) THEN
1363 120 : ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
1364 120 : ALLOCATE (col_dist, source=template%dist%col_dist)
1365 : ELSE
1366 14796 : ALLOCATE (row_dist, source=dbt_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
1367 14796 : ALLOCATE (col_dist, source=template%dist%col_dist)
1368 : END IF
1369 : END SELECT
1370 :
1371 18466 : CALL dbt_tas_get_split_info(info_template, mp_comm=mp_comm)
1372 18466 : CALL dbt_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
1373 18466 : IF (.NOT. transposed) THEN
1374 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1375 104 : matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
1376 : ELSE
1377 : CALL dbt_tas_create(matrix_out, dbm_get_name(matrix_in%matrix), dist_new, &
1378 18362 : matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
1379 : END IF
1380 :
1381 18466 : IF (.NOT. nodata_prv) CALL dbt_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
1382 :
1383 18466 : CALL timestop(handle)
1384 :
1385 36932 : END SUBROUTINE reshape_mm_template
1386 :
1387 : ! **************************************************************************************************
1388 : !> \brief Estimate sparsity pattern of C resulting from A x B = C
1389 : !> by multiplying the block norms of A and B Same dummy arguments as dbt_tas_multiply
1390 : !> \param transa ...
1391 : !> \param transb ...
1392 : !> \param transc ...
1393 : !> \param matrix_a ...
1394 : !> \param matrix_b ...
1395 : !> \param matrix_c ...
1396 : !> \param estimated_nze ...
1397 : !> \param filter_eps ...
1398 : !> \param unit_nr ...
1399 : !> \param retain_sparsity ...
1400 : !> \author Patrick Seewald
1401 : ! **************************************************************************************************
1402 72261 : SUBROUTINE dbt_tas_estimate_result_nze(transa, transb, transc, matrix_a, matrix_b, matrix_c, &
1403 : estimated_nze, filter_eps, unit_nr, retain_sparsity)
1404 : LOGICAL, INTENT(IN) :: transa, transb, transc
1405 : TYPE(dbt_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b, matrix_c
1406 : INTEGER(int_8), INTENT(OUT) :: estimated_nze
1407 : REAL(KIND=dp), INTENT(IN), OPTIONAL :: filter_eps
1408 : INTEGER, INTENT(IN), OPTIONAL :: unit_nr
1409 : LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
1410 :
1411 : CHARACTER(LEN=*), PARAMETER :: routineN = 'dbt_tas_estimate_result_nze'
1412 :
1413 : INTEGER :: col_size, handle, row_size
1414 : INTEGER(int_8) :: col, row
1415 : LOGICAL :: retain_sparsity_prv
1416 : TYPE(dbt_tas_iterator) :: iter
1417 : TYPE(dbt_tas_type), POINTER :: matrix_a_bnorm, matrix_b_bnorm, &
1418 : matrix_c_bnorm
1419 72261 : TYPE(mp_cart_type) :: mp_comm
1420 :
1421 72261 : CALL timeset(routineN, handle)
1422 :
1423 72261 : IF (PRESENT(retain_sparsity)) THEN
1424 118 : retain_sparsity_prv = retain_sparsity
1425 : ELSE
1426 : retain_sparsity_prv = .FALSE.
1427 : END IF
1428 :
1429 118 : IF (.NOT. retain_sparsity_prv) THEN
1430 1370717 : ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
1431 72143 : CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
1432 72143 : CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
1433 72143 : CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.TRUE.)
1434 :
1435 : CALL dbt_tas_multiply(transa, transb, transc, 1.0_dp, matrix_a_bnorm, &
1436 : matrix_b_bnorm, 0.0_dp, matrix_c_bnorm, &
1437 : filter_eps=filter_eps, move_data_a=.TRUE., move_data_b=.TRUE., &
1438 72143 : simple_split=.TRUE., unit_nr=unit_nr)
1439 72143 : CALL dbt_tas_destroy(matrix_a_bnorm)
1440 72143 : CALL dbt_tas_destroy(matrix_b_bnorm)
1441 :
1442 72143 : DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
1443 : ELSE
1444 : matrix_c_bnorm => matrix_c
1445 : END IF
1446 :
1447 72261 : estimated_nze = 0
1448 : !$OMP PARALLEL DEFAULT(NONE) REDUCTION(+:estimated_nze) SHARED(matrix_c_bnorm,matrix_c) &
1449 72261 : !$OMP PRIVATE(iter,row,col,row_size,col_size)
1450 : CALL dbt_tas_iterator_start(iter, matrix_c_bnorm)
1451 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
1452 : CALL dbt_tas_iterator_next_block(iter, row, col)
1453 : row_size = matrix_c%row_blk_size%data(row)
1454 : col_size = matrix_c%col_blk_size%data(col)
1455 : estimated_nze = estimated_nze + row_size*col_size
1456 : END DO
1457 : CALL dbt_tas_iterator_stop(iter)
1458 : !$OMP END PARALLEL
1459 :
1460 72261 : CALL dbt_tas_get_split_info(dbt_tas_info(matrix_a), mp_comm=mp_comm)
1461 72261 : CALL mp_comm%sum(estimated_nze)
1462 :
1463 72261 : IF (.NOT. retain_sparsity_prv) THEN
1464 72143 : CALL dbt_tas_destroy(matrix_c_bnorm)
1465 72143 : DEALLOCATE (matrix_c_bnorm)
1466 : END IF
1467 :
1468 72261 : CALL timestop(handle)
1469 :
1470 144522 : END SUBROUTINE dbt_tas_estimate_result_nze
1471 :
1472 : ! **************************************************************************************************
1473 : !> \brief Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
1474 : !> This estimate is based on the minimization of communication volume whereby the
1475 : !> communication of CARMA n-split step and CANNON-multiplication of submatrices are considered.
1476 : !> \param max_mm_dim ...
1477 : !> \param nze_a number of non-zeroes in A
1478 : !> \param nze_b number of non-zeroes in B
1479 : !> \param nze_c number of non-zeroes in C
1480 : !> \param numnodes number of MPI ranks
1481 : !> \return estimated split factor
1482 : !> \author Patrick Seewald
1483 : ! **************************************************************************************************
1484 242266 : FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
1485 : INTEGER, INTENT(IN) :: max_mm_dim
1486 : INTEGER(KIND=int_8), INTENT(IN) :: nze_a, nze_b, nze_c
1487 : INTEGER, INTENT(IN) :: numnodes
1488 : INTEGER :: nsplit
1489 :
1490 : INTEGER(KIND=int_8) :: max_nze, min_nze
1491 : REAL(dp) :: s_opt_factor
1492 :
1493 242266 : s_opt_factor = 1.0_dp ! Could be further tuned.
1494 :
1495 312808 : SELECT CASE (max_mm_dim)
1496 : CASE (1)
1497 70542 : min_nze = MAX(nze_b, 1_int_8)
1498 211626 : max_nze = MAX(MAXVAL([nze_a, nze_c]), 1_int_8)
1499 : CASE (2)
1500 73068 : min_nze = MAX(nze_c, 1_int_8)
1501 219204 : max_nze = MAX(MAXVAL([nze_a, nze_b]), 1_int_8)
1502 : CASE (3)
1503 98656 : min_nze = MAX(nze_a, 1_int_8)
1504 295968 : max_nze = MAX(MAXVAL([nze_b, nze_c]), 1_int_8)
1505 : CASE DEFAULT
1506 242266 : CPABORT("")
1507 : END SELECT
1508 :
1509 242266 : nsplit = INT(MIN(INT(numnodes, KIND=int_8), NINT(REAL(max_nze, dp)/(REAL(min_nze, dp)*s_opt_factor), KIND=int_8)))
1510 242266 : IF (nsplit == 0) nsplit = 1
1511 :
1512 242266 : END FUNCTION split_factor_estimate
1513 :
1514 : ! **************************************************************************************************
1515 : !> \brief Create a matrix with block sizes one that contains the block norms of matrix_in
1516 : !> \param matrix_in ...
1517 : !> \param matrix_out ...
1518 : !> \param nodata ...
1519 : !> \author Patrick Seewald
1520 : ! **************************************************************************************************
1521 1298574 : SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
1522 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix_in
1523 : TYPE(dbt_tas_type), INTENT(OUT) :: matrix_out
1524 : LOGICAL, INTENT(IN), OPTIONAL :: nodata
1525 :
1526 : CHARACTER(len=default_string_length) :: name
1527 : INTEGER(KIND=int_8) :: column, nblkcols, nblkrows, row
1528 : LOGICAL :: nodata_prv
1529 : REAL(dp), DIMENSION(1, 1) :: blk_put
1530 216429 : REAL(dp), DIMENSION(:, :), POINTER :: blk_get
1531 : TYPE(dbt_tas_blk_size_one) :: col_blk_size, row_blk_size
1532 : TYPE(dbt_tas_iterator) :: iter
1533 :
1534 : !REAL(dp), DIMENSION(:, :), POINTER :: dbt_put
1535 :
1536 216429 : CPASSERT(matrix_in%valid)
1537 :
1538 216429 : IF (PRESENT(nodata)) THEN
1539 72143 : nodata_prv = nodata
1540 : ELSE
1541 : nodata_prv = .FALSE.
1542 : END IF
1543 :
1544 216429 : CALL dbt_tas_get_info(matrix_in, name=name, nblkrows_total=nblkrows, nblkcols_total=nblkcols)
1545 216429 : row_blk_size = dbt_tas_blk_size_one(nblkrows)
1546 216429 : col_blk_size = dbt_tas_blk_size_one(nblkcols)
1547 :
1548 : ! not sure if assumption that same distribution can be taken still holds
1549 216429 : CALL dbt_tas_create(matrix_out, name, matrix_in%dist, row_blk_size, col_blk_size)
1550 :
1551 216429 : IF (.NOT. nodata_prv) THEN
1552 144286 : CALL dbt_tas_reserve_blocks(matrix_in, matrix_out)
1553 : !$OMP PARALLEL DEFAULT(NONE) SHARED(matrix_in,matrix_out) &
1554 144286 : !$OMP PRIVATE(iter,row,column,blk_get,blk_put)
1555 : CALL dbt_tas_iterator_start(iter, matrix_in)
1556 : DO WHILE (dbt_tas_iterator_blocks_left(iter))
1557 : CALL dbt_tas_iterator_next_block(iter, row, column, blk_get)
1558 : blk_put(1, 1) = NORM2(blk_get)
1559 : CALL dbt_tas_put_block(matrix_out, row, column, blk_put)
1560 : END DO
1561 : CALL dbt_tas_iterator_stop(iter)
1562 : !$OMP END PARALLEL
1563 : END IF
1564 :
1565 216429 : END SUBROUTINE create_block_norms_matrix
1566 :
1567 : ! **************************************************************************************************
1568 : !> \brief Convert a DBM matrix to a new process grid
1569 : !> \param mp_comm_cart new process grid
1570 : !> \param matrix_in ...
1571 : !> \param matrix_out ...
1572 : !> \param move_data memory optimization: move data such that matrix_in is empty on return.
1573 : !> \param nodata Data of matrix_in should not be copied to matrix_out
1574 : !> \param optimize_pgrid Whether to change process grid
1575 : !> \author Patrick Seewald
1576 : ! **************************************************************************************************
1577 726588 : SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
1578 : TYPE(mp_cart_type), INTENT(IN) :: mp_comm_cart
1579 : TYPE(dbm_type), INTENT(INOUT) :: matrix_in
1580 : TYPE(dbm_type), INTENT(OUT) :: matrix_out
1581 : LOGICAL, INTENT(IN), OPTIONAL :: move_data, nodata, optimize_pgrid
1582 :
1583 : CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_to_new_pgrid'
1584 :
1585 : CHARACTER(len=default_string_length) :: name
1586 : INTEGER :: handle, nbcols, nbrows
1587 726588 : INTEGER, CONTIGUOUS, DIMENSION(:), POINTER :: col_dist, rbsize, rcsize, row_dist
1588 : INTEGER, DIMENSION(2) :: pdims
1589 : LOGICAL :: nodata_prv, optimize_pgrid_prv
1590 : TYPE(dbm_distribution_obj) :: dist, dist_old
1591 :
1592 726588 : NULLIFY (row_dist, col_dist, rbsize, rcsize)
1593 :
1594 726588 : CALL timeset(routineN, handle)
1595 :
1596 726588 : IF (PRESENT(optimize_pgrid)) THEN
1597 726588 : optimize_pgrid_prv = optimize_pgrid
1598 : ELSE
1599 : optimize_pgrid_prv = .TRUE.
1600 : END IF
1601 :
1602 726588 : IF (PRESENT(nodata)) THEN
1603 242196 : nodata_prv = nodata
1604 : ELSE
1605 : nodata_prv = .FALSE.
1606 : END IF
1607 :
1608 726588 : name = dbm_get_name(matrix_in)
1609 :
1610 726588 : IF (.NOT. optimize_pgrid_prv) THEN
1611 726588 : CALL dbm_create_from_template(matrix_out, name=name, template=matrix_in)
1612 726588 : IF (.NOT. nodata_prv) CALL dbm_copy(matrix_out, matrix_in)
1613 726588 : CALL timestop(handle)
1614 726588 : RETURN
1615 : END IF
1616 :
1617 0 : rbsize => dbm_get_row_block_sizes(matrix_in)
1618 0 : rcsize => dbm_get_col_block_sizes(matrix_in)
1619 0 : nbrows = SIZE(rbsize)
1620 0 : nbcols = SIZE(rcsize)
1621 0 : dist_old = dbm_get_distribution(matrix_in)
1622 0 : pdims = mp_comm_cart%num_pe_cart
1623 :
1624 0 : ALLOCATE (row_dist(nbrows), col_dist(nbcols))
1625 0 : CALL dbt_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
1626 0 : CALL dbt_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
1627 :
1628 0 : CALL dbm_distribution_new(dist, mp_comm_cart, row_dist, col_dist)
1629 0 : DEALLOCATE (row_dist, col_dist)
1630 :
1631 0 : CALL dbm_create(matrix_out, name, dist, rbsize, rcsize)
1632 0 : CALL dbm_distribution_release(dist)
1633 :
1634 0 : IF (.NOT. nodata_prv) THEN
1635 0 : CALL dbm_redistribute(matrix_in, matrix_out)
1636 0 : IF (PRESENT(move_data)) THEN
1637 0 : IF (move_data) CALL dbm_clear(matrix_in)
1638 : END IF
1639 : END IF
1640 0 : CALL timestop(handle)
1641 726588 : END SUBROUTINE convert_to_new_pgrid
1642 :
1643 : ! **************************************************************************************************
1644 : !> \brief ...
1645 : !> \param matrix ...
1646 : !> \author Patrick Seewald
1647 : ! **************************************************************************************************
1648 82377 : SUBROUTINE dbt_tas_batched_mm_init(matrix)
1649 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1650 :
1651 82377 : CALL dbt_tas_set_batched_state(matrix, state=1)
1652 82377 : ALLOCATE (matrix%mm_storage)
1653 : matrix%mm_storage%batched_out = .FALSE.
1654 82377 : END SUBROUTINE dbt_tas_batched_mm_init
1655 :
1656 : ! **************************************************************************************************
1657 : !> \brief ...
1658 : !> \param matrix ...
1659 : !> \author Patrick Seewald
1660 : ! **************************************************************************************************
1661 164754 : SUBROUTINE dbt_tas_batched_mm_finalize(matrix)
1662 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1663 :
1664 : INTEGER :: handle
1665 :
1666 82377 : CALL matrix%dist%info%mp_comm%sync()
1667 82377 : CALL timeset("dbt_tas_total", handle)
1668 :
1669 82377 : IF (matrix%do_batched == 0) RETURN
1670 :
1671 82377 : IF (matrix%mm_storage%batched_out) THEN
1672 28072 : CALL dbm_scale(matrix%matrix, matrix%mm_storage%batched_beta)
1673 : END IF
1674 :
1675 82377 : CALL dbt_tas_batched_mm_complete(matrix)
1676 :
1677 82377 : matrix%mm_storage%batched_out = .FALSE.
1678 :
1679 82377 : DEALLOCATE (matrix%mm_storage)
1680 82377 : CALL dbt_tas_set_batched_state(matrix, state=0)
1681 :
1682 82377 : CALL matrix%dist%info%mp_comm%sync()
1683 82377 : CALL timestop(handle)
1684 :
1685 : END SUBROUTINE dbt_tas_batched_mm_finalize
1686 :
1687 : ! **************************************************************************************************
1688 : !> \brief set state flags during batched multiplication
1689 : !> \param matrix ...
1690 : !> \param state 0 no batched MM
1691 : !> 1 batched MM but mm_storage not yet initialized
1692 : !> 2 batched MM and mm_storage requires update
1693 : !> 3 batched MM and mm_storage initialized
1694 : !> \param opt_grid whether process grid was already optimized and should not be changed
1695 : !> \author Patrick Seewald
1696 : ! **************************************************************************************************
1697 1270617 : SUBROUTINE dbt_tas_set_batched_state(matrix, state, opt_grid)
1698 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1699 : INTEGER, INTENT(IN), OPTIONAL :: state
1700 : LOGICAL, INTENT(IN), OPTIONAL :: opt_grid
1701 :
1702 1270617 : IF (PRESENT(opt_grid)) THEN
1703 955115 : matrix%has_opt_pgrid = opt_grid
1704 955115 : matrix%dist%info%strict_split(1) = .TRUE.
1705 : END IF
1706 :
1707 1270617 : IF (PRESENT(state)) THEN
1708 981321 : matrix%do_batched = state
1709 711550 : SELECT CASE (state)
1710 : CASE (0, 1)
1711 : ! reset to default
1712 711550 : IF (matrix%has_opt_pgrid) THEN
1713 410426 : matrix%dist%info%strict_split(1) = .TRUE.
1714 : ELSE
1715 301124 : matrix%dist%info%strict_split(1) = matrix%dist%info%strict_split(2)
1716 : END IF
1717 : CASE (2, 3)
1718 269771 : matrix%dist%info%strict_split(1) = .TRUE.
1719 : CASE DEFAULT
1720 981321 : CPABORT("should not happen")
1721 : END SELECT
1722 : END IF
1723 1270617 : END SUBROUTINE dbt_tas_set_batched_state
1724 :
1725 : ! **************************************************************************************************
1726 : !> \brief ...
1727 : !> \param matrix ...
1728 : !> \param warn ...
1729 : !> \author Patrick Seewald
1730 : ! **************************************************************************************************
1731 1096207 : SUBROUTINE dbt_tas_batched_mm_complete(matrix, warn)
1732 : TYPE(dbt_tas_type), INTENT(INOUT) :: matrix
1733 : LOGICAL, INTENT(IN), OPTIONAL :: warn
1734 :
1735 1096207 : IF (matrix%do_batched == 0) RETURN
1736 : ASSOCIATE (storage => matrix%mm_storage)
1737 85735 : IF (PRESENT(warn)) THEN
1738 1588 : IF (warn .AND. matrix%do_batched == 3) THEN
1739 : CALL cp_warn(__LOCATION__, &
1740 0 : "Optimizations for batched multiplication are disabled because of conflicting data access")
1741 : END IF
1742 : END IF
1743 85735 : IF (storage%batched_out .AND. matrix%do_batched == 3) THEN
1744 :
1745 : CALL dbt_tas_merge(storage%store_batched%matrix, &
1746 29226 : storage%store_batched_repl, move_data=.TRUE.)
1747 :
1748 : CALL dbt_tas_reshape(storage%store_batched, matrix, summation=.TRUE., &
1749 29226 : transposed=storage%batched_trans, move_data=.TRUE.)
1750 29226 : CALL dbt_tas_destroy(storage%store_batched)
1751 29226 : DEALLOCATE (storage%store_batched)
1752 : END IF
1753 :
1754 171470 : IF (ASSOCIATED(storage%store_batched_repl)) THEN
1755 65013 : CALL dbt_tas_destroy(storage%store_batched_repl)
1756 65013 : DEALLOCATE (storage%store_batched_repl)
1757 : END IF
1758 : END ASSOCIATE
1759 :
1760 85735 : CALL dbt_tas_set_batched_state(matrix, state=2)
1761 :
1762 : END SUBROUTINE dbt_tas_batched_mm_complete
1763 :
1764 2000171 : END MODULE dbt_tas_mm
|