LCOV - code coverage report
Current view: top level - src - skala_gpw_functional.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:561f475) Lines: 74.3 % 907 674
Test Date: 2026-06-21 06:48:54 Functions: 87.5 % 24 21

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Experimental CP2K-native GPW real-space-grid path for SKALA TorchScript models.
      10              : ! **************************************************************************************************
      11              : MODULE skala_gpw_functional
      12              :    USE cell_types,                      ONLY: cell_type,&
      13              :                                               pbc
      14              :    USE cp_array_utils,                  ONLY: cp_3d_r_cp_type
      15              :    USE cp_log_handling,                 ONLY: cp_logger_get_default_io_unit
      16              :    USE input_section_types,             ONLY: section_get_rval,&
      17              :                                               section_vals_get_subs_vals,&
      18              :                                               section_vals_get_subs_vals2,&
      19              :                                               section_vals_type,&
      20              :                                               section_vals_val_get
      21              :    USE kinds,                           ONLY: default_path_length,&
      22              :                                               dp,&
      23              :                                               int_8
      24              :    USE message_passing,                 ONLY: mp_comm_type
      25              :    USE offload_api,                     ONLY: offload_set_chosen_device
      26              :    USE particle_types,                  ONLY: particle_type
      27              :    USE pw_grid_types,                   ONLY: pw_grid_type
      28              :    USE pw_methods,                      ONLY: pw_scale,&
      29              :                                               pw_zero
      30              :    USE pw_pool_types,                   ONLY: pw_pool_type
      31              :    USE pw_types,                        ONLY: pw_c1d_gs_type,&
      32              :                                               pw_r3d_rs_type
      33              :    USE qs_grid_atom,                    ONLY: grid_atom_type
      34              :    USE skala_gpw_features,              ONLY: skala_gpw_atom_partition_hard,&
      35              :                                               skala_gpw_atom_partition_smooth,&
      36              :                                               skala_gpw_atom_subchunk_count,&
      37              :                                               skala_gpw_feature_build,&
      38              :                                               skala_gpw_feature_build_atom_subchunk,&
      39              :                                               skala_gpw_feature_release,&
      40              :                                               skala_gpw_feature_type,&
      41              :                                               skala_gpw_smooth_partition_derivatives
      42              :    USE skala_torch_api,                 ONLY: skala_torch_model_get_exc,&
      43              :                                               skala_torch_model_load,&
      44              :                                               skala_torch_model_release,&
      45              :                                               skala_torch_model_type
      46              :    USE string_utilities,                ONLY: uppercase
      47              :    USE torch_api,                       ONLY: &
      48              :         torch_cuda_device_count, torch_cuda_is_available, torch_dict_create, torch_dict_insert, &
      49              :         torch_dict_release, torch_dict_type, torch_tensor_backward_scalar, torch_tensor_data_ptr, &
      50              :         torch_tensor_from_array, torch_tensor_grad, torch_tensor_release, &
      51              :         torch_tensor_to_device_leaf, torch_tensor_type, torch_use_cuda
      52              :    USE xc_rho_cflags_types,             ONLY: xc_rho_cflags_type
      53              :    USE xc_rho_set_types,                ONLY: xc_rho_set_create,&
      54              :                                               xc_rho_set_get,&
      55              :                                               xc_rho_set_release,&
      56              :                                               xc_rho_set_type,&
      57              :                                               xc_rho_set_update
      58              :    USE xc_util,                         ONLY: xc_pw_divergence,&
      59              :                                               xc_requires_tmp_g
      60              : #include "./base/base_uses.f90"
      61              : 
      62              :    IMPLICIT NONE
      63              : 
      64              :    PRIVATE
      65              : 
      66              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_functional'
      67              :    INTEGER, PARAMETER, PRIVATE          :: atom_chunk_auto_max_rows = 400000, &
      68              :                                            atom_chunk_auto_min_rows = 100000, &
      69              :                                            atom_chunk_auto_row_quantum = 100000, &
      70              :                                            ncollapsed_grad_per_point = 5, ngrad_per_point = 10
      71              : 
      72              :    PUBLIC :: ensure_native_skala_grid_scope, get_gauxc_section, skala_gapw_atom_vxc_of_r, &
      73              :              skala_gpw_eval, xc_section_uses_native_skala_grid
      74              : 
      75              :    TYPE(skala_torch_model_type), SAVE                  :: cached_model
      76              :    CHARACTER(len=default_path_length), SAVE            :: cached_model_path = ""
      77              :    LOGICAL, SAVE                                       :: cached_model_loaded = .FALSE.
      78              :    INTEGER, SAVE                                       :: cached_model_cuda_device = -3
      79              :    INTEGER, SAVE                                       :: logged_cuda_device = -3, &
      80              :                                                           logged_cuda_device_count = -1, &
      81              :                                                           logged_cuda_nproc = -1, &
      82              :                                                           logged_cuda_request = -3
      83              : 
      84              : CONTAINS
      85              : 
      86              : ! **************************************************************************************************
      87              : !> \brief Return true if the GAUXC subsection requests the CP2K-native GPW grid path.
      88              : !> \param xc_section ...
      89              : !> \return ...
      90              : ! **************************************************************************************************
      91       183791 :    FUNCTION xc_section_uses_native_skala_grid(xc_section) RESULT(uses_native_grid)
      92              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
      93              :       LOGICAL                                            :: uses_native_grid
      94              : 
      95              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
      96              : 
      97       183791 :       uses_native_grid = .FALSE.
      98       183791 :       gauxc_section => get_gauxc_section(xc_section)
      99       183791 :       IF (ASSOCIATED(gauxc_section)) THEN
     100          674 :          CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=uses_native_grid)
     101              :       END IF
     102              : 
     103       183791 :    END FUNCTION xc_section_uses_native_skala_grid
     104              : 
     105              : ! **************************************************************************************************
     106              : !> \brief Enforce the currently implemented native SKALA GPW input scope.
     107              : !> \param xc_section ...
     108              : ! **************************************************************************************************
     109          288 :    SUBROUTINE ensure_native_skala_grid_scope(xc_section)
     110              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
     111              : 
     112              :       CHARACTER(len=default_path_length)                 :: model_key, model_name
     113              :       INTEGER                                            :: ifun, nfun
     114              :       LOGICAL                                            :: native_grid
     115              :       TYPE(section_vals_type), POINTER                   :: functionals, gauxc_section, xc_fun
     116              : 
     117          144 :       NULLIFY (gauxc_section)
     118          144 :       IF (.NOT. ASSOCIATED(xc_section)) THEN
     119            0 :          CPABORT("Native SKALA GPW requires an XC section")
     120              :       END IF
     121              : 
     122          144 :       functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
     123          144 :       IF (.NOT. ASSOCIATED(functionals)) THEN
     124            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL section")
     125              :       END IF
     126              : 
     127          144 :       nfun = 0
     128          144 :       ifun = 0
     129              :       DO
     130          288 :          ifun = ifun + 1
     131          288 :          xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
     132          288 :          IF (.NOT. ASSOCIATED(xc_fun)) EXIT
     133          144 :          nfun = nfun + 1
     134          288 :          IF (xc_fun%section%name == "GAUXC") gauxc_section => xc_fun
     135              :       END DO
     136              : 
     137          144 :       IF (.NOT. ASSOCIATED(gauxc_section)) THEN
     138            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
     139              :       END IF
     140          144 :       IF (nfun /= 1) THEN
     141            0 :          CPABORT("Native SKALA GPW requires GAUXC to be the only XC functional")
     142              :       END IF
     143              : 
     144          144 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID", l_val=native_grid)
     145          144 :       IF (.NOT. native_grid) RETURN
     146              : 
     147          144 :       CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_name)
     148          144 :       model_key = ADJUSTL(model_name)
     149          144 :       CALL uppercase(model_key)
     150          144 :       IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
     151            0 :          CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
     152              :       END IF
     153              : 
     154              :    END SUBROUTINE ensure_native_skala_grid_scope
     155              : 
     156              : ! **************************************************************************************************
     157              : !> \brief Evaluate SKALA energy and first derivatives on a CP2K GPW grid.
     158              : !> \param vxc_rho ...
     159              : !> \param vxc_tau ...
     160              : !> \param exc ...
     161              : !> \param rho_r ...
     162              : !> \param rho_g ...
     163              : !> \param tau ...
     164              : !> \param xc_section ...
     165              : !> \param weights ...
     166              : !> \param pw_pool ...
     167              : !> \param particle_set ...
     168              : !> \param cell ...
     169              : !> \param compute_virial ...
     170              : !> \param virial_xc ...
     171              : !> \param just_energy ...
     172              : !> \param atom_force ...
     173              : ! **************************************************************************************************
     174          144 :    SUBROUTINE skala_gpw_eval(vxc_rho, vxc_tau, exc, rho_r, rho_g, tau, xc_section, &
     175              :                              weights, pw_pool, particle_set, cell, compute_virial, virial_xc, &
     176          144 :                              just_energy, atom_force)
     177              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: vxc_rho, vxc_tau
     178              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     179              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
     180              :       TYPE(pw_c1d_gs_type), DIMENSION(:), POINTER        :: rho_g
     181              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: tau
     182              :       TYPE(section_vals_type), POINTER                   :: xc_section
     183              :       TYPE(pw_r3d_rs_type), POINTER                      :: weights
     184              :       TYPE(pw_pool_type), POINTER                        :: pw_pool
     185              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     186              :       TYPE(cell_type), POINTER                           :: cell
     187              :       LOGICAL, INTENT(IN)                                :: compute_virial
     188              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(OUT)        :: virial_xc
     189              :       LOGICAL, INTENT(IN), OPTIONAL                      :: just_energy
     190              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT), &
     191              :          OPTIONAL                                        :: atom_force
     192              : 
     193              :       CHARACTER(len=default_path_length)                 :: model_path
     194              :       INTEGER :: iw, native_grid_atom_chunk_max_rows, native_grid_atom_partition, &
     195              :          native_grid_atom_subchunks, native_grid_cuda_device, nspins, phase_handle, &
     196              :          selected_cuda_device, xc_deriv_method_id, xc_rho_smooth_id
     197              :       LOGICAL :: have_atom_coord_grad, lsd, my_just_energy, native_grid_atom_chunk_routing, &
     198              :          native_grid_atom_chunks, native_grid_diagnostics, native_grid_use_cuda, needs_atom_force, &
     199              :          use_atom_subchunks
     200          144 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: density_grad, kin_grad
     201          144 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: grad_grad
     202              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: virial_before
     203              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     204          144 :       TYPE(skala_gpw_feature_type)                       :: features
     205              :       TYPE(torch_tensor_type)                            :: atom_coord_grad_t, &
     206              :                                                             atomic_grid_weight_grad_t, exc_tensor, &
     207              :                                                             grid_coord_grad_t, grid_weight_grad_t
     208              :       TYPE(xc_rho_cflags_type)                           :: needs
     209              :       TYPE(xc_rho_set_type)                              :: rho_set
     210              : 
     211          144 :       virial_xc = 0.0_dp
     212          144 :       exc = 0.0_dp
     213          144 :       my_just_energy = .FALSE.
     214          144 :       IF (PRESENT(just_energy)) my_just_energy = just_energy
     215          144 :       needs_atom_force = PRESENT(atom_force)
     216          272 :       IF (needs_atom_force) atom_force = 0.0_dp
     217          144 :       have_atom_coord_grad = .FALSE.
     218              : 
     219          144 :       IF (compute_virial .AND. my_just_energy) THEN
     220              :          CALL cp_abort(__LOCATION__, &
     221            0 :                        "Native SKALA GPW stress/virial requires feature gradients.")
     222              :       END IF
     223          144 :       IF (.NOT. ASSOCIATED(rho_g)) THEN
     224              :          CALL cp_abort(__LOCATION__, &
     225            0 :                        "Native SKALA GPW requires the reciprocal-space density to form density gradients.")
     226              :       END IF
     227          144 :       IF (.NOT. ASSOCIATED(tau)) THEN
     228              :          CALL cp_abort(__LOCATION__, &
     229            0 :                        "Native SKALA GPW requires the kinetic-energy density.")
     230              :       END IF
     231              : 
     232          144 :       nspins = SIZE(rho_r)
     233          144 :       lsd = (nspins /= 1)
     234          144 :       CALL get_skala_model_path(xc_section, model_path)
     235          144 :       gauxc_section => get_gauxc_section(xc_section)
     236          144 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
     237              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
     238          144 :                                 i_val=native_grid_cuda_device)
     239              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNKS", &
     240          144 :                                 l_val=native_grid_atom_chunks)
     241              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_ROUTING", &
     242          144 :                                 l_val=native_grid_atom_chunk_routing)
     243              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_CHUNK_MAX_ROWS", &
     244          144 :                                 i_val=native_grid_atom_chunk_max_rows)
     245              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_ATOM_PARTITION", &
     246          144 :                                 i_val=native_grid_atom_partition)
     247          134 :       SELECT CASE (native_grid_atom_partition)
     248              :       CASE (1)
     249          134 :          native_grid_atom_partition = skala_gpw_atom_partition_hard
     250              :       CASE (2)
     251           10 :          native_grid_atom_partition = skala_gpw_atom_partition_smooth
     252              :       CASE DEFAULT
     253              :          CALL cp_abort(__LOCATION__, &
     254          144 :                        "Unknown GAUXC%NATIVE_GRID_ATOM_PARTITION value.")
     255              :       END SELECT
     256          144 :       native_grid_atom_chunk_routing = native_grid_atom_chunk_routing .OR. native_grid_atom_chunks
     257          144 :       native_grid_atom_chunks = native_grid_atom_chunks .OR. native_grid_atom_chunk_routing
     258          144 :       IF (native_grid_atom_chunk_max_rows < -1) THEN
     259              :          CALL cp_abort(__LOCATION__, &
     260            0 :                        "GAUXC%NATIVE_GRID_ATOM_CHUNK_MAX_ROWS must be -1, zero, or positive.")
     261              :       END IF
     262          144 :       IF (native_grid_atom_chunks .AND. needs_atom_force) THEN
     263              :          CALL cp_abort(__LOCATION__, &
     264            0 :                        "Native SKALA GPW atom chunks are not implemented for atom forces yet.")
     265              :       END IF
     266              :       ! The portable SKALA export used by the regtests builds ragged-index tensors on CPU.
     267          144 :       CALL torch_use_cuda(native_grid_use_cuda)
     268              :       selected_cuda_device = configure_native_grid_cuda( &
     269          144 :                              native_grid_use_cuda, native_grid_cuda_device, rho_r(1)%pw_grid%para%group)
     270          144 :       CALL ensure_model_loaded(model_path, selected_cuda_device)
     271              : 
     272          144 :       IF (lsd) THEN
     273           54 :          needs%rho_spin = .TRUE.
     274           54 :          needs%drho_spin = .TRUE.
     275           54 :          needs%tau_spin = .TRUE.
     276              :       ELSE
     277           90 :          needs%rho = .TRUE.
     278           90 :          needs%drho = .TRUE.
     279           90 :          needs%tau = .TRUE.
     280              :       END IF
     281              : 
     282          144 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_DERIV", i_val=xc_deriv_method_id)
     283          144 :       CALL section_vals_val_get(xc_section, "XC_GRID%XC_SMOOTH_RHO", i_val=xc_rho_smooth_id)
     284              : 
     285              :       CALL xc_rho_set_create(rho_set, &
     286              :                              rho_r(1)%pw_grid%bounds_local, &
     287              :                              rho_cutoff=section_get_rval(xc_section, "density_cutoff"), &
     288              :                              drho_cutoff=section_get_rval(xc_section, "gradient_cutoff"), &
     289          144 :                              tau_cutoff=section_get_rval(xc_section, "tau_cutoff"))
     290              :       CALL xc_rho_set_update(rho_set, rho_r, rho_g, tau, needs, &
     291          144 :                              xc_deriv_method_id, xc_rho_smooth_id, pw_pool)
     292              : 
     293              :       CALL skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
     294              :                                    requires_grad=(.NOT. my_just_energy), weights=weights, &
     295              :                                    requires_coordinate_grad=(needs_atom_force .OR. compute_virial), &
     296              :                                    requires_stress_grad=compute_virial, &
     297              :                                    use_atom_chunks=native_grid_atom_chunks, &
     298              :                                    route_atom_chunks=native_grid_atom_chunk_routing, &
     299          272 :                                    atom_partition=native_grid_atom_partition)
     300          144 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_DIAGNOSTICS", l_val=native_grid_diagnostics)
     301          144 :       IF (native_grid_diagnostics) THEN
     302           24 :          CALL print_native_grid_diagnostics(features, rho_r(1)%pw_grid%para%group%mepos == 0)
     303              :       END IF
     304              : 
     305          144 :       IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows == -1) THEN
     306            0 :          IF (native_grid_use_cuda) THEN
     307              :             native_grid_atom_chunk_max_rows = auto_atom_chunk_max_rows(features, &
     308            0 :                                                                        rho_r(1)%pw_grid%para%group)
     309              :          ELSE
     310            0 :             native_grid_atom_chunk_max_rows = 0
     311              :          END IF
     312              :       END IF
     313          144 :       IF (native_grid_diagnostics .AND. features%uses_atom_chunks .AND. &
     314              :           rho_r(1)%pw_grid%para%group%mepos == 0) THEN
     315            1 :          iw = cp_logger_get_default_io_unit()
     316            1 :          IF (iw > 0) THEN
     317              :             WRITE (UNIT=iw, FMT="(T2,A,1X,I0)") &
     318            1 :                "SKALA_GPW| Native grid atom chunk max rows", native_grid_atom_chunk_max_rows
     319              :          END IF
     320              :       END IF
     321          144 :       native_grid_atom_subchunks = 1
     322          144 :       IF (features%uses_atom_chunks .AND. native_grid_atom_chunk_max_rows > 0) THEN
     323            2 :          native_grid_atom_subchunks = skala_gpw_atom_subchunk_count(native_grid_atom_chunk_max_rows)
     324              :       END IF
     325          144 :       use_atom_subchunks = native_grid_atom_subchunks > 1
     326            2 :       IF (use_atom_subchunks) THEN
     327              :          CALL evaluate_atom_subchunks(features, rho_r(1)%pw_grid%para%group, &
     328              :                                       native_grid_atom_chunk_max_rows, &
     329              :                                       compute_grads=(.NOT. my_just_energy), exc=exc, &
     330              :                                       density_grad=density_grad, grad_grad=grad_grad, &
     331            2 :                                       kin_grad=kin_grad, collapse_spin_grads=(nspins == 1))
     332              :       ELSE
     333              :          CALL skala_torch_model_get_exc(cached_model, features%inputs, &
     334          142 :                                         features%grid_weights_t, exc_tensor, exc)
     335              :       END IF
     336          144 :       IF (features%uses_atom_chunks) CALL rho_r(1)%pw_grid%para%group%sum(exc)
     337              : 
     338          144 :       IF (.NOT. my_just_energy) THEN
     339          144 :          IF (.NOT. use_atom_subchunks) THEN
     340          142 :             CALL timeset("skala_gpw_backward", phase_handle)
     341          142 :             CALL torch_tensor_backward_scalar(exc_tensor)
     342          142 :             CALL timestop(phase_handle)
     343              : 
     344          142 :             CALL timeset("skala_gpw_grad_fetch", phase_handle)
     345          142 :             IF (features%uses_atom_chunks) THEN
     346              :                CALL fetch_and_gather_atom_chunk_grads(features, rho_r(1)%pw_grid%para%group, &
     347            0 :                                                       density_grad, grad_grad, kin_grad)
     348              :             ELSE
     349          142 :                CALL fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
     350              :             END IF
     351          142 :             CALL timestop(phase_handle)
     352              :          END IF
     353          144 :          IF (needs_atom_force) THEN
     354              :             CALL add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, &
     355           16 :                                                rho_r(1)%pw_grid%para%group%mepos == 0)
     356           16 :             IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
     357              :                CALL add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
     358            4 :                                                grid_weight_grad_t, atomic_grid_weight_grad_t)
     359              :             END IF
     360              :             have_atom_coord_grad = .TRUE.
     361              :          END IF
     362              : 
     363          144 :          CALL timeset("skala_gpw_vxc_unpack", phase_handle)
     364          144 :          IF (compute_virial) THEN
     365            8 :             IF (native_grid_diagnostics) virial_before = virial_xc
     366            8 :             CALL build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
     367            8 :             IF (native_grid_diagnostics) THEN
     368              :                CALL print_virial_delta("feature-gradient", virial_xc - virial_before, &
     369            0 :                                        rho_r(1)%pw_grid%para%group%mepos == 0)
     370            0 :                virial_before = virial_xc
     371              :             END IF
     372            8 :             IF (.NOT. have_atom_coord_grad) THEN
     373            0 :                CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
     374            0 :                have_atom_coord_grad = .TRUE.
     375              :             END IF
     376              :             CALL build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
     377              :                                                 grid_coord_grad_t, &
     378              :                                                 rho_r(1)%pw_grid%para%group%mepos == 0, &
     379            8 :                                                 native_grid_diagnostics)
     380            8 :             IF (native_grid_diagnostics) THEN
     381              :                CALL print_virial_delta("static-coordinates", virial_xc - virial_before, &
     382            0 :                                        rho_r(1)%pw_grid%para%group%mepos == 0)
     383            0 :                virial_before = virial_xc
     384              :             END IF
     385            8 :             IF (features%atom_partition == skala_gpw_atom_partition_smooth) THEN
     386              :                CALL build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
     387            2 :                                                   grid_weight_grad_t, atomic_grid_weight_grad_t)
     388            2 :                IF (native_grid_diagnostics) THEN
     389              :                   CALL print_virial_delta("smooth-partition", virial_xc - virial_before, &
     390            0 :                                           rho_r(1)%pw_grid%para%group%mepos == 0)
     391            0 :                   virial_before = virial_xc
     392              :                END IF
     393              :             END IF
     394              :             CALL build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
     395              :                                      atomic_grid_weight_grad_t, &
     396              :                                      rho_r(1)%pw_grid%para%group%mepos == 0, &
     397            8 :                                      native_grid_diagnostics)
     398            8 :             IF (native_grid_diagnostics) THEN
     399              :                CALL print_virial_delta("weight-residual", virial_xc - virial_before, &
     400            0 :                                        rho_r(1)%pw_grid%para%group%mepos == 0)
     401              :             END IF
     402              :          END IF
     403              :          CALL build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
     404              :                                            density_grad, grad_grad, kin_grad, &
     405          144 :                                            xc_deriv_method_id)
     406          144 :          CALL timestop(phase_handle)
     407              : 
     408          144 :          CALL timeset("skala_gpw_grad_release", phase_handle)
     409          144 :          DEALLOCATE (density_grad, grad_grad, kin_grad)
     410          144 :          IF (have_atom_coord_grad) CALL torch_tensor_release(atom_coord_grad_t)
     411          144 :          CALL timestop(phase_handle)
     412              :       END IF
     413              : 
     414          144 :       CALL timeset("skala_gpw_cleanup", phase_handle)
     415          144 :       IF (.NOT. use_atom_subchunks) CALL torch_tensor_release(exc_tensor)
     416          144 :       CALL skala_gpw_feature_release(features)
     417          144 :       CALL xc_rho_set_release(rho_set, pw_pool=pw_pool)
     418          144 :       CALL torch_use_cuda(.TRUE.)
     419          144 :       CALL timestop(phase_handle)
     420              : 
     421         2880 :    END SUBROUTINE skala_gpw_eval
     422              : 
     423              : ! **************************************************************************************************
     424              : !> \brief Evaluate SKALA on a GAPW one-center atomic grid.
     425              : !> \param xc_section ...
     426              : !> \param grid_atom ...
     427              : !> \param group ...
     428              : !> \param atom_coord ...
     429              : !> \param rho ...
     430              : !> \param drho ...
     431              : !> \param tau ...
     432              : !> \param weights ...
     433              : !> \param lsd ...
     434              : !> \param nspins ...
     435              : !> \param na ...
     436              : !> \param nr ...
     437              : !> \param exc ...
     438              : !> \param vxc ...
     439              : !> \param vxg ...
     440              : !> \param vtau ...
     441              : !> \param energy_only ...
     442              : ! **************************************************************************************************
     443            8 :    SUBROUTINE skala_gapw_atom_vxc_of_r(xc_section, grid_atom, group, atom_coord, &
     444            8 :                                        rho, drho, tau, weights, lsd, nspins, na, nr, &
     445              :                                        exc, vxc, vxg, vtau, energy_only)
     446              :       TYPE(section_vals_type), POINTER                   :: xc_section
     447              :       TYPE(grid_atom_type), POINTER                      :: grid_atom
     448              : 
     449              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
     450              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: atom_coord
     451              :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: rho, tau, vxc, vtau
     452              :       REAL(KIND=dp), DIMENSION(:, :, :, :), POINTER      :: drho, vxg
     453              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: weights
     454              :       LOGICAL, INTENT(IN)                                :: lsd
     455              :       INTEGER, INTENT(IN)                                :: nspins, na, nr
     456              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     457              :       LOGICAL, INTENT(IN), OPTIONAL                      :: energy_only
     458              : 
     459              :       CHARACTER(len=default_path_length)                 :: model_path
     460              :       INTEGER                                            :: ia, idir, ir, native_grid_cuda_device, &
     461              :                                                             nflat, row, selected_cuda_device
     462            8 :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: atomic_grid_sizes
     463            8 :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :)  :: atomic_grid_size_bound_shape
     464              :       LOGICAL                                            :: my_energy_only, native_grid_use_cuda
     465            8 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: atomic_grid_weights, grid_weights
     466            8 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: coarse_0_atomic_coords, density, &
     467            8 :                                                             grid_coords, kin
     468            8 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: grad
     469            8 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: density_grad, kin_grad
     470            8 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad
     471              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
     472              :       TYPE(torch_dict_type)                              :: inputs
     473              :       TYPE(torch_tensor_type)                            :: atomic_grid_size_bound_shape_t, &
     474              :                                                             atomic_grid_sizes_t, &
     475              :                                                             atomic_grid_weights_t, &
     476              :                                                             coarse_0_atomic_coords_t, density_t, &
     477              :                                                             density_grad_t, exc_tensor, grad_t, &
     478              :                                                             grad_grad_t, grid_coords_t, &
     479              :                                                             grid_weights_t, kin_t, kin_grad_t
     480              : 
     481            0 :       CPASSERT(ASSOCIATED(xc_section))
     482            8 :       CPASSERT(ASSOCIATED(grid_atom))
     483            8 :       CPASSERT(ASSOCIATED(rho))
     484            8 :       CPASSERT(ASSOCIATED(drho))
     485            8 :       CPASSERT(ASSOCIATED(tau))
     486              : 
     487            8 :       my_energy_only = .FALSE.
     488            8 :       IF (PRESENT(energy_only)) my_energy_only = energy_only
     489            0 :       exc = 0.0_dp
     490            8 :       IF (.NOT. my_energy_only) THEN
     491        20416 :          vxc = 0.0_dp
     492        80416 :          vxg = 0.0_dp
     493        20416 :          vtau = 0.0_dp
     494              :       END IF
     495              : 
     496            8 :       CALL get_skala_model_path(xc_section, model_path)
     497            8 :       gauxc_section => get_gauxc_section(xc_section)
     498            8 :       CPASSERT(ASSOCIATED(gauxc_section))
     499            8 :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
     500              :       CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_CUDA_DEVICE", &
     501            8 :                                 i_val=native_grid_cuda_device)
     502            8 :       CALL torch_use_cuda(native_grid_use_cuda)
     503              :       selected_cuda_device = configure_native_grid_cuda( &
     504            8 :                              native_grid_use_cuda, native_grid_cuda_device, group)
     505            8 :       CALL ensure_model_loaded(model_path, selected_cuda_device)
     506              : 
     507            8 :       nflat = na*nr
     508              :       ALLOCATE (density(nflat, 2), grad(nflat, 3, 2), kin(nflat, 2), &
     509              :                 grid_coords(3, nflat), grid_weights(nflat), &
     510              :                 atomic_grid_weights(nflat), atomic_grid_sizes(1), &
     511          104 :                 coarse_0_atomic_coords(3, 1), atomic_grid_size_bound_shape(0, nflat))
     512            8 :       density = 0.0_dp
     513            8 :       grad = 0.0_dp
     514            8 :       kin = 0.0_dp
     515            8 :       grid_coords = 0.0_dp
     516            8 :       grid_weights = 0.0_dp
     517            8 :       atomic_grid_weights = 0.0_dp
     518            8 :       atomic_grid_sizes(1) = INT(nflat, KIND=int_8)
     519              :       atomic_grid_size_bound_shape = 0_int_8
     520           32 :       coarse_0_atomic_coords(:, 1) = atom_coord
     521              : 
     522              :       row = 0
     523          408 :       DO ir = 1, nr
     524        20408 :          DO ia = 1, na
     525        20000 :             row = row + 1
     526              :             grid_coords(1, row) = atom_coord(1) + grid_atom%rad(ir)* &
     527        20000 :                                   grid_atom%sin_pol(ia)*grid_atom%cos_azi(ia)
     528              :             grid_coords(2, row) = atom_coord(2) + grid_atom%rad(ir)* &
     529        20000 :                                   grid_atom%sin_pol(ia)*grid_atom%sin_azi(ia)
     530        20000 :             grid_coords(3, row) = atom_coord(3) + grid_atom%rad(ir)*grid_atom%cos_pol(ia)
     531        20000 :             grid_weights(row) = weights(ia, ir)
     532        20000 :             atomic_grid_weights(row) = weights(ia, ir)
     533        20400 :             IF (nspins == 1) THEN
     534        60000 :                density(row, :) = 0.5_dp*rho(ia, ir, 1)
     535        80000 :                DO idir = 1, 3
     536       200000 :                   grad(row, idir, :) = 0.5_dp*drho(idir, ia, ir, 1)
     537              :                END DO
     538        60000 :                kin(row, :) = 0.5_dp*tau(ia, ir, 1)
     539              :             ELSE
     540            0 :                density(row, :) = rho(ia, ir, 1:2)
     541            0 :                DO idir = 1, 3
     542            0 :                   grad(row, idir, :) = drho(idir, ia, ir, 1:2)
     543              :                END DO
     544            0 :                kin(row, :) = tau(ia, ir, 1:2)
     545              :             END IF
     546              :          END DO
     547              :       END DO
     548              : 
     549            8 :       CALL torch_tensor_from_array(grid_coords_t, grid_coords)
     550            8 :       CALL torch_tensor_to_device_leaf(grid_coords_t, .FALSE.)
     551            8 :       CALL torch_tensor_from_array(grid_weights_t, grid_weights)
     552            8 :       CALL torch_tensor_to_device_leaf(grid_weights_t, .FALSE.)
     553            8 :       CALL torch_tensor_from_array(atomic_grid_weights_t, atomic_grid_weights)
     554            8 :       CALL torch_tensor_to_device_leaf(atomic_grid_weights_t, .FALSE.)
     555            8 :       CALL torch_tensor_from_array(atomic_grid_sizes_t, atomic_grid_sizes)
     556            8 :       CALL torch_tensor_to_device_leaf(atomic_grid_sizes_t, .FALSE.)
     557              :       CALL torch_tensor_from_array(atomic_grid_size_bound_shape_t, &
     558            8 :                                    atomic_grid_size_bound_shape)
     559            8 :       CALL torch_tensor_to_device_leaf(atomic_grid_size_bound_shape_t, .FALSE.)
     560            8 :       CALL torch_tensor_from_array(coarse_0_atomic_coords_t, coarse_0_atomic_coords)
     561            8 :       CALL torch_tensor_to_device_leaf(coarse_0_atomic_coords_t, .FALSE.)
     562            8 :       CALL torch_tensor_from_array(density_t, density)
     563            8 :       CALL torch_tensor_to_device_leaf(density_t,.NOT. my_energy_only)
     564            8 :       CALL torch_tensor_from_array(grad_t, grad)
     565            8 :       CALL torch_tensor_to_device_leaf(grad_t,.NOT. my_energy_only)
     566            8 :       CALL torch_tensor_from_array(kin_t, kin)
     567            8 :       CALL torch_tensor_to_device_leaf(kin_t,.NOT. my_energy_only)
     568              : 
     569            8 :       CALL torch_dict_create(inputs)
     570            8 :       CALL torch_dict_insert(inputs, "grid_coords", grid_coords_t)
     571            8 :       CALL torch_dict_insert(inputs, "grid_weights", grid_weights_t)
     572            8 :       CALL torch_dict_insert(inputs, "atomic_grid_weights", atomic_grid_weights_t)
     573            8 :       CALL torch_dict_insert(inputs, "atomic_grid_sizes", atomic_grid_sizes_t)
     574              :       CALL torch_dict_insert(inputs, "atomic_grid_size_bound_shape", &
     575            8 :                              atomic_grid_size_bound_shape_t)
     576            8 :       CALL torch_dict_insert(inputs, "density", density_t)
     577            8 :       CALL torch_dict_insert(inputs, "grad", grad_t)
     578            8 :       CALL torch_dict_insert(inputs, "kin", kin_t)
     579            8 :       CALL torch_dict_insert(inputs, "coarse_0_atomic_coords", coarse_0_atomic_coords_t)
     580              : 
     581            8 :       CALL skala_torch_model_get_exc(cached_model, inputs, grid_weights_t, exc_tensor, exc)
     582              : 
     583            8 :       IF (.NOT. my_energy_only) THEN
     584            8 :          NULLIFY (density_grad, grad_grad, kin_grad)
     585            8 :          CALL torch_tensor_backward_scalar(exc_tensor)
     586            8 :          CALL torch_tensor_grad(density_t, density_grad_t)
     587            8 :          CALL torch_tensor_grad(grad_t, grad_grad_t)
     588            8 :          CALL torch_tensor_grad(kin_t, kin_grad_t)
     589            8 :          CALL torch_tensor_data_ptr(density_grad_t, density_grad)
     590            8 :          CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
     591            8 :          CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
     592              : 
     593            8 :          row = 0
     594          408 :          DO ir = 1, nr
     595        20408 :             DO ia = 1, na
     596        20000 :                row = row + 1
     597        20400 :                IF (lsd) THEN
     598            0 :                   vxc(ia, ir, 1:2) = density_grad(row, 1:2)
     599            0 :                   DO idir = 1, 3
     600            0 :                      vxg(idir, ia, ir, 1:2) = grad_grad(row, idir, 1:2)
     601              :                   END DO
     602            0 :                   vtau(ia, ir, 1:2) = kin_grad(row, 1:2)
     603              :                ELSE
     604        20000 :                   vxc(ia, ir, 1) = 0.5_dp*(density_grad(row, 1) + density_grad(row, 2))
     605        80000 :                   DO idir = 1, 3
     606              :                      vxg(idir, ia, ir, 1) = &
     607        80000 :                         0.5_dp*(grad_grad(row, idir, 1) + grad_grad(row, idir, 2))
     608              :                   END DO
     609        20000 :                   vtau(ia, ir, 1) = 0.5_dp*(kin_grad(row, 1) + kin_grad(row, 2))
     610              :                END IF
     611              :             END DO
     612              :          END DO
     613              : 
     614            8 :          CALL torch_tensor_release(density_grad_t)
     615            8 :          CALL torch_tensor_release(grad_grad_t)
     616            8 :          CALL torch_tensor_release(kin_grad_t)
     617              :       END IF
     618              : 
     619            8 :       CALL torch_tensor_release(exc_tensor)
     620            8 :       CALL torch_tensor_release(density_t)
     621            8 :       CALL torch_tensor_release(grad_t)
     622            8 :       CALL torch_tensor_release(kin_t)
     623            8 :       CALL torch_tensor_release(grid_coords_t)
     624            8 :       CALL torch_tensor_release(grid_weights_t)
     625            8 :       CALL torch_tensor_release(atomic_grid_weights_t)
     626            8 :       CALL torch_tensor_release(atomic_grid_sizes_t)
     627            8 :       CALL torch_tensor_release(atomic_grid_size_bound_shape_t)
     628            8 :       CALL torch_tensor_release(coarse_0_atomic_coords_t)
     629            8 :       CALL torch_dict_release(inputs)
     630            0 :       DEALLOCATE (atomic_grid_size_bound_shape, atomic_grid_sizes, atomic_grid_weights, &
     631            8 :                   coarse_0_atomic_coords, density, grad, grid_coords, grid_weights, kin)
     632            8 :       CALL torch_use_cuda(.TRUE.)
     633              : 
     634           24 :    END SUBROUTINE skala_gapw_atom_vxc_of_r
     635              : 
     636              : ! **************************************************************************************************
     637              : !> \brief Add the explicit SKALA derivative with respect to atom-center coordinates.
     638              : !> \param atom_force ...
     639              : !> \param features ...
     640              : !> \param atom_coord_grad_t ...
     641              : !> \param root_rank ...
     642              : ! **************************************************************************************************
     643           16 :    SUBROUTINE add_explicit_coordinate_force(atom_force, features, atom_coord_grad_t, root_rank)
     644              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: atom_force
     645              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     646              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: atom_coord_grad_t
     647              :       LOGICAL, INTENT(IN)                                :: root_rank
     648              : 
     649           16 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: atom_coord_grad
     650              : 
     651           16 :       NULLIFY (atom_coord_grad)
     652           16 :       CALL torch_tensor_grad(features%coarse_0_atomic_coords_t, atom_coord_grad_t)
     653           16 :       IF (root_rank) THEN
     654            8 :          CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
     655            8 :          CPASSERT(SIZE(atom_force, 1) == SIZE(atom_coord_grad, 1))
     656            8 :          CPASSERT(SIZE(atom_force, 2) == SIZE(atom_coord_grad, 2))
     657           72 :          atom_force(:, :) = atom_force(:, :) + atom_coord_grad(:, :)
     658              :       END IF
     659              : 
     660           16 :    END SUBROUTINE add_explicit_coordinate_force
     661              : 
     662              : ! **************************************************************************************************
     663              : !> \brief Add the force from SMOOTH native-grid atom partition weights.
     664              : !> \param atom_force ...
     665              : !> \param features ...
     666              : !> \param particle_set ...
     667              : !> \param cell ...
     668              : !> \param rho_r ...
     669              : !> \param grid_weight_grad_t ...
     670              : !> \param atomic_grid_weight_grad_t ...
     671              : ! **************************************************************************************************
     672            4 :    SUBROUTINE add_smooth_partition_force(atom_force, features, particle_set, cell, rho_r, &
     673              :                                          grid_weight_grad_t, atomic_grid_weight_grad_t)
     674              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: atom_force
     675              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     676              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     677              :       TYPE(cell_type), POINTER                           :: cell
     678              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
     679              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grid_weight_grad_t, &
     680              :                                                             atomic_grid_weight_grad_t
     681              : 
     682              :       INTEGER                                            :: feature_begin, feature_end, feature_pos, &
     683              :                                                             i, iatom, j, jatom, k, local_row, &
     684              :                                                             natom, row
     685              :       INTEGER, DIMENSION(2, 3)                           :: bo
     686              :       LOGICAL, ALLOCATABLE, DIMENSION(:)                 :: included
     687              :       REAL(KIND=dp)                                      :: base_weight, weight_grad
     688              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: weights
     689              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords_pbc
     690              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: dweights_datom, dweights_dstrain
     691              :       REAL(KIND=dp), DIMENSION(3)                        :: grid_point
     692            4 :       REAL(KIND=dp), DIMENSION(:), POINTER               :: atomic_grid_weight_grad, grid_weight_grad
     693              : 
     694            4 :       NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
     695            4 :       CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
     696            4 :       CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
     697            4 :       CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
     698            4 :       CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
     699              : 
     700            4 :       natom = SIZE(particle_set)
     701            4 :       CPASSERT(SIZE(atom_force, 1) == 3)
     702            4 :       CPASSERT(SIZE(atom_force, 2) == natom)
     703              :       ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
     704           48 :                 dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
     705           12 :       DO iatom = 1, natom
     706           12 :          atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
     707              :       END DO
     708              : 
     709           40 :       bo = rho_r(1)%pw_grid%bounds_local
     710            4 :       local_row = 0
     711          100 :       DO k = bo(1, 3), bo(2, 3)
     712         2404 :          DO j = bo(1, 2), bo(2, 2)
     713        30048 :             DO i = bo(1, 1), bo(2, 1)
     714        27648 :                local_row = local_row + 1
     715       110592 :                grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
     716              :                CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
     717              :                                                            weights, included, dweights_datom, &
     718        27648 :                                                            dweights_dstrain)
     719        27648 :                feature_begin = features%local_feature_offsets(local_row)
     720        27648 :                feature_end = features%local_feature_offsets(local_row + 1) - 1
     721        82944 :                CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
     722        27648 :                base_weight = 0.0_dp
     723        82804 :                DO feature_pos = feature_begin, feature_end
     724        55156 :                   row = features%local_feature_rows(feature_pos)
     725        82804 :                   base_weight = base_weight + features%grid_weights(row)
     726              :                END DO
     727              :                feature_pos = feature_begin
     728        82944 :                DO iatom = 1, natom
     729        55296 :                   IF (.NOT. included(iatom)) CYCLE
     730        55156 :                   row = features%local_feature_rows(feature_pos)
     731        55156 :                   weight_grad = grid_weight_grad(row) + atomic_grid_weight_grad(row)
     732       165468 :                   DO jatom = 1, natom
     733              :                      atom_force(:, jatom) = atom_force(:, jatom) + &
     734              :                                             weight_grad*base_weight* &
     735       496404 :                                             dweights_datom(:, jatom, iatom)
     736              :                   END DO
     737        82944 :                   feature_pos = feature_pos + 1
     738              :                END DO
     739        29952 :                CPASSERT(feature_pos == feature_end + 1)
     740              :             END DO
     741              :          END DO
     742              :       END DO
     743            4 :       CPASSERT(local_row == features%nflat_local)
     744              : 
     745            4 :       DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
     746            4 :       CALL torch_tensor_release(grid_weight_grad_t)
     747            4 :       CALL torch_tensor_release(atomic_grid_weight_grad_t)
     748              : 
     749            4 :    END SUBROUTINE add_smooth_partition_force
     750              : 
     751              : ! **************************************************************************************************
     752              : !> \brief Add the virial from SMOOTH native-grid atom partition weights.
     753              : !> \param virial_xc ...
     754              : !> \param features ...
     755              : !> \param particle_set ...
     756              : !> \param cell ...
     757              : !> \param rho_r ...
     758              : !> \param grid_weight_grad_t ...
     759              : !> \param atomic_grid_weight_grad_t ...
     760              : ! **************************************************************************************************
     761            2 :    SUBROUTINE build_smooth_partition_virial(virial_xc, features, particle_set, cell, rho_r, &
     762              :                                             grid_weight_grad_t, atomic_grid_weight_grad_t)
     763              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
     764              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     765              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     766              :       TYPE(cell_type), POINTER                           :: cell
     767              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
     768              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grid_weight_grad_t, &
     769              :                                                             atomic_grid_weight_grad_t
     770              : 
     771              :       INTEGER                                            :: feature_begin, feature_end, feature_pos, &
     772              :                                                             i, iatom, idir, j, jdir, k, local_row, &
     773              :                                                             natom, row
     774              :       INTEGER, DIMENSION(2, 3)                           :: bo
     775              :       LOGICAL, ALLOCATABLE, DIMENSION(:)                 :: included
     776              :       REAL(KIND=dp)                                      :: base_weight, tmp, weight_grad
     777              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: weights
     778              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords_pbc
     779              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: dweights_datom, dweights_dstrain
     780              :       REAL(KIND=dp), DIMENSION(3)                        :: grid_point
     781            2 :       REAL(KIND=dp), DIMENSION(:), POINTER               :: atomic_grid_weight_grad, grid_weight_grad
     782              : 
     783            2 :       NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
     784            2 :       CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
     785            2 :       CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
     786            2 :       CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
     787            2 :       CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
     788              : 
     789            2 :       natom = SIZE(particle_set)
     790              :       ALLOCATE (atom_coords_pbc(3, natom), included(natom), weights(natom), &
     791           24 :                 dweights_datom(3, natom, natom), dweights_dstrain(3, 3, natom))
     792            6 :       DO iatom = 1, natom
     793            6 :          atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
     794              :       END DO
     795              : 
     796           20 :       bo = rho_r(1)%pw_grid%bounds_local
     797            2 :       local_row = 0
     798           50 :       DO k = bo(1, 3), bo(2, 3)
     799         1202 :          DO j = bo(1, 2), bo(2, 2)
     800        15024 :             DO i = bo(1, 1), bo(2, 1)
     801        13824 :                local_row = local_row + 1
     802        55296 :                grid_point = native_grid_coordinate(rho_r(1)%pw_grid, [i, j, k])
     803              :                CALL skala_gpw_smooth_partition_derivatives(grid_point, atom_coords_pbc, cell, &
     804              :                                                            weights, included, dweights_datom, &
     805        13824 :                                                            dweights_dstrain)
     806        13824 :                feature_begin = features%local_feature_offsets(local_row)
     807        13824 :                feature_end = features%local_feature_offsets(local_row + 1) - 1
     808        41472 :                CPASSERT(feature_end - feature_begin + 1 == COUNT(included))
     809        13824 :                base_weight = 0.0_dp
     810        41402 :                DO feature_pos = feature_begin, feature_end
     811        27578 :                   row = features%local_feature_rows(feature_pos)
     812        41402 :                   base_weight = base_weight + features%grid_weights(row)
     813              :                END DO
     814              :                feature_pos = feature_begin
     815        41472 :                DO iatom = 1, natom
     816        27648 :                   IF (.NOT. included(iatom)) CYCLE
     817        27578 :                   row = features%local_feature_rows(feature_pos)
     818        27578 :                   weight_grad = grid_weight_grad(row) + atomic_grid_weight_grad(row)
     819       110312 :                   DO idir = 1, 3
     820       275780 :                      DO jdir = 1, idir
     821       165468 :                         tmp = weight_grad*base_weight*dweights_dstrain(idir, jdir, iatom)
     822       165468 :                         virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
     823       248202 :                         virial_xc(idir, jdir) = virial_xc(jdir, idir)
     824              :                      END DO
     825              :                   END DO
     826        41472 :                   feature_pos = feature_pos + 1
     827              :                END DO
     828        14976 :                CPASSERT(feature_pos == feature_end + 1)
     829              :             END DO
     830              :          END DO
     831              :       END DO
     832            2 :       CPASSERT(local_row == features%nflat_local)
     833              : 
     834            2 :       DEALLOCATE (atom_coords_pbc, dweights_datom, dweights_dstrain, included, weights)
     835            2 :       CALL torch_tensor_release(grid_weight_grad_t)
     836            2 :       CALL torch_tensor_release(atomic_grid_weight_grad_t)
     837              : 
     838            2 :    END SUBROUTINE build_smooth_partition_virial
     839              : 
     840              : ! **************************************************************************************************
     841              : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
     842              : !> \param pw_grid ...
     843              : !> \param index ...
     844              : !> \return ...
     845              : ! **************************************************************************************************
     846        41472 :    FUNCTION native_grid_coordinate(pw_grid, index) RESULT(coord)
     847              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     848              :       INTEGER, DIMENSION(3), INTENT(IN)                  :: index
     849              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
     850              : 
     851              :       INTEGER, DIMENSION(3)                              :: relative_index
     852              : 
     853       165888 :       relative_index = index - pw_grid%bounds(1, :)
     854              :       coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
     855              :               REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
     856       165888 :               REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
     857              : 
     858        41472 :    END FUNCTION native_grid_coordinate
     859              : 
     860              : ! **************************************************************************************************
     861              : !> \brief Evaluate a rank-local atom chunk as multiple atom-contiguous Torch subchunks.
     862              : !> \param features ...
     863              : !> \param group ...
     864              : !> \param max_rows ...
     865              : !> \param compute_grads ...
     866              : !> \param exc ...
     867              : !> \param density_grad ...
     868              : !> \param grad_grad ...
     869              : !> \param kin_grad ...
     870              : !> \param collapse_spin_grads ...
     871              : ! **************************************************************************************************
     872            2 :    SUBROUTINE evaluate_atom_subchunks(features, group, max_rows, compute_grads, exc, &
     873              :                                       density_grad, grad_grad, kin_grad, collapse_spin_grads)
     874              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
     875              : 
     876              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
     877              :       INTEGER, INTENT(IN)                                :: max_rows
     878              :       LOGICAL, INTENT(IN)                                :: compute_grads, collapse_spin_grads
     879              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     880              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
     881              :          INTENT(OUT)                                     :: density_grad, kin_grad
     882              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
     883              :          INTENT(OUT)                                     :: grad_grad
     884              : 
     885              :       INTEGER                                            :: base, isubchunk, local_row, nflat_local, &
     886              :                                                             nroute_grad_per_point, nroute_points, &
     887              :                                                             nsubchunks, phase_handle, point_pos, &
     888              :                                                             subphase_handle
     889            2 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: route_grad_return_recv_counts, &
     890            2 :                                                             route_grad_return_recv_displs, &
     891            2 :                                                             route_grad_return_send_counts, &
     892            2 :                                                             route_grad_return_send_displs
     893              :       REAL(KIND=dp)                                      :: subchunk_exc
     894            2 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: recv_grad_buffer, send_grad_buffer
     895            2 :       TYPE(skala_gpw_feature_type)                       :: subchunk
     896              :       TYPE(torch_tensor_type)                            :: subchunk_exc_tensor
     897              : 
     898            0 :       CPASSERT(features%uses_atom_chunks)
     899            2 :       CPASSERT(max_rows > 0)
     900            2 :       nflat_local = features%nflat_local
     901            2 :       nsubchunks = skala_gpw_atom_subchunk_count(max_rows)
     902            2 :       CPASSERT(nsubchunks > 0)
     903              : 
     904            2 :       exc = 0.0_dp
     905            2 :       IF (compute_grads) THEN
     906            2 :          CPASSERT(features%uses_atom_chunk_routing)
     907            6 :          CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
     908            2 :          nroute_points = SIZE(features%route_send_local_rows)
     909            6 :          CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
     910            2 :          nroute_grad_per_point = ngrad_per_point
     911            2 :          IF (collapse_spin_grads) nroute_grad_per_point = ncollapsed_grad_per_point
     912              :          ALLOCATE (send_grad_buffer(nroute_grad_per_point*features%chunk_feature_count), &
     913              :                    recv_grad_buffer(nroute_grad_per_point*nroute_points), &
     914              :                    route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
     915              :                    route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
     916              :                    route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
     917           26 :                    route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
     918              :          route_grad_return_send_counts(:) = &
     919            6 :             nroute_grad_per_point*features%route_point_recv_counts
     920              :          route_grad_return_send_displs(:) = &
     921            6 :             nroute_grad_per_point*features%route_point_recv_displs
     922              :          route_grad_return_recv_counts(:) = &
     923            6 :             nroute_grad_per_point*features%route_point_send_counts
     924              :          route_grad_return_recv_displs(:) = &
     925            6 :             nroute_grad_per_point*features%route_point_send_displs
     926              :       END IF
     927              : 
     928            2 :       CALL timeset("skala_gpw_atom_subchunks", phase_handle)
     929            6 :       DO isubchunk = 1, nsubchunks
     930            4 :          CALL timeset("skala_gpw_atom_subchunk_build", subphase_handle)
     931              :          CALL skala_gpw_feature_build_atom_subchunk(features, subchunk, isubchunk, &
     932            4 :                                                     max_rows, compute_grads)
     933            4 :          CALL timestop(subphase_handle)
     934            4 :          CALL timeset("skala_gpw_atom_subchunk_forward", subphase_handle)
     935              :          CALL skala_torch_model_get_exc(cached_model, subchunk%inputs, &
     936              :                                         subchunk%grid_weights_t, subchunk_exc_tensor, &
     937            4 :                                         subchunk_exc)
     938            4 :          CALL timestop(subphase_handle)
     939            4 :          exc = exc + subchunk_exc
     940            4 :          IF (compute_grads) THEN
     941            4 :             CALL timeset("skala_gpw_atom_subchunk_backward", subphase_handle)
     942            4 :             CALL torch_tensor_backward_scalar(subchunk_exc_tensor)
     943            4 :             CALL timestop(subphase_handle)
     944              :          END IF
     945            4 :          CALL timeset("skala_gpw_atom_subchunk_release", subphase_handle)
     946            4 :          CALL torch_tensor_release(subchunk_exc_tensor)
     947            4 :          CALL skala_gpw_feature_release(subchunk)
     948           18 :          CALL timestop(subphase_handle)
     949              :       END DO
     950            2 :       IF (compute_grads) THEN
     951            2 :          CALL timeset("skala_gpw_atom_subchunk_grad_pack", subphase_handle)
     952            2 :          CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., collapse_spin_grads)
     953            2 :          CALL timestop(subphase_handle)
     954              :       END IF
     955            2 :       CALL timestop(phase_handle)
     956              : 
     957            2 :       IF (compute_grads) THEN
     958            2 :          CALL timeset("skala_gpw_grad_route_comm", phase_handle)
     959              :          CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
     960              :                              route_grad_return_send_displs, recv_grad_buffer, &
     961            2 :                              route_grad_return_recv_counts, route_grad_return_recv_displs)
     962            2 :          CALL timestop(phase_handle)
     963              : 
     964            2 :          CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
     965            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
     966           14 :                    kin_grad(nflat_local, 2))
     967            2 :          density_grad = 0.0_dp
     968            2 :          grad_grad = 0.0_dp
     969            2 :          kin_grad = 0.0_dp
     970        64002 :          DO point_pos = 1, nroute_points
     971        64000 :             local_row = features%route_send_local_rows(point_pos)
     972        64000 :             CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
     973        64000 :             base = nroute_grad_per_point*(point_pos - 1)
     974        64002 :             IF (collapse_spin_grads) THEN
     975              :                density_grad(local_row, :) = density_grad(local_row, :) + &
     976       192000 :                                             recv_grad_buffer(base + 1)
     977              :                grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
     978       192000 :                                             recv_grad_buffer(base + 2)
     979              :                grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
     980       192000 :                                             recv_grad_buffer(base + 3)
     981              :                grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
     982       192000 :                                             recv_grad_buffer(base + 4)
     983       192000 :                kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
     984              :             ELSE
     985              :                density_grad(local_row, :) = density_grad(local_row, :) + &
     986            0 :                                             recv_grad_buffer(base + 1:base + 2)
     987              :                grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
     988            0 :                                             recv_grad_buffer(base + 3)
     989              :                grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
     990            0 :                                             recv_grad_buffer(base + 4)
     991              :                grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
     992            0 :                                             recv_grad_buffer(base + 5)
     993              :                grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
     994            0 :                                             recv_grad_buffer(base + 6)
     995              :                grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
     996            0 :                                             recv_grad_buffer(base + 7)
     997              :                grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
     998            0 :                                             recv_grad_buffer(base + 8)
     999              :                kin_grad(local_row, :) = kin_grad(local_row, :) + &
    1000            0 :                                         recv_grad_buffer(base + 9:base + 10)
    1001              :             END IF
    1002              :          END DO
    1003            2 :          CALL timestop(phase_handle)
    1004              : 
    1005            0 :          DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
    1006            0 :                      route_grad_return_recv_displs, route_grad_return_send_counts, &
    1007            6 :                      route_grad_return_send_displs, send_grad_buffer)
    1008              :       END IF
    1009              : 
    1010            4 :    END SUBROUTINE evaluate_atom_subchunks
    1011              : 
    1012              : ! **************************************************************************************************
    1013              : !> \brief Select an automatic CUDA atom-subchunk row cap.
    1014              : !> \param features ...
    1015              : !> \param group ...
    1016              : !> \return ...
    1017              : ! **************************************************************************************************
    1018            0 :    FUNCTION auto_atom_chunk_max_rows(features, group) RESULT(max_rows)
    1019              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1020              : 
    1021              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1022              :       INTEGER                                            :: max_rows
    1023              : 
    1024              :       INTEGER                                            :: local_rows_max, target_rows
    1025              : 
    1026            0 :       local_rows_max = features%chunk_feature_count
    1027            0 :       CALL group%max(local_rows_max)
    1028            0 :       IF (local_rows_max <= 0) THEN
    1029            0 :          max_rows = 0
    1030              :          RETURN
    1031              :       END IF
    1032              : 
    1033            0 :       IF (group%num_pe > 1) THEN
    1034            0 :          target_rows = CEILING(REAL(local_rows_max, KIND=dp)/2.0_dp)
    1035              :          max_rows = atom_chunk_auto_row_quantum* &
    1036            0 :                     ((target_rows + atom_chunk_auto_row_quantum - 1)/atom_chunk_auto_row_quantum)
    1037              :       ELSE
    1038            0 :          target_rows = NINT(REAL(local_rows_max, KIND=dp)/4.0_dp)
    1039              :          max_rows = atom_chunk_auto_row_quantum* &
    1040              :                     MAX(1, NINT(REAL(target_rows, KIND=dp)/ &
    1041            0 :                                 REAL(atom_chunk_auto_row_quantum, KIND=dp)))
    1042              :       END IF
    1043            0 :       max_rows = MAX(atom_chunk_auto_min_rows, MIN(atom_chunk_auto_max_rows, max_rows))
    1044              : 
    1045            0 :    END FUNCTION auto_atom_chunk_max_rows
    1046              : 
    1047              : ! **************************************************************************************************
    1048              : !> \brief Map full Torch feature gradients back to this rank's local grid order.
    1049              : !> \param features ...
    1050              : !> \param density_grad ...
    1051              : !> \param grad_grad ...
    1052              : !> \param kin_grad ...
    1053              : ! **************************************************************************************************
    1054          142 :    SUBROUTINE fetch_local_feature_grads(features, density_grad, grad_grad, kin_grad)
    1055              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1056              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1057              :          INTENT(OUT)                                     :: density_grad
    1058              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
    1059              :          INTENT(OUT)                                     :: grad_grad
    1060              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1061              :          INTENT(OUT)                                     :: kin_grad
    1062              : 
    1063              :       INTEGER                                            :: feature_pos, i, j, k, local_row, row
    1064          142 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: density_grad_all, kin_grad_all
    1065          142 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad_all
    1066              :       TYPE(torch_tensor_type)                            :: density_grad_t, grad_grad_t, kin_grad_t
    1067              : 
    1068          142 :       NULLIFY (density_grad_all, grad_grad_all, kin_grad_all)
    1069              :       CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
    1070          142 :                                   density_grad_all, grad_grad_all, kin_grad_all)
    1071          142 :       CPASSERT(SIZE(density_grad_all, 1) == features%nflat)
    1072          142 :       CPASSERT(SIZE(density_grad_all, 2) == 2)
    1073          142 :       CPASSERT(SIZE(grad_grad_all, 1) == features%nflat)
    1074          142 :       CPASSERT(SIZE(grad_grad_all, 2) == 3)
    1075          142 :       CPASSERT(SIZE(grad_grad_all, 3) == 2)
    1076          142 :       CPASSERT(SIZE(kin_grad_all, 1) == features%nflat)
    1077          142 :       CPASSERT(SIZE(kin_grad_all, 2) == 2)
    1078              : 
    1079            0 :       ALLOCATE (density_grad(features%nflat_local, 2), &
    1080            0 :                 grad_grad(features%nflat_local, 3, 2), &
    1081          994 :                 kin_grad(features%nflat_local, 2))
    1082          142 :       density_grad = 0.0_dp
    1083          142 :       grad_grad = 0.0_dp
    1084          142 :       kin_grad = 0.0_dp
    1085          142 :       local_row = 0
    1086         3462 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1087        89518 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1088      1546458 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1089      1302618 :                local_row = local_row + 1
    1090      2674006 :                DO feature_pos = features%local_feature_offsets(local_row), &
    1091      1382886 :                   features%local_feature_offsets(local_row + 1) - 1
    1092      1371388 :                   row = features%local_feature_rows(feature_pos)
    1093      1371388 :                   CPASSERT(row >= 1 .AND. row <= features%nflat)
    1094              :                   density_grad(local_row, :) = density_grad(local_row, :) + &
    1095      4114164 :                                                density_grad_all(row, :)
    1096              :                   grad_grad(local_row, :, :) = grad_grad(local_row, :, :) + &
    1097     12342492 :                                                grad_grad_all(row, :, :)
    1098      5416782 :                   kin_grad(local_row, :) = kin_grad(local_row, :) + kin_grad_all(row, :)
    1099              :                END DO
    1100              :             END DO
    1101              :          END DO
    1102              :       END DO
    1103          142 :       CPASSERT(local_row == features%nflat_local)
    1104              : 
    1105          142 :       CALL torch_tensor_release(density_grad_t)
    1106          142 :       CALL torch_tensor_release(grad_grad_t)
    1107          142 :       CALL torch_tensor_release(kin_grad_t)
    1108              : 
    1109          142 :    END SUBROUTINE fetch_local_feature_grads
    1110              : 
    1111              : ! **************************************************************************************************
    1112              : !> \brief Pack atom-chunk Torch gradients into CP2K communication buffers.
    1113              : !> \param features ...
    1114              : !> \param TARGET ...
    1115              : !> \param route_to_return_positions ...
    1116              : !> \param collapse_spin_grads ...
    1117              : ! **************************************************************************************************
    1118            2 :    SUBROUTINE pack_atom_chunk_grads(features, TARGET, route_to_return_positions, &
    1119              :                                     collapse_spin_grads)
    1120              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1121              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
    1122              :          INTENT(INOUT)                                   :: target
    1123              :       LOGICAL, INTENT(IN)                                :: route_to_return_positions
    1124              :       LOGICAL, INTENT(IN), OPTIONAL                      :: collapse_spin_grads
    1125              : 
    1126              :       INTEGER                                            :: base, irow, ngrad_buffer_per_point, &
    1127              :                                                             point_pos, target_points
    1128              :       LOGICAL                                            :: my_collapse_spin_grads
    1129            2 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: chunk_density_grad, chunk_kin_grad
    1130            2 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: chunk_grad_grad
    1131              :       TYPE(torch_tensor_type)                            :: density_grad_t, grad_grad_t, kin_grad_t
    1132              : 
    1133            2 :       my_collapse_spin_grads = .FALSE.
    1134            4 :       IF (PRESENT(collapse_spin_grads)) my_collapse_spin_grads = collapse_spin_grads
    1135            2 :       ngrad_buffer_per_point = ngrad_per_point
    1136            2 :       IF (my_collapse_spin_grads) ngrad_buffer_per_point = ncollapsed_grad_per_point
    1137              : 
    1138            2 :       NULLIFY (chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
    1139              :       CALL get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
    1140            2 :                                   chunk_density_grad, chunk_grad_grad, chunk_kin_grad)
    1141            2 :       CPASSERT(MOD(SIZE(TARGET), ngrad_buffer_per_point) == 0)
    1142            2 :       target_points = SIZE(TARGET)/ngrad_buffer_per_point
    1143            2 :       CPASSERT(target_points >= features%chunk_feature_count)
    1144            2 :       CPASSERT(SIZE(chunk_density_grad, 1) == features%chunk_feature_count)
    1145            2 :       CPASSERT(SIZE(chunk_grad_grad, 1) == features%chunk_feature_count)
    1146            2 :       CPASSERT(SIZE(chunk_grad_grad, 2) == 3)
    1147            2 :       CPASSERT(SIZE(chunk_kin_grad, 1) == features%chunk_feature_count)
    1148            2 :       IF (features%uses_collapsed_rks_dynamic) THEN
    1149            2 :          CPASSERT(my_collapse_spin_grads)
    1150            2 :          CPASSERT(SIZE(chunk_density_grad, 2) == 1)
    1151            2 :          CPASSERT(SIZE(chunk_grad_grad, 3) == 1)
    1152            2 :          CPASSERT(SIZE(chunk_kin_grad, 2) == 1)
    1153              :       ELSE
    1154            0 :          CPASSERT(SIZE(chunk_density_grad, 2) == 2)
    1155            0 :          CPASSERT(SIZE(chunk_grad_grad, 3) == 2)
    1156            0 :          CPASSERT(SIZE(chunk_kin_grad, 2) == 2)
    1157              :       END IF
    1158              : 
    1159        64002 :       DO irow = 1, features%chunk_feature_count
    1160        64000 :          IF (route_to_return_positions) THEN
    1161        64000 :             point_pos = features%chunk_return_positions(irow)
    1162        64000 :             CPASSERT(point_pos >= 1 .AND. point_pos <= target_points)
    1163              :          ELSE
    1164              :             point_pos = irow
    1165              :          END IF
    1166        64000 :          base = ngrad_buffer_per_point*(point_pos - 1)
    1167        64002 :          IF (my_collapse_spin_grads) THEN
    1168        64000 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1169        64000 :                TARGET(base + 1) = 0.5_dp*chunk_density_grad(irow, 1)
    1170        64000 :                TARGET(base + 2) = 0.5_dp*chunk_grad_grad(irow, 1, 1)
    1171        64000 :                TARGET(base + 3) = 0.5_dp*chunk_grad_grad(irow, 2, 1)
    1172        64000 :                TARGET(base + 4) = 0.5_dp*chunk_grad_grad(irow, 3, 1)
    1173        64000 :                TARGET(base + 5) = 0.5_dp*chunk_kin_grad(irow, 1)
    1174              :             ELSE
    1175              :                TARGET(base + 1) = 0.5_dp*(chunk_density_grad(irow, 1) + &
    1176            0 :                                           chunk_density_grad(irow, 2))
    1177              :                TARGET(base + 2) = 0.5_dp*(chunk_grad_grad(irow, 1, 1) + &
    1178            0 :                                           chunk_grad_grad(irow, 1, 2))
    1179              :                TARGET(base + 3) = 0.5_dp*(chunk_grad_grad(irow, 2, 1) + &
    1180            0 :                                           chunk_grad_grad(irow, 2, 2))
    1181              :                TARGET(base + 4) = 0.5_dp*(chunk_grad_grad(irow, 3, 1) + &
    1182            0 :                                           chunk_grad_grad(irow, 3, 2))
    1183            0 :                TARGET(base + 5) = 0.5_dp*(chunk_kin_grad(irow, 1) + chunk_kin_grad(irow, 2))
    1184              :             END IF
    1185              :          ELSE
    1186            0 :             TARGET(base + 1:base + 2) = chunk_density_grad(irow, :)
    1187            0 :             TARGET(base + 3) = chunk_grad_grad(irow, 1, 1)
    1188            0 :             TARGET(base + 4) = chunk_grad_grad(irow, 2, 1)
    1189            0 :             TARGET(base + 5) = chunk_grad_grad(irow, 3, 1)
    1190            0 :             TARGET(base + 6) = chunk_grad_grad(irow, 1, 2)
    1191            0 :             TARGET(base + 7) = chunk_grad_grad(irow, 2, 2)
    1192            0 :             TARGET(base + 8) = chunk_grad_grad(irow, 3, 2)
    1193            0 :             TARGET(base + 9:base + 10) = chunk_kin_grad(irow, :)
    1194              :          END IF
    1195              :       END DO
    1196              : 
    1197            2 :       CALL torch_tensor_release(density_grad_t)
    1198            2 :       CALL torch_tensor_release(grad_grad_t)
    1199            2 :       CALL torch_tensor_release(kin_grad_t)
    1200              : 
    1201            2 :    END SUBROUTINE pack_atom_chunk_grads
    1202              : 
    1203              : ! **************************************************************************************************
    1204              : !> \brief Return CPU views of autograd outputs for the SKALA dynamic feature tensors.
    1205              : !> \param features ...
    1206              : !> \param density_grad_t ...
    1207              : !> \param grad_grad_t ...
    1208              : !> \param kin_grad_t ...
    1209              : !> \param density_grad ...
    1210              : !> \param grad_grad ...
    1211              : !> \param kin_grad ...
    1212              : ! **************************************************************************************************
    1213          144 :    SUBROUTINE get_feature_grad_views(features, density_grad_t, grad_grad_t, kin_grad_t, &
    1214              :                                      density_grad, grad_grad, kin_grad)
    1215              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1216              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: density_grad_t, grad_grad_t, kin_grad_t
    1217              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: density_grad
    1218              :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: grad_grad
    1219              :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: kin_grad
    1220              : 
    1221          144 :       NULLIFY (density_grad, grad_grad, kin_grad)
    1222          144 :       CALL torch_tensor_grad(features%density_t, density_grad_t)
    1223          144 :       CALL torch_tensor_grad(features%grad_t, grad_grad_t)
    1224          144 :       CALL torch_tensor_grad(features%kin_t, kin_grad_t)
    1225          144 :       CALL torch_tensor_data_ptr(density_grad_t, density_grad)
    1226          144 :       CALL torch_tensor_data_ptr(grad_grad_t, grad_grad)
    1227          144 :       CALL torch_tensor_data_ptr(kin_grad_t, kin_grad)
    1228              : 
    1229          144 :    END SUBROUTINE get_feature_grad_views
    1230              : 
    1231              : ! **************************************************************************************************
    1232              : !> \brief Fetch atom-chunk gradients and route them back to their local grid owners.
    1233              : !> \param features ...
    1234              : !> \param group ...
    1235              : !> \param density_grad ...
    1236              : !> \param grad_grad ...
    1237              : !> \param kin_grad ...
    1238              : ! **************************************************************************************************
    1239            0 :    SUBROUTINE fetch_and_gather_atom_chunk_grads(features, group, density_grad, grad_grad, &
    1240              :                                                 kin_grad)
    1241              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1242              : 
    1243              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1244              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
    1245              :          INTENT(OUT)                                     :: density_grad, kin_grad
    1246              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
    1247              :          INTENT(OUT)                                     :: grad_grad
    1248              : 
    1249              :       INTEGER                                            :: base, feature_pos, i, j, k, local_row, &
    1250              :                                                             nflat_local, nroute_grad_per_point, &
    1251              :                                                             nroute_points, phase_handle, point_pos, row
    1252            0 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: route_grad_return_recv_counts, &
    1253            0 :                                                             route_grad_return_recv_displs, &
    1254            0 :                                                             route_grad_return_send_counts, &
    1255            0 :                                                             route_grad_return_send_displs
    1256            0 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: chunk_grad_buffer, global_grad_buffer, &
    1257            0 :                                                             recv_grad_buffer, send_grad_buffer
    1258              : 
    1259            0 :       CPASSERT(features%uses_atom_chunks)
    1260            0 :       CPASSERT(features%chunk_feature_count > 0)
    1261              : 
    1262            0 :       nflat_local = features%nflat_local
    1263            0 :       IF (features%uses_atom_chunk_routing) THEN
    1264            0 :          CPASSERT(SUM(features%route_point_recv_counts) == features%chunk_feature_count)
    1265            0 :          nroute_points = SIZE(features%route_send_local_rows)
    1266            0 :          CPASSERT(SUM(features%route_point_send_counts) == nroute_points)
    1267              : 
    1268            0 :          nroute_grad_per_point = ngrad_per_point
    1269            0 :          IF (features%uses_collapsed_rks_dynamic) &
    1270            0 :             nroute_grad_per_point = ncollapsed_grad_per_point
    1271              :          ALLOCATE (send_grad_buffer(nroute_grad_per_point*features%chunk_feature_count), &
    1272              :                    recv_grad_buffer(nroute_grad_per_point*nroute_points), &
    1273              :                    route_grad_return_send_counts(SIZE(features%route_point_recv_counts)), &
    1274              :                    route_grad_return_send_displs(SIZE(features%route_point_recv_displs)), &
    1275              :                    route_grad_return_recv_counts(SIZE(features%route_point_send_counts)), &
    1276            0 :                    route_grad_return_recv_displs(SIZE(features%route_point_send_displs)))
    1277              :          route_grad_return_send_counts(:) = &
    1278            0 :             nroute_grad_per_point*features%route_point_recv_counts
    1279              :          route_grad_return_send_displs(:) = &
    1280            0 :             nroute_grad_per_point*features%route_point_recv_displs
    1281              :          route_grad_return_recv_counts(:) = &
    1282            0 :             nroute_grad_per_point*features%route_point_send_counts
    1283              :          route_grad_return_recv_displs(:) = &
    1284            0 :             nroute_grad_per_point*features%route_point_send_displs
    1285              : 
    1286            0 :          CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
    1287              :          CALL pack_atom_chunk_grads(features, send_grad_buffer, .TRUE., &
    1288            0 :                                     features%uses_collapsed_rks_dynamic)
    1289            0 :          CALL timestop(phase_handle)
    1290              : 
    1291            0 :          CALL timeset("skala_gpw_grad_route_comm", phase_handle)
    1292              :          CALL group%alltoall(send_grad_buffer, route_grad_return_send_counts, &
    1293              :                              route_grad_return_send_displs, recv_grad_buffer, &
    1294            0 :                              route_grad_return_recv_counts, route_grad_return_recv_displs)
    1295            0 :          CALL timestop(phase_handle)
    1296              : 
    1297            0 :          CALL timeset("skala_gpw_grad_route_scatter", phase_handle)
    1298            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
    1299            0 :                    kin_grad(nflat_local, 2))
    1300            0 :          density_grad = 0.0_dp
    1301            0 :          grad_grad = 0.0_dp
    1302            0 :          kin_grad = 0.0_dp
    1303            0 :          DO point_pos = 1, nroute_points
    1304            0 :             local_row = features%route_send_local_rows(point_pos)
    1305            0 :             CPASSERT(local_row >= 1 .AND. local_row <= nflat_local)
    1306            0 :             base = nroute_grad_per_point*(point_pos - 1)
    1307            0 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1308              :                density_grad(local_row, :) = density_grad(local_row, :) + &
    1309            0 :                                             recv_grad_buffer(base + 1)
    1310              :                grad_grad(local_row, 1, :) = grad_grad(local_row, 1, :) + &
    1311            0 :                                             recv_grad_buffer(base + 2)
    1312              :                grad_grad(local_row, 2, :) = grad_grad(local_row, 2, :) + &
    1313            0 :                                             recv_grad_buffer(base + 3)
    1314              :                grad_grad(local_row, 3, :) = grad_grad(local_row, 3, :) + &
    1315            0 :                                             recv_grad_buffer(base + 4)
    1316            0 :                kin_grad(local_row, :) = kin_grad(local_row, :) + recv_grad_buffer(base + 5)
    1317              :             ELSE
    1318              :                density_grad(local_row, :) = density_grad(local_row, :) + &
    1319            0 :                                             recv_grad_buffer(base + 1:base + 2)
    1320              :                grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
    1321            0 :                                             recv_grad_buffer(base + 3)
    1322              :                grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
    1323            0 :                                             recv_grad_buffer(base + 4)
    1324              :                grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
    1325            0 :                                             recv_grad_buffer(base + 5)
    1326              :                grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
    1327            0 :                                             recv_grad_buffer(base + 6)
    1328              :                grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
    1329            0 :                                             recv_grad_buffer(base + 7)
    1330              :                grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
    1331            0 :                                             recv_grad_buffer(base + 8)
    1332              :                kin_grad(local_row, :) = kin_grad(local_row, :) + &
    1333            0 :                                         recv_grad_buffer(base + 9:base + 10)
    1334              :             END IF
    1335              :          END DO
    1336            0 :          CALL timestop(phase_handle)
    1337              : 
    1338            0 :          DEALLOCATE (recv_grad_buffer, route_grad_return_recv_counts, &
    1339            0 :                      route_grad_return_recv_displs, route_grad_return_send_counts, &
    1340            0 :                      route_grad_return_send_displs, send_grad_buffer)
    1341              :       ELSE
    1342              :          ALLOCATE (chunk_grad_buffer(ngrad_per_point*features%chunk_feature_count), &
    1343            0 :                    global_grad_buffer(ngrad_per_point*features%nflat))
    1344            0 :          CALL timeset("skala_gpw_grad_torch_pack", phase_handle)
    1345            0 :          CALL pack_atom_chunk_grads(features, chunk_grad_buffer, .FALSE.)
    1346            0 :          CALL timestop(phase_handle)
    1347              : 
    1348            0 :          CALL timeset("skala_gpw_grad_allgatherv", phase_handle)
    1349              :          CALL group%allgatherv(chunk_grad_buffer, global_grad_buffer, &
    1350            0 :                                features%chunk_grad_counts, features%chunk_grad_displs)
    1351            0 :          CALL timestop(phase_handle)
    1352              : 
    1353            0 :          CALL timeset("skala_gpw_grad_scatter", phase_handle)
    1354            0 :          ALLOCATE (density_grad(nflat_local, 2), grad_grad(nflat_local, 3, 2), &
    1355            0 :                    kin_grad(nflat_local, 2))
    1356            0 :          density_grad = 0.0_dp
    1357            0 :          grad_grad = 0.0_dp
    1358            0 :          kin_grad = 0.0_dp
    1359            0 :          local_row = 0
    1360            0 :          DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1361            0 :             DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1362            0 :                DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1363            0 :                   local_row = local_row + 1
    1364            0 :                   DO feature_pos = features%local_feature_offsets(local_row), &
    1365            0 :                      features%local_feature_offsets(local_row + 1) - 1
    1366            0 :                      row = features%local_feature_rows(feature_pos)
    1367            0 :                      CPASSERT(row >= 1 .AND. row <= features%nflat)
    1368            0 :                      base = ngrad_per_point*(row - 1)
    1369              :                      density_grad(local_row, :) = density_grad(local_row, :) + &
    1370            0 :                                                   global_grad_buffer(base + 1:base + 2)
    1371              :                      grad_grad(local_row, 1, 1) = grad_grad(local_row, 1, 1) + &
    1372            0 :                                                   global_grad_buffer(base + 3)
    1373              :                      grad_grad(local_row, 2, 1) = grad_grad(local_row, 2, 1) + &
    1374            0 :                                                   global_grad_buffer(base + 4)
    1375              :                      grad_grad(local_row, 3, 1) = grad_grad(local_row, 3, 1) + &
    1376            0 :                                                   global_grad_buffer(base + 5)
    1377              :                      grad_grad(local_row, 1, 2) = grad_grad(local_row, 1, 2) + &
    1378            0 :                                                   global_grad_buffer(base + 6)
    1379              :                      grad_grad(local_row, 2, 2) = grad_grad(local_row, 2, 2) + &
    1380            0 :                                                   global_grad_buffer(base + 7)
    1381              :                      grad_grad(local_row, 3, 2) = grad_grad(local_row, 3, 2) + &
    1382            0 :                                                   global_grad_buffer(base + 8)
    1383              :                      kin_grad(local_row, :) = kin_grad(local_row, :) + &
    1384            0 :                                               global_grad_buffer(base + 9:base + 10)
    1385              :                   END DO
    1386              :                END DO
    1387              :             END DO
    1388              :          END DO
    1389            0 :          CALL timestop(phase_handle)
    1390            0 :          DEALLOCATE (chunk_grad_buffer, global_grad_buffer)
    1391              : 
    1392              :       END IF
    1393              : 
    1394            0 :    END SUBROUTINE fetch_and_gather_atom_chunk_grads
    1395              : 
    1396              : ! **************************************************************************************************
    1397              : !> \brief Build the native SKALA XC virial from feature gradients.
    1398              : !> \param virial_xc ...
    1399              : !> \param rho_set ...
    1400              : !> \param rho_r ...
    1401              : !> \param grad_grad ...
    1402              : ! **************************************************************************************************
    1403            8 :    SUBROUTINE build_virial_from_feature_grads(virial_xc, rho_set, rho_r, grad_grad)
    1404              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
    1405              :       TYPE(xc_rho_set_type), INTENT(IN)                  :: rho_set
    1406              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: rho_r
    1407              :       REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: grad_grad
    1408              : 
    1409              :       INTEGER                                            :: i, idir, ipt, ispin, j, jdir, k, nspins
    1410              :       INTEGER, DIMENSION(2, 3)                           :: bo
    1411              :       REAL(KIND=dp)                                      :: grad_i, tmp
    1412           96 :       TYPE(cp_3d_r_cp_type), DIMENSION(3)                :: drho, drhoa, drhob
    1413              : 
    1414            8 :       nspins = SIZE(rho_r)
    1415           80 :       bo = rho_r(1)%pw_grid%bounds_local
    1416            8 :       ipt = 0
    1417              : 
    1418            8 :       IF (nspins == 1) THEN
    1419            8 :          CALL xc_rho_set_get(rho_set, drho=drho)
    1420          200 :          DO k = bo(1, 3), bo(2, 3)
    1421         4808 :             DO j = bo(1, 2), bo(2, 2)
    1422        60096 :                DO i = bo(1, 1), bo(2, 1)
    1423        55296 :                   ipt = ipt + 1
    1424       225792 :                   DO idir = 1, 3
    1425       165888 :                      grad_i = 0.5_dp*(grad_grad(ipt, idir, 1) + grad_grad(ipt, idir, 2))
    1426       552960 :                      DO jdir = 1, idir
    1427       331776 :                         tmp = -grad_i*drho(jdir)%array(i, j, k)
    1428       331776 :                         virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
    1429       497664 :                         virial_xc(idir, jdir) = virial_xc(jdir, idir)
    1430              :                      END DO
    1431              :                   END DO
    1432              :                END DO
    1433              :             END DO
    1434              :          END DO
    1435              :       ELSE
    1436            0 :          CALL xc_rho_set_get(rho_set, drhoa=drhoa, drhob=drhob)
    1437            0 :          DO k = bo(1, 3), bo(2, 3)
    1438            0 :             DO j = bo(1, 2), bo(2, 2)
    1439            0 :                DO i = bo(1, 1), bo(2, 1)
    1440            0 :                   ipt = ipt + 1
    1441            0 :                   DO idir = 1, 3
    1442            0 :                      DO jdir = 1, idir
    1443              :                         tmp = 0.0_dp
    1444            0 :                         DO ispin = 1, 2
    1445            0 :                            IF (ispin == 1) THEN
    1446            0 :                               tmp = tmp - grad_grad(ipt, idir, ispin)*drhoa(jdir)%array(i, j, k)
    1447              :                            ELSE
    1448            0 :                               tmp = tmp - grad_grad(ipt, idir, ispin)*drhob(jdir)%array(i, j, k)
    1449              :                            END IF
    1450              :                         END DO
    1451            0 :                         virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
    1452            0 :                         virial_xc(idir, jdir) = virial_xc(jdir, idir)
    1453              :                      END DO
    1454              :                   END DO
    1455              :                END DO
    1456              :             END DO
    1457              :          END DO
    1458              :       END IF
    1459              : 
    1460            8 :    END SUBROUTINE build_virial_from_feature_grads
    1461              : 
    1462              : ! **************************************************************************************************
    1463              : !> \brief Print a native SKALA XC virial contribution for diagnostics.
    1464              : !> \param label ...
    1465              : !> \param delta ...
    1466              : !> \param root_rank ...
    1467              : ! **************************************************************************************************
    1468            0 :    SUBROUTINE print_virial_delta(label, delta, root_rank)
    1469              :       CHARACTER(LEN=*), INTENT(IN)                       :: label
    1470              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(IN)         :: delta
    1471              :       LOGICAL, INTENT(IN)                                :: root_rank
    1472              : 
    1473              :       INTEGER                                            :: i, iw
    1474              : 
    1475            0 :       IF (.NOT. root_rank) RETURN
    1476            0 :       iw = cp_logger_get_default_io_unit()
    1477            0 :       IF (iw <= 0) RETURN
    1478            0 :       WRITE (iw, "(T2,A,1X,A)") "SKALA_GPW| XC virial contribution", TRIM(label)
    1479            0 :       DO i = 1, 3
    1480            0 :          WRITE (iw, "(T2,A,1X,3ES20.10)") "SKALA_GPW|", delta(i, 1:3)
    1481              :       END DO
    1482              : 
    1483              :    END SUBROUTINE print_virial_delta
    1484              : 
    1485              : ! **************************************************************************************************
    1486              : !> \brief Add explicit SKALA coordinate-feature contributions to the XC virial.
    1487              : !> \param virial_xc ...
    1488              : !> \param features ...
    1489              : !> \param atom_coord_grad_t ...
    1490              : !> \param grid_coord_grad_t ...
    1491              : !> \param root_rank ...
    1492              : !> \param print_components ...
    1493              : ! **************************************************************************************************
    1494            8 :    SUBROUTINE build_static_coordinate_virial(virial_xc, features, atom_coord_grad_t, &
    1495              :                                              grid_coord_grad_t, root_rank, print_components)
    1496              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
    1497              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1498              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: atom_coord_grad_t, grid_coord_grad_t
    1499              :       LOGICAL, INTENT(IN)                                :: root_rank
    1500              :       LOGICAL, INTENT(IN), OPTIONAL                      :: print_components
    1501              : 
    1502              :       INTEGER                                            :: feature_pos, i, iatom, idir, iw, j, &
    1503              :                                                             jdir, k, local_row, row
    1504              :       LOGICAL                                            :: my_print_components
    1505              :       REAL(KIND=dp)                                      :: tmp
    1506              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: atom_virial, grid_virial
    1507            8 :       REAL(KIND=dp), DIMENSION(:, :), POINTER            :: atom_coord_grad, grid_coord_grad
    1508              : 
    1509            8 :       my_print_components = .FALSE.
    1510            8 :       IF (PRESENT(print_components)) my_print_components = print_components
    1511              : 
    1512            8 :       NULLIFY (atom_coord_grad, grid_coord_grad)
    1513            8 :       CALL torch_tensor_grad(features%grid_coords_t, grid_coord_grad_t)
    1514            8 :       CALL torch_tensor_data_ptr(grid_coord_grad_t, grid_coord_grad)
    1515            8 :       CALL torch_tensor_data_ptr(atom_coord_grad_t, atom_coord_grad)
    1516              : 
    1517            8 :       grid_virial = 0.0_dp
    1518            8 :       atom_virial = 0.0_dp
    1519            8 :       local_row = 0
    1520          216 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1521         5192 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1522        69312 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1523        55296 :                local_row = local_row + 1
    1524       124346 :                DO feature_pos = features%local_feature_offsets(local_row), &
    1525        59904 :                   features%local_feature_offsets(local_row + 1) - 1
    1526        69050 :                   row = features%local_feature_rows(feature_pos)
    1527       331496 :                   DO idir = 1, 3
    1528       690500 :                      DO jdir = 1, idir
    1529       414300 :                         tmp = grid_coord_grad(idir, row)*features%grid_coords(jdir, row)
    1530       414300 :                         grid_virial(jdir, idir) = grid_virial(jdir, idir) + tmp
    1531       414300 :                         grid_virial(idir, jdir) = grid_virial(jdir, idir)
    1532       414300 :                         virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
    1533       621450 :                         virial_xc(idir, jdir) = virial_xc(jdir, idir)
    1534              :                      END DO
    1535              :                   END DO
    1536              :                END DO
    1537              :             END DO
    1538              :          END DO
    1539              :       END DO
    1540            8 :       CPASSERT(local_row == features%nflat_local)
    1541              : 
    1542            8 :       IF (root_rank) THEN
    1543           12 :          DO iatom = 1, SIZE(features%coarse_0_atomic_coords, 2)
    1544           36 :             DO idir = 1, 3
    1545           80 :                DO jdir = 1, idir
    1546           48 :                   tmp = atom_coord_grad(idir, iatom)*features%coarse_0_atomic_coords(jdir, iatom)
    1547           48 :                   atom_virial(jdir, idir) = atom_virial(jdir, idir) + tmp
    1548           48 :                   atom_virial(idir, jdir) = atom_virial(jdir, idir)
    1549           48 :                   virial_xc(jdir, idir) = virial_xc(jdir, idir) + tmp
    1550           72 :                   virial_xc(idir, jdir) = virial_xc(jdir, idir)
    1551              :                END DO
    1552              :             END DO
    1553              :          END DO
    1554              :       END IF
    1555              : 
    1556            8 :       IF (my_print_components .AND. root_rank) THEN
    1557            0 :          iw = cp_logger_get_default_io_unit()
    1558            0 :          IF (iw > 0) THEN
    1559            0 :             CALL print_virial_delta("static-grid", grid_virial, .TRUE.)
    1560            0 :             CALL print_virial_delta("static-atom", atom_virial, .TRUE.)
    1561              :          END IF
    1562              :       END IF
    1563              : 
    1564            8 :       CALL torch_tensor_release(grid_coord_grad_t)
    1565              : 
    1566            8 :    END SUBROUTINE build_static_coordinate_virial
    1567              : 
    1568              : ! **************************************************************************************************
    1569              : !> \brief Add residual SKALA weight-feature contributions to the XC virial.
    1570              : !> \param virial_xc ...
    1571              : !> \param features ...
    1572              : !> \param exc ...
    1573              : !> \param grid_weight_grad_t ...
    1574              : !> \param atomic_grid_weight_grad_t ...
    1575              : !> \param root_rank ...
    1576              : !> \param print_components ...
    1577              : ! **************************************************************************************************
    1578            8 :    SUBROUTINE build_weight_virial(virial_xc, features, exc, grid_weight_grad_t, &
    1579              :                                   atomic_grid_weight_grad_t, root_rank, print_components)
    1580              :       REAL(KIND=dp), DIMENSION(3, 3), INTENT(INOUT)      :: virial_xc
    1581              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1582              :       REAL(KIND=dp), INTENT(IN)                          :: exc
    1583              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grid_weight_grad_t, &
    1584              :                                                             atomic_grid_weight_grad_t
    1585              :       LOGICAL, INTENT(IN)                                :: root_rank
    1586              :       LOGICAL, INTENT(IN), OPTIONAL                      :: print_components
    1587              : 
    1588              :       INTEGER                                            :: feature_pos, i, idir, iw, j, k, &
    1589              :                                                             local_row, row
    1590              :       LOGICAL                                            :: my_print_components
    1591              :       REAL(KIND=dp)                                      :: atomic_tmp, exc_tmp, grid_tmp, tmp
    1592            8 :       REAL(KIND=dp), DIMENSION(:), POINTER               :: atomic_grid_weight_grad, grid_weight_grad
    1593              : 
    1594            8 :       my_print_components = .FALSE.
    1595            8 :       IF (PRESENT(print_components)) my_print_components = print_components
    1596              : 
    1597            8 :       NULLIFY (atomic_grid_weight_grad, grid_weight_grad)
    1598            8 :       CALL torch_tensor_grad(features%grid_weights_t, grid_weight_grad_t)
    1599            8 :       CALL torch_tensor_grad(features%atomic_grid_weights_t, atomic_grid_weight_grad_t)
    1600            8 :       CALL torch_tensor_data_ptr(grid_weight_grad_t, grid_weight_grad)
    1601            8 :       CALL torch_tensor_data_ptr(atomic_grid_weight_grad_t, atomic_grid_weight_grad)
    1602              : 
    1603            8 :       grid_tmp = 0.0_dp
    1604            8 :       atomic_tmp = 0.0_dp
    1605            8 :       local_row = 0
    1606          216 :       DO k = LBOUND(features%feature_index, 3), UBOUND(features%feature_index, 3)
    1607         5192 :          DO j = LBOUND(features%feature_index, 2), UBOUND(features%feature_index, 2)
    1608        69312 :             DO i = LBOUND(features%feature_index, 1), UBOUND(features%feature_index, 1)
    1609        55296 :                local_row = local_row + 1
    1610       124346 :                DO feature_pos = features%local_feature_offsets(local_row), &
    1611        59904 :                   features%local_feature_offsets(local_row + 1) - 1
    1612        69050 :                   row = features%local_feature_rows(feature_pos)
    1613        69050 :                   grid_tmp = grid_tmp + grid_weight_grad(row)*features%grid_weights(row)
    1614              :                   atomic_tmp = atomic_tmp + &
    1615       124346 :                                atomic_grid_weight_grad(row)*features%atomic_grid_weights(row)
    1616              :                END DO
    1617              :             END DO
    1618              :          END DO
    1619              :       END DO
    1620            8 :       CPASSERT(local_row == features%nflat_local)
    1621            8 :       exc_tmp = 0.0_dp
    1622            8 :       IF (root_rank) exc_tmp = -exc
    1623            8 :       tmp = grid_tmp + atomic_tmp + exc_tmp
    1624              : 
    1625            8 :       IF (my_print_components .AND. root_rank) THEN
    1626            0 :          iw = cp_logger_get_default_io_unit()
    1627            0 :          IF (iw > 0) THEN
    1628            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight grid", grid_tmp
    1629            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight atomic", atomic_tmp
    1630            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight final", exc_tmp
    1631            0 :             WRITE (iw, "(T2,A,1X,ES20.10)") "SKALA_GPW| XC virial weight residual", tmp
    1632              :          END IF
    1633              :       END IF
    1634              : 
    1635           32 :       DO idir = 1, 3
    1636           32 :          virial_xc(idir, idir) = virial_xc(idir, idir) + tmp
    1637              :       END DO
    1638              : 
    1639            8 :       CALL torch_tensor_release(grid_weight_grad_t)
    1640            8 :       CALL torch_tensor_release(atomic_grid_weight_grad_t)
    1641              : 
    1642            8 :    END SUBROUTINE build_weight_virial
    1643              : 
    1644              : ! **************************************************************************************************
    1645              : !> \brief Fill CP2K VXC real-space arrays from Torch feature gradients.
    1646              : !> \param vxc_rho ...
    1647              : !> \param vxc_tau ...
    1648              : !> \param rho_r ...
    1649              : !> \param pw_pool ...
    1650              : !> \param density_grad ...
    1651              : !> \param grad_grad ...
    1652              : !> \param kin_grad ...
    1653              : !> \param xc_deriv_method_id ...
    1654              : ! **************************************************************************************************
    1655          144 :    SUBROUTINE build_vxc_from_feature_grads(vxc_rho, vxc_tau, rho_r, pw_pool, &
    1656          144 :                                            density_grad, grad_grad, kin_grad, &
    1657              :                                            xc_deriv_method_id)
    1658              :       TYPE(pw_r3d_rs_type), DIMENSION(:), POINTER        :: vxc_rho, vxc_tau, rho_r
    1659              :       TYPE(pw_pool_type), POINTER                        :: pw_pool
    1660              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: density_grad
    1661              :       REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: grad_grad
    1662              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: kin_grad
    1663              :       INTEGER, INTENT(IN)                                :: xc_deriv_method_id
    1664              : 
    1665              :       INTEGER                                            :: i, ipt, ispin, j, k, nspins
    1666              :       INTEGER, DIMENSION(2, 3)                           :: bo
    1667              :       REAL(KIND=dp)                                      :: dvol_inv
    1668              :       TYPE(pw_c1d_gs_type)                               :: tmp_g, vxc_g
    1669          576 :       TYPE(pw_r3d_rs_type), DIMENSION(3)                 :: grad_pw
    1670              : 
    1671          144 :       nspins = SIZE(rho_r)
    1672         1440 :       bo = rho_r(1)%pw_grid%bounds_local
    1673          144 :       dvol_inv = 1.0_dp/rho_r(1)%pw_grid%dvol
    1674              : 
    1675          972 :       ALLOCATE (vxc_rho(nspins), vxc_tau(nspins))
    1676          342 :       DO ispin = 1, nspins
    1677          198 :          CALL pw_pool%create_pw(vxc_rho(ispin))
    1678          198 :          CALL pw_pool%create_pw(vxc_tau(ispin))
    1679          198 :          CALL pw_zero(vxc_rho(ispin))
    1680          342 :          CALL pw_zero(vxc_tau(ispin))
    1681              :       END DO
    1682              : 
    1683          144 :       IF (xc_requires_tmp_g(xc_deriv_method_id) .OR. rho_r(1)%pw_grid%spherical) THEN
    1684          144 :          CALL pw_pool%create_pw(vxc_g)
    1685          144 :          IF (.NOT. rho_r(1)%pw_grid%spherical) CALL pw_pool%create_pw(tmp_g)
    1686              :       END IF
    1687              : 
    1688          342 :       DO ispin = 1, nspins
    1689          792 :          DO i = 1, 3
    1690          594 :             CALL pw_pool%create_pw(grad_pw(i))
    1691          792 :             CALL pw_zero(grad_pw(i))
    1692              :          END DO
    1693              : 
    1694          198 :          ipt = 0
    1695         4304 :          DO k = bo(1, 3), bo(2, 3)
    1696       110722 :             DO j = bo(1, 2), bo(2, 2)
    1697      1831517 :                DO i = bo(1, 1), bo(2, 1)
    1698      1720993 :                   ipt = ipt + 1
    1699      1827411 :                   IF (nspins == 1) THEN
    1700              :                      vxc_rho(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1701      1012243 :                                                  (density_grad(ipt, 1) + density_grad(ipt, 2))
    1702              :                      vxc_tau(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1703      1012243 :                                                  (kin_grad(ipt, 1) + kin_grad(ipt, 2))
    1704              :                      grad_pw(1)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1705      1012243 :                                                  (grad_grad(ipt, 1, 1) + grad_grad(ipt, 1, 2))
    1706              :                      grad_pw(2)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1707      1012243 :                                                  (grad_grad(ipt, 2, 1) + grad_grad(ipt, 2, 2))
    1708              :                      grad_pw(3)%array(i, j, k) = 0.5_dp*dvol_inv* &
    1709      1012243 :                                                  (grad_grad(ipt, 3, 1) + grad_grad(ipt, 3, 2))
    1710              :                   ELSE
    1711       708750 :                      vxc_rho(ispin)%array(i, j, k) = dvol_inv*density_grad(ipt, ispin)
    1712       708750 :                      vxc_tau(ispin)%array(i, j, k) = dvol_inv*kin_grad(ipt, ispin)
    1713       708750 :                      grad_pw(1)%array(i, j, k) = dvol_inv*grad_grad(ipt, 1, ispin)
    1714       708750 :                      grad_pw(2)%array(i, j, k) = dvol_inv*grad_grad(ipt, 2, ispin)
    1715       708750 :                      grad_pw(3)%array(i, j, k) = dvol_inv*grad_grad(ipt, 3, ispin)
    1716              :                   END IF
    1717              :                END DO
    1718              :             END DO
    1719              :          END DO
    1720              : 
    1721          792 :          DO i = 1, 3
    1722          792 :             CALL pw_scale(grad_pw(i), -1.0_dp)
    1723              :          END DO
    1724          198 :          CALL xc_pw_divergence(xc_deriv_method_id, grad_pw, tmp_g, vxc_g, vxc_rho(ispin))
    1725              : 
    1726          936 :          DO i = 1, 3
    1727          792 :             CALL pw_pool%give_back_pw(grad_pw(i))
    1728              :          END DO
    1729              :       END DO
    1730              : 
    1731          144 :       IF (ASSOCIATED(vxc_g%pw_grid)) CALL pw_pool%give_back_pw(vxc_g)
    1732          144 :       IF (ASSOCIATED(tmp_g%pw_grid)) CALL pw_pool%give_back_pw(tmp_g)
    1733              : 
    1734          144 :    END SUBROUTINE build_vxc_from_feature_grads
    1735              : 
    1736              : ! **************************************************************************************************
    1737              : !> \brief Print optional diagnostics for the CP2K-native SKALA GPW feature block.
    1738              : !> \param features ...
    1739              : !> \param print_active ...
    1740              : ! **************************************************************************************************
    1741           24 :    SUBROUTINE print_native_grid_diagnostics(features, print_active)
    1742              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: features
    1743              :       LOGICAL, INTENT(IN)                                :: print_active
    1744              : 
    1745              :       INTEGER                                            :: atom_rows_max, atom_rows_min, &
    1746              :                                                             chunk_rows_max, chunk_rows_min, iw
    1747              :       REAL(KIND=dp)                                      :: chunk_imbalance
    1748              : 
    1749           24 :       IF (.NOT. print_active) RETURN
    1750              : 
    1751           12 :       iw = cp_logger_get_default_io_unit()
    1752           12 :       IF (iw <= 0) RETURN
    1753              :       WRITE (UNIT=iw, FMT="(/,T2,A,1X,ES19.11)") &
    1754           12 :          "SKALA_GPW| Native grid feature electrons", features%electron_count
    1755              :       WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
    1756           12 :          "SKALA_GPW| Native grid feature spin moment", features%spin_moment
    1757              :       WRITE (UNIT=iw, FMT="(T2,A,1X,ES19.11)") &
    1758           12 :          "SKALA_GPW| Native grid feature weight sum", features%grid_weight_sum
    1759           12 :       IF (ALLOCATED(features%atomic_grid_sizes)) THEN
    1760           49 :          atom_rows_min = INT(MINVAL(features%atomic_grid_sizes))
    1761           49 :          atom_rows_max = INT(MAXVAL(features%atomic_grid_sizes))
    1762              :          WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
    1763           12 :             "SKALA_GPW| Native grid atom row range", atom_rows_min, "to", &
    1764           61 :             atom_rows_max, "sum", INT(SUM(features%atomic_grid_sizes))
    1765              :       END IF
    1766           12 :       IF (features%uses_atom_chunks) THEN
    1767              :          WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0)") &
    1768            1 :             "SKALA_GPW| Native grid atom chunk rows", features%chunk_feature_count, &
    1769            2 :             "of", features%nflat
    1770            1 :          IF (ALLOCATED(features%chunk_grad_counts)) THEN
    1771            3 :             chunk_rows_min = MINVAL(features%chunk_grad_counts)/ngrad_per_point
    1772            3 :             chunk_rows_max = MAXVAL(features%chunk_grad_counts)/ngrad_per_point
    1773            1 :             chunk_imbalance = REAL(chunk_rows_max, KIND=dp)/REAL(MAX(1, chunk_rows_min), KIND=dp)
    1774              :             WRITE (UNIT=iw, FMT="(T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,ES12.5)") &
    1775            1 :                "SKALA_GPW| Native grid atom chunk row range", chunk_rows_min, &
    1776            2 :                "to", chunk_rows_max, "imbalance", chunk_imbalance
    1777              :          END IF
    1778              :       END IF
    1779              : 
    1780              :    END SUBROUTINE print_native_grid_diagnostics
    1781              : 
    1782              : ! **************************************************************************************************
    1783              : !> \brief Configure CUDA device selection for the native SKALA GPW Torch path.
    1784              : !> \param use_cuda ...
    1785              : !> \param requested_device ...
    1786              : !> \param group ...
    1787              : !> \return selected CUDA device, or -1 for CPU fallback/no visible CUDA device
    1788              : ! **************************************************************************************************
    1789          152 :    FUNCTION configure_native_grid_cuda(use_cuda, requested_device, group) RESULT(selected_device)
    1790              :       LOGICAL, INTENT(IN)                                :: use_cuda
    1791              :       INTEGER, INTENT(IN)                                :: requested_device
    1792              : 
    1793              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1794              : 
    1795              :       INTEGER                                            :: cuda_device_count, iw, pe, selected_device
    1796          152 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: selected_devices
    1797              : 
    1798          152 :       selected_device = -1
    1799              : 
    1800          152 :       IF (.NOT. use_cuda) RETURN
    1801              : 
    1802            0 :       IF (.NOT. torch_cuda_is_available()) THEN
    1803            0 :          cuda_device_count = 0
    1804              :       ELSE
    1805            0 :          cuda_device_count = torch_cuda_device_count()
    1806              :       END IF
    1807            0 :       IF (cuda_device_count > 0) THEN
    1808            0 :          IF (requested_device < 0) THEN
    1809            0 :             selected_device = MOD(group%mepos, cuda_device_count)
    1810              :          ELSE
    1811            0 :             selected_device = requested_device
    1812              :          END IF
    1813              :       END IF
    1814            0 :       IF (selected_device >= cuda_device_count) THEN
    1815              :          CALL cp_abort(__LOCATION__, &
    1816              :                        "GAUXC%NATIVE_GRID_CUDA_DEVICE selects a CUDA device outside the visible "// &
    1817            0 :                        "Torch CUDA device range.")
    1818              :       END IF
    1819            0 :       IF (selected_device >= 0) CALL offload_set_chosen_device(selected_device)
    1820              : 
    1821            0 :       ALLOCATE (selected_devices(group%num_pe))
    1822            0 :       CALL group%allgather(selected_device, selected_devices)
    1823              : 
    1824            0 :       IF (group%mepos /= 0) RETURN
    1825              :       IF (selected_device == logged_cuda_device .AND. &
    1826              :           cuda_device_count == logged_cuda_device_count .AND. &
    1827            0 :           group%num_pe == logged_cuda_nproc .AND. &
    1828              :           requested_device == logged_cuda_request) RETURN
    1829              : 
    1830            0 :       iw = cp_logger_get_default_io_unit()
    1831            0 :       IF (iw <= 0) RETURN
    1832            0 :       IF (selected_device >= 0) THEN
    1833              :          WRITE (UNIT=iw, FMT="(/,T2,A,1X,I0,1X,A,1X,I0,1X,A,1X,I0)") &
    1834            0 :             "SKALA_GPW| Native grid Torch CUDA device", selected_device, &
    1835            0 :             "of", cuda_device_count, "requested", requested_device
    1836              :       ELSE
    1837              :          WRITE (UNIT=iw, FMT="(/,T2,A)") &
    1838            0 :             "SKALA_GPW| Native grid Torch CUDA requested, but no Torch CUDA device is visible"
    1839              :       END IF
    1840              :       WRITE (UNIT=iw, FMT="(T2,A)", ADVANCE="NO") &
    1841            0 :          "SKALA_GPW| Native grid Torch CUDA rank devices"
    1842            0 :       DO pe = 1, group%num_pe
    1843            0 :          WRITE (UNIT=iw, FMT="(1X,I0,A,I0)", ADVANCE="NO") pe - 1, ":", selected_devices(pe)
    1844              :       END DO
    1845            0 :       WRITE (UNIT=iw, FMT=*)
    1846              : 
    1847            0 :       logged_cuda_device = selected_device
    1848            0 :       logged_cuda_device_count = cuda_device_count
    1849            0 :       logged_cuda_nproc = group%num_pe
    1850            0 :       logged_cuda_request = requested_device
    1851              : 
    1852          152 :    END FUNCTION configure_native_grid_cuda
    1853              : 
    1854              : ! **************************************************************************************************
    1855              : !> \brief Load and cache the TorchScript SKALA model.
    1856              : !> \param model_path ...
    1857              : !> \param cuda_device ...
    1858              : ! **************************************************************************************************
    1859          152 :    SUBROUTINE ensure_model_loaded(model_path, cuda_device)
    1860              :       CHARACTER(len=*), INTENT(IN)                       :: model_path
    1861              :       INTEGER, INTENT(IN)                                :: cuda_device
    1862              : 
    1863          152 :       IF (cached_model_loaded) THEN
    1864          108 :          IF (TRIM(cached_model_path) == TRIM(model_path) .AND. &
    1865              :              cached_model_cuda_device == cuda_device) RETURN
    1866            0 :          CALL skala_torch_model_release(cached_model)
    1867            0 :          cached_model_loaded = .FALSE.
    1868              :       END IF
    1869              : 
    1870           44 :       CALL skala_torch_model_load(cached_model, TRIM(model_path))
    1871           44 :       cached_model_path = model_path
    1872           44 :       cached_model_cuda_device = cuda_device
    1873           44 :       cached_model_loaded = .TRUE.
    1874              : 
    1875          152 :    END SUBROUTINE ensure_model_loaded
    1876              : 
    1877              : ! **************************************************************************************************
    1878              : !> \brief Resolve the SKALA TorchScript model path from the GAUXC subsection.
    1879              : !> \param xc_section ...
    1880              : !> \param model_path ...
    1881              : ! **************************************************************************************************
    1882          152 :    SUBROUTINE get_skala_model_path(xc_section, model_path)
    1883              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
    1884              :       CHARACTER(len=default_path_length), INTENT(OUT)    :: model_path
    1885              : 
    1886              :       CHARACTER(len=default_path_length)                 :: model_key
    1887              :       INTEGER                                            :: env_status
    1888              :       LOGICAL                                            :: native_grid_use_cuda
    1889              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
    1890              : 
    1891          152 :       gauxc_section => get_gauxc_section(xc_section)
    1892          152 :       IF (.NOT. ASSOCIATED(gauxc_section)) THEN
    1893            0 :          CPABORT("Native SKALA GPW requires an XC_FUNCTIONAL%GAUXC section")
    1894              :       END IF
    1895              : 
    1896          152 :       CALL section_vals_val_get(gauxc_section, "MODEL", c_val=model_path)
    1897          152 :       model_key = ADJUSTL(model_path)
    1898          152 :       CALL uppercase(model_key)
    1899          152 :       IF (TRIM(model_key) == "NONE" .OR. TRIM(model_key) == "") THEN
    1900            0 :          CPABORT("Native SKALA GPW requires GAUXC%MODEL SKALA or a TorchScript model path")
    1901          152 :       ELSE IF (TRIM(model_key) == "SKALA") THEN
    1902          152 :          CALL section_vals_val_get(gauxc_section, "NATIVE_GRID_USE_CUDA", l_val=native_grid_use_cuda)
    1903          152 :          IF (native_grid_use_cuda) THEN
    1904            0 :             CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_CUDA_MODEL", model_path, STATUS=env_status)
    1905            0 :             IF (env_status == 0 .AND. LEN_TRIM(model_path) > 0) RETURN
    1906              :          END IF
    1907          152 :          CALL GET_ENVIRONMENT_VARIABLE("GAUXC_SKALA_MODEL", model_path, STATUS=env_status)
    1908          152 :          IF (env_status /= 0 .OR. LEN_TRIM(model_path) == 0) THEN
    1909            0 :             IF (native_grid_use_cuda) THEN
    1910              :                CALL cp_abort(__LOCATION__, &
    1911            0 :                              "MODEL SKALA CUDA path requires GAUXC_SKALA_CUDA_MODEL or GAUXC_SKALA_MODEL")
    1912              :             ELSE
    1913              :                CALL cp_abort(__LOCATION__, &
    1914            0 :                              "MODEL SKALA requires the GAUXC_SKALA_MODEL environment variable")
    1915              :             END IF
    1916              :          END IF
    1917              :       END IF
    1918              : 
    1919              :    END SUBROUTINE get_skala_model_path
    1920              : 
    1921              : ! **************************************************************************************************
    1922              : !> \brief Return the first GAUXC functional subsection, if present.
    1923              : !> \param xc_section ...
    1924              : !> \return ...
    1925              : ! **************************************************************************************************
    1926       184095 :    FUNCTION get_gauxc_section(xc_section) RESULT(gauxc_section)
    1927              :       TYPE(section_vals_type), INTENT(IN), POINTER       :: xc_section
    1928              :       TYPE(section_vals_type), POINTER                   :: gauxc_section
    1929              : 
    1930              :       INTEGER                                            :: ifun
    1931              :       TYPE(section_vals_type), POINTER                   :: functionals, xc_fun
    1932              : 
    1933       184095 :       NULLIFY (gauxc_section)
    1934       184095 :       IF (.NOT. ASSOCIATED(xc_section)) RETURN
    1935              : 
    1936       184095 :       functionals => section_vals_get_subs_vals(xc_section, "XC_FUNCTIONAL")
    1937       184095 :       IF (.NOT. ASSOCIATED(functionals)) RETURN
    1938              : 
    1939       184095 :       ifun = 0
    1940              :       DO
    1941       370518 :          ifun = ifun + 1
    1942       370518 :          xc_fun => section_vals_get_subs_vals2(functionals, i_section=ifun)
    1943       370518 :          IF (.NOT. ASSOCIATED(xc_fun)) EXIT
    1944       370518 :          IF (xc_fun%section%name == "GAUXC") THEN
    1945              :             gauxc_section => xc_fun
    1946              :             EXIT
    1947              :          END IF
    1948              :       END DO
    1949              : 
    1950              :    END FUNCTION get_gauxc_section
    1951              : 
    1952              : END MODULE skala_gpw_functional
        

Generated by: LCOV version 2.0-1