#!/usr/bin/env python

import sys

if len(sys.argv) != 2 and len(sys.argv) != 3:
  print "Computes velocity power spectrum (autocorrelation Fourier-transform"
  print "  magnitude) as a function of frequency"
  print "Usage: %s infile -raw" % sys.argv[0]
  sys.exit(1)

import numpy as np
from quippy import *

raw=False
if len(sys.argv) == 3:
    if sys.argv[2] == "-raw":
        raw=True
    else:
        sys.stderr.write("Got unknown last argument %s\n" % sys.argv[2])
        sys.exit(2)

if raw:
    f = open(sys.argv[1],"r")
    l = next(f).rstrip()
    n_atoms = int(l)
    n_frames = 0
    vels_cur = fzeros((3,n_atoms))
    i_line = 0
    vels_raw=[]
    for l in f:
        l = l.rstrip()
        (vx,vy,vz) = l.split()
        vels_cur[:,i_line+1] = [vx, vy, vz]
        i_line += 1
        if i_line == n_atoms:
            vels_raw.append(vels_cur)
            vels_cur = fzeros((3,n_atoms))
            n_frames += 1
            i_line = 0
            if n_frames%10 == 0:
                print >> sys.stderr, "read config %d" % n_frames
    vels = fzeros((3*n_frames, 3*n_atoms))
    i_frame = 1
    for at_velo in vels_raw:
        for i in frange(n_atoms):
          vels[i_frame+n_frames, 1:3*n_atoms:3]   = at_velo[1,1:n_atoms]
          vels[i_frame+n_frames, 2:1+3*n_atoms:3] = at_velo[2,1:n_atoms]
          vels[i_frame+n_frames, 3:2+3*n_atoms:3] = at_velo[3,1:n_atoms]
        i_frame += 1
        if i_frame%10 == 0:
            print >> sys.stderr, "proc config %d" % i_frame
    dt=1.0
else:
    ar = AtomsReader(sys.argv[1])
    n_frames = len(ar)
    n_atoms = ar[1].n
    vels = fzeros((3*n_frames, 3*n_atoms))
    i_frame = 1
    prev_at = None
    for at in ar:
      if i_frame%10 == 0:
         print >> sys.stderr, "config %d" % i_frame
      if 'velo' in at.properties.keys():
        if not prev_at is None:
          dt = at.params['time'] - prev_at.params['time']
        for i in frange(n_atoms):
          vels[i_frame+n_frames, 1:3*at.n:3]   = at.velo[1,1:at.n]
          vels[i_frame+n_frames, 2:1+3*at.n:3] = at.velo[2,1:at.n]
          vels[i_frame+n_frames, 3:2+3*at.n:3] = at.velo[3,1:at.n]
        prev_at = at.copy()
      else:
        if not prev_at is None:
          dt = at.params['time'] - prev_at.params['time']
          for i in frange(n_atoms):
            vels[i_frame+n_frames, 1:3*at.n:3]   = (at.pos[1,1:at.n] - prev_at.pos[1,1:at.n])*dt
            vels[i_frame+n_frames, 2:1+3*at.n:3] = (at.pos[2,1:at.n] - prev_at.pos[2,1:at.n])*dt
            vels[i_frame+n_frames, 3:2+3*at.n:3] = (at.pos[3,1:at.n] - prev_at.pos[3,1:at.n])*dt
        prev_at = at.copy()
      i_frame += 1
    print >> sys.stderr, "dt %f n_frames %d" % (dt, n_frames)

print "first frame ", vels[n_frames+1,1:3]
print "second frame ", vels[n_frames+2,1:3]

if (len(vels[:,1]) % 2) == 0:
  mean_vel_fft = np.zeros(len(vels[:,1])/2+1)
else:
  mean_vel_fft = np.zeros(len(vels[:,1]-1)/2+1)

for i in frange(n_atoms): 
  print >> sys.stderr, "fft atom %d" % i
  for j in frange(3):
    vel_fft = np.fft.rfft(vels[:,3*(i-1)+j])
    mean_vel_fft[:] += abs(vel_fft[:]*vel_fft[:])

print "# nu(1/fs)  |v(nu)^2|"
mean_vel_fft[:] /= 3*n_atoms
ii = 0
for vel_nu in mean_vel_fft:
  print "%f %.10f" % (ii/(3*n_frames*float(dt)), vel_nu)
  ii += 1
