import numpy as np
import pyterpol3
import sys

def load_observations(f):
    """
    :param f: file
    :return:
    """

    # load the observations
    flist = np.loadtxt(f, usecols=[2], unpack=True, dtype=str)
    hjd, rms = np.loadtxt(f, usecols=[0, 1], unpack=True)

    # create list of observations
    obs = []
    for i, sf in enumerate(flist):
        obs.append(dict(filename=sf, error=rms[i], group=dict(rv=i), hjd=hjd[i]))

    # create ObservedList
    ol = pyterpol3.ObservedList()
    ol.add_observations(obs)

    return ol, flist, hjd, rms

def optimize_rv(itf):
    """
    Optimizes RVs step by step
    :param itf:
    :return:
    """

    # turn off all parameters
    for p in itf.sl.get_parameter_types():
        itf.set_parameter(parname=p, fitted=False)

    for i in range(0, itf.get_observed_spectra_number()):

        # fit one set of RVs
        itf.set_parameter(parname='rv', component='pri', fitted=True, group=i, vmin=-100., vmax=180.)
        itf.set_parameter(parname='rv', component='sec', fitted=True, group=i, vmin=-300., vmax=200.)

        # select only part of comparisonlist - to make it fast
        l = itf.get_comparisons(rv=i)

        #print itf
        #print "group: ", i

        # run the fit
        itf.run_fit(l=l)

        # turn off fitting of one set
        itf.set_parameter(parname='rv', fitted=True, group=i)

    return itf

def optimize_non_rv(itf):
    """
    Optimizes all apart from the RVs.
    :param itf:
    :return:
    """

    # set all parameters fitted
    itf.set_parameter(component='pri', parname='teff', fitted=True, vmin=26000., vmax=32000.)
    itf.set_parameter(component='pri', parname='logg', fitted=True, vmin=3.2, vmax=3.9)
    itf.set_parameter(component='pri', parname='vrot', fitted=True, vmin=35., vmax=90.)
    itf.set_parameter(component='pri', parname='lr', fitted=True, vmin=0.6, vmax=0.9)
    itf.set_parameter(component='sec', parname='vrot', fitted=True, vmin=30., vmax=80.)
    itf.set_parameter(component='sec', parname='lr', fitted=True, vmin=0.05, vmax=0.35)
    itf.set_parameter(component='sec', parname='logg', fitted=True, vmin=3.8, vmax=4.5)
    itf.set_parameter(component='sec', parname='teff', fitted=True, vmin=24000., vmax=29000.)

    # run the fitting
    itf.run_fit()

    return itf


def main():
    """
    :return:
    """
    # parameters
    niter = 5

    # 1) Generate region
    rl = pyterpol3.RegionList()
    rl.add_region(wmin=4515, wmax=5600, groups=dict(lr=0))
    #rl.add_region(wmin=6340, wmax=6725, groups=dict(lr=1))

    # 2) Load observed data
    ol = load_observations('seznam')[0]

    ## 3) Generate components
    sl = pyterpol3.StarList()
    sl.add_component('pri',  teff=29500., logg=3.40, lr=0.80, vrot=50., z=1.0)
    sl.add_component('sec', teff=27000., logg=4.00, lr=0.20,  vrot=50., z=1.0)


    ## 4) construct the interface
    itf = pyterpol3.Interface(sl=sl, rl=rl, ol=ol)
    itf.set_grid_properties(order=3, step=0.05)
    #itf.set_one_for_all(True)
    itf.setup()

    ## 6) choose a fitting environment
    itf.choose_fitter('nlopt_nelder_mead', ftol=5e-7)
    #itf.plot_all_comparisons(figname='initial')
    print (itf)
    print ("Degrees of freedom: ", itf.get_degrees_of_freedom())

    #sys.exit(1)
    for i in range(niter):
        # optimize radial velocity
        itf = optimize_rv(itf)

        # optimize the rest
        itf = optimize_non_rv(itf)

    # write down rvs
    itf.write_rvs('kxvel.rv')

    ## 8) plot result
    itf.save('fin.sav')
    itf.plot_all_comparisons(figname='fin')

if __name__ == '__main__':
    main()
