#!/usr/bin/env python3

import sys
import numpy as np
import pyterpol3





def load_observations(f):
    """
    :param f: file
    :return:
    """

    # load the observations
    flist = np.loadtxt(f, usecols=[2], unpack=True, dtype=str, ndmin=1)
    hjd = np.loadtxt(f, usecols=[0], unpack=True, ndmin=1)
    rms = np.loadtxt(f, usecols=[1], unpack=True, ndmin=1)

    # create list of observations
    obs = []
    for i, sf in enumerate(flist):
        obs.append(dict(filename=sf, error=rms[i], group=dict(rv=i)))

    # create ObservedList
    ol = pyterpol3.ObservedList()
    ol.add_observations(obs)

    return ol, flist, hjd, rms

def optimize_rv(itf):
    """
    Optimizes RVs step by step
    :param itf:
    :return:
    """

    # turn off all parameters
    for p in itf.sl.get_parameter_types():
        itf.set_parameter(parname=p, fitted=False)

    for i in range(0, itf.get_observed_spectra_number()):

        # fit one set of RVs
        itf.set_parameter(parname='rv', component='pri', fitted=True, group=i, vmin=-400., vmax=300.)
        itf.set_parameter(parname='rv', component='sec', fitted=True, group=i, vmin=-400., vmax=300.)
        itf.set_parameter(parname='rv', component='tri', fitted=True, group=i, vmin=-300., vmax=120.)

        # select only part of comparisonlist - to make it fast
        l = itf.get_comparisons(rv=i)

        # run the fit
        itf.run_fit(l=l)

        # turn off fitting of one set
        itf.set_parameter(parname='rv', fitted=False, group=i)

    return itf

def optimize_non_rv(itf):
    """
    Optimizes all apart from the RVs.
    :param itf:
    :return:
    """

    # set all parameters fitted
    itf.set_parameter(component='pri', parname='teff', fitted=True, vmin=29000., vmax=33000.)
    itf.set_parameter(component='pri', parname='logg', fitted=True, vmin=3.00, vmax=4.40)
    itf.set_parameter(component='pri', parname='vrot', fitted=True, vmin=80., vmax=250.)
    itf.set_parameter(component='pri', parname='lr', fitted=True, vmin=0.40, vmax=0.85)
    itf.set_parameter(component='pri', parname='z', fitted=False, vmin=0.5, vmax=2.0)
    itf.set_parameter(component='sec', parname='teff', fitted=True, vmin=18000., vmax=25000.)
    itf.set_parameter(component='sec', parname='logg', fitted=True, vmin=3.00, vmax=4.75)
    itf.set_parameter(component='sec', parname='vrot', fitted=True, vmin=60., vmax=200.)
    itf.set_parameter(component='sec', parname='lr', fitted=True, vmin=0.20, vmax=0.45)
    itf.set_parameter(component='sec', parname='z', fitted=False, vmin=0.5, vmax=2.0)
    itf.set_parameter(component='tri', parname='teff', fitted=True, vmin=15000., vmax=17000.)
    itf.set_parameter(component='tri', parname='logg', fitted=True, vmin=3.00, vmax=4.20)
    itf.set_parameter(component='tri', parname='vrot', fitted=True, vmin=60., vmax=250.)
    itf.set_parameter(component='tri', parname='lr', fitted=True, vmin=0.03, vmax=0.20)
    itf.set_parameter(component='tri', parname='z', fitted=False, vmin=0.5, vmax=2.0)
    # run the fitting
    itf.run_fit()

    return itf


def main():
    """
    :return:
    """
    # parameters
    niter = 5

    # 1) Generate region
    rl = pyterpol3.RegionList()
    rl.add_region(wmin=4630,wmax=5428, groups=dict(lr=0))
#   rl.add_region(wmin=4990,wmax=5428, groups=dict(lr=1))

    # 2) Load observed data
    ol = load_observations('seznam')[0]

    ## 3) Generate components
    sl = pyterpol3.StarList()
    sl.add_component('pri',  teff=31400., logg=3.87, lr=0.60, vrot=115., z=1.0)
    sl.add_component('sec',  teff=22000., logg=3.92, lr=0.32, vrot=140., z=1.0)
    sl.add_component('tri',  teff=16000., logg=3.27, lr=0.08, vrot=235., z=1.0)


    ## 4) construct the interface
    itf = pyterpol3.Interface(sl=sl, rl=rl, ol=ol)
    itf.set_grid_properties(order=3, step=0.05)
    #itf.set_one_for_all(True)
    itf.setup()

    ## 6) choose a fitting environment
    itf.choose_fitter('nlopt_nelder_mead', ftol=1e-7)
    #itf.plot_all_comparisons(figname='initial')
    print(itf)
    print("Degrees of freedom: ", itf.get_degrees_of_freedom())

    #sys.exit(1)
    for i in range(niter):
        # optimize radial velocity
        itf = optimize_rv(itf)

        # optimize the rest
        itf = optimize_non_rv(itf)

    # write down rvs
    itf.write_rvs('hd.rvs')

    ## 8) plot result
    itf.save('final.sav')
    itf.plot_all_comparisons(figname='new')

if __name__ == '__main__':
    main()
