LCOV - code coverage report
Current view: top level - src - skala_gpw_features.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:561f475) Lines: 84.9 % 1335 1133
Test Date: 2026-06-21 06:48:54 Functions: 83.3 % 36 30

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Build SKALA TorchScript feature dictionaries from CP2K GPW real-space grids.
      10              : ! **************************************************************************************************
      11              : MODULE skala_gpw_features
      12              :    USE cell_types,                      ONLY: cell_type,&
      13              :                                               pbc
      14              :    USE cp_array_utils,                  ONLY: cp_3d_r_cp_type
      15              :    USE kinds,                           ONLY: dp,&
      16              :                                               int_8
      17              :    USE message_passing,                 ONLY: mp_comm_type
      18              :    USE particle_types,                  ONLY: particle_type
      19              :    USE pw_grid_types,                   ONLY: pw_grid_type
      20              :    USE pw_types,                        ONLY: pw_r3d_rs_type
      21              :    USE torch_api,                       ONLY: &
      22              :         torch_dict_clone, torch_dict_create, torch_dict_insert, torch_dict_release, &
      23              :         torch_dict_type, torch_tensor_expand_dim, torch_tensor_from_array, torch_tensor_narrow, &
      24              :         torch_tensor_release, torch_tensor_reset_from_array, torch_tensor_to_device_leaf, &
      25              :         torch_tensor_type
      26              :    USE xc_rho_set_types,                ONLY: xc_rho_set_get,&
      27              :                                               xc_rho_set_type
      28              : #include "./base/base_uses.f90"
      29              : 
      30              :    IMPLICIT NONE
      31              : 
      32              :    PRIVATE
      33              : 
      34              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_gpw_features'
      35              :    REAL(KIND=dp), PARAMETER, PRIVATE    :: layout_tol = 1.0E-12_dp
      36              :    INTEGER, PARAMETER, PRIVATE          :: ndynamic_per_point = 10, nrks_dynamic_per_point = 5, &
      37              :                                            nstatic_per_point = 4, ngrad_per_point = 10
      38              :    INTEGER, PARAMETER, PUBLIC           :: skala_gpw_atom_partition_hard = 1, &
      39              :                                            skala_gpw_atom_partition_smooth = 2
      40              :    REAL(KIND=dp), PARAMETER, PRIVATE    :: smooth_partition_eps = 1.0E-12_dp
      41              : 
      42              :    PUBLIC :: skala_gpw_atom_subchunk_count, skala_gpw_feature_build, &
      43              :              skala_gpw_feature_build_atom_subchunk, skala_gpw_feature_release, &
      44              :              skala_gpw_feature_type, skala_gpw_smooth_partition_derivatives
      45              : 
      46              :    TYPE skala_gpw_layout_cache_type
      47              :       INTEGER                                            :: chunk_atom_begin = 1, chunk_atom_end = 0, &
      48              :                                                             chunk_feature_begin = 1, &
      49              :                                                             chunk_feature_count = 0, chunk_natom = 0, &
      50              :                                                             natom = 0, nflat = 0, nflat_local = 0, &
      51              :                                                             npoint = 0, nproc = 0, &
      52              :                                                             atom_partition = skala_gpw_atom_partition_hard
      53              :       INTEGER, DIMENSION(2, 3)                           :: bo = 0, bounds = 0
      54              :       INTEGER, DIMENSION(3)                              :: npts = 0
      55              :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dynamic_counts, dynamic_displs, &
      56              :                                                             chunk_feature_counts, chunk_feature_displs, &
      57              :                                                             chunk_grad_counts, chunk_grad_displs, &
      58              :                                                             feature_counts, feature_displs, &
      59              :                                                             feature_source_points, global_to_feature, &
      60              :                                                             local_feature_counts, local_feature_offsets, &
      61              :                                                             local_feature_points, local_feature_rows, &
      62              :                                                             route_grad_return_recv_counts, &
      63              :                                                             route_grad_return_recv_displs, &
      64              :                                                             route_grad_return_send_counts, &
      65              :                                                             route_grad_return_send_displs, &
      66              :                                                             route_local_dest, chunk_return_positions, &
      67              :                                                             route_point_recv_counts, &
      68              :                                                             route_point_recv_displs, &
      69              :                                                             route_point_send_counts, &
      70              :                                                             route_point_send_displs, &
      71              :                                                             route_send_local_rows
      72              :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: feature_index
      73              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: atomic_grid_sizes, chunk_atomic_grid_sizes, &
      74              :                                                             chunk_feature_indices
      75              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: local_feature_indices
      76              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :)  :: atomic_grid_size_bound_shape, &
      77              :                                                             chunk_atomic_grid_size_bound_shape
      78              :       TYPE(torch_dict_type)                              :: chunk_inputs
      79              :       TYPE(torch_dict_type)                              :: chunk_static_inputs
      80              :       TYPE(torch_dict_type)                              :: inputs
      81              :       TYPE(torch_dict_type)                              :: static_inputs
      82              :       TYPE(torch_tensor_type)                            :: atomic_grid_size_bound_shape_t
      83              :       TYPE(torch_tensor_type)                            :: atomic_grid_sizes_t
      84              :       TYPE(torch_tensor_type)                            :: atomic_grid_weights_t
      85              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_size_bound_shape_t
      86              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_sizes_t
      87              :       TYPE(torch_tensor_type)                            :: chunk_atomic_grid_weights_t
      88              :       TYPE(torch_tensor_type)                            :: chunk_coarse_0_atomic_coords_t
      89              :       TYPE(torch_tensor_type)                            :: chunk_density_t
      90              :       TYPE(torch_tensor_type)                            :: chunk_density_input_t
      91              :       TYPE(torch_tensor_type)                            :: chunk_feature_indices_t
      92              :       TYPE(torch_tensor_type)                            :: chunk_grad_t
      93              :       TYPE(torch_tensor_type)                            :: chunk_grad_input_t
      94              :       TYPE(torch_tensor_type)                            :: chunk_grid_coords_t
      95              :       TYPE(torch_tensor_type)                            :: chunk_grid_weights_t
      96              :       TYPE(torch_tensor_type)                            :: chunk_kin_t
      97              :       TYPE(torch_tensor_type)                            :: chunk_kin_input_t
      98              :       TYPE(torch_tensor_type)                            :: coarse_0_atomic_coords_t
      99              :       TYPE(torch_tensor_type)                            :: density_t
     100              :       TYPE(torch_tensor_type)                            :: grid_coords_t
     101              :       TYPE(torch_tensor_type)                            :: grid_weights_t
     102              :       TYPE(torch_tensor_type)                            :: grad_t
     103              :       TYPE(torch_tensor_type)                            :: kin_t
     104              :       TYPE(torch_tensor_type)                            :: local_feature_indices_t
     105              :       REAL(KIND=dp)                                      :: dvol = 0.0_dp, weight_sum = 0.0_dp, &
     106              :                                                             weight_sumsq = 0.0_dp
     107              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: cell_hmat = 0.0_dp, dh = 0.0_dp
     108              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: atomic_grid_weights, chunk_atomic_grid_weights, &
     109              :                                                             chunk_grid_weights, grid_weights
     110              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords, chunk_coarse_0_atomic_coords, &
     111              :                                                             chunk_grid_coords, coarse_0_atomic_coords, &
     112              :                                                             grid_coords
     113              :       LOGICAL                                            :: active = .FALSE., has_weights = .FALSE., &
     114              :                                                             chunk_dynamic_input_views_active = .FALSE., &
     115              :                                                             chunk_dynamic_tensors_active = .FALSE., &
     116              :                                                             chunk_inputs_active = .FALSE., &
     117              :                                                             chunk_inputs_use_collapsed_rks = .FALSE., &
     118              :                                                             chunk_static_tensors_active = .FALSE., &
     119              :                                                             dynamic_tensors_active = .FALSE., &
     120              :                                                             inputs_active = .FALSE., &
     121              :                                                             static_tensors_active = .FALSE.
     122              :    END TYPE skala_gpw_layout_cache_type
     123              : 
     124              :    TYPE skala_gpw_feature_type
     125              :       INTEGER                                            :: chunk_feature_count = 0, nflat = 0, &
     126              :                                                             nflat_local = 0, &
     127              :                                                             atom_partition = skala_gpw_atom_partition_hard
     128              :       TYPE(torch_dict_type)                             :: inputs
     129              :       TYPE(torch_tensor_type)                           :: atomic_grid_size_bound_shape_t
     130              :       TYPE(torch_tensor_type)                           :: atomic_grid_sizes_t
     131              :       TYPE(torch_tensor_type)                           :: atomic_grid_weights_t
     132              :       TYPE(torch_tensor_type)                           :: coarse_0_atomic_coords_t
     133              :       TYPE(torch_tensor_type)                           :: density_input_t
     134              :       TYPE(torch_tensor_type)                           :: density_t
     135              :       TYPE(torch_tensor_type)                           :: grad_t
     136              :       TYPE(torch_tensor_type)                           :: grad_input_t
     137              :       TYPE(torch_tensor_type)                           :: grid_coords_t
     138              :       TYPE(torch_tensor_type)                           :: grid_weights_t
     139              :       TYPE(torch_tensor_type)                           :: kin_input_t
     140              :       TYPE(torch_tensor_type)                           :: kin_t
     141              :       TYPE(torch_tensor_type)                           :: local_feature_indices_t
     142              :       INTEGER, ALLOCATABLE, DIMENSION(:)                :: chunk_grad_counts, chunk_grad_displs, &
     143              :                                                            local_feature_counts, local_feature_offsets, &
     144              :                                                            local_feature_rows, &
     145              :                                                            chunk_return_positions, &
     146              :                                                            route_grad_return_recv_counts, &
     147              :                                                            route_grad_return_recv_displs, &
     148              :                                                            route_grad_return_send_counts, &
     149              :                                                            route_grad_return_send_displs, &
     150              :                                                            route_point_recv_counts, &
     151              :                                                            route_point_recv_displs, &
     152              :                                                            route_point_send_counts, &
     153              :                                                            route_point_send_displs, &
     154              :                                                            route_send_local_rows
     155              :       INTEGER, ALLOCATABLE, DIMENSION(:, :, :)          :: feature_index
     156              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)    :: atomic_grid_sizes
     157              :       INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:, :) :: atomic_grid_size_bound_shape
     158              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)          :: atomic_grid_weights, grid_weights
     159              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)       :: chunk_density, chunk_kin, &
     160              :                                                            coarse_0_atomic_coords, density, &
     161              :                                                            grid_coords, kin
     162              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)    :: chunk_grad, grad
     163              :       REAL(KIND=dp)                                      :: electron_count = 0.0_dp, &
     164              :                                                             grid_weight_sum = 0.0_dp, &
     165              :                                                             spin_moment = 0.0_dp
     166              :       LOGICAL                                            :: active = .FALSE., owns_coordinate_tensor = .FALSE., &
     167              :                                                             owns_grid_coordinate_tensor = .FALSE., &
     168              :                                                             owns_weight_tensors = .FALSE., &
     169              :                                                             owns_dynamic_tensors = .TRUE., &
     170              :                                                             owns_inputs = .TRUE., &
     171              :                                                             owns_static_tensors = .TRUE., &
     172              :                                                             uses_atom_chunk_routing = .FALSE., &
     173              :                                                             uses_atom_chunks = .FALSE., &
     174              :                                                             uses_collapsed_rks_dynamic = .FALSE.
     175              :    END TYPE skala_gpw_feature_type
     176              : 
     177              :    TYPE(skala_gpw_layout_cache_type), SAVE               :: cached_layout
     178              : 
     179              : CONTAINS
     180              : 
     181              : ! **************************************************************************************************
     182              : !> \brief Build a flat SKALA molecular feature dictionary from a local GPW grid.
     183              : !> \param features ...
     184              : !> \param rho_set ...
     185              : !> \param rho_r ...
     186              : !> \param particle_set ...
     187              : !> \param cell ...
     188              : !> \param requires_grad ...
     189              : !> \param weights ...
     190              : !> \param requires_coordinate_grad ...
     191              : !> \param requires_stress_grad ...
     192              : !> \param use_atom_chunks ...
     193              : !> \param route_atom_chunks ...
     194              : !> \param atom_partition ...
     195              : ! **************************************************************************************************
     196          144 :    SUBROUTINE skala_gpw_feature_build(features, rho_set, rho_r, particle_set, cell, &
     197              :                                       requires_grad, weights, requires_coordinate_grad, &
     198              :                                       requires_stress_grad, use_atom_chunks, route_atom_chunks, &
     199              :                                       atom_partition)
     200              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
     201              :       TYPE(xc_rho_set_type), INTENT(IN)                  :: rho_set
     202              :       TYPE(pw_r3d_rs_type), DIMENSION(:), INTENT(IN)     :: rho_r
     203              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     204              :       TYPE(cell_type), POINTER                           :: cell
     205              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_grad
     206              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     207              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_coordinate_grad, &
     208              :                                                             requires_stress_grad, use_atom_chunks, &
     209              :                                                             route_atom_chunks
     210              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     211              : 
     212              :       INTEGER :: handle, i, ipt, ispin, j, k, local_row, my_atom_partition, &
     213              :          ndynamic_local_per_point, nflat, nflat_local, nspins, phase_handle, real_base, row
     214              :       INTEGER, DIMENSION(2, 3)                           :: bo
     215              :       LOGICAL :: can_use_atom_chunks, collapse_spin_dynamics, my_requires_coordinate_grad, &
     216              :          my_requires_grad, my_requires_stress_grad, my_route_atom_chunks, my_use_atom_chunks
     217          144 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: global_dynamic, local_dynamic
     218          144 :       REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: rho, rhoa, rhob, tau_a, tau_b, tau_total
     219         1728 :       TYPE(cp_3d_r_cp_type), DIMENSION(3)                :: drho, drhoa, drhob
     220              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     221              : 
     222          144 :       CALL timeset("skala_gpw_feature_build", handle)
     223              : 
     224          144 :       my_requires_grad = .FALSE.
     225          144 :       IF (PRESENT(requires_grad)) my_requires_grad = requires_grad
     226          144 :       my_requires_coordinate_grad = .FALSE.
     227          144 :       IF (PRESENT(requires_coordinate_grad)) &
     228          144 :          my_requires_coordinate_grad = requires_coordinate_grad
     229          144 :       my_requires_stress_grad = .FALSE.
     230          144 :       IF (PRESENT(requires_stress_grad)) my_requires_stress_grad = requires_stress_grad
     231          144 :       my_use_atom_chunks = .FALSE.
     232          144 :       IF (PRESENT(use_atom_chunks)) my_use_atom_chunks = use_atom_chunks
     233          144 :       my_route_atom_chunks = .FALSE.
     234          144 :       IF (PRESENT(route_atom_chunks)) my_route_atom_chunks = route_atom_chunks
     235          144 :       my_atom_partition = skala_gpw_atom_partition_hard
     236          144 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     237          144 :       IF (my_atom_partition /= skala_gpw_atom_partition_hard .AND. &
     238              :           my_atom_partition /= skala_gpw_atom_partition_smooth) THEN
     239            0 :          CALL cp_abort(__LOCATION__, "Unknown native SKALA atom-partition mode.")
     240              :       END IF
     241          144 :       CPASSERT(ASSOCIATED(cell))
     242          144 :       CPASSERT(ASSOCIATED(particle_set))
     243          144 :       CPASSERT(SIZE(rho_r) == 1 .OR. SIZE(rho_r) == 2)
     244          144 :       CPASSERT(ASSOCIATED(rho_r(1)%pw_grid))
     245          144 :       pw_grid => rho_r(1)%pw_grid
     246              : 
     247          144 :       nspins = SIZE(rho_r)
     248         1440 :       bo = pw_grid%bounds_local
     249          144 :       nflat_local = pw_grid%ngpts_local
     250              : 
     251          144 :       CALL timeset("skala_gpw_pre_release", phase_handle)
     252          144 :       CALL skala_gpw_feature_release(features)
     253          144 :       CALL timestop(phase_handle)
     254              : 
     255          144 :       CALL timeset("skala_gpw_layout_cache", phase_handle)
     256          144 :       CALL ensure_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
     257          144 :       CALL timestop(phase_handle)
     258          144 :       nflat = cached_layout%nflat
     259          144 :       can_use_atom_chunks = my_use_atom_chunks .AND. cached_layout%chunk_feature_count > 0
     260          144 :       IF (my_requires_stress_grad .AND. can_use_atom_chunks) THEN
     261              :          CALL cp_abort(__LOCATION__, &
     262            0 :                        "Native SKALA analytical stress is not implemented with atom-chunk routing yet.")
     263              :       END IF
     264              :       collapse_spin_dynamics = (nspins == 1 .AND. can_use_atom_chunks .AND. &
     265          144 :                                 my_route_atom_chunks)
     266          144 :       ndynamic_local_per_point = ndynamic_per_point
     267          144 :       IF (collapse_spin_dynamics) ndynamic_local_per_point = nrks_dynamic_per_point
     268          432 :       ALLOCATE (local_dynamic(ndynamic_local_per_point*nflat_local))
     269          144 :       local_dynamic = 0.0_dp
     270              : 
     271          144 :       CALL timeset("skala_gpw_pack_local", phase_handle)
     272          144 :       IF (nspins == 1) THEN
     273           90 :          CALL xc_rho_set_get(rho_set, rho=rho, drho=drho, tau=tau_total)
     274              :       ELSE
     275              :          CALL xc_rho_set_get(rho_set, rhoa=rhoa, rhob=rhob, drhoa=drhoa, drhob=drhob, &
     276           54 :                              tau_a=tau_a, tau_b=tau_b)
     277              :       END IF
     278              : 
     279          144 :       local_row = 0
     280         3260 :       DO k = bo(1, 3), bo(2, 3)
     281        86728 :          DO j = bo(1, 2), bo(2, 2)
     282      1453202 :             DO i = bo(1, 1), bo(2, 1)
     283      1366618 :                local_row = local_row + 1
     284      1366618 :                real_base = ndynamic_local_per_point*(local_row - 1)
     285              : 
     286      1450086 :                IF (nspins == 1) THEN
     287      1012243 :                   IF (collapse_spin_dynamics) THEN
     288        64000 :                      local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
     289        64000 :                      local_dynamic(real_base + 2) = 0.5_dp*drho(1)%array(i, j, k)
     290        64000 :                      local_dynamic(real_base + 3) = 0.5_dp*drho(2)%array(i, j, k)
     291        64000 :                      local_dynamic(real_base + 4) = 0.5_dp*drho(3)%array(i, j, k)
     292        64000 :                      local_dynamic(real_base + 5) = 0.5_dp*tau_total(i, j, k)
     293              :                   ELSE
     294       948243 :                      local_dynamic(real_base + 1) = 0.5_dp*rho(i, j, k)
     295       948243 :                      local_dynamic(real_base + 2) = 0.5_dp*rho(i, j, k)
     296      2844729 :                      DO ispin = 1, 2
     297              :                         local_dynamic(real_base + 2 + 3*(ispin - 1) + 1) = &
     298      1896486 :                            0.5_dp*drho(1)%array(i, j, k)
     299              :                         local_dynamic(real_base + 2 + 3*(ispin - 1) + 2) = &
     300      1896486 :                            0.5_dp*drho(2)%array(i, j, k)
     301              :                         local_dynamic(real_base + 2 + 3*(ispin - 1) + 3) = &
     302      1896486 :                            0.5_dp*drho(3)%array(i, j, k)
     303      2844729 :                         local_dynamic(real_base + 8 + ispin) = 0.5_dp*tau_total(i, j, k)
     304              :                      END DO
     305              :                   END IF
     306              :                ELSE
     307       354375 :                   local_dynamic(real_base + 1) = rhoa(i, j, k)
     308       354375 :                   local_dynamic(real_base + 2) = rhob(i, j, k)
     309       354375 :                   local_dynamic(real_base + 3) = drhoa(1)%array(i, j, k)
     310       354375 :                   local_dynamic(real_base + 4) = drhoa(2)%array(i, j, k)
     311       354375 :                   local_dynamic(real_base + 5) = drhoa(3)%array(i, j, k)
     312       354375 :                   local_dynamic(real_base + 6) = drhob(1)%array(i, j, k)
     313       354375 :                   local_dynamic(real_base + 7) = drhob(2)%array(i, j, k)
     314       354375 :                   local_dynamic(real_base + 8) = drhob(3)%array(i, j, k)
     315       354375 :                   local_dynamic(real_base + 9) = tau_a(i, j, k)
     316       354375 :                   local_dynamic(real_base + 10) = tau_b(i, j, k)
     317              :                END IF
     318              :             END DO
     319              :          END DO
     320              :       END DO
     321          144 :       CALL timestop(phase_handle)
     322              : 
     323          144 :       CALL timeset("skala_gpw_copy_layout", phase_handle)
     324              :       CALL copy_cached_layout(features, my_requires_coordinate_grad .OR. my_requires_stress_grad, &
     325              :                               my_requires_stress_grad .OR. &
     326              :                               (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
     327          278 :                                (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
     328          144 :       CALL timestop(phase_handle)
     329              : 
     330          144 :       IF (can_use_atom_chunks .AND. my_route_atom_chunks) THEN
     331            2 :          CALL timeset("skala_gpw_route_dyn", phase_handle)
     332              :          CALL route_atom_chunk_dynamics(features, local_dynamic, pw_grid%para%group, &
     333            2 :                                         collapse_spin_dynamics)
     334            2 :          features%uses_atom_chunk_routing = .TRUE.
     335            2 :          features%uses_atom_chunks = .TRUE.
     336            2 :          CALL timestop(phase_handle)
     337              :       ELSE
     338          426 :          ALLOCATE (global_dynamic(ndynamic_per_point*cached_layout%npoint))
     339          142 :          CALL timeset("skala_gpw_allgatherv", phase_handle)
     340              :          CALL pw_grid%para%group%allgatherv(local_dynamic, global_dynamic, &
     341              :                                             cached_layout%dynamic_counts, &
     342          142 :                                             cached_layout%dynamic_displs)
     343          142 :          CALL timestop(phase_handle)
     344              : 
     345          142 :          CALL timeset("skala_gpw_reorder_dyn", phase_handle)
     346            0 :          ALLOCATE (features%density(nflat, 2), features%grad(nflat, 3, 2), &
     347          994 :                    features%kin(nflat, 2))
     348      5485978 :          features%density = 0.0_dp
     349     16457934 :          features%grad = 0.0_dp
     350      5485978 :          features%kin = 0.0_dp
     351              : 
     352      2742918 :          DO row = 1, nflat
     353      2742776 :             ipt = cached_layout%feature_source_points(row)
     354      2742776 :             real_base = ndynamic_per_point*(ipt - 1)
     355      8228328 :             features%density(row, :) = global_dynamic(real_base + 1:real_base + 2)
     356      2742776 :             features%grad(row, 1, 1) = global_dynamic(real_base + 3)
     357      2742776 :             features%grad(row, 2, 1) = global_dynamic(real_base + 4)
     358      2742776 :             features%grad(row, 3, 1) = global_dynamic(real_base + 5)
     359      2742776 :             features%grad(row, 1, 2) = global_dynamic(real_base + 6)
     360      2742776 :             features%grad(row, 2, 2) = global_dynamic(real_base + 7)
     361      2742776 :             features%grad(row, 3, 2) = global_dynamic(real_base + 8)
     362      8228470 :             features%kin(row, :) = global_dynamic(real_base + 9:real_base + 10)
     363              :          END DO
     364          426 :          CALL timestop(phase_handle)
     365              :       END IF
     366              : 
     367          144 :       CALL timeset("skala_gpw_feature_sums", phase_handle)
     368          144 :       IF (features%uses_atom_chunks) THEN
     369            2 :          IF (features%uses_collapsed_rks_dynamic) THEN
     370              :             features%electron_count = SUM(2.0_dp*features%chunk_density(:, 1)* &
     371        64002 :                                           cached_layout%chunk_grid_weights)
     372            2 :             features%spin_moment = 0.0_dp
     373              :          ELSE
     374              :             features%electron_count = SUM((features%chunk_density(:, 1) + &
     375              :                                            features%chunk_density(:, 2))* &
     376            0 :                                           cached_layout%chunk_grid_weights)
     377              :             features%spin_moment = SUM((features%chunk_density(:, 1) - &
     378              :                                         features%chunk_density(:, 2))* &
     379            0 :                                        cached_layout%chunk_grid_weights)
     380              :          END IF
     381            2 :          CALL pw_grid%para%group%sum(features%electron_count)
     382            2 :          CALL pw_grid%para%group%sum(features%spin_moment)
     383              :       ELSE
     384              :          features%electron_count = SUM((features%density(:, 1) + features%density(:, 2))* &
     385      2742918 :                                        features%grid_weights)
     386              :          features%spin_moment = SUM((features%density(:, 1) - features%density(:, 2))* &
     387      2742918 :                                     features%grid_weights)
     388              :       END IF
     389      2870920 :       features%grid_weight_sum = SUM(features%grid_weights)
     390          144 :       CALL timestop(phase_handle)
     391              : 
     392          144 :       CALL timeset("skala_gpw_tensor_update", phase_handle)
     393          144 :       IF (can_use_atom_chunks .AND. .NOT. features%uses_atom_chunks) THEN
     394            0 :          CALL extract_atom_chunk_dynamics(features)
     395            0 :          features%uses_atom_chunks = .TRUE.
     396              :       END IF
     397              :       CALL add_feature_tensors(features, my_requires_grad, my_requires_coordinate_grad, &
     398              :                                my_requires_stress_grad, &
     399              :                                features%uses_atom_chunks, &
     400              :                                requires_weight_grad= &
     401              :                                (my_atom_partition == skala_gpw_atom_partition_smooth .AND. &
     402          284 :                                 (my_requires_coordinate_grad .OR. my_requires_stress_grad)))
     403          144 :       CALL timestop(phase_handle)
     404          144 :       features%active = .TRUE.
     405              : 
     406          144 :       IF (ALLOCATED(global_dynamic)) DEALLOCATE (global_dynamic)
     407          144 :       DEALLOCATE (local_dynamic)
     408          144 :       CALL timestop(handle)
     409              : 
     410         1152 :    END SUBROUTINE skala_gpw_feature_build
     411              : 
     412              : ! **************************************************************************************************
     413              : !> \brief Ensure that static grid-to-atom layout data is cached for the current grid/geometry.
     414              : !> \param pw_grid ...
     415              : !> \param particle_set ...
     416              : !> \param cell ...
     417              : !> \param weights ...
     418              : !> \param atom_partition ...
     419              : ! **************************************************************************************************
     420          144 :    SUBROUTINE ensure_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
     421              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     422              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     423              :       TYPE(cell_type), POINTER                           :: cell
     424              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     425              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     426              : 
     427              :       INTEGER                                            :: my_atom_partition, phase_handle
     428              :       LOGICAL                                            :: cache_matches
     429              : 
     430          144 :       my_atom_partition = skala_gpw_atom_partition_hard
     431          144 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     432          144 :       IF (PRESENT(weights)) THEN
     433          144 :          CALL timeset("skala_gpw_layout_match", phase_handle)
     434              :          cache_matches = layout_cache_matches(pw_grid, particle_set, cell, weights, &
     435          144 :                                               my_atom_partition)
     436          144 :          CALL timestop(phase_handle)
     437          144 :          IF (cache_matches) RETURN
     438           52 :          CALL timeset("skala_gpw_layout_rebuild", phase_handle)
     439           52 :          CALL rebuild_layout_cache(pw_grid, particle_set, cell, weights, my_atom_partition)
     440           52 :          CALL timestop(phase_handle)
     441              :       ELSE
     442            0 :          CALL timeset("skala_gpw_layout_match", phase_handle)
     443              :          cache_matches = layout_cache_matches(pw_grid, particle_set, cell, &
     444            0 :                                               atom_partition=my_atom_partition)
     445            0 :          CALL timestop(phase_handle)
     446            0 :          IF (cache_matches) RETURN
     447            0 :          CALL timeset("skala_gpw_layout_rebuild", phase_handle)
     448              :          CALL rebuild_layout_cache(pw_grid, particle_set, cell, &
     449            0 :                                    atom_partition=my_atom_partition)
     450            0 :          CALL timestop(phase_handle)
     451              :       END IF
     452              : 
     453              :    END SUBROUTINE ensure_layout_cache
     454              : 
     455              : ! **************************************************************************************************
     456              : !> \brief Check whether the current static layout cache can be reused.
     457              : !> \param pw_grid ...
     458              : !> \param particle_set ...
     459              : !> \param cell ...
     460              : !> \param weights ...
     461              : !> \param atom_partition ...
     462              : !> \return ...
     463              : ! **************************************************************************************************
     464          144 :    FUNCTION layout_cache_matches(pw_grid, particle_set, cell, weights, atom_partition) RESULT(matches)
     465              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     466              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     467              :       TYPE(cell_type), POINTER                           :: cell
     468              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     469              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     470              :       LOGICAL                                            :: matches
     471              : 
     472              :       INTEGER                                            :: iatom, my_atom_partition
     473              :       LOGICAL                                            :: weights_match
     474              : 
     475          144 :       my_atom_partition = skala_gpw_atom_partition_hard
     476          144 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     477          144 :       matches = .FALSE.
     478          144 :       IF (.NOT. cached_layout%active) RETURN
     479          100 :       IF (cached_layout%atom_partition /= my_atom_partition) RETURN
     480          100 :       IF (cached_layout%natom /= SIZE(particle_set)) RETURN
     481          100 :       IF (cached_layout%nflat_local /= pw_grid%ngpts_local) RETURN
     482          100 :       IF (cached_layout%nproc /= pw_grid%para%group%num_pe) RETURN
     483         1000 :       IF (ANY(cached_layout%bo /= pw_grid%bounds_local)) RETURN
     484         1000 :       IF (ANY(cached_layout%bounds /= pw_grid%bounds)) RETURN
     485          400 :       IF (ANY(cached_layout%npts /= pw_grid%npts)) RETURN
     486          100 :       IF (ABS(cached_layout%dvol - pw_grid%dvol) > layout_tol) RETURN
     487         1300 :       IF (ANY(ABS(cached_layout%dh - pw_grid%dh) > layout_tol)) RETURN
     488         1300 :       IF (ANY(ABS(cached_layout%cell_hmat - cell%hmat) > layout_tol)) RETURN
     489          100 :       IF (.NOT. ALLOCATED(cached_layout%atom_coords)) RETURN
     490              : 
     491          292 :       DO iatom = 1, SIZE(particle_set)
     492          884 :          IF (ANY(ABS(cached_layout%atom_coords(:, iatom) - particle_set(iatom)%r) > layout_tol)) RETURN
     493              :       END DO
     494              : 
     495           92 :       IF (PRESENT(weights)) THEN
     496           92 :          weights_match = layout_weights_match(pw_grid, weights)
     497              :       ELSE
     498            0 :          weights_match = layout_weights_match(pw_grid)
     499              :       END IF
     500           92 :       IF (.NOT. weights_match) RETURN
     501              : 
     502          144 :       matches = .TRUE.
     503              : 
     504              :    END FUNCTION layout_cache_matches
     505              : 
     506              : ! **************************************************************************************************
     507              : !> \brief Check whether current optional integration weights match the cached static tensors.
     508              : !> \param pw_grid ...
     509              : !> \param weights ...
     510              : !> \return ...
     511              : ! **************************************************************************************************
     512           92 :    FUNCTION layout_weights_match(pw_grid, weights) RESULT(matches)
     513              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     514              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     515              :       LOGICAL                                            :: matches
     516              : 
     517              :       LOGICAL                                            :: has_weights
     518              :       REAL(KIND=dp)                                      :: weight_sum, weight_sumsq
     519              : 
     520           92 :       matches = .FALSE.
     521              :       MARK_USED(pw_grid)
     522           92 :       IF (PRESENT(weights)) THEN
     523           92 :          CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
     524              :       ELSE
     525              :          CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
     526            0 :                                 weight_sumsq=weight_sumsq)
     527              :       END IF
     528              : 
     529           92 :       IF (cached_layout%has_weights .NEQV. has_weights) RETURN
     530           92 :       IF (ABS(cached_layout%weight_sum - weight_sum) > layout_tol) RETURN
     531           92 :       IF (ABS(cached_layout%weight_sumsq - weight_sumsq) > layout_tol) RETURN
     532              : 
     533           92 :       matches = .TRUE.
     534              : 
     535              :    END FUNCTION layout_weights_match
     536              : 
     537              : ! **************************************************************************************************
     538              : !> \brief Build the static SKALA layout cache.
     539              : !> \param pw_grid ...
     540              : !> \param particle_set ...
     541              : !> \param cell ...
     542              : !> \param weights ...
     543              : !> \param atom_partition ...
     544              : ! **************************************************************************************************
     545           52 :    SUBROUTINE rebuild_layout_cache(pw_grid, particle_set, cell, weights, atom_partition)
     546              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
     547              :       TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
     548              :       TYPE(cell_type), POINTER                           :: cell
     549              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
     550              :       INTEGER, INTENT(IN), OPTIONAL                      :: atom_partition
     551              : 
     552              :       INTEGER :: feature_local, i, iatom, ipt, j, k, local_row, max_grid_size, max_local_features, &
     553              :          my_atom_partition, natom, nfeature_local, nflat, nflat_local, npoint, nproc, owner, pe, &
     554              :          pe_index, phase_handle, row, source_global, source_local, static_base
     555           52 :       INTEGER, ALLOCATABLE, DIMENSION(:) :: atom_offset, atom_position, chunk_atom_begin, &
     556           52 :          chunk_atom_end, cursor, feature_counts, feature_displs, global_owner, &
     557           52 :          global_source_points, local_feature_counts_tmp, local_owner, local_source_global, &
     558           52 :          local_source_points, point_counts, point_displs, static_counts, static_displs
     559              :       INTEGER, DIMENSION(2, 3)                           :: bo
     560              :       LOGICAL                                            :: has_weights
     561              :       REAL(KIND=dp)                                      :: base_weight, included_sum, &
     562              :                                                             partition_weight, weight_sum, &
     563              :                                                             weight_sumsq
     564           52 :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: distances, global_static, local_static, &
     565              :                                                             partition_weights
     566              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: atom_coords_pbc, atom_image_coords
     567              :       REAL(KIND=dp), DIMENSION(3)                        :: grid_point, owner_coord
     568              : 
     569           52 :       CALL release_layout_cache(cached_layout)
     570              : 
     571           52 :       my_atom_partition = skala_gpw_atom_partition_hard
     572           52 :       IF (PRESENT(atom_partition)) my_atom_partition = atom_partition
     573           52 :       natom = SIZE(particle_set)
     574          520 :       bo = pw_grid%bounds_local
     575           52 :       nflat_local = pw_grid%ngpts_local
     576           52 :       nproc = pw_grid%para%group%num_pe
     577           52 :       pe_index = pw_grid%para%group%mepos + 1
     578              : 
     579           52 :       IF (PRESENT(weights)) THEN
     580           52 :          CALL weights_signature(weights, has_weights, weight_sum, weight_sumsq)
     581              :       ELSE
     582              :          CALL weights_signature(has_weights=has_weights, weight_sum=weight_sum, &
     583            0 :                                 weight_sumsq=weight_sumsq)
     584              :       END IF
     585              : 
     586           52 :       max_local_features = nflat_local
     587           52 :       IF (my_atom_partition == skala_gpw_atom_partition_smooth) &
     588            6 :          max_local_features = nflat_local*natom
     589              :       ALLOCATE (local_owner(max_local_features), &
     590              :                 local_source_points(max_local_features), &
     591              :                 local_static(nstatic_per_point*max_local_features), &
     592              :                 local_feature_counts_tmp(nflat_local), feature_counts(nproc), &
     593              :                 feature_displs(nproc), point_counts(nproc), point_displs(nproc), &
     594              :                 static_counts(nproc), static_displs(nproc), atom_coords_pbc(3, natom), &
     595         1092 :                 atom_image_coords(3, natom), distances(natom), partition_weights(natom))
     596            0 :       ALLOCATE (cached_layout%feature_index(bo(1, 1):bo(2, 1), &
     597              :                                             bo(1, 2):bo(2, 2), &
     598          260 :                                             bo(1, 3):bo(2, 3)))
     599      1208910 :       cached_layout%feature_index = 0
     600           52 :       local_static = 0.0_dp
     601           52 :       local_feature_counts_tmp = 0
     602          184 :       DO iatom = 1, natom
     603          184 :          atom_coords_pbc(:, iatom) = pbc(particle_set(iatom)%r, cell, positive_range=.TRUE.)
     604              :       END DO
     605              : 
     606           52 :       CALL timeset("skala_gpw_layout_local", phase_handle)
     607           52 :       local_row = 0
     608           52 :       nfeature_local = 0
     609         1680 :       DO k = bo(1, 3), bo(2, 3)
     610        60236 :          DO j = bo(1, 2), bo(2, 2)
     611      1208858 :             DO i = bo(1, 1), bo(2, 1)
     612      1148674 :                local_row = local_row + 1
     613      4594696 :                grid_point = grid_coordinate(pw_grid, [i, j, k])
     614      1148674 :                base_weight = pw_grid%dvol
     615      1148674 :                IF (PRESENT(weights)) THEN
     616      1148674 :                   IF (ASSOCIATED(weights)) base_weight = base_weight*weights%array(i, j, k)
     617              :                END IF
     618      1148674 :                cached_layout%feature_index(i, j, k) = local_row
     619              : 
     620      1207230 :                IF (my_atom_partition == skala_gpw_atom_partition_hard) THEN
     621      1107202 :                   owner = nearest_atom(grid_point, atom_coords_pbc, cell)
     622      4428808 :                   owner_coord = atom_coords_pbc(:, owner)
     623      1107202 :                   nfeature_local = nfeature_local + 1
     624      1107202 :                   local_feature_counts_tmp(local_row) = 1
     625      1107202 :                   local_owner(nfeature_local) = owner
     626      1107202 :                   local_source_points(nfeature_local) = local_row
     627      1107202 :                   static_base = nstatic_per_point*(nfeature_local - 1)
     628              :                   local_static(static_base + 1:static_base + 3) = &
     629      1107202 :                      nearest_image_coordinate(owner_coord, grid_point, cell)
     630      1107202 :                   local_static(static_base + 4) = base_weight
     631              :                ELSE
     632              :                   CALL smooth_atom_partition(grid_point, atom_coords_pbc, cell, &
     633        41472 :                                              partition_weights, atom_image_coords, distances)
     634       124416 :                   included_sum = SUM(partition_weights, MASK=partition_weights > smooth_partition_eps)
     635        41472 :                   IF (included_sum <= 0.0_dp) THEN
     636            0 :                      owner = nearest_atom(grid_point, atom_coords_pbc, cell)
     637            0 :                      partition_weights = 0.0_dp
     638            0 :                      partition_weights(owner) = 1.0_dp
     639            0 :                      included_sum = 1.0_dp
     640              :                   END IF
     641       124416 :                   DO iatom = 1, natom
     642        82944 :                      IF (partition_weights(iatom) <= smooth_partition_eps) CYCLE
     643        82734 :                      partition_weight = partition_weights(iatom)/included_sum
     644        82734 :                      nfeature_local = nfeature_local + 1
     645              :                      local_feature_counts_tmp(local_row) = &
     646        82734 :                         local_feature_counts_tmp(local_row) + 1
     647        82734 :                      local_owner(nfeature_local) = iatom
     648        82734 :                      local_source_points(nfeature_local) = local_row
     649        82734 :                      static_base = nstatic_per_point*(nfeature_local - 1)
     650              :                      local_static(static_base + 1:static_base + 3) = &
     651       330936 :                         atom_image_coords(:, iatom)
     652       124416 :                      local_static(static_base + 4) = base_weight*partition_weight
     653              :                   END DO
     654              :                END IF
     655              :             END DO
     656              :          END DO
     657              :       END DO
     658           52 :       CALL timestop(phase_handle)
     659              : 
     660              :       ! SKALA groups all grid points by atom. This ordering is static while the
     661              :       ! grid, cell, atom positions, and optional integration weights are unchanged.
     662           52 :       CALL timeset("skala_gpw_layout_gather", phase_handle)
     663           52 :       CALL pw_grid%para%group%allgather(nflat_local, point_counts)
     664           52 :       CALL counts_to_displs(point_counts, point_displs)
     665          156 :       npoint = SUM(point_counts)
     666           52 :       CALL pw_grid%para%group%allgather(nfeature_local, feature_counts)
     667           52 :       CALL counts_to_displs(feature_counts, feature_displs)
     668          156 :       DO pe = 1, nproc
     669          104 :          static_counts(pe) = nstatic_per_point*feature_counts(pe)
     670          156 :          static_displs(pe) = nstatic_per_point*feature_displs(pe)
     671              :       END DO
     672          156 :       nflat = SUM(feature_counts)
     673              :       ALLOCATE (global_owner(nflat), global_source_points(nflat), &
     674          416 :                 global_static(nstatic_per_point*nflat), local_source_global(nfeature_local))
     675      1189988 :       DO feature_local = 1, nfeature_local
     676      1189988 :          local_source_global(feature_local) = point_displs(pe_index) + local_source_points(feature_local)
     677              :       END DO
     678              :       CALL pw_grid%para%group%allgatherv(local_owner(1:nfeature_local), global_owner, feature_counts, &
     679           52 :                                          feature_displs)
     680              :       CALL pw_grid%para%group%allgatherv(local_source_global, global_source_points, feature_counts, &
     681           52 :                                          feature_displs)
     682              :       CALL pw_grid%para%group%allgatherv(local_static(1:nstatic_per_point*nfeature_local), &
     683              :                                          global_static, static_counts, &
     684           52 :                                          static_displs)
     685           52 :       CALL timestop(phase_handle)
     686              : 
     687            0 :       ALLOCATE (cached_layout%chunk_feature_counts(nproc), &
     688            0 :                 cached_layout%chunk_feature_displs(nproc), &
     689            0 :                 cached_layout%chunk_grad_counts(nproc), cached_layout%chunk_grad_displs(nproc), &
     690            0 :                 cached_layout%feature_counts(nproc), cached_layout%feature_displs(nproc), &
     691            0 :                 cached_layout%dynamic_counts(nproc), cached_layout%dynamic_displs(nproc), &
     692            0 :                 cached_layout%route_grad_return_recv_counts(nproc), &
     693            0 :                 cached_layout%route_grad_return_recv_displs(nproc), &
     694            0 :                 cached_layout%route_grad_return_send_counts(nproc), &
     695            0 :                 cached_layout%route_grad_return_send_displs(nproc), &
     696            0 :                 cached_layout%route_point_recv_counts(nproc), &
     697            0 :                 cached_layout%route_point_recv_displs(nproc), &
     698            0 :                 cached_layout%route_point_send_counts(nproc), &
     699            0 :                 cached_layout%route_point_send_displs(nproc), &
     700            0 :                 cached_layout%feature_source_points(nflat), &
     701            0 :                 cached_layout%global_to_feature(npoint), cached_layout%atomic_grid_sizes(natom), &
     702            0 :                 cached_layout%local_feature_counts(nflat_local), &
     703            0 :                 cached_layout%local_feature_offsets(nflat_local + 1), &
     704            0 :                 cached_layout%local_feature_rows(nfeature_local), &
     705            0 :                 cached_layout%local_feature_points(nfeature_local), &
     706            0 :                 cached_layout%local_feature_indices(nfeature_local), atom_offset(natom + 1), &
     707              :                 atom_position(natom), chunk_atom_begin(nproc), chunk_atom_end(nproc), &
     708         1820 :                 cursor(nflat_local))
     709          156 :       cached_layout%feature_counts(:) = feature_counts
     710          156 :       cached_layout%feature_displs(:) = feature_displs
     711          156 :       cached_layout%dynamic_counts(:) = ndynamic_per_point*point_counts
     712          156 :       cached_layout%dynamic_displs(:) = ndynamic_per_point*point_displs
     713          184 :       cached_layout%atomic_grid_sizes = 0_int_8
     714      2297400 :       cached_layout%global_to_feature = 0
     715      1148726 :       cached_layout%local_feature_counts(:) = local_feature_counts_tmp
     716           52 :       cached_layout%local_feature_offsets(1) = 1
     717      1148726 :       DO local_row = 1, nflat_local
     718              :          cached_layout%local_feature_offsets(local_row + 1) = &
     719              :             cached_layout%local_feature_offsets(local_row) + &
     720      1148726 :             cached_layout%local_feature_counts(local_row)
     721              :       END DO
     722      1148726 :       cursor(:) = cached_layout%local_feature_offsets(1:nflat_local)
     723              : 
     724           52 :       CALL timeset("skala_gpw_layout_atom_sort", phase_handle)
     725      2379924 :       DO ipt = 1, nflat
     726              :          cached_layout%atomic_grid_sizes(global_owner(ipt)) = &
     727      2379924 :             cached_layout%atomic_grid_sizes(global_owner(ipt)) + 1_int_8
     728              :       END DO
     729           52 :       atom_offset(1) = 1
     730          184 :       DO iatom = 1, natom
     731          184 :          atom_offset(iatom + 1) = atom_offset(iatom) + INT(cached_layout%atomic_grid_sizes(iatom))
     732              :       END DO
     733          184 :       DO iatom = 1, natom
     734          184 :          atom_position(iatom) = atom_offset(iatom)
     735              :       END DO
     736          184 :       max_grid_size = MAXVAL(INT(cached_layout%atomic_grid_sizes))
     737              :       CALL build_atom_chunks(cached_layout%atomic_grid_sizes, atom_offset, nproc, &
     738              :                              chunk_atom_begin, chunk_atom_end, &
     739              :                              cached_layout%chunk_feature_counts, &
     740           52 :                              cached_layout%chunk_feature_displs)
     741          156 :       cached_layout%chunk_grad_counts(:) = ngrad_per_point*cached_layout%chunk_feature_counts
     742          156 :       cached_layout%chunk_grad_displs(:) = ngrad_per_point*cached_layout%chunk_feature_displs
     743           52 :       cached_layout%chunk_atom_begin = chunk_atom_begin(pe_index)
     744           52 :       cached_layout%chunk_atom_end = chunk_atom_end(pe_index)
     745           52 :       cached_layout%chunk_feature_begin = cached_layout%chunk_feature_displs(pe_index) + 1
     746           52 :       cached_layout%chunk_feature_count = cached_layout%chunk_feature_counts(pe_index)
     747              :       cached_layout%chunk_natom = cached_layout%chunk_atom_end - &
     748           52 :                                   cached_layout%chunk_atom_begin + 1
     749              : 
     750            0 :       ALLOCATE (cached_layout%grid_coords(3, nflat), cached_layout%grid_weights(nflat), &
     751            0 :                 cached_layout%atomic_grid_weights(nflat), &
     752            0 :                 cached_layout%coarse_0_atomic_coords(3, natom), &
     753            0 :                 cached_layout%atomic_grid_size_bound_shape(0, max_grid_size), &
     754          468 :                 cached_layout%atom_coords(3, natom))
     755      9519540 :       cached_layout%grid_coords = 0.0_dp
     756      2379924 :       cached_layout%grid_weights = 0.0_dp
     757      2379924 :       cached_layout%atomic_grid_weights = 0.0_dp
     758       980456 :       cached_layout%atomic_grid_size_bound_shape = 0_int_8
     759              : 
     760          184 :       DO iatom = 1, natom
     761          528 :          cached_layout%atom_coords(:, iatom) = particle_set(iatom)%r
     762          580 :          cached_layout%coarse_0_atomic_coords(:, iatom) = atom_coords_pbc(:, iatom)
     763              :       END DO
     764              : 
     765      2379924 :       DO ipt = 1, nflat
     766      2379872 :          owner = global_owner(ipt)
     767      2379872 :          row = atom_position(owner)
     768      2379872 :          atom_position(owner) = atom_position(owner) + 1
     769      2379872 :          source_global = global_source_points(ipt)
     770      2379872 :          cached_layout%feature_source_points(row) = source_global
     771      2379872 :          IF (cached_layout%global_to_feature(source_global) == 0) &
     772      2297348 :             cached_layout%global_to_feature(source_global) = row
     773      2379872 :          static_base = nstatic_per_point*(ipt - 1)
     774      9519488 :          cached_layout%grid_coords(:, row) = global_static(static_base + 1:static_base + 3)
     775      2379872 :          cached_layout%grid_weights(row) = global_static(static_base + 4)
     776      2379872 :          cached_layout%atomic_grid_weights(row) = cached_layout%grid_weights(row)
     777      2379872 :          source_local = source_global - point_displs(pe_index)
     778      2379924 :          IF (source_local >= 1 .AND. source_local <= nflat_local) THEN
     779      1189936 :             feature_local = cursor(source_local)
     780      1189936 :             cursor(source_local) = cursor(source_local) + 1
     781      1189936 :             cached_layout%local_feature_rows(feature_local) = row
     782      1189936 :             cached_layout%local_feature_points(feature_local) = source_local
     783              :          END IF
     784              :       END DO
     785              : 
     786      2297400 :       CPASSERT(ALL(cached_layout%global_to_feature > 0))
     787      1189988 :       CPASSERT(ALL(cached_layout%local_feature_rows > 0))
     788      1189988 :       CPASSERT(ALL(cached_layout%local_feature_points > 0))
     789         1680 :       DO k = bo(1, 3), bo(2, 3)
     790        60236 :          DO j = bo(1, 2), bo(2, 2)
     791      1208858 :             DO i = bo(1, 1), bo(2, 1)
     792      1148674 :                local_row = cached_layout%feature_index(i, j, k)
     793              :                cached_layout%feature_index(i, j, k) = &
     794      1207230 :                   cached_layout%local_feature_rows(cached_layout%local_feature_offsets(local_row))
     795              :             END DO
     796              :          END DO
     797              :       END DO
     798      1189988 :       DO feature_local = 1, nfeature_local
     799              :          cached_layout%local_feature_indices(feature_local) = &
     800      1189988 :             INT(cached_layout%local_feature_rows(feature_local) - 1, KIND=int_8)
     801              :       END DO
     802           52 :       CALL timestop(phase_handle)
     803           52 :       CALL timeset("skala_gpw_layout_chunk_routes", phase_handle)
     804              :       CALL build_atom_chunk_routes(cached_layout, cached_layout%local_feature_rows, &
     805           52 :                                    pw_grid%para%group)
     806           52 :       CALL build_atom_chunk_layout(cached_layout)
     807           52 :       CALL timestop(phase_handle)
     808              : 
     809           52 :       cached_layout%natom = natom
     810           52 :       cached_layout%nflat = nflat
     811           52 :       cached_layout%nflat_local = nflat_local
     812           52 :       cached_layout%npoint = npoint
     813           52 :       cached_layout%nproc = nproc
     814           52 :       cached_layout%atom_partition = my_atom_partition
     815          520 :       cached_layout%bo = bo
     816          520 :       cached_layout%bounds = pw_grid%bounds
     817          208 :       cached_layout%npts = pw_grid%npts
     818           52 :       cached_layout%dvol = pw_grid%dvol
     819          676 :       cached_layout%dh = pw_grid%dh
     820          676 :       cached_layout%cell_hmat = cell%hmat
     821           52 :       cached_layout%weight_sum = weight_sum
     822           52 :       cached_layout%weight_sumsq = weight_sumsq
     823           52 :       cached_layout%has_weights = has_weights
     824           52 :       CALL timeset("skala_gpw_layout_tensors", phase_handle)
     825           52 :       CALL build_static_layout_tensors(cached_layout)
     826           52 :       CALL timestop(phase_handle)
     827           52 :       cached_layout%active = .TRUE.
     828              : 
     829            0 :       DEALLOCATE (atom_coords_pbc, atom_image_coords, atom_offset, atom_position, &
     830            0 :                   chunk_atom_begin, chunk_atom_end, cursor, feature_counts, feature_displs, &
     831            0 :                   global_owner, global_source_points, global_static, local_feature_counts_tmp, &
     832            0 :                   distances, local_owner, local_source_global, local_source_points, &
     833            0 :                   local_static, partition_weights, point_counts, point_displs, static_counts, &
     834           52 :                   static_displs)
     835              : 
     836          260 :    END SUBROUTINE rebuild_layout_cache
     837              : 
     838              : ! **************************************************************************************************
     839              : !> \brief Build cached Torch tensors for static SKALA inputs.
     840              : !> \param cache ...
     841              : ! **************************************************************************************************
     842           52 :    SUBROUTINE build_static_layout_tensors(cache)
     843              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
     844              : 
     845           52 :       CPASSERT(.NOT. cache%static_tensors_active)
     846              : 
     847           52 :       CALL torch_tensor_from_array(cache%grid_coords_t, cache%grid_coords)
     848           52 :       CALL torch_tensor_to_device_leaf(cache%grid_coords_t, .FALSE.)
     849           52 :       CALL torch_tensor_from_array(cache%grid_weights_t, cache%grid_weights)
     850           52 :       CALL torch_tensor_to_device_leaf(cache%grid_weights_t, .FALSE.)
     851           52 :       CALL torch_tensor_from_array(cache%atomic_grid_weights_t, cache%atomic_grid_weights)
     852           52 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_weights_t, .FALSE.)
     853           52 :       CALL torch_tensor_from_array(cache%atomic_grid_sizes_t, cache%atomic_grid_sizes)
     854           52 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_sizes_t, .FALSE.)
     855           52 :       CALL torch_tensor_from_array(cache%coarse_0_atomic_coords_t, cache%coarse_0_atomic_coords)
     856           52 :       CALL torch_tensor_to_device_leaf(cache%coarse_0_atomic_coords_t, .FALSE.)
     857              :       CALL torch_tensor_from_array(cache%atomic_grid_size_bound_shape_t, &
     858           52 :                                    cache%atomic_grid_size_bound_shape)
     859           52 :       CALL torch_tensor_to_device_leaf(cache%atomic_grid_size_bound_shape_t, .FALSE.)
     860           52 :       CALL torch_tensor_from_array(cache%local_feature_indices_t, cache%local_feature_indices)
     861           52 :       CALL torch_tensor_to_device_leaf(cache%local_feature_indices_t, .FALSE.)
     862              : 
     863           52 :       CALL torch_dict_create(cache%static_inputs)
     864           52 :       CALL torch_dict_insert(cache%static_inputs, "grid_coords", cache%grid_coords_t)
     865           52 :       CALL torch_dict_insert(cache%static_inputs, "grid_weights", cache%grid_weights_t)
     866              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_weights", &
     867           52 :                              cache%atomic_grid_weights_t)
     868              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_sizes", &
     869           52 :                              cache%atomic_grid_sizes_t)
     870              :       CALL torch_dict_insert(cache%static_inputs, "atomic_grid_size_bound_shape", &
     871           52 :                              cache%atomic_grid_size_bound_shape_t)
     872           52 :       cache%static_tensors_active = .TRUE.
     873              : 
     874           52 :       IF (cache%chunk_feature_count > 0) THEN
     875           52 :          CPASSERT(.NOT. cache%chunk_static_tensors_active)
     876           52 :          CALL torch_tensor_from_array(cache%chunk_grid_coords_t, cache%chunk_grid_coords)
     877           52 :          CALL torch_tensor_to_device_leaf(cache%chunk_grid_coords_t, .FALSE.)
     878           52 :          CALL torch_tensor_from_array(cache%chunk_grid_weights_t, cache%chunk_grid_weights)
     879           52 :          CALL torch_tensor_to_device_leaf(cache%chunk_grid_weights_t, .FALSE.)
     880              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_weights_t, &
     881           52 :                                       cache%chunk_atomic_grid_weights)
     882           52 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_weights_t, .FALSE.)
     883              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_sizes_t, &
     884           52 :                                       cache%chunk_atomic_grid_sizes)
     885           52 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_sizes_t, .FALSE.)
     886              :          CALL torch_tensor_from_array(cache%chunk_coarse_0_atomic_coords_t, &
     887           52 :                                       cache%chunk_coarse_0_atomic_coords)
     888           52 :          CALL torch_tensor_to_device_leaf(cache%chunk_coarse_0_atomic_coords_t, .FALSE.)
     889              :          CALL torch_tensor_from_array(cache%chunk_atomic_grid_size_bound_shape_t, &
     890           52 :                                       cache%chunk_atomic_grid_size_bound_shape)
     891           52 :          CALL torch_tensor_to_device_leaf(cache%chunk_atomic_grid_size_bound_shape_t, .FALSE.)
     892           52 :          CALL torch_tensor_from_array(cache%chunk_feature_indices_t, cache%chunk_feature_indices)
     893           52 :          CALL torch_tensor_to_device_leaf(cache%chunk_feature_indices_t, .FALSE.)
     894              : 
     895           52 :          CALL torch_dict_create(cache%chunk_static_inputs)
     896              :          CALL torch_dict_insert(cache%chunk_static_inputs, "grid_coords", &
     897           52 :                                 cache%chunk_grid_coords_t)
     898              :          CALL torch_dict_insert(cache%chunk_static_inputs, "grid_weights", &
     899           52 :                                 cache%chunk_grid_weights_t)
     900              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_weights", &
     901           52 :                                 cache%chunk_atomic_grid_weights_t)
     902              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_sizes", &
     903           52 :                                 cache%chunk_atomic_grid_sizes_t)
     904              :          CALL torch_dict_insert(cache%chunk_static_inputs, "atomic_grid_size_bound_shape", &
     905           52 :                                 cache%chunk_atomic_grid_size_bound_shape_t)
     906           52 :          cache%chunk_static_tensors_active = .TRUE.
     907              :       END IF
     908              : 
     909           52 :    END SUBROUTINE build_static_layout_tensors
     910              : 
     911              : ! **************************************************************************************************
     912              : !> \brief Copy static cached layout arrays into a feature bundle.
     913              : !> \param features ...
     914              : !> \param needs_coordinate_array ...
     915              : !> \param needs_grid_coordinate_array ...
     916              : ! **************************************************************************************************
     917          144 :    SUBROUTINE copy_cached_layout(features, needs_coordinate_array, needs_grid_coordinate_array)
     918              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
     919              :       LOGICAL, INTENT(IN)                                :: needs_coordinate_array, &
     920              :                                                             needs_grid_coordinate_array
     921              : 
     922          144 :       CPASSERT(cached_layout%active)
     923              : 
     924            0 :       ALLOCATE (features%feature_index(LBOUND(cached_layout%feature_index, 1): &
     925              :                                        UBOUND(cached_layout%feature_index, 1), &
     926              :                                        LBOUND(cached_layout%feature_index, 2): &
     927              :                                        UBOUND(cached_layout%feature_index, 2), &
     928              :                                        LBOUND(cached_layout%feature_index, 3): &
     929          720 :                                        UBOUND(cached_layout%feature_index, 3)))
     930          432 :       ALLOCATE (features%grid_weights(cached_layout%nflat))
     931            0 :       ALLOCATE (features%local_feature_counts(cached_layout%nflat_local), &
     932            0 :                 features%local_feature_offsets(cached_layout%nflat_local + 1), &
     933         1008 :                 features%local_feature_rows(SIZE(cached_layout%local_feature_rows)))
     934              : 
     935      1453346 :       features%feature_index(:, :, :) = cached_layout%feature_index
     936      2870920 :       features%grid_weights(:) = cached_layout%grid_weights
     937      1366762 :       features%local_feature_counts(:) = cached_layout%local_feature_counts
     938      1366906 :       features%local_feature_offsets(:) = cached_layout%local_feature_offsets
     939      1435532 :       features%local_feature_rows(:) = cached_layout%local_feature_rows
     940          144 :       features%nflat = cached_layout%nflat
     941          144 :       features%nflat_local = cached_layout%nflat_local
     942          144 :       features%chunk_feature_count = cached_layout%chunk_feature_count
     943          144 :       features%atom_partition = cached_layout%atom_partition
     944          432 :       ALLOCATE (features%atomic_grid_sizes(cached_layout%natom))
     945          460 :       features%atomic_grid_sizes(:) = cached_layout%atomic_grid_sizes
     946          144 :       IF (needs_grid_coordinate_array) THEN
     947           30 :          ALLOCATE (features%grid_coords(3, cached_layout%nflat))
     948           20 :          ALLOCATE (features%atomic_grid_weights(cached_layout%nflat))
     949       773034 :          features%grid_coords(:, :) = cached_layout%grid_coords
     950       193266 :          features%atomic_grid_weights(:) = cached_layout%atomic_grid_weights
     951              :       END IF
     952            0 :       ALLOCATE (features%chunk_grad_counts(cached_layout%nproc), &
     953            0 :                 features%chunk_grad_displs(cached_layout%nproc), &
     954            0 :                 features%route_grad_return_recv_counts(cached_layout%nproc), &
     955            0 :                 features%route_grad_return_recv_displs(cached_layout%nproc), &
     956            0 :                 features%route_grad_return_send_counts(cached_layout%nproc), &
     957            0 :                 features%route_grad_return_send_displs(cached_layout%nproc), &
     958            0 :                 features%route_point_recv_counts(cached_layout%nproc), &
     959            0 :                 features%route_point_recv_displs(cached_layout%nproc), &
     960            0 :                 features%route_point_send_counts(cached_layout%nproc), &
     961            0 :                 features%route_point_send_displs(cached_layout%nproc), &
     962         2016 :                 features%route_send_local_rows(SIZE(cached_layout%route_send_local_rows)))
     963          432 :       features%chunk_grad_counts(:) = cached_layout%chunk_grad_counts
     964          432 :       features%chunk_grad_displs(:) = cached_layout%chunk_grad_displs
     965          432 :       features%route_grad_return_recv_counts(:) = cached_layout%route_grad_return_recv_counts
     966          432 :       features%route_grad_return_recv_displs(:) = cached_layout%route_grad_return_recv_displs
     967          432 :       features%route_grad_return_send_counts(:) = cached_layout%route_grad_return_send_counts
     968          432 :       features%route_grad_return_send_displs(:) = cached_layout%route_grad_return_send_displs
     969          432 :       features%route_point_recv_counts(:) = cached_layout%route_point_recv_counts
     970          432 :       features%route_point_recv_displs(:) = cached_layout%route_point_recv_displs
     971          432 :       features%route_point_send_counts(:) = cached_layout%route_point_send_counts
     972          432 :       features%route_point_send_displs(:) = cached_layout%route_point_send_displs
     973      1435532 :       features%route_send_local_rows(:) = cached_layout%route_send_local_rows
     974          144 :       IF (needs_coordinate_array) THEN
     975           48 :          ALLOCATE (features%coarse_0_atomic_coords(3, cached_layout%natom))
     976          144 :          features%coarse_0_atomic_coords(:, :) = cached_layout%coarse_0_atomic_coords
     977              :       END IF
     978              : 
     979          144 :    END SUBROUTINE copy_cached_layout
     980              : 
     981              : ! **************************************************************************************************
     982              : !> \brief Split the atom-ordered feature rows into contiguous atom chunks.
     983              : !> \param atomic_grid_sizes ...
     984              : !> \param atom_offset ...
     985              : !> \param nproc ...
     986              : !> \param chunk_atom_begin ...
     987              : !> \param chunk_atom_end ...
     988              : !> \param chunk_feature_counts ...
     989              : !> \param chunk_feature_displs ...
     990              : ! **************************************************************************************************
     991           52 :    SUBROUTINE build_atom_chunks(atomic_grid_sizes, atom_offset, nproc, chunk_atom_begin, &
     992           52 :                                 chunk_atom_end, chunk_feature_counts, chunk_feature_displs)
     993              :       INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN)      :: atomic_grid_sizes
     994              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: atom_offset
     995              :       INTEGER, INTENT(IN)                                :: nproc
     996              :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: chunk_atom_begin, chunk_atom_end, &
     997              :                                                             chunk_feature_counts, &
     998              :                                                             chunk_feature_displs
     999              : 
    1000              :       INTEGER :: best_limit, count, displ, end_atom, lower_limit, max_end_atom, midpoint, natom, &
    1001              :          next_atom, next_count, pe, ranks_left, target_chunks, total_count, upper_limit
    1002              : 
    1003           52 :       natom = SIZE(atomic_grid_sizes)
    1004          156 :       chunk_atom_begin = natom + 1
    1005          156 :       chunk_atom_end = natom
    1006          156 :       chunk_feature_counts = 0
    1007          156 :       chunk_feature_displs = 0
    1008           52 :       IF (natom == 0) RETURN
    1009              : 
    1010           52 :       target_chunks = MIN(nproc, natom)
    1011           52 :       total_count = atom_offset(natom + 1) - 1
    1012          184 :       lower_limit = MAXVAL(INT(atomic_grid_sizes))
    1013           52 :       lower_limit = MAX(lower_limit, (total_count + target_chunks - 1)/target_chunks)
    1014           52 :       upper_limit = total_count
    1015           52 :       best_limit = upper_limit
    1016          738 :       DO WHILE (lower_limit <= upper_limit)
    1017          686 :          midpoint = (lower_limit + upper_limit)/2
    1018          738 :          IF (atom_chunks_fit_limit(atomic_grid_sizes, midpoint, target_chunks)) THEN
    1019          574 :             best_limit = midpoint
    1020          574 :             upper_limit = midpoint - 1
    1021              :          ELSE
    1022          112 :             lower_limit = midpoint + 1
    1023              :          END IF
    1024              :       END DO
    1025              : 
    1026              :       displ = 0
    1027              :       next_atom = 1
    1028          156 :       DO pe = 1, nproc
    1029          104 :          chunk_feature_displs(pe) = displ
    1030          104 :          IF (pe > target_chunks .OR. next_atom > natom) CYCLE
    1031              : 
    1032          104 :          ranks_left = target_chunks - pe + 1
    1033          104 :          chunk_atom_begin(pe) = next_atom
    1034          104 :          max_end_atom = natom - ranks_left + 1
    1035          104 :          end_atom = next_atom
    1036          104 :          count = INT(atomic_grid_sizes(end_atom))
    1037          132 :          DO WHILE (end_atom < max_end_atom)
    1038           38 :             next_count = count + INT(atomic_grid_sizes(end_atom + 1))
    1039           38 :             IF (next_count > best_limit) EXIT
    1040              :             end_atom = end_atom + 1
    1041          104 :             count = next_count
    1042              :          END DO
    1043              : 
    1044          104 :          chunk_atom_end(pe) = end_atom
    1045          104 :          chunk_feature_counts(pe) = atom_offset(end_atom + 1) - atom_offset(next_atom)
    1046          104 :          displ = displ + chunk_feature_counts(pe)
    1047          156 :          next_atom = end_atom + 1
    1048              :       END DO
    1049              : 
    1050           52 :       CPASSERT(displ == atom_offset(natom + 1) - 1)
    1051              : 
    1052              :    END SUBROUTINE build_atom_chunks
    1053              : 
    1054              : ! **************************************************************************************************
    1055              : !> \brief Check if contiguous atom chunks can stay below a feature-count limit.
    1056              : !> \param atomic_grid_sizes ...
    1057              : !> \param limit ...
    1058              : !> \param nchunks ...
    1059              : !> \return ...
    1060              : ! **************************************************************************************************
    1061          686 :    FUNCTION atom_chunks_fit_limit(atomic_grid_sizes, limit, nchunks) RESULT(fits)
    1062              :       INTEGER(KIND=int_8), DIMENSION(:), INTENT(IN)      :: atomic_grid_sizes
    1063              :       INTEGER, INTENT(IN)                                :: limit, nchunks
    1064              :       LOGICAL                                            :: fits
    1065              : 
    1066              :       INTEGER                                            :: atom_count, chunk_count, iatom, &
    1067              :                                                             used_chunks
    1068              : 
    1069          686 :       fits = .FALSE.
    1070          686 :       IF (SIZE(atomic_grid_sizes) == 0) THEN
    1071          686 :          fits = .TRUE.
    1072              :          RETURN
    1073              :       END IF
    1074              : 
    1075         2478 :       used_chunks = 1
    1076         2478 :       chunk_count = 0
    1077         2478 :       DO iatom = 1, SIZE(atomic_grid_sizes)
    1078         1792 :          atom_count = INT(atomic_grid_sizes(iatom))
    1079         1792 :          IF (atom_count > limit) RETURN
    1080         2478 :          IF (chunk_count + atom_count > limit) THEN
    1081          798 :             used_chunks = used_chunks + 1
    1082          798 :             chunk_count = atom_count
    1083              :          ELSE
    1084              :             chunk_count = chunk_count + atom_count
    1085              :          END IF
    1086              :       END DO
    1087          686 :       fits = used_chunks <= nchunks
    1088              : 
    1089          686 :    END FUNCTION atom_chunks_fit_limit
    1090              : 
    1091              : ! **************************************************************************************************
    1092              : !> \brief Return the MPI rank owning an atom-ordered feature row.
    1093              : !> \param row ...
    1094              : !> \param counts ...
    1095              : !> \param displs ...
    1096              : !> \return ...
    1097              : ! **************************************************************************************************
    1098      1189936 :    FUNCTION feature_row_chunk_owner(row, counts, displs) RESULT(owner)
    1099              :       INTEGER, INTENT(IN)                                :: row
    1100              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: counts, displs
    1101              :       INTEGER                                            :: owner
    1102              : 
    1103              :       INTEGER                                            :: pe
    1104              : 
    1105      1189936 :       owner = 0
    1106      1735872 :       DO pe = 1, SIZE(counts)
    1107      1735872 :          IF (row > displs(pe) .AND. row <= displs(pe) + counts(pe)) THEN
    1108      1189936 :             owner = pe
    1109              :             RETURN
    1110              :          END IF
    1111              :       END DO
    1112              : 
    1113              :    END FUNCTION feature_row_chunk_owner
    1114              : 
    1115              : ! **************************************************************************************************
    1116              : !> \brief Build zero-based displacement arrays from per-rank counts.
    1117              : !> \param counts ...
    1118              : !> \param displs ...
    1119              : ! **************************************************************************************************
    1120          208 :    SUBROUTINE counts_to_displs(counts, displs)
    1121              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: counts
    1122              :       INTEGER, DIMENSION(:), INTENT(OUT)                 :: displs
    1123              : 
    1124              :       INTEGER                                            :: pe
    1125              : 
    1126          208 :       displs(1) = 0
    1127          416 :       DO pe = 2, SIZE(counts)
    1128          416 :          displs(pe) = displs(pe - 1) + counts(pe - 1)
    1129              :       END DO
    1130              : 
    1131          208 :    END SUBROUTINE counts_to_displs
    1132              : 
    1133              : ! **************************************************************************************************
    1134              : !> \brief Precompute all-to-all routing between local grid rows and atom chunks.
    1135              : !> \param cache ...
    1136              : !> \param local_to_global ...
    1137              : !> \param group ...
    1138              : ! **************************************************************************************************
    1139           52 :    SUBROUTINE build_atom_chunk_routes(cache, local_to_global, group)
    1140              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
    1141              :       INTEGER, DIMENSION(:), INTENT(IN)                  :: local_to_global
    1142              : 
    1143              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1144              : 
    1145              :       INTEGER                                            :: chunk_row, dest, local_feature, point_pos, row
    1146           52 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cursor, recv_meta, send_meta
    1147              : 
    1148            0 :       ALLOCATE (cache%route_local_dest(SIZE(local_to_global)), &
    1149            0 :                 cache%route_send_local_rows(SIZE(local_to_global)), &
    1150            0 :                 cache%chunk_return_positions(cache%chunk_feature_count), &
    1151          416 :                 cursor(SIZE(cache%route_point_send_counts)))
    1152          156 :       cache%route_point_send_counts = 0
    1153      1189988 :       cache%route_send_local_rows = 0
    1154      1189988 :       cache%chunk_return_positions = 0
    1155      1189988 :       DO local_feature = 1, SIZE(local_to_global)
    1156              :          dest = feature_row_chunk_owner(local_to_global(local_feature), &
    1157              :                                         cache%chunk_feature_counts, &
    1158      1189936 :                                         cache%chunk_feature_displs)
    1159      1189936 :          CPASSERT(dest > 0)
    1160      1189936 :          cache%route_local_dest(local_feature) = dest
    1161      1189988 :          cache%route_point_send_counts(dest) = cache%route_point_send_counts(dest) + 1
    1162              :       END DO
    1163           52 :       CALL counts_to_displs(cache%route_point_send_counts, cache%route_point_send_displs)
    1164          156 :       cursor(:) = cache%route_point_send_displs + 1
    1165      1189988 :       DO local_feature = 1, SIZE(local_to_global)
    1166      1189936 :          dest = cache%route_local_dest(local_feature)
    1167      1189936 :          point_pos = cursor(dest)
    1168      1189936 :          cursor(dest) = cursor(dest) + 1
    1169      1189988 :          cache%route_send_local_rows(point_pos) = cache%local_feature_points(local_feature)
    1170              :       END DO
    1171           52 :       CALL group%alltoall(cache%route_point_send_counts, cache%route_point_recv_counts, 1)
    1172           52 :       CALL counts_to_displs(cache%route_point_recv_counts, cache%route_point_recv_displs)
    1173              : 
    1174          208 :       ALLOCATE (send_meta(SIZE(local_to_global)), recv_meta(cache%chunk_feature_count))
    1175          156 :       cursor(:) = cache%route_point_send_displs + 1
    1176      1189988 :       DO local_feature = 1, SIZE(local_to_global)
    1177      1189936 :          dest = cache%route_local_dest(local_feature)
    1178      1189936 :          point_pos = cursor(dest)
    1179      1189936 :          cursor(dest) = cursor(dest) + 1
    1180      1189988 :          send_meta(point_pos) = local_to_global(local_feature)
    1181              :       END DO
    1182              :       CALL group%alltoall(send_meta, cache%route_point_send_counts, &
    1183              :                           cache%route_point_send_displs, recv_meta, &
    1184              :                           cache%route_point_recv_counts, &
    1185           52 :                           cache%route_point_recv_displs)
    1186      1189988 :       DO point_pos = 1, cache%chunk_feature_count
    1187      1189936 :          row = recv_meta(point_pos)
    1188      1189936 :          chunk_row = row - cache%chunk_feature_begin + 1
    1189      1189936 :          CPASSERT(chunk_row >= 1 .AND. chunk_row <= cache%chunk_feature_count)
    1190      1189988 :          cache%chunk_return_positions(chunk_row) = point_pos
    1191              :       END DO
    1192              : 
    1193          156 :       cache%route_grad_return_send_counts(:) = ngrad_per_point*cache%route_point_recv_counts
    1194          156 :       cache%route_grad_return_send_displs(:) = ngrad_per_point*cache%route_point_recv_displs
    1195          156 :       cache%route_grad_return_recv_counts(:) = ngrad_per_point*cache%route_point_send_counts
    1196          156 :       cache%route_grad_return_recv_displs(:) = ngrad_per_point*cache%route_point_send_displs
    1197              : 
    1198          156 :       CPASSERT(SUM(cache%route_point_send_counts) == SIZE(local_to_global))
    1199          156 :       CPASSERT(SUM(cache%route_point_recv_counts) == cache%chunk_feature_count)
    1200      1189988 :       CPASSERT(ALL(cache%route_send_local_rows > 0))
    1201      1189988 :       CPASSERT(ALL(cache%chunk_return_positions > 0))
    1202              : 
    1203           52 :       DEALLOCATE (cursor, recv_meta, send_meta)
    1204              : 
    1205           52 :    END SUBROUTINE build_atom_chunk_routes
    1206              : 
    1207              : ! **************************************************************************************************
    1208              : !> \brief Materialize the current rank's atom chunk static layout.
    1209              : !> \param cache ...
    1210              : ! **************************************************************************************************
    1211           52 :    SUBROUTINE build_atom_chunk_layout(cache)
    1212              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
    1213              : 
    1214              :       INTEGER                                            :: irow, max_grid_size, row_begin, row_end
    1215              : 
    1216           52 :       IF (cache%chunk_feature_count <= 0 .OR. cache%chunk_natom <= 0) RETURN
    1217              : 
    1218           52 :       row_begin = cache%chunk_feature_begin
    1219           52 :       row_end = row_begin + cache%chunk_feature_count - 1
    1220            0 :       ALLOCATE (cache%chunk_grid_coords(3, cache%chunk_feature_count), &
    1221            0 :                 cache%chunk_grid_weights(cache%chunk_feature_count), &
    1222            0 :                 cache%chunk_atomic_grid_weights(cache%chunk_feature_count), &
    1223            0 :                 cache%chunk_atomic_grid_sizes(cache%chunk_natom), &
    1224            0 :                 cache%chunk_coarse_0_atomic_coords(3, cache%chunk_natom), &
    1225          572 :                 cache%chunk_feature_indices(cache%chunk_feature_count))
    1226      4759796 :       cache%chunk_grid_coords(:, :) = cache%grid_coords(:, row_begin:row_end)
    1227      1189988 :       cache%chunk_grid_weights(:) = cache%grid_weights(row_begin:row_end)
    1228      1189988 :       cache%chunk_atomic_grid_weights(:) = cache%atomic_grid_weights(row_begin:row_end)
    1229              :       cache%chunk_atomic_grid_sizes(:) = &
    1230          118 :          cache%atomic_grid_sizes(cache%chunk_atom_begin:cache%chunk_atom_end)
    1231              :       cache%chunk_coarse_0_atomic_coords(:, :) = &
    1232          316 :          cache%coarse_0_atomic_coords(:, cache%chunk_atom_begin:cache%chunk_atom_end)
    1233              : 
    1234          118 :       max_grid_size = MAXVAL(INT(cache%chunk_atomic_grid_sizes))
    1235          104 :       ALLOCATE (cache%chunk_atomic_grid_size_bound_shape(0, max_grid_size))
    1236       959187 :       cache%chunk_atomic_grid_size_bound_shape = 0_int_8
    1237      1189988 :       DO irow = 1, cache%chunk_feature_count
    1238      1189988 :          cache%chunk_feature_indices(irow) = INT(irow - 1, KIND=int_8)
    1239              :       END DO
    1240              : 
    1241              :    END SUBROUTINE build_atom_chunk_layout
    1242              : 
    1243              : ! **************************************************************************************************
    1244              : !> \brief Send local dynamic feature rows to their atom-chunk owner ranks.
    1245              : !> \param features ...
    1246              : !> \param local_dynamic ...
    1247              : !> \param group ...
    1248              : !> \param collapse_spin_dynamics ...
    1249              : ! **************************************************************************************************
    1250            2 :    SUBROUTINE route_atom_chunk_dynamics(features, local_dynamic, group, collapse_spin_dynamics)
    1251              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1252              :       REAL(KIND=dp), DIMENSION(:), INTENT(IN)            :: local_dynamic
    1253              : 
    1254              :       CLASS(mp_comm_type), INTENT(IN)                    :: group
    1255              :       LOGICAL, INTENT(IN)                                :: collapse_spin_dynamics
    1256              : 
    1257              :       INTEGER                                            :: chunk_row, dest, dyn_base, local_feature, local_row, &
    1258              :                                                             ndynamic_route_per_point, nrecv, nsend, &
    1259              :                                                             point_pos, src_base
    1260            2 :       INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cursor, recv_counts, recv_displs, &
    1261              :                                                             send_counts, send_displs
    1262              :       REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: recv_dynamic, send_dynamic
    1263              : 
    1264            2 :       CPASSERT(cached_layout%chunk_feature_count > 0)
    1265            2 :       nsend = SIZE(cached_layout%route_local_dest)
    1266            6 :       nrecv = SUM(cached_layout%route_point_recv_counts)
    1267            2 :       CPASSERT(nsend == SIZE(cached_layout%local_feature_rows))
    1268            2 :       CPASSERT(nrecv == cached_layout%chunk_feature_count)
    1269            2 :       ndynamic_route_per_point = ndynamic_per_point
    1270            2 :       IF (collapse_spin_dynamics) ndynamic_route_per_point = nrks_dynamic_per_point
    1271              : 
    1272              :       ALLOCATE (send_dynamic(ndynamic_route_per_point*nsend), &
    1273              :                 recv_dynamic(ndynamic_route_per_point*nrecv), &
    1274              :                 cursor(cached_layout%nproc), send_counts(cached_layout%nproc), &
    1275              :                 send_displs(cached_layout%nproc), recv_counts(cached_layout%nproc), &
    1276           22 :                 recv_displs(cached_layout%nproc))
    1277            6 :       send_counts(:) = ndynamic_route_per_point*cached_layout%route_point_send_counts
    1278            6 :       send_displs(:) = ndynamic_route_per_point*cached_layout%route_point_send_displs
    1279            6 :       recv_counts(:) = ndynamic_route_per_point*cached_layout%route_point_recv_counts
    1280            6 :       recv_displs(:) = ndynamic_route_per_point*cached_layout%route_point_recv_displs
    1281            6 :       cursor(:) = cached_layout%route_point_send_displs + 1
    1282        64002 :       DO local_feature = 1, nsend
    1283        64000 :          dest = cached_layout%route_local_dest(local_feature)
    1284        64000 :          point_pos = cursor(dest)
    1285        64000 :          cursor(dest) = cursor(dest) + 1
    1286        64000 :          dyn_base = ndynamic_route_per_point*(point_pos - 1)
    1287        64000 :          local_row = cached_layout%local_feature_points(local_feature)
    1288        64000 :          src_base = ndynamic_route_per_point*(local_row - 1)
    1289              :          send_dynamic(dyn_base + 1:dyn_base + ndynamic_route_per_point) = &
    1290       384002 :             local_dynamic(src_base + 1:src_base + ndynamic_route_per_point)
    1291              :       END DO
    1292              : 
    1293              :       CALL group%alltoall(send_dynamic, send_counts, send_displs, recv_dynamic, recv_counts, &
    1294            2 :                           recv_displs)
    1295              : 
    1296            2 :       IF (collapse_spin_dynamics) THEN
    1297            0 :          ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 1), &
    1298            0 :                    features%chunk_grad(cached_layout%chunk_feature_count, 3, 1), &
    1299            0 :                    features%chunk_kin(cached_layout%chunk_feature_count, 1), &
    1300           16 :                    features%chunk_return_positions(cached_layout%chunk_feature_count))
    1301            2 :          features%uses_collapsed_rks_dynamic = .TRUE.
    1302              :       ELSE
    1303            0 :          ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
    1304            0 :                    features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
    1305            0 :                    features%chunk_kin(cached_layout%chunk_feature_count, 2), &
    1306            0 :                    features%chunk_return_positions(cached_layout%chunk_feature_count))
    1307              :       END IF
    1308        64002 :       features%chunk_return_positions(:) = cached_layout%chunk_return_positions
    1309              : 
    1310        64002 :       DO chunk_row = 1, cached_layout%chunk_feature_count
    1311        64000 :          point_pos = cached_layout%chunk_return_positions(chunk_row)
    1312        64000 :          CPASSERT(point_pos >= 1 .AND. point_pos <= cached_layout%chunk_feature_count)
    1313        64000 :          dyn_base = ndynamic_route_per_point*(point_pos - 1)
    1314        64002 :          IF (collapse_spin_dynamics) THEN
    1315        64000 :             features%chunk_density(chunk_row, 1) = recv_dynamic(dyn_base + 1)
    1316        64000 :             features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 2)
    1317        64000 :             features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 3)
    1318        64000 :             features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 4)
    1319        64000 :             features%chunk_kin(chunk_row, 1) = recv_dynamic(dyn_base + 5)
    1320              :          ELSE
    1321            0 :             features%chunk_density(chunk_row, :) = recv_dynamic(dyn_base + 1:dyn_base + 2)
    1322            0 :             features%chunk_grad(chunk_row, 1, 1) = recv_dynamic(dyn_base + 3)
    1323            0 :             features%chunk_grad(chunk_row, 2, 1) = recv_dynamic(dyn_base + 4)
    1324            0 :             features%chunk_grad(chunk_row, 3, 1) = recv_dynamic(dyn_base + 5)
    1325            0 :             features%chunk_grad(chunk_row, 1, 2) = recv_dynamic(dyn_base + 6)
    1326            0 :             features%chunk_grad(chunk_row, 2, 2) = recv_dynamic(dyn_base + 7)
    1327            0 :             features%chunk_grad(chunk_row, 3, 2) = recv_dynamic(dyn_base + 8)
    1328            0 :             features%chunk_kin(chunk_row, :) = recv_dynamic(dyn_base + 9:dyn_base + 10)
    1329              :          END IF
    1330              :       END DO
    1331        64002 :       CPASSERT(ALL(features%chunk_return_positions > 0))
    1332              : 
    1333            0 :       DEALLOCATE (cursor, recv_counts, recv_displs, recv_dynamic, send_counts, send_displs, &
    1334            2 :                   send_dynamic)
    1335              : 
    1336            2 :    END SUBROUTINE route_atom_chunk_dynamics
    1337              : 
    1338              : ! **************************************************************************************************
    1339              : !> \brief Extract the current rank's atom chunk from the global dynamic feature arrays.
    1340              : !> \param features ...
    1341              : ! **************************************************************************************************
    1342            0 :    SUBROUTINE extract_atom_chunk_dynamics(features)
    1343              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1344              : 
    1345              :       INTEGER                                            :: row_begin, row_end
    1346              : 
    1347            0 :       CPASSERT(cached_layout%chunk_feature_count > 0)
    1348            0 :       row_begin = cached_layout%chunk_feature_begin
    1349            0 :       row_end = row_begin + cached_layout%chunk_feature_count - 1
    1350            0 :       ALLOCATE (features%chunk_density(cached_layout%chunk_feature_count, 2), &
    1351            0 :                 features%chunk_grad(cached_layout%chunk_feature_count, 3, 2), &
    1352            0 :                 features%chunk_kin(cached_layout%chunk_feature_count, 2))
    1353            0 :       features%chunk_density(:, :) = features%density(row_begin:row_end, :)
    1354            0 :       features%chunk_grad(:, :, :) = features%grad(row_begin:row_end, :, :)
    1355            0 :       features%chunk_kin(:, :) = features%kin(row_begin:row_end, :)
    1356              : 
    1357            0 :    END SUBROUTINE extract_atom_chunk_dynamics
    1358              : 
    1359              : ! **************************************************************************************************
    1360              : !> \brief Compute a local signature for optional integration weights.
    1361              : !> \param weights ...
    1362              : !> \param has_weights ...
    1363              : !> \param weight_sum ...
    1364              : !> \param weight_sumsq ...
    1365              : ! **************************************************************************************************
    1366          144 :    SUBROUTINE weights_signature(weights, has_weights, weight_sum, weight_sumsq)
    1367              :       TYPE(pw_r3d_rs_type), OPTIONAL, POINTER            :: weights
    1368              :       LOGICAL, INTENT(OUT)                               :: has_weights
    1369              :       REAL(KIND=dp), INTENT(OUT)                         :: weight_sum, weight_sumsq
    1370              : 
    1371          144 :       has_weights = .FALSE.
    1372          144 :       weight_sum = 0.0_dp
    1373          144 :       weight_sumsq = 0.0_dp
    1374          144 :       IF (PRESENT(weights)) THEN
    1375          144 :          IF (ASSOCIATED(weights)) THEN
    1376            0 :             has_weights = .TRUE.
    1377            0 :             weight_sum = SUM(weights%array)
    1378            0 :             weight_sumsq = SUM(weights%array*weights%array)
    1379              :          END IF
    1380              :       END IF
    1381              : 
    1382          144 :    END SUBROUTINE weights_signature
    1383              : 
    1384              : ! **************************************************************************************************
    1385              : !> \brief Release cached layout arrays.
    1386              : !> \param cache ...
    1387              : ! **************************************************************************************************
    1388           52 :    SUBROUTINE release_layout_cache(cache)
    1389              :       TYPE(skala_gpw_layout_cache_type), INTENT(INOUT)   :: cache
    1390              : 
    1391           52 :       IF (cache%inputs_active) THEN
    1392            8 :          CALL torch_dict_release(cache%inputs)
    1393            8 :          cache%inputs_active = .FALSE.
    1394              :       END IF
    1395              : 
    1396           52 :       IF (cache%chunk_inputs_active) THEN
    1397            0 :          CALL torch_dict_release(cache%chunk_inputs)
    1398            0 :          cache%chunk_inputs_active = .FALSE.
    1399              :       END IF
    1400              : 
    1401           52 :       IF (cache%dynamic_tensors_active) THEN
    1402            8 :          CALL torch_tensor_release(cache%density_t)
    1403            8 :          CALL torch_tensor_release(cache%grad_t)
    1404            8 :          CALL torch_tensor_release(cache%kin_t)
    1405            8 :          cache%dynamic_tensors_active = .FALSE.
    1406              :       END IF
    1407              : 
    1408           52 :       IF (cache%chunk_dynamic_tensors_active) THEN
    1409            0 :          IF (cache%chunk_dynamic_input_views_active) THEN
    1410            0 :             CALL torch_tensor_release(cache%chunk_density_input_t)
    1411            0 :             CALL torch_tensor_release(cache%chunk_grad_input_t)
    1412            0 :             CALL torch_tensor_release(cache%chunk_kin_input_t)
    1413            0 :             cache%chunk_dynamic_input_views_active = .FALSE.
    1414              :          END IF
    1415            0 :          CALL torch_tensor_release(cache%chunk_density_t)
    1416            0 :          CALL torch_tensor_release(cache%chunk_grad_t)
    1417            0 :          CALL torch_tensor_release(cache%chunk_kin_t)
    1418            0 :          cache%chunk_dynamic_tensors_active = .FALSE.
    1419              :       END IF
    1420              : 
    1421           52 :       IF (cache%static_tensors_active) THEN
    1422            8 :          CALL torch_tensor_release(cache%grid_coords_t)
    1423            8 :          CALL torch_tensor_release(cache%grid_weights_t)
    1424            8 :          CALL torch_tensor_release(cache%atomic_grid_weights_t)
    1425            8 :          CALL torch_tensor_release(cache%atomic_grid_sizes_t)
    1426            8 :          CALL torch_tensor_release(cache%coarse_0_atomic_coords_t)
    1427            8 :          CALL torch_tensor_release(cache%atomic_grid_size_bound_shape_t)
    1428            8 :          CALL torch_tensor_release(cache%local_feature_indices_t)
    1429            8 :          CALL torch_dict_release(cache%static_inputs)
    1430            8 :          cache%static_tensors_active = .FALSE.
    1431              :       END IF
    1432              : 
    1433           52 :       IF (cache%chunk_static_tensors_active) THEN
    1434            8 :          CALL torch_tensor_release(cache%chunk_grid_coords_t)
    1435            8 :          CALL torch_tensor_release(cache%chunk_grid_weights_t)
    1436            8 :          CALL torch_tensor_release(cache%chunk_atomic_grid_weights_t)
    1437            8 :          CALL torch_tensor_release(cache%chunk_atomic_grid_sizes_t)
    1438            8 :          CALL torch_tensor_release(cache%chunk_coarse_0_atomic_coords_t)
    1439            8 :          CALL torch_tensor_release(cache%chunk_atomic_grid_size_bound_shape_t)
    1440            8 :          CALL torch_tensor_release(cache%chunk_feature_indices_t)
    1441            8 :          CALL torch_dict_release(cache%chunk_static_inputs)
    1442              :          cache%chunk_static_tensors_active = .FALSE.
    1443              :       END IF
    1444              : 
    1445           52 :       IF (ALLOCATED(cache%chunk_feature_counts)) DEALLOCATE (cache%chunk_feature_counts)
    1446           52 :       IF (ALLOCATED(cache%chunk_feature_displs)) DEALLOCATE (cache%chunk_feature_displs)
    1447           52 :       IF (ALLOCATED(cache%chunk_grad_counts)) DEALLOCATE (cache%chunk_grad_counts)
    1448           52 :       IF (ALLOCATED(cache%chunk_grad_displs)) DEALLOCATE (cache%chunk_grad_displs)
    1449           52 :       IF (ALLOCATED(cache%route_grad_return_recv_counts)) &
    1450            8 :          DEALLOCATE (cache%route_grad_return_recv_counts)
    1451           52 :       IF (ALLOCATED(cache%route_grad_return_recv_displs)) &
    1452            8 :          DEALLOCATE (cache%route_grad_return_recv_displs)
    1453           52 :       IF (ALLOCATED(cache%route_grad_return_send_counts)) &
    1454            8 :          DEALLOCATE (cache%route_grad_return_send_counts)
    1455           52 :       IF (ALLOCATED(cache%route_grad_return_send_displs)) &
    1456            8 :          DEALLOCATE (cache%route_grad_return_send_displs)
    1457           52 :       IF (ALLOCATED(cache%route_local_dest)) DEALLOCATE (cache%route_local_dest)
    1458           52 :       IF (ALLOCATED(cache%chunk_return_positions)) DEALLOCATE (cache%chunk_return_positions)
    1459           52 :       IF (ALLOCATED(cache%route_point_recv_counts)) DEALLOCATE (cache%route_point_recv_counts)
    1460           52 :       IF (ALLOCATED(cache%route_point_recv_displs)) DEALLOCATE (cache%route_point_recv_displs)
    1461           52 :       IF (ALLOCATED(cache%route_point_send_counts)) DEALLOCATE (cache%route_point_send_counts)
    1462           52 :       IF (ALLOCATED(cache%route_point_send_displs)) DEALLOCATE (cache%route_point_send_displs)
    1463           52 :       IF (ALLOCATED(cache%route_send_local_rows)) DEALLOCATE (cache%route_send_local_rows)
    1464           52 :       IF (ALLOCATED(cache%dynamic_counts)) DEALLOCATE (cache%dynamic_counts)
    1465           52 :       IF (ALLOCATED(cache%dynamic_displs)) DEALLOCATE (cache%dynamic_displs)
    1466           52 :       IF (ALLOCATED(cache%feature_counts)) DEALLOCATE (cache%feature_counts)
    1467           52 :       IF (ALLOCATED(cache%feature_displs)) DEALLOCATE (cache%feature_displs)
    1468           52 :       IF (ALLOCATED(cache%feature_source_points)) DEALLOCATE (cache%feature_source_points)
    1469           52 :       IF (ALLOCATED(cache%global_to_feature)) DEALLOCATE (cache%global_to_feature)
    1470           52 :       IF (ALLOCATED(cache%feature_index)) DEALLOCATE (cache%feature_index)
    1471           52 :       IF (ALLOCATED(cache%atomic_grid_sizes)) DEALLOCATE (cache%atomic_grid_sizes)
    1472           52 :       IF (ALLOCATED(cache%chunk_atomic_grid_sizes)) DEALLOCATE (cache%chunk_atomic_grid_sizes)
    1473           52 :       IF (ALLOCATED(cache%chunk_feature_indices)) DEALLOCATE (cache%chunk_feature_indices)
    1474           52 :       IF (ALLOCATED(cache%local_feature_counts)) DEALLOCATE (cache%local_feature_counts)
    1475           52 :       IF (ALLOCATED(cache%local_feature_indices)) DEALLOCATE (cache%local_feature_indices)
    1476           52 :       IF (ALLOCATED(cache%local_feature_offsets)) DEALLOCATE (cache%local_feature_offsets)
    1477           52 :       IF (ALLOCATED(cache%local_feature_points)) DEALLOCATE (cache%local_feature_points)
    1478           52 :       IF (ALLOCATED(cache%local_feature_rows)) DEALLOCATE (cache%local_feature_rows)
    1479           52 :       IF (ALLOCATED(cache%atomic_grid_size_bound_shape)) &
    1480            8 :          DEALLOCATE (cache%atomic_grid_size_bound_shape)
    1481           52 :       IF (ALLOCATED(cache%chunk_atomic_grid_size_bound_shape)) &
    1482            8 :          DEALLOCATE (cache%chunk_atomic_grid_size_bound_shape)
    1483           52 :       IF (ALLOCATED(cache%atomic_grid_weights)) DEALLOCATE (cache%atomic_grid_weights)
    1484           52 :       IF (ALLOCATED(cache%chunk_atomic_grid_weights)) DEALLOCATE (cache%chunk_atomic_grid_weights)
    1485           52 :       IF (ALLOCATED(cache%chunk_grid_weights)) DEALLOCATE (cache%chunk_grid_weights)
    1486           52 :       IF (ALLOCATED(cache%grid_weights)) DEALLOCATE (cache%grid_weights)
    1487           52 :       IF (ALLOCATED(cache%atom_coords)) DEALLOCATE (cache%atom_coords)
    1488           52 :       IF (ALLOCATED(cache%chunk_coarse_0_atomic_coords)) &
    1489            8 :          DEALLOCATE (cache%chunk_coarse_0_atomic_coords)
    1490           52 :       IF (ALLOCATED(cache%coarse_0_atomic_coords)) DEALLOCATE (cache%coarse_0_atomic_coords)
    1491           52 :       IF (ALLOCATED(cache%chunk_grid_coords)) DEALLOCATE (cache%chunk_grid_coords)
    1492           52 :       IF (ALLOCATED(cache%grid_coords)) DEALLOCATE (cache%grid_coords)
    1493              : 
    1494           52 :       cache%chunk_atom_begin = 1
    1495           52 :       cache%chunk_atom_end = 0
    1496           52 :       cache%chunk_feature_begin = 1
    1497           52 :       cache%chunk_feature_count = 0
    1498           52 :       cache%chunk_natom = 0
    1499           52 :       cache%natom = 0
    1500           52 :       cache%nflat = 0
    1501           52 :       cache%nflat_local = 0
    1502           52 :       cache%npoint = 0
    1503           52 :       cache%nproc = 0
    1504           52 :       cache%atom_partition = skala_gpw_atom_partition_hard
    1505          520 :       cache%bo = 0
    1506          520 :       cache%bounds = 0
    1507          208 :       cache%npts = 0
    1508           52 :       cache%dvol = 0.0_dp
    1509           52 :       cache%weight_sum = 0.0_dp
    1510           52 :       cache%weight_sumsq = 0.0_dp
    1511          676 :       cache%cell_hmat = 0.0_dp
    1512          676 :       cache%dh = 0.0_dp
    1513           52 :       cache%active = .FALSE.
    1514           52 :       cache%has_weights = .FALSE.
    1515           52 :       cache%chunk_dynamic_tensors_active = .FALSE.
    1516           52 :       cache%chunk_dynamic_input_views_active = .FALSE.
    1517           52 :       cache%chunk_inputs_active = .FALSE.
    1518           52 :       cache%chunk_inputs_use_collapsed_rks = .FALSE.
    1519           52 :       cache%chunk_static_tensors_active = .FALSE.
    1520           52 :       cache%dynamic_tensors_active = .FALSE.
    1521           52 :       cache%inputs_active = .FALSE.
    1522           52 :       cache%static_tensors_active = .FALSE.
    1523              : 
    1524           52 :    END SUBROUTINE release_layout_cache
    1525              : 
    1526              : ! **************************************************************************************************
    1527              : !> \brief Release Torch objects and backing arrays owned by a feature bundle.
    1528              : !> \param features ...
    1529              : ! **************************************************************************************************
    1530          296 :    SUBROUTINE skala_gpw_feature_release(features)
    1531              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1532              : 
    1533          296 :       IF (features%active) THEN
    1534          148 :          IF (features%owns_dynamic_tensors) THEN
    1535            4 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1536            4 :                CALL torch_tensor_release(features%density_input_t)
    1537            4 :                CALL torch_tensor_release(features%grad_input_t)
    1538            4 :                CALL torch_tensor_release(features%kin_input_t)
    1539              :             END IF
    1540            4 :             CALL torch_tensor_release(features%density_t)
    1541            4 :             CALL torch_tensor_release(features%grad_t)
    1542            4 :             CALL torch_tensor_release(features%kin_t)
    1543              :          END IF
    1544          148 :          IF (features%owns_static_tensors) THEN
    1545            4 :             CALL torch_tensor_release(features%grid_coords_t)
    1546            4 :             CALL torch_tensor_release(features%grid_weights_t)
    1547            4 :             CALL torch_tensor_release(features%atomic_grid_weights_t)
    1548            4 :             CALL torch_tensor_release(features%atomic_grid_sizes_t)
    1549            4 :             CALL torch_tensor_release(features%atomic_grid_size_bound_shape_t)
    1550              :          END IF
    1551          148 :          IF (features%owns_grid_coordinate_tensor) THEN
    1552            8 :             CALL torch_tensor_release(features%grid_coords_t)
    1553              :          END IF
    1554          148 :          IF (features%owns_weight_tensors) THEN
    1555           10 :             CALL torch_tensor_release(features%grid_weights_t)
    1556           10 :             CALL torch_tensor_release(features%atomic_grid_weights_t)
    1557              :          END IF
    1558          148 :          IF (features%owns_static_tensors .OR. features%owns_coordinate_tensor) THEN
    1559           20 :             CALL torch_tensor_release(features%coarse_0_atomic_coords_t)
    1560              :          END IF
    1561          148 :          IF (features%owns_inputs) CALL torch_dict_release(features%inputs)
    1562          148 :          features%active = .FALSE.
    1563          148 :          features%owns_coordinate_tensor = .FALSE.
    1564          148 :          features%owns_grid_coordinate_tensor = .FALSE.
    1565          148 :          features%owns_weight_tensors = .FALSE.
    1566          148 :          features%owns_dynamic_tensors = .TRUE.
    1567          148 :          features%owns_inputs = .TRUE.
    1568          148 :          features%owns_static_tensors = .TRUE.
    1569              :          features%uses_atom_chunk_routing = .FALSE.
    1570          148 :          features%uses_atom_chunks = .FALSE.
    1571              :          features%uses_collapsed_rks_dynamic = .FALSE.
    1572              :       END IF
    1573              : 
    1574          296 :       IF (ALLOCATED(features%chunk_density)) DEALLOCATE (features%chunk_density)
    1575          296 :       IF (ALLOCATED(features%chunk_grad)) DEALLOCATE (features%chunk_grad)
    1576          296 :       IF (ALLOCATED(features%chunk_kin)) DEALLOCATE (features%chunk_kin)
    1577          296 :       IF (ALLOCATED(features%density)) DEALLOCATE (features%density)
    1578          296 :       IF (ALLOCATED(features%grad)) DEALLOCATE (features%grad)
    1579          296 :       IF (ALLOCATED(features%kin)) DEALLOCATE (features%kin)
    1580          296 :       IF (ALLOCATED(features%chunk_grad_counts)) DEALLOCATE (features%chunk_grad_counts)
    1581          296 :       IF (ALLOCATED(features%chunk_grad_displs)) DEALLOCATE (features%chunk_grad_displs)
    1582          296 :       IF (ALLOCATED(features%chunk_return_positions)) DEALLOCATE (features%chunk_return_positions)
    1583          296 :       IF (ALLOCATED(features%route_grad_return_recv_counts)) &
    1584          144 :          DEALLOCATE (features%route_grad_return_recv_counts)
    1585          296 :       IF (ALLOCATED(features%route_grad_return_recv_displs)) &
    1586          144 :          DEALLOCATE (features%route_grad_return_recv_displs)
    1587          296 :       IF (ALLOCATED(features%route_grad_return_send_counts)) &
    1588          144 :          DEALLOCATE (features%route_grad_return_send_counts)
    1589          296 :       IF (ALLOCATED(features%route_grad_return_send_displs)) &
    1590          144 :          DEALLOCATE (features%route_grad_return_send_displs)
    1591          296 :       IF (ALLOCATED(features%route_point_recv_counts)) &
    1592          144 :          DEALLOCATE (features%route_point_recv_counts)
    1593          296 :       IF (ALLOCATED(features%route_point_recv_displs)) &
    1594          144 :          DEALLOCATE (features%route_point_recv_displs)
    1595          296 :       IF (ALLOCATED(features%route_point_send_counts)) &
    1596          144 :          DEALLOCATE (features%route_point_send_counts)
    1597          296 :       IF (ALLOCATED(features%route_point_send_displs)) &
    1598          144 :          DEALLOCATE (features%route_point_send_displs)
    1599          296 :       IF (ALLOCATED(features%route_send_local_rows)) DEALLOCATE (features%route_send_local_rows)
    1600          296 :       IF (ALLOCATED(features%feature_index)) DEALLOCATE (features%feature_index)
    1601          296 :       IF (ALLOCATED(features%local_feature_counts)) DEALLOCATE (features%local_feature_counts)
    1602          296 :       IF (ALLOCATED(features%local_feature_offsets)) DEALLOCATE (features%local_feature_offsets)
    1603          296 :       IF (ALLOCATED(features%local_feature_rows)) DEALLOCATE (features%local_feature_rows)
    1604          296 :       IF (ALLOCATED(features%grid_coords)) DEALLOCATE (features%grid_coords)
    1605          296 :       IF (ALLOCATED(features%grid_weights)) DEALLOCATE (features%grid_weights)
    1606          296 :       IF (ALLOCATED(features%atomic_grid_weights)) DEALLOCATE (features%atomic_grid_weights)
    1607          296 :       IF (ALLOCATED(features%atomic_grid_sizes)) DEALLOCATE (features%atomic_grid_sizes)
    1608          296 :       IF (ALLOCATED(features%coarse_0_atomic_coords)) DEALLOCATE (features%coarse_0_atomic_coords)
    1609          296 :       IF (ALLOCATED(features%atomic_grid_size_bound_shape)) &
    1610            4 :          DEALLOCATE (features%atomic_grid_size_bound_shape)
    1611          296 :       features%chunk_feature_count = 0
    1612          296 :       features%nflat = 0
    1613          296 :       features%nflat_local = 0
    1614          296 :       features%atom_partition = skala_gpw_atom_partition_hard
    1615          296 :       features%uses_atom_chunk_routing = .FALSE.
    1616          296 :       features%uses_collapsed_rks_dynamic = .FALSE.
    1617              : 
    1618          296 :    END SUBROUTINE skala_gpw_feature_release
    1619              : 
    1620              : ! **************************************************************************************************
    1621              : !> \brief Return how many atom-contiguous subchunks the cached rank chunk needs.
    1622              : !> \param max_rows ...
    1623              : !> \return ...
    1624              : ! **************************************************************************************************
    1625            4 :    FUNCTION skala_gpw_atom_subchunk_count(max_rows) RESULT(nsubchunks)
    1626              :       INTEGER, INTENT(IN)                                :: max_rows
    1627              :       INTEGER                                            :: nsubchunks
    1628              : 
    1629              :       INTEGER                                            :: atom_rows, iatom, rows
    1630              : 
    1631            4 :       nsubchunks = 1
    1632            4 :       IF (max_rows <= 0) RETURN
    1633            4 :       IF (.NOT. cached_layout%active) RETURN
    1634            4 :       IF (cached_layout%chunk_natom <= 0) RETURN
    1635              : 
    1636              :       nsubchunks = 0
    1637              :       rows = 0
    1638           12 :       DO iatom = 1, cached_layout%chunk_natom
    1639            8 :          atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
    1640            8 :          IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
    1641            4 :             nsubchunks = nsubchunks + 1
    1642            4 :             rows = 0
    1643              :          END IF
    1644           12 :          rows = rows + atom_rows
    1645              :       END DO
    1646            4 :       IF (rows > 0) nsubchunks = nsubchunks + 1
    1647            4 :       nsubchunks = MAX(1, nsubchunks)
    1648              : 
    1649            4 :    END FUNCTION skala_gpw_atom_subchunk_count
    1650              : 
    1651              : ! **************************************************************************************************
    1652              : !> \brief Build an atom-contiguous subchunk feature bundle from a rank-local atom chunk.
    1653              : !> \param parent ...
    1654              : !> \param features ...
    1655              : !> \param subchunk_index ...
    1656              : !> \param max_rows ...
    1657              : !> \param requires_grad ...
    1658              : ! **************************************************************************************************
    1659            4 :    SUBROUTINE skala_gpw_feature_build_atom_subchunk(parent, features, subchunk_index, &
    1660              :                                                     max_rows, requires_grad)
    1661              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: parent
    1662              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1663              :       INTEGER, INTENT(IN)                                :: subchunk_index, max_rows
    1664              :       LOGICAL, INTENT(IN)                                :: requires_grad
    1665              : 
    1666              :       INTEGER                                            :: atom_begin, atom_count, atom_end, &
    1667              :                                                             max_grid_size, row_begin, row_count, &
    1668              :                                                             row_end
    1669              : 
    1670            4 :       CALL skala_gpw_feature_release(features)
    1671            4 :       CPASSERT(parent%uses_atom_chunks)
    1672              :       CALL atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
    1673            4 :                                 row_begin, row_end)
    1674            4 :       atom_count = atom_end - atom_begin + 1
    1675            4 :       row_count = row_end - row_begin + 1
    1676            4 :       CPASSERT(atom_count > 0)
    1677            4 :       CPASSERT(row_count > 0)
    1678              :       MARK_USED(requires_grad)
    1679            8 :       max_grid_size = MAXVAL(INT(cached_layout%chunk_atomic_grid_sizes(atom_begin:atom_end)))
    1680              : 
    1681            8 :       ALLOCATE (features%atomic_grid_size_bound_shape(0, max_grid_size))
    1682        64004 :       features%atomic_grid_size_bound_shape = 0_int_8
    1683              : 
    1684            4 :       features%chunk_feature_count = row_count
    1685            4 :       features%nflat = parent%nflat
    1686            4 :       features%nflat_local = parent%nflat_local
    1687        64004 :       features%grid_weight_sum = SUM(cached_layout%chunk_grid_weights(row_begin:row_end))
    1688            4 :       features%uses_atom_chunks = .TRUE.
    1689            4 :       features%uses_atom_chunk_routing = parent%uses_atom_chunk_routing
    1690              :       CALL add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
    1691            4 :                                         row_count)
    1692            4 :       features%active = .TRUE.
    1693              : 
    1694            4 :    END SUBROUTINE skala_gpw_feature_build_atom_subchunk
    1695              : 
    1696              : ! **************************************************************************************************
    1697              : !> \brief Return atom and row bounds for an atom-contiguous rank-local subchunk.
    1698              : !> \param subchunk_index ...
    1699              : !> \param max_rows ...
    1700              : !> \param atom_begin ...
    1701              : !> \param atom_end ...
    1702              : !> \param row_begin ...
    1703              : !> \param row_end ...
    1704              : ! **************************************************************************************************
    1705            4 :    SUBROUTINE atom_subchunk_bounds(subchunk_index, max_rows, atom_begin, atom_end, &
    1706              :                                    row_begin, row_end)
    1707              :       INTEGER, INTENT(IN)                                :: subchunk_index, max_rows
    1708              :       INTEGER, INTENT(OUT)                               :: atom_begin, atom_end, row_begin, row_end
    1709              : 
    1710              :       INTEGER                                            :: atom_rows, current_subchunk, iatom, &
    1711              :                                                             row_cursor, rows
    1712              : 
    1713            4 :       CPASSERT(subchunk_index > 0)
    1714            4 :       CPASSERT(max_rows > 0)
    1715            4 :       CPASSERT(cached_layout%chunk_natom > 0)
    1716              : 
    1717            4 :       atom_begin = 1
    1718            4 :       atom_end = 0
    1719            4 :       row_begin = 1
    1720            4 :       row_end = 0
    1721            4 :       current_subchunk = 1
    1722            4 :       row_cursor = 1
    1723            4 :       rows = 0
    1724           10 :       DO iatom = 1, cached_layout%chunk_natom
    1725            8 :          atom_rows = INT(cached_layout%chunk_atomic_grid_sizes(iatom))
    1726            8 :          IF (rows > 0 .AND. rows + atom_rows > max_rows) THEN
    1727            4 :             IF (current_subchunk == subchunk_index) THEN
    1728            2 :                atom_end = iatom - 1
    1729            2 :                row_end = row_cursor - 1
    1730            2 :                RETURN
    1731              :             END IF
    1732            2 :             current_subchunk = current_subchunk + 1
    1733            2 :             atom_begin = iatom
    1734            2 :             row_begin = row_cursor
    1735            2 :             rows = 0
    1736              :          END IF
    1737            6 :          rows = rows + atom_rows
    1738            8 :          row_cursor = row_cursor + atom_rows
    1739              :       END DO
    1740              : 
    1741            2 :       IF (current_subchunk == subchunk_index) THEN
    1742            2 :          atom_end = cached_layout%chunk_natom
    1743            2 :          row_end = row_cursor - 1
    1744            2 :          RETURN
    1745              :       END IF
    1746              : 
    1747            0 :       CPABORT("Requested native SKALA atom subchunk does not exist.")
    1748              : 
    1749              :    END SUBROUTINE atom_subchunk_bounds
    1750              : 
    1751              : ! **************************************************************************************************
    1752              : !> \brief Insert a subchunk into a Torch dictionary using static views of the cached chunk tensors.
    1753              : !> \param parent ...
    1754              : !> \param features ...
    1755              : !> \param atom_begin ...
    1756              : !> \param atom_count ...
    1757              : !> \param row_begin ...
    1758              : !> \param row_count ...
    1759              : ! **************************************************************************************************
    1760            4 :    SUBROUTINE add_subchunk_feature_tensors(parent, features, atom_begin, atom_count, row_begin, &
    1761              :                                            row_count)
    1762              :       TYPE(skala_gpw_feature_type), INTENT(IN)           :: parent
    1763              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1764              :       INTEGER, INTENT(IN)                                :: atom_begin, atom_count, row_begin, &
    1765              :                                                             row_count
    1766              : 
    1767            4 :       CPASSERT(cached_layout%chunk_static_tensors_active)
    1768            4 :       CPASSERT(parent%active)
    1769            4 :       CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
    1770              : 
    1771            4 :       features%owns_coordinate_tensor = .FALSE.
    1772            4 :       features%owns_dynamic_tensors = .TRUE.
    1773            4 :       features%owns_inputs = .TRUE.
    1774            4 :       features%owns_static_tensors = .TRUE.
    1775            4 :       features%uses_collapsed_rks_dynamic = parent%uses_collapsed_rks_dynamic
    1776              : 
    1777              :       CALL torch_tensor_narrow(cached_layout%chunk_grid_coords_t, 0, row_begin - 1, &
    1778            4 :                                row_count, features%grid_coords_t)
    1779              :       CALL torch_tensor_narrow(cached_layout%chunk_grid_weights_t, 0, row_begin - 1, &
    1780            4 :                                row_count, features%grid_weights_t)
    1781              :       CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_weights_t, 0, row_begin - 1, &
    1782            4 :                                row_count, features%atomic_grid_weights_t)
    1783              :       CALL torch_tensor_narrow(cached_layout%chunk_atomic_grid_sizes_t, 0, atom_begin - 1, &
    1784            4 :                                atom_count, features%atomic_grid_sizes_t)
    1785              :       CALL torch_tensor_narrow(cached_layout%chunk_coarse_0_atomic_coords_t, 0, &
    1786            4 :                                atom_begin - 1, atom_count, features%coarse_0_atomic_coords_t)
    1787              :       CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
    1788            4 :                                    features%atomic_grid_size_bound_shape)
    1789            4 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
    1790              :       CALL torch_tensor_narrow(parent%density_t, 1, row_begin - 1, row_count, &
    1791            4 :                                features%density_t)
    1792            4 :       CALL torch_tensor_narrow(parent%grad_t, 2, row_begin - 1, row_count, features%grad_t)
    1793            4 :       CALL torch_tensor_narrow(parent%kin_t, 1, row_begin - 1, row_count, features%kin_t)
    1794            4 :       IF (features%uses_collapsed_rks_dynamic) THEN
    1795            4 :          CALL torch_tensor_expand_dim(features%density_t, 0, 2, features%density_input_t)
    1796            4 :          CALL torch_tensor_expand_dim(features%grad_t, 0, 2, features%grad_input_t)
    1797            4 :          CALL torch_tensor_expand_dim(features%kin_t, 0, 2, features%kin_input_t)
    1798              :       END IF
    1799              : 
    1800            4 :       CALL torch_dict_create(features%inputs)
    1801            4 :       CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    1802            4 :       CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
    1803              :       CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
    1804            4 :                              features%atomic_grid_weights_t)
    1805              :       CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
    1806            4 :                              features%atomic_grid_sizes_t)
    1807              :       CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
    1808            4 :                              features%atomic_grid_size_bound_shape_t)
    1809            4 :       IF (features%uses_collapsed_rks_dynamic) THEN
    1810            4 :          CALL torch_dict_insert(features%inputs, "density", features%density_input_t)
    1811            4 :          CALL torch_dict_insert(features%inputs, "grad", features%grad_input_t)
    1812            4 :          CALL torch_dict_insert(features%inputs, "kin", features%kin_input_t)
    1813              :       ELSE
    1814            0 :          CALL torch_dict_insert(features%inputs, "density", features%density_t)
    1815            0 :          CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    1816            0 :          CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    1817              :       END IF
    1818              :       CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    1819            4 :                              features%coarse_0_atomic_coords_t)
    1820              : 
    1821            4 :    END SUBROUTINE add_subchunk_feature_tensors
    1822              : 
    1823              : ! **************************************************************************************************
    1824              : !> \brief Insert owned subchunk arrays into a Torch dictionary.
    1825              : !> \param features ...
    1826              : !> \param requires_grad ...
    1827              : ! **************************************************************************************************
    1828            0 :    SUBROUTINE add_owned_feature_tensors(features, requires_grad)
    1829              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1830              :       LOGICAL, INTENT(IN)                                :: requires_grad
    1831              : 
    1832            0 :       CPASSERT(ALLOCATED(features%chunk_density))
    1833            0 :       CPASSERT(ALLOCATED(features%chunk_grad))
    1834            0 :       CPASSERT(ALLOCATED(features%chunk_kin))
    1835            0 :       CPASSERT(ALLOCATED(features%grid_coords))
    1836            0 :       CPASSERT(ALLOCATED(features%grid_weights))
    1837            0 :       CPASSERT(ALLOCATED(features%atomic_grid_weights))
    1838            0 :       CPASSERT(ALLOCATED(features%atomic_grid_sizes))
    1839            0 :       CPASSERT(ALLOCATED(features%atomic_grid_size_bound_shape))
    1840            0 :       CPASSERT(ALLOCATED(features%coarse_0_atomic_coords))
    1841              : 
    1842            0 :       features%owns_coordinate_tensor = .FALSE.
    1843            0 :       features%owns_dynamic_tensors = .TRUE.
    1844            0 :       features%owns_inputs = .TRUE.
    1845            0 :       features%owns_static_tensors = .TRUE.
    1846              : 
    1847            0 :       CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
    1848            0 :       CALL torch_tensor_to_device_leaf(features%grid_coords_t, .FALSE.)
    1849            0 :       CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
    1850            0 :       CALL torch_tensor_to_device_leaf(features%grid_weights_t, .FALSE.)
    1851            0 :       CALL torch_tensor_from_array(features%atomic_grid_weights_t, features%atomic_grid_weights)
    1852            0 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .FALSE.)
    1853            0 :       CALL torch_tensor_from_array(features%atomic_grid_sizes_t, features%atomic_grid_sizes)
    1854            0 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_sizes_t, .FALSE.)
    1855              :       CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
    1856            0 :                                    features%coarse_0_atomic_coords)
    1857            0 :       CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .FALSE.)
    1858              :       CALL torch_tensor_from_array(features%atomic_grid_size_bound_shape_t, &
    1859            0 :                                    features%atomic_grid_size_bound_shape)
    1860            0 :       CALL torch_tensor_to_device_leaf(features%atomic_grid_size_bound_shape_t, .FALSE.)
    1861            0 :       CALL torch_tensor_from_array(features%density_t, features%chunk_density)
    1862            0 :       CALL torch_tensor_to_device_leaf(features%density_t, requires_grad)
    1863            0 :       CALL torch_tensor_from_array(features%grad_t, features%chunk_grad)
    1864            0 :       CALL torch_tensor_to_device_leaf(features%grad_t, requires_grad)
    1865            0 :       CALL torch_tensor_from_array(features%kin_t, features%chunk_kin)
    1866            0 :       CALL torch_tensor_to_device_leaf(features%kin_t, requires_grad)
    1867              : 
    1868            0 :       CALL torch_dict_create(features%inputs)
    1869            0 :       CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    1870            0 :       CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
    1871              :       CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
    1872            0 :                              features%atomic_grid_weights_t)
    1873              :       CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
    1874            0 :                              features%atomic_grid_sizes_t)
    1875              :       CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
    1876            0 :                              features%atomic_grid_size_bound_shape_t)
    1877            0 :       CALL torch_dict_insert(features%inputs, "density", features%density_t)
    1878            0 :       CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    1879            0 :       CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    1880              :       CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    1881            0 :                              features%coarse_0_atomic_coords_t)
    1882              : 
    1883            0 :    END SUBROUTINE add_owned_feature_tensors
    1884              : 
    1885              : ! **************************************************************************************************
    1886              : !> \brief Insert all SKALA feature tensors into the Torch dictionary.
    1887              : !> \param features ...
    1888              : !> \param requires_grad ...
    1889              : !> \param requires_coordinate_grad ...
    1890              : !> \param requires_stress_grad ...
    1891              : !> \param use_atom_chunks ...
    1892              : !> \param requires_weight_grad ...
    1893              : ! **************************************************************************************************
    1894          144 :    SUBROUTINE add_feature_tensors(features, requires_grad, requires_coordinate_grad, &
    1895              :                                   requires_stress_grad, use_atom_chunks, requires_weight_grad)
    1896              :       TYPE(skala_gpw_feature_type), INTENT(INOUT)        :: features
    1897              :       LOGICAL, INTENT(IN)                                :: requires_grad, requires_coordinate_grad, &
    1898              :                                                             requires_stress_grad, use_atom_chunks
    1899              :       LOGICAL, INTENT(IN), OPTIONAL                      :: requires_weight_grad
    1900              : 
    1901              :       LOGICAL                                            :: my_requires_weight_grad
    1902              : 
    1903          144 :       my_requires_weight_grad = .FALSE.
    1904          144 :       IF (PRESENT(requires_weight_grad)) my_requires_weight_grad = requires_weight_grad
    1905              : 
    1906          144 :       CPASSERT(cached_layout%static_tensors_active)
    1907          144 :       features%owns_static_tensors = .FALSE.
    1908          144 :       features%owns_coordinate_tensor = .FALSE.
    1909          144 :       features%owns_grid_coordinate_tensor = .FALSE.
    1910          144 :       features%owns_weight_tensors = .FALSE.
    1911          144 :       features%owns_dynamic_tensors = .FALSE.
    1912          144 :       features%owns_inputs = .TRUE.
    1913          144 :       IF (use_atom_chunks) THEN
    1914            2 :          IF (requires_stress_grad .OR. my_requires_weight_grad) THEN
    1915              :             CALL cp_abort(__LOCATION__, &
    1916              :                           "Native SKALA analytical stress/SMOOTH derivatives are not implemented "// &
    1917            0 :                           "with atom-chunk tensors yet.")
    1918              :          END IF
    1919            2 :          CPASSERT(cached_layout%chunk_static_tensors_active)
    1920            2 :          features%grid_coords_t = cached_layout%chunk_grid_coords_t
    1921            2 :          features%grid_weights_t = cached_layout%chunk_grid_weights_t
    1922            2 :          features%atomic_grid_weights_t = cached_layout%chunk_atomic_grid_weights_t
    1923            2 :          features%atomic_grid_sizes_t = cached_layout%chunk_atomic_grid_sizes_t
    1924              :          features%atomic_grid_size_bound_shape_t = &
    1925            2 :             cached_layout%chunk_atomic_grid_size_bound_shape_t
    1926            2 :          features%local_feature_indices_t = cached_layout%chunk_feature_indices_t
    1927              : 
    1928            2 :          IF (cached_layout%chunk_inputs_active .AND. &
    1929              :              (cached_layout%chunk_inputs_use_collapsed_rks .NEQV. &
    1930              :               features%uses_collapsed_rks_dynamic)) THEN
    1931            0 :             CALL torch_dict_release(cached_layout%chunk_inputs)
    1932            0 :             cached_layout%chunk_inputs_active = .FALSE.
    1933              :          END IF
    1934            2 :          IF (.NOT. features%uses_collapsed_rks_dynamic .AND. &
    1935              :              cached_layout%chunk_dynamic_input_views_active) THEN
    1936            0 :             CALL torch_tensor_release(cached_layout%chunk_density_input_t)
    1937            0 :             CALL torch_tensor_release(cached_layout%chunk_grad_input_t)
    1938            0 :             CALL torch_tensor_release(cached_layout%chunk_kin_input_t)
    1939            0 :             cached_layout%chunk_dynamic_input_views_active = .FALSE.
    1940              :          END IF
    1941              : 
    1942              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_density_t, &
    1943            2 :                                             features%chunk_density, requires_grad=requires_grad)
    1944            2 :          features%density_t = cached_layout%chunk_density_t
    1945              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_grad_t, features%chunk_grad, &
    1946            2 :                                             requires_grad=requires_grad)
    1947            2 :          features%grad_t = cached_layout%chunk_grad_t
    1948              :          CALL torch_tensor_reset_from_array(cached_layout%chunk_kin_t, features%chunk_kin, &
    1949            2 :                                             requires_grad=requires_grad)
    1950            2 :          features%kin_t = cached_layout%chunk_kin_t
    1951            2 :          cached_layout%chunk_dynamic_tensors_active = .TRUE.
    1952              : 
    1953            2 :          IF (features%uses_collapsed_rks_dynamic .AND. &
    1954              :              .NOT. cached_layout%chunk_dynamic_input_views_active) THEN
    1955              :             CALL torch_tensor_expand_dim(cached_layout%chunk_density_t, 0, 2, &
    1956            2 :                                          cached_layout%chunk_density_input_t)
    1957              :             CALL torch_tensor_expand_dim(cached_layout%chunk_grad_t, 0, 2, &
    1958            2 :                                          cached_layout%chunk_grad_input_t)
    1959              :             CALL torch_tensor_expand_dim(cached_layout%chunk_kin_t, 0, 2, &
    1960            2 :                                          cached_layout%chunk_kin_input_t)
    1961            2 :             cached_layout%chunk_dynamic_input_views_active = .TRUE.
    1962              :          END IF
    1963            2 :          IF (features%uses_collapsed_rks_dynamic) THEN
    1964            2 :             features%density_input_t = cached_layout%chunk_density_input_t
    1965            2 :             features%grad_input_t = cached_layout%chunk_grad_input_t
    1966            2 :             features%kin_input_t = cached_layout%chunk_kin_input_t
    1967              :          END IF
    1968              : 
    1969            2 :          IF (.NOT. cached_layout%chunk_inputs_active) THEN
    1970            2 :             CALL torch_dict_clone(cached_layout%chunk_static_inputs, cached_layout%chunk_inputs)
    1971            2 :             IF (features%uses_collapsed_rks_dynamic) THEN
    1972              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
    1973            2 :                                       features%density_input_t)
    1974              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
    1975            2 :                                       features%grad_input_t)
    1976              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
    1977            2 :                                       features%kin_input_t)
    1978              :             ELSE
    1979              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "density", &
    1980            0 :                                       cached_layout%chunk_density_t)
    1981              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "grad", &
    1982            0 :                                       cached_layout%chunk_grad_t)
    1983              :                CALL torch_dict_insert(cached_layout%chunk_inputs, "kin", &
    1984            0 :                                       cached_layout%chunk_kin_t)
    1985              :             END IF
    1986              :             CALL torch_dict_insert(cached_layout%chunk_inputs, "coarse_0_atomic_coords", &
    1987            2 :                                    cached_layout%chunk_coarse_0_atomic_coords_t)
    1988            2 :             cached_layout%chunk_inputs_use_collapsed_rks = features%uses_collapsed_rks_dynamic
    1989            2 :             cached_layout%chunk_inputs_active = .TRUE.
    1990              :          END IF
    1991            2 :          features%inputs = cached_layout%chunk_inputs
    1992            2 :          features%owns_inputs = .FALSE.
    1993            2 :          features%coarse_0_atomic_coords_t = cached_layout%chunk_coarse_0_atomic_coords_t
    1994              :       ELSE
    1995          142 :          IF (.NOT. requires_stress_grad .AND. .NOT. my_requires_weight_grad) THEN
    1996          132 :             features%grid_coords_t = cached_layout%grid_coords_t
    1997          132 :             features%grid_weights_t = cached_layout%grid_weights_t
    1998          132 :             features%atomic_grid_weights_t = cached_layout%atomic_grid_weights_t
    1999              :          END IF
    2000          142 :          features%atomic_grid_sizes_t = cached_layout%atomic_grid_sizes_t
    2001          142 :          features%atomic_grid_size_bound_shape_t = cached_layout%atomic_grid_size_bound_shape_t
    2002          142 :          features%local_feature_indices_t = cached_layout%local_feature_indices_t
    2003              : 
    2004              :          CALL torch_tensor_reset_from_array(cached_layout%density_t, features%density, &
    2005          142 :                                             requires_grad=requires_grad)
    2006          142 :          features%density_t = cached_layout%density_t
    2007              :          CALL torch_tensor_reset_from_array(cached_layout%grad_t, features%grad, &
    2008          142 :                                             requires_grad=requires_grad)
    2009          142 :          features%grad_t = cached_layout%grad_t
    2010              :          CALL torch_tensor_reset_from_array(cached_layout%kin_t, features%kin, &
    2011          142 :                                             requires_grad=requires_grad)
    2012          142 :          features%kin_t = cached_layout%kin_t
    2013          142 :          cached_layout%dynamic_tensors_active = .TRUE.
    2014              : 
    2015          142 :          IF (requires_coordinate_grad .OR. requires_stress_grad .OR. my_requires_weight_grad) THEN
    2016           16 :             IF (requires_stress_grad .OR. my_requires_weight_grad) THEN
    2017           10 :                CALL torch_dict_create(features%inputs)
    2018           10 :                IF (requires_stress_grad) THEN
    2019            8 :                   CALL torch_tensor_from_array(features%grid_coords_t, features%grid_coords)
    2020            8 :                   CALL torch_tensor_to_device_leaf(features%grid_coords_t, .TRUE.)
    2021            8 :                   CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    2022            8 :                   features%owns_grid_coordinate_tensor = .TRUE.
    2023              :                ELSE
    2024            2 :                   features%grid_coords_t = cached_layout%grid_coords_t
    2025            2 :                   CALL torch_dict_insert(features%inputs, "grid_coords", features%grid_coords_t)
    2026              :                END IF
    2027           10 :                CALL torch_tensor_from_array(features%grid_weights_t, features%grid_weights)
    2028           10 :                CALL torch_tensor_to_device_leaf(features%grid_weights_t, .TRUE.)
    2029              :                CALL torch_tensor_from_array(features%atomic_grid_weights_t, &
    2030           10 :                                             features%atomic_grid_weights)
    2031           10 :                CALL torch_tensor_to_device_leaf(features%atomic_grid_weights_t, .TRUE.)
    2032           10 :                CALL torch_dict_insert(features%inputs, "grid_weights", features%grid_weights_t)
    2033              :                CALL torch_dict_insert(features%inputs, "atomic_grid_weights", &
    2034           10 :                                       features%atomic_grid_weights_t)
    2035              :                CALL torch_dict_insert(features%inputs, "atomic_grid_sizes", &
    2036           10 :                                       features%atomic_grid_sizes_t)
    2037              :                CALL torch_dict_insert(features%inputs, "atomic_grid_size_bound_shape", &
    2038           10 :                                       features%atomic_grid_size_bound_shape_t)
    2039           10 :                features%owns_weight_tensors = .TRUE.
    2040              :             ELSE
    2041            6 :                CALL torch_dict_clone(cached_layout%static_inputs, features%inputs)
    2042              :             END IF
    2043           16 :             CALL torch_dict_insert(features%inputs, "density", features%density_t)
    2044           16 :             CALL torch_dict_insert(features%inputs, "grad", features%grad_t)
    2045           16 :             CALL torch_dict_insert(features%inputs, "kin", features%kin_t)
    2046              :          ELSE
    2047          126 :             IF (.NOT. cached_layout%inputs_active) THEN
    2048           50 :                CALL torch_dict_clone(cached_layout%static_inputs, cached_layout%inputs)
    2049           50 :                CALL torch_dict_insert(cached_layout%inputs, "density", cached_layout%density_t)
    2050           50 :                CALL torch_dict_insert(cached_layout%inputs, "grad", cached_layout%grad_t)
    2051           50 :                CALL torch_dict_insert(cached_layout%inputs, "kin", cached_layout%kin_t)
    2052              :                CALL torch_dict_insert(cached_layout%inputs, "coarse_0_atomic_coords", &
    2053           50 :                                       cached_layout%coarse_0_atomic_coords_t)
    2054           50 :                cached_layout%inputs_active = .TRUE.
    2055              :             END IF
    2056          126 :             features%inputs = cached_layout%inputs
    2057          126 :             features%owns_inputs = .FALSE.
    2058          126 :             features%coarse_0_atomic_coords_t = cached_layout%coarse_0_atomic_coords_t
    2059              :          END IF
    2060              :       END IF
    2061              : 
    2062          144 :       IF (requires_coordinate_grad .OR. requires_stress_grad) THEN
    2063           16 :          CPASSERT(.NOT. use_atom_chunks)
    2064              :          CALL torch_tensor_from_array(features%coarse_0_atomic_coords_t, &
    2065           16 :                                       features%coarse_0_atomic_coords)
    2066           16 :          CALL torch_tensor_to_device_leaf(features%coarse_0_atomic_coords_t, .TRUE.)
    2067              :          CALL torch_dict_insert(features%inputs, "coarse_0_atomic_coords", &
    2068           16 :                                 features%coarse_0_atomic_coords_t)
    2069           16 :          features%owns_coordinate_tensor = .TRUE.
    2070              :       END IF
    2071              : 
    2072          144 :    END SUBROUTINE add_feature_tensors
    2073              : 
    2074              : ! **************************************************************************************************
    2075              : !> \brief Return the Cartesian coordinate of a regular GPW grid point.
    2076              : !> \param pw_grid ...
    2077              : !> \param index ...
    2078              : !> \return ...
    2079              : ! **************************************************************************************************
    2080      1148674 :    FUNCTION grid_coordinate(pw_grid, index) RESULT(coord)
    2081              :       TYPE(pw_grid_type), POINTER                        :: pw_grid
    2082              :       INTEGER, DIMENSION(3), INTENT(IN)                  :: index
    2083              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    2084              : 
    2085              :       INTEGER, DIMENSION(3)                              :: relative_index
    2086              : 
    2087      4594696 :       relative_index = index - pw_grid%bounds(1, :)
    2088              :       coord = REAL(relative_index(1), KIND=dp)*pw_grid%dh(:, 1) + &
    2089              :               REAL(relative_index(2), KIND=dp)*pw_grid%dh(:, 2) + &
    2090      4594696 :               REAL(relative_index(3), KIND=dp)*pw_grid%dh(:, 3)
    2091              : 
    2092      1148674 :    END FUNCTION grid_coordinate
    2093              : 
    2094              : ! **************************************************************************************************
    2095              : !> \brief Build Becke-like smooth atom weights for one native-grid point.
    2096              : !> \param grid_point ...
    2097              : !> \param atom_coords ...
    2098              : !> \param cell ...
    2099              : !> \param weights ...
    2100              : !> \param atom_image_coords ...
    2101              : !> \param distances ...
    2102              : ! **************************************************************************************************
    2103        41472 :    SUBROUTINE smooth_atom_partition(grid_point, atom_coords, cell, weights, atom_image_coords, &
    2104        41472 :                                     distances)
    2105              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: grid_point
    2106              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: atom_coords
    2107              :       TYPE(cell_type), POINTER                           :: cell
    2108              :       REAL(KIND=dp), DIMENSION(:), INTENT(OUT)           :: weights
    2109              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(OUT)        :: atom_image_coords
    2110              :       REAL(KIND=dp), DIMENSION(:), INTENT(OUT)           :: distances
    2111              : 
    2112              :       INTEGER                                            :: iatom, jatom, natom
    2113              :       REAL(KIND=dp)                                      :: mu, rab, rsum, switch, total
    2114              :       REAL(KIND=dp), DIMENSION(3)                        :: rij
    2115        82944 :       REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2))  :: partition_atom_coords
    2116              : 
    2117        41472 :       natom = SIZE(atom_coords, 2)
    2118        41472 :       CPASSERT(SIZE(weights) == natom)
    2119        41472 :       CPASSERT(SIZE(atom_image_coords, 1) == 3)
    2120        41472 :       CPASSERT(SIZE(atom_image_coords, 2) == natom)
    2121        41472 :       CPASSERT(SIZE(distances) == natom)
    2122              : 
    2123       124416 :       DO iatom = 1, natom
    2124              :          atom_image_coords(:, iatom) = &
    2125        82944 :             nearest_image_coordinate(atom_coords(:, iatom), grid_point, cell)
    2126              :          partition_atom_coords(:, iatom) = &
    2127        82944 :             nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
    2128       331776 :          rij = grid_point - partition_atom_coords(:, iatom)
    2129       373248 :          distances(iatom) = SQRT(SUM(rij**2))
    2130              :       END DO
    2131              : 
    2132       124416 :       weights = 1.0_dp
    2133        82944 :       DO iatom = 1, natom - 1
    2134       124416 :          DO jatom = iatom + 1, natom
    2135       165888 :             rij = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
    2136       165888 :             rab = SQRT(SUM(rij**2))
    2137        41472 :             IF (rab <= layout_tol) CYCLE
    2138        41472 :             mu = (distances(iatom) - distances(jatom))/rab
    2139        41472 :             mu = MAX(-1.0_dp, MIN(1.0_dp, mu))
    2140        41472 :             switch = 0.5_dp*(1.0_dp - becke_shape(mu))
    2141        41472 :             weights(iatom) = weights(iatom)*switch
    2142        82944 :             weights(jatom) = weights(jatom)*(1.0_dp - switch)
    2143              :          END DO
    2144              :       END DO
    2145              : 
    2146       124416 :       total = SUM(weights)
    2147        41472 :       IF (total > 0.0_dp) THEN
    2148       124416 :          weights = weights/total
    2149              :       ELSE
    2150              :          rsum = HUGE(1.0_dp)
    2151              :          jatom = 1
    2152            0 :          DO iatom = 1, natom
    2153            0 :             IF (distances(iatom) < rsum) THEN
    2154            0 :                rsum = distances(iatom)
    2155            0 :                jatom = iatom
    2156              :             END IF
    2157              :          END DO
    2158            0 :          weights = 0.0_dp
    2159            0 :          weights(jatom) = 1.0_dp
    2160              :       END IF
    2161              : 
    2162        41472 :    END SUBROUTINE smooth_atom_partition
    2163              : 
    2164              : ! **************************************************************************************************
    2165              : !> \brief Build smooth atom weights and their atom/cell deformation derivatives.
    2166              : !> \param grid_point ...
    2167              : !> \param atom_coords ...
    2168              : !> \param cell ...
    2169              : !> \param weights ...
    2170              : !> \param included ...
    2171              : !> \param dweights_datom ...
    2172              : !> \param dweights_dstrain ...
    2173              : ! **************************************************************************************************
    2174        41472 :    SUBROUTINE skala_gpw_smooth_partition_derivatives(grid_point, atom_coords, cell, &
    2175        41472 :                                                      weights, included, dweights_datom, &
    2176        41472 :                                                      dweights_dstrain)
    2177              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: grid_point
    2178              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: atom_coords
    2179              :       TYPE(cell_type), POINTER                           :: cell
    2180              :       REAL(KIND=dp), DIMENSION(:), INTENT(OUT)           :: weights
    2181              :       LOGICAL, DIMENSION(:), INTENT(OUT)                 :: included
    2182              :       REAL(KIND=dp), DIMENSION(:, :, :), INTENT(OUT)     :: dweights_datom, dweights_dstrain
    2183              : 
    2184              :       INTEGER                                            :: iatom, idir, jatom, jdir, natom
    2185              :       REAL(KIND=dp)                                      :: dist_diff, ds_dmu, included_sum, mu, &
    2186              :                                                             mu_raw, one_minus_switch, rab, rsum, &
    2187              :                                                             switch, total
    2188              :       REAL(KIND=dp), DIMENSION(3)                        :: dmu_atom_i, dmu_atom_j, ds_atom_i, &
    2189              :                                                             ds_atom_j, pair, unit_pair
    2190              :       REAL(KIND=dp), DIMENSION(3, 3)                     :: dmu_strain, ds_strain, mean_strain
    2191              :       REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2), &
    2192        82944 :          SIZE(atom_coords, 2))                           :: log_weight_atom
    2193        82944 :       REAL(KIND=dp), DIMENSION(3, SIZE(atom_coords, 2))  :: mean_atom, partition_atom_coords, rvecs, &
    2194        82944 :                                                             unit_rvecs
    2195              :       REAL(KIND=dp), &
    2196        82944 :          DIMENSION(3, 3, SIZE(atom_coords, 2))           :: log_weight_strain
    2197        82944 :       REAL(KIND=dp), DIMENSION(SIZE(atom_coords, 2))     :: distances, normalized_weights, &
    2198        41472 :                                                             raw_weights
    2199              : 
    2200        41472 :       natom = SIZE(atom_coords, 2)
    2201        41472 :       CPASSERT(SIZE(weights) == natom)
    2202        41472 :       CPASSERT(SIZE(included) == natom)
    2203        41472 :       CPASSERT(SIZE(dweights_datom, 1) == 3)
    2204        41472 :       CPASSERT(SIZE(dweights_datom, 2) == natom)
    2205        41472 :       CPASSERT(SIZE(dweights_datom, 3) == natom)
    2206        41472 :       CPASSERT(SIZE(dweights_dstrain, 1) == 3)
    2207        41472 :       CPASSERT(SIZE(dweights_dstrain, 2) == 3)
    2208        41472 :       CPASSERT(SIZE(dweights_dstrain, 3) == natom)
    2209              : 
    2210       124416 :       weights = 0.0_dp
    2211       124416 :       included = .FALSE.
    2212       787968 :       dweights_datom = 0.0_dp
    2213      1119744 :       dweights_dstrain = 0.0_dp
    2214       124416 :       raw_weights = 1.0_dp
    2215       787968 :       log_weight_atom = 0.0_dp
    2216      1119744 :       log_weight_strain = 0.0_dp
    2217              : 
    2218       124416 :       DO iatom = 1, natom
    2219              :          partition_atom_coords(:, iatom) = &
    2220        82944 :             nearest_atom_image_coordinate(atom_coords(:, iatom), grid_point, cell)
    2221       331776 :          rvecs(:, iatom) = grid_point - partition_atom_coords(:, iatom)
    2222       331776 :          distances(iatom) = SQRT(SUM(rvecs(:, iatom)**2))
    2223       124416 :          IF (distances(iatom) > layout_tol) THEN
    2224       331776 :             unit_rvecs(:, iatom) = rvecs(:, iatom)/distances(iatom)
    2225              :          ELSE
    2226            0 :             unit_rvecs(:, iatom) = 0.0_dp
    2227              :          END IF
    2228              :       END DO
    2229              : 
    2230        82944 :       DO iatom = 1, natom - 1
    2231       124416 :          DO jatom = iatom + 1, natom
    2232       165888 :             pair = partition_atom_coords(:, iatom) - partition_atom_coords(:, jatom)
    2233       165888 :             rab = SQRT(SUM(pair**2))
    2234        41472 :             IF (rab <= layout_tol) CYCLE
    2235       165888 :             unit_pair = pair/rab
    2236        41472 :             dist_diff = distances(iatom) - distances(jatom)
    2237        41472 :             mu_raw = dist_diff/rab
    2238        41472 :             mu = MAX(-1.0_dp, MIN(1.0_dp, mu_raw))
    2239        41472 :             switch = 0.5_dp*(1.0_dp - becke_shape(mu))
    2240        41472 :             one_minus_switch = 1.0_dp - switch
    2241              : 
    2242        41472 :             IF (ABS(mu_raw) < 1.0_dp) THEN
    2243        41433 :                ds_dmu = -0.5_dp*becke_shape_derivative(mu)
    2244              :             ELSE
    2245              :                ds_dmu = 0.0_dp
    2246              :             END IF
    2247        41433 :             IF (ABS(ds_dmu) > 0.0_dp .AND. switch > TINY(1.0_dp) .AND. &
    2248              :                 one_minus_switch > TINY(1.0_dp)) THEN
    2249       165648 :                dmu_atom_i = (-unit_rvecs(:, iatom)*rab - dist_diff*unit_pair)/rab**2
    2250       165648 :                dmu_atom_j = (unit_rvecs(:, jatom)*rab + dist_diff*unit_pair)/rab**2
    2251       165648 :                ds_atom_i = ds_dmu*dmu_atom_i
    2252       165648 :                ds_atom_j = ds_dmu*dmu_atom_j
    2253              :                log_weight_atom(:, iatom, iatom) = &
    2254       165648 :                   log_weight_atom(:, iatom, iatom) + ds_atom_i/switch
    2255              :                log_weight_atom(:, iatom, jatom) = &
    2256       165648 :                   log_weight_atom(:, iatom, jatom) - ds_atom_i/one_minus_switch
    2257              :                log_weight_atom(:, jatom, iatom) = &
    2258       165648 :                   log_weight_atom(:, jatom, iatom) + ds_atom_j/switch
    2259              :                log_weight_atom(:, jatom, jatom) = &
    2260       165648 :                   log_weight_atom(:, jatom, jatom) - ds_atom_j/one_minus_switch
    2261              : 
    2262       165648 :                DO idir = 1, 3
    2263       538356 :                   DO jdir = 1, 3
    2264              :                      dmu_strain(idir, jdir) = &
    2265              :                         ((unit_rvecs(idir, iatom)*rvecs(jdir, iatom) - &
    2266              :                           unit_rvecs(idir, jatom)*rvecs(jdir, jatom))*rab - &
    2267       496944 :                          dist_diff*unit_pair(idir)*pair(jdir))/rab**2
    2268              :                   END DO
    2269              :                END DO
    2270       538356 :                ds_strain = ds_dmu*dmu_strain
    2271              :                log_weight_strain(:, :, iatom) = &
    2272       538356 :                   log_weight_strain(:, :, iatom) + ds_strain/switch
    2273              :                log_weight_strain(:, :, jatom) = &
    2274       538356 :                   log_weight_strain(:, :, jatom) - ds_strain/one_minus_switch
    2275              :             END IF
    2276              : 
    2277        41472 :             raw_weights(iatom) = raw_weights(iatom)*switch
    2278        82944 :             raw_weights(jatom) = raw_weights(jatom)*one_minus_switch
    2279              :          END DO
    2280              :       END DO
    2281              : 
    2282       124416 :       total = SUM(raw_weights)
    2283        41472 :       IF (total > 0.0_dp) THEN
    2284       124416 :          normalized_weights = raw_weights/total
    2285       124416 :          included = normalized_weights > smooth_partition_eps
    2286              :       ELSE
    2287              :          rsum = HUGE(1.0_dp)
    2288              :          jatom = 1
    2289            0 :          DO iatom = 1, natom
    2290            0 :             IF (distances(iatom) < rsum) THEN
    2291            0 :                rsum = distances(iatom)
    2292            0 :                jatom = iatom
    2293              :             END IF
    2294              :          END DO
    2295            0 :          included(jatom) = .TRUE.
    2296            0 :          weights(jatom) = 1.0_dp
    2297            0 :          RETURN
    2298              :       END IF
    2299              : 
    2300       124416 :       included_sum = SUM(raw_weights, MASK=included)
    2301        41472 :       IF (included_sum <= 0.0_dp) THEN
    2302              :          rsum = HUGE(1.0_dp)
    2303              :          jatom = 1
    2304            0 :          DO iatom = 1, natom
    2305            0 :             IF (distances(iatom) < rsum) THEN
    2306            0 :                rsum = distances(iatom)
    2307            0 :                jatom = iatom
    2308              :             END IF
    2309              :          END DO
    2310            0 :          included = .FALSE.
    2311            0 :          included(jatom) = .TRUE.
    2312            0 :          weights = 0.0_dp
    2313            0 :          weights(jatom) = 1.0_dp
    2314            0 :          RETURN
    2315              :       END IF
    2316              : 
    2317       124416 :       DO iatom = 1, natom
    2318       124416 :          IF (included(iatom)) weights(iatom) = raw_weights(iatom)/included_sum
    2319              :       END DO
    2320              : 
    2321       373248 :       mean_atom = 0.0_dp
    2322        41472 :       mean_strain = 0.0_dp
    2323       124416 :       DO iatom = 1, natom
    2324        82944 :          IF (.NOT. included(iatom)) CYCLE
    2325      1075542 :          mean_strain = mean_strain + weights(iatom)*log_weight_strain(:, :, iatom)
    2326       289674 :          DO jatom = 1, natom
    2327              :             mean_atom(:, jatom) = mean_atom(:, jatom) + &
    2328       744816 :                                   weights(iatom)*log_weight_atom(:, jatom, iatom)
    2329              :          END DO
    2330              :       END DO
    2331              : 
    2332       124416 :       DO iatom = 1, natom
    2333        82944 :          IF (.NOT. included(iatom)) CYCLE
    2334              :          dweights_dstrain(:, :, iatom) = &
    2335      1075542 :             weights(iatom)*(log_weight_strain(:, :, iatom) - mean_strain)
    2336       289674 :          DO jatom = 1, natom
    2337              :             dweights_datom(:, jatom, iatom) = &
    2338       744816 :                weights(iatom)*(log_weight_atom(:, jatom, iatom) - mean_atom(:, jatom))
    2339              :          END DO
    2340              :       END DO
    2341              : 
    2342              :    END SUBROUTINE skala_gpw_smooth_partition_derivatives
    2343              : 
    2344              : ! **************************************************************************************************
    2345              : !> \brief Becke fuzzy-cell shape function.
    2346              : !> \param mu ...
    2347              : !> \return ...
    2348              : ! **************************************************************************************************
    2349        82944 :    PURE FUNCTION becke_shape(mu) RESULT(val)
    2350              :       REAL(KIND=dp), INTENT(IN)                          :: mu
    2351              :       REAL(KIND=dp)                                      :: val
    2352              : 
    2353              :       INTEGER                                            :: iter
    2354              : 
    2355        82944 :       val = mu
    2356       331776 :       DO iter = 1, 3
    2357       331776 :          val = 0.5_dp*val*(3.0_dp - val*val)
    2358              :       END DO
    2359              : 
    2360        82944 :    END FUNCTION becke_shape
    2361              : 
    2362              : ! **************************************************************************************************
    2363              : !> \brief Derivative of the Becke fuzzy-cell shape function.
    2364              : !> \param mu ...
    2365              : !> \return ...
    2366              : ! **************************************************************************************************
    2367        41433 :    PURE FUNCTION becke_shape_derivative(mu) RESULT(val)
    2368              :       REAL(KIND=dp), INTENT(IN)                          :: mu
    2369              :       REAL(KIND=dp)                                      :: val
    2370              : 
    2371              :       INTEGER                                            :: iter
    2372              :       REAL(KIND=dp)                                      :: x
    2373              : 
    2374        41433 :       x = mu
    2375        41433 :       val = 1.0_dp
    2376       165732 :       DO iter = 1, 3
    2377       124299 :          val = val*1.5_dp*(1.0_dp - x*x)
    2378       165732 :          x = 0.5_dp*x*(3.0_dp - x*x)
    2379              :       END DO
    2380              : 
    2381        41433 :    END FUNCTION becke_shape_derivative
    2382              : 
    2383              : ! **************************************************************************************************
    2384              : !> \brief Return the atom image nearest to a regular-grid point.
    2385              : !> \param atom_coord ...
    2386              : !> \param grid_point ...
    2387              : !> \param cell ...
    2388              : !> \return ...
    2389              : ! **************************************************************************************************
    2390       165888 :    FUNCTION nearest_atom_image_coordinate(atom_coord, grid_point, cell) RESULT(coord)
    2391              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: atom_coord, grid_point
    2392              :       TYPE(cell_type), POINTER                           :: cell
    2393              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    2394              : 
    2395              :       REAL(KIND=dp)                                      :: dx, dy, dz
    2396              : 
    2397       165888 :       IF (cell%orthorhombic) THEN
    2398       165888 :          dx = atom_coord(1) - grid_point(1)
    2399       165888 :          dy = atom_coord(2) - grid_point(2)
    2400       165888 :          dz = atom_coord(3) - grid_point(3)
    2401       165888 :          dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    2402       165888 :          dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    2403       165888 :          dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    2404       663552 :          coord = grid_point + [dx, dy, dz]
    2405              :       ELSE
    2406            0 :          coord = grid_point + pbc(grid_point, atom_coord, cell)
    2407              :       END IF
    2408              : 
    2409       165888 :    END FUNCTION nearest_atom_image_coordinate
    2410              : 
    2411              : ! **************************************************************************************************
    2412              : !> \brief Return the grid-point image nearest to the owning atom coordinate.
    2413              : !> \param owner_coord ...
    2414              : !> \param grid_point ...
    2415              : !> \param cell ...
    2416              : !> \return ...
    2417              : ! **************************************************************************************************
    2418      1190146 :    FUNCTION nearest_image_coordinate(owner_coord, grid_point, cell) RESULT(coord)
    2419              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: owner_coord, grid_point
    2420              :       TYPE(cell_type), POINTER                           :: cell
    2421              :       REAL(KIND=dp), DIMENSION(3)                        :: coord
    2422              : 
    2423              :       REAL(KIND=dp)                                      :: dx, dy, dz
    2424              : 
    2425      1190146 :       IF (cell%orthorhombic) THEN
    2426      1190146 :          dx = grid_point(1) - owner_coord(1)
    2427      1190146 :          dy = grid_point(2) - owner_coord(2)
    2428      1190146 :          dz = grid_point(3) - owner_coord(3)
    2429      1190146 :          dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    2430      1190146 :          dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    2431      1190146 :          dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    2432      4760584 :          coord = owner_coord + [dx, dy, dz]
    2433              :       ELSE
    2434            0 :          coord = owner_coord + pbc(owner_coord, grid_point, cell)
    2435              :       END IF
    2436              : 
    2437      1190146 :    END FUNCTION nearest_image_coordinate
    2438              : 
    2439              : ! **************************************************************************************************
    2440              : !> \brief Assign a grid point to the nearest periodic atom.
    2441              : !> \param grid_point ...
    2442              : !> \param atom_coords ...
    2443              : !> \param cell ...
    2444              : !> \return ...
    2445              : ! **************************************************************************************************
    2446      1107202 :    FUNCTION nearest_atom(grid_point, atom_coords, cell) RESULT(owner)
    2447              :       REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: grid_point
    2448              :       REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: atom_coords
    2449              :       TYPE(cell_type), POINTER                           :: cell
    2450              :       INTEGER                                            :: owner
    2451              : 
    2452              :       INTEGER                                            :: iatom
    2453              :       REAL(KIND=dp)                                      :: best_r2, dx, dy, dz, r2
    2454              :       REAL(KIND=dp), DIMENSION(3)                        :: rij
    2455              : 
    2456      1107202 :       owner = 1
    2457      1107202 :       best_r2 = HUGE(1.0_dp)
    2458      1107202 :       IF (cell%orthorhombic) THEN
    2459      4291418 :          DO iatom = 1, SIZE(atom_coords, 2)
    2460      3184216 :             dx = grid_point(1) - atom_coords(1, iatom)
    2461      3184216 :             dy = grid_point(2) - atom_coords(2, iatom)
    2462      3184216 :             dz = grid_point(3) - atom_coords(3, iatom)
    2463      3184216 :             dx = dx - cell%hmat(1, 1)*cell%perd(1)*ANINT(cell%h_inv(1, 1)*dx)
    2464      3184216 :             dy = dy - cell%hmat(2, 2)*cell%perd(2)*ANINT(cell%h_inv(2, 2)*dy)
    2465      3184216 :             dz = dz - cell%hmat(3, 3)*cell%perd(3)*ANINT(cell%h_inv(3, 3)*dz)
    2466      3184216 :             r2 = dx*dx + dy*dy + dz*dz
    2467      4291418 :             IF (r2 < best_r2) THEN
    2468      1967760 :                best_r2 = r2
    2469      1967760 :                owner = iatom
    2470              :             END IF
    2471              :          END DO
    2472              :       ELSE
    2473            0 :          DO iatom = 1, SIZE(atom_coords, 2)
    2474            0 :             rij = pbc(grid_point, atom_coords(:, iatom), cell)
    2475            0 :             r2 = SUM(rij**2)
    2476            0 :             IF (r2 < best_r2) THEN
    2477            0 :                best_r2 = r2
    2478            0 :                owner = iatom
    2479              :             END IF
    2480              :          END DO
    2481              :       END IF
    2482              : 
    2483      1107202 :    END FUNCTION nearest_atom
    2484              : 
    2485            0 : END MODULE skala_gpw_features
        

Generated by: LCOV version 2.0-1