! Copyright (C) 2022  Light and Molecules Group

! This program is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.

! This program is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU General Public License for more details.

! You should have received a copy of the GNU General Public License
! along with this program.  If not, see <https://www.gnu.org/licenses/>.

module mod_statistics
  !! author: Baptiste Demoulin <baptiste.demoulin@univ-amu.fr>
  !! date: 2020-09-30
  !!
  !! Module for statistical analysis of NX results.
  !!
  !! This module intends at gathering all routines related to the
  !! statistical analysis of Newton-X trajectories, either in plain
  !! text format, or in the H5MD format.
  !!
  !! Function and routines that are specific to plain text format
  !! are prefixed with ``txt``, while those related to H5MD are
  !! prefixed with ``h5``.
  use mod_kinds, only: dp
  use mod_constants
#ifdef USE_HDF5
  use hdf5
  use mod_h5md
#endif
  implicit none

  private

  public :: txt_compare_nstatdyn, txt_get_nr_of_fields
  public :: txt_update_avg_var
  public :: pretty_print_stats
  public :: pretty_print_fractraj

#ifdef USE_HDF5
  public :: h5_update_avg_var, h5_get_time_step
  public :: h5_stats_energies
  public :: h5_compare_nstatdyn

  interface h5_update_avg_var
     module procedure h5_update_avg_var_d1
     module procedure h5_update_avg_var_d2
  end interface h5_update_avg_var
#endif

contains

  subroutine txt_compare_nstatdyn(filename, numtraj, offset, nlines,&
       & nstatdyn_init, time)
    !! Determine the number of trajectories in any state.
    !!
    !! This routine is used to determine the fraction of trajectories
    !! in each state along the dynamics. The ``numtraj`` array
    !! has lines corresponding to the steps, and columns for state
    !! population. The ``nstatdyn`` column of ``filename`` is used to
    !! update the information in ``numtraj``.
    character(len=MAX_STR_SIZE), intent(in) :: filename
    !! Text file to read.
    integer, dimension(:, :), intent(inout) :: numtraj
    !! Number of trajectories in each state
    integer, intent(in) :: offset
    !! Line at which to start reading ``filename``.
    integer, intent(in) :: nlines
    !! Number of lines to read from the file.
    integer, intent(in) :: nstatdyn_init
    !! Initial energy surface
    real(dp), dimension(:), intent(inout) :: time

    integer :: u, i, ll
    integer :: curnstatdyn, prevnstatdyn

    real(dp), dimension(:), allocatable :: rbuf
    integer :: nFields

    nFields = txt_get_nr_of_fields(filename, offset)

    allocate(rbuf(nFields))

    open(newunit=u, file=filename, action='read')
    do i=1, offset-1
       read(u, *)
    end do

    prevnstatdyn = nstatdyn_init
    do ll=1, nlines
       read(u, *) rbuf(:)
       curnstatdyn = int(rbuf(2))
       time(ll) = rbuf(1)
       if (curnstatdyn /= prevnstatdyn) then
          numtraj(ll:nlines, prevnstatdyn) = numtraj(ll:nlines, prevnstatdyn) - 1
          numtraj(ll:nlines, curnstatdyn) = numtraj(ll:nlines, curnstatdyn) + 1
          prevnstatdyn = curnstatdyn
       end if
    end do
    close(u)
  end subroutine txt_compare_nstatdyn


  function txt_get_nr_of_fields(filename, offset) result(nfields)
    !! Determine the total number of fields in file ``filename''.
    !!
    !! The ``offset'' parameter will be used to specify which line to
    !! read to determine the number of field (i.e. if the first few
    !! lines are comments / titles, it may be necessary to go further
    !! in the file).
    !!
    !! For now this function only works with space-separated files.
    character(len=*) :: filename
    !! Name of the file to read.
    integer :: offset
    !! Line to read from.
    integer :: nfields
    !! Number of fields determined.

    integer :: i, u
    integer :: strlen
    character(len=1024) :: buf

    ! Determination of the number of fields:
    ! One line is read and stored in a CHAR buffer. The length of the
    ! buffer is determined by 'removing' all trailing whitespaces.
    ! Then, a new field is found if a whitespace, NOT FOLLOWED by
    ! another whitespace is found.
    open(newunit=u, file=filename, action='read')

    do i=1, offset-1
       read(u, *)
    end do

    read(u, '(a)') buf
    rewind(u)

    strlen = len(buf)
    do while (buf(strlen:strlen) == ' ')
       strlen = strlen - 1
    end do

    ! If the line starts with a whitespace, the next loop will catch
    ! it. If not, we already have a field.
    if (buf(1:1) == ' ') then
       nfields = 0
    else
       nfields = 1
    end if

    do i=1, strlen

       if ((buf(i:i) == ' ') .and. &
            & (buf(i+1:i+1) /= ' ')) then
          nfields = nfields + 1
       end if
    end do
    close(u)
  end function txt_get_nr_of_fields


  subroutine txt_update_avg_var(filename, avg, var, offset, nlines,&
       & istep, usecols)
    !! Update the average and variance from ``filename``.
    !!
    !! This will read the content from ``filename`` starting from
    !! ``offset`` (corresponding to trajectory nr. ``istep) and
    !! directly compute the average and variance, to store it into
    !! ``avg`` and ``var`` respectively.
    !!
    !! The subroutine implements the Welford's method, as presented by
    !! D. Knuth in Art of Computer Programming, Vol 2, page 232, 3rd
    !! edition.
    character(len=MAX_STR_SIZE), intent(in) :: filename
    real(dp), dimension(:, :), allocatable, intent(inout) :: avg
    real(dp), dimension(:, :), allocatable, intent(inout) :: var
    integer, intent(in) :: offset
    integer, intent(in) :: nlines
    integer :: istep
    integer, dimension(:), intent(in), optional :: usecols

    integer :: u, ll, k, io, i, col
    integer :: nFields
    real(dp), dimension(:, :), allocatable :: rbuf
    real(dp), dimension(:, :), allocatable :: old_avg

    nFields = txt_get_nr_of_fields(filename, offset)

    allocate(rbuf(nlines, nFields))
    write(*, *) 'Size of rbuf: ', nlines, nFields

    open(newunit=u, file=filename, action='read')
    do i=1, offset-1
       read(u, *)
    end do
    do ll=1, nlines
       read(u, *, iostat=io) rbuf(ll, :)
    end do
    close(u)

    if (present(usecols)) then
       k = size(usecols)
    else
       k = nFields
    end if


    if (istep == 1) then

       do i=1, k
          if (present(usecols)) then
             col =  usecols(i)
          else
             col = i
          end if
          avg(i, :) = rbuf(:, col)
          var(i, :) = 0.0_dp
       end do
       var(1, :) = rbuf(:, 1)
    else
       allocate(old_avg(size(avg, 1), size(avg, 2)))
       old_avg(:, :) = avg(:, :)
       do i=2, k
          if (present(usecols)) then
             col = usecols(i)
          else
             col = i
          end if
          avg(i, :) = avg(i, :) + (rbuf(:, col) - avg(i, :)) / istep
          var(i, :) = var(i, :) &
               & + (rbuf(:, col) - old_avg(i, :)) &
               & * (rbuf(:, col) - avg(i, :))
       end do
       var(2:k, :) = var(2:k, :) / (istep - 1)
    end if
  end subroutine txt_update_avg_var


#ifdef USE_HDF5
  subroutine h5_update_avg_var_d1(loc_id, gname, avg, var, itraj)
    !! Update the average and variance from ``gname``.
    !!
    !! The routine reads the ``nx_h5md_ele_t`` element corresponding
    !! to ``gname``, situated at ``loc_id``, and updates the content
    !! of ``avg`` and ``var`` arrays with the new average and
    !! variance.
    !!
    !! The subroutine implements the Welford's method, as presented by
    !! D. Knuth in Art of Computer Programming, Vol 2, page 232, 3rd
    !! edition.
    integer(hid_t), intent(in) :: loc_id
    !! Location inside the H5MD file to look for ``gname``.
    character(len=*), intent(in) :: gname
    !! Group for which the average and variance are wanted.
    real(dp), dimension(:), intent(inout) :: avg
    !! Average.
    real(dp), dimension(:), intent(inout) :: var
    !! Variance.
    integer, intent(in) :: itraj
    !! Trajectory index.

    type(nx_h5md_ele_t) :: h5md
    real(dp), dimension(:), allocatable :: buf
    real(dp), dimension(:), allocatable :: old_avg

    call h5md%open(gname, loc_id)
    ! Read the data in ```buf`` (``buf`` is allocated by this routine)
    call h5md%read(buf)


    if (itraj == 1) then
       avg(:) = buf(:)
       var(:) = 0.0_dp
    else
       allocate(old_avg(size(avg)))
       old_avg(:) = avg(:)
       avg(:) = avg + (buf - avg) / itraj
       var(:) = var + (buf - old_avg)*(buf - avg)
       var(:) = var(:) / (itraj - 1)
    end if

    call h5md%close()

  end subroutine h5_update_avg_var_d1


  subroutine h5_update_avg_var_d2(loc_id, gname, avg, var, itraj)
    !! Update the average and variance from ``gname``.
    !!
    !! The routine reads the ``nx_h5md_ele_t`` element corresponding
    !! to ``gname``, situated at ``loc_id``, and updates the content
    !! of ``avg`` and ``var`` arrays with the new average and
    !! variance.
    !!
    !! The subroutine implements the Welford's method, as presented by
    !! D. Knuth in Art of Computer Programming, Vol 2, page 232, 3rd
    !! edition.
    integer(hid_t), intent(in) :: loc_id
    !! Location inside the H5MD file to look for ``gname``.
    character(len=*), intent(in) :: gname
    !! Group for which the average and variance are wanted.
    real(dp), dimension(:, :), intent(inout) :: avg
    !! Average.
    real(dp), dimension(:, :), intent(inout) :: var
    !! Variance.
    integer, intent(in) :: itraj
    !! Trajectory index.

    type(nx_h5md_ele_t) :: h5md
    real(dp), dimension(:, :), allocatable :: buf
    real(dp), dimension(:, :), allocatable :: old_avg

    call h5md%open(gname, loc_id)
    ! Read the data in ```buf`` (``buf`` is allocated by this routine)
    call h5md%read(buf)

    ! write(*, *) 'Shape of buf: ', shape(buf)


    if (itraj == 1) then
       avg = buf
       var = 0.0_dp
    else
       allocate(old_avg(size(avg, 1), size(avg, 2)))
       old_avg = avg
       avg = avg + (buf - avg) / itraj
       var = var + (buf - old_avg)*(buf - avg)
       var = var / (itraj - 1)
    end if

    call h5md%close()

  end subroutine h5_update_avg_var_d2



  subroutine h5_get_time_step(loc_id, gname, time, step)
    integer(hid_t), intent(in) :: loc_id
    character(len=*), intent(in) :: gname
    real(dp), dimension(:), allocatable, intent(out) :: time
    integer, dimension(:), allocatable, intent(out) :: step

    type(nx_h5md_ele_t) :: h5md
    real(dp), dimension(:), allocatable :: buf

    call h5md%open(gname, loc_id)

    call h5md%read(buf, step, time)
    call h5md%close()

  end subroutine h5_get_time_step


  subroutine h5_stats_energies(loc_id, avg, var, itraj)
    !! Update the energy statistics with given trajectory.
    !!
    !! The routine reads the different components of the energy at
    !! ``loc_id`` (``total_energy``, ``kinetic_energy`` and
    !! ``potential_energy``) from the trajectory number ``itraj``, and
    !! stores the corresponding average and variance in
    !! ``avg`` and ``var`` respectively for plotting.
    !!
    !! The arrays will be ordered as follows:
    !!
    !! ``Time    Etot     Ekin     Epot(1)     Epot(2) ... ``
    !!
    integer(hid_t), intent(in) :: loc_id
    !! Location from where to read the different energies.
    real(dp), dimension(:, :),  intent(inout) :: avg
    !! Array for holding the average of energies.
    real(dp), dimension(:, :),  intent(inout) :: var
    !! Array for holding the variance of energies.
    integer, intent(in) :: itraj
    !! Trajectory identifier.

    integer :: ss

    ss = size(avg, 1)

    call h5_update_avg_var(loc_id, 'total_energy', avg(2, :), var(2,&
         & :), itraj)

    call h5_update_avg_var(loc_id, 'kinetic_energy', avg(3, :), var(3&
         & , :), itraj)

    call h5_update_avg_var(loc_id, 'potential_energy', &
          & avg(4:ss, :), var(4:ss, :), itraj)

  end subroutine h5_stats_energies


  subroutine h5_compare_nstatdyn(loc_id, fractraj, nstatdyn_init)
    integer(hid_t), intent(in) :: loc_id
    integer, dimension(:, :), intent(inout) :: fractraj
    integer, intent(in) :: nstatdyn_init

    type(nx_h5md_ele_t) :: h5md
    integer, dimension(:), allocatable :: buf
    integer :: i, nsteps
    integer :: prevnstatdyn

    call h5md%open('nstatdyn', loc_id)
    call h5md%read(buf)
    call h5md%close()

    nsteps = size(buf)
    prevnstatdyn = nstatdyn_init
    do i=1, nsteps
       if (buf(i) /= prevnstatdyn) then
          fractraj(i:nsteps, prevnstatdyn) = fractraj(i:nsteps, prevnstatdyn) - 1
          fractraj(i:nsteps, buf(i)) = fractraj(i:nsteps, buf(i)) + 1
          prevnstatdyn = buf(i)
       end if
    end do
  end subroutine h5_compare_nstatdyn



#endif


  subroutine pretty_print_stats(stat, filename, type)
    !! Pretty print the ``stat`` array into ``filename``.
    !!
    !! The stat array is expected to be 2-D, indexed as:
    !!
    !! ``time(1)     time(2)      time(3)   ... ``
    !! ``stat(1,1)   stat(1,2)    stat(1, 3) ... ``
    !! ``stat(2,1)   stat(2,2)    stat(2, 3) ... ``
    !! ``...``
    !!
    !! This corresponds to the Fortran orderiung that we use in the
    !! H5MD format. However, it is more convenient for plotting to
    !! have the time in the first column, and the properties as the
    !! following columns. Thus, we will print it as:
    !!
    !! ``time(1)     stat(1,1)      stat(2,1)   ... ``
    !! ``time(2)     stat(1,2)      stat(2,2)   ... ``
    !! ``time(3)     stat(1,3)      stat(2,3)   ... ``
    real(dp), dimension(:, :), intent(in) :: stat
    character(len=*), intent(in) :: filename
    character(len=*), intent(in) :: type

    integer :: u, i, j
    integer :: nlines, nsteps
    character(len=24) :: prtfmt

    nlines = size(stat, 1)
    nsteps = size(stat, 2)
    write(prtfmt, '(A,I0,A)') &
         & '(F10.3,', nlines-1, 'F20.12)'

    open(newunit=u, file=filename, action='write')
    if (type .eq. 'energies') then
       write(u, '(A10,5A20)') &
            & 'Time', 'Etot', 'Ekin', 'Epot(1)', 'Epot(2)', '...'
    end if

    if (type .eq. 'populations') then
       write(u, '(A10,3A20)') &
            & 'Time', 'State(1)', 'State(2)', '...'
    end if

    do i=1, nsteps
       write(u, prtfmt) (stat(j, i), j=1, nlines)
    end do
    close(u)
  end subroutine pretty_print_stats


  subroutine pretty_print_fractraj(fractraj, time, filename)
    integer, dimension(:, :), intent(in) :: fractraj
    real(dp), dimension(:), intent(in) :: time
    character(len=*), intent(in) :: filename

    integer :: u, i, j
    integer :: nlines, nstat
    character(len=24) :: prtfmt

    nlines = size(fractraj, 1)
    nstat = size(fractraj, 2)
    write(prtfmt, '(A,I0,A)') &
         & '(F10.3,', nstat, 'I20)'

    open(newunit=u, file=filename, action='write')
    write(u, '(4A20)') 'Time', 'State(1)', 'State(2)', '...'
    do i=1, nlines
       write(u, prtfmt) time(i), (fractraj(i, j), j=1, nstat)
    end do
    close(u)

  end subroutine pretty_print_fractraj





end module mod_statistics
