#!/usr/bin/env python3

import sys
import mlatom as ml
from mlatom.constants import Angstrom2Bohr

def set_model():
    return ml.models.methods(method="AIQM1", 
                             qm_program="MNDO",
                             qm_program_kwargs={'read_keywords_from_file':'mndokw'})


def predict(model, folder_path, n_states, curr_state):
    mol = ml.data.molecule()
    mol.read_from_xyz_file(filename=f'{folder_path}/geom.xyz')
    
    model.predict(molecule=mol, 
                  nstates=n_states,
                  current_state=curr_state - 1,
                  calculate_energy=True,
                  calculate_energy_gradients=[True]*n_states,
                  calculate_nacv=False)

    # Write potential energies
    with open(f'{folder_path}/epot', 'w') as f:
        for i in range(n_states):
            f.write(f'{mol.electronic_states[i].energy}\n')

    # Conversion constant for gradients
    const = 1 / Angstrom2Bohr

    # Write gradients based on the current state
    with open(f'{folder_path}/grad', 'w') as f:
        for gradient in mol.state_gradients*const:
            for atom in gradient:
                f.write(f'{atom[0]} \t {atom[1]} \t {atom[2]}\n')
 
    return {"stats": "success"}


if __name__ == "__main__":
    # Parse arguments
    n_states = int(sys.argv[1])
    curr_state = int(sys.argv[2]) 
    print(n_states*[True])

    model = set_model()
    predict(model, './', n_states, curr_state)
