Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
10 : ! **************************************************************************************************
11 : MODULE skala_gpw_functional
12 : USE cell_types, ONLY: cell_type,&
13 : pbc
14 : USE cp_array_utils, ONLY: cp_3d_r_cp_type
15 : USE cp_log_handling, ONLY: cp_logger_get_default_io_unit
16 : USE input_section_types, ONLY: section_get_rval,&
17 : section_vals_get_subs_vals,&
18 : section_vals_get_subs_vals2,&
19 : section_vals_type,&
20 : section_vals_val_get
21 : USE kinds, ONLY: default_path_length,&
22 : dp,&
23 : int_8
24 : USE message_passing, ONLY: mp_comm_type
25 : USE offload_api, ONLY: offload_set_chosen_device
26 : USE particle_types, ONLY: particle_type
27 : USE pw_grid_types, ONLY: pw_grid_type
28 : USE pw_methods, ONLY: pw_scale,&
29 : pw_zero
30 : USE pw_pool_types, ONLY: pw_pool_type
31 : USE pw_types, ONLY: pw_c1d_gs_type,&
32 : pw_r3d_rs_type
33 : USE qs_grid_atom, ONLY: grid_atom_type
34 : USE skala_gpw_features, ONLY: skala_gpw_atom_partition_hard,&
35 : skala_gpw_atom_partition_smooth,&
36 : skala_gpw_atom_subchunk_count,&
37 : skala_gpw_feature_build,&
38 : skala_gpw_feature_build_atom_subchunk,&
39 : skala_gpw_feature_release,&
40 : skala_gpw_feature_type,&
41 : skala_gpw_smooth_partition_derivatives
42 : USE skala_torch_api, ONLY: skala_torch_model_get_exc,&
43 : skala_torch_model_load,&
44 : skala_torch_model_release,&
45 : skala_torch_model_type
46 : USE string_utilities, ONLY: uppercase
47 : USE torch_api, ONLY: &
48 : torch_cuda_device_count, torch_cuda_is_available, torch_dict_create, torch_dict_insert, &
49 : torch_dict_release, torch_dict_type, torch_tensor_backward_scalar, torch_tensor_data_ptr, &
50 : torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
51 : torch_tensor_to_device_leaf, torch_tensor_type, torch_use_cuda
52 : USE xc_rho_cflags_types, ONLY: xc_rho_cflags_type
53 : USE xc_rho_set_types, ONLY: xc_rho_set_create,&
54 : xc_rho_set_get,&
55 : xc_rho_set_release,&
56 : xc_rho_set_type,&
57 : xc_rho_set_update
58 : USE xc_util, ONLY: xc_pw_divergence,&
59 : xc_requires_tmp_g
60 : #include "./base/base_uses.f90"
61 :
62 : IMPLICIT NONE
63 :
64 : PRIVATE
65 :
66 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_functional'
67 : INTEGER, PARAMETER, PRIVATE :: atom_chunk_auto_max_rows = 400000, &
68 : atom_chunk_auto_min_rows = 100000, &
69 : atom_chunk_auto_row_quantum = 100000, &
70 : ncollapsed_grad_per_point = 5, ngrad_per_point = 10
71 :
72 : PUBLIC :: ensure_native_skala_grid_scope, get_gauxc_section, skala_gapw_atom_vxc_of_r, &
73 : skala_gpw_eval, xc_section_uses_native_skala_grid
74 :
75 : TYPE(skala_torch_model_type), SAVE :: cached_model
76 : CHARACTER(len=default_path_length), SAVE :: cached_model_path = ""
77 : LOGICAL, SAVE :: cached_model_loaded = .FALSE.
78 : INTEGER, SAVE :: cached_model_cuda_device = -3
79 : INTEGER, SAVE :: logged_cuda_device = -3, &
80 : logged_cuda_device_count = -1, &
81 : logged_cuda_nproc = -1, &
82 : logged_cuda_request = -3
83 :
84 : CONTAINS
85 :
86 : ! **************************************************************************************************
87 : !> \brief Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
88 : !> \param xc_section ...
89 : !> \return ...
90 : ! **************************************************************************************************
91 183791 : FUNCTION xc_section_uses_native_skala_grid(xc_section) RESULT(uses_native_grid)
92 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
93 : LOGICAL :: uses_native_grid
94 :
95 : TYPE(section_vals_type), POINTER :: gauxc_section
96 :
97 183791 : uses_native_grid = .FALSE.
98 183791 : gauxc_section => get_gauxc_section(xc_section)
99 183791 : IF (ASSOCIATED(gauxc_section)) THEN
100 674 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=uses_native_grid)
101 : END IF
102 :
103 183791 : END FUNCTION xc_section_uses_native_skala_grid
104 :
105 : ! **************************************************************************************************
106 : !> \brief Enforce the currently implemented native SKALA GPW input scope.
107 : !> \param xc_section ...
108 : ! **************************************************************************************************
109 288 : SUBROUTINE ensure_native_skala_grid_scope(xc_section)
110 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
111 :
112 : CHARACTER(len=default_path_length) :: model_key, model_name
113 : INTEGER :: ifun, nfun
114 : LOGICAL :: native_grid
115 : TYPE(section_vals_type), POINTER :: functionals, gauxc_section, xc_fun
116 :
117 144 : NULLIFY (gauxc_section)
118 144 : IF (.NOT. ASSOCIATED(xc_section)) THEN
119 0 : CPABORT("Native SKALA GPW requires an XC section")
120 : END IF
121 :
122 144 : functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
123 144 : IF (.NOT. ASSOCIATED(functionals)) THEN
124 0 : CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL section")
125 : END IF
126 :
127 144 : nfun = 0
128 144 : ifun = 0
129 : DO
130 288 : ifun = ifun + 1
131 288 : xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
132 288 : IF (.NOT. ASSOCIATED(xc_fun)) EXIT
133 144 : nfun = nfun + 1
134 288 : IF (xc_fun%section%name == "GAUXC") gauxc_section => xc_fun
135 : END DO
136 :
137 144 : IF (.NOT. ASSOCIATED(gauxc_section)) THEN
138 0 : CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
139 : END IF
140 144 : IF (nfun /= 1) THEN
141 0 : CPABORT("Native SKALA GPW requires GAUXC to be the only XC functional")
142 : END IF
143 :
144 144 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=native_grid)
145 144 : IF (.NOT. native_grid) RETURN
146 :
147 144 : CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
148 144 : model_key = ADJUSTL(model_name)
149 144 : CALL uppercase(model_key)
150 144 : IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
151 0 : CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
152 : END IF
153 :
154 : END SUBROUTINE ensure_native_skala_grid_scope
155 :
156 : ! **************************************************************************************************
157 : !> \brief Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
158 : !> \param vxc_rho ...
159 : !> \param vxc_tau ...
160 : !> \param exc ...
161 : !> \param rho_r ...
162 : !> \param rho_g ...
163 : !> \param tau ...
164 : !> \param xc_section ...
165 : !> \param weights ...
166 : !> \param pw_pool ...
167 : !> \param particle_set ...
168 : !> \param cell ...
169 : !> \param compute_virial ...
170 : !> \param virial_xc ...
171 : !> \param just_energy ...
172 : !> \param atom_force ...
173 : ! **************************************************************************************************
174 144 : SUBROUTINE skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, &
175 : weights, pw_pool, particle_set, cell, compute_virial, virial_xc, &
176 144 : just_energy, atom_force)
177 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: vxc_rho, vxc_tau
178 : REAL(KIND=dp), INTENT(OUT) :: exc
179 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
180 : TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER :: rho_g
181 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: tau
182 : TYPE(section_vals_type), POINTER :: xc_section
183 : TYPE(pw_r3d_rs_type), POINTER :: weights
184 : TYPE(pw_pool_type), POINTER :: pw_pool
185 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
186 : TYPE(cell_type), POINTER :: cell
187 : LOGICAL, INTENT(IN) :: compute_virial
188 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(OUT) :: virial_xc
189 : LOGICAL, INTENT(IN), OPTIONAL :: just_energy
190 : REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT), &
191 : OPTIONAL :: atom_force
192 :
193 : CHARACTER(len=default_path_length) :: model_path
194 : INTEGER :: iw, native_grid_atom_chunk_max_rows, native_grid_atom_partition, &
195 : native_grid_atom_subchunks, native_grid_cuda_device, nspins, phase_handle, &
196 : selected_cuda_device, xc_deriv_method_id, xc_rho_smooth_id
197 : LOGICAL :: have_atom_coord_grad, lsd, my_just_energy, native_grid_atom_chunk_routing, &
198 : native_grid_atom_chunks, native_grid_diagnostics, native_grid_use_cuda, needs_atom_force, &
199 : use_atom_subchunks
200 144 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: density_grad, kin_grad
201 144 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: grad_grad
202 : REAL(KIND=dp), DIMENSION(3, 3) :: virial_before
203 : TYPE(section_vals_type), POINTER :: gauxc_section
204 144 : TYPE(skala_gpw_feature_type) :: features
205 : TYPE(torch_tensor_type) :: atom_coord_grad_t, &
206 : atomic_grid_weight_grad_t, exc_tensor, &
207 : grid_coord_grad_t, grid_weight_grad_t
208 : TYPE(xc_rho_cflags_type) :: needs
209 : TYPE(xc_rho_set_type) :: rho_set
210 :
211 144 : virial_xc = 0.0_dp
212 144 : exc = 0.0_dp
213 144 : my_just_energy = .FALSE.
214 144 : IF (PRESENT(just_energy)) my_just_energy = just_energy
215 144 : needs_atom_force = PRESENT(atom_force)
216 272 : IF (needs_atom_force) atom_force = 0.0_dp
217 144 : have_atom_coord_grad = .FALSE.
218 :
219 144 : IF (compute_virial .AND. my_just_energy) THEN
220 : CALL cp_abort(__LOCATION__, &
221 0 : "Native SKALA GPW stress/virial requires feature gradients.")
222 : END IF
223 144 : IF (.NOT. ASSOCIATED(rho_g)) THEN
224 : CALL cp_abort(__LOCATION__, &
225 0 : "Native SKALA GPW requires the reciprocal-space density to form density gradients.")
226 : END IF
227 144 : IF (.NOT. ASSOCIATED(tau)) THEN
228 : CALL cp_abort(__LOCATION__, &
229 0 : "Native SKALA GPW requires the kinetic-energy density.")
230 : END IF
231 :
232 144 : nspins = SIZE(rho_r)
233 144 : lsd = (nspins /= 1)
234 144 : CALL get_skala_model_path(xc_section, model_path)
235 144 : gauxc_section => get_gauxc_section(xc_section)
236 144 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
237 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
238 144 : i_val=native_grid_cuda_device)
239 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
240 144 : l_val=native_grid_atom_chunks)
241 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
242 144 : l_val=native_grid_atom_chunk_routing)
243 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_MAX_ROWS", &
244 144 : i_val=native_grid_atom_chunk_max_rows)
245 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_PARTITION", &
246 144 : i_val=native_grid_atom_partition)
247 134 : SELECT CASE (native_grid_atom_partition)
248 : CASE (1)
249 134 : native_grid_atom_partition = skala_gpw_atom_partition_hard
250 : CASE (2)
251 10 : native_grid_atom_partition = skala_gpw_atom_partition_smooth
252 : CASE DEFAULT
253 : CALL cp_abort(__LOCATION__, &
254 144 : "Unknown GAUXC%NATIVE_GRID_ATOM_PARTITION value.")
255 : END SELECT
256 144 : native_grid_atom_chunk_routing = native_grid_atom_chunk_routing .OR. native_grid_atom_chunks
257 144 : native_grid_atom_chunks = native_grid_atom_chunks .OR. native_grid_atom_chunk_routing
258 144 : IF (native_grid_atom_chunk_max_rows < -1) THEN
259 : CALL cp_abort(__LOCATION__, &
260 0 : "GAUXC%NATIVE_GRID_ATOM_CHUNK_MAX_ROWS must be -1, zero, or positive.")
261 : END IF
262 144 : IF (native_grid_atom_chunks .AND. needs_atom_force) THEN
263 : CALL cp_abort(__LOCATION__, &
264 0 : "Native SKALA GPW atom chunks are not implemented for atom forces yet.")
265 : END IF
266 : ! The portable SKALA export used by the regtests builds ragged-index tensors on CPU.
267 144 : CALL torch_use_cuda(native_grid_use_cuda)
268 : selected_cuda_device = configure_native_grid_cuda( &
269 144 : native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
270 144 : CALL ensure_model_loaded(model_path, selected_cuda_device)
271 :
272 144 : IF (lsd) THEN
273 54 : needs%rho_spin = .TRUE.
274 54 : needs%drho_spin = .TRUE.
275 54 : needs%tau_spin = .TRUE.
276 : ELSE
277 90 : needs%rho = .TRUE.
278 90 : needs%drho = .TRUE.
279 90 : needs%tau = .TRUE.
280 : END IF
281 :
282 144 : CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
283 144 : CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
284 :
285 : CALL xc_rho_set_create(rho_set, &
286 : rho_r(1)%pw_grid%bounds_local, &
287 : rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
288 : drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
289 144 : tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
290 : CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
291 144 : xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
292 :
293 : CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
294 : requires_grad=(.NOT. my_just_energy), weights=weights, &
295 : requires_coordinate_grad=(needs_atom_force .OR. compute_virial), &
296 : requires_stress_grad=compute_virial, &
297 : use_atom_chunks=native_grid_atom_chunks, &
298 : route_atom_chunks=native_grid_atom_chunk_routing, &
299 272 : atom_partition=native_grid_atom_partition)
300 144 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_DIAGNOSTICS", l_val=native_grid_diagnostics)
301 144 : IF (native_grid_diagnostics) THEN
302 24 : CALL print_native_grid_diagnostics(features, rho_r(1)%pw_grid%para%group%mepos == 0)
303 : END IF
304 :
305 144 : IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows == -1) THEN
306 0 : IF (native_grid_use_cuda) THEN
307 : native_grid_atom_chunk_max_rows = auto_atom_chunk_max_rows(features, &
308 0 : rho_r(1)%pw_grid%para%group)
309 : ELSE
310 0 : native_grid_atom_chunk_max_rows = 0
311 : END IF
312 : END IF
313 144 : IF (native_grid_diagnostics .AND. features%uses_atom_chunks .AND. &
314 : rho_r(1)%pw_grid%para%group%mepos == 0) THEN
315 1 : iw = cp_logger_get_default_io_unit()
316 1 : IF (iw > 0) THEN
317 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0)") &
318 1 : "SKALA_GPW| Native grid atom chunk max rows", native_grid_atom_chunk_max_rows
319 : END IF
320 : END IF
321 144 : native_grid_atom_subchunks = 1
322 144 : IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows > 0) THEN
323 2 : native_grid_atom_subchunks = skala_gpw_atom_subchunk_count(native_grid_atom_chunk_max_rows)
324 : END IF
325 144 : use_atom_subchunks = native_grid_atom_subchunks > 1
326 2 : IF (use_atom_subchunks) THEN
327 : CALL evaluate_atom_subchunks(features, rho_r(1)%pw_grid%para%group, &
328 : native_grid_atom_chunk_max_rows, &
329 : compute_grads=(.NOT. my_just_energy), exc=exc, &
330 : density_grad=density_grad, grad_grad=grad_grad, &
331 2 : kin_grad=kin_grad, collapse_spin_grads=(nspins == 1))
332 : ELSE
333 : CALL skala_torch_model_get_exc(cached_model, features%inputs, &
334 142 : features%grid_weights_t, exc_tensor, exc)
335 : END IF
336 144 : IF (features%uses_atom_chunks) CALL rho_r(1)%pw_grid%para%group%sum(exc)
337 :
338 144 : IF (.NOT. my_just_energy) THEN
339 144 : IF (.NOT. use_atom_subchunks) THEN
340 142 : CALL timeset("skala_gpw_backward", phase_handle)
341 142 : CALL torch_tensor_backward_scalar(exc_tensor)
342 142 : CALL timestop(phase_handle)
343 :
344 142 : CALL timeset("skala_gpw_grad_fetch", phase_handle)
345 142 : IF (features%uses_atom_chunks) THEN
346 : CALL fetch_and_gather_atom_chunk_grads(features, rho_r(1)%pw_grid%para%group, &
347 0 : density_grad, grad_grad, kin_grad)
348 : ELSE
349 142 : CALL fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
350 : END IF
351 142 : CALL timestop(phase_handle)
352 : END IF
353 144 : IF (needs_atom_force) THEN
354 : CALL add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, &
355 16 : rho_r(1)%pw_grid%para%group%mepos == 0)
356 16 : IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
357 : CALL add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
358 4 : grid_weight_grad_t, atomic_grid_weight_grad_t)
359 : END IF
360 : have_atom_coord_grad = .TRUE.
361 : END IF
362 :
363 144 : CALL timeset("skala_gpw_vxc_unpack", phase_handle)
364 144 : IF (compute_virial) THEN
365 8 : IF (native_grid_diagnostics) virial_before = virial_xc
366 8 : CALL build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
367 8 : IF (native_grid_diagnostics) THEN
368 : CALL print_virial_delta("feature-gradient", virial_xc - virial_before, &
369 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
370 0 : virial_before = virial_xc
371 : END IF
372 8 : IF (.NOT. have_atom_coord_grad) THEN
373 0 : CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
374 0 : have_atom_coord_grad = .TRUE.
375 : END IF
376 : CALL build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
377 : grid_coord_grad_t, &
378 : rho_r(1)%pw_grid%para%group%mepos == 0, &
379 8 : native_grid_diagnostics)
380 8 : IF (native_grid_diagnostics) THEN
381 : CALL print_virial_delta("static-coordinates", virial_xc - virial_before, &
382 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
383 0 : virial_before = virial_xc
384 : END IF
385 8 : IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
386 : CALL build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
387 2 : grid_weight_grad_t, atomic_grid_weight_grad_t)
388 2 : IF (native_grid_diagnostics) THEN
389 : CALL print_virial_delta("smooth-partition", virial_xc - virial_before, &
390 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
391 0 : virial_before = virial_xc
392 : END IF
393 : END IF
394 : CALL build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
395 : atomic_grid_weight_grad_t, &
396 : rho_r(1)%pw_grid%para%group%mepos == 0, &
397 8 : native_grid_diagnostics)
398 8 : IF (native_grid_diagnostics) THEN
399 : CALL print_virial_delta("weight-residual", virial_xc - virial_before, &
400 0 : rho_r(1)%pw_grid%para%group%mepos == 0)
401 : END IF
402 : END IF
403 : CALL build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
404 : density_grad, grad_grad, kin_grad, &
405 144 : xc_deriv_method_id)
406 144 : CALL timestop(phase_handle)
407 :
408 144 : CALL timeset("skala_gpw_grad_release", phase_handle)
409 144 : DEALLOCATE (density_grad, grad_grad, kin_grad)
410 144 : IF (have_atom_coord_grad) CALL torch_tensor_release(atom_coord_grad_t)
411 144 : CALL timestop(phase_handle)
412 : END IF
413 :
414 144 : CALL timeset("skala_gpw_cleanup", phase_handle)
415 144 : IF (.NOT. use_atom_subchunks) CALL torch_tensor_release(exc_tensor)
416 144 : CALL skala_gpw_feature_release(features)
417 144 : CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
418 144 : CALL torch_use_cuda(.TRUE.)
419 144 : CALL timestop(phase_handle)
420 :
421 2880 : END SUBROUTINE skala_gpw_eval
422 :
423 : ! **************************************************************************************************
424 : !> \brief Evaluate SKALA on a GAPW one-center atomic grid.
425 : !> \param xc_section ...
426 : !> \param grid_atom ...
427 : !> \param group ...
428 : !> \param atom_coord ...
429 : !> \param rho ...
430 : !> \param drho ...
431 : !> \param tau ...
432 : !> \param weights ...
433 : !> \param lsd ...
434 : !> \param nspins ...
435 : !> \param na ...
436 : !> \param nr ...
437 : !> \param exc ...
438 : !> \param vxc ...
439 : !> \param vxg ...
440 : !> \param vtau ...
441 : !> \param energy_only ...
442 : ! **************************************************************************************************
443 8 : SUBROUTINE skala_gapw_atom_vxc_of_r(xc_section, grid_atom, group, atom_coord, &
444 8 : rho, drho, tau, weights, lsd, nspins, na, nr, &
445 : exc, vxc, vxg, vtau, energy_only)
446 : TYPE(section_vals_type), POINTER :: xc_section
447 : TYPE(grid_atom_type), POINTER :: grid_atom
448 :
449 : CLASS(mp_comm_type), INTENT(IN) :: group
450 : REAL(KIND=dp), DIMENSION(3), INTENT(IN) :: atom_coord
451 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: rho, tau, vxc, vtau
452 : REAL(KIND=dp), DIMENSION(:, :, :, :), POINTER :: drho, vxg
453 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: weights
454 : LOGICAL, INTENT(IN) :: lsd
455 : INTEGER, INTENT(IN) :: nspins, na, nr
456 : REAL(KIND=dp), INTENT(OUT) :: exc
457 : LOGICAL, INTENT(IN), OPTIONAL :: energy_only
458 :
459 : CHARACTER(len=default_path_length) :: model_path
460 : INTEGER :: ia, idir, ir, native_grid_cuda_device, &
461 : nflat, row, selected_cuda_device
462 8 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:) :: atomic_grid_sizes
463 8 : INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
464 : LOGICAL :: my_energy_only, native_grid_use_cuda
465 8 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: atomic_grid_weights, grid_weights
466 8 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: coarse_0_atomic_coords, density, &
467 8 : grid_coords, kin
468 8 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: grad
469 8 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: density_grad, kin_grad
470 8 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: grad_grad
471 : TYPE(section_vals_type), POINTER :: gauxc_section
472 : TYPE(torch_dict_type) :: inputs
473 : TYPE(torch_tensor_type) :: atomic_grid_size_bound_shape_t, &
474 : atomic_grid_sizes_t, &
475 : atomic_grid_weights_t, &
476 : coarse_0_atomic_coords_t, density_t, &
477 : density_grad_t, exc_tensor, grad_t, &
478 : grad_grad_t, grid_coords_t, &
479 : grid_weights_t, kin_t, kin_grad_t
480 :
481 0 : CPASSERT(ASSOCIATED(xc_section))
482 8 : CPASSERT(ASSOCIATED(grid_atom))
483 8 : CPASSERT(ASSOCIATED(rho))
484 8 : CPASSERT(ASSOCIATED(drho))
485 8 : CPASSERT(ASSOCIATED(tau))
486 :
487 8 : my_energy_only = .FALSE.
488 8 : IF (PRESENT(energy_only)) my_energy_only = energy_only
489 0 : exc = 0.0_dp
490 8 : IF (.NOT. my_energy_only) THEN
491 20416 : vxc = 0.0_dp
492 80416 : vxg = 0.0_dp
493 20416 : vtau = 0.0_dp
494 : END IF
495 :
496 8 : CALL get_skala_model_path(xc_section, model_path)
497 8 : gauxc_section => get_gauxc_section(xc_section)
498 8 : CPASSERT(ASSOCIATED(gauxc_section))
499 8 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
500 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
501 8 : i_val=native_grid_cuda_device)
502 8 : CALL torch_use_cuda(native_grid_use_cuda)
503 : selected_cuda_device = configure_native_grid_cuda( &
504 8 : native_grid_use_cuda, native_grid_cuda_device, group)
505 8 : CALL ensure_model_loaded(model_path, selected_cuda_device)
506 :
507 8 : nflat = na*nr
508 : ALLOCATE (density(nflat, 2), grad(nflat, 3, 2), kin(nflat, 2), &
509 : grid_coords(3, nflat), grid_weights(nflat), &
510 : atomic_grid_weights(nflat), atomic_grid_sizes(1), &
511 104 : coarse_0_atomic_coords(3, 1), atomic_grid_size_bound_shape(0, nflat))
512 8 : density = 0.0_dp
513 8 : grad = 0.0_dp
514 8 : kin = 0.0_dp
515 8 : grid_coords = 0.0_dp
516 8 : grid_weights = 0.0_dp
517 8 : atomic_grid_weights = 0.0_dp
518 8 : atomic_grid_sizes(1) = INT(nflat, KIND=int_8)
519 : atomic_grid_size_bound_shape = 0_int_8
520 32 : coarse_0_atomic_coords(:, 1) = atom_coord
521 :
522 : row = 0
523 408 : DO ir = 1, nr
524 20408 : DO ia = 1, na
525 20000 : row = row + 1
526 : grid_coords(1, row) = atom_coord(1) + grid_atom%rad(ir)* &
527 20000 : grid_atom%sin_pol(ia)*grid_atom%cos_azi(ia)
528 : grid_coords(2, row) = atom_coord(2) + grid_atom%rad(ir)* &
529 20000 : grid_atom%sin_pol(ia)*grid_atom%sin_azi(ia)
530 20000 : grid_coords(3, row) = atom_coord(3) + grid_atom%rad(ir)*grid_atom%cos_pol(ia)
531 20000 : grid_weights(row) = weights(ia, ir)
532 20000 : atomic_grid_weights(row) = weights(ia, ir)
533 20400 : IF (nspins == 1) THEN
534 60000 : density(row, :) = 0.5_dp*rho(ia, ir, 1)
535 80000 : DO idir = 1, 3
536 200000 : grad(row, idir, :) = 0.5_dp*drho(idir, ia, ir, 1)
537 : END DO
538 60000 : kin(row, :) = 0.5_dp*tau(ia, ir, 1)
539 : ELSE
540 0 : density(row, :) = rho(ia, ir, 1:2)
541 0 : DO idir = 1, 3
542 0 : grad(row, idir, :) = drho(idir, ia, ir, 1:2)
543 : END DO
544 0 : kin(row, :) = tau(ia, ir, 1:2)
545 : END IF
546 : END DO
547 : END DO
548 :
549 8 : CALL torch_tensor_from_array(grid_coords_t, grid_coords)
550 8 : CALL torch_tensor_to_device_leaf(grid_coords_t, .FALSE.)
551 8 : CALL torch_tensor_from_array(grid_weights_t, grid_weights)
552 8 : CALL torch_tensor_to_device_leaf(grid_weights_t, .FALSE.)
553 8 : CALL torch_tensor_from_array(atomic_grid_weights_t, atomic_grid_weights)
554 8 : CALL torch_tensor_to_device_leaf(atomic_grid_weights_t, .FALSE.)
555 8 : CALL torch_tensor_from_array(atomic_grid_sizes_t, atomic_grid_sizes)
556 8 : CALL torch_tensor_to_device_leaf(atomic_grid_sizes_t, .FALSE.)
557 : CALL torch_tensor_from_array(atomic_grid_size_bound_shape_t, &
558 8 : atomic_grid_size_bound_shape)
559 8 : CALL torch_tensor_to_device_leaf(atomic_grid_size_bound_shape_t, .FALSE.)
560 8 : CALL torch_tensor_from_array(coarse_0_atomic_coords_t, coarse_0_atomic_coords)
561 8 : CALL torch_tensor_to_device_leaf(coarse_0_atomic_coords_t, .FALSE.)
562 8 : CALL torch_tensor_from_array(density_t, density)
563 8 : CALL torch_tensor_to_device_leaf(density_t,.NOT. my_energy_only)
564 8 : CALL torch_tensor_from_array(grad_t, grad)
565 8 : CALL torch_tensor_to_device_leaf(grad_t,.NOT. my_energy_only)
566 8 : CALL torch_tensor_from_array(kin_t, kin)
567 8 : CALL torch_tensor_to_device_leaf(kin_t,.NOT. my_energy_only)
568 :
569 8 : CALL torch_dict_create(inputs)
570 8 : CALL torch_dict_insert(inputs, "grid_coords", grid_coords_t)
571 8 : CALL torch_dict_insert(inputs, "grid_weights", grid_weights_t)
572 8 : CALL torch_dict_insert(inputs, "atomic_grid_weights", atomic_grid_weights_t)
573 8 : CALL torch_dict_insert(inputs, "atomic_grid_sizes", atomic_grid_sizes_t)
574 : CALL torch_dict_insert(inputs, "atomic_grid_size_bound_shape", &
575 8 : atomic_grid_size_bound_shape_t)
576 8 : CALL torch_dict_insert(inputs, "density", density_t)
577 8 : CALL torch_dict_insert(inputs, "grad", grad_t)
578 8 : CALL torch_dict_insert(inputs, "kin", kin_t)
579 8 : CALL torch_dict_insert(inputs, "coarse_0_atomic_coords", coarse_0_atomic_coords_t)
580 :
581 8 : CALL skala_torch_model_get_exc(cached_model, inputs, grid_weights_t, exc_tensor, exc)
582 :
583 8 : IF (.NOT. my_energy_only) THEN
584 8 : NULLIFY (density_grad, grad_grad, kin_grad)
585 8 : CALL torch_tensor_backward_scalar(exc_tensor)
586 8 : CALL torch_tensor_grad(density_t, density_grad_t)
587 8 : CALL torch_tensor_grad(grad_t, grad_grad_t)
588 8 : CALL torch_tensor_grad(kin_t, kin_grad_t)
589 8 : CALL torch_tensor_data_ptr(density_grad_t, density_grad)
590 8 : CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
591 8 : CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
592 :
593 8 : row = 0
594 408 : DO ir = 1, nr
595 20408 : DO ia = 1, na
596 20000 : row = row + 1
597 20400 : IF (lsd) THEN
598 0 : vxc(ia, ir, 1:2) = density_grad(row, 1:2)
599 0 : DO idir = 1, 3
600 0 : vxg(idir, ia, ir, 1:2) = grad_grad(row, idir, 1:2)
601 : END DO
602 0 : vtau(ia, ir, 1:2) = kin_grad(row, 1:2)
603 : ELSE
604 20000 : vxc(ia, ir, 1) = 0.5_dp*(density_grad(row, 1) + density_grad(row, 2))
605 80000 : DO idir = 1, 3
606 : vxg(idir, ia, ir, 1) = &
607 80000 : 0.5_dp*(grad_grad(row, idir, 1) + grad_grad(row, idir, 2))
608 : END DO
609 20000 : vtau(ia, ir, 1) = 0.5_dp*(kin_grad(row, 1) + kin_grad(row, 2))
610 : END IF
611 : END DO
612 : END DO
613 :
614 8 : CALL torch_tensor_release(density_grad_t)
615 8 : CALL torch_tensor_release(grad_grad_t)
616 8 : CALL torch_tensor_release(kin_grad_t)
617 : END IF
618 :
619 8 : CALL torch_tensor_release(exc_tensor)
620 8 : CALL torch_tensor_release(density_t)
621 8 : CALL torch_tensor_release(grad_t)
622 8 : CALL torch_tensor_release(kin_t)
623 8 : CALL torch_tensor_release(grid_coords_t)
624 8 : CALL torch_tensor_release(grid_weights_t)
625 8 : CALL torch_tensor_release(atomic_grid_weights_t)
626 8 : CALL torch_tensor_release(atomic_grid_sizes_t)
627 8 : CALL torch_tensor_release(atomic_grid_size_bound_shape_t)
628 8 : CALL torch_tensor_release(coarse_0_atomic_coords_t)
629 8 : CALL torch_dict_release(inputs)
630 0 : DEALLOCATE (atomic_grid_size_bound_shape, atomic_grid_sizes, atomic_grid_weights, &
631 8 : coarse_0_atomic_coords, density, grad, grid_coords, grid_weights, kin)
632 8 : CALL torch_use_cuda(.TRUE.)
633 :
634 24 : END SUBROUTINE skala_gapw_atom_vxc_of_r
635 :
636 : ! **************************************************************************************************
637 : !> \brief Add the explicit SKALA derivative with respect to atom-center coordinates.
638 : !> \param atom_force ...
639 : !> \param features ...
640 : !> \param atom_coord_grad_t ...
641 : !> \param root_rank ...
642 : ! **************************************************************************************************
643 16 : SUBROUTINE add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, root_rank)
644 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: atom_force
645 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
646 : TYPE(torch_tensor_type), INTENT(INOUT) :: atom_coord_grad_t
647 : LOGICAL, INTENT(IN) :: root_rank
648 :
649 16 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: atom_coord_grad
650 :
651 16 : NULLIFY (atom_coord_grad)
652 16 : CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
653 16 : IF (root_rank) THEN
654 8 : CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
655 8 : CPASSERT(SIZE(atom_force, 1) == SIZE(atom_coord_grad, 1))
656 8 : CPASSERT(SIZE(atom_force, 2) == SIZE(atom_coord_grad, 2))
657 72 : atom_force(:, :) = atom_force(:, :) + atom_coord_grad(:, :)
658 : END IF
659 :
660 16 : END SUBROUTINE add_explicit_coordinate_force
661 :
662 : ! **************************************************************************************************
663 : !> \brief Add the force from SMOOTH native-grid atom partition weights.
664 : !> \param atom_force ...
665 : !> \param features ...
666 : !> \param particle_set ...
667 : !> \param cell ...
668 : !> \param rho_r ...
669 : !> \param grid_weight_grad_t ...
670 : !> \param atomic_grid_weight_grad_t ...
671 : ! **************************************************************************************************
672 4 : SUBROUTINE add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
673 : grid_weight_grad_t, atomic_grid_weight_grad_t)
674 : REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT) :: atom_force
675 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
676 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
677 : TYPE(cell_type), POINTER :: cell
678 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
679 : TYPE(torch_tensor_type), INTENT(INOUT) :: grid_weight_grad_t, &
680 : atomic_grid_weight_grad_t
681 :
682 : INTEGER :: feature_begin, feature_end, feature_pos, &
683 : i, iatom, j, jatom, k, local_row, &
684 : natom, row
685 : INTEGER, DIMENSION(2, 3) :: bo
686 : LOGICAL, ALLOCATABLE, DIMENSION(:) :: included
687 : REAL(KIND=dp) :: base_weight, weight_grad
688 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: weights
689 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc
690 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: dweights_datom, dweights_dstrain
691 : REAL(KIND=dp), DIMENSION(3) :: grid_point
692 4 : REAL(KIND=dp), DIMENSION(:), POINTER :: atomic_grid_weight_grad, grid_weight_grad
693 :
694 4 : NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
695 4 : CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
696 4 : CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
697 4 : CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
698 4 : CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
699 :
700 4 : natom = SIZE(particle_set)
701 4 : CPASSERT(SIZE(atom_force, 1) == 3)
702 4 : CPASSERT(SIZE(atom_force, 2) == natom)
703 : ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
704 48 : dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
705 12 : DO iatom = 1, natom
706 12 : atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
707 : END DO
708 :
709 40 : bo = rho_r(1)%pw_grid%bounds_local
710 4 : local_row = 0
711 100 : DO k = bo(1, 3), bo(2, 3)
712 2404 : DO j = bo(1, 2), bo(2, 2)
713 30048 : DO i = bo(1, 1), bo(2, 1)
714 27648 : local_row = local_row + 1
715 110592 : grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
716 : CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
717 : weights, included, dweights_datom, &
718 27648 : dweights_dstrain)
719 27648 : feature_begin = features%local_feature_offsets(local_row)
720 27648 : feature_end = features%local_feature_offsets(local_row + 1) - 1
721 82944 : CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
722 27648 : base_weight = 0.0_dp
723 82804 : DO feature_pos = feature_begin, feature_end
724 55156 : row = features%local_feature_rows(feature_pos)
725 82804 : base_weight = base_weight + features%grid_weights(row)
726 : END DO
727 : feature_pos = feature_begin
728 82944 : DO iatom = 1, natom
729 55296 : IF (.NOT. included(iatom)) CYCLE
730 55156 : row = features%local_feature_rows(feature_pos)
731 55156 : weight_grad = grid_weight_grad(row) + atomic_grid_weight_grad(row)
732 165468 : DO jatom = 1, natom
733 : atom_force(:, jatom) = atom_force(:, jatom) + &
734 : weight_grad*base_weight* &
735 496404 : dweights_datom(:, jatom, iatom)
736 : END DO
737 82944 : feature_pos = feature_pos + 1
738 : END DO
739 29952 : CPASSERT(feature_pos == feature_end + 1)
740 : END DO
741 : END DO
742 : END DO
743 4 : CPASSERT(local_row == features%nflat_local)
744 :
745 4 : DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
746 4 : CALL torch_tensor_release(grid_weight_grad_t)
747 4 : CALL torch_tensor_release(atomic_grid_weight_grad_t)
748 :
749 4 : END SUBROUTINE add_smooth_partition_force
750 :
751 : ! **************************************************************************************************
752 : !> \brief Add the virial from SMOOTH native-grid atom partition weights.
753 : !> \param virial_xc ...
754 : !> \param features ...
755 : !> \param particle_set ...
756 : !> \param cell ...
757 : !> \param rho_r ...
758 : !> \param grid_weight_grad_t ...
759 : !> \param atomic_grid_weight_grad_t ...
760 : ! **************************************************************************************************
761 2 : SUBROUTINE build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
762 : grid_weight_grad_t, atomic_grid_weight_grad_t)
763 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
764 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
765 : TYPE(particle_type), DIMENSION(:), POINTER :: particle_set
766 : TYPE(cell_type), POINTER :: cell
767 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
768 : TYPE(torch_tensor_type), INTENT(INOUT) :: grid_weight_grad_t, &
769 : atomic_grid_weight_grad_t
770 :
771 : INTEGER :: feature_begin, feature_end, feature_pos, &
772 : i, iatom, idir, j, jdir, k, local_row, &
773 : natom, row
774 : INTEGER, DIMENSION(2, 3) :: bo
775 : LOGICAL, ALLOCATABLE, DIMENSION(:) :: included
776 : REAL(KIND=dp) :: base_weight, tmp, weight_grad
777 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: weights
778 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :) :: atom_coords_pbc
779 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :) :: dweights_datom, dweights_dstrain
780 : REAL(KIND=dp), DIMENSION(3) :: grid_point
781 2 : REAL(KIND=dp), DIMENSION(:), POINTER :: atomic_grid_weight_grad, grid_weight_grad
782 :
783 2 : NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
784 2 : CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
785 2 : CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
786 2 : CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
787 2 : CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
788 :
789 2 : natom = SIZE(particle_set)
790 : ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
791 24 : dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
792 6 : DO iatom = 1, natom
793 6 : atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
794 : END DO
795 :
796 20 : bo = rho_r(1)%pw_grid%bounds_local
797 2 : local_row = 0
798 50 : DO k = bo(1, 3), bo(2, 3)
799 1202 : DO j = bo(1, 2), bo(2, 2)
800 15024 : DO i = bo(1, 1), bo(2, 1)
801 13824 : local_row = local_row + 1
802 55296 : grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
803 : CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
804 : weights, included, dweights_datom, &
805 13824 : dweights_dstrain)
806 13824 : feature_begin = features%local_feature_offsets(local_row)
807 13824 : feature_end = features%local_feature_offsets(local_row + 1) - 1
808 41472 : CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
809 13824 : base_weight = 0.0_dp
810 41402 : DO feature_pos = feature_begin, feature_end
811 27578 : row = features%local_feature_rows(feature_pos)
812 41402 : base_weight = base_weight + features%grid_weights(row)
813 : END DO
814 : feature_pos = feature_begin
815 41472 : DO iatom = 1, natom
816 27648 : IF (.NOT. included(iatom)) CYCLE
817 27578 : row = features%local_feature_rows(feature_pos)
818 27578 : weight_grad = grid_weight_grad(row) + atomic_grid_weight_grad(row)
819 110312 : DO idir = 1, 3
820 275780 : DO jdir = 1, idir
821 165468 : tmp = weight_grad*base_weight*dweights_dstrain(idir, jdir, iatom)
822 165468 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
823 248202 : virial_xc(idir, jdir) = virial_xc(jdir, idir)
824 : END DO
825 : END DO
826 41472 : feature_pos = feature_pos + 1
827 : END DO
828 14976 : CPASSERT(feature_pos == feature_end + 1)
829 : END DO
830 : END DO
831 : END DO
832 2 : CPASSERT(local_row == features%nflat_local)
833 :
834 2 : DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
835 2 : CALL torch_tensor_release(grid_weight_grad_t)
836 2 : CALL torch_tensor_release(atomic_grid_weight_grad_t)
837 :
838 2 : END SUBROUTINE build_smooth_partition_virial
839 :
840 : ! **************************************************************************************************
841 : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
842 : !> \param pw_grid ...
843 : !> \param index ...
844 : !> \return ...
845 : ! **************************************************************************************************
846 41472 : FUNCTION native_grid_coordinate(pw_grid, index) RESULT(coord)
847 : TYPE(pw_grid_type), POINTER :: pw_grid
848 : INTEGER, DIMENSION(3), INTENT(IN) :: index
849 : REAL(KIND=dp), DIMENSION(3) :: coord
850 :
851 : INTEGER, DIMENSION(3) :: relative_index
852 :
853 165888 : relative_index = index - pw_grid%bounds(1, :)
854 : coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
855 : REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
856 165888 : REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
857 :
858 41472 : END FUNCTION native_grid_coordinate
859 :
860 : ! **************************************************************************************************
861 : !> \brief Evaluate a rank-local atom chunk as multiple atom-contiguous Torch subchunks.
862 : !> \param features ...
863 : !> \param group ...
864 : !> \param max_rows ...
865 : !> \param compute_grads ...
866 : !> \param exc ...
867 : !> \param density_grad ...
868 : !> \param grad_grad ...
869 : !> \param kin_grad ...
870 : !> \param collapse_spin_grads ...
871 : ! **************************************************************************************************
872 2 : SUBROUTINE evaluate_atom_subchunks(features, group, max_rows, compute_grads, exc, &
873 : density_grad, grad_grad, kin_grad, collapse_spin_grads)
874 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
875 :
876 : CLASS(mp_comm_type), INTENT(IN) :: group
877 : INTEGER, INTENT(IN) :: max_rows
878 : LOGICAL, INTENT(IN) :: compute_grads, collapse_spin_grads
879 : REAL(KIND=dp), INTENT(OUT) :: exc
880 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
881 : INTENT(OUT) :: density_grad, kin_grad
882 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
883 : INTENT(OUT) :: grad_grad
884 :
885 : INTEGER :: base, isubchunk, local_row, nflat_local, &
886 : nroute_grad_per_point, nroute_points, &
887 : nsubchunks, phase_handle, point_pos, &
888 : subphase_handle
889 2 : INTEGER, ALLOCATABLE, DIMENSION(:) :: route_grad_return_recv_counts, &
890 2 : route_grad_return_recv_displs, &
891 2 : route_grad_return_send_counts, &
892 2 : route_grad_return_send_displs
893 : REAL(KIND=dp) :: subchunk_exc
894 2 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: recv_grad_buffer, send_grad_buffer
895 2 : TYPE(skala_gpw_feature_type) :: subchunk
896 : TYPE(torch_tensor_type) :: subchunk_exc_tensor
897 :
898 0 : CPASSERT(features%uses_atom_chunks)
899 2 : CPASSERT(max_rows > 0)
900 2 : nflat_local = features%nflat_local
901 2 : nsubchunks = skala_gpw_atom_subchunk_count(max_rows)
902 2 : CPASSERT(nsubchunks > 0)
903 :
904 2 : exc = 0.0_dp
905 2 : IF (compute_grads) THEN
906 2 : CPASSERT(features%uses_atom_chunk_routing)
907 6 : CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
908 2 : nroute_points = SIZE(features%route_send_local_rows)
909 6 : CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
910 2 : nroute_grad_per_point = ngrad_per_point
911 2 : IF (collapse_spin_grads) nroute_grad_per_point = ncollapsed_grad_per_point
912 : ALLOCATE (send_grad_buffer(nroute_grad_per_point*features%chunk_feature_count), &
913 : recv_grad_buffer(nroute_grad_per_point*nroute_points), &
914 : route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
915 : route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
916 : route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
917 26 : route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
918 : route_grad_return_send_counts(:) = &
919 6 : nroute_grad_per_point*features%route_point_recv_counts
920 : route_grad_return_send_displs(:) = &
921 6 : nroute_grad_per_point*features%route_point_recv_displs
922 : route_grad_return_recv_counts(:) = &
923 6 : nroute_grad_per_point*features%route_point_send_counts
924 : route_grad_return_recv_displs(:) = &
925 6 : nroute_grad_per_point*features%route_point_send_displs
926 : END IF
927 :
928 2 : CALL timeset("skala_gpw_atom_subchunks", phase_handle)
929 6 : DO isubchunk = 1, nsubchunks
930 4 : CALL timeset("skala_gpw_atom_subchunk_build", subphase_handle)
931 : CALL skala_gpw_feature_build_atom_subchunk(features, subchunk, isubchunk, &
932 4 : max_rows, compute_grads)
933 4 : CALL timestop(subphase_handle)
934 4 : CALL timeset("skala_gpw_atom_subchunk_forward", subphase_handle)
935 : CALL skala_torch_model_get_exc(cached_model, subchunk%inputs, &
936 : subchunk%grid_weights_t, subchunk_exc_tensor, &
937 4 : subchunk_exc)
938 4 : CALL timestop(subphase_handle)
939 4 : exc = exc + subchunk_exc
940 4 : IF (compute_grads) THEN
941 4 : CALL timeset("skala_gpw_atom_subchunk_backward", subphase_handle)
942 4 : CALL torch_tensor_backward_scalar(subchunk_exc_tensor)
943 4 : CALL timestop(subphase_handle)
944 : END IF
945 4 : CALL timeset("skala_gpw_atom_subchunk_release", subphase_handle)
946 4 : CALL torch_tensor_release(subchunk_exc_tensor)
947 4 : CALL skala_gpw_feature_release(subchunk)
948 18 : CALL timestop(subphase_handle)
949 : END DO
950 2 : IF (compute_grads) THEN
951 2 : CALL timeset("skala_gpw_atom_subchunk_grad_pack", subphase_handle)
952 2 : CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., collapse_spin_grads)
953 2 : CALL timestop(subphase_handle)
954 : END IF
955 2 : CALL timestop(phase_handle)
956 :
957 2 : IF (compute_grads) THEN
958 2 : CALL timeset("skala_gpw_grad_route_comm", phase_handle)
959 : CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
960 : route_grad_return_send_displs, recv_grad_buffer, &
961 2 : route_grad_return_recv_counts, route_grad_return_recv_displs)
962 2 : CALL timestop(phase_handle)
963 :
964 2 : CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
965 0 : ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
966 14 : kin_grad(nflat_local, 2))
967 2 : density_grad = 0.0_dp
968 2 : grad_grad = 0.0_dp
969 2 : kin_grad = 0.0_dp
970 64002 : DO point_pos = 1, nroute_points
971 64000 : local_row = features%route_send_local_rows(point_pos)
972 64000 : CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
973 64000 : base = nroute_grad_per_point*(point_pos - 1)
974 64002 : IF (collapse_spin_grads) THEN
975 : density_grad(local_row, :) = density_grad(local_row, :) + &
976 192000 : recv_grad_buffer(base + 1)
977 : grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
978 192000 : recv_grad_buffer(base + 2)
979 : grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
980 192000 : recv_grad_buffer(base + 3)
981 : grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
982 192000 : recv_grad_buffer(base + 4)
983 192000 : kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
984 : ELSE
985 : density_grad(local_row, :) = density_grad(local_row, :) + &
986 0 : recv_grad_buffer(base + 1:base + 2)
987 : grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
988 0 : recv_grad_buffer(base + 3)
989 : grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
990 0 : recv_grad_buffer(base + 4)
991 : grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
992 0 : recv_grad_buffer(base + 5)
993 : grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
994 0 : recv_grad_buffer(base + 6)
995 : grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
996 0 : recv_grad_buffer(base + 7)
997 : grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
998 0 : recv_grad_buffer(base + 8)
999 : kin_grad(local_row, :) = kin_grad(local_row, :) + &
1000 0 : recv_grad_buffer(base + 9:base + 10)
1001 : END IF
1002 : END DO
1003 2 : CALL timestop(phase_handle)
1004 :
1005 0 : DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
1006 0 : route_grad_return_recv_displs, route_grad_return_send_counts, &
1007 6 : route_grad_return_send_displs, send_grad_buffer)
1008 : END IF
1009 :
1010 4 : END SUBROUTINE evaluate_atom_subchunks
1011 :
1012 : ! **************************************************************************************************
1013 : !> \brief Select an automatic CUDA atom-subchunk row cap.
1014 : !> \param features ...
1015 : !> \param group ...
1016 : !> \return ...
1017 : ! **************************************************************************************************
1018 0 : FUNCTION auto_atom_chunk_max_rows(features, group) RESULT(max_rows)
1019 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1020 :
1021 : CLASS(mp_comm_type), INTENT(IN) :: group
1022 : INTEGER :: max_rows
1023 :
1024 : INTEGER :: local_rows_max, target_rows
1025 :
1026 0 : local_rows_max = features%chunk_feature_count
1027 0 : CALL group%max(local_rows_max)
1028 0 : IF (local_rows_max <= 0) THEN
1029 0 : max_rows = 0
1030 : RETURN
1031 : END IF
1032 :
1033 0 : IF (group%num_pe > 1) THEN
1034 0 : target_rows = CEILING(REAL(local_rows_max, KIND=dp)/2.0_dp)
1035 : max_rows = atom_chunk_auto_row_quantum* &
1036 0 : ((target_rows + atom_chunk_auto_row_quantum - 1)/atom_chunk_auto_row_quantum)
1037 : ELSE
1038 0 : target_rows = NINT(REAL(local_rows_max, KIND=dp)/4.0_dp)
1039 : max_rows = atom_chunk_auto_row_quantum* &
1040 : MAX(1, NINT(REAL(target_rows, KIND=dp)/ &
1041 0 : REAL(atom_chunk_auto_row_quantum, KIND=dp)))
1042 : END IF
1043 0 : max_rows = MAX(atom_chunk_auto_min_rows, MIN(atom_chunk_auto_max_rows, max_rows))
1044 :
1045 0 : END FUNCTION auto_atom_chunk_max_rows
1046 :
1047 : ! **************************************************************************************************
1048 : !> \brief Map full Torch feature gradients back to this rank's local grid order.
1049 : !> \param features ...
1050 : !> \param density_grad ...
1051 : !> \param grad_grad ...
1052 : !> \param kin_grad ...
1053 : ! **************************************************************************************************
1054 142 : SUBROUTINE fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
1055 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1056 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
1057 : INTENT(OUT) :: density_grad
1058 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
1059 : INTENT(OUT) :: grad_grad
1060 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
1061 : INTENT(OUT) :: kin_grad
1062 :
1063 : INTEGER :: feature_pos, i, j, k, local_row, row
1064 142 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: density_grad_all, kin_grad_all
1065 142 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: grad_grad_all
1066 : TYPE(torch_tensor_type) :: density_grad_t, grad_grad_t, kin_grad_t
1067 :
1068 142 : NULLIFY (density_grad_all, grad_grad_all, kin_grad_all)
1069 : CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
1070 142 : density_grad_all, grad_grad_all, kin_grad_all)
1071 142 : CPASSERT(SIZE(density_grad_all, 1) == features%nflat)
1072 142 : CPASSERT(SIZE(density_grad_all, 2) == 2)
1073 142 : CPASSERT(SIZE(grad_grad_all, 1) == features%nflat)
1074 142 : CPASSERT(SIZE(grad_grad_all, 2) == 3)
1075 142 : CPASSERT(SIZE(grad_grad_all, 3) == 2)
1076 142 : CPASSERT(SIZE(kin_grad_all, 1) == features%nflat)
1077 142 : CPASSERT(SIZE(kin_grad_all, 2) == 2)
1078 :
1079 0 : ALLOCATE (density_grad(features%nflat_local, 2), &
1080 0 : grad_grad(features%nflat_local, 3, 2), &
1081 994 : kin_grad(features%nflat_local, 2))
1082 142 : density_grad = 0.0_dp
1083 142 : grad_grad = 0.0_dp
1084 142 : kin_grad = 0.0_dp
1085 142 : local_row = 0
1086 3462 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1087 89518 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1088 1546458 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1089 1302618 : local_row = local_row + 1
1090 2674006 : DO feature_pos = features%local_feature_offsets(local_row), &
1091 1382886 : features%local_feature_offsets(local_row + 1) - 1
1092 1371388 : row = features%local_feature_rows(feature_pos)
1093 1371388 : CPASSERT(row >= 1 .AND. row <= features%nflat)
1094 : density_grad(local_row, :) = density_grad(local_row, :) + &
1095 4114164 : density_grad_all(row, :)
1096 : grad_grad(local_row, :, :) = grad_grad(local_row, :, :) + &
1097 12342492 : grad_grad_all(row, :, :)
1098 5416782 : kin_grad(local_row, :) = kin_grad(local_row, :) + kin_grad_all(row, :)
1099 : END DO
1100 : END DO
1101 : END DO
1102 : END DO
1103 142 : CPASSERT(local_row == features%nflat_local)
1104 :
1105 142 : CALL torch_tensor_release(density_grad_t)
1106 142 : CALL torch_tensor_release(grad_grad_t)
1107 142 : CALL torch_tensor_release(kin_grad_t)
1108 :
1109 142 : END SUBROUTINE fetch_local_feature_grads
1110 :
1111 : ! **************************************************************************************************
1112 : !> \brief Pack atom-chunk Torch gradients into CP2K communication buffers.
1113 : !> \param features ...
1114 : !> \param TARGET ...
1115 : !> \param route_to_return_positions ...
1116 : !> \param collapse_spin_grads ...
1117 : ! **************************************************************************************************
1118 2 : SUBROUTINE pack_atom_chunk_grads(features, TARGET, route_to_return_positions, &
1119 : collapse_spin_grads)
1120 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1121 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
1122 : INTENT(INOUT) :: target
1123 : LOGICAL, INTENT(IN) :: route_to_return_positions
1124 : LOGICAL, INTENT(IN), OPTIONAL :: collapse_spin_grads
1125 :
1126 : INTEGER :: base, irow, ngrad_buffer_per_point, &
1127 : point_pos, target_points
1128 : LOGICAL :: my_collapse_spin_grads
1129 2 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: chunk_density_grad, chunk_kin_grad
1130 2 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: chunk_grad_grad
1131 : TYPE(torch_tensor_type) :: density_grad_t, grad_grad_t, kin_grad_t
1132 :
1133 2 : my_collapse_spin_grads = .FALSE.
1134 4 : IF (PRESENT(collapse_spin_grads)) my_collapse_spin_grads = collapse_spin_grads
1135 2 : ngrad_buffer_per_point = ngrad_per_point
1136 2 : IF (my_collapse_spin_grads) ngrad_buffer_per_point = ncollapsed_grad_per_point
1137 :
1138 2 : NULLIFY (chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
1139 : CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
1140 2 : chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
1141 2 : CPASSERT(MOD(SIZE(TARGET), ngrad_buffer_per_point) == 0)
1142 2 : target_points = SIZE(TARGET)/ngrad_buffer_per_point
1143 2 : CPASSERT(target_points >= features%chunk_feature_count)
1144 2 : CPASSERT(SIZE(chunk_density_grad, 1) == features%chunk_feature_count)
1145 2 : CPASSERT(SIZE(chunk_grad_grad, 1) == features%chunk_feature_count)
1146 2 : CPASSERT(SIZE(chunk_grad_grad, 2) == 3)
1147 2 : CPASSERT(SIZE(chunk_kin_grad, 1) == features%chunk_feature_count)
1148 2 : IF (features%uses_collapsed_rks_dynamic) THEN
1149 2 : CPASSERT(my_collapse_spin_grads)
1150 2 : CPASSERT(SIZE(chunk_density_grad, 2) == 1)
1151 2 : CPASSERT(SIZE(chunk_grad_grad, 3) == 1)
1152 2 : CPASSERT(SIZE(chunk_kin_grad, 2) == 1)
1153 : ELSE
1154 0 : CPASSERT(SIZE(chunk_density_grad, 2) == 2)
1155 0 : CPASSERT(SIZE(chunk_grad_grad, 3) == 2)
1156 0 : CPASSERT(SIZE(chunk_kin_grad, 2) == 2)
1157 : END IF
1158 :
1159 64002 : DO irow = 1, features%chunk_feature_count
1160 64000 : IF (route_to_return_positions) THEN
1161 64000 : point_pos = features%chunk_return_positions(irow)
1162 64000 : CPASSERT(point_pos >= 1 .AND. point_pos <= target_points)
1163 : ELSE
1164 : point_pos = irow
1165 : END IF
1166 64000 : base = ngrad_buffer_per_point*(point_pos - 1)
1167 64002 : IF (my_collapse_spin_grads) THEN
1168 64000 : IF (features%uses_collapsed_rks_dynamic) THEN
1169 64000 : TARGET(base + 1) = 0.5_dp*chunk_density_grad(irow, 1)
1170 64000 : TARGET(base + 2) = 0.5_dp*chunk_grad_grad(irow, 1, 1)
1171 64000 : TARGET(base + 3) = 0.5_dp*chunk_grad_grad(irow, 2, 1)
1172 64000 : TARGET(base + 4) = 0.5_dp*chunk_grad_grad(irow, 3, 1)
1173 64000 : TARGET(base + 5) = 0.5_dp*chunk_kin_grad(irow, 1)
1174 : ELSE
1175 : TARGET(base + 1) = 0.5_dp*(chunk_density_grad(irow, 1) + &
1176 0 : chunk_density_grad(irow, 2))
1177 : TARGET(base + 2) = 0.5_dp*(chunk_grad_grad(irow, 1, 1) + &
1178 0 : chunk_grad_grad(irow, 1, 2))
1179 : TARGET(base + 3) = 0.5_dp*(chunk_grad_grad(irow, 2, 1) + &
1180 0 : chunk_grad_grad(irow, 2, 2))
1181 : TARGET(base + 4) = 0.5_dp*(chunk_grad_grad(irow, 3, 1) + &
1182 0 : chunk_grad_grad(irow, 3, 2))
1183 0 : TARGET(base + 5) = 0.5_dp*(chunk_kin_grad(irow, 1) + chunk_kin_grad(irow, 2))
1184 : END IF
1185 : ELSE
1186 0 : TARGET(base + 1:base + 2) = chunk_density_grad(irow, :)
1187 0 : TARGET(base + 3) = chunk_grad_grad(irow, 1, 1)
1188 0 : TARGET(base + 4) = chunk_grad_grad(irow, 2, 1)
1189 0 : TARGET(base + 5) = chunk_grad_grad(irow, 3, 1)
1190 0 : TARGET(base + 6) = chunk_grad_grad(irow, 1, 2)
1191 0 : TARGET(base + 7) = chunk_grad_grad(irow, 2, 2)
1192 0 : TARGET(base + 8) = chunk_grad_grad(irow, 3, 2)
1193 0 : TARGET(base + 9:base + 10) = chunk_kin_grad(irow, :)
1194 : END IF
1195 : END DO
1196 :
1197 2 : CALL torch_tensor_release(density_grad_t)
1198 2 : CALL torch_tensor_release(grad_grad_t)
1199 2 : CALL torch_tensor_release(kin_grad_t)
1200 :
1201 2 : END SUBROUTINE pack_atom_chunk_grads
1202 :
1203 : ! **************************************************************************************************
1204 : !> \brief Return CPU views of autograd outputs for the SKALA dynamic feature tensors.
1205 : !> \param features ...
1206 : !> \param density_grad_t ...
1207 : !> \param grad_grad_t ...
1208 : !> \param kin_grad_t ...
1209 : !> \param density_grad ...
1210 : !> \param grad_grad ...
1211 : !> \param kin_grad ...
1212 : ! **************************************************************************************************
1213 144 : SUBROUTINE get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
1214 : density_grad, grad_grad, kin_grad)
1215 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1216 : TYPE(torch_tensor_type), INTENT(INOUT) :: density_grad_t, grad_grad_t, kin_grad_t
1217 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: density_grad
1218 : REAL(KIND=dp), DIMENSION(:, :, :), POINTER :: grad_grad
1219 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: kin_grad
1220 :
1221 144 : NULLIFY (density_grad, grad_grad, kin_grad)
1222 144 : CALL torch_tensor_grad(features%density_t, density_grad_t)
1223 144 : CALL torch_tensor_grad(features%grad_t, grad_grad_t)
1224 144 : CALL torch_tensor_grad(features%kin_t, kin_grad_t)
1225 144 : CALL torch_tensor_data_ptr(density_grad_t, density_grad)
1226 144 : CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
1227 144 : CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
1228 :
1229 144 : END SUBROUTINE get_feature_grad_views
1230 :
1231 : ! **************************************************************************************************
1232 : !> \brief Fetch atom-chunk gradients and route them back to their local grid owners.
1233 : !> \param features ...
1234 : !> \param group ...
1235 : !> \param density_grad ...
1236 : !> \param grad_grad ...
1237 : !> \param kin_grad ...
1238 : ! **************************************************************************************************
1239 0 : SUBROUTINE fetch_and_gather_atom_chunk_grads(features, group, density_grad, grad_grad, &
1240 : kin_grad)
1241 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1242 :
1243 : CLASS(mp_comm_type), INTENT(IN) :: group
1244 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
1245 : INTENT(OUT) :: density_grad, kin_grad
1246 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
1247 : INTENT(OUT) :: grad_grad
1248 :
1249 : INTEGER :: base, feature_pos, i, j, k, local_row, &
1250 : nflat_local, nroute_grad_per_point, &
1251 : nroute_points, phase_handle, point_pos, row
1252 0 : INTEGER, ALLOCATABLE, DIMENSION(:) :: route_grad_return_recv_counts, &
1253 0 : route_grad_return_recv_displs, &
1254 0 : route_grad_return_send_counts, &
1255 0 : route_grad_return_send_displs
1256 0 : REAL(KIND=dp), ALLOCATABLE, DIMENSION(:) :: chunk_grad_buffer, global_grad_buffer, &
1257 0 : recv_grad_buffer, send_grad_buffer
1258 :
1259 0 : CPASSERT(features%uses_atom_chunks)
1260 0 : CPASSERT(features%chunk_feature_count > 0)
1261 :
1262 0 : nflat_local = features%nflat_local
1263 0 : IF (features%uses_atom_chunk_routing) THEN
1264 0 : CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
1265 0 : nroute_points = SIZE(features%route_send_local_rows)
1266 0 : CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
1267 :
1268 0 : nroute_grad_per_point = ngrad_per_point
1269 0 : IF (features%uses_collapsed_rks_dynamic) &
1270 0 : nroute_grad_per_point = ncollapsed_grad_per_point
1271 : ALLOCATE (send_grad_buffer(nroute_grad_per_point*features%chunk_feature_count), &
1272 : recv_grad_buffer(nroute_grad_per_point*nroute_points), &
1273 : route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
1274 : route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
1275 : route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
1276 0 : route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
1277 : route_grad_return_send_counts(:) = &
1278 0 : nroute_grad_per_point*features%route_point_recv_counts
1279 : route_grad_return_send_displs(:) = &
1280 0 : nroute_grad_per_point*features%route_point_recv_displs
1281 : route_grad_return_recv_counts(:) = &
1282 0 : nroute_grad_per_point*features%route_point_send_counts
1283 : route_grad_return_recv_displs(:) = &
1284 0 : nroute_grad_per_point*features%route_point_send_displs
1285 :
1286 0 : CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
1287 : CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., &
1288 0 : features%uses_collapsed_rks_dynamic)
1289 0 : CALL timestop(phase_handle)
1290 :
1291 0 : CALL timeset("skala_gpw_grad_route_comm", phase_handle)
1292 : CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
1293 : route_grad_return_send_displs, recv_grad_buffer, &
1294 0 : route_grad_return_recv_counts, route_grad_return_recv_displs)
1295 0 : CALL timestop(phase_handle)
1296 :
1297 0 : CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
1298 0 : ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
1299 0 : kin_grad(nflat_local, 2))
1300 0 : density_grad = 0.0_dp
1301 0 : grad_grad = 0.0_dp
1302 0 : kin_grad = 0.0_dp
1303 0 : DO point_pos = 1, nroute_points
1304 0 : local_row = features%route_send_local_rows(point_pos)
1305 0 : CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
1306 0 : base = nroute_grad_per_point*(point_pos - 1)
1307 0 : IF (features%uses_collapsed_rks_dynamic) THEN
1308 : density_grad(local_row, :) = density_grad(local_row, :) + &
1309 0 : recv_grad_buffer(base + 1)
1310 : grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
1311 0 : recv_grad_buffer(base + 2)
1312 : grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
1313 0 : recv_grad_buffer(base + 3)
1314 : grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
1315 0 : recv_grad_buffer(base + 4)
1316 0 : kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
1317 : ELSE
1318 : density_grad(local_row, :) = density_grad(local_row, :) + &
1319 0 : recv_grad_buffer(base + 1:base + 2)
1320 : grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
1321 0 : recv_grad_buffer(base + 3)
1322 : grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
1323 0 : recv_grad_buffer(base + 4)
1324 : grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
1325 0 : recv_grad_buffer(base + 5)
1326 : grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
1327 0 : recv_grad_buffer(base + 6)
1328 : grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
1329 0 : recv_grad_buffer(base + 7)
1330 : grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
1331 0 : recv_grad_buffer(base + 8)
1332 : kin_grad(local_row, :) = kin_grad(local_row, :) + &
1333 0 : recv_grad_buffer(base + 9:base + 10)
1334 : END IF
1335 : END DO
1336 0 : CALL timestop(phase_handle)
1337 :
1338 0 : DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
1339 0 : route_grad_return_recv_displs, route_grad_return_send_counts, &
1340 0 : route_grad_return_send_displs, send_grad_buffer)
1341 : ELSE
1342 : ALLOCATE (chunk_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
1343 0 : global_grad_buffer(ngrad_per_point*features%nflat))
1344 0 : CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
1345 0 : CALL pack_atom_chunk_grads(features, chunk_grad_buffer, .FALSE.)
1346 0 : CALL timestop(phase_handle)
1347 :
1348 0 : CALL timeset("skala_gpw_grad_allgatherv", phase_handle)
1349 : CALL group%allgatherv(chunk_grad_buffer, global_grad_buffer, &
1350 0 : features%chunk_grad_counts, features%chunk_grad_displs)
1351 0 : CALL timestop(phase_handle)
1352 :
1353 0 : CALL timeset("skala_gpw_grad_scatter", phase_handle)
1354 0 : ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
1355 0 : kin_grad(nflat_local, 2))
1356 0 : density_grad = 0.0_dp
1357 0 : grad_grad = 0.0_dp
1358 0 : kin_grad = 0.0_dp
1359 0 : local_row = 0
1360 0 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1361 0 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1362 0 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1363 0 : local_row = local_row + 1
1364 0 : DO feature_pos = features%local_feature_offsets(local_row), &
1365 0 : features%local_feature_offsets(local_row + 1) - 1
1366 0 : row = features%local_feature_rows(feature_pos)
1367 0 : CPASSERT(row >= 1 .AND. row <= features%nflat)
1368 0 : base = ngrad_per_point*(row - 1)
1369 : density_grad(local_row, :) = density_grad(local_row, :) + &
1370 0 : global_grad_buffer(base + 1:base + 2)
1371 : grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
1372 0 : global_grad_buffer(base + 3)
1373 : grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
1374 0 : global_grad_buffer(base + 4)
1375 : grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
1376 0 : global_grad_buffer(base + 5)
1377 : grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
1378 0 : global_grad_buffer(base + 6)
1379 : grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
1380 0 : global_grad_buffer(base + 7)
1381 : grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
1382 0 : global_grad_buffer(base + 8)
1383 : kin_grad(local_row, :) = kin_grad(local_row, :) + &
1384 0 : global_grad_buffer(base + 9:base + 10)
1385 : END DO
1386 : END DO
1387 : END DO
1388 : END DO
1389 0 : CALL timestop(phase_handle)
1390 0 : DEALLOCATE (chunk_grad_buffer, global_grad_buffer)
1391 :
1392 : END IF
1393 :
1394 0 : END SUBROUTINE fetch_and_gather_atom_chunk_grads
1395 :
1396 : ! **************************************************************************************************
1397 : !> \brief Build the native SKALA XC virial from feature gradients.
1398 : !> \param virial_xc ...
1399 : !> \param rho_set ...
1400 : !> \param rho_r ...
1401 : !> \param grad_grad ...
1402 : ! **************************************************************************************************
1403 8 : SUBROUTINE build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
1404 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
1405 : TYPE(xc_rho_set_type), INTENT(IN) :: rho_set
1406 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: rho_r
1407 : REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN) :: grad_grad
1408 :
1409 : INTEGER :: i, idir, ipt, ispin, j, jdir, k, nspins
1410 : INTEGER, DIMENSION(2, 3) :: bo
1411 : REAL(KIND=dp) :: grad_i, tmp
1412 96 : TYPE(cp_3d_r_cp_type), DIMENSION(3) :: drho, drhoa, drhob
1413 :
1414 8 : nspins = SIZE(rho_r)
1415 80 : bo = rho_r(1)%pw_grid%bounds_local
1416 8 : ipt = 0
1417 :
1418 8 : IF (nspins == 1) THEN
1419 8 : CALL xc_rho_set_get(rho_set, drho=drho)
1420 200 : DO k = bo(1, 3), bo(2, 3)
1421 4808 : DO j = bo(1, 2), bo(2, 2)
1422 60096 : DO i = bo(1, 1), bo(2, 1)
1423 55296 : ipt = ipt + 1
1424 225792 : DO idir = 1, 3
1425 165888 : grad_i = 0.5_dp*(grad_grad(ipt, idir, 1) + grad_grad(ipt, idir, 2))
1426 552960 : DO jdir = 1, idir
1427 331776 : tmp = -grad_i*drho(jdir)%array(i, j, k)
1428 331776 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
1429 497664 : virial_xc(idir, jdir) = virial_xc(jdir, idir)
1430 : END DO
1431 : END DO
1432 : END DO
1433 : END DO
1434 : END DO
1435 : ELSE
1436 0 : CALL xc_rho_set_get(rho_set, drhoa=drhoa, drhob=drhob)
1437 0 : DO k = bo(1, 3), bo(2, 3)
1438 0 : DO j = bo(1, 2), bo(2, 2)
1439 0 : DO i = bo(1, 1), bo(2, 1)
1440 0 : ipt = ipt + 1
1441 0 : DO idir = 1, 3
1442 0 : DO jdir = 1, idir
1443 : tmp = 0.0_dp
1444 0 : DO ispin = 1, 2
1445 0 : IF (ispin == 1) THEN
1446 0 : tmp = tmp - grad_grad(ipt, idir, ispin)*drhoa(jdir)%array(i, j, k)
1447 : ELSE
1448 0 : tmp = tmp - grad_grad(ipt, idir, ispin)*drhob(jdir)%array(i, j, k)
1449 : END IF
1450 : END DO
1451 0 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
1452 0 : virial_xc(idir, jdir) = virial_xc(jdir, idir)
1453 : END DO
1454 : END DO
1455 : END DO
1456 : END DO
1457 : END DO
1458 : END IF
1459 :
1460 8 : END SUBROUTINE build_virial_from_feature_grads
1461 :
1462 : ! **************************************************************************************************
1463 : !> \brief Print a native SKALA XC virial contribution for diagnostics.
1464 : !> \param label ...
1465 : !> \param delta ...
1466 : !> \param root_rank ...
1467 : ! **************************************************************************************************
1468 0 : SUBROUTINE print_virial_delta(label, delta, root_rank)
1469 : CHARACTER(LEN=*), INTENT(IN) :: label
1470 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(IN) :: delta
1471 : LOGICAL, INTENT(IN) :: root_rank
1472 :
1473 : INTEGER :: i, iw
1474 :
1475 0 : IF (.NOT. root_rank) RETURN
1476 0 : iw = cp_logger_get_default_io_unit()
1477 0 : IF (iw <= 0) RETURN
1478 0 : WRITE (iw, "(T2,A,1X,A)") "SKALA_GPW| XC virial contribution", TRIM(label)
1479 0 : DO i = 1, 3
1480 0 : WRITE (iw, "(T2,A,1X,3ES20.10)") "SKALA_GPW|", delta(i, 1:3)
1481 : END DO
1482 :
1483 : END SUBROUTINE print_virial_delta
1484 :
1485 : ! **************************************************************************************************
1486 : !> \brief Add explicit SKALA coordinate-feature contributions to the XC virial.
1487 : !> \param virial_xc ...
1488 : !> \param features ...
1489 : !> \param atom_coord_grad_t ...
1490 : !> \param grid_coord_grad_t ...
1491 : !> \param root_rank ...
1492 : !> \param print_components ...
1493 : ! **************************************************************************************************
1494 8 : SUBROUTINE build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
1495 : grid_coord_grad_t, root_rank, print_components)
1496 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
1497 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1498 : TYPE(torch_tensor_type), INTENT(INOUT) :: atom_coord_grad_t, grid_coord_grad_t
1499 : LOGICAL, INTENT(IN) :: root_rank
1500 : LOGICAL, INTENT(IN), OPTIONAL :: print_components
1501 :
1502 : INTEGER :: feature_pos, i, iatom, idir, iw, j, &
1503 : jdir, k, local_row, row
1504 : LOGICAL :: my_print_components
1505 : REAL(KIND=dp) :: tmp
1506 : REAL(KIND=dp), DIMENSION(3, 3) :: atom_virial, grid_virial
1507 8 : REAL(KIND=dp), DIMENSION(:, :), POINTER :: atom_coord_grad, grid_coord_grad
1508 :
1509 8 : my_print_components = .FALSE.
1510 8 : IF (PRESENT(print_components)) my_print_components = print_components
1511 :
1512 8 : NULLIFY (atom_coord_grad, grid_coord_grad)
1513 8 : CALL torch_tensor_grad(features%grid_coords_t, grid_coord_grad_t)
1514 8 : CALL torch_tensor_data_ptr(grid_coord_grad_t, grid_coord_grad)
1515 8 : CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
1516 :
1517 8 : grid_virial = 0.0_dp
1518 8 : atom_virial = 0.0_dp
1519 8 : local_row = 0
1520 216 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1521 5192 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1522 69312 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1523 55296 : local_row = local_row + 1
1524 124346 : DO feature_pos = features%local_feature_offsets(local_row), &
1525 59904 : features%local_feature_offsets(local_row + 1) - 1
1526 69050 : row = features%local_feature_rows(feature_pos)
1527 331496 : DO idir = 1, 3
1528 690500 : DO jdir = 1, idir
1529 414300 : tmp = grid_coord_grad(idir, row)*features%grid_coords(jdir, row)
1530 414300 : grid_virial(jdir, idir) = grid_virial(jdir, idir) + tmp
1531 414300 : grid_virial(idir, jdir) = grid_virial(jdir, idir)
1532 414300 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
1533 621450 : virial_xc(idir, jdir) = virial_xc(jdir, idir)
1534 : END DO
1535 : END DO
1536 : END DO
1537 : END DO
1538 : END DO
1539 : END DO
1540 8 : CPASSERT(local_row == features%nflat_local)
1541 :
1542 8 : IF (root_rank) THEN
1543 12 : DO iatom = 1, SIZE(features%coarse_0_atomic_coords, 2)
1544 36 : DO idir = 1, 3
1545 80 : DO jdir = 1, idir
1546 48 : tmp = atom_coord_grad(idir, iatom)*features%coarse_0_atomic_coords(jdir, iatom)
1547 48 : atom_virial(jdir, idir) = atom_virial(jdir, idir) + tmp
1548 48 : atom_virial(idir, jdir) = atom_virial(jdir, idir)
1549 48 : virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
1550 72 : virial_xc(idir, jdir) = virial_xc(jdir, idir)
1551 : END DO
1552 : END DO
1553 : END DO
1554 : END IF
1555 :
1556 8 : IF (my_print_components .AND. root_rank) THEN
1557 0 : iw = cp_logger_get_default_io_unit()
1558 0 : IF (iw > 0) THEN
1559 0 : CALL print_virial_delta("static-grid", grid_virial, .TRUE.)
1560 0 : CALL print_virial_delta("static-atom", atom_virial, .TRUE.)
1561 : END IF
1562 : END IF
1563 :
1564 8 : CALL torch_tensor_release(grid_coord_grad_t)
1565 :
1566 8 : END SUBROUTINE build_static_coordinate_virial
1567 :
1568 : ! **************************************************************************************************
1569 : !> \brief Add residual SKALA weight-feature contributions to the XC virial.
1570 : !> \param virial_xc ...
1571 : !> \param features ...
1572 : !> \param exc ...
1573 : !> \param grid_weight_grad_t ...
1574 : !> \param atomic_grid_weight_grad_t ...
1575 : !> \param root_rank ...
1576 : !> \param print_components ...
1577 : ! **************************************************************************************************
1578 8 : SUBROUTINE build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
1579 : atomic_grid_weight_grad_t, root_rank, print_components)
1580 : REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT) :: virial_xc
1581 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1582 : REAL(KIND=dp), INTENT(IN) :: exc
1583 : TYPE(torch_tensor_type), INTENT(INOUT) :: grid_weight_grad_t, &
1584 : atomic_grid_weight_grad_t
1585 : LOGICAL, INTENT(IN) :: root_rank
1586 : LOGICAL, INTENT(IN), OPTIONAL :: print_components
1587 :
1588 : INTEGER :: feature_pos, i, idir, iw, j, k, &
1589 : local_row, row
1590 : LOGICAL :: my_print_components
1591 : REAL(KIND=dp) :: atomic_tmp, exc_tmp, grid_tmp, tmp
1592 8 : REAL(KIND=dp), DIMENSION(:), POINTER :: atomic_grid_weight_grad, grid_weight_grad
1593 :
1594 8 : my_print_components = .FALSE.
1595 8 : IF (PRESENT(print_components)) my_print_components = print_components
1596 :
1597 8 : NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
1598 8 : CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
1599 8 : CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
1600 8 : CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
1601 8 : CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
1602 :
1603 8 : grid_tmp = 0.0_dp
1604 8 : atomic_tmp = 0.0_dp
1605 8 : local_row = 0
1606 216 : DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
1607 5192 : DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
1608 69312 : DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
1609 55296 : local_row = local_row + 1
1610 124346 : DO feature_pos = features%local_feature_offsets(local_row), &
1611 59904 : features%local_feature_offsets(local_row + 1) - 1
1612 69050 : row = features%local_feature_rows(feature_pos)
1613 69050 : grid_tmp = grid_tmp + grid_weight_grad(row)*features%grid_weights(row)
1614 : atomic_tmp = atomic_tmp + &
1615 124346 : atomic_grid_weight_grad(row)*features%atomic_grid_weights(row)
1616 : END DO
1617 : END DO
1618 : END DO
1619 : END DO
1620 8 : CPASSERT(local_row == features%nflat_local)
1621 8 : exc_tmp = 0.0_dp
1622 8 : IF (root_rank) exc_tmp = -exc
1623 8 : tmp = grid_tmp + atomic_tmp + exc_tmp
1624 :
1625 8 : IF (my_print_components .AND. root_rank) THEN
1626 0 : iw = cp_logger_get_default_io_unit()
1627 0 : IF (iw > 0) THEN
1628 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight grid", grid_tmp
1629 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight atomic", atomic_tmp
1630 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight final", exc_tmp
1631 0 : WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight residual", tmp
1632 : END IF
1633 : END IF
1634 :
1635 32 : DO idir = 1, 3
1636 32 : virial_xc(idir, idir) = virial_xc(idir, idir) + tmp
1637 : END DO
1638 :
1639 8 : CALL torch_tensor_release(grid_weight_grad_t)
1640 8 : CALL torch_tensor_release(atomic_grid_weight_grad_t)
1641 :
1642 8 : END SUBROUTINE build_weight_virial
1643 :
1644 : ! **************************************************************************************************
1645 : !> \brief Fill CP2K VXC real-space arrays from Torch feature gradients.
1646 : !> \param vxc_rho ...
1647 : !> \param vxc_tau ...
1648 : !> \param rho_r ...
1649 : !> \param pw_pool ...
1650 : !> \param density_grad ...
1651 : !> \param grad_grad ...
1652 : !> \param kin_grad ...
1653 : !> \param xc_deriv_method_id ...
1654 : ! **************************************************************************************************
1655 144 : SUBROUTINE build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
1656 144 : density_grad, grad_grad, kin_grad, &
1657 : xc_deriv_method_id)
1658 : TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER :: vxc_rho, vxc_tau, rho_r
1659 : TYPE(pw_pool_type), POINTER :: pw_pool
1660 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: density_grad
1661 : REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN) :: grad_grad
1662 : REAL(KIND=dp), DIMENSION(:, :), INTENT(IN) :: kin_grad
1663 : INTEGER, INTENT(IN) :: xc_deriv_method_id
1664 :
1665 : INTEGER :: i, ipt, ispin, j, k, nspins
1666 : INTEGER, DIMENSION(2, 3) :: bo
1667 : REAL(KIND=dp) :: dvol_inv
1668 : TYPE(pw_c1d_gs_type) :: tmp_g, vxc_g
1669 576 : TYPE(pw_r3d_rs_type), DIMENSION(3) :: grad_pw
1670 :
1671 144 : nspins = SIZE(rho_r)
1672 1440 : bo = rho_r(1)%pw_grid%bounds_local
1673 144 : dvol_inv = 1.0_dp/rho_r(1)%pw_grid%dvol
1674 :
1675 972 : ALLOCATE (vxc_rho(nspins), vxc_tau(nspins))
1676 342 : DO ispin = 1, nspins
1677 198 : CALL pw_pool%create_pw(vxc_rho(ispin))
1678 198 : CALL pw_pool%create_pw(vxc_tau(ispin))
1679 198 : CALL pw_zero(vxc_rho(ispin))
1680 342 : CALL pw_zero(vxc_tau(ispin))
1681 : END DO
1682 :
1683 144 : IF (xc_requires_tmp_g(xc_deriv_method_id) .OR. rho_r(1)%pw_grid%spherical) THEN
1684 144 : CALL pw_pool%create_pw(vxc_g)
1685 144 : IF (.NOT. rho_r(1)%pw_grid%spherical) CALL pw_pool%create_pw(tmp_g)
1686 : END IF
1687 :
1688 342 : DO ispin = 1, nspins
1689 792 : DO i = 1, 3
1690 594 : CALL pw_pool%create_pw(grad_pw(i))
1691 792 : CALL pw_zero(grad_pw(i))
1692 : END DO
1693 :
1694 198 : ipt = 0
1695 4304 : DO k = bo(1, 3), bo(2, 3)
1696 110722 : DO j = bo(1, 2), bo(2, 2)
1697 1831517 : DO i = bo(1, 1), bo(2, 1)
1698 1720993 : ipt = ipt + 1
1699 1827411 : IF (nspins == 1) THEN
1700 : vxc_rho(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
1701 1012243 : (density_grad(ipt, 1) + density_grad(ipt, 2))
1702 : vxc_tau(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
1703 1012243 : (kin_grad(ipt, 1) + kin_grad(ipt, 2))
1704 : grad_pw(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
1705 1012243 : (grad_grad(ipt, 1, 1) + grad_grad(ipt, 1, 2))
1706 : grad_pw(2)%array(i, j, k) = 0.5_dp*dvol_inv* &
1707 1012243 : (grad_grad(ipt, 2, 1) + grad_grad(ipt, 2, 2))
1708 : grad_pw(3)%array(i, j, k) = 0.5_dp*dvol_inv* &
1709 1012243 : (grad_grad(ipt, 3, 1) + grad_grad(ipt, 3, 2))
1710 : ELSE
1711 708750 : vxc_rho(ispin)%array(i, j, k) = dvol_inv*density_grad(ipt, ispin)
1712 708750 : vxc_tau(ispin)%array(i, j, k) = dvol_inv*kin_grad(ipt, ispin)
1713 708750 : grad_pw(1)%array(i, j, k) = dvol_inv*grad_grad(ipt, 1, ispin)
1714 708750 : grad_pw(2)%array(i, j, k) = dvol_inv*grad_grad(ipt, 2, ispin)
1715 708750 : grad_pw(3)%array(i, j, k) = dvol_inv*grad_grad(ipt, 3, ispin)
1716 : END IF
1717 : END DO
1718 : END DO
1719 : END DO
1720 :
1721 792 : DO i = 1, 3
1722 792 : CALL pw_scale(grad_pw(i), -1.0_dp)
1723 : END DO
1724 198 : CALL xc_pw_divergence(xc_deriv_method_id, grad_pw, tmp_g, vxc_g, vxc_rho(ispin))
1725 :
1726 936 : DO i = 1, 3
1727 792 : CALL pw_pool%give_back_pw(grad_pw(i))
1728 : END DO
1729 : END DO
1730 :
1731 144 : IF (ASSOCIATED(vxc_g%pw_grid)) CALL pw_pool%give_back_pw(vxc_g)
1732 144 : IF (ASSOCIATED(tmp_g%pw_grid)) CALL pw_pool%give_back_pw(tmp_g)
1733 :
1734 144 : END SUBROUTINE build_vxc_from_feature_grads
1735 :
1736 : ! **************************************************************************************************
1737 : !> \brief Print optional diagnostics for the CP2K-native SKALA GPW feature block.
1738 : !> \param features ...
1739 : !> \param print_active ...
1740 : ! **************************************************************************************************
1741 24 : SUBROUTINE print_native_grid_diagnostics(features, print_active)
1742 : TYPE(skala_gpw_feature_type), INTENT(IN) :: features
1743 : LOGICAL, INTENT(IN) :: print_active
1744 :
1745 : INTEGER :: atom_rows_max, atom_rows_min, &
1746 : chunk_rows_max, chunk_rows_min, iw
1747 : REAL(KIND=dp) :: chunk_imbalance
1748 :
1749 24 : IF (.NOT. print_active) RETURN
1750 :
1751 12 : iw = cp_logger_get_default_io_unit()
1752 12 : IF (iw <= 0) RETURN
1753 : WRITE (UNIT=iw, FMT="(/,T2,A,1X,ES19.11)") &
1754 12 : "SKALA_GPW| Native grid feature electrons", features%electron_count
1755 : WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
1756 12 : "SKALA_GPW| Native grid feature spin moment", features%spin_moment
1757 : WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
1758 12 : "SKALA_GPW| Native grid feature weight sum", features%grid_weight_sum
1759 12 : IF (ALLOCATED(features%atomic_grid_sizes)) THEN
1760 49 : atom_rows_min = INT(MINVAL(features%atomic_grid_sizes))
1761 49 : atom_rows_max = INT(MAXVAL(features%atomic_grid_sizes))
1762 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
1763 12 : "SKALA_GPW| Native grid atom row range", atom_rows_min, "to", &
1764 61 : atom_rows_max, "sum", INT(SUM(features%atomic_grid_sizes))
1765 : END IF
1766 12 : IF (features%uses_atom_chunks) THEN
1767 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0)") &
1768 1 : "SKALA_GPW| Native grid atom chunk rows", features%chunk_feature_count, &
1769 2 : "of", features%nflat
1770 1 : IF (ALLOCATED(features%chunk_grad_counts)) THEN
1771 3 : chunk_rows_min = MINVAL(features%chunk_grad_counts)/ngrad_per_point
1772 3 : chunk_rows_max = MAXVAL(features%chunk_grad_counts)/ngrad_per_point
1773 1 : chunk_imbalance = REAL(chunk_rows_max, KIND=dp)/REAL(MAX(1, chunk_rows_min), KIND=dp)
1774 : WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,ES12.5)") &
1775 1 : "SKALA_GPW| Native grid atom chunk row range", chunk_rows_min, &
1776 2 : "to", chunk_rows_max, "imbalance", chunk_imbalance
1777 : END IF
1778 : END IF
1779 :
1780 : END SUBROUTINE print_native_grid_diagnostics
1781 :
1782 : ! **************************************************************************************************
1783 : !> \brief Configure CUDA device selection for the native SKALA GPW Torch path.
1784 : !> \param use_cuda ...
1785 : !> \param requested_device ...
1786 : !> \param group ...
1787 : !> \return selected CUDA device, or -1 for CPU fallback/no visible CUDA device
1788 : ! **************************************************************************************************
1789 152 : FUNCTION configure_native_grid_cuda(use_cuda, requested_device, group) RESULT(selected_device)
1790 : LOGICAL, INTENT(IN) :: use_cuda
1791 : INTEGER, INTENT(IN) :: requested_device
1792 :
1793 : CLASS(mp_comm_type), INTENT(IN) :: group
1794 :
1795 : INTEGER :: cuda_device_count, iw, pe, selected_device
1796 152 : INTEGER, ALLOCATABLE, DIMENSION(:) :: selected_devices
1797 :
1798 152 : selected_device = -1
1799 :
1800 152 : IF (.NOT. use_cuda) RETURN
1801 :
1802 0 : IF (.NOT. torch_cuda_is_available()) THEN
1803 0 : cuda_device_count = 0
1804 : ELSE
1805 0 : cuda_device_count = torch_cuda_device_count()
1806 : END IF
1807 0 : IF (cuda_device_count > 0) THEN
1808 0 : IF (requested_device < 0) THEN
1809 0 : selected_device = MOD(group%mepos, cuda_device_count)
1810 : ELSE
1811 0 : selected_device = requested_device
1812 : END IF
1813 : END IF
1814 0 : IF (selected_device >= cuda_device_count) THEN
1815 : CALL cp_abort(__LOCATION__, &
1816 : "GAUXC%NATIVE_GRID_CUDA_DEVICE selects a CUDA device outside the visible "// &
1817 0 : "Torch CUDA device range.")
1818 : END IF
1819 0 : IF (selected_device >= 0) CALL offload_set_chosen_device(selected_device)
1820 :
1821 0 : ALLOCATE (selected_devices(group%num_pe))
1822 0 : CALL group%allgather(selected_device, selected_devices)
1823 :
1824 0 : IF (group%mepos /= 0) RETURN
1825 : IF (selected_device == logged_cuda_device .AND. &
1826 : cuda_device_count == logged_cuda_device_count .AND. &
1827 0 : group%num_pe == logged_cuda_nproc .AND. &
1828 : requested_device == logged_cuda_request) RETURN
1829 :
1830 0 : iw = cp_logger_get_default_io_unit()
1831 0 : IF (iw <= 0) RETURN
1832 0 : IF (selected_device >= 0) THEN
1833 : WRITE (UNIT=iw, FMT="(/,T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
1834 0 : "SKALA_GPW| Native grid Torch CUDA device", selected_device, &
1835 0 : "of", cuda_device_count, "requested", requested_device
1836 : ELSE
1837 : WRITE (UNIT=iw, FMT="(/,T2,A)") &
1838 0 : "SKALA_GPW| Native grid Torch CUDA requested, but no Torch CUDA device is visible"
1839 : END IF
1840 : WRITE (UNIT=iw, FMT="(T2,A)", ADVANCE="NO") &
1841 0 : "SKALA_GPW| Native grid Torch CUDA rank devices"
1842 0 : DO pe = 1, group%num_pe
1843 0 : WRITE (UNIT=iw, FMT="(1X,I0,A,I0)", ADVANCE="NO") pe - 1, ":", selected_devices(pe)
1844 : END DO
1845 0 : WRITE (UNIT=iw, FMT=*)
1846 :
1847 0 : logged_cuda_device = selected_device
1848 0 : logged_cuda_device_count = cuda_device_count
1849 0 : logged_cuda_nproc = group%num_pe
1850 0 : logged_cuda_request = requested_device
1851 :
1852 152 : END FUNCTION configure_native_grid_cuda
1853 :
1854 : ! **************************************************************************************************
1855 : !> \brief Load and cache the TorchScript SKALA model.
1856 : !> \param model_path ...
1857 : !> \param cuda_device ...
1858 : ! **************************************************************************************************
1859 152 : SUBROUTINE ensure_model_loaded(model_path, cuda_device)
1860 : CHARACTER(len=*), INTENT(IN) :: model_path
1861 : INTEGER, INTENT(IN) :: cuda_device
1862 :
1863 152 : IF (cached_model_loaded) THEN
1864 108 : IF (TRIM(cached_model_path) == TRIM(model_path) .AND. &
1865 : cached_model_cuda_device == cuda_device) RETURN
1866 0 : CALL skala_torch_model_release(cached_model)
1867 0 : cached_model_loaded = .FALSE.
1868 : END IF
1869 :
1870 44 : CALL skala_torch_model_load(cached_model, TRIM(model_path))
1871 44 : cached_model_path = model_path
1872 44 : cached_model_cuda_device = cuda_device
1873 44 : cached_model_loaded = .TRUE.
1874 :
1875 152 : END SUBROUTINE ensure_model_loaded
1876 :
1877 : ! **************************************************************************************************
1878 : !> \brief Resolve the SKALA TorchScript model path from the GAUXC subsection.
1879 : !> \param xc_section ...
1880 : !> \param model_path ...
1881 : ! **************************************************************************************************
1882 152 : SUBROUTINE get_skala_model_path(xc_section, model_path)
1883 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
1884 : CHARACTER(len=default_path_length), INTENT(OUT) :: model_path
1885 :
1886 : CHARACTER(len=default_path_length) :: model_key
1887 : INTEGER :: env_status
1888 : LOGICAL :: native_grid_use_cuda
1889 : TYPE(section_vals_type), POINTER :: gauxc_section
1890 :
1891 152 : gauxc_section => get_gauxc_section(xc_section)
1892 152 : IF (.NOT. ASSOCIATED(gauxc_section)) THEN
1893 0 : CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
1894 : END IF
1895 :
1896 152 : CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_path)
1897 152 : model_key = ADJUSTL(model_path)
1898 152 : CALL uppercase(model_key)
1899 152 : IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
1900 0 : CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
1901 152 : ELSE IF (TRIM(model_key) == "SKALA") THEN
1902 152 : CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
1903 152 : IF (native_grid_use_cuda) THEN
1904 0 : CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_CUDA_MODEL", model_path, STATUS=env_status)
1905 0 : IF (env_status == 0 .AND. LEN_TRIM(model_path) > 0) RETURN
1906 : END IF
1907 152 : CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_MODEL", model_path, STATUS=env_status)
1908 152 : IF (env_status /= 0 .OR. LEN_TRIM(model_path) == 0) THEN
1909 0 : IF (native_grid_use_cuda) THEN
1910 : CALL cp_abort(__LOCATION__, &
1911 0 : "MODEL SKALA CUDA path requires GAUXC_SKALA_CUDA_MODEL or GAUXC_SKALA_MODEL")
1912 : ELSE
1913 : CALL cp_abort(__LOCATION__, &
1914 0 : "MODEL SKALA requires the GAUXC_SKALA_MODEL environment variable")
1915 : END IF
1916 : END IF
1917 : END IF
1918 :
1919 : END SUBROUTINE get_skala_model_path
1920 :
1921 : ! **************************************************************************************************
1922 : !> \brief Return the first GAUXC functional subsection, if present.
1923 : !> \param xc_section ...
1924 : !> \return ...
1925 : ! **************************************************************************************************
1926 184095 : FUNCTION get_gauxc_section(xc_section) RESULT(gauxc_section)
1927 : TYPE(section_vals_type), INTENT(IN), POINTER :: xc_section
1928 : TYPE(section_vals_type), POINTER :: gauxc_section
1929 :
1930 : INTEGER :: ifun
1931 : TYPE(section_vals_type), POINTER :: functionals, xc_fun
1932 :
1933 184095 : NULLIFY (gauxc_section)
1934 184095 : IF (.NOT. ASSOCIATED(xc_section)) RETURN
1935 :
1936 184095 : functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
1937 184095 : IF (.NOT. ASSOCIATED(functionals)) RETURN
1938 :
1939 184095 : ifun = 0
1940 : DO
1941 370518 : ifun = ifun + 1
1942 370518 : xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
1943 370518 : IF (.NOT. ASSOCIATED(xc_fun)) EXIT
1944 370518 : IF (xc_fun%section%name == "GAUXC") THEN
1945 : gauxc_section => xc_fun
1946 : EXIT
1947 : END IF
1948 : END DO
1949 :
1950 : END FUNCTION get_gauxc_section
1951 :
1952 : END MODULE skala_gpw_functional
|