Line data Source code
1 : !--------------------------------------------------------------------------------------------------!
2 : ! CP2K: A general program to perform molecular dynamics simulations !
3 : ! Copyright 2000-2026 CP2K developers group <https://cp2k.org> !
4 : ! !
5 : ! SPDX-License-Identifier: GPL-2.0-or-later !
6 : !--------------------------------------------------------------------------------------------------!
7 : MODULE torch_api
8 : USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
9 : C_BOOL, &
10 : C_CHAR, &
11 : C_FLOAT, &
12 : C_DOUBLE, &
13 : C_F_POINTER, &
14 : C_INT, &
15 : C_NULL_CHAR, &
16 : C_NULL_PTR, &
17 : C_PTR, &
18 : C_INT32_T, &
19 : C_INT64_T
20 :
21 : USE kinds, ONLY: sp, int_4, int_8, dp, default_string_length
22 :
23 : #include "./base/base_uses.f90"
24 :
25 : IMPLICIT NONE
26 :
27 : PRIVATE
28 :
29 : TYPE torch_tensor_type
30 : PRIVATE
31 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
32 : END TYPE torch_tensor_type
33 :
34 : TYPE torch_dict_type
35 : PRIVATE
36 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
37 : END TYPE torch_dict_type
38 :
39 : TYPE torch_model_type
40 : PRIVATE
41 : TYPE(C_PTR) :: c_ptr = C_NULL_PTR
42 : END TYPE torch_model_type
43 :
44 : #:set max_dim = 3
45 : INTERFACE torch_tensor_from_array
46 : #:for ndims in range(1, max_dim+1)
47 : MODULE PROCEDURE torch_tensor_from_array_int32_${ndims}$d
48 : MODULE PROCEDURE torch_tensor_from_array_float_${ndims}$d
49 : MODULE PROCEDURE torch_tensor_from_array_int64_${ndims}$d
50 : MODULE PROCEDURE torch_tensor_from_array_double_${ndims}$d
51 : #:endfor
52 : END INTERFACE torch_tensor_from_array
53 :
54 : INTERFACE torch_tensor_reset_from_array
55 : #:for ndims in range(1, max_dim+1)
56 : MODULE PROCEDURE torch_tensor_reset_from_array_double_${ndims}$d
57 : #:endfor
58 : END INTERFACE torch_tensor_reset_from_array
59 :
60 : INTERFACE torch_tensor_data_ptr
61 : #:for ndims in range(1, max_dim+1)
62 : MODULE PROCEDURE torch_tensor_data_ptr_int32_${ndims}$d
63 : MODULE PROCEDURE torch_tensor_data_ptr_float_${ndims}$d
64 : MODULE PROCEDURE torch_tensor_data_ptr_int64_${ndims}$d
65 : MODULE PROCEDURE torch_tensor_data_ptr_double_${ndims}$d
66 : #:endfor
67 : END INTERFACE torch_tensor_data_ptr
68 :
69 : INTERFACE torch_model_get_attr
70 : MODULE PROCEDURE torch_model_get_attr_string
71 : MODULE PROCEDURE torch_model_get_attr_double
72 : MODULE PROCEDURE torch_model_get_attr_int64
73 : MODULE PROCEDURE torch_model_get_attr_int32
74 : MODULE PROCEDURE torch_model_get_attr_strlist
75 : END INTERFACE torch_model_get_attr
76 :
77 : PUBLIC :: torch_tensor_type, torch_tensor_expand_dim, torch_tensor_from_array, &
78 : torch_tensor_narrow, torch_tensor_release
79 : PUBLIC :: torch_tensor_reset_from_array
80 : PUBLIC :: torch_tensor_data_ptr, torch_tensor_backward, torch_tensor_backward_scalar
81 : PUBLIC :: torch_tensor_grad
82 : PUBLIC :: torch_tensor_to_device_leaf
83 : PUBLIC :: torch_tensor_item_double, torch_tensor_weighted_sum
84 : PUBLIC :: torch_dict_type, torch_dict_clone, torch_dict_create, torch_dict_insert
85 : PUBLIC :: torch_dict_get, torch_dict_release
86 : PUBLIC :: torch_model_type, torch_model_load, torch_model_forward, torch_model_release
87 : PUBLIC :: torch_model_forward_mol_tensor
88 : PUBLIC :: torch_model_get_attr, torch_model_read_metadata
89 : PUBLIC :: torch_cuda_device_count, torch_cuda_is_available
90 : PUBLIC :: torch_allow_tf32, torch_model_freeze, torch_use_cuda
91 :
92 : CONTAINS
93 :
94 : #:set typenames = ['int32', 'float', 'int64', 'double']
95 : #:set types_f = ['INTEGER(kind=int_4)', 'REAL(sp)', 'INTEGER(kind=int_8)', 'REAL(dp)']
96 : #:set types_c = ['INTEGER(kind=C_INT32_T)', 'REAL(kind=C_FLOAT)', 'INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)']
97 :
98 : #:for ndims in range(1, max_dim+1)
99 : #:for typename, type_f, type_c in zip(typenames, types_f, types_c)
100 :
101 : ! **************************************************************************************************
102 : !> \brief Creates a Torch tensor from an array. The passed array has to outlive the tensor!
103 : !> The source must be an ALLOCATABLE to prevent passing a temporary array.
104 : !> \author Ole Schuett
105 : ! **************************************************************************************************
106 1100 : SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d(tensor, source, requires_grad)
107 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
108 : #:set arraydims = ", ".join(":" for i in range(ndims))
109 : ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source
110 : LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
111 :
112 : #if defined(__LIBTORCH)
113 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
114 : LOGICAL :: my_req_grad
115 :
116 : INTERFACE
117 : SUBROUTINE torch_c_tensor_from_array_${typename}$ (tensor, req_grad, ndims, sizes, source) &
118 : BIND(C, name="torch_c_tensor_from_array_${typename}$")
119 : IMPORT :: C_PTR, C_INT, C_INT32_T, C_INT64_T, C_FLOAT, C_DOUBLE, C_BOOL
120 : TYPE(C_PTR) :: tensor
121 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
122 : INTEGER(kind=C_INT), VALUE :: ndims
123 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
124 : ${type_c}$, DIMENSION(*) :: source
125 : END SUBROUTINE torch_c_tensor_from_array_${typename}$
126 : END INTERFACE
127 :
128 1100 : my_req_grad = .FALSE.
129 1100 : IF (PRESENT(requires_grad)) my_req_grad = requires_grad
130 :
131 : #:for axis in range(ndims)
132 1100 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
133 : #:endfor
134 :
135 1100 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
136 : CALL torch_c_tensor_from_array_${typename}$ (tensor=tensor%c_ptr, &
137 : req_grad=LOGICAL(my_req_grad, C_BOOL), &
138 : ndims=${ndims}$, &
139 : sizes=sizes_c, &
140 1100 : source=source)
141 1100 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
142 : #else
143 : CPABORT("CP2K compiled without the Torch library.")
144 : MARK_USED(tensor)
145 : MARK_USED(source)
146 : MARK_USED(requires_grad)
147 : #endif
148 1100 : END SUBROUTINE torch_tensor_from_array_${typename}$_${ndims}$d
149 :
150 : ! **************************************************************************************************
151 : !> \brief Copies data from a Torch tensor to an array.
152 : !> The returned pointer is only valide during the tensor's lifetime!
153 : !> \author Ole Schuett
154 : ! **************************************************************************************************
155 586 : SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d(tensor, data_ptr)
156 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
157 : #:set arraydims = ", ".join(":" for i in range(ndims))
158 : ${type_f}$, DIMENSION(${arraydims}$), POINTER :: data_ptr
159 :
160 : #if defined(__LIBTORCH)
161 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_f, sizes_c
162 : TYPE(C_PTR) :: data_ptr_c
163 :
164 : INTERFACE
165 : SUBROUTINE torch_c_tensor_data_ptr_${typename}$ (tensor, ndims, sizes, data_ptr) &
166 : BIND(C, name="torch_c_tensor_data_ptr_${typename}$")
167 : IMPORT :: C_CHAR, C_PTR, C_INT, C_INT32_T, C_INT64_T
168 : TYPE(C_PTR), VALUE :: tensor
169 : INTEGER(kind=C_INT), VALUE :: ndims
170 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
171 : TYPE(C_PTR) :: data_ptr
172 : END SUBROUTINE torch_c_tensor_data_ptr_${typename}$
173 : END INTERFACE
174 :
175 1940 : sizes_c(:) = -1
176 586 : data_ptr_c = C_NULL_PTR
177 586 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
178 586 : CPASSERT(.NOT. ASSOCIATED(data_ptr))
179 : CALL torch_c_tensor_data_ptr_${typename}$ (tensor=tensor%c_ptr, &
180 : ndims=${ndims}$, &
181 : sizes=sizes_c, &
182 586 : data_ptr=data_ptr_c)
183 :
184 : #:for axis in range(ndims)
185 586 : sizes_f(${axis + 1}$) = sizes_c(${ndims - axis}$) ! C arrays are stored row-major.
186 : #:endfor
187 :
188 1940 : IF (ALL(sizes_f /= 0)) THEN ! Torch returns null pointer for zero-sized tensors.
189 586 : CPASSERT(C_ASSOCIATED(data_ptr_c))
190 1940 : CALL C_F_POINTER(data_ptr_c, data_ptr, shape=sizes_f)
191 : END IF
192 : #else
193 : CPABORT("CP2K compiled without the Torch library.")
194 : MARK_USED(tensor)
195 : MARK_USED(data_ptr)
196 : #endif
197 586 : END SUBROUTINE torch_tensor_data_ptr_${typename}$_${ndims}$d
198 :
199 : #:endfor
200 : #:endfor
201 :
202 : #:for ndims in range(1, max_dim+1)
203 :
204 : ! **************************************************************************************************
205 : !> \brief Reuses or creates a device leaf tensor and copies data into it.
206 : !> The source must be an ALLOCATABLE to prevent passing a temporary array.
207 : ! **************************************************************************************************
208 432 : SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d(tensor, source, requires_grad)
209 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
210 : #:set arraydims = ", ".join(":" for i in range(ndims))
211 : REAL(dp), DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source
212 : LOGICAL, OPTIONAL, INTENT(IN) :: requires_grad
213 :
214 : #if defined(__LIBTORCH)
215 : INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c
216 : LOGICAL :: my_req_grad
217 :
218 : INTERFACE
219 : SUBROUTINE torch_c_tensor_reset_from_array_double(tensor, req_grad, ndims, sizes, source) &
220 : BIND(C, name="torch_c_tensor_reset_from_array_double")
221 : IMPORT :: C_PTR, C_INT, C_INT64_T, C_DOUBLE, C_BOOL
222 : TYPE(C_PTR) :: tensor
223 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
224 : INTEGER(kind=C_INT), VALUE :: ndims
225 : INTEGER(kind=C_INT64_T), DIMENSION(*) :: sizes
226 : REAL(kind=C_DOUBLE), DIMENSION(*) :: source
227 : END SUBROUTINE torch_c_tensor_reset_from_array_double
228 : END INTERFACE
229 :
230 432 : my_req_grad = .FALSE.
231 432 : IF (PRESENT(requires_grad)) my_req_grad = requires_grad
232 :
233 : #:for axis in range(ndims)
234 432 : sizes_c(${axis + 1}$) = SIZE(source, ${ndims - axis}$) ! C arrays are stored row-major.
235 : #:endfor
236 :
237 : CALL torch_c_tensor_reset_from_array_double(tensor=tensor%c_ptr, &
238 : req_grad=LOGICAL(my_req_grad, C_BOOL), &
239 : ndims=${ndims}$, &
240 : sizes=sizes_c, &
241 432 : source=source)
242 432 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
243 : #else
244 : CPABORT("CP2K compiled without the Torch library.")
245 : MARK_USED(tensor)
246 : MARK_USED(source)
247 : MARK_USED(requires_grad)
248 : #endif
249 432 : END SUBROUTINE torch_tensor_reset_from_array_double_${ndims}$d
250 :
251 : #:endfor
252 :
253 : ! **************************************************************************************************
254 : !> \brief Creates an expanded tensor view along one singleton dimension.
255 : ! **************************************************************************************************
256 18 : SUBROUTINE torch_tensor_expand_dim(tensor, dim, extent, result)
257 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
258 : INTEGER, INTENT(IN) :: dim, extent
259 : TYPE(torch_tensor_type), INTENT(INOUT) :: result
260 :
261 : #if defined(__LIBTORCH)
262 : INTERFACE
263 : SUBROUTINE torch_c_tensor_expand_dim(tensor, dim, extent, result) &
264 : BIND(C, name="torch_c_tensor_expand_dim")
265 : IMPORT :: C_INT64_T, C_PTR
266 : TYPE(C_PTR), VALUE :: tensor
267 : INTEGER(kind=C_INT64_T), VALUE :: dim, extent
268 : TYPE(C_PTR) :: result
269 : END SUBROUTINE torch_c_tensor_expand_dim
270 : END INTERFACE
271 :
272 18 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
273 18 : CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
274 18 : CPASSERT(dim >= 0)
275 18 : CPASSERT(extent >= 0)
276 : CALL torch_c_tensor_expand_dim(tensor=tensor%c_ptr, &
277 : dim=INT(dim, C_INT64_T), &
278 : extent=INT(extent, C_INT64_T), &
279 18 : result=result%c_ptr)
280 18 : CPASSERT(C_ASSOCIATED(result%c_ptr))
281 : #else
282 : CPABORT("CP2K compiled without the Torch library.")
283 : MARK_USED(tensor)
284 : MARK_USED(dim)
285 : MARK_USED(extent)
286 : MARK_USED(result)
287 : #endif
288 18 : END SUBROUTINE torch_tensor_expand_dim
289 :
290 : ! **************************************************************************************************
291 : !> \brief Creates a view of a contiguous tensor slice.
292 : ! **************************************************************************************************
293 32 : SUBROUTINE torch_tensor_narrow(tensor, dim, start_index, length, result)
294 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
295 : INTEGER, INTENT(IN) :: dim, start_index, length
296 : TYPE(torch_tensor_type), INTENT(INOUT) :: result
297 :
298 : #if defined(__LIBTORCH)
299 : INTERFACE
300 : SUBROUTINE torch_c_tensor_narrow(tensor, dim, start_index, length, result) &
301 : BIND(C, name="torch_c_tensor_narrow")
302 : IMPORT :: C_INT64_T, C_PTR
303 : TYPE(C_PTR), VALUE :: tensor
304 : INTEGER(kind=C_INT64_T), VALUE :: dim, start_index, length
305 : TYPE(C_PTR) :: result
306 : END SUBROUTINE torch_c_tensor_narrow
307 : END INTERFACE
308 :
309 32 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
310 32 : CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
311 32 : CPASSERT(dim >= 0)
312 32 : CPASSERT(start_index >= 0)
313 32 : CPASSERT(length >= 0)
314 : CALL torch_c_tensor_narrow(tensor=tensor%c_ptr, &
315 : dim=INT(dim, C_INT64_T), &
316 : start_index=INT(start_index, C_INT64_T), &
317 : length=INT(length, C_INT64_T), &
318 32 : result=result%c_ptr)
319 32 : CPASSERT(C_ASSOCIATED(result%c_ptr))
320 : #else
321 : CPABORT("CP2K compiled without the Torch library.")
322 : MARK_USED(tensor)
323 : MARK_USED(dim)
324 : MARK_USED(start_index)
325 : MARK_USED(length)
326 : MARK_USED(result)
327 : #endif
328 32 : END SUBROUTINE torch_tensor_narrow
329 :
330 : ! **************************************************************************************************
331 : !> \brief Runs autograd on a Torch tensor.
332 : !> \author Ole Schuett
333 : ! **************************************************************************************************
334 6 : SUBROUTINE torch_tensor_backward(tensor, outer_grad)
335 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
336 : TYPE(torch_tensor_type), INTENT(IN) :: outer_grad
337 :
338 : #if defined(__LIBTORCH)
339 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_tensor_backward'
340 : INTEGER :: handle
341 :
342 : INTERFACE
343 : SUBROUTINE torch_c_tensor_backward(tensor, outer_grad) &
344 : BIND(C, name="torch_c_tensor_backward")
345 : IMPORT :: C_CHAR, C_PTR
346 : TYPE(C_PTR), VALUE :: tensor
347 : TYPE(C_PTR), VALUE :: outer_grad
348 : END SUBROUTINE torch_c_tensor_backward
349 : END INTERFACE
350 :
351 6 : CALL timeset(routineN, handle)
352 6 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
353 6 : CPASSERT(C_ASSOCIATED(outer_grad%c_ptr))
354 6 : CALL torch_c_tensor_backward(tensor=tensor%c_ptr, outer_grad=outer_grad%c_ptr)
355 6 : CALL timestop(handle)
356 : #else
357 : CPABORT("CP2K compiled without the Torch library.")
358 : MARK_USED(tensor)
359 : MARK_USED(outer_grad)
360 : #endif
361 6 : END SUBROUTINE torch_tensor_backward
362 :
363 : ! **************************************************************************************************
364 : !> \brief Runs autograd on a scalar Torch tensor.
365 : ! **************************************************************************************************
366 154 : SUBROUTINE torch_tensor_backward_scalar(tensor)
367 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
368 :
369 : #if defined(__LIBTORCH)
370 : INTERFACE
371 : SUBROUTINE torch_c_tensor_backward_scalar(tensor) &
372 : BIND(C, name="torch_c_tensor_backward_scalar")
373 : IMPORT :: C_PTR
374 : TYPE(C_PTR), VALUE :: tensor
375 : END SUBROUTINE torch_c_tensor_backward_scalar
376 : END INTERFACE
377 :
378 154 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
379 154 : CALL torch_c_tensor_backward_scalar(tensor=tensor%c_ptr)
380 : #else
381 : CPABORT("CP2K compiled without the Torch library.")
382 : MARK_USED(tensor)
383 : #endif
384 154 : END SUBROUTINE torch_tensor_backward_scalar
385 :
386 : ! **************************************************************************************************
387 : !> \brief Moves a tensor to the active Torch device and makes it an autograd leaf.
388 : ! **************************************************************************************************
389 848 : SUBROUTINE torch_tensor_to_device_leaf(tensor, requires_grad)
390 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
391 : LOGICAL, INTENT(IN) :: requires_grad
392 :
393 : #if defined(__LIBTORCH)
394 : INTERFACE
395 : SUBROUTINE torch_c_tensor_to_device_leaf(tensor, req_grad) &
396 : BIND(C, name="torch_c_tensor_to_device_leaf")
397 : IMPORT :: C_BOOL, C_PTR
398 : TYPE(C_PTR) :: tensor
399 : LOGICAL(kind=C_BOOL), VALUE :: req_grad
400 : END SUBROUTINE torch_c_tensor_to_device_leaf
401 : END INTERFACE
402 :
403 848 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
404 : CALL torch_c_tensor_to_device_leaf(tensor=tensor%c_ptr, &
405 848 : req_grad=LOGICAL(requires_grad, C_BOOL))
406 848 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
407 : #else
408 : CPABORT("CP2K compiled without the Torch library.")
409 : MARK_USED(tensor)
410 : MARK_USED(requires_grad)
411 : #endif
412 848 : END SUBROUTINE torch_tensor_to_device_leaf
413 :
414 : ! **************************************************************************************************
415 : !> \brief Select whether Torch wrappers should use CUDA when available.
416 : ! **************************************************************************************************
417 304 : SUBROUTINE torch_use_cuda(use_cuda)
418 : LOGICAL, INTENT(IN) :: use_cuda
419 :
420 : #if defined(__LIBTORCH)
421 : INTERFACE
422 : SUBROUTINE torch_c_use_cuda(use_cuda) BIND(C, name="torch_c_use_cuda")
423 : IMPORT :: C_BOOL
424 : LOGICAL(kind=C_BOOL), VALUE :: use_cuda
425 : END SUBROUTINE torch_c_use_cuda
426 : END INTERFACE
427 :
428 304 : CALL torch_c_use_cuda(use_cuda=LOGICAL(use_cuda, C_BOOL))
429 : #else
430 : MARK_USED(use_cuda)
431 : #endif
432 304 : END SUBROUTINE torch_use_cuda
433 :
434 : ! **************************************************************************************************
435 : !> \brief Returns the gradient of a Torch tensor which was computed by autograd.
436 : !> \author Ole Schuett
437 : ! **************************************************************************************************
438 514 : SUBROUTINE torch_tensor_grad(tensor, grad)
439 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
440 : TYPE(torch_tensor_type), INTENT(INOUT) :: grad
441 :
442 : #if defined(__LIBTORCH)
443 : INTERFACE
444 : SUBROUTINE torch_c_tensor_grad(tensor, grad) &
445 : BIND(C, name="torch_c_tensor_grad")
446 : IMPORT :: C_PTR
447 : TYPE(C_PTR), VALUE :: tensor
448 : TYPE(C_PTR) :: grad
449 : END SUBROUTINE torch_c_tensor_grad
450 : END INTERFACE
451 :
452 514 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
453 514 : CPASSERT(.NOT. C_ASSOCIATED(grad%c_ptr))
454 514 : CALL torch_c_tensor_grad(tensor=tensor%c_ptr, grad=grad%c_ptr)
455 514 : CPASSERT(C_ASSOCIATED(grad%c_ptr))
456 : #else
457 : CPABORT("CP2K compiled without the Torch library.")
458 : MARK_USED(tensor)
459 : MARK_USED(grad)
460 : #endif
461 514 : END SUBROUTINE torch_tensor_grad
462 :
463 : ! **************************************************************************************************
464 : !> \brief Returns the weighted sum of two Torch tensors.
465 : ! **************************************************************************************************
466 154 : SUBROUTINE torch_tensor_weighted_sum(values, weights, result)
467 : TYPE(torch_tensor_type), INTENT(IN) :: values, weights
468 : TYPE(torch_tensor_type), INTENT(INOUT) :: result
469 :
470 : #if defined(__LIBTORCH)
471 : INTERFACE
472 : SUBROUTINE torch_c_tensor_weighted_sum(values, weights, result) &
473 : BIND(C, name="torch_c_tensor_weighted_sum")
474 : IMPORT :: C_PTR
475 : TYPE(C_PTR), VALUE :: values
476 : TYPE(C_PTR), VALUE :: weights
477 : TYPE(C_PTR) :: result
478 : END SUBROUTINE torch_c_tensor_weighted_sum
479 : END INTERFACE
480 :
481 154 : CPASSERT(C_ASSOCIATED(values%c_ptr))
482 154 : CPASSERT(C_ASSOCIATED(weights%c_ptr))
483 154 : CPASSERT(.NOT. C_ASSOCIATED(result%c_ptr))
484 154 : CALL torch_c_tensor_weighted_sum(values=values%c_ptr, weights=weights%c_ptr, result=result%c_ptr)
485 154 : CPASSERT(C_ASSOCIATED(result%c_ptr))
486 : #else
487 : CPABORT("CP2K compiled without the Torch library.")
488 : MARK_USED(values)
489 : MARK_USED(weights)
490 : MARK_USED(result)
491 : #endif
492 154 : END SUBROUTINE torch_tensor_weighted_sum
493 :
494 : ! **************************************************************************************************
495 : !> \brief Returns a scalar double value from a Torch tensor.
496 : ! **************************************************************************************************
497 154 : FUNCTION torch_tensor_item_double(tensor) RESULT(value)
498 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
499 : REAL(KIND=dp) :: value
500 :
501 : #if defined(__LIBTORCH)
502 : INTERFACE
503 : FUNCTION torch_c_tensor_item_double(tensor) RESULT(value) &
504 : BIND(C, name="torch_c_tensor_item_double")
505 : IMPORT :: C_DOUBLE, C_PTR
506 : TYPE(C_PTR), VALUE :: tensor
507 : REAL(KIND=C_DOUBLE) :: value
508 : END FUNCTION torch_c_tensor_item_double
509 : END INTERFACE
510 :
511 154 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
512 154 : value = torch_c_tensor_item_double(tensor=tensor%c_ptr)
513 : #else
514 : value = 0.0_dp
515 : CPABORT("CP2K compiled without the Torch library.")
516 : MARK_USED(tensor)
517 : #endif
518 154 : END FUNCTION torch_tensor_item_double
519 :
520 : ! **************************************************************************************************
521 : !> \brief Releases a Torch tensor and all its ressources.
522 : !> \author Ole Schuett
523 : ! **************************************************************************************************
524 1446 : SUBROUTINE torch_tensor_release(tensor)
525 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
526 :
527 : #if defined(__LIBTORCH)
528 : INTERFACE
529 : SUBROUTINE torch_c_tensor_release(tensor) BIND(C, name="torch_c_tensor_release")
530 : IMPORT :: C_PTR
531 : TYPE(C_PTR), VALUE :: tensor
532 : END SUBROUTINE torch_c_tensor_release
533 : END INTERFACE
534 :
535 1446 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
536 1446 : CALL torch_c_tensor_release(tensor=tensor%c_ptr)
537 1446 : tensor%c_ptr = C_NULL_PTR
538 : #else
539 : CPABORT("CP2K was compiled without Torch library.")
540 : MARK_USED(tensor)
541 : #endif
542 1446 : END SUBROUTINE torch_tensor_release
543 :
544 : ! **************************************************************************************************
545 : !> \brief Creates an empty Torch dictionary.
546 : !> \author Ole Schuett
547 : ! **************************************************************************************************
548 246 : SUBROUTINE torch_dict_create(dict)
549 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
550 :
551 : #if defined(__LIBTORCH)
552 : INTERFACE
553 : SUBROUTINE torch_c_dict_create(dict) BIND(C, name="torch_c_dict_create")
554 : IMPORT :: C_PTR
555 : TYPE(C_PTR) :: dict
556 : END SUBROUTINE torch_c_dict_create
557 : END INTERFACE
558 :
559 246 : CPASSERT(.NOT. C_ASSOCIATED(dict%c_ptr))
560 246 : CALL torch_c_dict_create(dict=dict%c_ptr)
561 246 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
562 : #else
563 : CPABORT("CP2K was compiled without Torch library.")
564 : MARK_USED(dict)
565 : #endif
566 246 : END SUBROUTINE torch_dict_create
567 :
568 : ! **************************************************************************************************
569 : !> \brief Clones a Torch dictionary.
570 : ! **************************************************************************************************
571 58 : SUBROUTINE torch_dict_clone(source, target)
572 : TYPE(torch_dict_type), INTENT(IN) :: source
573 : TYPE(torch_dict_type), INTENT(INOUT) :: target
574 :
575 : #if defined(__LIBTORCH)
576 : INTERFACE
577 : SUBROUTINE torch_c_dict_clone(source, target) BIND(C, name="torch_c_dict_clone")
578 : IMPORT :: C_PTR
579 : TYPE(C_PTR), VALUE :: source
580 : TYPE(C_PTR) :: target
581 : END SUBROUTINE torch_c_dict_clone
582 : END INTERFACE
583 :
584 58 : CPASSERT(C_ASSOCIATED(source%c_ptr))
585 58 : CPASSERT(.NOT. C_ASSOCIATED(target%c_ptr))
586 58 : CALL torch_c_dict_clone(source=source%c_ptr, target=target%c_ptr)
587 58 : CPASSERT(C_ASSOCIATED(target%c_ptr))
588 : #else
589 : CPABORT("CP2K was compiled without Torch library.")
590 : MARK_USED(source)
591 : MARK_USED(target)
592 : #endif
593 58 : END SUBROUTINE torch_dict_clone
594 :
595 : ! **************************************************************************************************
596 : !> \brief Inserts a Torch tensor into a Torch dictionary.
597 : !> \author Ole Schuett
598 : ! **************************************************************************************************
599 1196 : SUBROUTINE torch_dict_insert(dict, key, tensor)
600 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
601 : CHARACTER(len=*), INTENT(IN) :: key
602 : TYPE(torch_tensor_type), INTENT(IN) :: tensor
603 :
604 : #if defined(__LIBTORCH)
605 :
606 : INTERFACE
607 : SUBROUTINE torch_c_dict_insert(dict, key, tensor) &
608 : BIND(C, name="torch_c_dict_insert")
609 : IMPORT :: C_CHAR, C_PTR
610 : TYPE(C_PTR), VALUE :: dict
611 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
612 : TYPE(C_PTR), VALUE :: tensor
613 : END SUBROUTINE torch_c_dict_insert
614 : END INTERFACE
615 :
616 1196 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
617 1196 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
618 1196 : CALL torch_c_dict_insert(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
619 : #else
620 : CPABORT("CP2K compiled without the Torch library.")
621 : MARK_USED(dict)
622 : MARK_USED(key)
623 : MARK_USED(tensor)
624 : #endif
625 1196 : END SUBROUTINE torch_dict_insert
626 :
627 : ! **************************************************************************************************
628 : !> \brief Retrieves a Torch tensor from a Torch dictionary.
629 : !> \author Ole Schuett
630 : ! **************************************************************************************************
631 72 : SUBROUTINE torch_dict_get(dict, key, tensor)
632 : TYPE(torch_dict_type), INTENT(IN) :: dict
633 : CHARACTER(len=*), INTENT(IN) :: key
634 : TYPE(torch_tensor_type), INTENT(INOUT) :: tensor
635 :
636 : #if defined(__LIBTORCH)
637 :
638 : INTERFACE
639 : SUBROUTINE torch_c_dict_get(dict, key, tensor) &
640 : BIND(C, name="torch_c_dict_get")
641 : IMPORT :: C_CHAR, C_PTR
642 : TYPE(C_PTR), VALUE :: dict
643 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
644 : TYPE(C_PTR) :: tensor
645 : END SUBROUTINE torch_c_dict_get
646 : END INTERFACE
647 :
648 72 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
649 72 : CPASSERT(.NOT. C_ASSOCIATED(tensor%c_ptr))
650 72 : CALL torch_c_dict_get(dict=dict%c_ptr, key=TRIM(key)//C_NULL_CHAR, tensor=tensor%c_ptr)
651 72 : CPASSERT(C_ASSOCIATED(tensor%c_ptr))
652 :
653 : #else
654 : CPABORT("CP2K compiled without the Torch library.")
655 : MARK_USED(dict)
656 : MARK_USED(key)
657 : MARK_USED(tensor)
658 : #endif
659 72 : END SUBROUTINE torch_dict_get
660 :
661 : ! **************************************************************************************************
662 : !> \brief Releases a Torch dictionary and all its ressources.
663 : !> \author Ole Schuett
664 : ! **************************************************************************************************
665 172 : SUBROUTINE torch_dict_release(dict)
666 : TYPE(torch_dict_type), INTENT(INOUT) :: dict
667 :
668 : #if defined(__LIBTORCH)
669 : INTERFACE
670 : SUBROUTINE torch_c_dict_release(dict) BIND(C, name="torch_c_dict_release")
671 : IMPORT :: C_PTR
672 : TYPE(C_PTR), VALUE :: dict
673 : END SUBROUTINE torch_c_dict_release
674 : END INTERFACE
675 :
676 172 : CPASSERT(C_ASSOCIATED(dict%c_ptr))
677 172 : CALL torch_c_dict_release(dict=dict%c_ptr)
678 172 : dict%c_ptr = C_NULL_PTR
679 : #else
680 : CPABORT("CP2K was compiled without Torch library.")
681 : MARK_USED(dict)
682 : #endif
683 172 : END SUBROUTINE torch_dict_release
684 :
685 : ! **************************************************************************************************
686 : !> \brief Loads a Torch model from given "*.pth" file. (In Torch lingo models are called modules)
687 : !> \author Ole Schuett
688 : ! **************************************************************************************************
689 58 : SUBROUTINE torch_model_load(model, filename)
690 : TYPE(torch_model_type), INTENT(INOUT) :: model
691 : CHARACTER(len=*), INTENT(IN) :: filename
692 :
693 : #if defined(__LIBTORCH)
694 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_load'
695 : INTEGER :: handle
696 :
697 : INTERFACE
698 : SUBROUTINE torch_c_model_load(model, filename) BIND(C, name="torch_c_model_load")
699 : IMPORT :: C_PTR, C_CHAR
700 : TYPE(C_PTR) :: model
701 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename
702 : END SUBROUTINE torch_c_model_load
703 : END INTERFACE
704 :
705 58 : CALL timeset(routineN, handle)
706 58 : CPASSERT(.NOT. C_ASSOCIATED(model%c_ptr))
707 58 : CALL torch_c_model_load(model=model%c_ptr, filename=TRIM(filename)//C_NULL_CHAR)
708 58 : CPASSERT(C_ASSOCIATED(model%c_ptr))
709 58 : CALL timestop(handle)
710 : #else
711 : CPABORT("CP2K was compiled without Torch library.")
712 : MARK_USED(model)
713 : MARK_USED(filename)
714 : #endif
715 58 : END SUBROUTINE torch_model_load
716 :
717 : ! **************************************************************************************************
718 : !> \brief Evaluates the given Torch model.
719 : !> \author Ole Schuett
720 : ! **************************************************************************************************
721 60 : SUBROUTINE torch_model_forward(model, inputs, outputs)
722 : TYPE(torch_model_type), INTENT(INOUT) :: model
723 : TYPE(torch_dict_type), INTENT(IN) :: inputs
724 : TYPE(torch_dict_type), INTENT(INOUT) :: outputs
725 :
726 : #if defined(__LIBTORCH)
727 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_forward'
728 : INTEGER :: handle
729 :
730 : INTERFACE
731 : SUBROUTINE torch_c_model_forward(model, inputs, outputs) BIND(C, name="torch_c_model_forward")
732 : IMPORT :: C_PTR
733 : TYPE(C_PTR), VALUE :: model
734 : TYPE(C_PTR), VALUE :: inputs
735 : TYPE(C_PTR), VALUE :: outputs
736 : END SUBROUTINE torch_c_model_forward
737 : END INTERFACE
738 :
739 60 : CALL timeset(routineN, handle)
740 60 : CPASSERT(C_ASSOCIATED(model%c_ptr))
741 60 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
742 60 : CPASSERT(C_ASSOCIATED(outputs%c_ptr))
743 60 : CALL torch_c_model_forward(model=model%c_ptr, inputs=inputs%c_ptr, outputs=outputs%c_ptr)
744 60 : CALL timestop(handle)
745 : #else
746 : CPABORT("CP2K was compiled without Torch library.")
747 : MARK_USED(model)
748 : MARK_USED(inputs)
749 : MARK_USED(outputs)
750 : #endif
751 60 : END SUBROUTINE torch_model_forward
752 :
753 : ! **************************************************************************************************
754 : !> \brief Evaluates a TorchScript model method expecting keyword argument "mol".
755 : ! **************************************************************************************************
756 154 : SUBROUTINE torch_model_forward_mol_tensor(model, method_name, inputs, output)
757 : TYPE(torch_model_type), INTENT(INOUT) :: model
758 : CHARACTER(len=*), INTENT(IN) :: method_name
759 : TYPE(torch_dict_type), INTENT(IN) :: inputs
760 : TYPE(torch_tensor_type), INTENT(INOUT) :: output
761 :
762 : #if defined(__LIBTORCH)
763 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_forward_mol_tensor'
764 : INTEGER :: handle
765 :
766 : INTERFACE
767 : SUBROUTINE torch_c_model_forward_mol_tensor(model, method_name, inputs, output) &
768 : BIND(C, name="torch_c_model_forward_mol_tensor")
769 : IMPORT :: C_CHAR, C_PTR
770 : TYPE(C_PTR), VALUE :: model
771 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: method_name
772 : TYPE(C_PTR), VALUE :: inputs
773 : TYPE(C_PTR) :: output
774 : END SUBROUTINE torch_c_model_forward_mol_tensor
775 : END INTERFACE
776 :
777 154 : CALL timeset(routineN, handle)
778 154 : CPASSERT(C_ASSOCIATED(model%c_ptr))
779 154 : CPASSERT(C_ASSOCIATED(inputs%c_ptr))
780 154 : CPASSERT(.NOT. C_ASSOCIATED(output%c_ptr))
781 : CALL torch_c_model_forward_mol_tensor(model=model%c_ptr, &
782 : method_name=TRIM(method_name)//C_NULL_CHAR, &
783 : inputs=inputs%c_ptr, &
784 154 : output=output%c_ptr)
785 154 : CPASSERT(C_ASSOCIATED(output%c_ptr))
786 154 : CALL timestop(handle)
787 : #else
788 : CPABORT("CP2K was compiled without Torch library.")
789 : MARK_USED(model)
790 : MARK_USED(method_name)
791 : MARK_USED(inputs)
792 : MARK_USED(output)
793 : #endif
794 154 : END SUBROUTINE torch_model_forward_mol_tensor
795 :
796 : ! **************************************************************************************************
797 : !> \brief Releases a Torch model and all its ressources.
798 : !> \author Ole Schuett
799 : ! **************************************************************************************************
800 14 : SUBROUTINE torch_model_release(model)
801 : TYPE(torch_model_type), INTENT(INOUT) :: model
802 :
803 : #if defined(__LIBTORCH)
804 : INTERFACE
805 : SUBROUTINE torch_c_model_release(model) BIND(C, name="torch_c_model_release")
806 : IMPORT :: C_PTR
807 : TYPE(C_PTR), VALUE :: model
808 : END SUBROUTINE torch_c_model_release
809 : END INTERFACE
810 :
811 14 : CPASSERT(C_ASSOCIATED(model%c_ptr))
812 14 : CALL torch_c_model_release(model=model%c_ptr)
813 14 : model%c_ptr = C_NULL_PTR
814 : #else
815 : CPABORT("CP2K was compiled without Torch library.")
816 : MARK_USED(model)
817 : #endif
818 14 : END SUBROUTINE torch_model_release
819 :
820 : ! **************************************************************************************************
821 : !> \brief Reads metadata entry from given "*.pth" file. (In Torch lingo they are called extra files)
822 : !> \author Ole Schuett
823 : ! **************************************************************************************************
824 116 : FUNCTION torch_model_read_metadata(filename, key) RESULT(res)
825 : CHARACTER(len=*), INTENT(IN) :: filename, key
826 : CHARACTER(:), ALLOCATABLE :: res
827 :
828 : #if defined(__LIBTORCH)
829 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_read_metadata'
830 : INTEGER :: handle
831 :
832 : INTEGER :: length
833 : TYPE(C_PTR) :: content_c
834 :
835 : INTERFACE
836 : SUBROUTINE torch_c_model_read_metadata(filename, key, content, length) &
837 : BIND(C, name="torch_c_model_read_metadata")
838 : IMPORT :: C_CHAR, C_PTR, C_INT
839 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: filename, key
840 : TYPE(C_PTR) :: content
841 : INTEGER(kind=C_INT) :: length
842 : END SUBROUTINE torch_c_model_read_metadata
843 : END INTERFACE
844 :
845 116 : CALL timeset(routineN, handle)
846 116 : content_c = C_NULL_PTR
847 116 : length = -1
848 : CALL torch_c_model_read_metadata(filename=TRIM(filename)//C_NULL_CHAR, &
849 : key=TRIM(key)//C_NULL_CHAR, &
850 : content=content_c, &
851 116 : length=length)
852 116 : CALL c_string_to_allocatable(content_c, length, res)
853 116 : CALL timestop(handle)
854 : #else
855 : res = ""
856 : MARK_USED(filename)
857 : MARK_USED(key)
858 : CPABORT("CP2K was compiled without Torch library.")
859 : #endif
860 116 : END FUNCTION torch_model_read_metadata
861 :
862 : ! **************************************************************************************************
863 : !> \brief Move a C-allocated null-terminated string into an allocatable Fortran string.
864 : ! **************************************************************************************************
865 116 : SUBROUTINE c_string_to_allocatable(content_c, length, res)
866 : TYPE(C_PTR), INTENT(INOUT) :: content_c
867 : INTEGER, INTENT(IN) :: length
868 : CHARACTER(:), ALLOCATABLE, INTENT(OUT) :: res
869 :
870 : #if defined(__LIBTORCH)
871 : CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), &
872 116 : POINTER :: content_f
873 : INTEGER :: i
874 :
875 : INTERFACE
876 : SUBROUTINE torch_c_free_string(content) BIND(C, name="torch_c_free_string")
877 : IMPORT :: C_PTR
878 : TYPE(C_PTR), VALUE :: content
879 : END SUBROUTINE torch_c_free_string
880 : END INTERFACE
881 :
882 0 : CPASSERT(C_ASSOCIATED(content_c))
883 116 : CPASSERT(length >= 0)
884 :
885 232 : CALL C_F_POINTER(content_c, content_f, shape=[length + 1])
886 116 : CPASSERT(content_f(length + 1) == C_NULL_CHAR)
887 :
888 116 : ALLOCATE (CHARACTER(LEN=length) :: res)
889 7232 : DO i = 1, length
890 7116 : CPASSERT(content_f(i) /= C_NULL_CHAR)
891 7232 : res(i:i) = content_f(i)
892 : END DO
893 :
894 116 : NULLIFY (content_f)
895 116 : CALL torch_c_free_string(content_c)
896 116 : content_c = C_NULL_PTR
897 :
898 : #else
899 : res = ""
900 : MARK_USED(content_c)
901 : MARK_USED(length)
902 : CPABORT("CP2K was compiled without Torch library.")
903 : #endif
904 116 : END SUBROUTINE c_string_to_allocatable
905 :
906 : ! **************************************************************************************************
907 : !> \brief Returns true iff the Torch CUDA backend is available.
908 : !> \author Ole Schuett
909 : ! **************************************************************************************************
910 2 : FUNCTION torch_cuda_is_available() RESULT(res)
911 : LOGICAL :: res
912 :
913 : #if defined(__LIBTORCH)
914 : INTERFACE
915 : FUNCTION torch_c_cuda_is_available() BIND(C, name="torch_c_cuda_is_available")
916 : IMPORT :: C_BOOL
917 : LOGICAL(C_BOOL) :: torch_c_cuda_is_available
918 : END FUNCTION torch_c_cuda_is_available
919 : END INTERFACE
920 :
921 2 : res = torch_c_cuda_is_available()
922 : #else
923 : CPABORT("CP2K was compiled without Torch library.")
924 : res = .FALSE.
925 : #endif
926 2 : END FUNCTION torch_cuda_is_available
927 :
928 : ! **************************************************************************************************
929 : !> \brief Return the number of CUDA devices visible to Torch.
930 : ! **************************************************************************************************
931 0 : FUNCTION torch_cuda_device_count() RESULT(count)
932 : INTEGER :: count
933 :
934 : #if defined(__LIBTORCH)
935 : INTERFACE
936 : FUNCTION torch_c_cuda_device_count() BIND(C, name="torch_c_cuda_device_count")
937 : IMPORT :: C_INT
938 : INTEGER(C_INT) :: torch_c_cuda_device_count
939 : END FUNCTION torch_c_cuda_device_count
940 : END INTERFACE
941 :
942 0 : count = torch_c_cuda_device_count()
943 : #else
944 : CPABORT("CP2K was compiled without Torch library.")
945 : count = 0
946 : #endif
947 0 : END FUNCTION torch_cuda_device_count
948 :
949 : ! **************************************************************************************************
950 : !> \brief Set whether to allow the use of TF32.
951 : !> Needed due to changes in defaults from pytorch 1.7 to 1.11 to >=1.12
952 : !> See https://pytorch.org/docs/stable/notes/cuda.html
953 : !> \author Gabriele Tocci
954 : ! **************************************************************************************************
955 4 : SUBROUTINE torch_allow_tf32(allow_tf32)
956 : LOGICAL, INTENT(IN) :: allow_tf32
957 :
958 : #if defined(__LIBTORCH)
959 : INTERFACE
960 : SUBROUTINE torch_c_allow_tf32(allow_tf32) BIND(C, name="torch_c_allow_tf32")
961 : IMPORT :: C_BOOL
962 : LOGICAL(C_BOOL), VALUE :: allow_tf32
963 : END SUBROUTINE torch_c_allow_tf32
964 : END INTERFACE
965 :
966 4 : CALL torch_c_allow_tf32(allow_tf32=LOGICAL(allow_tf32, C_BOOL))
967 : #else
968 : CPABORT("CP2K was compiled without Torch library.")
969 : MARK_USED(allow_tf32)
970 : #endif
971 4 : END SUBROUTINE torch_allow_tf32
972 :
973 : ! **************************************************************************************************
974 : !> \brief Freeze the given Torch model: applies generic optimization that speed up model.
975 : !> See https://pytorch.org/docs/stable/generated/torch.jit.freeze.html
976 : !> \author Gabriele Tocci
977 : ! **************************************************************************************************
978 4 : SUBROUTINE torch_model_freeze(model)
979 : TYPE(torch_model_type), INTENT(INOUT) :: model
980 :
981 : #if defined(__LIBTORCH)
982 : CHARACTER(len=*), PARAMETER :: routineN = 'torch_model_freeze'
983 : INTEGER :: handle
984 :
985 : INTERFACE
986 : SUBROUTINE torch_c_model_freeze(model) BIND(C, name="torch_c_model_freeze")
987 : IMPORT :: C_PTR
988 : TYPE(C_PTR), VALUE :: model
989 : END SUBROUTINE torch_c_model_freeze
990 : END INTERFACE
991 :
992 4 : CALL timeset(routineN, handle)
993 4 : CPASSERT(C_ASSOCIATED(model%c_ptr))
994 4 : CALL torch_c_model_freeze(model=model%c_ptr)
995 4 : CALL timestop(handle)
996 : #else
997 : CPABORT("CP2K was compiled without Torch library.")
998 : MARK_USED(model)
999 : #endif
1000 4 : END SUBROUTINE torch_model_freeze
1001 :
1002 : #:set typenames = ['int64', 'double', 'string']
1003 : #:set types_f = ['INTEGER(kind=int_8)', 'REAL(dp)', 'CHARACTER(LEN=default_string_length)']
1004 : #:set types_c = ['INTEGER(kind=C_INT64_T)', 'REAL(kind=C_DOUBLE)', 'CHARACTER(kind=C_CHAR), DIMENSION(*)']
1005 : #:set zeros_f = ['0', '0.0_dp', '""']
1006 :
1007 : #:for typename, type_f, type_c, zero_f in zip(typenames, types_f, types_c, zeros_f)
1008 : ! **************************************************************************************************
1009 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1010 : !> \author Ole Schuett
1011 : ! **************************************************************************************************
1012 64 : SUBROUTINE torch_model_get_attr_${typename}$ (model, key, dest)
1013 : TYPE(torch_model_type), INTENT(IN) :: model
1014 : CHARACTER(len=*), INTENT(IN) :: key
1015 : ${type_f}$, INTENT(OUT) :: dest
1016 :
1017 : #if defined(__LIBTORCH)
1018 :
1019 : INTERFACE
1020 : SUBROUTINE torch_c_model_get_attr_${typename}$ (model, key, dest) &
1021 : BIND(C, name="torch_c_model_get_attr_${typename}$")
1022 : IMPORT :: C_PTR, C_CHAR, C_INT64_T, C_DOUBLE
1023 : TYPE(C_PTR), VALUE :: model
1024 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1025 : ${type_c}$ :: dest
1026 : END SUBROUTINE torch_c_model_get_attr_${typename}$
1027 : END INTERFACE
1028 :
1029 : CALL torch_c_model_get_attr_${typename}$ (model=model%c_ptr, &
1030 : key=TRIM(key)//C_NULL_CHAR, &
1031 64 : dest=dest)
1032 : #else
1033 : dest = ${zero_f}$
1034 : MARK_USED(model)
1035 : MARK_USED(key)
1036 : CPABORT("CP2K compiled without the Torch library.")
1037 : #endif
1038 64 : END SUBROUTINE torch_model_get_attr_${typename}$
1039 : #:endfor
1040 :
1041 : ! **************************************************************************************************
1042 : !> \brief Retrieves an attribute from a Torch model. Must be called before torch_model_freeze.
1043 : !> \author Ole Schuett
1044 : ! **************************************************************************************************
1045 40 : SUBROUTINE torch_model_get_attr_int32(model, key, dest)
1046 : TYPE(torch_model_type), INTENT(IN) :: model
1047 : CHARACTER(len=*), INTENT(IN) :: key
1048 : INTEGER, INTENT(OUT) :: dest
1049 :
1050 : INTEGER(kind=int_8) :: temp
1051 40 : CALL torch_model_get_attr_int64(model, key, temp)
1052 40 : CPASSERT(ABS(temp) < HUGE(dest))
1053 40 : dest = INT(temp)
1054 40 : END SUBROUTINE torch_model_get_attr_int32
1055 :
1056 : ! **************************************************************************************************
1057 : !> \brief Retrieves a list attribute from a Torch model. Must be called before torch_model_freeze.
1058 : !> \author Ole Schuett
1059 : ! **************************************************************************************************
1060 8 : SUBROUTINE torch_model_get_attr_strlist(model, key, dest)
1061 : TYPE(torch_model_type), INTENT(IN) :: model
1062 : CHARACTER(len=*), INTENT(IN) :: key
1063 : CHARACTER(LEN=default_string_length), &
1064 : ALLOCATABLE, DIMENSION(:) :: dest
1065 :
1066 : #if defined(__LIBTORCH)
1067 :
1068 : INTEGER :: num_items, i
1069 :
1070 : INTERFACE
1071 : SUBROUTINE torch_c_model_get_attr_list_size(model, key, size) &
1072 : BIND(C, name="torch_c_model_get_attr_list_size")
1073 : IMPORT :: C_PTR, C_CHAR, C_INT
1074 : TYPE(C_PTR), VALUE :: model
1075 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1076 : INTEGER(kind=C_INT) :: size
1077 : END SUBROUTINE torch_c_model_get_attr_list_size
1078 : END INTERFACE
1079 :
1080 : INTERFACE
1081 : SUBROUTINE torch_c_model_get_attr_strlist(model, key, index, dest) &
1082 : BIND(C, name="torch_c_model_get_attr_strlist")
1083 : IMPORT :: C_PTR, C_CHAR, C_INT
1084 : TYPE(C_PTR), VALUE :: model
1085 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: key
1086 : INTEGER(kind=C_INT), VALUE :: index
1087 : CHARACTER(kind=C_CHAR), DIMENSION(*) :: dest
1088 : END SUBROUTINE torch_c_model_get_attr_strlist
1089 : END INTERFACE
1090 :
1091 : CALL torch_c_model_get_attr_list_size(model=model%c_ptr, &
1092 : key=TRIM(key)//C_NULL_CHAR, &
1093 8 : size=num_items)
1094 24 : ALLOCATE (dest(num_items))
1095 24 : dest(:) = ""
1096 :
1097 24 : DO i = 1, num_items
1098 : CALL torch_c_model_get_attr_strlist(model=model%c_ptr, &
1099 : key=TRIM(key)//C_NULL_CHAR, &
1100 : index=i - 1, &
1101 24 : dest=dest(i))
1102 :
1103 : END DO
1104 : #else
1105 : CPABORT("CP2K compiled without the Torch library.")
1106 : MARK_USED(model)
1107 : MARK_USED(key)
1108 : MARK_USED(dest)
1109 : #endif
1110 :
1111 8 : END SUBROUTINE torch_model_get_attr_strlist
1112 :
1113 0 : END MODULE torch_api
|