LCOV - code coverage report
Current view: top level - src - skala_torch_api.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:561f475) Lines: 61.3 % 62 38
Test Date: 2026-06-21 06:48:54 Functions: 33.3 % 9 3

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Small CP2K wrapper around the SKALA TorchScript functional protocol.
      10              : ! **************************************************************************************************
      11              : MODULE skala_torch_api
      12              : #if defined (__HAS_IEEE_EXCEPTIONS)
      13              :    USE ieee_exceptions, ONLY: ieee_all, &
      14              :                               ieee_get_halting_mode, &
      15              :                               ieee_set_halting_mode
      16              : #endif
      17              :    USE kinds, ONLY: default_string_length, &
      18              :                     dp
      19              :    USE string_utilities, ONLY: uppercase
      20              :    USE torch_api, ONLY: &
      21              :       torch_dict_type, torch_model_forward_mol_tensor, torch_model_load, &
      22              :       torch_model_read_metadata, torch_model_release, torch_model_type, &
      23              :       torch_tensor_item_double, torch_tensor_release, torch_tensor_type, &
      24              :       torch_tensor_weighted_sum
      25              : #include "./base/base_uses.f90"
      26              : 
      27              :    IMPLICIT NONE
      28              : 
      29              :    PRIVATE
      30              : 
      31              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_torch_api'
      32              : 
      33              :    PUBLIC :: skala_torch_model_type, skala_torch_model_load, skala_torch_model_release
      34              :    PUBLIC :: skala_torch_model_get_exc, skala_torch_model_get_exc_density
      35              :    PUBLIC :: skala_torch_model_needs_feature, skala_torch_model_protocol_version
      36              : 
      37              :    TYPE skala_torch_model_type
      38              :       PRIVATE
      39              :       INTEGER                                            :: protocol_version = -1
      40              :       CHARACTER(len=default_string_length), ALLOCATABLE, &
      41              :          DIMENSION(:)                                    :: features
      42              :       TYPE(torch_model_type)                             :: torch_model
      43              :    END TYPE skala_torch_model_type
      44              : 
      45              : CONTAINS
      46              : 
      47              : ! **************************************************************************************************
      48              : !> \brief Load a SKALA TorchScript model and its feature metadata.
      49              : !> \param model ...
      50              : !> \param filename ...
      51              : ! **************************************************************************************************
      52           44 :    SUBROUTINE skala_torch_model_load(model, filename)
      53              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
      54              :       CHARACTER(len=*), INTENT(IN)                       :: filename
      55              : 
      56           44 :       CHARACTER(:), ALLOCATABLE                          :: features_json, protocol_string
      57              :       INTEGER                                            :: ios
      58              : 
      59           44 :       CALL torch_model_load(model%torch_model, filename)
      60           44 :       protocol_string = torch_model_read_metadata(filename, "protocol_version")
      61           44 :       features_json = torch_model_read_metadata(filename, "features")
      62           44 :       READ (protocol_string, *, IOSTAT=ios) model%protocol_version
      63           44 :       IF (ios /= 0) CPABORT("Could not parse SKALA TorchScript protocol_version metadata")
      64           44 :       IF (model%protocol_version /= 2) THEN
      65            0 :          CPABORT("Unsupported SKALA TorchScript protocol version")
      66              :       END IF
      67              : 
      68           44 :       CALL parse_feature_list(features_json, model%features)
      69              : 
      70           44 :    END SUBROUTINE skala_torch_model_load
      71              : 
      72              : ! **************************************************************************************************
      73              : !> \brief Release a loaded SKALA TorchScript model.
      74              : !> \param model ...
      75              : ! **************************************************************************************************
      76            0 :    SUBROUTINE skala_torch_model_release(model)
      77              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
      78              : 
      79            0 :       CALL torch_model_release(model%torch_model)
      80            0 :       IF (ALLOCATED(model%features)) DEALLOCATE (model%features)
      81            0 :       model%protocol_version = -1
      82              : 
      83            0 :    END SUBROUTINE skala_torch_model_release
      84              : 
      85              : ! **************************************************************************************************
      86              : !> \brief Check whether a loaded SKALA model requests a feature.
      87              : !> \param model ...
      88              : !> \param feature ...
      89              : !> \return ...
      90              : ! **************************************************************************************************
      91            0 :    FUNCTION skala_torch_model_needs_feature(model, feature) RESULT(needs_feature)
      92              :       TYPE(skala_torch_model_type), INTENT(IN)           :: model
      93              :       CHARACTER(len=*), INTENT(IN)                       :: feature
      94              :       LOGICAL                                            :: needs_feature
      95              : 
      96              :       CHARACTER(len=default_string_length)               :: feature_key, model_feature
      97              :       INTEGER                                            :: i
      98              : 
      99            0 :       feature_key = ADJUSTL(feature)
     100            0 :       CALL uppercase(feature_key)
     101              : 
     102            0 :       needs_feature = .FALSE.
     103            0 :       IF (.NOT. ALLOCATED(model%features)) RETURN
     104              : 
     105            0 :       DO i = 1, SIZE(model%features)
     106            0 :          model_feature = ADJUSTL(model%features(i))
     107            0 :          CALL uppercase(model_feature)
     108            0 :          IF (TRIM(model_feature) == TRIM(feature_key)) THEN
     109            0 :             needs_feature = .TRUE.
     110              :             RETURN
     111              :          END IF
     112              :       END DO
     113              : 
     114            0 :    END FUNCTION skala_torch_model_needs_feature
     115              : 
     116              : ! **************************************************************************************************
     117              : !> \brief Return the loaded SKALA TorchScript protocol version.
     118              : !> \param model ...
     119              : !> \return ...
     120              : ! **************************************************************************************************
     121            0 :    FUNCTION skala_torch_model_protocol_version(model) RESULT(protocol_version)
     122              :       TYPE(skala_torch_model_type), INTENT(IN)           :: model
     123              :       INTEGER                                            :: protocol_version
     124              : 
     125            0 :       protocol_version = model%protocol_version
     126              : 
     127            0 :    END FUNCTION skala_torch_model_protocol_version
     128              : 
     129              : ! **************************************************************************************************
     130              : !> \brief Evaluate the SKALA exchange-correlation energy density.
     131              : !> \param model ...
     132              : !> \param inputs ...
     133              : !> \param exc_density ...
     134              : ! **************************************************************************************************
     135            0 :    SUBROUTINE skala_torch_model_get_exc_density(model, inputs, exc_density)
     136              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
     137              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     138              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: exc_density
     139              : 
     140              : #if defined (__HAS_IEEE_EXCEPTIONS)
     141              :       LOGICAL, DIMENSION(5)                              :: ieee_halt
     142              : 
     143              :       CALL ieee_get_halting_mode(IEEE_ALL, ieee_halt)
     144              :       CALL ieee_set_halting_mode(IEEE_ALL, .FALSE.)
     145              : #endif
     146            0 :       CALL torch_model_forward_mol_tensor(model%torch_model, "get_exc_density", inputs, exc_density)
     147              : #if defined (__HAS_IEEE_EXCEPTIONS)
     148              :       CALL ieee_set_halting_mode(IEEE_ALL, ieee_halt)
     149              : #endif
     150              : 
     151            0 :    END SUBROUTINE skala_torch_model_get_exc_density
     152              : 
     153              : ! **************************************************************************************************
     154              : !> \brief Evaluate the weighted SKALA exchange-correlation energy.
     155              : !> \param model ...
     156              : !> \param inputs ...
     157              : !> \param grid_weights ...
     158              : !> \param exc_tensor ...
     159              : !> \param exc ...
     160              : ! **************************************************************************************************
     161          154 :    SUBROUTINE skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
     162              :       TYPE(skala_torch_model_type), INTENT(INOUT)        :: model
     163              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     164              :       TYPE(torch_tensor_type), INTENT(IN)                :: grid_weights
     165              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: exc_tensor
     166              :       REAL(KIND=dp), INTENT(OUT)                         :: exc
     167              : 
     168              :       TYPE(torch_tensor_type)                            :: exc_density
     169              : 
     170              : #if defined (__HAS_IEEE_EXCEPTIONS)
     171              :       LOGICAL, DIMENSION(5)                              :: ieee_halt
     172              : 
     173              :       CALL ieee_get_halting_mode(IEEE_ALL, ieee_halt)
     174              :       CALL ieee_set_halting_mode(IEEE_ALL, .FALSE.)
     175              : #endif
     176          154 :       CALL torch_model_forward_mol_tensor(model%torch_model, "get_exc_density", inputs, exc_density)
     177          154 :       CALL torch_tensor_weighted_sum(exc_density, grid_weights, exc_tensor)
     178          154 :       CALL torch_tensor_release(exc_density)
     179          154 :       exc = torch_tensor_item_double(exc_tensor)
     180              : #if defined (__HAS_IEEE_EXCEPTIONS)
     181              :       CALL ieee_set_halting_mode(IEEE_ALL, ieee_halt)
     182              : #endif
     183              : 
     184          154 :    END SUBROUTINE skala_torch_model_get_exc
     185              : 
     186              : ! **************************************************************************************************
     187              : !> \brief Parse a TorchScript extra_files JSON list of feature names.
     188              : !> \param features_json ...
     189              : !> \param features ...
     190              : ! **************************************************************************************************
     191           44 :    SUBROUTINE parse_feature_list(features_json, features)
     192              :       CHARACTER(len=*), INTENT(IN)                       :: features_json
     193              :       CHARACTER(len=default_string_length), &
     194              :          ALLOCATABLE, DIMENSION(:), INTENT(OUT)          :: features
     195              : 
     196              :       INTEGER                                            :: end_pos, feature_count, i, pos, quote1, &
     197              :                                                             quote2, start_pos
     198              : 
     199           44 :       feature_count = 0
     200           44 :       pos = 1
     201          396 :       DO
     202          440 :          quote1 = INDEX(features_json(pos:), '"')
     203          440 :          IF (quote1 == 0) EXIT
     204          396 :          start_pos = pos + quote1
     205          396 :          quote2 = INDEX(features_json(start_pos:), '"')
     206          396 :          IF (quote2 == 0) EXIT
     207          396 :          feature_count = feature_count + 1
     208          396 :          pos = start_pos + quote2
     209              :       END DO
     210              : 
     211           44 :       IF (feature_count == 0) CPABORT("SKALA TorchScript model does not list any features")
     212          132 :       ALLOCATE (features(feature_count))
     213          440 :       features = ""
     214              : 
     215              :       pos = 1
     216          440 :       DO i = 1, feature_count
     217          396 :          quote1 = INDEX(features_json(pos:), '"')
     218          396 :          start_pos = pos + quote1
     219          396 :          quote2 = INDEX(features_json(start_pos:), '"')
     220          396 :          end_pos = start_pos + quote2 - 2
     221          396 :          features(i) = features_json(start_pos:end_pos)
     222          440 :          pos = start_pos + quote2
     223              :       END DO
     224              : 
     225           44 :    END SUBROUTINE parse_feature_list
     226              : 
     227            0 : END MODULE skala_torch_api
        

Generated by: LCOV version 2.0-1