LCOV - code coverage report
Current view: top level - src - pao_main.F (source / functions) Coverage Total Hit
Test: CP2K Regtests (git:07c9450) Lines: 100.0 % 141 141
Test Date: 2025-12-13 06:52:47 Functions: 100.0 % 5 5

            Line data    Source code
       1              : !--------------------------------------------------------------------------------------------------!
       2              : !   CP2K: A general program to perform molecular dynamics simulations                              !
       3              : !   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
       4              : !                                                                                                  !
       5              : !   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
       6              : !--------------------------------------------------------------------------------------------------!
       7              : 
       8              : ! **************************************************************************************************
       9              : !> \brief Main module for the PAO method
      10              : !> \author Ole Schuett
      11              : ! **************************************************************************************************
      12              : MODULE pao_main
      13              :    USE bibliography,                    ONLY: Schuett2018,&
      14              :                                               cite_reference
      15              :    USE cp_dbcsr_api,                    ONLY: dbcsr_add,&
      16              :                                               dbcsr_copy,&
      17              :                                               dbcsr_create,&
      18              :                                               dbcsr_p_type,&
      19              :                                               dbcsr_release,&
      20              :                                               dbcsr_set,&
      21              :                                               dbcsr_type
      22              :    USE cp_dbcsr_contrib,                ONLY: dbcsr_reserve_diag_blocks
      23              :    USE cp_external_control,             ONLY: external_control
      24              :    USE dm_ls_scf_types,                 ONLY: ls_mstruct_type,&
      25              :                                               ls_scf_env_type
      26              :    USE input_section_types,             ONLY: section_vals_get_subs_vals,&
      27              :                                               section_vals_type
      28              :    USE kinds,                           ONLY: dp
      29              :    USE linesearch,                      ONLY: linesearch_finalize,&
      30              :                                               linesearch_init,&
      31              :                                               linesearch_reset,&
      32              :                                               linesearch_step
      33              :    USE machine,                         ONLY: m_walltime
      34              :    USE pao_input,                       ONLY: parse_pao_section
      35              :    USE pao_io,                          ONLY: pao_read_restart,&
      36              :                                               pao_write_hcore_matrix_csr,&
      37              :                                               pao_write_ks_matrix_csr,&
      38              :                                               pao_write_p_matrix_csr,&
      39              :                                               pao_write_restart,&
      40              :                                               pao_write_s_matrix_csr
      41              :    USE pao_methods,                     ONLY: &
      42              :         pao_add_forces, pao_build_core_hamiltonian, pao_build_diag_distribution, &
      43              :         pao_build_matrix_X, pao_build_orthogonalizer, pao_build_selector, pao_calc_energy, &
      44              :         pao_check_grad, pao_check_trace_ps, pao_guess_initial_P, pao_init_kinds, &
      45              :         pao_print_atom_info, pao_store_P, pao_test_convergence
      46              :    USE pao_ml,                          ONLY: pao_ml_init,&
      47              :                                               pao_ml_predict
      48              :    USE pao_model,                       ONLY: pao_model_predict
      49              :    USE pao_optimizer,                   ONLY: pao_opt_finalize,&
      50              :                                               pao_opt_init,&
      51              :                                               pao_opt_new_dir
      52              :    USE pao_param,                       ONLY: pao_calc_AB,&
      53              :                                               pao_param_finalize,&
      54              :                                               pao_param_init,&
      55              :                                               pao_param_initial_guess
      56              :    USE pao_types,                       ONLY: pao_env_type
      57              :    USE qs_environment_types,            ONLY: get_qs_env,&
      58              :                                               qs_environment_type
      59              : #include "./base/base_uses.f90"
      60              : 
      61              :    IMPLICIT NONE
      62              : 
      63              :    PRIVATE
      64              : 
      65              :    CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'pao_main'
      66              : 
      67              :    PUBLIC :: pao_init, pao_update, pao_post_scf, pao_optimization_start, pao_optimization_end
      68              : 
      69              : CONTAINS
      70              : 
      71              : ! **************************************************************************************************
      72              : !> \brief Initialize the PAO environment
      73              : !> \param qs_env ...
      74              : !> \param ls_scf_env ...
      75              : ! **************************************************************************************************
      76          434 :    SUBROUTINE pao_init(qs_env, ls_scf_env)
      77              :       TYPE(qs_environment_type), POINTER                 :: qs_env
      78              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
      79              : 
      80              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_init'
      81              : 
      82              :       INTEGER                                            :: handle
      83          336 :       TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s
      84              :       TYPE(pao_env_type), POINTER                        :: pao
      85              :       TYPE(section_vals_type), POINTER                   :: input
      86              : 
      87          238 :       IF (.NOT. ls_scf_env%do_pao) RETURN
      88              : 
      89           98 :       CALL timeset(routineN, handle)
      90           98 :       CALL cite_reference(Schuett2018)
      91           98 :       pao => ls_scf_env%pao_env
      92           98 :       CALL get_qs_env(qs_env=qs_env, input=input, matrix_s=matrix_s)
      93              : 
      94              :       ! parse input
      95           98 :       CALL parse_pao_section(pao, input)
      96              : 
      97           98 :       CALL pao_init_kinds(pao, qs_env)
      98              : 
      99              :       ! train machine learning
     100           98 :       CALL pao_ml_init(pao, qs_env)
     101              : 
     102           98 :       CALL timestop(handle)
     103          336 :    END SUBROUTINE pao_init
     104              : 
     105              : ! **************************************************************************************************
     106              : !> \brief Start a PAO optimization run.
     107              : !> \param qs_env ...
     108              : !> \param ls_scf_env ...
     109              : ! **************************************************************************************************
     110          892 :    SUBROUTINE pao_optimization_start(qs_env, ls_scf_env)
     111              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     112              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     113              : 
     114              :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_start'
     115              : 
     116              :       INTEGER                                            :: handle
     117              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     118              :       TYPE(pao_env_type), POINTER                        :: pao
     119              :       TYPE(section_vals_type), POINTER                   :: input, section
     120              : 
     121          598 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     122              : 
     123          294 :       CALL timeset(routineN, handle)
     124          294 :       CALL get_qs_env(qs_env, input=input)
     125          294 :       pao => ls_scf_env%pao_env
     126          294 :       ls_mstruct => ls_scf_env%ls_mstruct
     127              : 
     128              :       ! reset state
     129          294 :       pao%step_start_time = m_walltime()
     130          294 :       pao%istep = 0
     131          294 :       pao%matrix_P_ready = .FALSE.
     132              : 
     133              :       ! ready stuff that does not depend on atom positions
     134          294 :       IF (.NOT. pao%constants_ready) THEN
     135           98 :          CALL pao_build_diag_distribution(pao, qs_env)
     136           98 :          CALL pao_build_orthogonalizer(pao, qs_env)
     137           98 :          CALL pao_build_selector(pao, qs_env)
     138           98 :          CALL pao_build_core_hamiltonian(pao, qs_env)
     139           98 :          pao%constants_ready = .TRUE.
     140              :       END IF
     141              : 
     142          294 :       CALL pao_param_init(pao, qs_env)
     143              : 
     144              :       ! ready PAO parameter matrix_X
     145          294 :       IF (.NOT. pao%matrix_X_ready) THEN
     146           98 :          CALL pao_build_matrix_X(pao, qs_env)
     147           98 :          CALL pao_print_atom_info(pao)
     148           98 :          IF (LEN_TRIM(pao%restart_file) > 0) THEN
     149            8 :             CALL pao_read_restart(pao, qs_env)
     150           90 :          ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     151           18 :             CALL pao_ml_predict(pao, qs_env)
     152           72 :          ELSE IF (ALLOCATED(pao%models)) THEN
     153            4 :             CALL pao_model_predict(pao, qs_env)
     154              :          ELSE
     155           68 :             CALL pao_param_initial_guess(pao, qs_env)
     156              :          END IF
     157           98 :          pao%matrix_X_ready = .TRUE.
     158          196 :       ELSE IF (SIZE(pao%ml_training_set) > 0) THEN
     159          120 :          CALL pao_ml_predict(pao, qs_env)
     160           76 :       ELSE IF (ALLOCATED(pao%models)) THEN
     161           12 :          CALL pao_model_predict(pao, qs_env)
     162              :       ELSE
     163           64 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| reusing matrix_X from previous optimization"
     164              :       END IF
     165              : 
     166              :       ! init line-search
     167          294 :       section => section_vals_get_subs_vals(input, "DFT%LS_SCF%PAO%LINE_SEARCH")
     168          294 :       CALL linesearch_init(pao%linesearch, section, "PAO|")
     169              : 
     170              :       ! create some more matrices
     171          294 :       CALL dbcsr_copy(pao%matrix_G, pao%matrix_X)
     172          294 :       CALL dbcsr_set(pao%matrix_G, 0.0_dp)
     173              : 
     174          294 :       CALL dbcsr_create(ls_mstruct%matrix_A, template=pao%matrix_Y)
     175          294 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_A)
     176          294 :       CALL dbcsr_create(ls_mstruct%matrix_B, template=pao%matrix_Y)
     177          294 :       CALL dbcsr_reserve_diag_blocks(ls_mstruct%matrix_B)
     178              : 
     179              :       ! fill PAO transformation matrices
     180          294 :       CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.FALSE.)
     181              : 
     182          294 :       CALL timestop(handle)
     183              :    END SUBROUTINE pao_optimization_start
     184              : 
     185              : ! **************************************************************************************************
     186              : !> \brief Called after the SCF optimization, updates the PAO basis.
     187              : !> \param qs_env ...
     188              : !> \param ls_scf_env ...
     189              : !> \param pao_is_done ...
     190              : ! **************************************************************************************************
     191         1062 :    SUBROUTINE pao_update(qs_env, ls_scf_env, pao_is_done)
     192              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     193              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     194              :       LOGICAL, INTENT(OUT)                               :: pao_is_done
     195              : 
     196              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_update'
     197              : 
     198              :       INTEGER                                            :: handle, icycle
     199              :       LOGICAL                                            :: cycle_converged, do_mixing, should_stop
     200              :       REAL(KIND=dp)                                      :: energy, penalty
     201              :       TYPE(dbcsr_type)                                   :: matrix_X_mixing
     202              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     203              :       TYPE(pao_env_type), POINTER                        :: pao
     204              : 
     205          816 :       IF (.NOT. ls_scf_env%do_pao) THEN
     206          304 :          pao_is_done = .TRUE.
     207          570 :          RETURN
     208              :       END IF
     209              : 
     210          512 :       ls_mstruct => ls_scf_env%ls_mstruct
     211          512 :       pao => ls_scf_env%pao_env
     212              : 
     213          512 :       IF (.NOT. pao%matrix_P_ready) THEN
     214          294 :          CALL pao_guess_initial_P(pao, qs_env, ls_scf_env)
     215          294 :          pao%matrix_P_ready = .TRUE.
     216              :       END IF
     217              : 
     218          512 :       IF (pao%max_pao == 0) THEN
     219          218 :          pao_is_done = .TRUE.
     220          218 :          RETURN
     221              :       END IF
     222              : 
     223          294 :       IF (pao%need_initial_scf) THEN
     224           48 :          pao_is_done = .FALSE.
     225           48 :          pao%need_initial_scf = .FALSE.
     226           48 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Performing initial SCF optimization."
     227           48 :          RETURN
     228              :       END IF
     229              : 
     230          246 :       CALL timeset(routineN, handle)
     231              : 
     232              :       ! perform mixing once we are well into the optimization
     233          246 :       do_mixing = pao%mixing /= 1.0_dp .AND. pao%istep > 1
     234              :       IF (do_mixing) THEN
     235          128 :          CALL dbcsr_copy(matrix_X_mixing, pao%matrix_X)
     236              :       END IF
     237              : 
     238          246 :       cycle_converged = .FALSE.
     239          246 :       icycle = 0
     240          246 :       CALL linesearch_reset(pao%linesearch)
     241          246 :       CALL pao_opt_init(pao)
     242              : 
     243        20024 :       DO WHILE (.TRUE.)
     244        10126 :          pao%istep = pao%istep + 1
     245              : 
     246        15189 :          IF (pao%iw > 0) WRITE (pao%iw, "(A,I9,A)") " PAO| ======================= Iteration: ", &
     247        10126 :             pao%istep, " ============================="
     248              : 
     249              :          ! calc energy and check trace_PS
     250        10126 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     251        10126 :          CALL pao_check_trace_PS(ls_scf_env)
     252              : 
     253        10126 :          IF (pao%linesearch%starts) THEN
     254         2616 :             icycle = icycle + 1
     255              :             ! calc new gradient including penalty terms
     256         2616 :             CALL pao_calc_AB(pao, qs_env, ls_scf_env, gradient=.TRUE., penalty=penalty)
     257         2616 :             CALL pao_check_grad(pao, qs_env, ls_scf_env)
     258              : 
     259              :             ! calculate new direction for line-search
     260         2616 :             CALL pao_opt_new_dir(pao, icycle)
     261              : 
     262              :             !backup X
     263         2616 :             CALL dbcsr_copy(pao%matrix_X_orig, pao%matrix_X)
     264              : 
     265              :             ! print info and convergence test
     266         2616 :             CALL pao_test_convergence(pao, ls_scf_env, energy, cycle_converged)
     267         2616 :             IF (cycle_converged) THEN
     268          210 :                pao_is_done = icycle < 3
     269          210 :                IF (pao_is_done .AND. pao%iw > 0) WRITE (pao%iw, *) "PAO| converged after ", pao%istep, " steps :-)"
     270              :                EXIT
     271              :             END IF
     272              : 
     273              :             ! if we have reached the maximum number of cycles exit in order
     274              :             ! to restart with a fresh hamiltonian
     275         2406 :             IF (icycle >= pao%max_cycles) THEN
     276           18 :                IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| CG not yet converged after ", icycle, " cylces."
     277           18 :                pao_is_done = .FALSE.
     278           18 :                EXIT
     279              :             END IF
     280              : 
     281         2388 :             IF (MOD(icycle, pao%write_cycles) == 0) &
     282            8 :                CALL pao_write_restart(pao, qs_env, energy) ! write an intermediate restart file
     283              :          END IF
     284              : 
     285              :          ! check for early abort without convergence?
     286         9898 :          CALL external_control(should_stop, "PAO", start_time=qs_env%start_time, target_time=qs_env%target_time)
     287         9898 :          IF (should_stop .OR. pao%istep >= pao%max_pao) THEN
     288           18 :             CPWARN("PAO not converged!")
     289           18 :             pao_is_done = .TRUE.
     290           18 :             EXIT
     291              :          END IF
     292              : 
     293              :          ! perform line-search step
     294         9880 :          CALL linesearch_step(pao%linesearch, energy=energy, slope=pao%norm_G**2)
     295              : 
     296         9880 :          IF (pao%linesearch%step_size < 1e-9_dp) CPABORT("PAO gradient is wrong.")
     297              : 
     298         9880 :          CALL dbcsr_copy(pao%matrix_X, pao%matrix_X_orig) !restore X
     299         9880 :          CALL dbcsr_add(pao%matrix_X, pao%matrix_D, 1.0_dp, pao%linesearch%step_size)
     300              :       END DO
     301              : 
     302              :       ! perform mixing of matrix_X
     303          246 :       IF (do_mixing) THEN
     304          128 :          CALL dbcsr_add(pao%matrix_X, matrix_X_mixing, pao%mixing, 1.0_dp - pao%mixing)
     305          128 :          CALL dbcsr_release(matrix_X_mixing)
     306          128 :          IF (pao%iw > 0) WRITE (pao%iw, *) "PAO| Recalculating energy after mixing."
     307          128 :          CALL pao_calc_energy(pao, qs_env, ls_scf_env, energy)
     308              :       END IF
     309              : 
     310          246 :       CALL pao_write_restart(pao, qs_env, energy)
     311          246 :       CALL pao_opt_finalize(pao)
     312              : 
     313          246 :       CALL timestop(handle)
     314          816 :    END SUBROUTINE pao_update
     315              : 
     316              : ! **************************************************************************************************
     317              : !> \brief Calculate PAO forces and store density matrix for future ASPC extrapolations
     318              : !> \param qs_env ...
     319              : !> \param ls_scf_env ...
     320              : !> \param pao_is_done ...
     321              : ! **************************************************************************************************
     322         1110 :    SUBROUTINE pao_post_scf(qs_env, ls_scf_env, pao_is_done)
     323              :       TYPE(qs_environment_type), POINTER                 :: qs_env
     324              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     325              :       LOGICAL, INTENT(IN)                                :: pao_is_done
     326              : 
     327              :       CHARACTER(len=*), PARAMETER                        :: routineN = 'pao_post_scf'
     328              : 
     329              :       INTEGER                                            :: handle
     330              : 
     331         1034 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     332          512 :       IF (.NOT. pao_is_done) RETURN
     333              : 
     334          294 :       CALL timeset(routineN, handle)
     335              : 
     336              :       ! print out the matrices here before pao_store_P converts them back into matrices in
     337              :       ! terms of the primary basis
     338          294 :       CALL pao_write_ks_matrix_csr(qs_env, ls_scf_env)
     339          294 :       CALL pao_write_s_matrix_csr(qs_env, ls_scf_env)
     340          294 :       CALL pao_write_hcore_matrix_csr(qs_env, ls_scf_env)
     341          294 :       CALL pao_write_p_matrix_csr(qs_env, ls_scf_env)
     342              : 
     343          294 :       CALL pao_store_P(qs_env, ls_scf_env)
     344          294 :       IF (ls_scf_env%calculate_forces) CALL pao_add_forces(qs_env, ls_scf_env)
     345              : 
     346          294 :       CALL timestop(handle)
     347              :    END SUBROUTINE pao_post_scf
     348              : 
     349              : ! **************************************************************************************************
     350              : !> \brief Finish a PAO optimization run.
     351              : !> \param ls_scf_env ...
     352              : ! **************************************************************************************************
     353          892 :    SUBROUTINE pao_optimization_end(ls_scf_env)
     354              :       TYPE(ls_scf_env_type), TARGET                      :: ls_scf_env
     355              : 
     356              :       CHARACTER(len=*), PARAMETER :: routineN = 'pao_optimization_end'
     357              : 
     358              :       INTEGER                                            :: handle
     359              :       TYPE(ls_mstruct_type), POINTER                     :: ls_mstruct
     360              :       TYPE(pao_env_type), POINTER                        :: pao
     361              : 
     362          598 :       IF (.NOT. ls_scf_env%do_pao) RETURN
     363              : 
     364          294 :       pao => ls_scf_env%pao_env
     365          294 :       ls_mstruct => ls_scf_env%ls_mstruct
     366              : 
     367          294 :       CALL timeset(routineN, handle)
     368              : 
     369          294 :       CALL pao_param_finalize(pao)
     370              : 
     371              :       ! We keep pao%matrix_X for next scf-run, e.g. during MD or GEO-OPT
     372          294 :       CALL dbcsr_release(pao%matrix_X_orig)
     373          294 :       CALL dbcsr_release(pao%matrix_G)
     374          294 :       CALL dbcsr_release(ls_mstruct%matrix_A)
     375          294 :       CALL dbcsr_release(ls_mstruct%matrix_B)
     376              : 
     377          294 :       CALL linesearch_finalize(pao%linesearch)
     378              : 
     379          294 :       CALL timestop(handle)
     380              :    END SUBROUTINE pao_optimization_end
     381              : 
     382              : END MODULE pao_main
        

Generated by: LCOV version 2.0-1