Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 :
8 : ! **************************************************************************************************
9 : !> \brief Small CP2K wrapper around the SKALA TorchScript functional protocol.
10 : ! **************************************************************************************************
11 : MODULE skala_torch_api
12 : #if defined (__HAS_IEEE_EXCEPTIONS)
13 : USE ieee_exceptions, ONLY: ieee_all, &
14 : ieee_get_halting_mode, &
15 : ieee_set_halting_mode
16 : #endif
17 : USE kinds, ONLY: default_string_length, &
18 : dp
19 : USE string_utilities, ONLY: uppercase
20 : USE torch_api, ONLY: &
21 : torch_dict_type, torch_model_forward_mol_tensor, torch_model_load, &
22 : torch_model_read_metadata, torch_model_release, torch_model_type, &
23 : torch_tensor_item_double, torch_tensor_release, torch_tensor_type, &
24 : torch_tensor_weighted_sum
25 : #include "./base/base_uses.f90"
26 :
27 : IMPLICIT NONE
28 :
29 : PRIVATE
30 :
31 : CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'skala_torch_api'
32 :
33 : PUBLIC :: skala_torch_model_type, skala_torch_model_load, skala_torch_model_release
34 : PUBLIC :: skala_torch_model_get_exc, skala_torch_model_get_exc_density
35 : PUBLIC :: skala_torch_model_needs_feature, skala_torch_model_protocol_version
36 :
37 : TYPE skala_torch_model_type
38 : PRIVATE
39 : INTEGER :: protocol_version = -1
40 : CHARACTER(len=default_string_length), ALLOCATABLE, &
41 : DIMENSION(:) :: features
42 : TYPE(torch_model_type) :: torch_model
43 : END TYPE skala_torch_model_type
44 :
45 : CONTAINS
46 :
47 : ! **************************************************************************************************
48 : !> \brief Load a SKALA TorchScript model and its feature metadata.
49 : !> \param model ...
50 : !> \param filename ...
51 : ! **************************************************************************************************
52 44 : SUBROUTINE skala_torch_model_load(model, filename)
53 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
54 : CHARACTER(len=*), INTENT(IN) :: filename
55 :
56 44 : CHARACTER(:), ALLOCATABLE :: features_json, protocol_string
57 : INTEGER :: ios
58 :
59 44 : CALL torch_model_load(model%torch_model, filename)
60 44 : protocol_string = torch_model_read_metadata(filename, "protocol_version")
61 44 : features_json = torch_model_read_metadata(filename, "features")
62 44 : READ (protocol_string, *, IOSTAT=ios) model%protocol_version
63 44 : IF (ios /= 0) CPABORT("Could not parse SKALA TorchScript protocol_version metadata")
64 44 : IF (model%protocol_version /= 2) THEN
65 0 : CPABORT("Unsupported SKALA TorchScript protocol version")
66 : END IF
67 :
68 44 : CALL parse_feature_list(features_json, model%features)
69 :
70 44 : END SUBROUTINE skala_torch_model_load
71 :
72 : ! **************************************************************************************************
73 : !> \brief Release a loaded SKALA TorchScript model.
74 : !> \param model ...
75 : ! **************************************************************************************************
76 0 : SUBROUTINE skala_torch_model_release(model)
77 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
78 :
79 0 : CALL torch_model_release(model%torch_model)
80 0 : IF (ALLOCATED(model%features)) DEALLOCATE (model%features)
81 0 : model%protocol_version = -1
82 :
83 0 : END SUBROUTINE skala_torch_model_release
84 :
85 : ! **************************************************************************************************
86 : !> \brief Check whether a loaded SKALA model requests a feature.
87 : !> \param model ...
88 : !> \param feature ...
89 : !> \return ...
90 : ! **************************************************************************************************
91 0 : FUNCTION skala_torch_model_needs_feature(model, feature) RESULT(needs_feature)
92 : TYPE(skala_torch_model_type), INTENT(IN) :: model
93 : CHARACTER(len=*), INTENT(IN) :: feature
94 : LOGICAL :: needs_feature
95 :
96 : CHARACTER(len=default_string_length) :: feature_key, model_feature
97 : INTEGER :: i
98 :
99 0 : feature_key = ADJUSTL(feature)
100 0 : CALL uppercase(feature_key)
101 :
102 0 : needs_feature = .FALSE.
103 0 : IF (.NOT. ALLOCATED(model%features)) RETURN
104 :
105 0 : DO i = 1, SIZE(model%features)
106 0 : model_feature = ADJUSTL(model%features(i))
107 0 : CALL uppercase(model_feature)
108 0 : IF (TRIM(model_feature) == TRIM(feature_key)) THEN
109 0 : needs_feature = .TRUE.
110 : RETURN
111 : END IF
112 : END DO
113 :
114 0 : END FUNCTION skala_torch_model_needs_feature
115 :
116 : ! **************************************************************************************************
117 : !> \brief Return the loaded SKALA TorchScript protocol version.
118 : !> \param model ...
119 : !> \return ...
120 : ! **************************************************************************************************
121 0 : FUNCTION skala_torch_model_protocol_version(model) RESULT(protocol_version)
122 : TYPE(skala_torch_model_type), INTENT(IN) :: model
123 : INTEGER :: protocol_version
124 :
125 0 : protocol_version = model%protocol_version
126 :
127 0 : END FUNCTION skala_torch_model_protocol_version
128 :
129 : ! **************************************************************************************************
130 : !> \brief Evaluate the SKALA exchange-correlation energy density.
131 : !> \param model ...
132 : !> \param inputs ...
133 : !> \param exc_density ...
134 : ! **************************************************************************************************
135 0 : SUBROUTINE skala_torch_model_get_exc_density(model, inputs, exc_density)
136 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
137 : TYPE(torch_dict_type), INTENT(IN) :: inputs
138 : TYPE(torch_tensor_type), INTENT(INOUT) :: exc_density
139 :
140 : #if defined (__HAS_IEEE_EXCEPTIONS)
141 : LOGICAL, DIMENSION(5) :: ieee_halt
142 :
143 : CALL ieee_get_halting_mode(IEEE_ALL, ieee_halt)
144 : CALL ieee_set_halting_mode(IEEE_ALL, .FALSE.)
145 : #endif
146 0 : CALL torch_model_forward_mol_tensor(model%torch_model, "get_exc_density", inputs, exc_density)
147 : #if defined (__HAS_IEEE_EXCEPTIONS)
148 : CALL ieee_set_halting_mode(IEEE_ALL, ieee_halt)
149 : #endif
150 :
151 0 : END SUBROUTINE skala_torch_model_get_exc_density
152 :
153 : ! **************************************************************************************************
154 : !> \brief Evaluate the weighted SKALA exchange-correlation energy.
155 : !> \param model ...
156 : !> \param inputs ...
157 : !> \param grid_weights ...
158 : !> \param exc_tensor ...
159 : !> \param exc ...
160 : ! **************************************************************************************************
161 154 : SUBROUTINE skala_torch_model_get_exc(model, inputs, grid_weights, exc_tensor, exc)
162 : TYPE(skala_torch_model_type), INTENT(INOUT) :: model
163 : TYPE(torch_dict_type), INTENT(IN) :: inputs
164 : TYPE(torch_tensor_type), INTENT(IN) :: grid_weights
165 : TYPE(torch_tensor_type), INTENT(INOUT) :: exc_tensor
166 : REAL(KIND=dp), INTENT(OUT) :: exc
167 :
168 : TYPE(torch_tensor_type) :: exc_density
169 :
170 : #if defined (__HAS_IEEE_EXCEPTIONS)
171 : LOGICAL, DIMENSION(5) :: ieee_halt
172 :
173 : CALL ieee_get_halting_mode(IEEE_ALL, ieee_halt)
174 : CALL ieee_set_halting_mode(IEEE_ALL, .FALSE.)
175 : #endif
176 154 : CALL torch_model_forward_mol_tensor(model%torch_model, "get_exc_density", inputs, exc_density)
177 154 : CALL torch_tensor_weighted_sum(exc_density, grid_weights, exc_tensor)
178 154 : CALL torch_tensor_release(exc_density)
179 154 : exc = torch_tensor_item_double(exc_tensor)
180 : #if defined (__HAS_IEEE_EXCEPTIONS)
181 : CALL ieee_set_halting_mode(IEEE_ALL, ieee_halt)
182 : #endif
183 :
184 154 : END SUBROUTINE skala_torch_model_get_exc
185 :
186 : ! **************************************************************************************************
187 : !> \brief Parse a TorchScript extra_files JSON list of feature names.
188 : !> \param features_json ...
189 : !> \param features ...
190 : ! **************************************************************************************************
191 44 : SUBROUTINE parse_feature_list(features_json, features)
192 : CHARACTER(len=*), INTENT(IN) :: features_json
193 : CHARACTER(len=default_string_length), &
194 : ALLOCATABLE, DIMENSION(:), INTENT(OUT) :: features
195 :
196 : INTEGER :: end_pos, feature_count, i, pos, quote1, &
197 : quote2, start_pos
198 :
199 44 : feature_count = 0
200 44 : pos = 1
201 396 : DO
202 440 : quote1 = INDEX(features_json(pos:), '"')
203 440 : IF (quote1 == 0) EXIT
204 396 : start_pos = pos + quote1
205 396 : quote2 = INDEX(features_json(start_pos:), '"')
206 396 : IF (quote2 == 0) EXIT
207 396 : feature_count = feature_count + 1
208 396 : pos = start_pos + quote2
209 : END DO
210 :
211 44 : IF (feature_count == 0) CPABORT("SKALA TorchScript model does not list any features")
212 132 : ALLOCATE (features(feature_count))
213 440 : features = ""
214 :
215 : pos = 1
216 440 : DO i = 1, feature_count
217 396 : quote1 = INDEX(features_json(pos:), '"')
218 396 : start_pos = pos + quote1
219 396 : quote2 = INDEX(features_json(start_pos:), '"')
220 396 : end_pos = start_pos + quote2 - 2
221 396 : features(i) = features_json(start_pos:end_pos)
222 440 : pos = start_pos + quote2
223 : END DO
224 :
225 44 : END SUBROUTINE parse_feature_list
226 :
227 0 : END MODULE skala_torch_api
|