Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
10 : ! **************************************************************************************************
11 : MODULE skala_gpw_features
12 : USE cell_types, ONLY: cell_type,&
13 : pbc
14 : USE cp_array_utils, ONLY: cp_3d_r_cp_type
15 : USE kinds, ONLY: dp,&
16 : int_8
17 : USE message_passing, ONLY: mp_comm_type
18 : USE particle_types, ONLY: particle_type
19 : USE pw_grid_types, ONLY: pw_grid_type
20 : USE pw_types, ONLY: pw_r3d_rs_type
21 : USE torch_api, ONLY: &
22 : torch_dict_clone, torch_dict_create, torch_dict_insert, torch_dict_release, &
23 : torch_dict_type, torch_tensor_expand_dim, torch_tensor_from_array, torch_tensor_narrow, &
24 : torch_tensor_release, torch_tensor_reset_from_array, torch_tensor_to_device_leaf, &
25 : torch_tensor_type
26 : USE xc_rho_set_types, ONLY: xc_rho_set_get,&
27 : xc_rho_set_type
28 : #include "./base/base_uses.f90"
29 :
30 : IMPLICIT NONE
31 :
32 : PRIVATE
33 :
34 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_features'
35 : REAL(KIND=dp), PARAMETER, PRIVATE :: layout_tol = 1.0E-12_dp
36 : INTEGER, PARAMETER, PRIVATE :: ndynamic_per_point = 10, nrks_dynamic_per_point = 5, &
37 : nstatic_per_point = 4, ngrad_per_point = 10
38 : INTEGER, PARAMETER, PUBLIC :: skala_gpw_atom_partition_hard = 1, &
39 : skala_gpw_atom_partition_smooth = 2
40 : REAL(KIND=dp), PARAMETER, PRIVATE :: smooth_partition_eps = 1.0E-12_dp
41 :
42 : PUBLIC :: skala_gpw_atom_subchunk_count, skala_gpw_feature_build, &
43 : skala_gpw_feature_build_atom_subchunk, skala_gpw_feature_release, &
44 : skala_gpw_feature_type, skala_gpw_smooth_partition_derivatives
45 :
46 : TYPE skala_gpw_layout_cache_type
47 : INTEGER :: chunk_atom_begin = 1, chunk_atom_end = 0, &
48 : chunk_feature_begin = 1, &
49 : chunk_feature_count = 0, chunk_natom = 0, &
50 : natom = 0, nflat = 0, nflat_local = 0, &
51 : npoint = 0, nproc = 0, &
52 : atom_partition = skala_gpw_atom_partition_hard
53 : INTEGER, DIMENSION(2, 3) :: bo = 0, bounds = 0
54 : INTEGER, DIMENSION(3) :: npts = 0
55 : INTEGER, ALLOCATABLE, DIMENSION(:) :: dynamic_counts, dynamic_displs, &
56 : chunk_feature_counts, chunk_feature_displs, &
57 : chunk_grad_counts, chunk_grad_displs, &
58 : feature_counts, feature_displs, &
59 : feature_source_points, global_to_feature, &
60 : local_feature_counts, local_feature_offsets, &
61 : local_feature_points, local_feature_rows, &
62 : route_grad_return_recv_counts, &
63 : route_grad_return_recv_displs, &
64 : route_grad_return_send_counts, &
65 : route_grad_return_send_displs, &
66 : route_local_dest, chunk_return_positions, &
67 : route_point_recv_counts, &
68 : route_point_recv_displs, &
69 : route_point_send_counts, &
70 : route_point_send_displs, &
71 : route_send_local_rows
72 : INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
73 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes, chunk_atomic_grid_sizes, &
74 : chunk_feature_indices
75 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: local_feature_indices
76 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape, &
77 : chunk_atomic_grid_size_bound_shape
78 : TYPE(torch_dict_type) :: chunk_inputs
79 : TYPE(torch_dict_type) :: chunk_static_inputs
80 : TYPE(torch_dict_type) :: inputs
81 : TYPE(torch_dict_type) :: static_inputs
82 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
83 : TYPE(torch_tensor_type) :: atomic_grid_sizes_t
84 : TYPE(torch_tensor_type) :: atomic_grid_weights_t
85 : TYPE(torch_tensor_type) :: chunk_atomic_grid_size_bound_shape_t
86 : TYPE(torch_tensor_type) :: chunk_atomic_grid_sizes_t
87 : TYPE(torch_tensor_type) :: chunk_atomic_grid_weights_t
88 : TYPE(torch_tensor_type) :: chunk_coarse_0_atomic_coords_t
89 : TYPE(torch_tensor_type) :: chunk_density_t
90 : TYPE(torch_tensor_type) :: chunk_density_input_t
91 : TYPE(torch_tensor_type) :: chunk_feature_indices_t
92 : TYPE(torch_tensor_type) :: chunk_grad_t
93 : TYPE(torch_tensor_type) :: chunk_grad_input_t
94 : TYPE(torch_tensor_type) :: chunk_grid_coords_t
95 : TYPE(torch_tensor_type) :: chunk_grid_weights_t
96 : TYPE(torch_tensor_type) :: chunk_kin_t
97 : TYPE(torch_tensor_type) :: chunk_kin_input_t
98 : TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
99 : TYPE(torch_tensor_type) :: density_t
100 : TYPE(torch_tensor_type) :: grid_coords_t
101 : TYPE(torch_tensor_type) :: grid_weights_t
102 : TYPE(torch_tensor_type) :: grad_t
103 : TYPE(torch_tensor_type) :: kin_t
104 : TYPE(torch_tensor_type) :: local_feature_indices_t
105 : REAL(KIND=dp) :: dvol = 0.0_dp, weight_sum = 0.0_dp, &
106 : weight_sumsq = 0.0_dp
107 : REAL(KIND=dp), DIMENSION(3, 3) :: cell_hmat = 0.0_dp, dh = 0.0_dp
108 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, chunk_atomic_grid_weights, &
109 : chunk_grid_weights, grid_weights
110 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords, chunk_coarse_0_atomic_coords, &
111 : chunk_grid_coords, coarse_0_atomic_coords, &
112 : grid_coords
113 : LOGICAL :: active = .FALSE., has_weights = .FALSE., &
114 : chunk_dynamic_input_views_active = .FALSE., &
115 : chunk_dynamic_tensors_active = .FALSE., &
116 : chunk_inputs_active = .FALSE., &
117 : chunk_inputs_use_collapsed_rks = .FALSE., &
118 : chunk_static_tensors_active = .FALSE., &
119 : dynamic_tensors_active = .FALSE., &
120 : inputs_active = .FALSE., &
121 : static_tensors_active = .FALSE.
122 : END TYPE skala_gpw_layout_cache_type
123 :
124 : TYPE skala_gpw_feature_type
125 : INTEGER :: chunk_feature_count = 0, nflat = 0, &
126 : nflat_local = 0, &
127 : atom_partition = skala_gpw_atom_partition_hard
128 : TYPE(torch_dict_type) :: inputs
129 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t
130 : TYPE(torch_tensor_type) :: atomic_grid_sizes_t
131 : TYPE(torch_tensor_type) :: atomic_grid_weights_t
132 : TYPE(torch_tensor_type) :: coarse_0_atomic_coords_t
133 : TYPE(torch_tensor_type) :: density_input_t
134 : TYPE(torch_tensor_type) :: density_t
135 : TYPE(torch_tensor_type) :: grad_t
136 : TYPE(torch_tensor_type) :: grad_input_t
137 : TYPE(torch_tensor_type) :: grid_coords_t
138 : TYPE(torch_tensor_type) :: grid_weights_t
139 : TYPE(torch_tensor_type) :: kin_input_t
140 : TYPE(torch_tensor_type) :: kin_t
141 : TYPE(torch_tensor_type) :: local_feature_indices_t
142 : INTEGER, ALLOCATABLE, DIMENSION(:) :: chunk_grad_counts, chunk_grad_displs, &
143 : local_feature_counts, local_feature_offsets, &
144 : local_feature_rows, &
145 : chunk_return_positions, &
146 : route_grad_return_recv_counts, &
147 : route_grad_return_recv_displs, &
148 : route_grad_return_send_counts, &
149 : route_grad_return_send_displs, &
150 : route_point_recv_counts, &
151 : route_point_recv_displs, &
152 : route_point_send_counts, &
153 : route_point_send_displs, &
154 : route_send_local_rows
155 : INTEGER, ALLOCATABLE, DIMENSION(:, :, :) :: feature_index
156 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes
157 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
158 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, grid_weights
159 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: chunk_density, chunk_kin, &
160 : coarse_0_atomic_coords, density, &
161 : grid_coords, kin
162 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: chunk_grad, grad
163 : REAL(KIND=dp) :: electron_count = 0.0_dp, &
164 : grid_weight_sum = 0.0_dp, &
165 : spin_moment = 0.0_dp
166 : LOGICAL :: active = .FALSE., owns_coordinate_tensor = .FALSE., &
167 : owns_grid_coordinate_tensor = .FALSE., &
168 : owns_weight_tensors = .FALSE., &
169 : owns_dynamic_tensors = .TRUE., &
170 : owns_inputs = .TRUE., &
171 : owns_static_tensors = .TRUE., &
172 : uses_atom_chunk_routing = .FALSE., &
173 : uses_atom_chunks = .FALSE., &
174 : uses_collapsed_rks_dynamic = .FALSE.
175 : END TYPE skala_gpw_feature_type
176 :
177 : TYPE(skala_gpw_layout_cache_type), SAVE :: cached_layout
178 :
179 : CONTAINS
180 :
181 : ! **************************************************************************************************
182 : !> \brief Build a flat SKALA molecular feature dictionary from a local GPW grid.
183 : !> \param features ...
184 : !> \param rho_set ...
185 : !> \param rho_r ...
186 : !> \param particle_set ...
187 : !> \param cell ...
188 : !> \param requires_grad ...
189 : !> \param weights ...
190 : !> \param requires_coordinate_grad ...
191 : !> \param requires_stress_grad ...
192 : !> \param use_atom_chunks ...
193 : !> \param route_atom_chunks ...
194 : !> \param atom_partition ...
195 : ! **************************************************************************************************
196 144 : SUBROUTINE skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
197 : requires_grad, weights, requires_coordinate_grad, &
198 : requires_stress_grad, use_atom_chunks, route_atom_chunks, &
199 : atom_partition)
200 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
201 : TYPE(xc_rho_set_type), INTENT(IN) :: rho_set
202 : TYPE(pw_r3d_rs_type), DIMENSION(:), INTENT(IN) :: rho_r
203 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
204 : TYPE(cell_type), POINTER :: cell
205 : LOGICAL, INTENT(IN), OPTIONAL :: requires_grad
206 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
207 : LOGICAL, INTENT(IN), OPTIONAL :: requires_coordinate_grad, &
208 : requires_stress_grad, use_atom_chunks, &
209 : route_atom_chunks
210 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
211 :
212 : INTEGER :: handle, i, ipt, ispin, j, k, local_row, my_atom_partition, &
213 : ndynamic_local_per_point, nflat, nflat_local, nspins, phase_handle, real_base, row
214 : INTEGER, DIMENSION(2, 3) :: bo
215 : LOGICAL :: can_use_atom_chunks, collapse_spin_dynamics, my_requires_coordinate_grad, &
216 : my_requires_grad, my_requires_stress_grad, my_route_atom_chunks, my_use_atom_chunks
217 144 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: global_dynamic, local_dynamic
218 144 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: rho, rhoa, rhob, tau_a, tau_b, tau_total
219 1728 : TYPE(cp_3d_r_cp_type), DIMENSION(3) :: drho, drhoa, drhob
220 : TYPE(pw_grid_type), POINTER :: pw_grid
221 :
222 144 : CALL timeset("skala_gpw_feature_build", handle)
223 :
224 144 : my_requires_grad = .FALSE.
225 144 : IF (PRESENT(requires_grad)) my_requires_grad = requires_grad
226 144 : my_requires_coordinate_grad = .FALSE.
227 144 : IF (PRESENT(requires_coordinate_grad)) &
228 144 : my_requires_coordinate_grad = requires_coordinate_grad
229 144 : my_requires_stress_grad = .FALSE.
230 144 : IF (PRESENT(requires_stress_grad)) my_requires_stress_grad = requires_stress_grad
231 144 : my_use_atom_chunks = .FALSE.
232 144 : IF (PRESENT(use_atom_chunks)) my_use_atom_chunks = use_atom_chunks
233 144 : my_route_atom_chunks = .FALSE.
234 144 : IF (PRESENT(route_atom_chunks)) my_route_atom_chunks = route_atom_chunks
235 144 : my_atom_partition = skala_gpw_atom_partition_hard
236 144 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
237 144 : IF (my_atom_partition /= skala_gpw_atom_partition_hard .AND. &
238 : my_atom_partition /= skala_gpw_atom_partition_smooth) THEN
239 0 : CALL cp_abort(__LOCATION__, "Unknown native SKALA atom-partition mode.")
240 : END IF
241 144 : CPASSERT(ASSOCIATED(cell))
242 144 : CPASSERT(ASSOCIATED(particle_set))
243 144 : CPASSERT(SIZE(rho_r) == 1 .OR. SIZE(rho_r) == 2)
244 144 : CPASSERT(ASSOCIATED(rho_r(1)%pw_grid))
245 144 : pw_grid => rho_r(1)%pw_grid
246 :
247 144 : nspins = SIZE(rho_r)
248 1440 : bo = pw_grid%bounds_local
249 144 : nflat_local = pw_grid%ngpts_local
250 :
251 144 : CALL timeset("skala_gpw_pre_release", phase_handle)
252 144 : CALL skala_gpw_feature_release(features)
253 144 : CALL timestop(phase_handle)
254 :
255 144 : CALL timeset("skala_gpw_layout_cache", phase_handle)
256 144 : CALL ensure_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
257 144 : CALL timestop(phase_handle)
258 144 : nflat = cached_layout%nflat
259 144 : can_use_atom_chunks = my_use_atom_chunks .AND. cached_layout%chunk_feature_count > 0
260 144 : IF (my_requires_stress_grad .AND. can_use_atom_chunks) THEN
261 : CALL cp_abort(__LOCATION__, &
262 0 : "Native SKALA analytical stress is not implemented with atom-chunk routing yet.")
263 : END IF
264 : collapse_spin_dynamics = (nspins == 1 .AND. can_use_atom_chunks .AND. &
265 144 : my_route_atom_chunks)
266 144 : ndynamic_local_per_point = ndynamic_per_point
267 144 : IF (collapse_spin_dynamics) ndynamic_local_per_point = nrks_dynamic_per_point
268 432 : ALLOCATE (local_dynamic(ndynamic_local_per_point*nflat_local))
269 144 : local_dynamic = 0.0_dp
270 :
271 144 : CALL timeset("skala_gpw_pack_local", phase_handle)
272 144 : IF (nspins == 1) THEN
273 90 : CALL xc_rho_set_get(rho_set, rho=rho, drho=drho, tau=tau_total)
274 : ELSE
275 : CALL xc_rho_set_get(rho_set, rhoa=rhoa, rhob=rhob, drhoa=drhoa, drhob=drhob, &
276 54 : tau_a=tau_a, tau_b=tau_b)
277 : END IF
278 :
279 144 : local_row = 0
280 3260 : DO k = bo(1, 3), bo(2, 3)
281 86728 : DO j = bo(1, 2), bo(2, 2)
282 1453202 : DO i = bo(1, 1), bo(2, 1)
283 1366618 : local_row = local_row + 1
284 1366618 : real_base = ndynamic_local_per_point*(local_row - 1)
285 :
286 1450086 : IF (nspins == 1) THEN
287 1012243 : IF (collapse_spin_dynamics) THEN
288 64000 : local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
289 64000 : local_dynamic(real_base + 2) = 0.5_dp*drho(1)%array(i, j, k)
290 64000 : local_dynamic(real_base + 3) = 0.5_dp*drho(2)%array(i, j, k)
291 64000 : local_dynamic(real_base + 4) = 0.5_dp*drho(3)%array(i, j, k)
292 64000 : local_dynamic(real_base + 5) = 0.5_dp*tau_total(i, j, k)
293 : ELSE
294 948243 : local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
295 948243 : local_dynamic(real_base + 2) = 0.5_dp*rho(i, j, k)
296 2844729 : DO ispin = 1, 2
297 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 1) = &
298 1896486 : 0.5_dp*drho(1)%array(i, j, k)
299 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 2) = &
300 1896486 : 0.5_dp*drho(2)%array(i, j, k)
301 : local_dynamic(real_base + 2 + 3*(ispin - 1) + 3) = &
302 1896486 : 0.5_dp*drho(3)%array(i, j, k)
303 2844729 : local_dynamic(real_base + 8 + ispin) = 0.5_dp*tau_total(i, j, k)
304 : END DO
305 : END IF
306 : ELSE
307 354375 : local_dynamic(real_base + 1) = rhoa(i, j, k)
308 354375 : local_dynamic(real_base + 2) = rhob(i, j, k)
309 354375 : local_dynamic(real_base + 3) = drhoa(1)%array(i, j, k)
310 354375 : local_dynamic(real_base + 4) = drhoa(2)%array(i, j, k)
311 354375 : local_dynamic(real_base + 5) = drhoa(3)%array(i, j, k)
312 354375 : local_dynamic(real_base + 6) = drhob(1)%array(i, j, k)
313 354375 : local_dynamic(real_base + 7) = drhob(2)%array(i, j, k)
314 354375 : local_dynamic(real_base + 8) = drhob(3)%array(i, j, k)
315 354375 : local_dynamic(real_base + 9) = tau_a(i, j, k)
316 354375 : local_dynamic(real_base + 10) = tau_b(i, j, k)
317 : END IF
318 : END DO
319 : END DO
320 : END DO
321 144 : CALL timestop(phase_handle)
322 :
323 144 : CALL timeset("skala_gpw_copy_layout", phase_handle)
324 : CALL copy_cached_layout(features, my_requires_coordinate_grad .OR. my_requires_stress_grad, &
325 : my_requires_stress_grad .OR. &
326 : (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
327 278 : (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
328 144 : CALL timestop(phase_handle)
329 :
330 144 : IF (can_use_atom_chunks .AND. my_route_atom_chunks) THEN
331 2 : CALL timeset("skala_gpw_route_dyn", phase_handle)
332 : CALL route_atom_chunk_dynamics(features, local_dynamic, pw_grid%para%group, &
333 2 : collapse_spin_dynamics)
334 2 : features%uses_atom_chunk_routing = .TRUE.
335 2 : features%uses_atom_chunks = .TRUE.
336 2 : CALL timestop(phase_handle)
337 : ELSE
338 426 : ALLOCATE (global_dynamic(ndynamic_per_point*cached_layout%npoint))
339 142 : CALL timeset("skala_gpw_allgatherv", phase_handle)
340 : CALL pw_grid%para%group%allgatherv(local_dynamic, global_dynamic, &
341 : cached_layout%dynamic_counts, &
342 142 : cached_layout%dynamic_displs)
343 142 : CALL timestop(phase_handle)
344 :
345 142 : CALL timeset("skala_gpw_reorder_dyn", phase_handle)
346 0 : ALLOCATE (features%density(nflat, 2), features%grad(nflat, 3, 2), &
347 994 : features%kin(nflat, 2))
348 5485978 : features%density = 0.0_dp
349 16457934 : features%grad = 0.0_dp
350 5485978 : features%kin = 0.0_dp
351 :
352 2742918 : DO row = 1, nflat
353 2742776 : ipt = cached_layout%feature_source_points(row)
354 2742776 : real_base = ndynamic_per_point*(ipt - 1)
355 8228328 : features%density(row, :) = global_dynamic(real_base + 1:real_base + 2)
356 2742776 : features%grad(row, 1, 1) = global_dynamic(real_base + 3)
357 2742776 : features%grad(row, 2, 1) = global_dynamic(real_base + 4)
358 2742776 : features%grad(row, 3, 1) = global_dynamic(real_base + 5)
359 2742776 : features%grad(row, 1, 2) = global_dynamic(real_base + 6)
360 2742776 : features%grad(row, 2, 2) = global_dynamic(real_base + 7)
361 2742776 : features%grad(row, 3, 2) = global_dynamic(real_base + 8)
362 8228470 : features%kin(row, :) = global_dynamic(real_base + 9:real_base + 10)
363 : END DO
364 426 : CALL timestop(phase_handle)
365 : END IF
366 :
367 144 : CALL timeset("skala_gpw_feature_sums", phase_handle)
368 144 : IF (features%uses_atom_chunks) THEN
369 2 : IF (features%uses_collapsed_rks_dynamic) THEN
370 : features%electron_count = SUM(2.0_dp*features%chunk_density(:, 1)* &
371 64002 : cached_layout%chunk_grid_weights)
372 2 : features%spin_moment = 0.0_dp
373 : ELSE
374 : features%electron_count = SUM((features%chunk_density(:, 1) + &
375 : features%chunk_density(:, 2))* &
376 0 : cached_layout%chunk_grid_weights)
377 : features%spin_moment = SUM((features%chunk_density(:, 1) - &
378 : features%chunk_density(:, 2))* &
379 0 : cached_layout%chunk_grid_weights)
380 : END IF
381 2 : CALL pw_grid%para%group%sum(features%electron_count)
382 2 : CALL pw_grid%para%group%sum(features%spin_moment)
383 : ELSE
384 : features%electron_count = SUM((features%density(:, 1) + features%density(:, 2))* &
385 2742918 : features%grid_weights)
386 : features%spin_moment = SUM((features%density(:, 1) - features%density(:, 2))* &
387 2742918 : features%grid_weights)
388 : END IF
389 2870920 : features%grid_weight_sum = SUM(features%grid_weights)
390 144 : CALL timestop(phase_handle)
391 :
392 144 : CALL timeset("skala_gpw_tensor_update", phase_handle)
393 144 : IF (can_use_atom_chunks .AND. .NOT. features%uses_atom_chunks) THEN
394 0 : CALL extract_atom_chunk_dynamics(features)
395 0 : features%uses_atom_chunks = .TRUE.
396 : END IF
397 : CALL add_feature_tensors(features, my_requires_grad, my_requires_coordinate_grad, &
398 : my_requires_stress_grad, &
399 : features%uses_atom_chunks, &
400 : requires_weight_grad= &
401 : (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
402 284 : (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
403 144 : CALL timestop(phase_handle)
404 144 : features%active = .TRUE.
405 :
406 144 : IF (ALLOCATED(global_dynamic)) DEALLOCATE (global_dynamic)
407 144 : DEALLOCATE (local_dynamic)
408 144 : CALL timestop(handle)
409 :
410 1152 : END SUBROUTINE skala_gpw_feature_build
411 :
412 : ! **************************************************************************************************
413 : !> \brief Ensure that static grid-to-atom layout data is cached for the current grid/geometry.
414 : !> \param pw_grid ...
415 : !> \param particle_set ...
416 : !> \param cell ...
417 : !> \param weights ...
418 : !> \param atom_partition ...
419 : ! **************************************************************************************************
420 144 : SUBROUTINE ensure_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
421 : TYPE(pw_grid_type), POINTER :: pw_grid
422 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
423 : TYPE(cell_type), POINTER :: cell
424 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
425 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
426 :
427 : INTEGER :: my_atom_partition, phase_handle
428 : LOGICAL :: cache_matches
429 :
430 144 : my_atom_partition = skala_gpw_atom_partition_hard
431 144 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
432 144 : IF (PRESENT(weights)) THEN
433 144 : CALL timeset("skala_gpw_layout_match", phase_handle)
434 : cache_matches = layout_cache_matches(pw_grid, particle_set, cell, weights, &
435 144 : my_atom_partition)
436 144 : CALL timestop(phase_handle)
437 144 : IF (cache_matches) RETURN
438 52 : CALL timeset("skala_gpw_layout_rebuild", phase_handle)
439 52 : CALL rebuild_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
440 52 : CALL timestop(phase_handle)
441 : ELSE
442 0 : CALL timeset("skala_gpw_layout_match", phase_handle)
443 : cache_matches = layout_cache_matches(pw_grid, particle_set, cell, &
444 0 : atom_partition=my_atom_partition)
445 0 : CALL timestop(phase_handle)
446 0 : IF (cache_matches) RETURN
447 0 : CALL timeset("skala_gpw_layout_rebuild", phase_handle)
448 : CALL rebuild_layout_cache(pw_grid, particle_set, cell, &
449 0 : atom_partition=my_atom_partition)
450 0 : CALL timestop(phase_handle)
451 : END IF
452 :
453 : END SUBROUTINE ensure_layout_cache
454 :
455 : ! **************************************************************************************************
456 : !> \brief Check whether the current static layout cache can be reused.
457 : !> \param pw_grid ...
458 : !> \param particle_set ...
459 : !> \param cell ...
460 : !> \param weights ...
461 : !> \param atom_partition ...
462 : !> \return ...
463 : ! **************************************************************************************************
464 144 : FUNCTION layout_cache_matches(pw_grid, particle_set, cell, weights, atom_partition) RESULT(matches)
465 : TYPE(pw_grid_type), POINTER :: pw_grid
466 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
467 : TYPE(cell_type), POINTER :: cell
468 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
469 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
470 : LOGICAL :: matches
471 :
472 : INTEGER :: iatom, my_atom_partition
473 : LOGICAL :: weights_match
474 :
475 144 : my_atom_partition = skala_gpw_atom_partition_hard
476 144 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
477 144 : matches = .FALSE.
478 144 : IF (.NOT. cached_layout%active) RETURN
479 100 : IF (cached_layout%atom_partition /= my_atom_partition) RETURN
480 100 : IF (cached_layout%natom /= SIZE(particle_set)) RETURN
481 100 : IF (cached_layout%nflat_local /= pw_grid%ngpts_local) RETURN
482 100 : IF (cached_layout%nproc /= pw_grid%para%group%num_pe) RETURN
483 1000 : IF (ANY(cached_layout%bo /= pw_grid%bounds_local)) RETURN
484 1000 : IF (ANY(cached_layout%bounds /= pw_grid%bounds)) RETURN
485 400 : IF (ANY(cached_layout%npts /= pw_grid%npts)) RETURN
486 100 : IF (ABS(cached_layout%dvol - pw_grid%dvol) > layout_tol) RETURN
487 1300 : IF (ANY(ABS(cached_layout%dh - pw_grid%dh) > layout_tol)) RETURN
488 1300 : IF (ANY(ABS(cached_layout%cell_hmat - cell%hmat) > layout_tol)) RETURN
489 100 : IF (.NOT. ALLOCATED(cached_layout%atom_coords)) RETURN
490 :
491 292 : DO iatom = 1, SIZE(particle_set)
492 884 : IF (ANY(ABS(cached_layout%atom_coords(:, iatom) - particle_set(iatom)%r) > layout_tol)) RETURN
493 : END DO
494 :
495 92 : IF (PRESENT(weights)) THEN
496 92 : weights_match = layout_weights_match(pw_grid, weights)
497 : ELSE
498 0 : weights_match = layout_weights_match(pw_grid)
499 : END IF
500 92 : IF (.NOT. weights_match) RETURN
501 :
502 144 : matches = .TRUE.
503 :
504 : END FUNCTION layout_cache_matches
505 :
506 : ! **************************************************************************************************
507 : !> \brief Check whether current optional integration weights match the cached static tensors.
508 : !> \param pw_grid ...
509 : !> \param weights ...
510 : !> \return ...
511 : ! **************************************************************************************************
512 92 : FUNCTION layout_weights_match(pw_grid, weights) RESULT(matches)
513 : TYPE(pw_grid_type), POINTER :: pw_grid
514 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
515 : LOGICAL :: matches
516 :
517 : LOGICAL :: has_weights
518 : REAL(KIND=dp) :: weight_sum, weight_sumsq
519 :
520 92 : matches = .FALSE.
521 : MARK_USED(pw_grid)
522 92 : IF (PRESENT(weights)) THEN
523 92 : CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
524 : ELSE
525 : CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
526 0 : weight_sumsq=weight_sumsq)
527 : END IF
528 :
529 92 : IF (cached_layout%has_weights .NEQV. has_weights) RETURN
530 92 : IF (ABS(cached_layout%weight_sum - weight_sum) > layout_tol) RETURN
531 92 : IF (ABS(cached_layout%weight_sumsq - weight_sumsq) > layout_tol) RETURN
532 :
533 92 : matches = .TRUE.
534 :
535 : END FUNCTION layout_weights_match
536 :
537 : ! **************************************************************************************************
538 : !> \brief Build the static SKALA layout cache.
539 : !> \param pw_grid ...
540 : !> \param particle_set ...
541 : !> \param cell ...
542 : !> \param weights ...
543 : !> \param atom_partition ...
544 : ! **************************************************************************************************
545 52 : SUBROUTINE rebuild_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
546 : TYPE(pw_grid_type), POINTER :: pw_grid
547 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
548 : TYPE(cell_type), POINTER :: cell
549 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
550 : INTEGER, INTENT(IN), OPTIONAL :: atom_partition
551 :
552 : INTEGER :: feature_local, i, iatom, ipt, j, k, local_row, max_grid_size, max_local_features, &
553 : my_atom_partition, natom, nfeature_local, nflat, nflat_local, npoint, nproc, owner, pe, &
554 : pe_index, phase_handle, row, source_global, source_local, static_base
555 52 : INTEGER, ALLOCATABLE, DIMENSION(:) :: atom_offset, atom_position, chunk_atom_begin, &
556 52 : chunk_atom_end, cursor, feature_counts, feature_displs, global_owner, &
557 52 : global_source_points, local_feature_counts_tmp, local_owner, local_source_global, &
558 52 : local_source_points, point_counts, point_displs, static_counts, static_displs
559 : INTEGER, DIMENSION(2, 3) :: bo
560 : LOGICAL :: has_weights
561 : REAL(KIND=dp) :: base_weight, included_sum, &
562 : partition_weight, weight_sum, &
563 : weight_sumsq
564 52 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: distances, global_static, local_static, &
565 : partition_weights
566 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc, atom_image_coords
567 : REAL(KIND=dp), DIMENSION(3) :: grid_point, owner_coord
568 :
569 52 : CALL release_layout_cache(cached_layout)
570 :
571 52 : my_atom_partition = skala_gpw_atom_partition_hard
572 52 : IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
573 52 : natom = SIZE(particle_set)
574 520 : bo = pw_grid%bounds_local
575 52 : nflat_local = pw_grid%ngpts_local
576 52 : nproc = pw_grid%para%group%num_pe
577 52 : pe_index = pw_grid%para%group%mepos + 1
578 :
579 52 : IF (PRESENT(weights)) THEN
580 52 : CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
581 : ELSE
582 : CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
583 0 : weight_sumsq=weight_sumsq)
584 : END IF
585 :
586 52 : max_local_features = nflat_local
587 52 : IF (my_atom_partition == skala_gpw_atom_partition_smooth) &
588 6 : max_local_features = nflat_local*natom
589 : ALLOCATE (local_owner(max_local_features), &
590 : local_source_points(max_local_features), &
591 : local_static(nstatic_per_point*max_local_features), &
592 : local_feature_counts_tmp(nflat_local), feature_counts(nproc), &
593 : feature_displs(nproc), point_counts(nproc), point_displs(nproc), &
594 : static_counts(nproc), static_displs(nproc), atom_coords_pbc(3, natom), &
595 1092 : atom_image_coords(3, natom), distances(natom), partition_weights(natom))
596 0 : ALLOCATE (cached_layout%feature_index(bo(1, 1):bo(2, 1), &
597 : bo(1, 2):bo(2, 2), &
598 260 : bo(1, 3):bo(2, 3)))
599 1208910 : cached_layout%feature_index = 0
600 52 : local_static = 0.0_dp
601 52 : local_feature_counts_tmp = 0
602 184 : DO iatom = 1, natom
603 184 : atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
604 : END DO
605 :
606 52 : CALL timeset("skala_gpw_layout_local", phase_handle)
607 52 : local_row = 0
608 52 : nfeature_local = 0
609 1680 : DO k = bo(1, 3), bo(2, 3)
610 60236 : DO j = bo(1, 2), bo(2, 2)
611 1208858 : DO i = bo(1, 1), bo(2, 1)
612 1148674 : local_row = local_row + 1
613 4594696 : grid_point = grid_coordinate(pw_grid, [i, j, k])
614 1148674 : base_weight = pw_grid%dvol
615 1148674 : IF (PRESENT(weights)) THEN
616 1148674 : IF (ASSOCIATED(weights)) base_weight = base_weight*weights%array(i, j, k)
617 : END IF
618 1148674 : cached_layout%feature_index(i, j, k) = local_row
619 :
620 1207230 : IF (my_atom_partition == skala_gpw_atom_partition_hard) THEN
621 1107202 : owner = nearest_atom(grid_point, atom_coords_pbc, cell)
622 4428808 : owner_coord = atom_coords_pbc(:, owner)
623 1107202 : nfeature_local = nfeature_local + 1
624 1107202 : local_feature_counts_tmp(local_row) = 1
625 1107202 : local_owner(nfeature_local) = owner
626 1107202 : local_source_points(nfeature_local) = local_row
627 1107202 : static_base = nstatic_per_point*(nfeature_local - 1)
628 : local_static(static_base + 1:static_base + 3) = &
629 1107202 : nearest_image_coordinate(owner_coord, grid_point, cell)
630 1107202 : local_static(static_base + 4) = base_weight
631 : ELSE
632 : CALL smooth_atom_partition(grid_point, atom_coords_pbc, cell, &
633 41472 : partition_weights, atom_image_coords, distances)
634 124416 : included_sum = SUM(partition_weights, MASK=partition_weights > smooth_partition_eps)
635 41472 : IF (included_sum <= 0.0_dp) THEN
636 0 : owner = nearest_atom(grid_point, atom_coords_pbc, cell)
637 0 : partition_weights = 0.0_dp
638 0 : partition_weights(owner) = 1.0_dp
639 0 : included_sum = 1.0_dp
640 : END IF
641 124416 : DO iatom = 1, natom
642 82944 : IF (partition_weights(iatom) <= smooth_partition_eps) CYCLE
643 82734 : partition_weight = partition_weights(iatom)/included_sum
644 82734 : nfeature_local = nfeature_local + 1
645 : local_feature_counts_tmp(local_row) = &
646 82734 : local_feature_counts_tmp(local_row) + 1
647 82734 : local_owner(nfeature_local) = iatom
648 82734 : local_source_points(nfeature_local) = local_row
649 82734 : static_base = nstatic_per_point*(nfeature_local - 1)
650 : local_static(static_base + 1:static_base + 3) = &
651 330936 : atom_image_coords(:, iatom)
652 124416 : local_static(static_base + 4) = base_weight*partition_weight
653 : END DO
654 : END IF
655 : END DO
656 : END DO
657 : END DO
658 52 : CALL timestop(phase_handle)
659 :
660 : ! SKALA groups all grid points by atom. This ordering is static while the
661 : ! grid, cell, atom positions, and optional integration weights are unchanged.
662 52 : CALL timeset("skala_gpw_layout_gather", phase_handle)
663 52 : CALL pw_grid%para%group%allgather(nflat_local, point_counts)
664 52 : CALL counts_to_displs(point_counts, point_displs)
665 156 : npoint = SUM(point_counts)
666 52 : CALL pw_grid%para%group%allgather(nfeature_local, feature_counts)
667 52 : CALL counts_to_displs(feature_counts, feature_displs)
668 156 : DO pe = 1, nproc
669 104 : static_counts(pe) = nstatic_per_point*feature_counts(pe)
670 156 : static_displs(pe) = nstatic_per_point*feature_displs(pe)
671 : END DO
672 156 : nflat = SUM(feature_counts)
673 : ALLOCATE (global_owner(nflat), global_source_points(nflat), &
674 416 : global_static(nstatic_per_point*nflat), local_source_global(nfeature_local))
675 1189988 : DO feature_local = 1, nfeature_local
676 1189988 : local_source_global(feature_local) = point_displs(pe_index) + local_source_points(feature_local)
677 : END DO
678 : CALL pw_grid%para%group%allgatherv(local_owner(1:nfeature_local), global_owner, feature_counts, &
679 52 : feature_displs)
680 : CALL pw_grid%para%group%allgatherv(local_source_global, global_source_points, feature_counts, &
681 52 : feature_displs)
682 : CALL pw_grid%para%group%allgatherv(local_static(1:nstatic_per_point*nfeature_local), &
683 : global_static, static_counts, &
684 52 : static_displs)
685 52 : CALL timestop(phase_handle)
686 :
687 0 : ALLOCATE (cached_layout%chunk_feature_counts(nproc), &
688 0 : cached_layout%chunk_feature_displs(nproc), &
689 0 : cached_layout%chunk_grad_counts(nproc), cached_layout%chunk_grad_displs(nproc), &
690 0 : cached_layout%feature_counts(nproc), cached_layout%feature_displs(nproc), &
691 0 : cached_layout%dynamic_counts(nproc), cached_layout%dynamic_displs(nproc), &
692 0 : cached_layout%route_grad_return_recv_counts(nproc), &
693 0 : cached_layout%route_grad_return_recv_displs(nproc), &
694 0 : cached_layout%route_grad_return_send_counts(nproc), &
695 0 : cached_layout%route_grad_return_send_displs(nproc), &
696 0 : cached_layout%route_point_recv_counts(nproc), &
697 0 : cached_layout%route_point_recv_displs(nproc), &
698 0 : cached_layout%route_point_send_counts(nproc), &
699 0 : cached_layout%route_point_send_displs(nproc), &
700 0 : cached_layout%feature_source_points(nflat), &
701 0 : cached_layout%global_to_feature(npoint), cached_layout%atomic_grid_sizes(natom), &
702 0 : cached_layout%local_feature_counts(nflat_local), &
703 0 : cached_layout%local_feature_offsets(nflat_local + 1), &
704 0 : cached_layout%local_feature_rows(nfeature_local), &
705 0 : cached_layout%local_feature_points(nfeature_local), &
706 0 : cached_layout%local_feature_indices(nfeature_local), atom_offset(natom + 1), &
707 : atom_position(natom), chunk_atom_begin(nproc), chunk_atom_end(nproc), &
708 1820 : cursor(nflat_local))
709 156 : cached_layout%feature_counts(:) = feature_counts
710 156 : cached_layout%feature_displs(:) = feature_displs
711 156 : cached_layout%dynamic_counts(:) = ndynamic_per_point*point_counts
712 156 : cached_layout%dynamic_displs(:) = ndynamic_per_point*point_displs
713 184 : cached_layout%atomic_grid_sizes = 0_int_8
714 2297400 : cached_layout%global_to_feature = 0
715 1148726 : cached_layout%local_feature_counts(:) = local_feature_counts_tmp
716 52 : cached_layout%local_feature_offsets(1) = 1
717 1148726 : DO local_row = 1, nflat_local
718 : cached_layout%local_feature_offsets(local_row + 1) = &
719 : cached_layout%local_feature_offsets(local_row) + &
720 1148726 : cached_layout%local_feature_counts(local_row)
721 : END DO
722 1148726 : cursor(:) = cached_layout%local_feature_offsets(1:nflat_local)
723 :
724 52 : CALL timeset("skala_gpw_layout_atom_sort", phase_handle)
725 2379924 : DO ipt = 1, nflat
726 : cached_layout%atomic_grid_sizes(global_owner(ipt)) = &
727 2379924 : cached_layout%atomic_grid_sizes(global_owner(ipt)) + 1_int_8
728 : END DO
729 52 : atom_offset(1) = 1
730 184 : DO iatom = 1, natom
731 184 : atom_offset(iatom + 1) = atom_offset(iatom) + INT(cached_layout%atomic_grid_sizes(iatom))
732 : END DO
733 184 : DO iatom = 1, natom
734 184 : atom_position(iatom) = atom_offset(iatom)
735 : END DO
736 184 : max_grid_size = MAXVAL(INT(cached_layout%atomic_grid_sizes))
737 : CALL build_atom_chunks(cached_layout%atomic_grid_sizes, atom_offset, nproc, &
738 : chunk_atom_begin, chunk_atom_end, &
739 : cached_layout%chunk_feature_counts, &
740 52 : cached_layout%chunk_feature_displs)
741 156 : cached_layout%chunk_grad_counts(:) = ngrad_per_point*cached_layout%chunk_feature_counts
742 156 : cached_layout%chunk_grad_displs(:) = ngrad_per_point*cached_layout%chunk_feature_displs
743 52 : cached_layout%chunk_atom_begin = chunk_atom_begin(pe_index)
744 52 : cached_layout%chunk_atom_end = chunk_atom_end(pe_index)
745 52 : cached_layout%chunk_feature_begin = cached_layout%chunk_feature_displs(pe_index) + 1
746 52 : cached_layout%chunk_feature_count = cached_layout%chunk_feature_counts(pe_index)
747 : cached_layout%chunk_natom = cached_layout%chunk_atom_end - &
748 52 : cached_layout%chunk_atom_begin + 1
749 :
750 0 : ALLOCATE (cached_layout%grid_coords(3, nflat), cached_layout%grid_weights(nflat), &
751 0 : cached_layout%atomic_grid_weights(nflat), &
752 0 : cached_layout%coarse_0_atomic_coords(3, natom), &
753 0 : cached_layout%atomic_grid_size_bound_shape(0, max_grid_size), &
754 468 : cached_layout%atom_coords(3, natom))
755 9519540 : cached_layout%grid_coords = 0.0_dp
756 2379924 : cached_layout%grid_weights = 0.0_dp
757 2379924 : cached_layout%atomic_grid_weights = 0.0_dp
758 980456 : cached_layout%atomic_grid_size_bound_shape = 0_int_8
759 :
760 184 : DO iatom = 1, natom
761 528 : cached_layout%atom_coords(:, iatom) = particle_set(iatom)%r
762 580 : cached_layout%coarse_0_atomic_coords(:, iatom) = atom_coords_pbc(:, iatom)
763 : END DO
764 :
765 2379924 : DO ipt = 1, nflat
766 2379872 : owner = global_owner(ipt)
767 2379872 : row = atom_position(owner)
768 2379872 : atom_position(owner) = atom_position(owner) + 1
769 2379872 : source_global = global_source_points(ipt)
770 2379872 : cached_layout%feature_source_points(row) = source_global
771 2379872 : IF (cached_layout%global_to_feature(source_global) == 0) &
772 2297348 : cached_layout%global_to_feature(source_global) = row
773 2379872 : static_base = nstatic_per_point*(ipt - 1)
774 9519488 : cached_layout%grid_coords(:, row) = global_static(static_base + 1:static_base + 3)
775 2379872 : cached_layout%grid_weights(row) = global_static(static_base + 4)
776 2379872 : cached_layout%atomic_grid_weights(row) = cached_layout%grid_weights(row)
777 2379872 : source_local = source_global - point_displs(pe_index)
778 2379924 : IF (source_local >= 1 .AND. source_local <= nflat_local) THEN
779 1189936 : feature_local = cursor(source_local)
780 1189936 : cursor(source_local) = cursor(source_local) + 1
781 1189936 : cached_layout%local_feature_rows(feature_local) = row
782 1189936 : cached_layout%local_feature_points(feature_local) = source_local
783 : END IF
784 : END DO
785 :
786 2297400 : CPASSERT(ALL(cached_layout%global_to_feature > 0))
787 1189988 : CPASSERT(ALL(cached_layout%local_feature_rows > 0))
788 1189988 : CPASSERT(ALL(cached_layout%local_feature_points > 0))
789 1680 : DO k = bo(1, 3), bo(2, 3)
790 60236 : DO j = bo(1, 2), bo(2, 2)
791 1208858 : DO i = bo(1, 1), bo(2, 1)
792 1148674 : local_row = cached_layout%feature_index(i, j, k)
793 : cached_layout%feature_index(i, j, k) = &
794 1207230 : cached_layout%local_feature_rows(cached_layout%local_feature_offsets(local_row))
795 : END DO
796 : END DO
797 : END DO
798 1189988 : DO feature_local = 1, nfeature_local
799 : cached_layout%local_feature_indices(feature_local) = &
800 1189988 : INT(cached_layout%local_feature_rows(feature_local) - 1, KIND=int_8)
801 : END DO
802 52 : CALL timestop(phase_handle)
803 52 : CALL timeset("skala_gpw_layout_chunk_routes", phase_handle)
804 : CALL build_atom_chunk_routes(cached_layout, cached_layout%local_feature_rows, &
805 52 : pw_grid%para%group)
806 52 : CALL build_atom_chunk_layout(cached_layout)
807 52 : CALL timestop(phase_handle)
808 :
809 52 : cached_layout%natom = natom
810 52 : cached_layout%nflat = nflat
811 52 : cached_layout%nflat_local = nflat_local
812 52 : cached_layout%npoint = npoint
813 52 : cached_layout%nproc = nproc
814 52 : cached_layout%atom_partition = my_atom_partition
815 520 : cached_layout%bo = bo
816 520 : cached_layout%bounds = pw_grid%bounds
817 208 : cached_layout%npts = pw_grid%npts
818 52 : cached_layout%dvol = pw_grid%dvol
819 676 : cached_layout%dh = pw_grid%dh
820 676 : cached_layout%cell_hmat = cell%hmat
821 52 : cached_layout%weight_sum = weight_sum
822 52 : cached_layout%weight_sumsq = weight_sumsq
823 52 : cached_layout%has_weights = has_weights
824 52 : CALL timeset("skala_gpw_layout_tensors", phase_handle)
825 52 : CALL build_static_layout_tensors(cached_layout)
826 52 : CALL timestop(phase_handle)
827 52 : cached_layout%active = .TRUE.
828 :
829 0 : DEALLOCATE (atom_coords_pbc, atom_image_coords, atom_offset, atom_position, &
830 0 : chunk_atom_begin, chunk_atom_end, cursor, feature_counts, feature_displs, &
831 0 : global_owner, global_source_points, global_static, local_feature_counts_tmp, &
832 0 : distances, local_owner, local_source_global, local_source_points, &
833 0 : local_static, partition_weights, point_counts, point_displs, static_counts, &
834 52 : static_displs)
835 :
836 260 : END SUBROUTINE rebuild_layout_cache
837 :
838 : ! **************************************************************************************************
839 : !> \brief Build cached Torch tensors for static SKALA inputs.
840 : !> \param cache ...
841 : ! **************************************************************************************************
842 52 : SUBROUTINE build_static_layout_tensors(cache)
843 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
844 :
845 52 : CPASSERT(.NOT. cache%static_tensors_active)
846 :
847 52 : CALL torch_tensor_from_array(cache%grid_coords_t, cache%grid_coords)
848 52 : CALL torch_tensor_to_device_leaf(cache%grid_coords_t, .FALSE.)
849 52 : CALL torch_tensor_from_array(cache%grid_weights_t, cache%grid_weights)
850 52 : CALL torch_tensor_to_device_leaf(cache%grid_weights_t, .FALSE.)
851 52 : CALL torch_tensor_from_array(cache%atomic_grid_weights_t, cache%atomic_grid_weights)
852 52 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_weights_t, .FALSE.)
853 52 : CALL torch_tensor_from_array(cache%atomic_grid_sizes_t, cache%atomic_grid_sizes)
854 52 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_sizes_t, .FALSE.)
855 52 : CALL torch_tensor_from_array(cache%coarse_0_atomic_coords_t, cache%coarse_0_atomic_coords)
856 52 : CALL torch_tensor_to_device_leaf(cache%coarse_0_atomic_coords_t, .FALSE.)
857 : CALL torch_tensor_from_array(cache%atomic_grid_size_bound_shape_t, &
858 52 : cache%atomic_grid_size_bound_shape)
859 52 : CALL torch_tensor_to_device_leaf(cache%atomic_grid_size_bound_shape_t, .FALSE.)
860 52 : CALL torch_tensor_from_array(cache%local_feature_indices_t, cache%local_feature_indices)
861 52 : CALL torch_tensor_to_device_leaf(cache%local_feature_indices_t, .FALSE.)
862 :
863 52 : CALL torch_dict_create(cache%static_inputs)
864 52 : CALL torch_dict_insert(cache%static_inputs, "grid_coords", cache%grid_coords_t)
865 52 : CALL torch_dict_insert(cache%static_inputs, "grid_weights", cache%grid_weights_t)
866 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_weights", &
867 52 : cache%atomic_grid_weights_t)
868 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_sizes", &
869 52 : cache%atomic_grid_sizes_t)
870 : CALL torch_dict_insert(cache%static_inputs, "atomic_grid_size_bound_shape", &
871 52 : cache%atomic_grid_size_bound_shape_t)
872 52 : cache%static_tensors_active = .TRUE.
873 :
874 52 : IF (cache%chunk_feature_count > 0) THEN
875 52 : CPASSERT(.NOT. cache%chunk_static_tensors_active)
876 52 : CALL torch_tensor_from_array(cache%chunk_grid_coords_t, cache%chunk_grid_coords)
877 52 : CALL torch_tensor_to_device_leaf(cache%chunk_grid_coords_t, .FALSE.)
878 52 : CALL torch_tensor_from_array(cache%chunk_grid_weights_t, cache%chunk_grid_weights)
879 52 : CALL torch_tensor_to_device_leaf(cache%chunk_grid_weights_t, .FALSE.)
880 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_weights_t, &
881 52 : cache%chunk_atomic_grid_weights)
882 52 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_weights_t, .FALSE.)
883 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_sizes_t, &
884 52 : cache%chunk_atomic_grid_sizes)
885 52 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_sizes_t, .FALSE.)
886 : CALL torch_tensor_from_array(cache%chunk_coarse_0_atomic_coords_t, &
887 52 : cache%chunk_coarse_0_atomic_coords)
888 52 : CALL torch_tensor_to_device_leaf(cache%chunk_coarse_0_atomic_coords_t, .FALSE.)
889 : CALL torch_tensor_from_array(cache%chunk_atomic_grid_size_bound_shape_t, &
890 52 : cache%chunk_atomic_grid_size_bound_shape)
891 52 : CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_size_bound_shape_t, .FALSE.)
892 52 : CALL torch_tensor_from_array(cache%chunk_feature_indices_t, cache%chunk_feature_indices)
893 52 : CALL torch_tensor_to_device_leaf(cache%chunk_feature_indices_t, .FALSE.)
894 :
895 52 : CALL torch_dict_create(cache%chunk_static_inputs)
896 : CALL torch_dict_insert(cache%chunk_static_inputs, "grid_coords", &
897 52 : cache%chunk_grid_coords_t)
898 : CALL torch_dict_insert(cache%chunk_static_inputs, "grid_weights", &
899 52 : cache%chunk_grid_weights_t)
900 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_weights", &
901 52 : cache%chunk_atomic_grid_weights_t)
902 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_sizes", &
903 52 : cache%chunk_atomic_grid_sizes_t)
904 : CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_size_bound_shape", &
905 52 : cache%chunk_atomic_grid_size_bound_shape_t)
906 52 : cache%chunk_static_tensors_active = .TRUE.
907 : END IF
908 :
909 52 : END SUBROUTINE build_static_layout_tensors
910 :
911 : ! **************************************************************************************************
912 : !> \brief Copy static cached layout arrays into a feature bundle.
913 : !> \param features ...
914 : !> \param needs_coordinate_array ...
915 : !> \param needs_grid_coordinate_array ...
916 : ! **************************************************************************************************
917 144 : SUBROUTINE copy_cached_layout(features, needs_coordinate_array, needs_grid_coordinate_array)
918 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
919 : LOGICAL, INTENT(IN) :: needs_coordinate_array, &
920 : needs_grid_coordinate_array
921 :
922 144 : CPASSERT(cached_layout%active)
923 :
924 0 : ALLOCATE (features%feature_index(LBOUND(cached_layout%feature_index, 1): &
925 : UBOUND(cached_layout%feature_index, 1), &
926 : LBOUND(cached_layout%feature_index, 2): &
927 : UBOUND(cached_layout%feature_index, 2), &
928 : LBOUND(cached_layout%feature_index, 3): &
929 720 : UBOUND(cached_layout%feature_index, 3)))
930 432 : ALLOCATE (features%grid_weights(cached_layout%nflat))
931 0 : ALLOCATE (features%local_feature_counts(cached_layout%nflat_local), &
932 0 : features%local_feature_offsets(cached_layout%nflat_local + 1), &
933 1008 : features%local_feature_rows(SIZE(cached_layout%local_feature_rows)))
934 :
935 1453346 : features%feature_index(:, :, :) = cached_layout%feature_index
936 2870920 : features%grid_weights(:) = cached_layout%grid_weights
937 1366762 : features%local_feature_counts(:) = cached_layout%local_feature_counts
938 1366906 : features%local_feature_offsets(:) = cached_layout%local_feature_offsets
939 1435532 : features%local_feature_rows(:) = cached_layout%local_feature_rows
940 144 : features%nflat = cached_layout%nflat
941 144 : features%nflat_local = cached_layout%nflat_local
942 144 : features%chunk_feature_count = cached_layout%chunk_feature_count
943 144 : features%atom_partition = cached_layout%atom_partition
944 432 : ALLOCATE (features%atomic_grid_sizes(cached_layout%natom))
945 460 : features%atomic_grid_sizes(:) = cached_layout%atomic_grid_sizes
946 144 : IF (needs_grid_coordinate_array) THEN
947 30 : ALLOCATE (features%grid_coords(3, cached_layout%nflat))
948 20 : ALLOCATE (features%atomic_grid_weights(cached_layout%nflat))
949 773034 : features%grid_coords(:, :) = cached_layout%grid_coords
950 193266 : features%atomic_grid_weights(:) = cached_layout%atomic_grid_weights
951 : END IF
952 0 : ALLOCATE (features%chunk_grad_counts(cached_layout%nproc), &
953 0 : features%chunk_grad_displs(cached_layout%nproc), &
954 0 : features%route_grad_return_recv_counts(cached_layout%nproc), &
955 0 : features%route_grad_return_recv_displs(cached_layout%nproc), &
956 0 : features%route_grad_return_send_counts(cached_layout%nproc), &
957 0 : features%route_grad_return_send_displs(cached_layout%nproc), &
958 0 : features%route_point_recv_counts(cached_layout%nproc), &
959 0 : features%route_point_recv_displs(cached_layout%nproc), &
960 0 : features%route_point_send_counts(cached_layout%nproc), &
961 0 : features%route_point_send_displs(cached_layout%nproc), &
962 2016 : features%route_send_local_rows(SIZE(cached_layout%route_send_local_rows)))
963 432 : features%chunk_grad_counts(:) = cached_layout%chunk_grad_counts
964 432 : features%chunk_grad_displs(:) = cached_layout%chunk_grad_displs
965 432 : features%route_grad_return_recv_counts(:) = cached_layout%route_grad_return_recv_counts
966 432 : features%route_grad_return_recv_displs(:) = cached_layout%route_grad_return_recv_displs
967 432 : features%route_grad_return_send_counts(:) = cached_layout%route_grad_return_send_counts
968 432 : features%route_grad_return_send_displs(:) = cached_layout%route_grad_return_send_displs
969 432 : features%route_point_recv_counts(:) = cached_layout%route_point_recv_counts
970 432 : features%route_point_recv_displs(:) = cached_layout%route_point_recv_displs
971 432 : features%route_point_send_counts(:) = cached_layout%route_point_send_counts
972 432 : features%route_point_send_displs(:) = cached_layout%route_point_send_displs
973 1435532 : features%route_send_local_rows(:) = cached_layout%route_send_local_rows
974 144 : IF (needs_coordinate_array) THEN
975 48 : ALLOCATE (features%coarse_0_atomic_coords(3, cached_layout%natom))
976 144 : features%coarse_0_atomic_coords(:, :) = cached_layout%coarse_0_atomic_coords
977 : END IF
978 :
979 144 : END SUBROUTINE copy_cached_layout
980 :
981 : ! **************************************************************************************************
982 : !> \brief Split the atom-ordered feature rows into contiguous atom chunks.
983 : !> \param atomic_grid_sizes ...
984 : !> \param atom_offset ...
985 : !> \param nproc ...
986 : !> \param chunk_atom_begin ...
987 : !> \param chunk_atom_end ...
988 : !> \param chunk_feature_counts ...
989 : !> \param chunk_feature_displs ...
990 : ! **************************************************************************************************
991 52 : SUBROUTINE build_atom_chunks(atomic_grid_sizes, atom_offset, nproc, chunk_atom_begin, &
992 52 : chunk_atom_end, chunk_feature_counts, chunk_feature_displs)
993 : INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN) :: atomic_grid_sizes
994 : INTEGER, DIMENSION(:), INTENT(IN) :: atom_offset
995 : INTEGER, INTENT(IN) :: nproc
996 : INTEGER, DIMENSION(:), INTENT(OUT) :: chunk_atom_begin, chunk_atom_end, &
997 : chunk_feature_counts, &
998 : chunk_feature_displs
999 :
1000 : INTEGER :: best_limit, count, displ, end_atom, lower_limit, max_end_atom, midpoint, natom, &
1001 : next_atom, next_count, pe, ranks_left, target_chunks, total_count, upper_limit
1002 :
1003 52 : natom = SIZE(atomic_grid_sizes)
1004 156 : chunk_atom_begin = natom + 1
1005 156 : chunk_atom_end = natom
1006 156 : chunk_feature_counts = 0
1007 156 : chunk_feature_displs = 0
1008 52 : IF (natom == 0) RETURN
1009 :
1010 52 : target_chunks = MIN(nproc, natom)
1011 52 : total_count = atom_offset(natom + 1) - 1
1012 184 : lower_limit = MAXVAL(INT(atomic_grid_sizes))
1013 52 : lower_limit = MAX(lower_limit, (total_count + target_chunks - 1)/target_chunks)
1014 52 : upper_limit = total_count
1015 52 : best_limit = upper_limit
1016 738 : DO WHILE (lower_limit <= upper_limit)
1017 686 : midpoint = (lower_limit + upper_limit)/2
1018 738 : IF (atom_chunks_fit_limit(atomic_grid_sizes, midpoint, target_chunks)) THEN
1019 574 : best_limit = midpoint
1020 574 : upper_limit = midpoint - 1
1021 : ELSE
1022 112 : lower_limit = midpoint + 1
1023 : END IF
1024 : END DO
1025 :
1026 : displ = 0
1027 : next_atom = 1
1028 156 : DO pe = 1, nproc
1029 104 : chunk_feature_displs(pe) = displ
1030 104 : IF (pe > target_chunks .OR. next_atom > natom) CYCLE
1031 :
1032 104 : ranks_left = target_chunks - pe + 1
1033 104 : chunk_atom_begin(pe) = next_atom
1034 104 : max_end_atom = natom - ranks_left + 1
1035 104 : end_atom = next_atom
1036 104 : count = INT(atomic_grid_sizes(end_atom))
1037 132 : DO WHILE (end_atom < max_end_atom)
1038 38 : next_count = count + INT(atomic_grid_sizes(end_atom + 1))
1039 38 : IF (next_count > best_limit) EXIT
1040 : end_atom = end_atom + 1
1041 104 : count = next_count
1042 : END DO
1043 :
1044 104 : chunk_atom_end(pe) = end_atom
1045 104 : chunk_feature_counts(pe) = atom_offset(end_atom + 1) - atom_offset(next_atom)
1046 104 : displ = displ + chunk_feature_counts(pe)
1047 156 : next_atom = end_atom + 1
1048 : END DO
1049 :
1050 52 : CPASSERT(displ == atom_offset(natom + 1) - 1)
1051 :
1052 : END SUBROUTINE build_atom_chunks
1053 :
1054 : ! **************************************************************************************************
1055 : !> \brief Check if contiguous atom chunks can stay below a feature-count limit.
1056 : !> \param atomic_grid_sizes ...
1057 : !> \param limit ...
1058 : !> \param nchunks ...
1059 : !> \return ...
1060 : ! **************************************************************************************************
1061 686 : FUNCTION atom_chunks_fit_limit(atomic_grid_sizes, limit, nchunks) RESULT(fits)
1062 : INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN) :: atomic_grid_sizes
1063 : INTEGER, INTENT(IN) :: limit, nchunks
1064 : LOGICAL :: fits
1065 :
1066 : INTEGER :: atom_count, chunk_count, iatom, &
1067 : used_chunks
1068 :
1069 686 : fits = .FALSE.
1070 686 : IF (SIZE(atomic_grid_sizes) == 0) THEN
1071 686 : fits = .TRUE.
1072 : RETURN
1073 : END IF
1074 :
1075 2478 : used_chunks = 1
1076 2478 : chunk_count = 0
1077 2478 : DO iatom = 1, SIZE(atomic_grid_sizes)
1078 1792 : atom_count = INT(atomic_grid_sizes(iatom))
1079 1792 : IF (atom_count > limit) RETURN
1080 2478 : IF (chunk_count + atom_count > limit) THEN
1081 798 : used_chunks = used_chunks + 1
1082 798 : chunk_count = atom_count
1083 : ELSE
1084 : chunk_count = chunk_count + atom_count
1085 : END IF
1086 : END DO
1087 686 : fits = used_chunks <= nchunks
1088 :
1089 686 : END FUNCTION atom_chunks_fit_limit
1090 :
1091 : ! **************************************************************************************************
1092 : !> \brief Return the MPI rank owning an atom-ordered feature row.
1093 : !> \param row ...
1094 : !> \param counts ...
1095 : !> \param displs ...
1096 : !> \return ...
1097 : ! **************************************************************************************************
1098 1189936 : FUNCTION feature_row_chunk_owner(row, counts, displs) RESULT(owner)
1099 : INTEGER, INTENT(IN) :: row
1100 : INTEGER, DIMENSION(:), INTENT(IN) :: counts, displs
1101 : INTEGER :: owner
1102 :
1103 : INTEGER :: pe
1104 :
1105 1189936 : owner = 0
1106 1735872 : DO pe = 1, SIZE(counts)
1107 1735872 : IF (row > displs(pe) .AND. row <= displs(pe) + counts(pe)) THEN
1108 1189936 : owner = pe
1109 : RETURN
1110 : END IF
1111 : END DO
1112 :
1113 : END FUNCTION feature_row_chunk_owner
1114 :
1115 : ! **************************************************************************************************
1116 : !> \brief Build zero-based displacement arrays from per-rank counts.
1117 : !> \param counts ...
1118 : !> \param displs ...
1119 : ! **************************************************************************************************
1120 208 : SUBROUTINE counts_to_displs(counts, displs)
1121 : INTEGER, DIMENSION(:), INTENT(IN) :: counts
1122 : INTEGER, DIMENSION(:), INTENT(OUT) :: displs
1123 :
1124 : INTEGER :: pe
1125 :
1126 208 : displs(1) = 0
1127 416 : DO pe = 2, SIZE(counts)
1128 416 : displs(pe) = displs(pe - 1) + counts(pe - 1)
1129 : END DO
1130 :
1131 208 : END SUBROUTINE counts_to_displs
1132 :
1133 : ! **************************************************************************************************
1134 : !> \brief Precompute all-to-all routing between local grid rows and atom chunks.
1135 : !> \param cache ...
1136 : !> \param local_to_global ...
1137 : !> \param group ...
1138 : ! **************************************************************************************************
1139 52 : SUBROUTINE build_atom_chunk_routes(cache, local_to_global, group)
1140 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1141 : INTEGER, DIMENSION(:), INTENT(IN) :: local_to_global
1142 :
1143 : CLASS(mp_comm_type), INTENT(IN) :: group
1144 :
1145 : INTEGER :: chunk_row, dest, local_feature, point_pos, row
1146 52 : INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor, recv_meta, send_meta
1147 :
1148 0 : ALLOCATE (cache%route_local_dest(SIZE(local_to_global)), &
1149 0 : cache%route_send_local_rows(SIZE(local_to_global)), &
1150 0 : cache%chunk_return_positions(cache%chunk_feature_count), &
1151 416 : cursor(SIZE(cache%route_point_send_counts)))
1152 156 : cache%route_point_send_counts = 0
1153 1189988 : cache%route_send_local_rows = 0
1154 1189988 : cache%chunk_return_positions = 0
1155 1189988 : DO local_feature = 1, SIZE(local_to_global)
1156 : dest = feature_row_chunk_owner(local_to_global(local_feature), &
1157 : cache%chunk_feature_counts, &
1158 1189936 : cache%chunk_feature_displs)
1159 1189936 : CPASSERT(dest > 0)
1160 1189936 : cache%route_local_dest(local_feature) = dest
1161 1189988 : cache%route_point_send_counts(dest) = cache%route_point_send_counts(dest) + 1
1162 : END DO
1163 52 : CALL counts_to_displs(cache%route_point_send_counts, cache%route_point_send_displs)
1164 156 : cursor(:) = cache%route_point_send_displs + 1
1165 1189988 : DO local_feature = 1, SIZE(local_to_global)
1166 1189936 : dest = cache%route_local_dest(local_feature)
1167 1189936 : point_pos = cursor(dest)
1168 1189936 : cursor(dest) = cursor(dest) + 1
1169 1189988 : cache%route_send_local_rows(point_pos) = cache%local_feature_points(local_feature)
1170 : END DO
1171 52 : CALL group%alltoall(cache%route_point_send_counts, cache%route_point_recv_counts, 1)
1172 52 : CALL counts_to_displs(cache%route_point_recv_counts, cache%route_point_recv_displs)
1173 :
1174 208 : ALLOCATE (send_meta(SIZE(local_to_global)), recv_meta(cache%chunk_feature_count))
1175 156 : cursor(:) = cache%route_point_send_displs + 1
1176 1189988 : DO local_feature = 1, SIZE(local_to_global)
1177 1189936 : dest = cache%route_local_dest(local_feature)
1178 1189936 : point_pos = cursor(dest)
1179 1189936 : cursor(dest) = cursor(dest) + 1
1180 1189988 : send_meta(point_pos) = local_to_global(local_feature)
1181 : END DO
1182 : CALL group%alltoall(send_meta, cache%route_point_send_counts, &
1183 : cache%route_point_send_displs, recv_meta, &
1184 : cache%route_point_recv_counts, &
1185 52 : cache%route_point_recv_displs)
1186 1189988 : DO point_pos = 1, cache%chunk_feature_count
1187 1189936 : row = recv_meta(point_pos)
1188 1189936 : chunk_row = row - cache%chunk_feature_begin + 1
1189 1189936 : CPASSERT(chunk_row >= 1 .AND. chunk_row <= cache%chunk_feature_count)
1190 1189988 : cache%chunk_return_positions(chunk_row) = point_pos
1191 : END DO
1192 :
1193 156 : cache%route_grad_return_send_counts(:) = ngrad_per_point*cache%route_point_recv_counts
1194 156 : cache%route_grad_return_send_displs(:) = ngrad_per_point*cache%route_point_recv_displs
1195 156 : cache%route_grad_return_recv_counts(:) = ngrad_per_point*cache%route_point_send_counts
1196 156 : cache%route_grad_return_recv_displs(:) = ngrad_per_point*cache%route_point_send_displs
1197 :
1198 156 : CPASSERT(SUM(cache%route_point_send_counts) == SIZE(local_to_global))
1199 156 : CPASSERT(SUM(cache%route_point_recv_counts) == cache%chunk_feature_count)
1200 1189988 : CPASSERT(ALL(cache%route_send_local_rows > 0))
1201 1189988 : CPASSERT(ALL(cache%chunk_return_positions > 0))
1202 :
1203 52 : DEALLOCATE (cursor, recv_meta, send_meta)
1204 :
1205 52 : END SUBROUTINE build_atom_chunk_routes
1206 :
1207 : ! **************************************************************************************************
1208 : !> \brief Materialize the current rank's atom chunk static layout.
1209 : !> \param cache ...
1210 : ! **************************************************************************************************
1211 52 : SUBROUTINE build_atom_chunk_layout(cache)
1212 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1213 :
1214 : INTEGER :: irow, max_grid_size, row_begin, row_end
1215 :
1216 52 : IF (cache%chunk_feature_count <= 0 .OR. cache%chunk_natom <= 0) RETURN
1217 :
1218 52 : row_begin = cache%chunk_feature_begin
1219 52 : row_end = row_begin + cache%chunk_feature_count - 1
1220 0 : ALLOCATE (cache%chunk_grid_coords(3, cache%chunk_feature_count), &
1221 0 : cache%chunk_grid_weights(cache%chunk_feature_count), &
1222 0 : cache%chunk_atomic_grid_weights(cache%chunk_feature_count), &
1223 0 : cache%chunk_atomic_grid_sizes(cache%chunk_natom), &
1224 0 : cache%chunk_coarse_0_atomic_coords(3, cache%chunk_natom), &
1225 572 : cache%chunk_feature_indices(cache%chunk_feature_count))
1226 4759796 : cache%chunk_grid_coords(:, :) = cache%grid_coords(:, row_begin:row_end)
1227 1189988 : cache%chunk_grid_weights(:) = cache%grid_weights(row_begin:row_end)
1228 1189988 : cache%chunk_atomic_grid_weights(:) = cache%atomic_grid_weights(row_begin:row_end)
1229 : cache%chunk_atomic_grid_sizes(:) = &
1230 118 : cache%atomic_grid_sizes(cache%chunk_atom_begin:cache%chunk_atom_end)
1231 : cache%chunk_coarse_0_atomic_coords(:, :) = &
1232 316 : cache%coarse_0_atomic_coords(:, cache%chunk_atom_begin:cache%chunk_atom_end)
1233 :
1234 118 : max_grid_size = MAXVAL(INT(cache%chunk_atomic_grid_sizes))
1235 104 : ALLOCATE (cache%chunk_atomic_grid_size_bound_shape(0, max_grid_size))
1236 959187 : cache%chunk_atomic_grid_size_bound_shape = 0_int_8
1237 1189988 : DO irow = 1, cache%chunk_feature_count
1238 1189988 : cache%chunk_feature_indices(irow) = INT(irow - 1, KIND=int_8)
1239 : END DO
1240 :
1241 : END SUBROUTINE build_atom_chunk_layout
1242 :
1243 : ! **************************************************************************************************
1244 : !> \brief Send local dynamic feature rows to their atom-chunk owner ranks.
1245 : !> \param features ...
1246 : !> \param local_dynamic ...
1247 : !> \param group ...
1248 : !> \param collapse_spin_dynamics ...
1249 : ! **************************************************************************************************
1250 2 : SUBROUTINE route_atom_chunk_dynamics(features, local_dynamic, group, collapse_spin_dynamics)
1251 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1252 : REAL(KIND=dp), DIMENSION(:), INTENT(IN) :: local_dynamic
1253 :
1254 : CLASS(mp_comm_type), INTENT(IN) :: group
1255 : LOGICAL, INTENT(IN) :: collapse_spin_dynamics
1256 :
1257 : INTEGER :: chunk_row, dest, dyn_base, local_feature, local_row, &
1258 : ndynamic_route_per_point, nrecv, nsend, &
1259 : point_pos, src_base
1260 2 : INTEGER, ALLOCATABLE, DIMENSION(:) :: cursor, recv_counts, recv_displs, &
1261 : send_counts, send_displs
1262 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: recv_dynamic, send_dynamic
1263 :
1264 2 : CPASSERT(cached_layout%chunk_feature_count > 0)
1265 2 : nsend = SIZE(cached_layout%route_local_dest)
1266 6 : nrecv = SUM(cached_layout%route_point_recv_counts)
1267 2 : CPASSERT(nsend == SIZE(cached_layout%local_feature_rows))
1268 2 : CPASSERT(nrecv == cached_layout%chunk_feature_count)
1269 2 : ndynamic_route_per_point = ndynamic_per_point
1270 2 : IF (collapse_spin_dynamics) ndynamic_route_per_point = nrks_dynamic_per_point
1271 :
1272 : ALLOCATE (send_dynamic(ndynamic_route_per_point*nsend), &
1273 : recv_dynamic(ndynamic_route_per_point*nrecv), &
1274 : cursor(cached_layout%nproc), send_counts(cached_layout%nproc), &
1275 : send_displs(cached_layout%nproc), recv_counts(cached_layout%nproc), &
1276 22 : recv_displs(cached_layout%nproc))
1277 6 : send_counts(:) = ndynamic_route_per_point*cached_layout%route_point_send_counts
1278 6 : send_displs(:) = ndynamic_route_per_point*cached_layout%route_point_send_displs
1279 6 : recv_counts(:) = ndynamic_route_per_point*cached_layout%route_point_recv_counts
1280 6 : recv_displs(:) = ndynamic_route_per_point*cached_layout%route_point_recv_displs
1281 6 : cursor(:) = cached_layout%route_point_send_displs + 1
1282 64002 : DO local_feature = 1, nsend
1283 64000 : dest = cached_layout%route_local_dest(local_feature)
1284 64000 : point_pos = cursor(dest)
1285 64000 : cursor(dest) = cursor(dest) + 1
1286 64000 : dyn_base = ndynamic_route_per_point*(point_pos - 1)
1287 64000 : local_row = cached_layout%local_feature_points(local_feature)
1288 64000 : src_base = ndynamic_route_per_point*(local_row - 1)
1289 : send_dynamic(dyn_base + 1:dyn_base + ndynamic_route_per_point) = &
1290 384002 : local_dynamic(src_base + 1:src_base + ndynamic_route_per_point)
1291 : END DO
1292 :
1293 : CALL group%alltoall(send_dynamic, send_counts, send_displs, recv_dynamic, recv_counts, &
1294 2 : recv_displs)
1295 :
1296 2 : IF (collapse_spin_dynamics) THEN
1297 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 1), &
1298 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 1), &
1299 0 : features%chunk_kin(cached_layout%chunk_feature_count, 1), &
1300 16 : features%chunk_return_positions(cached_layout%chunk_feature_count))
1301 2 : features%uses_collapsed_rks_dynamic = .TRUE.
1302 : ELSE
1303 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1304 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1305 0 : features%chunk_kin(cached_layout%chunk_feature_count, 2), &
1306 0 : features%chunk_return_positions(cached_layout%chunk_feature_count))
1307 : END IF
1308 64002 : features%chunk_return_positions(:) = cached_layout%chunk_return_positions
1309 :
1310 64002 : DO chunk_row = 1, cached_layout%chunk_feature_count
1311 64000 : point_pos = cached_layout%chunk_return_positions(chunk_row)
1312 64000 : CPASSERT(point_pos >= 1 .AND. point_pos <= cached_layout%chunk_feature_count)
1313 64000 : dyn_base = ndynamic_route_per_point*(point_pos - 1)
1314 64002 : IF (collapse_spin_dynamics) THEN
1315 64000 : features%chunk_density(chunk_row, 1) = recv_dynamic(dyn_base + 1)
1316 64000 : features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 2)
1317 64000 : features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 3)
1318 64000 : features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 4)
1319 64000 : features%chunk_kin(chunk_row, 1) = recv_dynamic(dyn_base + 5)
1320 : ELSE
1321 0 : features%chunk_density(chunk_row, :) = recv_dynamic(dyn_base + 1:dyn_base + 2)
1322 0 : features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 3)
1323 0 : features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 4)
1324 0 : features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 5)
1325 0 : features%chunk_grad(chunk_row, 1, 2) = recv_dynamic(dyn_base + 6)
1326 0 : features%chunk_grad(chunk_row, 2, 2) = recv_dynamic(dyn_base + 7)
1327 0 : features%chunk_grad(chunk_row, 3, 2) = recv_dynamic(dyn_base + 8)
1328 0 : features%chunk_kin(chunk_row, :) = recv_dynamic(dyn_base + 9:dyn_base + 10)
1329 : END IF
1330 : END DO
1331 64002 : CPASSERT(ALL(features%chunk_return_positions > 0))
1332 :
1333 0 : DEALLOCATE (cursor, recv_counts, recv_displs, recv_dynamic, send_counts, send_displs, &
1334 2 : send_dynamic)
1335 :
1336 2 : END SUBROUTINE route_atom_chunk_dynamics
1337 :
1338 : ! **************************************************************************************************
1339 : !> \brief Extract the current rank's atom chunk from the global dynamic feature arrays.
1340 : !> \param features ...
1341 : ! **************************************************************************************************
1342 0 : SUBROUTINE extract_atom_chunk_dynamics(features)
1343 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1344 :
1345 : INTEGER :: row_begin, row_end
1346 :
1347 0 : CPASSERT(cached_layout%chunk_feature_count > 0)
1348 0 : row_begin = cached_layout%chunk_feature_begin
1349 0 : row_end = row_begin + cached_layout%chunk_feature_count - 1
1350 0 : ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
1351 0 : features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
1352 0 : features%chunk_kin(cached_layout%chunk_feature_count, 2))
1353 0 : features%chunk_density(:, :) = features%density(row_begin:row_end, :)
1354 0 : features%chunk_grad(:, :, :) = features%grad(row_begin:row_end, :, :)
1355 0 : features%chunk_kin(:, :) = features%kin(row_begin:row_end, :)
1356 :
1357 0 : END SUBROUTINE extract_atom_chunk_dynamics
1358 :
1359 : ! **************************************************************************************************
1360 : !> \brief Compute a local signature for optional integration weights.
1361 : !> \param weights ...
1362 : !> \param has_weights ...
1363 : !> \param weight_sum ...
1364 : !> \param weight_sumsq ...
1365 : ! **************************************************************************************************
1366 144 : SUBROUTINE weights_signature(weights, has_weights, weight_sum, weight_sumsq)
1367 : TYPE(pw_r3d_rs_type), OPTIONAL, POINTER :: weights
1368 : LOGICAL, INTENT(OUT) :: has_weights
1369 : REAL(KIND=dp), INTENT(OUT) :: weight_sum, weight_sumsq
1370 :
1371 144 : has_weights = .FALSE.
1372 144 : weight_sum = 0.0_dp
1373 144 : weight_sumsq = 0.0_dp
1374 144 : IF (PRESENT(weights)) THEN
1375 144 : IF (ASSOCIATED(weights)) THEN
1376 0 : has_weights = .TRUE.
1377 0 : weight_sum = SUM(weights%array)
1378 0 : weight_sumsq = SUM(weights%array*weights%array)
1379 : END IF
1380 : END IF
1381 :
1382 144 : END SUBROUTINE weights_signature
1383 :
1384 : ! **************************************************************************************************
1385 : !> \brief Release cached layout arrays.
1386 : !> \param cache ...
1387 : ! **************************************************************************************************
1388 52 : SUBROUTINE release_layout_cache(cache)
1389 : TYPE(skala_gpw_layout_cache_type), INTENT(INOUT) :: cache
1390 :
1391 52 : IF (cache%inputs_active) THEN
1392 8 : CALL torch_dict_release(cache%inputs)
1393 8 : cache%inputs_active = .FALSE.
1394 : END IF
1395 :
1396 52 : IF (cache%chunk_inputs_active) THEN
1397 0 : CALL torch_dict_release(cache%chunk_inputs)
1398 0 : cache%chunk_inputs_active = .FALSE.
1399 : END IF
1400 :
1401 52 : IF (cache%dynamic_tensors_active) THEN
1402 8 : CALL torch_tensor_release(cache%density_t)
1403 8 : CALL torch_tensor_release(cache%grad_t)
1404 8 : CALL torch_tensor_release(cache%kin_t)
1405 8 : cache%dynamic_tensors_active = .FALSE.
1406 : END IF
1407 :
1408 52 : IF (cache%chunk_dynamic_tensors_active) THEN
1409 0 : IF (cache%chunk_dynamic_input_views_active) THEN
1410 0 : CALL torch_tensor_release(cache%chunk_density_input_t)
1411 0 : CALL torch_tensor_release(cache%chunk_grad_input_t)
1412 0 : CALL torch_tensor_release(cache%chunk_kin_input_t)
1413 0 : cache%chunk_dynamic_input_views_active = .FALSE.
1414 : END IF
1415 0 : CALL torch_tensor_release(cache%chunk_density_t)
1416 0 : CALL torch_tensor_release(cache%chunk_grad_t)
1417 0 : CALL torch_tensor_release(cache%chunk_kin_t)
1418 0 : cache%chunk_dynamic_tensors_active = .FALSE.
1419 : END IF
1420 :
1421 52 : IF (cache%static_tensors_active) THEN
1422 8 : CALL torch_tensor_release(cache%grid_coords_t)
1423 8 : CALL torch_tensor_release(cache%grid_weights_t)
1424 8 : CALL torch_tensor_release(cache%atomic_grid_weights_t)
1425 8 : CALL torch_tensor_release(cache%atomic_grid_sizes_t)
1426 8 : CALL torch_tensor_release(cache%coarse_0_atomic_coords_t)
1427 8 : CALL torch_tensor_release(cache%atomic_grid_size_bound_shape_t)
1428 8 : CALL torch_tensor_release(cache%local_feature_indices_t)
1429 8 : CALL torch_dict_release(cache%static_inputs)
1430 8 : cache%static_tensors_active = .FALSE.
1431 : END IF
1432 :
1433 52 : IF (cache%chunk_static_tensors_active) THEN
1434 8 : CALL torch_tensor_release(cache%chunk_grid_coords_t)
1435 8 : CALL torch_tensor_release(cache%chunk_grid_weights_t)
1436 8 : CALL torch_tensor_release(cache%chunk_atomic_grid_weights_t)
1437 8 : CALL torch_tensor_release(cache%chunk_atomic_grid_sizes_t)
1438 8 : CALL torch_tensor_release(cache%chunk_coarse_0_atomic_coords_t)
1439 8 : CALL torch_tensor_release(cache%chunk_atomic_grid_size_bound_shape_t)
1440 8 : CALL torch_tensor_release(cache%chunk_feature_indices_t)
1441 8 : CALL torch_dict_release(cache%chunk_static_inputs)
1442 : cache%chunk_static_tensors_active = .FALSE.
1443 : END IF
1444 :
1445 52 : IF (ALLOCATED(cache%chunk_feature_counts)) DEALLOCATE (cache%chunk_feature_counts)
1446 52 : IF (ALLOCATED(cache%chunk_feature_displs)) DEALLOCATE (cache%chunk_feature_displs)
1447 52 : IF (ALLOCATED(cache%chunk_grad_counts)) DEALLOCATE (cache%chunk_grad_counts)
1448 52 : IF (ALLOCATED(cache%chunk_grad_displs)) DEALLOCATE (cache%chunk_grad_displs)
1449 52 : IF (ALLOCATED(cache%route_grad_return_recv_counts)) &
1450 8 : DEALLOCATE (cache%route_grad_return_recv_counts)
1451 52 : IF (ALLOCATED(cache%route_grad_return_recv_displs)) &
1452 8 : DEALLOCATE (cache%route_grad_return_recv_displs)
1453 52 : IF (ALLOCATED(cache%route_grad_return_send_counts)) &
1454 8 : DEALLOCATE (cache%route_grad_return_send_counts)
1455 52 : IF (ALLOCATED(cache%route_grad_return_send_displs)) &
1456 8 : DEALLOCATE (cache%route_grad_return_send_displs)
1457 52 : IF (ALLOCATED(cache%route_local_dest)) DEALLOCATE (cache%route_local_dest)
1458 52 : IF (ALLOCATED(cache%chunk_return_positions)) DEALLOCATE (cache%chunk_return_positions)
1459 52 : IF (ALLOCATED(cache%route_point_recv_counts)) DEALLOCATE (cache%route_point_recv_counts)
1460 52 : IF (ALLOCATED(cache%route_point_recv_displs)) DEALLOCATE (cache%route_point_recv_displs)
1461 52 : IF (ALLOCATED(cache%route_point_send_counts)) DEALLOCATE (cache%route_point_send_counts)
1462 52 : IF (ALLOCATED(cache%route_point_send_displs)) DEALLOCATE (cache%route_point_send_displs)
1463 52 : IF (ALLOCATED(cache%route_send_local_rows)) DEALLOCATE (cache%route_send_local_rows)
1464 52 : IF (ALLOCATED(cache%dynamic_counts)) DEALLOCATE (cache%dynamic_counts)
1465 52 : IF (ALLOCATED(cache%dynamic_displs)) DEALLOCATE (cache%dynamic_displs)
1466 52 : IF (ALLOCATED(cache%feature_counts)) DEALLOCATE (cache%feature_counts)
1467 52 : IF (ALLOCATED(cache%feature_displs)) DEALLOCATE (cache%feature_displs)
1468 52 : IF (ALLOCATED(cache%feature_source_points)) DEALLOCATE (cache%feature_source_points)
1469 52 : IF (ALLOCATED(cache%global_to_feature)) DEALLOCATE (cache%global_to_feature)
1470 52 : IF (ALLOCATED(cache%feature_index)) DEALLOCATE (cache%feature_index)
1471 52 : IF (ALLOCATED(cache%atomic_grid_sizes)) DEALLOCATE (cache%atomic_grid_sizes)
1472 52 : IF (ALLOCATED(cache%chunk_atomic_grid_sizes)) DEALLOCATE (cache%chunk_atomic_grid_sizes)
1473 52 : IF (ALLOCATED(cache%chunk_feature_indices)) DEALLOCATE (cache%chunk_feature_indices)
1474 52 : IF (ALLOCATED(cache%local_feature_counts)) DEALLOCATE (cache%local_feature_counts)
1475 52 : IF (ALLOCATED(cache%local_feature_indices)) DEALLOCATE (cache%local_feature_indices)
1476 52 : IF (ALLOCATED(cache%local_feature_offsets)) DEALLOCATE (cache%local_feature_offsets)
1477 52 : IF (ALLOCATED(cache%local_feature_points)) DEALLOCATE (cache%local_feature_points)
1478 52 : IF (ALLOCATED(cache%local_feature_rows)) DEALLOCATE (cache%local_feature_rows)
1479 52 : IF (ALLOCATED(cache%atomic_grid_size_bound_shape)) &
1480 8 : DEALLOCATE (cache%atomic_grid_size_bound_shape)
1481 52 : IF (ALLOCATED(cache%chunk_atomic_grid_size_bound_shape)) &
1482 8 : DEALLOCATE (cache%chunk_atomic_grid_size_bound_shape)
1483 52 : IF (ALLOCATED(cache%atomic_grid_weights)) DEALLOCATE (cache%atomic_grid_weights)
1484 52 : IF (ALLOCATED(cache%chunk_atomic_grid_weights)) DEALLOCATE (cache%chunk_atomic_grid_weights)
1485 52 : IF (ALLOCATED(cache%chunk_grid_weights)) DEALLOCATE (cache%chunk_grid_weights)
1486 52 : IF (ALLOCATED(cache%grid_weights)) DEALLOCATE (cache%grid_weights)
1487 52 : IF (ALLOCATED(cache%atom_coords)) DEALLOCATE (cache%atom_coords)
1488 52 : IF (ALLOCATED(cache%chunk_coarse_0_atomic_coords)) &
1489 8 : DEALLOCATE (cache%chunk_coarse_0_atomic_coords)
1490 52 : IF (ALLOCATED(cache%coarse_0_atomic_coords)) DEALLOCATE (cache%coarse_0_atomic_coords)
1491 52 : IF (ALLOCATED(cache%chunk_grid_coords)) DEALLOCATE (cache%chunk_grid_coords)
1492 52 : IF (ALLOCATED(cache%grid_coords)) DEALLOCATE (cache%grid_coords)
1493 :
1494 52 : cache%chunk_atom_begin = 1
1495 52 : cache%chunk_atom_end = 0
1496 52 : cache%chunk_feature_begin = 1
1497 52 : cache%chunk_feature_count = 0
1498 52 : cache%chunk_natom = 0
1499 52 : cache%natom = 0
1500 52 : cache%nflat = 0
1501 52 : cache%nflat_local = 0
1502 52 : cache%npoint = 0
1503 52 : cache%nproc = 0
1504 52 : cache%atom_partition = skala_gpw_atom_partition_hard
1505 520 : cache%bo = 0
1506 520 : cache%bounds = 0
1507 208 : cache%npts = 0
1508 52 : cache%dvol = 0.0_dp
1509 52 : cache%weight_sum = 0.0_dp
1510 52 : cache%weight_sumsq = 0.0_dp
1511 676 : cache%cell_hmat = 0.0_dp
1512 676 : cache%dh = 0.0_dp
1513 52 : cache%active = .FALSE.
1514 52 : cache%has_weights = .FALSE.
1515 52 : cache%chunk_dynamic_tensors_active = .FALSE.
1516 52 : cache%chunk_dynamic_input_views_active = .FALSE.
1517 52 : cache%chunk_inputs_active = .FALSE.
1518 52 : cache%chunk_inputs_use_collapsed_rks = .FALSE.
1519 52 : cache%chunk_static_tensors_active = .FALSE.
1520 52 : cache%dynamic_tensors_active = .FALSE.
1521 52 : cache%inputs_active = .FALSE.
1522 52 : cache%static_tensors_active = .FALSE.
1523 :
1524 52 : END SUBROUTINE release_layout_cache
1525 :
1526 : ! **************************************************************************************************
1527 : !> \brief Release Torch objects and backing arrays owned by a feature bundle.
1528 : !> \param features ...
1529 : ! **************************************************************************************************
1530 296 : SUBROUTINE skala_gpw_feature_release(features)
1531 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1532 :
1533 296 : IF (features%active) THEN
1534 148 : IF (features%owns_dynamic_tensors) THEN
1535 4 : IF (features%uses_collapsed_rks_dynamic) THEN
1536 4 : CALL torch_tensor_release(features%density_input_t)
1537 4 : CALL torch_tensor_release(features%grad_input_t)
1538 4 : CALL torch_tensor_release(features%kin_input_t)
1539 : END IF
1540 4 : CALL torch_tensor_release(features%density_t)
1541 4 : CALL torch_tensor_release(features%grad_t)
1542 4 : CALL torch_tensor_release(features%kin_t)
1543 : END IF
1544 148 : IF (features%owns_static_tensors) THEN
1545 4 : CALL torch_tensor_release(features%grid_coords_t)
1546 4 : CALL torch_tensor_release(features%grid_weights_t)
1547 4 : CALL torch_tensor_release(features%atomic_grid_weights_t)
1548 4 : CALL torch_tensor_release(features%atomic_grid_sizes_t)
1549 4 : CALL torch_tensor_release(features%atomic_grid_size_bound_shape_t)
1550 : END IF
1551 148 : IF (features%owns_grid_coordinate_tensor) THEN
1552 8 : CALL torch_tensor_release(features%grid_coords_t)
1553 : END IF
1554 148 : IF (features%owns_weight_tensors) THEN
1555 10 : CALL torch_tensor_release(features%grid_weights_t)
1556 10 : CALL torch_tensor_release(features%atomic_grid_weights_t)
1557 : END IF
1558 148 : IF (features%owns_static_tensors .OR. features%owns_coordinate_tensor) THEN
1559 20 : CALL torch_tensor_release(features%coarse_0_atomic_coords_t)
1560 : END IF
1561 148 : IF (features%owns_inputs) CALL torch_dict_release(features%inputs)
1562 148 : features%active = .FALSE.
1563 148 : features%owns_coordinate_tensor = .FALSE.
1564 148 : features%owns_grid_coordinate_tensor = .FALSE.
1565 148 : features%owns_weight_tensors = .FALSE.
1566 148 : features%owns_dynamic_tensors = .TRUE.
1567 148 : features%owns_inputs = .TRUE.
1568 148 : features%owns_static_tensors = .TRUE.
1569 : features%uses_atom_chunk_routing = .FALSE.
1570 148 : features%uses_atom_chunks = .FALSE.
1571 : features%uses_collapsed_rks_dynamic = .FALSE.
1572 : END IF
1573 :
1574 296 : IF (ALLOCATED(features%chunk_density)) DEALLOCATE (features%chunk_density)
1575 296 : IF (ALLOCATED(features%chunk_grad)) DEALLOCATE (features%chunk_grad)
1576 296 : IF (ALLOCATED(features%chunk_kin)) DEALLOCATE (features%chunk_kin)
1577 296 : IF (ALLOCATED(features%density)) DEALLOCATE (features%density)
1578 296 : IF (ALLOCATED(features%grad)) DEALLOCATE (features%grad)
1579 296 : IF (ALLOCATED(features%kin)) DEALLOCATE (features%kin)
1580 296 : IF (ALLOCATED(features%chunk_grad_counts)) DEALLOCATE (features%chunk_grad_counts)
1581 296 : IF (ALLOCATED(features%chunk_grad_displs)) DEALLOCATE (features%chunk_grad_displs)
1582 296 : IF (ALLOCATED(features%chunk_return_positions)) DEALLOCATE (features%chunk_return_positions)
1583 296 : IF (ALLOCATED(features%route_grad_return_recv_counts)) &
1584 144 : DEALLOCATE (features%route_grad_return_recv_counts)
1585 296 : IF (ALLOCATED(features%route_grad_return_recv_displs)) &
1586 144 : DEALLOCATE (features%route_grad_return_recv_displs)
1587 296 : IF (ALLOCATED(features%route_grad_return_send_counts)) &
1588 144 : DEALLOCATE (features%route_grad_return_send_counts)
1589 296 : IF (ALLOCATED(features%route_grad_return_send_displs)) &
1590 144 : DEALLOCATE (features%route_grad_return_send_displs)
1591 296 : IF (ALLOCATED(features%route_point_recv_counts)) &
1592 144 : DEALLOCATE (features%route_point_recv_counts)
1593 296 : IF (ALLOCATED(features%route_point_recv_displs)) &
1594 144 : DEALLOCATE (features%route_point_recv_displs)
1595 296 : IF (ALLOCATED(features%route_point_send_counts)) &
1596 144 : DEALLOCATE (features%route_point_send_counts)
1597 296 : IF (ALLOCATED(features%route_point_send_displs)) &
1598 144 : DEALLOCATE (features%route_point_send_displs)
1599 296 : IF (ALLOCATED(features%route_send_local_rows)) DEALLOCATE (features%route_send_local_rows)
1600 296 : IF (ALLOCATED(features%feature_index)) DEALLOCATE (features%feature_index)
1601 296 : IF (ALLOCATED(features%local_feature_counts)) DEALLOCATE (features%local_feature_counts)
1602 296 : IF (ALLOCATED(features%local_feature_offsets)) DEALLOCATE (features%local_feature_offsets)
1603 296 : IF (ALLOCATED(features%local_feature_rows)) DEALLOCATE (features%local_feature_rows)
1604 296 : IF (ALLOCATED(features%grid_coords)) DEALLOCATE (features%grid_coords)
1605 296 : IF (ALLOCATED(features%grid_weights)) DEALLOCATE (features%grid_weights)
1606 296 : IF (ALLOCATED(features%atomic_grid_weights)) DEALLOCATE (features%atomic_grid_weights)
1607 296 : IF (ALLOCATED(features%atomic_grid_sizes)) DEALLOCATE (features%atomic_grid_sizes)
1608 296 : IF (ALLOCATED(features%coarse_0_atomic_coords)) DEALLOCATE (features%coarse_0_atomic_coords)
1609 296 : IF (ALLOCATED(features%atomic_grid_size_bound_shape)) &
1610 4 : DEALLOCATE (features%atomic_grid_size_bound_shape)
1611 296 : features%chunk_feature_count = 0
1612 296 : features%nflat = 0
1613 296 : features%nflat_local = 0
1614 296 : features%atom_partition = skala_gpw_atom_partition_hard
1615 296 : features%uses_atom_chunk_routing = .FALSE.
1616 296 : features%uses_collapsed_rks_dynamic = .FALSE.
1617 :
1618 296 : END SUBROUTINE skala_gpw_feature_release
1619 :
1620 : ! **************************************************************************************************
1621 : !> \brief Return how many atom-contiguous subchunks the cached rank chunk needs.
1622 : !> \param max_rows ...
1623 : !> \return ...
1624 : ! **************************************************************************************************
1625 4 : FUNCTION skala_gpw_atom_subchunk_count(max_rows) RESULT(nsubchunks)
1626 : INTEGER, INTENT(IN) :: max_rows
1627 : INTEGER :: nsubchunks
1628 :
1629 : INTEGER :: atom_rows, iatom, rows
1630 :
1631 4 : nsubchunks = 1
1632 4 : IF (max_rows <= 0) RETURN
1633 4 : IF (.NOT. cached_layout%active) RETURN
1634 4 : IF (cached_layout%chunk_natom <= 0) RETURN
1635 :
1636 : nsubchunks = 0
1637 : rows = 0
1638 12 : DO iatom = 1, cached_layout%chunk_natom
1639 8 : atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
1640 8 : IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
1641 4 : nsubchunks = nsubchunks + 1
1642 4 : rows = 0
1643 : END IF
1644 12 : rows = rows + atom_rows
1645 : END DO
1646 4 : IF (rows > 0) nsubchunks = nsubchunks + 1
1647 4 : nsubchunks = MAX(1, nsubchunks)
1648 :
1649 4 : END FUNCTION skala_gpw_atom_subchunk_count
1650 :
1651 : ! **************************************************************************************************
1652 : !> \brief Build an atom-contiguous subchunk feature bundle from a rank-local atom chunk.
1653 : !> \param parent ...
1654 : !> \param features ...
1655 : !> \param subchunk_index ...
1656 : !> \param max_rows ...
1657 : !> \param requires_grad ...
1658 : ! **************************************************************************************************
1659 4 : SUBROUTINE skala_gpw_feature_build_atom_subchunk(parent, features, subchunk_index, &
1660 : max_rows, requires_grad)
1661 : TYPE(skala_gpw_feature_type), INTENT(IN) :: parent
1662 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1663 : INTEGER, INTENT(IN) :: subchunk_index, max_rows
1664 : LOGICAL, INTENT(IN) :: requires_grad
1665 :
1666 : INTEGER :: atom_begin, atom_count, atom_end, &
1667 : max_grid_size, row_begin, row_count, &
1668 : row_end
1669 :
1670 4 : CALL skala_gpw_feature_release(features)
1671 4 : CPASSERT(parent%uses_atom_chunks)
1672 : CALL atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
1673 4 : row_begin, row_end)
1674 4 : atom_count = atom_end - atom_begin + 1
1675 4 : row_count = row_end - row_begin + 1
1676 4 : CPASSERT(atom_count > 0)
1677 4 : CPASSERT(row_count > 0)
1678 : MARK_USED(requires_grad)
1679 8 : max_grid_size = MAXVAL(INT(cached_layout%chunk_atomic_grid_sizes(atom_begin:atom_end)))
1680 :
1681 8 : ALLOCATE (features%atomic_grid_size_bound_shape(0, max_grid_size))
1682 64004 : features%atomic_grid_size_bound_shape = 0_int_8
1683 :
1684 4 : features%chunk_feature_count = row_count
1685 4 : features%nflat = parent%nflat
1686 4 : features%nflat_local = parent%nflat_local
1687 64004 : features%grid_weight_sum = SUM(cached_layout%chunk_grid_weights(row_begin:row_end))
1688 4 : features%uses_atom_chunks = .TRUE.
1689 4 : features%uses_atom_chunk_routing = parent%uses_atom_chunk_routing
1690 : CALL add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
1691 4 : row_count)
1692 4 : features%active = .TRUE.
1693 :
1694 4 : END SUBROUTINE skala_gpw_feature_build_atom_subchunk
1695 :
1696 : ! **************************************************************************************************
1697 : !> \brief Return atom and row bounds for an atom-contiguous rank-local subchunk.
1698 : !> \param subchunk_index ...
1699 : !> \param max_rows ...
1700 : !> \param atom_begin ...
1701 : !> \param atom_end ...
1702 : !> \param row_begin ...
1703 : !> \param row_end ...
1704 : ! **************************************************************************************************
1705 4 : SUBROUTINE atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
1706 : row_begin, row_end)
1707 : INTEGER, INTENT(IN) :: subchunk_index, max_rows
1708 : INTEGER, INTENT(OUT) :: atom_begin, atom_end, row_begin, row_end
1709 :
1710 : INTEGER :: atom_rows, current_subchunk, iatom, &
1711 : row_cursor, rows
1712 :
1713 4 : CPASSERT(subchunk_index > 0)
1714 4 : CPASSERT(max_rows > 0)
1715 4 : CPASSERT(cached_layout%chunk_natom > 0)
1716 :
1717 4 : atom_begin = 1
1718 4 : atom_end = 0
1719 4 : row_begin = 1
1720 4 : row_end = 0
1721 4 : current_subchunk = 1
1722 4 : row_cursor = 1
1723 4 : rows = 0
1724 10 : DO iatom = 1, cached_layout%chunk_natom
1725 8 : atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
1726 8 : IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
1727 4 : IF (current_subchunk == subchunk_index) THEN
1728 2 : atom_end = iatom - 1
1729 2 : row_end = row_cursor - 1
1730 2 : RETURN
1731 : END IF
1732 2 : current_subchunk = current_subchunk + 1
1733 2 : atom_begin = iatom
1734 2 : row_begin = row_cursor
1735 2 : rows = 0
1736 : END IF
1737 6 : rows = rows + atom_rows
1738 8 : row_cursor = row_cursor + atom_rows
1739 : END DO
1740 :
1741 2 : IF (current_subchunk == subchunk_index) THEN
1742 2 : atom_end = cached_layout%chunk_natom
1743 2 : row_end = row_cursor - 1
1744 2 : RETURN
1745 : END IF
1746 :
1747 0 : CPABORT("Requested native SKALA atom subchunk does not exist.")
1748 :
1749 : END SUBROUTINE atom_subchunk_bounds
1750 :
1751 : ! **************************************************************************************************
1752 : !> \brief Insert a subchunk into a Torch dictionary using static views of the cached chunk tensors.
1753 : !> \param parent ...
1754 : !> \param features ...
1755 : !> \param atom_begin ...
1756 : !> \param atom_count ...
1757 : !> \param row_begin ...
1758 : !> \param row_count ...
1759 : ! **************************************************************************************************
1760 4 : SUBROUTINE add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
1761 : row_count)
1762 : TYPE(skala_gpw_feature_type), INTENT(IN) :: parent
1763 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1764 : INTEGER, INTENT(IN) :: atom_begin, atom_count, row_begin, &
1765 : row_count
1766 :
1767 4 : CPASSERT(cached_layout%chunk_static_tensors_active)
1768 4 : CPASSERT(parent%active)
1769 4 : CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
1770 :
1771 4 : features%owns_coordinate_tensor = .FALSE.
1772 4 : features%owns_dynamic_tensors = .TRUE.
1773 4 : features%owns_inputs = .TRUE.
1774 4 : features%owns_static_tensors = .TRUE.
1775 4 : features%uses_collapsed_rks_dynamic = parent%uses_collapsed_rks_dynamic
1776 :
1777 : CALL torch_tensor_narrow(cached_layout%chunk_grid_coords_t, 0, row_begin - 1, &
1778 4 : row_count, features%grid_coords_t)
1779 : CALL torch_tensor_narrow(cached_layout%chunk_grid_weights_t, 0, row_begin - 1, &
1780 4 : row_count, features%grid_weights_t)
1781 : CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_weights_t, 0, row_begin - 1, &
1782 4 : row_count, features%atomic_grid_weights_t)
1783 : CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_sizes_t, 0, atom_begin - 1, &
1784 4 : atom_count, features%atomic_grid_sizes_t)
1785 : CALL torch_tensor_narrow(cached_layout%chunk_coarse_0_atomic_coords_t, 0, &
1786 4 : atom_begin - 1, atom_count, features%coarse_0_atomic_coords_t)
1787 : CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
1788 4 : features%atomic_grid_size_bound_shape)
1789 4 : CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
1790 : CALL torch_tensor_narrow(parent%density_t, 1, row_begin - 1, row_count, &
1791 4 : features%density_t)
1792 4 : CALL torch_tensor_narrow(parent%grad_t, 2, row_begin - 1, row_count, features%grad_t)
1793 4 : CALL torch_tensor_narrow(parent%kin_t, 1, row_begin - 1, row_count, features%kin_t)
1794 4 : IF (features%uses_collapsed_rks_dynamic) THEN
1795 4 : CALL torch_tensor_expand_dim(features%density_t, 0, 2, features%density_input_t)
1796 4 : CALL torch_tensor_expand_dim(features%grad_t, 0, 2, features%grad_input_t)
1797 4 : CALL torch_tensor_expand_dim(features%kin_t, 0, 2, features%kin_input_t)
1798 : END IF
1799 :
1800 4 : CALL torch_dict_create(features%inputs)
1801 4 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
1802 4 : CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
1803 : CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
1804 4 : features%atomic_grid_weights_t)
1805 : CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
1806 4 : features%atomic_grid_sizes_t)
1807 : CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
1808 4 : features%atomic_grid_size_bound_shape_t)
1809 4 : IF (features%uses_collapsed_rks_dynamic) THEN
1810 4 : CALL torch_dict_insert(features%inputs, "density", features%density_input_t)
1811 4 : CALL torch_dict_insert(features%inputs, "grad", features%grad_input_t)
1812 4 : CALL torch_dict_insert(features%inputs, "kin", features%kin_input_t)
1813 : ELSE
1814 0 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
1815 0 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1816 0 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1817 : END IF
1818 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1819 4 : features%coarse_0_atomic_coords_t)
1820 :
1821 4 : END SUBROUTINE add_subchunk_feature_tensors
1822 :
1823 : ! **************************************************************************************************
1824 : !> \brief Insert owned subchunk arrays into a Torch dictionary.
1825 : !> \param features ...
1826 : !> \param requires_grad ...
1827 : ! **************************************************************************************************
1828 0 : SUBROUTINE add_owned_feature_tensors(features, requires_grad)
1829 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1830 : LOGICAL, INTENT(IN) :: requires_grad
1831 :
1832 0 : CPASSERT(ALLOCATED(features%chunk_density))
1833 0 : CPASSERT(ALLOCATED(features%chunk_grad))
1834 0 : CPASSERT(ALLOCATED(features%chunk_kin))
1835 0 : CPASSERT(ALLOCATED(features%grid_coords))
1836 0 : CPASSERT(ALLOCATED(features%grid_weights))
1837 0 : CPASSERT(ALLOCATED(features%atomic_grid_weights))
1838 0 : CPASSERT(ALLOCATED(features%atomic_grid_sizes))
1839 0 : CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
1840 0 : CPASSERT(ALLOCATED(features%coarse_0_atomic_coords))
1841 :
1842 0 : features%owns_coordinate_tensor = .FALSE.
1843 0 : features%owns_dynamic_tensors = .TRUE.
1844 0 : features%owns_inputs = .TRUE.
1845 0 : features%owns_static_tensors = .TRUE.
1846 :
1847 0 : CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
1848 0 : CALL torch_tensor_to_device_leaf(features%grid_coords_t, .FALSE.)
1849 0 : CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
1850 0 : CALL torch_tensor_to_device_leaf(features%grid_weights_t, .FALSE.)
1851 0 : CALL torch_tensor_from_array(features%atomic_grid_weights_t, features%atomic_grid_weights)
1852 0 : CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .FALSE.)
1853 0 : CALL torch_tensor_from_array(features%atomic_grid_sizes_t, features%atomic_grid_sizes)
1854 0 : CALL torch_tensor_to_device_leaf(features%atomic_grid_sizes_t, .FALSE.)
1855 : CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
1856 0 : features%coarse_0_atomic_coords)
1857 0 : CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .FALSE.)
1858 : CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
1859 0 : features%atomic_grid_size_bound_shape)
1860 0 : CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
1861 0 : CALL torch_tensor_from_array(features%density_t, features%chunk_density)
1862 0 : CALL torch_tensor_to_device_leaf(features%density_t, requires_grad)
1863 0 : CALL torch_tensor_from_array(features%grad_t, features%chunk_grad)
1864 0 : CALL torch_tensor_to_device_leaf(features%grad_t, requires_grad)
1865 0 : CALL torch_tensor_from_array(features%kin_t, features%chunk_kin)
1866 0 : CALL torch_tensor_to_device_leaf(features%kin_t, requires_grad)
1867 :
1868 0 : CALL torch_dict_create(features%inputs)
1869 0 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
1870 0 : CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
1871 : CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
1872 0 : features%atomic_grid_weights_t)
1873 : CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
1874 0 : features%atomic_grid_sizes_t)
1875 : CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
1876 0 : features%atomic_grid_size_bound_shape_t)
1877 0 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
1878 0 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
1879 0 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
1880 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
1881 0 : features%coarse_0_atomic_coords_t)
1882 :
1883 0 : END SUBROUTINE add_owned_feature_tensors
1884 :
1885 : ! **************************************************************************************************
1886 : !> \brief Insert all SKALA feature tensors into the Torch dictionary.
1887 : !> \param features ...
1888 : !> \param requires_grad ...
1889 : !> \param requires_coordinate_grad ...
1890 : !> \param requires_stress_grad ...
1891 : !> \param use_atom_chunks ...
1892 : !> \param requires_weight_grad ...
1893 : ! **************************************************************************************************
1894 144 : SUBROUTINE add_feature_tensors(features, requires_grad, requires_coordinate_grad, &
1895 : requires_stress_grad, use_atom_chunks, requires_weight_grad)
1896 : TYPE(skala_gpw_feature_type), INTENT(INOUT) :: features
1897 : LOGICAL, INTENT(IN) :: requires_grad, requires_coordinate_grad, &
1898 : requires_stress_grad, use_atom_chunks
1899 : LOGICAL, INTENT(IN), OPTIONAL :: requires_weight_grad
1900 :
1901 : LOGICAL :: my_requires_weight_grad
1902 :
1903 144 : my_requires_weight_grad = .FALSE.
1904 144 : IF (PRESENT(requires_weight_grad)) my_requires_weight_grad = requires_weight_grad
1905 :
1906 144 : CPASSERT(cached_layout%static_tensors_active)
1907 144 : features%owns_static_tensors = .FALSE.
1908 144 : features%owns_coordinate_tensor = .FALSE.
1909 144 : features%owns_grid_coordinate_tensor = .FALSE.
1910 144 : features%owns_weight_tensors = .FALSE.
1911 144 : features%owns_dynamic_tensors = .FALSE.
1912 144 : features%owns_inputs = .TRUE.
1913 144 : IF (use_atom_chunks) THEN
1914 2 : IF (requires_stress_grad .OR. my_requires_weight_grad) THEN
1915 : CALL cp_abort(__LOCATION__, &
1916 : "Native SKALA analytical stress/SMOOTH derivatives are not implemented "// &
1917 0 : "with atom-chunk tensors yet.")
1918 : END IF
1919 2 : CPASSERT(cached_layout%chunk_static_tensors_active)
1920 2 : features%grid_coords_t = cached_layout%chunk_grid_coords_t
1921 2 : features%grid_weights_t = cached_layout%chunk_grid_weights_t
1922 2 : features%atomic_grid_weights_t = cached_layout%chunk_atomic_grid_weights_t
1923 2 : features%atomic_grid_sizes_t = cached_layout%chunk_atomic_grid_sizes_t
1924 : features%atomic_grid_size_bound_shape_t = &
1925 2 : cached_layout%chunk_atomic_grid_size_bound_shape_t
1926 2 : features%local_feature_indices_t = cached_layout%chunk_feature_indices_t
1927 :
1928 2 : IF (cached_layout%chunk_inputs_active .AND. &
1929 : (cached_layout%chunk_inputs_use_collapsed_rks .NEQV. &
1930 : features%uses_collapsed_rks_dynamic)) THEN
1931 0 : CALL torch_dict_release(cached_layout%chunk_inputs)
1932 0 : cached_layout%chunk_inputs_active = .FALSE.
1933 : END IF
1934 2 : IF (.NOT. features%uses_collapsed_rks_dynamic .AND. &
1935 : cached_layout%chunk_dynamic_input_views_active) THEN
1936 0 : CALL torch_tensor_release(cached_layout%chunk_density_input_t)
1937 0 : CALL torch_tensor_release(cached_layout%chunk_grad_input_t)
1938 0 : CALL torch_tensor_release(cached_layout%chunk_kin_input_t)
1939 0 : cached_layout%chunk_dynamic_input_views_active = .FALSE.
1940 : END IF
1941 :
1942 : CALL torch_tensor_reset_from_array(cached_layout%chunk_density_t, &
1943 2 : features%chunk_density, requires_grad=requires_grad)
1944 2 : features%density_t = cached_layout%chunk_density_t
1945 : CALL torch_tensor_reset_from_array(cached_layout%chunk_grad_t, features%chunk_grad, &
1946 2 : requires_grad=requires_grad)
1947 2 : features%grad_t = cached_layout%chunk_grad_t
1948 : CALL torch_tensor_reset_from_array(cached_layout%chunk_kin_t, features%chunk_kin, &
1949 2 : requires_grad=requires_grad)
1950 2 : features%kin_t = cached_layout%chunk_kin_t
1951 2 : cached_layout%chunk_dynamic_tensors_active = .TRUE.
1952 :
1953 2 : IF (features%uses_collapsed_rks_dynamic .AND. &
1954 : .NOT. cached_layout%chunk_dynamic_input_views_active) THEN
1955 : CALL torch_tensor_expand_dim(cached_layout%chunk_density_t, 0, 2, &
1956 2 : cached_layout%chunk_density_input_t)
1957 : CALL torch_tensor_expand_dim(cached_layout%chunk_grad_t, 0, 2, &
1958 2 : cached_layout%chunk_grad_input_t)
1959 : CALL torch_tensor_expand_dim(cached_layout%chunk_kin_t, 0, 2, &
1960 2 : cached_layout%chunk_kin_input_t)
1961 2 : cached_layout%chunk_dynamic_input_views_active = .TRUE.
1962 : END IF
1963 2 : IF (features%uses_collapsed_rks_dynamic) THEN
1964 2 : features%density_input_t = cached_layout%chunk_density_input_t
1965 2 : features%grad_input_t = cached_layout%chunk_grad_input_t
1966 2 : features%kin_input_t = cached_layout%chunk_kin_input_t
1967 : END IF
1968 :
1969 2 : IF (.NOT. cached_layout%chunk_inputs_active) THEN
1970 2 : CALL torch_dict_clone(cached_layout%chunk_static_inputs, cached_layout%chunk_inputs)
1971 2 : IF (features%uses_collapsed_rks_dynamic) THEN
1972 : CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
1973 2 : features%density_input_t)
1974 : CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
1975 2 : features%grad_input_t)
1976 : CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
1977 2 : features%kin_input_t)
1978 : ELSE
1979 : CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
1980 0 : cached_layout%chunk_density_t)
1981 : CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
1982 0 : cached_layout%chunk_grad_t)
1983 : CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
1984 0 : cached_layout%chunk_kin_t)
1985 : END IF
1986 : CALL torch_dict_insert(cached_layout%chunk_inputs, "coarse_0_atomic_coords", &
1987 2 : cached_layout%chunk_coarse_0_atomic_coords_t)
1988 2 : cached_layout%chunk_inputs_use_collapsed_rks = features%uses_collapsed_rks_dynamic
1989 2 : cached_layout%chunk_inputs_active = .TRUE.
1990 : END IF
1991 2 : features%inputs = cached_layout%chunk_inputs
1992 2 : features%owns_inputs = .FALSE.
1993 2 : features%coarse_0_atomic_coords_t = cached_layout%chunk_coarse_0_atomic_coords_t
1994 : ELSE
1995 142 : IF (.NOT. requires_stress_grad .AND. .NOT. my_requires_weight_grad) THEN
1996 132 : features%grid_coords_t = cached_layout%grid_coords_t
1997 132 : features%grid_weights_t = cached_layout%grid_weights_t
1998 132 : features%atomic_grid_weights_t = cached_layout%atomic_grid_weights_t
1999 : END IF
2000 142 : features%atomic_grid_sizes_t = cached_layout%atomic_grid_sizes_t
2001 142 : features%atomic_grid_size_bound_shape_t = cached_layout%atomic_grid_size_bound_shape_t
2002 142 : features%local_feature_indices_t = cached_layout%local_feature_indices_t
2003 :
2004 : CALL torch_tensor_reset_from_array(cached_layout%density_t, features%density, &
2005 142 : requires_grad=requires_grad)
2006 142 : features%density_t = cached_layout%density_t
2007 : CALL torch_tensor_reset_from_array(cached_layout%grad_t, features%grad, &
2008 142 : requires_grad=requires_grad)
2009 142 : features%grad_t = cached_layout%grad_t
2010 : CALL torch_tensor_reset_from_array(cached_layout%kin_t, features%kin, &
2011 142 : requires_grad=requires_grad)
2012 142 : features%kin_t = cached_layout%kin_t
2013 142 : cached_layout%dynamic_tensors_active = .TRUE.
2014 :
2015 142 : IF (requires_coordinate_grad .OR. requires_stress_grad .OR. my_requires_weight_grad) THEN
2016 16 : IF (requires_stress_grad .OR. my_requires_weight_grad) THEN
2017 10 : CALL torch_dict_create(features%inputs)
2018 10 : IF (requires_stress_grad) THEN
2019 8 : CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
2020 8 : CALL torch_tensor_to_device_leaf(features%grid_coords_t, .TRUE.)
2021 8 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
2022 8 : features%owns_grid_coordinate_tensor = .TRUE.
2023 : ELSE
2024 2 : features%grid_coords_t = cached_layout%grid_coords_t
2025 2 : CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
2026 : END IF
2027 10 : CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
2028 10 : CALL torch_tensor_to_device_leaf(features%grid_weights_t, .TRUE.)
2029 : CALL torch_tensor_from_array(features%atomic_grid_weights_t, &
2030 10 : features%atomic_grid_weights)
2031 10 : CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .TRUE.)
2032 10 : CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
2033 : CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
2034 10 : features%atomic_grid_weights_t)
2035 : CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
2036 10 : features%atomic_grid_sizes_t)
2037 : CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
2038 10 : features%atomic_grid_size_bound_shape_t)
2039 10 : features%owns_weight_tensors = .TRUE.
2040 : ELSE
2041 6 : CALL torch_dict_clone(cached_layout%static_inputs, features%inputs)
2042 : END IF
2043 16 : CALL torch_dict_insert(features%inputs, "density", features%density_t)
2044 16 : CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
2045 16 : CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
2046 : ELSE
2047 126 : IF (.NOT. cached_layout%inputs_active) THEN
2048 50 : CALL torch_dict_clone(cached_layout%static_inputs, cached_layout%inputs)
2049 50 : CALL torch_dict_insert(cached_layout%inputs, "density", cached_layout%density_t)
2050 50 : CALL torch_dict_insert(cached_layout%inputs, "grad", cached_layout%grad_t)
2051 50 : CALL torch_dict_insert(cached_layout%inputs, "kin", cached_layout%kin_t)
2052 : CALL torch_dict_insert(cached_layout%inputs, "coarse_0_atomic_coords", &
2053 50 : cached_layout%coarse_0_atomic_coords_t)
2054 50 : cached_layout%inputs_active = .TRUE.
2055 : END IF
2056 126 : features%inputs = cached_layout%inputs
2057 126 : features%owns_inputs = .FALSE.
2058 126 : features%coarse_0_atomic_coords_t = cached_layout%coarse_0_atomic_coords_t
2059 : END IF
2060 : END IF
2061 :
2062 144 : IF (requires_coordinate_grad .OR. requires_stress_grad) THEN
2063 16 : CPASSERT(.NOT. use_atom_chunks)
2064 : CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
2065 16 : features%coarse_0_atomic_coords)
2066 16 : CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .TRUE.)
2067 : CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
2068 16 : features%coarse_0_atomic_coords_t)
2069 16 : features%owns_coordinate_tensor = .TRUE.
2070 : END IF
2071 :
2072 144 : END SUBROUTINE add_feature_tensors
2073 :
2074 : ! **************************************************************************************************
2075 : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
2076 : !> \param pw_grid ...
2077 : !> \param index ...
2078 : !> \return ...
2079 : ! **************************************************************************************************
2080 1148674 : FUNCTION grid_coordinate(pw_grid, index) RESULT(coord)
2081 : TYPE(pw_grid_type), POINTER :: pw_grid
2082 : INTEGER, DIMENSION(3), INTENT(IN) :: index
2083 : REAL(KIND=dp), DIMENSION(3) :: coord
2084 :
2085 : INTEGER, DIMENSION(3) :: relative_index
2086 :
2087 4594696 : relative_index = index - pw_grid%bounds(1, :)
2088 : coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
2089 : REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
2090 4594696 : REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
2091 :
2092 1148674 : END FUNCTION grid_coordinate
2093 :
2094 : ! **************************************************************************************************
2095 : !> \brief Build Becke-like smooth atom weights for one native-grid point.
2096 : !> \param grid_point ...
2097 : !> \param atom_coords ...
2098 : !> \param cell ...
2099 : !> \param weights ...
2100 : !> \param atom_image_coords ...
2101 : !> \param distances ...
2102 : ! **************************************************************************************************
2103 41472 : SUBROUTINE smooth_atom_partition(grid_point, atom_coords, cell, weights, atom_image_coords, &
2104 41472 : distances)
2105 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: grid_point
2106 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
2107 : TYPE(cell_type), POINTER :: cell
2108 : REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: weights
2109 : REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT) :: atom_image_coords
2110 : REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: distances
2111 :
2112 : INTEGER :: iatom, jatom, natom
2113 : REAL(KIND=dp) :: mu, rab, rsum, switch, total
2114 : REAL(KIND=dp), DIMENSION(3) :: rij
2115 82944 : REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2)) :: partition_atom_coords
2116 :
2117 41472 : natom = SIZE(atom_coords, 2)
2118 41472 : CPASSERT(SIZE(weights) == natom)
2119 41472 : CPASSERT(SIZE(atom_image_coords, 1) == 3)
2120 41472 : CPASSERT(SIZE(atom_image_coords, 2) == natom)
2121 41472 : CPASSERT(SIZE(distances) == natom)
2122 :
2123 124416 : DO iatom = 1, natom
2124 : atom_image_coords(:, iatom) = &
2125 82944 : nearest_image_coordinate(atom_coords(:, iatom), grid_point, cell)
2126 : partition_atom_coords(:, iatom) = &
2127 82944 : nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
2128 331776 : rij = grid_point - partition_atom_coords(:, iatom)
2129 373248 : distances(iatom) = SQRT(SUM(rij**2))
2130 : END DO
2131 :
2132 124416 : weights = 1.0_dp
2133 82944 : DO iatom = 1, natom - 1
2134 124416 : DO jatom = iatom + 1, natom
2135 165888 : rij = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
2136 165888 : rab = SQRT(SUM(rij**2))
2137 41472 : IF (rab <= layout_tol) CYCLE
2138 41472 : mu = (distances(iatom) - distances(jatom))/rab
2139 41472 : mu = MAX(-1.0_dp, MIN(1.0_dp, mu))
2140 41472 : switch = 0.5_dp*(1.0_dp - becke_shape(mu))
2141 41472 : weights(iatom) = weights(iatom)*switch
2142 82944 : weights(jatom) = weights(jatom)*(1.0_dp - switch)
2143 : END DO
2144 : END DO
2145 :
2146 124416 : total = SUM(weights)
2147 41472 : IF (total > 0.0_dp) THEN
2148 124416 : weights = weights/total
2149 : ELSE
2150 : rsum = HUGE(1.0_dp)
2151 : jatom = 1
2152 0 : DO iatom = 1, natom
2153 0 : IF (distances(iatom) < rsum) THEN
2154 0 : rsum = distances(iatom)
2155 0 : jatom = iatom
2156 : END IF
2157 : END DO
2158 0 : weights = 0.0_dp
2159 0 : weights(jatom) = 1.0_dp
2160 : END IF
2161 :
2162 41472 : END SUBROUTINE smooth_atom_partition
2163 :
2164 : ! **************************************************************************************************
2165 : !> \brief Build smooth atom weights and their atom/cell deformation derivatives.
2166 : !> \param grid_point ...
2167 : !> \param atom_coords ...
2168 : !> \param cell ...
2169 : !> \param weights ...
2170 : !> \param included ...
2171 : !> \param dweights_datom ...
2172 : !> \param dweights_dstrain ...
2173 : ! **************************************************************************************************
2174 41472 : SUBROUTINE skala_gpw_smooth_partition_derivatives(grid_point, atom_coords, cell, &
2175 41472 : weights, included, dweights_datom, &
2176 41472 : dweights_dstrain)
2177 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: grid_point
2178 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
2179 : TYPE(cell_type), POINTER :: cell
2180 : REAL(KIND=dp), DIMENSION(:), INTENT(OUT) :: weights
2181 : LOGICAL, DIMENSION(:), INTENT(OUT) :: included
2182 : REAL(KIND=dp), DIMENSION(:, :, :), INTENT(OUT) :: dweights_datom, dweights_dstrain
2183 :
2184 : INTEGER :: iatom, idir, jatom, jdir, natom
2185 : REAL(KIND=dp) :: dist_diff, ds_dmu, included_sum, mu, &
2186 : mu_raw, one_minus_switch, rab, rsum, &
2187 : switch, total
2188 : REAL(KIND=dp), DIMENSION(3) :: dmu_atom_i, dmu_atom_j, ds_atom_i, &
2189 : ds_atom_j, pair, unit_pair
2190 : REAL(KIND=dp), DIMENSION(3, 3) :: dmu_strain, ds_strain, mean_strain
2191 : REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2), &
2192 82944 : SIZE(atom_coords, 2)) :: log_weight_atom
2193 82944 : REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2)) :: mean_atom, partition_atom_coords, rvecs, &
2194 82944 : unit_rvecs
2195 : REAL(KIND=dp), &
2196 82944 : DIMENSION(3, 3, SIZE(atom_coords, 2)) :: log_weight_strain
2197 82944 : REAL(KIND=dp), DIMENSION(SIZE(atom_coords, 2)) :: distances, normalized_weights, &
2198 41472 : raw_weights
2199 :
2200 41472 : natom = SIZE(atom_coords, 2)
2201 41472 : CPASSERT(SIZE(weights) == natom)
2202 41472 : CPASSERT(SIZE(included) == natom)
2203 41472 : CPASSERT(SIZE(dweights_datom, 1) == 3)
2204 41472 : CPASSERT(SIZE(dweights_datom, 2) == natom)
2205 41472 : CPASSERT(SIZE(dweights_datom, 3) == natom)
2206 41472 : CPASSERT(SIZE(dweights_dstrain, 1) == 3)
2207 41472 : CPASSERT(SIZE(dweights_dstrain, 2) == 3)
2208 41472 : CPASSERT(SIZE(dweights_dstrain, 3) == natom)
2209 :
2210 124416 : weights = 0.0_dp
2211 124416 : included = .FALSE.
2212 787968 : dweights_datom = 0.0_dp
2213 1119744 : dweights_dstrain = 0.0_dp
2214 124416 : raw_weights = 1.0_dp
2215 787968 : log_weight_atom = 0.0_dp
2216 1119744 : log_weight_strain = 0.0_dp
2217 :
2218 124416 : DO iatom = 1, natom
2219 : partition_atom_coords(:, iatom) = &
2220 82944 : nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
2221 331776 : rvecs(:, iatom) = grid_point - partition_atom_coords(:, iatom)
2222 331776 : distances(iatom) = SQRT(SUM(rvecs(:, iatom)**2))
2223 124416 : IF (distances(iatom) > layout_tol) THEN
2224 331776 : unit_rvecs(:, iatom) = rvecs(:, iatom)/distances(iatom)
2225 : ELSE
2226 0 : unit_rvecs(:, iatom) = 0.0_dp
2227 : END IF
2228 : END DO
2229 :
2230 82944 : DO iatom = 1, natom - 1
2231 124416 : DO jatom = iatom + 1, natom
2232 165888 : pair = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
2233 165888 : rab = SQRT(SUM(pair**2))
2234 41472 : IF (rab <= layout_tol) CYCLE
2235 165888 : unit_pair = pair/rab
2236 41472 : dist_diff = distances(iatom) - distances(jatom)
2237 41472 : mu_raw = dist_diff/rab
2238 41472 : mu = MAX(-1.0_dp, MIN(1.0_dp, mu_raw))
2239 41472 : switch = 0.5_dp*(1.0_dp - becke_shape(mu))
2240 41472 : one_minus_switch = 1.0_dp - switch
2241 :
2242 41472 : IF (ABS(mu_raw) < 1.0_dp) THEN
2243 41433 : ds_dmu = -0.5_dp*becke_shape_derivative(mu)
2244 : ELSE
2245 : ds_dmu = 0.0_dp
2246 : END IF
2247 41433 : IF (ABS(ds_dmu) > 0.0_dp .AND. switch > TINY(1.0_dp) .AND. &
2248 : one_minus_switch > TINY(1.0_dp)) THEN
2249 165648 : dmu_atom_i = (-unit_rvecs(:, iatom)*rab - dist_diff*unit_pair)/rab**2
2250 165648 : dmu_atom_j = (unit_rvecs(:, jatom)*rab + dist_diff*unit_pair)/rab**2
2251 165648 : ds_atom_i = ds_dmu*dmu_atom_i
2252 165648 : ds_atom_j = ds_dmu*dmu_atom_j
2253 : log_weight_atom(:, iatom, iatom) = &
2254 165648 : log_weight_atom(:, iatom, iatom) + ds_atom_i/switch
2255 : log_weight_atom(:, iatom, jatom) = &
2256 165648 : log_weight_atom(:, iatom, jatom) - ds_atom_i/one_minus_switch
2257 : log_weight_atom(:, jatom, iatom) = &
2258 165648 : log_weight_atom(:, jatom, iatom) + ds_atom_j/switch
2259 : log_weight_atom(:, jatom, jatom) = &
2260 165648 : log_weight_atom(:, jatom, jatom) - ds_atom_j/one_minus_switch
2261 :
2262 165648 : DO idir = 1, 3
2263 538356 : DO jdir = 1, 3
2264 : dmu_strain(idir, jdir) = &
2265 : ((unit_rvecs(idir, iatom)*rvecs(jdir, iatom) - &
2266 : unit_rvecs(idir, jatom)*rvecs(jdir, jatom))*rab - &
2267 496944 : dist_diff*unit_pair(idir)*pair(jdir))/rab**2
2268 : END DO
2269 : END DO
2270 538356 : ds_strain = ds_dmu*dmu_strain
2271 : log_weight_strain(:, :, iatom) = &
2272 538356 : log_weight_strain(:, :, iatom) + ds_strain/switch
2273 : log_weight_strain(:, :, jatom) = &
2274 538356 : log_weight_strain(:, :, jatom) - ds_strain/one_minus_switch
2275 : END IF
2276 :
2277 41472 : raw_weights(iatom) = raw_weights(iatom)*switch
2278 82944 : raw_weights(jatom) = raw_weights(jatom)*one_minus_switch
2279 : END DO
2280 : END DO
2281 :
2282 124416 : total = SUM(raw_weights)
2283 41472 : IF (total > 0.0_dp) THEN
2284 124416 : normalized_weights = raw_weights/total
2285 124416 : included = normalized_weights > smooth_partition_eps
2286 : ELSE
2287 : rsum = HUGE(1.0_dp)
2288 : jatom = 1
2289 0 : DO iatom = 1, natom
2290 0 : IF (distances(iatom) < rsum) THEN
2291 0 : rsum = distances(iatom)
2292 0 : jatom = iatom
2293 : END IF
2294 : END DO
2295 0 : included(jatom) = .TRUE.
2296 0 : weights(jatom) = 1.0_dp
2297 0 : RETURN
2298 : END IF
2299 :
2300 124416 : included_sum = SUM(raw_weights, MASK=included)
2301 41472 : IF (included_sum <= 0.0_dp) THEN
2302 : rsum = HUGE(1.0_dp)
2303 : jatom = 1
2304 0 : DO iatom = 1, natom
2305 0 : IF (distances(iatom) < rsum) THEN
2306 0 : rsum = distances(iatom)
2307 0 : jatom = iatom
2308 : END IF
2309 : END DO
2310 0 : included = .FALSE.
2311 0 : included(jatom) = .TRUE.
2312 0 : weights = 0.0_dp
2313 0 : weights(jatom) = 1.0_dp
2314 0 : RETURN
2315 : END IF
2316 :
2317 124416 : DO iatom = 1, natom
2318 124416 : IF (included(iatom)) weights(iatom) = raw_weights(iatom)/included_sum
2319 : END DO
2320 :
2321 373248 : mean_atom = 0.0_dp
2322 41472 : mean_strain = 0.0_dp
2323 124416 : DO iatom = 1, natom
2324 82944 : IF (.NOT. included(iatom)) CYCLE
2325 1075542 : mean_strain = mean_strain + weights(iatom)*log_weight_strain(:, :, iatom)
2326 289674 : DO jatom = 1, natom
2327 : mean_atom(:, jatom) = mean_atom(:, jatom) + &
2328 744816 : weights(iatom)*log_weight_atom(:, jatom, iatom)
2329 : END DO
2330 : END DO
2331 :
2332 124416 : DO iatom = 1, natom
2333 82944 : IF (.NOT. included(iatom)) CYCLE
2334 : dweights_dstrain(:, :, iatom) = &
2335 1075542 : weights(iatom)*(log_weight_strain(:, :, iatom) - mean_strain)
2336 289674 : DO jatom = 1, natom
2337 : dweights_datom(:, jatom, iatom) = &
2338 744816 : weights(iatom)*(log_weight_atom(:, jatom, iatom) - mean_atom(:, jatom))
2339 : END DO
2340 : END DO
2341 :
2342 : END SUBROUTINE skala_gpw_smooth_partition_derivatives
2343 :
2344 : ! **************************************************************************************************
2345 : !> \brief Becke fuzzy-cell shape function.
2346 : !> \param mu ...
2347 : !> \return ...
2348 : ! **************************************************************************************************
2349 82944 : PURE FUNCTION becke_shape(mu) RESULT(val)
2350 : REAL(KIND=dp), INTENT(IN) :: mu
2351 : REAL(KIND=dp) :: val
2352 :
2353 : INTEGER :: iter
2354 :
2355 82944 : val = mu
2356 331776 : DO iter = 1, 3
2357 331776 : val = 0.5_dp*val*(3.0_dp - val*val)
2358 : END DO
2359 :
2360 82944 : END FUNCTION becke_shape
2361 :
2362 : ! **************************************************************************************************
2363 : !> \brief Derivative of the Becke fuzzy-cell shape function.
2364 : !> \param mu ...
2365 : !> \return ...
2366 : ! **************************************************************************************************
2367 41433 : PURE FUNCTION becke_shape_derivative(mu) RESULT(val)
2368 : REAL(KIND=dp), INTENT(IN) :: mu
2369 : REAL(KIND=dp) :: val
2370 :
2371 : INTEGER :: iter
2372 : REAL(KIND=dp) :: x
2373 :
2374 41433 : x = mu
2375 41433 : val = 1.0_dp
2376 165732 : DO iter = 1, 3
2377 124299 : val = val*1.5_dp*(1.0_dp - x*x)
2378 165732 : x = 0.5_dp*x*(3.0_dp - x*x)
2379 : END DO
2380 :
2381 41433 : END FUNCTION becke_shape_derivative
2382 :
2383 : ! **************************************************************************************************
2384 : !> \brief Return the atom image nearest to a regular-grid point.
2385 : !> \param atom_coord ...
2386 : !> \param grid_point ...
2387 : !> \param cell ...
2388 : !> \return ...
2389 : ! **************************************************************************************************
2390 165888 : FUNCTION nearest_atom_image_coordinate(atom_coord, grid_point, cell) RESULT(coord)
2391 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: atom_coord, grid_point
2392 : TYPE(cell_type), POINTER :: cell
2393 : REAL(KIND=dp), DIMENSION(3) :: coord
2394 :
2395 : REAL(KIND=dp) :: dx, dy, dz
2396 :
2397 165888 : IF (cell%orthorhombic) THEN
2398 165888 : dx = atom_coord(1) - grid_point(1)
2399 165888 : dy = atom_coord(2) - grid_point(2)
2400 165888 : dz = atom_coord(3) - grid_point(3)
2401 165888 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
2402 165888 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
2403 165888 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
2404 663552 : coord = grid_point + [dx, dy, dz]
2405 : ELSE
2406 0 : coord = grid_point + pbc(grid_point, atom_coord, cell)
2407 : END IF
2408 :
2409 165888 : END FUNCTION nearest_atom_image_coordinate
2410 :
2411 : ! **************************************************************************************************
2412 : !> \brief Return the grid-point image nearest to the owning atom coordinate.
2413 : !> \param owner_coord ...
2414 : !> \param grid_point ...
2415 : !> \param cell ...
2416 : !> \return ...
2417 : ! **************************************************************************************************
2418 1190146 : FUNCTION nearest_image_coordinate(owner_coord, grid_point, cell) RESULT(coord)
2419 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: owner_coord, grid_point
2420 : TYPE(cell_type), POINTER :: cell
2421 : REAL(KIND=dp), DIMENSION(3) :: coord
2422 :
2423 : REAL(KIND=dp) :: dx, dy, dz
2424 :
2425 1190146 : IF (cell%orthorhombic) THEN
2426 1190146 : dx = grid_point(1) - owner_coord(1)
2427 1190146 : dy = grid_point(2) - owner_coord(2)
2428 1190146 : dz = grid_point(3) - owner_coord(3)
2429 1190146 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
2430 1190146 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
2431 1190146 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
2432 4760584 : coord = owner_coord + [dx, dy, dz]
2433 : ELSE
2434 0 : coord = owner_coord + pbc(owner_coord, grid_point, cell)
2435 : END IF
2436 :
2437 1190146 : END FUNCTION nearest_image_coordinate
2438 :
2439 : ! **************************************************************************************************
2440 : !> \brief Assign a grid point to the nearest periodic atom.
2441 : !> \param grid_point ...
2442 : !> \param atom_coords ...
2443 : !> \param cell ...
2444 : !> \return ...
2445 : ! **************************************************************************************************
2446 1107202 : FUNCTION nearest_atom(grid_point, atom_coords, cell) RESULT(owner)
2447 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: grid_point
2448 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: atom_coords
2449 : TYPE(cell_type), POINTER :: cell
2450 : INTEGER :: owner
2451 :
2452 : INTEGER :: iatom
2453 : REAL(KIND=dp) :: best_r2, dx, dy, dz, r2
2454 : REAL(KIND=dp), DIMENSION(3) :: rij
2455 :
2456 1107202 : owner = 1
2457 1107202 : best_r2 = HUGE(1.0_dp)
2458 1107202 : IF (cell%orthorhombic) THEN
2459 4291418 : DO iatom = 1, SIZE(atom_coords, 2)
2460 3184216 : dx = grid_point(1) - atom_coords(1, iatom)
2461 3184216 : dy = grid_point(2) - atom_coords(2, iatom)
2462 3184216 : dz = grid_point(3) - atom_coords(3, iatom)
2463 3184216 : dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
2464 3184216 : dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
2465 3184216 : dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
2466 3184216 : r2 = dx*dx + dy*dy + dz*dz
2467 4291418 : IF (r2 < best_r2) THEN
2468 1967760 : best_r2 = r2
2469 1967760 : owner = iatom
2470 : END IF
2471 : END DO
2472 : ELSE
2473 0 : DO iatom = 1, SIZE(atom_coords, 2)
2474 0 : rij = pbc(grid_point, atom_coords(:, iatom), cell)
2475 0 : r2 = SUM(rij**2)
2476 0 : IF (r2 < best_r2) THEN
2477 0 : best_r2 = r2
2478 0 : owner = iatom
2479 : END IF
2480 : END DO
2481 : END IF
2482 :
2483 1107202 : END FUNCTION nearest_atom
2484 :
2485 0 : END MODULE skala_gpw_features
|