LCOV - code coverage report
Current view: top level - src - torch_api.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:561f475) Lines: 97.4 % 193 188
Test Date: 2026-06-21 06:48:54 Functions: 71.7 % 60 43

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : MODULE torch_api
       8              :    USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
       9              :                             C_BOOL, &
      10              :                             C_CHAR, &
      11              :                             C_FLOAT, &
      12              :                             C_DOUBLE, &
      13              :                             C_F_POINTER, &
      14              :                             C_INT, &
      15              :                             C_NULL_CHAR, &
      16              :                             C_NULL_PTR, &
      17              :                             C_PTR, &
      18              :                             C_INT32_T, &
      19              :                             C_INT64_T
      20              : 
      21              :    USE kinds, ONLY: sp, int_4, int_8, dp, default_string_length
      22              : 
      23              : #include "./base/base_uses.f90"
      24              : 
      25              :    IMPLICIT NONE
      26              : 
      27              :    PRIVATE
      28              : 
      29              :    TYPE torch_tensor_type
      30              :       PRIVATE
      31              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      32              :    END TYPE torch_tensor_type
      33              : 
      34              :    TYPE torch_dict_type
      35              :       PRIVATE
      36              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      37              :    END TYPE torch_dict_type
      38              : 
      39              :    TYPE torch_model_type
      40              :       PRIVATE
      41              :       TYPE(C_PTR)                          :: c_ptr = C_NULL_PTR
      42              :    END TYPE torch_model_type
      43              : 
      44              :    #:set max_dim = 3
      45              :    INTERFACE torch_tensor_from_array
      46              :       #:for ndims  in range(1, max_dim+1)
      47              :          MODULE PROCEDURE torch_tensor_from_array_int32_${ndims}$d
      48              :          MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
      49              :          MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
      50              :          MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
      51              :       #:endfor
      52              :    END INTERFACE torch_tensor_from_array
      53              : 
      54              :    INTERFACE torch_tensor_reset_from_array
      55              :       #:for ndims  in range(1, max_dim+1)
      56              :          MODULE PROCEDURE torch_tensor_reset_from_array_double_${ndims}$d
      57              :       #:endfor
      58              :    END INTERFACE torch_tensor_reset_from_array
      59              : 
      60              :    INTERFACE torch_tensor_data_ptr
      61              :       #:for ndims  in range(1, max_dim+1)
      62              :          MODULE PROCEDURE torch_tensor_data_ptr_int32_${ndims}$d
      63              :          MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
      64              :          MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
      65              :          MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
      66              :       #:endfor
      67              :    END INTERFACE torch_tensor_data_ptr
      68              : 
      69              :    INTERFACE torch_model_get_attr
      70              :       MODULE PROCEDURE torch_model_get_attr_string
      71              :       MODULE PROCEDURE torch_model_get_attr_double
      72              :       MODULE PROCEDURE torch_model_get_attr_int64
      73              :       MODULE PROCEDURE torch_model_get_attr_int32
      74              :       MODULE PROCEDURE torch_model_get_attr_strlist
      75              :    END INTERFACE torch_model_get_attr
      76              : 
      77              :    PUBLIC :: torch_tensor_type, torch_tensor_expand_dim, torch_tensor_from_array, &
      78              :              torch_tensor_narrow, torch_tensor_release
      79              :    PUBLIC :: torch_tensor_reset_from_array
      80              :    PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_backward_scalar
      81              :    PUBLIC :: torch_tensor_grad
      82              :    PUBLIC :: torch_tensor_to_device_leaf
      83              :    PUBLIC :: torch_tensor_item_double, torch_tensor_weighted_sum
      84              :    PUBLIC :: torch_dict_type, torch_dict_clone, torch_dict_create, torch_dict_insert
      85              :    PUBLIC :: torch_dict_get, torch_dict_release
      86              :    PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
      87              :    PUBLIC :: torch_model_forward_mol_tensor
      88              :    PUBLIC :: torch_model_get_attr, torch_model_read_metadata
      89              :    PUBLIC :: torch_cuda_device_count, torch_cuda_is_available
      90              :    PUBLIC :: torch_allow_tf32, torch_model_freeze, torch_use_cuda
      91              : 
      92              : CONTAINS
      93              : 
      94              :    #:set typenames = ['int32', 'float', 'int64', 'double']
      95              :    #:set types_f = ['INTEGER(kind=int_4)', 'REAL(sp)', 'INTEGER(kind=int_8)', 'REAL(dp)']
      96              :    #:set types_c = ['INTEGER(kind=C_INT32_T)', 'REAL(kind=C_FLOAT)', 'INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
      97              : 
      98              :    #:for ndims in range(1, max_dim+1)
      99              :       #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
     100              : 
     101              : ! **************************************************************************************************
     102              : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
     103              : !>        The source must be an ALLOCATABLE to prevent passing a temporary array.
     104              : !> \author Ole Schuett
     105              : ! **************************************************************************************************
     106         1100 :          SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
     107              :             TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     108              :             #:set arraydims = ", ".join(":" for i in range(ndims))
     109              :             ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN)  :: source
     110              :             LOGICAL, OPTIONAL, INTENT(IN)                      :: requires_grad
     111              : 
     112              : #if defined(__LIBTORCH)
     113              :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
     114              :             LOGICAL                                            :: my_req_grad
     115              : 
     116              :             INTERFACE
     117              :                SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
     118              :                   BIND(C, name="torch_c_tensor_from_array_${typename}$")
     119              :                   IMPORT :: C_PTR, C_INT, C_INT32_T, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
     120              :                   TYPE(C_PTR)                                  :: tensor
     121              :                   LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     122              :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     123              :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     124              :                   ${type_c}$, DIMENSION(*)                     :: source
     125              :                END SUBROUTINE torch_c_tensor_from_array_${typename}$
     126              :             END INTERFACE
     127              : 
     128         1100 :             my_req_grad = .FALSE.
     129         1100 :             IF (PRESENT(requires_grad)) my_req_grad = requires_grad
     130              : 
     131              :             #:for axis in range(ndims)
     132         1100 :                sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
     133              :             #:endfor
     134              : 
     135         1100 :             CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     136              :             CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
     137              :                                                          req_grad=LOGICAL(my_req_grad, C_BOOL), &
     138              :                                                          ndims=${ndims}$, &
     139              :                                                          sizes=sizes_c, &
     140         1100 :                                                          source=source)
     141         1100 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     142              : #else
     143              :             CPABORT("CP2K compiled without the Torch library.")
     144              :             MARK_USED(tensor)
     145              :             MARK_USED(source)
     146              :             MARK_USED(requires_grad)
     147              : #endif
     148         1100 :          END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
     149              : 
     150              : ! **************************************************************************************************
     151              : !> \brief Copies data from a Torch tensor to an array.
     152              : !>        The returned pointer is only valide during the tensor's lifetime!
     153              : !> \author Ole Schuett
     154              : ! **************************************************************************************************
     155          586 :          SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
     156              :             TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     157              :             #:set arraydims = ", ".join(":" for i in range(ndims))
     158              :             ${type_f}$, DIMENSION(${arraydims}$), POINTER      :: data_ptr
     159              : 
     160              : #if defined(__LIBTORCH)
     161              :             INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_f, sizes_c
     162              :             TYPE(C_PTR)                                        :: data_ptr_c
     163              : 
     164              :             INTERFACE
     165              :                SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
     166              :                   BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
     167              :                   IMPORT :: C_CHAR, C_PTR, C_INT, C_INT32_T, C_INT64_T
     168              :                   TYPE(C_PTR), VALUE                           :: tensor
     169              :                   INTEGER(kind=C_INT), VALUE                   :: ndims
     170              :                   INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     171              :                   TYPE(C_PTR)                                  :: data_ptr
     172              :                END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
     173              :             END INTERFACE
     174              : 
     175         1940 :             sizes_c(:) = -1
     176          586 :             data_ptr_c = C_NULL_PTR
     177          586 :             CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     178          586 :             CPASSERT(.NOT. ASSOCIATED(data_ptr))
     179              :             CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
     180              :                                                        ndims=${ndims}$, &
     181              :                                                        sizes=sizes_c, &
     182          586 :                                                        data_ptr=data_ptr_c)
     183              : 
     184              :             #:for axis in range(ndims)
     185          586 :                sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
     186              :             #:endfor
     187              : 
     188         1940 :             IF (ALL(sizes_f /= 0)) THEN  ! Torch returns null pointer for zero-sized tensors.
     189          586 :                CPASSERT(C_ASSOCIATED(data_ptr_c))
     190         1940 :                CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
     191              :             END IF
     192              : #else
     193              :             CPABORT("CP2K compiled without the Torch library.")
     194              :             MARK_USED(tensor)
     195              :             MARK_USED(data_ptr)
     196              : #endif
     197          586 :          END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
     198              : 
     199              :       #:endfor
     200              :    #:endfor
     201              : 
     202              :    #:for ndims in range(1, max_dim+1)
     203              : 
     204              : ! **************************************************************************************************
     205              : !> \brief Reuses or creates a device leaf tensor and copies data into it.
     206              : !>        The source must be an ALLOCATABLE to prevent passing a temporary array.
     207              : ! **************************************************************************************************
     208          432 :       SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d(tensor, source, requires_grad)
     209              :          TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     210              :          #:set arraydims = ", ".join(":" for i in range(ndims))
     211              :          REAL(dp), DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN)  :: source
     212              :          LOGICAL, OPTIONAL, INTENT(IN)                      :: requires_grad
     213              : 
     214              : #if defined(__LIBTORCH)
     215              :          INTEGER(kind=int_8), DIMENSION(${ndims}$)          :: sizes_c
     216              :          LOGICAL                                            :: my_req_grad
     217              : 
     218              :          INTERFACE
     219              :             SUBROUTINE torch_c_tensor_reset_from_array_double(tensor, req_grad, ndims, sizes, source) &
     220              :                BIND(C, name="torch_c_tensor_reset_from_array_double")
     221              :                IMPORT :: C_PTR, C_INT, C_INT64_T, C_DOUBLE, C_BOOL
     222              :                TYPE(C_PTR)                                  :: tensor
     223              :                LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     224              :                INTEGER(kind=C_INT), VALUE                   :: ndims
     225              :                INTEGER(kind=C_INT64_T), DIMENSION(*)        :: sizes
     226              :                REAL(kind=C_DOUBLE), DIMENSION(*)            :: source
     227              :             END SUBROUTINE torch_c_tensor_reset_from_array_double
     228              :          END INTERFACE
     229              : 
     230          432 :          my_req_grad = .FALSE.
     231          432 :          IF (PRESENT(requires_grad)) my_req_grad = requires_grad
     232              : 
     233              :          #:for axis in range(ndims)
     234          432 :             sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
     235              :          #:endfor
     236              : 
     237              :          CALL torch_c_tensor_reset_from_array_double(tensor=tensor%c_ptr, &
     238              :                                                      req_grad=LOGICAL(my_req_grad, C_BOOL), &
     239              :                                                      ndims=${ndims}$, &
     240              :                                                      sizes=sizes_c, &
     241          432 :                                                      source=source)
     242          432 :          CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     243              : #else
     244              :          CPABORT("CP2K compiled without the Torch library.")
     245              :          MARK_USED(tensor)
     246              :          MARK_USED(source)
     247              :          MARK_USED(requires_grad)
     248              : #endif
     249          432 :       END SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d
     250              : 
     251              :    #:endfor
     252              : 
     253              : ! **************************************************************************************************
     254              : !> \brief Creates an expanded tensor view along one singleton dimension.
     255              : ! **************************************************************************************************
     256           18 :    SUBROUTINE torch_tensor_expand_dim(tensor, dim, extent, result)
     257              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     258              :       INTEGER, INTENT(IN)                                :: dim, extent
     259              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: result
     260              : 
     261              : #if defined(__LIBTORCH)
     262              :       INTERFACE
     263              :          SUBROUTINE torch_c_tensor_expand_dim(tensor, dim, extent, result) &
     264              :             BIND(C, name="torch_c_tensor_expand_dim")
     265              :             IMPORT :: C_INT64_T, C_PTR
     266              :             TYPE(C_PTR), VALUE                           :: tensor
     267              :             INTEGER(kind=C_INT64_T), VALUE               :: dim, extent
     268              :             TYPE(C_PTR)                                  :: result
     269              :          END SUBROUTINE torch_c_tensor_expand_dim
     270              :       END INTERFACE
     271              : 
     272           18 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     273           18 :       CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
     274           18 :       CPASSERT(dim >= 0)
     275           18 :       CPASSERT(extent >= 0)
     276              :       CALL torch_c_tensor_expand_dim(tensor=tensor%c_ptr, &
     277              :                                      dim=INT(dim, C_INT64_T), &
     278              :                                      extent=INT(extent, C_INT64_T), &
     279           18 :                                      result=result%c_ptr)
     280           18 :       CPASSERT(C_ASSOCIATED(result%c_ptr))
     281              : #else
     282              :       CPABORT("CP2K compiled without the Torch library.")
     283              :       MARK_USED(tensor)
     284              :       MARK_USED(dim)
     285              :       MARK_USED(extent)
     286              :       MARK_USED(result)
     287              : #endif
     288           18 :    END SUBROUTINE torch_tensor_expand_dim
     289              : 
     290              : ! **************************************************************************************************
     291              : !> \brief Creates a view of a contiguous tensor slice.
     292              : ! **************************************************************************************************
     293           32 :    SUBROUTINE torch_tensor_narrow(tensor, dim, start_index, length, result)
     294              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     295              :       INTEGER, INTENT(IN)                                :: dim, start_index, length
     296              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: result
     297              : 
     298              : #if defined(__LIBTORCH)
     299              :       INTERFACE
     300              :          SUBROUTINE torch_c_tensor_narrow(tensor, dim, start_index, length, result) &
     301              :             BIND(C, name="torch_c_tensor_narrow")
     302              :             IMPORT :: C_INT64_T, C_PTR
     303              :             TYPE(C_PTR), VALUE                           :: tensor
     304              :             INTEGER(kind=C_INT64_T), VALUE               :: dim, start_index, length
     305              :             TYPE(C_PTR)                                  :: result
     306              :          END SUBROUTINE torch_c_tensor_narrow
     307              :       END INTERFACE
     308              : 
     309           32 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     310           32 :       CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
     311           32 :       CPASSERT(dim >= 0)
     312           32 :       CPASSERT(start_index >= 0)
     313           32 :       CPASSERT(length >= 0)
     314              :       CALL torch_c_tensor_narrow(tensor=tensor%c_ptr, &
     315              :                                  dim=INT(dim, C_INT64_T), &
     316              :                                  start_index=INT(start_index, C_INT64_T), &
     317              :                                  length=INT(length, C_INT64_T), &
     318           32 :                                  result=result%c_ptr)
     319           32 :       CPASSERT(C_ASSOCIATED(result%c_ptr))
     320              : #else
     321              :       CPABORT("CP2K compiled without the Torch library.")
     322              :       MARK_USED(tensor)
     323              :       MARK_USED(dim)
     324              :       MARK_USED(start_index)
     325              :       MARK_USED(length)
     326              :       MARK_USED(result)
     327              : #endif
     328           32 :    END SUBROUTINE torch_tensor_narrow
     329              : 
     330              : ! **************************************************************************************************
     331              : !> \brief Runs autograd on a Torch tensor.
     332              : !> \author Ole Schuett
     333              : ! **************************************************************************************************
     334            6 :    SUBROUTINE torch_tensor_backward(tensor, outer_grad)
     335              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     336              :       TYPE(torch_tensor_type), INTENT(IN)                :: outer_grad
     337              : 
     338              : #if defined(__LIBTORCH)
     339              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_tensor_backward'
     340              :       INTEGER                                            :: handle
     341              : 
     342              :       INTERFACE
     343              :          SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
     344              :             BIND(C, name="torch_c_tensor_backward")
     345              :             IMPORT :: C_CHAR, C_PTR
     346              :             TYPE(C_PTR), VALUE                           :: tensor
     347              :             TYPE(C_PTR), VALUE                           :: outer_grad
     348              :          END SUBROUTINE torch_c_tensor_backward
     349              :       END INTERFACE
     350              : 
     351            6 :       CALL timeset(routineN, handle)
     352            6 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     353            6 :       CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
     354            6 :       CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
     355            6 :       CALL timestop(handle)
     356              : #else
     357              :       CPABORT("CP2K compiled without the Torch library.")
     358              :       MARK_USED(tensor)
     359              :       MARK_USED(outer_grad)
     360              : #endif
     361            6 :    END SUBROUTINE torch_tensor_backward
     362              : 
     363              : ! **************************************************************************************************
     364              : !> \brief Runs autograd on a scalar Torch tensor.
     365              : ! **************************************************************************************************
     366          154 :    SUBROUTINE torch_tensor_backward_scalar(tensor)
     367              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     368              : 
     369              : #if defined(__LIBTORCH)
     370              :       INTERFACE
     371              :          SUBROUTINE torch_c_tensor_backward_scalar(tensor) &
     372              :             BIND(C, name="torch_c_tensor_backward_scalar")
     373              :             IMPORT :: C_PTR
     374              :             TYPE(C_PTR), VALUE                           :: tensor
     375              :          END SUBROUTINE torch_c_tensor_backward_scalar
     376              :       END INTERFACE
     377              : 
     378          154 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     379          154 :       CALL torch_c_tensor_backward_scalar(tensor=tensor%c_ptr)
     380              : #else
     381              :       CPABORT("CP2K compiled without the Torch library.")
     382              :       MARK_USED(tensor)
     383              : #endif
     384          154 :    END SUBROUTINE torch_tensor_backward_scalar
     385              : 
     386              : ! **************************************************************************************************
     387              : !> \brief Moves a tensor to the active Torch device and makes it an autograd leaf.
     388              : ! **************************************************************************************************
     389          848 :    SUBROUTINE torch_tensor_to_device_leaf(tensor, requires_grad)
     390              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     391              :       LOGICAL, INTENT(IN)                                :: requires_grad
     392              : 
     393              : #if defined(__LIBTORCH)
     394              :       INTERFACE
     395              :          SUBROUTINE torch_c_tensor_to_device_leaf(tensor, req_grad) &
     396              :             BIND(C, name="torch_c_tensor_to_device_leaf")
     397              :             IMPORT :: C_BOOL, C_PTR
     398              :             TYPE(C_PTR)                                  :: tensor
     399              :             LOGICAL(kind=C_BOOL), VALUE                  :: req_grad
     400              :          END SUBROUTINE torch_c_tensor_to_device_leaf
     401              :       END INTERFACE
     402              : 
     403          848 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     404              :       CALL torch_c_tensor_to_device_leaf(tensor=tensor%c_ptr, &
     405          848 :                                          req_grad=LOGICAL(requires_grad, C_BOOL))
     406          848 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     407              : #else
     408              :       CPABORT("CP2K compiled without the Torch library.")
     409              :       MARK_USED(tensor)
     410              :       MARK_USED(requires_grad)
     411              : #endif
     412          848 :    END SUBROUTINE torch_tensor_to_device_leaf
     413              : 
     414              : ! **************************************************************************************************
     415              : !> \brief Select whether Torch wrappers should use CUDA when available.
     416              : ! **************************************************************************************************
     417          304 :    SUBROUTINE torch_use_cuda(use_cuda)
     418              :       LOGICAL, INTENT(IN)                                :: use_cuda
     419              : 
     420              : #if defined(__LIBTORCH)
     421              :       INTERFACE
     422              :          SUBROUTINE torch_c_use_cuda(use_cuda) BIND(C, name="torch_c_use_cuda")
     423              :             IMPORT :: C_BOOL
     424              :             LOGICAL(kind=C_BOOL), VALUE                  :: use_cuda
     425              :          END SUBROUTINE torch_c_use_cuda
     426              :       END INTERFACE
     427              : 
     428          304 :       CALL torch_c_use_cuda(use_cuda=LOGICAL(use_cuda, C_BOOL))
     429              : #else
     430              :       MARK_USED(use_cuda)
     431              : #endif
     432          304 :    END SUBROUTINE torch_use_cuda
     433              : 
     434              : ! **************************************************************************************************
     435              : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
     436              : !> \author Ole Schuett
     437              : ! **************************************************************************************************
     438          514 :    SUBROUTINE torch_tensor_grad(tensor, grad)
     439              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     440              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: grad
     441              : 
     442              : #if defined(__LIBTORCH)
     443              :       INTERFACE
     444              :          SUBROUTINE torch_c_tensor_grad(tensor, grad) &
     445              :             BIND(C, name="torch_c_tensor_grad")
     446              :             IMPORT :: C_PTR
     447              :             TYPE(C_PTR), VALUE                           :: tensor
     448              :             TYPE(C_PTR)                                  :: grad
     449              :          END SUBROUTINE torch_c_tensor_grad
     450              :       END INTERFACE
     451              : 
     452          514 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     453          514 :       CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
     454          514 :       CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
     455          514 :       CPASSERT(C_ASSOCIATED(grad%c_ptr))
     456              : #else
     457              :       CPABORT("CP2K compiled without the Torch library.")
     458              :       MARK_USED(tensor)
     459              :       MARK_USED(grad)
     460              : #endif
     461          514 :    END SUBROUTINE torch_tensor_grad
     462              : 
     463              : ! **************************************************************************************************
     464              : !> \brief Returns the weighted sum of two Torch tensors.
     465              : ! **************************************************************************************************
     466          154 :    SUBROUTINE torch_tensor_weighted_sum(values, weights, result)
     467              :       TYPE(torch_tensor_type), INTENT(IN)                :: values, weights
     468              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: result
     469              : 
     470              : #if defined(__LIBTORCH)
     471              :       INTERFACE
     472              :          SUBROUTINE torch_c_tensor_weighted_sum(values, weights, result) &
     473              :             BIND(C, name="torch_c_tensor_weighted_sum")
     474              :             IMPORT :: C_PTR
     475              :             TYPE(C_PTR), VALUE                           :: values
     476              :             TYPE(C_PTR), VALUE                           :: weights
     477              :             TYPE(C_PTR)                                  :: result
     478              :          END SUBROUTINE torch_c_tensor_weighted_sum
     479              :       END INTERFACE
     480              : 
     481          154 :       CPASSERT(C_ASSOCIATED(values%c_ptr))
     482          154 :       CPASSERT(C_ASSOCIATED(weights%c_ptr))
     483          154 :       CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
     484          154 :       CALL torch_c_tensor_weighted_sum(values=values%c_ptr, weights=weights%c_ptr, result=result%c_ptr)
     485          154 :       CPASSERT(C_ASSOCIATED(result%c_ptr))
     486              : #else
     487              :       CPABORT("CP2K compiled without the Torch library.")
     488              :       MARK_USED(values)
     489              :       MARK_USED(weights)
     490              :       MARK_USED(result)
     491              : #endif
     492          154 :    END SUBROUTINE torch_tensor_weighted_sum
     493              : 
     494              : ! **************************************************************************************************
     495              : !> \brief Returns a scalar double value from a Torch tensor.
     496              : ! **************************************************************************************************
     497          154 :    FUNCTION torch_tensor_item_double(tensor) RESULT(value)
     498              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     499              :       REAL(KIND=dp)                                      :: value
     500              : 
     501              : #if defined(__LIBTORCH)
     502              :       INTERFACE
     503              :          FUNCTION torch_c_tensor_item_double(tensor) RESULT(value) &
     504              :             BIND(C, name="torch_c_tensor_item_double")
     505              :             IMPORT :: C_DOUBLE, C_PTR
     506              :             TYPE(C_PTR), VALUE                           :: tensor
     507              :             REAL(KIND=C_DOUBLE)                          :: value
     508              :          END FUNCTION torch_c_tensor_item_double
     509              :       END INTERFACE
     510              : 
     511          154 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     512          154 :       value = torch_c_tensor_item_double(tensor=tensor%c_ptr)
     513              : #else
     514              :       value = 0.0_dp
     515              :       CPABORT("CP2K compiled without the Torch library.")
     516              :       MARK_USED(tensor)
     517              : #endif
     518          154 :    END FUNCTION torch_tensor_item_double
     519              : 
     520              : ! **************************************************************************************************
     521              : !> \brief Releases a Torch tensor and all its ressources.
     522              : !> \author Ole Schuett
     523              : ! **************************************************************************************************
     524         1446 :    SUBROUTINE torch_tensor_release(tensor)
     525              :       TYPE(torch_tensor_type), INTENT(INOUT)               :: tensor
     526              : 
     527              : #if defined(__LIBTORCH)
     528              :       INTERFACE
     529              :          SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
     530              :             IMPORT :: C_PTR
     531              :             TYPE(C_PTR), VALUE                        :: tensor
     532              :          END SUBROUTINE torch_c_tensor_release
     533              :       END INTERFACE
     534              : 
     535         1446 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     536         1446 :       CALL torch_c_tensor_release(tensor=tensor%c_ptr)
     537         1446 :       tensor%c_ptr = C_NULL_PTR
     538              : #else
     539              :       CPABORT("CP2K was compiled without Torch library.")
     540              :       MARK_USED(tensor)
     541              : #endif
     542         1446 :    END SUBROUTINE torch_tensor_release
     543              : 
     544              : ! **************************************************************************************************
     545              : !> \brief Creates an empty Torch dictionary.
     546              : !> \author Ole Schuett
     547              : ! **************************************************************************************************
     548          246 :    SUBROUTINE torch_dict_create(dict)
     549              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     550              : 
     551              : #if defined(__LIBTORCH)
     552              :       INTERFACE
     553              :          SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
     554              :             IMPORT :: C_PTR
     555              :             TYPE(C_PTR)                               :: dict
     556              :          END SUBROUTINE torch_c_dict_create
     557              :       END INTERFACE
     558              : 
     559          246 :       CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
     560          246 :       CALL torch_c_dict_create(dict=dict%c_ptr)
     561          246 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     562              : #else
     563              :       CPABORT("CP2K was compiled without Torch library.")
     564              :       MARK_USED(dict)
     565              : #endif
     566          246 :    END SUBROUTINE torch_dict_create
     567              : 
     568              : ! **************************************************************************************************
     569              : !> \brief Clones a Torch dictionary.
     570              : ! **************************************************************************************************
     571           58 :    SUBROUTINE torch_dict_clone(source, target)
     572              :       TYPE(torch_dict_type), INTENT(IN)                  :: source
     573              :       TYPE(torch_dict_type), INTENT(INOUT)               :: target
     574              : 
     575              : #if defined(__LIBTORCH)
     576              :       INTERFACE
     577              :          SUBROUTINE torch_c_dict_clone(source, target) BIND(C, name="torch_c_dict_clone")
     578              :             IMPORT :: C_PTR
     579              :             TYPE(C_PTR), VALUE                        :: source
     580              :             TYPE(C_PTR)                               :: target
     581              :          END SUBROUTINE torch_c_dict_clone
     582              :       END INTERFACE
     583              : 
     584           58 :       CPASSERT(C_ASSOCIATED(source%c_ptr))
     585           58 :       CPASSERT(.NOT. C_ASSOCIATED(target%c_ptr))
     586           58 :       CALL torch_c_dict_clone(source=source%c_ptr, target=target%c_ptr)
     587           58 :       CPASSERT(C_ASSOCIATED(target%c_ptr))
     588              : #else
     589              :       CPABORT("CP2K was compiled without Torch library.")
     590              :       MARK_USED(source)
     591              :       MARK_USED(target)
     592              : #endif
     593           58 :    END SUBROUTINE torch_dict_clone
     594              : 
     595              : ! **************************************************************************************************
     596              : !> \brief Inserts a Torch tensor into a Torch dictionary.
     597              : !> \author Ole Schuett
     598              : ! **************************************************************************************************
     599         1196 :    SUBROUTINE torch_dict_insert(dict, key, tensor)
     600              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     601              :       CHARACTER(len=*), INTENT(IN)                       :: key
     602              :       TYPE(torch_tensor_type), INTENT(IN)                :: tensor
     603              : 
     604              : #if defined(__LIBTORCH)
     605              : 
     606              :       INTERFACE
     607              :          SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
     608              :             BIND(C, name="torch_c_dict_insert")
     609              :             IMPORT :: C_CHAR, C_PTR
     610              :             TYPE(C_PTR), VALUE                           :: dict
     611              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     612              :             TYPE(C_PTR), VALUE                           :: tensor
     613              :          END SUBROUTINE torch_c_dict_insert
     614              :       END INTERFACE
     615              : 
     616         1196 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     617         1196 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     618         1196 :       CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     619              : #else
     620              :       CPABORT("CP2K compiled without the Torch library.")
     621              :       MARK_USED(dict)
     622              :       MARK_USED(key)
     623              :       MARK_USED(tensor)
     624              : #endif
     625         1196 :    END SUBROUTINE torch_dict_insert
     626              : 
     627              : ! **************************************************************************************************
     628              : !> \brief Retrieves a Torch tensor from a Torch dictionary.
     629              : !> \author Ole Schuett
     630              : ! **************************************************************************************************
     631           72 :    SUBROUTINE torch_dict_get(dict, key, tensor)
     632              :       TYPE(torch_dict_type), INTENT(IN)                  :: dict
     633              :       CHARACTER(len=*), INTENT(IN)                       :: key
     634              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: tensor
     635              : 
     636              : #if defined(__LIBTORCH)
     637              : 
     638              :       INTERFACE
     639              :          SUBROUTINE torch_c_dict_get(dict, key, tensor) &
     640              :             BIND(C, name="torch_c_dict_get")
     641              :             IMPORT :: C_CHAR, C_PTR
     642              :             TYPE(C_PTR), VALUE                           :: dict
     643              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
     644              :             TYPE(C_PTR)                                  :: tensor
     645              :          END SUBROUTINE torch_c_dict_get
     646              :       END INTERFACE
     647              : 
     648           72 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     649           72 :       CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
     650           72 :       CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
     651           72 :       CPASSERT(C_ASSOCIATED(tensor%c_ptr))
     652              : 
     653              : #else
     654              :       CPABORT("CP2K compiled without the Torch library.")
     655              :       MARK_USED(dict)
     656              :       MARK_USED(key)
     657              :       MARK_USED(tensor)
     658              : #endif
     659           72 :    END SUBROUTINE torch_dict_get
     660              : 
     661              : ! **************************************************************************************************
     662              : !> \brief Releases a Torch dictionary and all its ressources.
     663              : !> \author Ole Schuett
     664              : ! **************************************************************************************************
     665          172 :    SUBROUTINE torch_dict_release(dict)
     666              :       TYPE(torch_dict_type), INTENT(INOUT)               :: dict
     667              : 
     668              : #if defined(__LIBTORCH)
     669              :       INTERFACE
     670              :          SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
     671              :             IMPORT :: C_PTR
     672              :             TYPE(C_PTR), VALUE                        :: dict
     673              :          END SUBROUTINE torch_c_dict_release
     674              :       END INTERFACE
     675              : 
     676          172 :       CPASSERT(C_ASSOCIATED(dict%c_ptr))
     677          172 :       CALL torch_c_dict_release(dict=dict%c_ptr)
     678          172 :       dict%c_ptr = C_NULL_PTR
     679              : #else
     680              :       CPABORT("CP2K was compiled without Torch library.")
     681              :       MARK_USED(dict)
     682              : #endif
     683          172 :    END SUBROUTINE torch_dict_release
     684              : 
     685              : ! **************************************************************************************************
     686              : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
     687              : !> \author Ole Schuett
     688              : ! **************************************************************************************************
     689           58 :    SUBROUTINE torch_model_load(model, filename)
     690              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     691              :       CHARACTER(len=*), INTENT(IN)                       :: filename
     692              : 
     693              : #if defined(__LIBTORCH)
     694              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_load'
     695              :       INTEGER                                            :: handle
     696              : 
     697              :       INTERFACE
     698              :          SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
     699              :             IMPORT :: C_PTR, C_CHAR
     700              :             TYPE(C_PTR)                               :: model
     701              :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename
     702              :          END SUBROUTINE torch_c_model_load
     703              :       END INTERFACE
     704              : 
     705           58 :       CALL timeset(routineN, handle)
     706           58 :       CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
     707           58 :       CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
     708           58 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     709           58 :       CALL timestop(handle)
     710              : #else
     711              :       CPABORT("CP2K was compiled without Torch library.")
     712              :       MARK_USED(model)
     713              :       MARK_USED(filename)
     714              : #endif
     715           58 :    END SUBROUTINE torch_model_load
     716              : 
     717              : ! **************************************************************************************************
     718              : !> \brief Evaluates the given Torch model.
     719              : !> \author Ole Schuett
     720              : ! **************************************************************************************************
     721           60 :    SUBROUTINE torch_model_forward(model, inputs, outputs)
     722              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     723              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     724              :       TYPE(torch_dict_type), INTENT(INOUT)               :: outputs
     725              : 
     726              : #if defined(__LIBTORCH)
     727              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_forward'
     728              :       INTEGER                                            :: handle
     729              : 
     730              :       INTERFACE
     731              :          SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
     732              :             IMPORT :: C_PTR
     733              :             TYPE(C_PTR), VALUE                        :: model
     734              :             TYPE(C_PTR), VALUE                        :: inputs
     735              :             TYPE(C_PTR), VALUE                        :: outputs
     736              :          END SUBROUTINE torch_c_model_forward
     737              :       END INTERFACE
     738              : 
     739           60 :       CALL timeset(routineN, handle)
     740           60 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     741           60 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     742           60 :       CPASSERT(C_ASSOCIATED(outputs%c_ptr))
     743           60 :       CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
     744           60 :       CALL timestop(handle)
     745              : #else
     746              :       CPABORT("CP2K was compiled without Torch library.")
     747              :       MARK_USED(model)
     748              :       MARK_USED(inputs)
     749              :       MARK_USED(outputs)
     750              : #endif
     751           60 :    END SUBROUTINE torch_model_forward
     752              : 
     753              : ! **************************************************************************************************
     754              : !> \brief Evaluates a TorchScript model method expecting keyword argument "mol".
     755              : ! **************************************************************************************************
     756          154 :    SUBROUTINE torch_model_forward_mol_tensor(model, method_name, inputs, output)
     757              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     758              :       CHARACTER(len=*), INTENT(IN)                       :: method_name
     759              :       TYPE(torch_dict_type), INTENT(IN)                  :: inputs
     760              :       TYPE(torch_tensor_type), INTENT(INOUT)             :: output
     761              : 
     762              : #if defined(__LIBTORCH)
     763              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_forward_mol_tensor'
     764              :       INTEGER                                            :: handle
     765              : 
     766              :       INTERFACE
     767              :          SUBROUTINE torch_c_model_forward_mol_tensor(model, method_name, inputs, output) &
     768              :             BIND(C, name="torch_c_model_forward_mol_tensor")
     769              :             IMPORT :: C_CHAR, C_PTR
     770              :             TYPE(C_PTR), VALUE                           :: model
     771              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: method_name
     772              :             TYPE(C_PTR), VALUE                           :: inputs
     773              :             TYPE(C_PTR)                                  :: output
     774              :          END SUBROUTINE torch_c_model_forward_mol_tensor
     775              :       END INTERFACE
     776              : 
     777          154 :       CALL timeset(routineN, handle)
     778          154 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     779          154 :       CPASSERT(C_ASSOCIATED(inputs%c_ptr))
     780          154 :       CPASSERT(.NOT. C_ASSOCIATED(output%c_ptr))
     781              :       CALL torch_c_model_forward_mol_tensor(model=model%c_ptr, &
     782              :                                             method_name=TRIM(method_name)//C_NULL_CHAR, &
     783              :                                             inputs=inputs%c_ptr, &
     784          154 :                                             output=output%c_ptr)
     785          154 :       CPASSERT(C_ASSOCIATED(output%c_ptr))
     786          154 :       CALL timestop(handle)
     787              : #else
     788              :       CPABORT("CP2K was compiled without Torch library.")
     789              :       MARK_USED(model)
     790              :       MARK_USED(method_name)
     791              :       MARK_USED(inputs)
     792              :       MARK_USED(output)
     793              : #endif
     794          154 :    END SUBROUTINE torch_model_forward_mol_tensor
     795              : 
     796              : ! **************************************************************************************************
     797              : !> \brief Releases a Torch model and all its ressources.
     798              : !> \author Ole Schuett
     799              : ! **************************************************************************************************
     800           14 :    SUBROUTINE torch_model_release(model)
     801              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     802              : 
     803              : #if defined(__LIBTORCH)
     804              :       INTERFACE
     805              :          SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
     806              :             IMPORT :: C_PTR
     807              :             TYPE(C_PTR), VALUE                        :: model
     808              :          END SUBROUTINE torch_c_model_release
     809              :       END INTERFACE
     810              : 
     811           14 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     812           14 :       CALL torch_c_model_release(model=model%c_ptr)
     813           14 :       model%c_ptr = C_NULL_PTR
     814              : #else
     815              :       CPABORT("CP2K was compiled without Torch library.")
     816              :       MARK_USED(model)
     817              : #endif
     818           14 :    END SUBROUTINE torch_model_release
     819              : 
     820              : ! **************************************************************************************************
     821              : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
     822              : !> \author Ole Schuett
     823              : ! **************************************************************************************************
     824          116 :    FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
     825              :       CHARACTER(len=*), INTENT(IN)                       :: filename, key
     826              :       CHARACTER(:), ALLOCATABLE                           :: res
     827              : 
     828              : #if defined(__LIBTORCH)
     829              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_read_metadata'
     830              :       INTEGER                                            :: handle
     831              : 
     832              :       INTEGER                                            :: length
     833              :       TYPE(C_PTR)                                        :: content_c
     834              : 
     835              :       INTERFACE
     836              :          SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
     837              :             BIND(C, name="torch_c_model_read_metadata")
     838              :             IMPORT :: C_CHAR, C_PTR, C_INT
     839              :             CHARACTER(kind=C_CHAR), DIMENSION(*)      :: filename, key
     840              :             TYPE(C_PTR)                               :: content
     841              :             INTEGER(kind=C_INT)                       :: length
     842              :          END SUBROUTINE torch_c_model_read_metadata
     843              :       END INTERFACE
     844              : 
     845          116 :       CALL timeset(routineN, handle)
     846          116 :       content_c = C_NULL_PTR
     847          116 :       length = -1
     848              :       CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
     849              :                                        key=TRIM(key)//C_NULL_CHAR, &
     850              :                                        content=content_c, &
     851          116 :                                        length=length)
     852          116 :       CALL c_string_to_allocatable(content_c, length, res)
     853          116 :       CALL timestop(handle)
     854              : #else
     855              :       res = ""
     856              :       MARK_USED(filename)
     857              :       MARK_USED(key)
     858              :       CPABORT("CP2K was compiled without Torch library.")
     859              : #endif
     860          116 :    END FUNCTION torch_model_read_metadata
     861              : 
     862              : ! **************************************************************************************************
     863              : !> \brief Move a C-allocated null-terminated string into an allocatable Fortran string.
     864              : ! **************************************************************************************************
     865          116 :    SUBROUTINE c_string_to_allocatable(content_c, length, res)
     866              :       TYPE(C_PTR), INTENT(INOUT)                         :: content_c
     867              :       INTEGER, INTENT(IN)                                :: length
     868              :       CHARACTER(:), ALLOCATABLE, INTENT(OUT)             :: res
     869              : 
     870              : #if defined(__LIBTORCH)
     871              :       CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
     872          116 :          POINTER                                         :: content_f
     873              :       INTEGER                                            :: i
     874              : 
     875              :       INTERFACE
     876              :          SUBROUTINE torch_c_free_string(content) BIND(C, name="torch_c_free_string")
     877              :             IMPORT :: C_PTR
     878              :             TYPE(C_PTR), VALUE                        :: content
     879              :          END SUBROUTINE torch_c_free_string
     880              :       END INTERFACE
     881              : 
     882            0 :       CPASSERT(C_ASSOCIATED(content_c))
     883          116 :       CPASSERT(length >= 0)
     884              : 
     885          232 :       CALL C_F_POINTER(content_c, content_f, shape=[length + 1])
     886          116 :       CPASSERT(content_f(length + 1) == C_NULL_CHAR)
     887              : 
     888          116 :       ALLOCATE (CHARACTER(LEN=length) :: res)
     889         7232 :       DO i = 1, length
     890         7116 :          CPASSERT(content_f(i) /= C_NULL_CHAR)
     891         7232 :          res(i:i) = content_f(i)
     892              :       END DO
     893              : 
     894          116 :       NULLIFY (content_f)
     895          116 :       CALL torch_c_free_string(content_c)
     896          116 :       content_c = C_NULL_PTR
     897              : 
     898              : #else
     899              :       res = ""
     900              :       MARK_USED(content_c)
     901              :       MARK_USED(length)
     902              :       CPABORT("CP2K was compiled without Torch library.")
     903              : #endif
     904          116 :    END SUBROUTINE c_string_to_allocatable
     905              : 
     906              : ! **************************************************************************************************
     907              : !> \brief Returns true iff the Torch CUDA backend is available.
     908              : !> \author Ole Schuett
     909              : ! **************************************************************************************************
     910            2 :    FUNCTION torch_cuda_is_available() RESULT(res)
     911              :       LOGICAL                                            :: res
     912              : 
     913              : #if defined(__LIBTORCH)
     914              :       INTERFACE
     915              :          FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
     916              :             IMPORT :: C_BOOL
     917              :             LOGICAL(C_BOOL)                           :: torch_c_cuda_is_available
     918              :          END FUNCTION torch_c_cuda_is_available
     919              :       END INTERFACE
     920              : 
     921            2 :       res = torch_c_cuda_is_available()
     922              : #else
     923              :       CPABORT("CP2K was compiled without Torch library.")
     924              :       res = .FALSE.
     925              : #endif
     926            2 :    END FUNCTION torch_cuda_is_available
     927              : 
     928              : ! **************************************************************************************************
     929              : !> \brief Return the number of CUDA devices visible to Torch.
     930              : ! **************************************************************************************************
     931            0 :    FUNCTION torch_cuda_device_count() RESULT(count)
     932              :       INTEGER                                            :: count
     933              : 
     934              : #if defined(__LIBTORCH)
     935              :       INTERFACE
     936              :          FUNCTION torch_c_cuda_device_count() BIND(C, name="torch_c_cuda_device_count")
     937              :             IMPORT :: C_INT
     938              :             INTEGER(C_INT)                            :: torch_c_cuda_device_count
     939              :          END FUNCTION torch_c_cuda_device_count
     940              :       END INTERFACE
     941              : 
     942            0 :       count = torch_c_cuda_device_count()
     943              : #else
     944              :       CPABORT("CP2K was compiled without Torch library.")
     945              :       count = 0
     946              : #endif
     947            0 :    END FUNCTION torch_cuda_device_count
     948              : 
     949              : ! **************************************************************************************************
     950              : !> \brief Set whether to allow the use of TF32.
     951              : !>        Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
     952              : !>        See https://pytorch.org/docs/stable/notes/cuda.html
     953              : !> \author Gabriele Tocci
     954              : ! **************************************************************************************************
     955            4 :    SUBROUTINE torch_allow_tf32(allow_tf32)
     956              :       LOGICAL, INTENT(IN)                                  :: allow_tf32
     957              : 
     958              : #if defined(__LIBTORCH)
     959              :       INTERFACE
     960              :          SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
     961              :             IMPORT :: C_BOOL
     962              :             LOGICAL(C_BOOL), VALUE                  :: allow_tf32
     963              :          END SUBROUTINE torch_c_allow_tf32
     964              :       END INTERFACE
     965              : 
     966            4 :       CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
     967              : #else
     968              :       CPABORT("CP2K was compiled without Torch library.")
     969              :       MARK_USED(allow_tf32)
     970              : #endif
     971            4 :    END SUBROUTINE torch_allow_tf32
     972              : 
     973              : ! **************************************************************************************************
     974              : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
     975              : !>        See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
     976              : !> \author Gabriele Tocci
     977              : ! **************************************************************************************************
     978            4 :    SUBROUTINE torch_model_freeze(model)
     979              :       TYPE(torch_model_type), INTENT(INOUT)              :: model
     980              : 
     981              : #if defined(__LIBTORCH)
     982              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'torch_model_freeze'
     983              :       INTEGER                                            :: handle
     984              : 
     985              :       INTERFACE
     986              :          SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
     987              :             IMPORT :: C_PTR
     988              :             TYPE(C_PTR), VALUE                        :: model
     989              :          END SUBROUTINE torch_c_model_freeze
     990              :       END INTERFACE
     991              : 
     992            4 :       CALL timeset(routineN, handle)
     993            4 :       CPASSERT(C_ASSOCIATED(model%c_ptr))
     994            4 :       CALL torch_c_model_freeze(model=model%c_ptr)
     995            4 :       CALL timestop(handle)
     996              : #else
     997              :       CPABORT("CP2K was compiled without Torch library.")
     998              :       MARK_USED(model)
     999              : #endif
    1000            4 :    END SUBROUTINE torch_model_freeze
    1001              : 
    1002              :    #:set typenames = ['int64', 'double', 'string']
    1003              :    #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
    1004              :    #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
    1005              :    #:set zeros_f = ['0', '0.0_dp', '""']
    1006              : 
    1007              :    #:for typename, type_f, type_c, zero_f in zip(typenames, types_f, types_c, zeros_f)
    1008              : ! **************************************************************************************************
    1009              : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
    1010              : !> \author Ole Schuett
    1011              : ! **************************************************************************************************
    1012           64 :       SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
    1013              :          TYPE(torch_model_type), INTENT(IN)                 :: model
    1014              :          CHARACTER(len=*), INTENT(IN)                       :: key
    1015              :          ${type_f}$, INTENT(OUT)                            :: dest
    1016              : 
    1017              : #if defined(__LIBTORCH)
    1018              : 
    1019              :          INTERFACE
    1020              :             SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
    1021              :                BIND(C, name="torch_c_model_get_attr_${typename}$")
    1022              :                IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
    1023              :                TYPE(C_PTR), VALUE                           :: model
    1024              :                CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
    1025              :                ${type_c}$                                   :: dest
    1026              :             END SUBROUTINE torch_c_model_get_attr_${typename}$
    1027              :          END INTERFACE
    1028              : 
    1029              :          CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
    1030              :                                                    key=TRIM(key)//C_NULL_CHAR, &
    1031           64 :                                                    dest=dest)
    1032              : #else
    1033              :          dest = ${zero_f}$
    1034              :          MARK_USED(model)
    1035              :          MARK_USED(key)
    1036              :          CPABORT("CP2K compiled without the Torch library.")
    1037              : #endif
    1038           64 :       END SUBROUTINE torch_model_get_attr_${typename}$
    1039              :    #:endfor
    1040              : 
    1041              : ! **************************************************************************************************
    1042              : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
    1043              : !> \author Ole Schuett
    1044              : ! **************************************************************************************************
    1045           40 :    SUBROUTINE torch_model_get_attr_int32(model, key, dest)
    1046              :       TYPE(torch_model_type), INTENT(IN)                 :: model
    1047              :       CHARACTER(len=*), INTENT(IN)                       :: key
    1048              :       INTEGER, INTENT(OUT)                               :: dest
    1049              : 
    1050              :       INTEGER(kind=int_8)                                :: temp
    1051           40 :       CALL torch_model_get_attr_int64(model, key, temp)
    1052           40 :       CPASSERT(ABS(temp) < HUGE(dest))
    1053           40 :       dest = INT(temp)
    1054           40 :    END SUBROUTINE torch_model_get_attr_int32
    1055              : 
    1056              : ! **************************************************************************************************
    1057              : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
    1058              : !> \author Ole Schuett
    1059              : ! **************************************************************************************************
    1060            8 :    SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
    1061              :       TYPE(torch_model_type), INTENT(IN)                 :: model
    1062              :       CHARACTER(len=*), INTENT(IN)                       :: key
    1063              :       CHARACTER(LEN=default_string_length), &
    1064              :          ALLOCATABLE, DIMENSION(:)                       :: dest
    1065              : 
    1066              : #if defined(__LIBTORCH)
    1067              : 
    1068              :       INTEGER :: num_items, i
    1069              : 
    1070              :       INTERFACE
    1071              :          SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
    1072              :             BIND(C, name="torch_c_model_get_attr_list_size")
    1073              :             IMPORT :: C_PTR, C_CHAR, C_INT
    1074              :             TYPE(C_PTR), VALUE                           :: model
    1075              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
    1076              :             INTEGER(kind=C_INT)                          :: size
    1077              :          END SUBROUTINE torch_c_model_get_attr_list_size
    1078              :       END INTERFACE
    1079              : 
    1080              :       INTERFACE
    1081              :          SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
    1082              :             BIND(C, name="torch_c_model_get_attr_strlist")
    1083              :             IMPORT :: C_PTR, C_CHAR, C_INT
    1084              :             TYPE(C_PTR), VALUE                           :: model
    1085              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: key
    1086              :             INTEGER(kind=C_INT), VALUE                   :: index
    1087              :             CHARACTER(kind=C_CHAR), DIMENSION(*)         :: dest
    1088              :          END SUBROUTINE torch_c_model_get_attr_strlist
    1089              :       END INTERFACE
    1090              : 
    1091              :       CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
    1092              :                                             key=TRIM(key)//C_NULL_CHAR, &
    1093            8 :                                             size=num_items)
    1094           24 :       ALLOCATE (dest(num_items))
    1095           24 :       dest(:) = ""
    1096              : 
    1097           24 :       DO i = 1, num_items
    1098              :          CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
    1099              :                                              key=TRIM(key)//C_NULL_CHAR, &
    1100              :                                              index=i - 1, &
    1101           24 :                                              dest=dest(i))
    1102              : 
    1103              :       END DO
    1104              : #else
    1105              :       CPABORT("CP2K compiled without the Torch library.")
    1106              :       MARK_USED(model)
    1107              :       MARK_USED(key)
    1108              :       MARK_USED(dest)
    1109              : #endif
    1110              : 
    1111            8 :    END SUBROUTINE torch_model_get_attr_strlist
    1112              : 
    1113            0 : END MODULE torch_api
        

Generated by: LCOV version 2.0-1