from matplotlib import pyplot as plt
import numpy as np
from scipy import optimize, interpolate
from datetime import datetime, timedelta
import os, tqdm, calendar, itertools
from netCDF4 import Dataset
from dateutil.relativedelta import relativedelta

class MBL_GR_from_DEI(object):

    def __init__(self, *args, **kwargs):
        super(MBL_GR_from_DEI, self).__init__()
        self.output_dir_pattern = 'output/co2/results.web.%Y-%m'

    def harmonic_function(self, t, coeffs):
        if len(coeffs) % 2 != 0:
            raise RuntimeError('coeffs must have an even number of elements')
        num_harmonics = len(coeffs)//2

        ret_y = np.zeros(len(t), np.float64)

        for i in range(num_harmonics):
            ret_y = ret_y + coeffs[2*i] * np.sin(2*np.pi*t*(i+1)) + coeffs[2*i+1] * np.cos(2*np.pi*t*(i+1))

        return ret_y

    def polynomial_function(self, t, coeffs):
        return np.polyval(coeffs, t)

    def total_function(self, t, coeffs, num_poly):
        poly_coeffs = coeffs[:num_poly]
        harm_coeffs = coeffs[num_poly:]
        return self.polynomial_function(t, poly_coeffs) + self.harmonic_function(t, harm_coeffs)

    def fit_poly_plus_harmonic(self, t, y, num_harmonics, num_poly):
        # first, fit a polynomial to get a rough idea of the polynomial coefficients
        poly_prelim = np.polyfit(t, y, deg=num_poly)
        # now fit a harmonic to y-polynomial to get a rough idea of the harmonic coefficients
        y_detrended = y - np.polyval(poly_prelim, t)
        y_amp = np.percentile(y_detrended, 99.0) - np.percentile(y_detrended, 1.0)
        harm_prelim_prior = 0.5 * y_amp * np.ones(2*num_harmonics, np.float64)
        harm_prelim, _ = optimize.curve_fit(lambda x,*p: self.harmonic_function(x,p), t, y_detrended, p0=harm_prelim_prior)
        # now fit the original data starting with guesses poly_prelim and harm_prelim
        coeff_prelim = np.concatenate((poly_prelim, harm_prelim))
        num_poly_coeffs = len(poly_prelim)
        coeffs, _ = optimize.curve_fit(
            lambda x,*p: self.total_function(x, p, num_poly_coeffs),
            t, y, p0=coeff_prelim,
            )
        # now return the fit
        ret_dict = {
            'poly_fit'      : self.polynomial_function(t, coeffs[:num_poly_coeffs]),
            'harmonic_fit'  : self.harmonic_function(t, coeffs[num_poly_coeffs:]),
            'total_fit'     : self.total_function(t, coeffs, num_poly_coeffs),
            }
        return ret_dict

    def subtract_avg_seasonal_cycle(self, result_date, **kwargs):
        num_harmonics = kwargs['num_harmonics'] if 'num_harmonics' in kwargs else 3
        poly_deg = kwargs['poly_deg'] if 'poly_deg' in kwargs else 2
        num_years = kwargs['num_years'] if 'num_years' in kwargs else 7
        min_fit_years = kwargs['min_fit_years'] if 'min_fit_years' in kwargs else num_years-1

        co2_file = os.path.join(result_date.strftime(self.output_dir_pattern), 'zone_gl.mbl.co2')
        year, co2 = np.loadtxt(co2_file, unpack=True)
        idx = co2 > 0.0
        year = year[idx]
        co2 = co2[idx]

        nsamples = np.zeros_like(year)
        co2_harm = np.zeros_like(co2)

        # create the year boundaries
        year_boundaries = []
        min_year = year.min()
        while min_year+min_fit_years <= year.max():
            max_year = min(min_year+num_years, year.max())
            year_boundaries.append((min_year, max_year))
            min_year += 1

        for min_year, max_year in tqdm.tqdm(year_boundaries, desc='Calculating moving average'):
            idx = np.logical_and(year >= min_year, year <= max_year)
            fit_dict = self.fit_poly_plus_harmonic(year[idx], co2[idx], num_harmonics, poly_deg)
            nsamples[idx] += 1.0
            co2_harm[idx] += fit_dict['harmonic_fit']

        co2_harm = co2_harm/nsamples
        co2_trend = co2 - co2_harm

        return {
            'year': year,
            'co2': co2,
            'trend': co2_trend,
            }

    def calculate_growth_rate(self, result_date, **kwargs):
        co2_data = self.subtract_avg_seasonal_cycle(result_date, **kwargs)

        # find the number of complete years
        week = 1./12./4. # 1 year = 48 weeks according to DEI
        complete_years = []
        dei_years = co2_data['year']
        min_year = int(dei_years.min())
        if min_year < dei_years.min():
            min_year += 1
        while min_year < dei_years.max():
            if min_year-4*week >= dei_years.min() and min_year+1+4*week <= dei_years.max():
                complete_years.append(min_year)
            min_year += 1

        # now calculate the growth rate
        spl = interpolate.InterpolatedUnivariateSpline(co2_data['year'], co2_data['trend'], k=1)
        growth_rate = []

        num_weeks = 4 # monthly avg is avg over 4 weeks
        weights = 2*np.ones(num_weeks+1, np.float64)
        weights[0] = 1 ; weights[-1] = 1
        weights = weights/weights.sum()

        for year in tqdm.tqdm(complete_years, desc='Calculating growth rate'):
            next_year = year+1
            beg_dec_avg = 0.0
            beg_jan_avg = 0.0
            end_dec_avg = 0.0
            end_jan_avg = 0.0
            for iw in range(num_weeks+1):
                beg_dec_avg += weights[iw] * spl(year-iw*week)
                beg_jan_avg += weights[iw] * spl(year+iw*week)
                end_dec_avg += weights[iw] * spl(next_year-iw*week)
                end_jan_avg += weights[iw] * spl(next_year+iw*week)
            beg_avg = 0.5*(beg_dec_avg+beg_jan_avg)
            end_avg = 0.5*(end_dec_avg+end_jan_avg)
            growth_rate.append(end_avg-beg_avg)

        return {
            'year': np.array(complete_years),
            'growth_rate' : np.array(growth_rate),
            }

    def stretch_fn(self, x, tp=36.0):
        # f(x) = a * sqrt(x), x <= tp,
        #      = (x-tp) + a*sqrt(tp), x > tp
        # a is determined so that the function is continuous and derivable at tp
        # a = 2 * sqrt(tp)
        a = 2.0 * np.sqrt(tp)
        fn = np.zeros(len(x), np.float64)
        fn[x<=tp] = a * np.sqrt(x[x<=tp])
        fn[x>tp] = x[x>tp]-tp + a*np.sqrt(tp)
        return fn

    def plot_growth_rate_convergence(self, **kwargs):
        method = kwargs['method'] if 'method' in kwargs else 'harmonics'
        years = kwargs['years'] if 'years' in kwargs else np.arange(2014, 2021)
        sub_asymptote = kwargs['relative'] if 'relative' in kwargs else False

        possible_markers = itertools.cycle(['o', 's', 'D', 'p', '^', 'v', '<', '>'])
        markers = dict.fromkeys(years)
        for year in years:
            markers[year] = next(possible_markers)

        min_web_results_date = datetime(2015,2,1)
        max_web_results_date = datetime(2022,3,1)

        growth_rate_dict = dict.fromkeys(years)
        for year in years:
            growth_rate_dict[year] = {'months elapsed': [], 'growth rate': []}

        dp = min_web_results_date
        while dp <= max_web_results_date:
            if method == 'harmonics':
                co2_data = self.calculate_growth_rate(dp, **kwargs)
                gr_dict = dict(zip(co2_data['year'], co2_data['growth_rate']))
                for year in years:
                    if not year in gr_dict:
                        continue
                    gr = gr_dict[year]
                    months_elapsed = (dp.year-year-1) * 12 + dp.month
                    growth_rate_dict[year]['months elapsed'].append(months_elapsed)
                    growth_rate_dict[year]['growth rate'].append(gr)

            else:
                raise RuntimeError('Other methods not yet implemented')

            dp = dp + relativedelta(months=1)

        fig = plt.figure()
        ax = plt.gca()

        # if sub_asymptote:
            # A_ = np.loadtxt('co2_gr_gl_20220326.txt')
            # asymp_dict = {}
            # for y, g, _ in A_:
                # asymp_dict[int(y)] = g

        mel_max = 0.0
        for year, gr_data in growth_rate_dict.items():
            sort_order = np.argsort(gr_data['months elapsed'])
            mel = np.array(gr_data['months elapsed'])[sort_order]
            grr = np.array(gr_data['growth rate'])[sort_order]

            if sub_asymptote:
                # grr = grr - asymp_dict[year]
                # take the last 20% of the data
                N = len(mel)//5
                asymp_gr = np.average(grr[-N:])
                grr = grr - asymp_gr

            xvals = self.stretch_fn(mel)
            ax.plot(xvals, grr, marker=markers[year], ls='-', lw=1.5, ms=6, mew=1, label='%4i'%year)
            #ax.plot(mel, grr, marker=markers[year], ls='-', lw=1.5, ms=6, mew=1, label='%4i'%year)
            # ax.semilogx(mel, grr, marker=markers[year], ls='-', lw=1.5, ms=6, mew=1, label='%4i'%year)

            mel_max = max(mel_max, mel.max())

        xticklocs = np.array([2] + list(np.arange(6, mel_max, 6)))
        xticklabs = ['%i'%x for x in xticklocs]
        ax.set_xticks(self.stretch_fn(xticklocs))
        ax.set_xticklabels(xticklabs)

        plt.setp(ax.get_xticklabels(), size=14)
        plt.setp(ax.get_yticklabels(), size=14)
        ax.set_xlabel('Months of data after year end', size=16)
        if sub_asymptote:
            ax.set_ylabel(u'CO\u2082 growth rate \u2212 final value', size=16)
        else:
            ax.set_ylabel(u'Estimated CO\u2082 growth rate (ppm/year)', size=16)
        ax.grid(True, ls='--')

        leg = plt.legend(loc='upper right', ncol=2, fontsize=14, handlelength=1.2, labelspacing=0.25, numpoints=1)
        leg.set_draggable(True)

        plt.subplots_adjust(0.115,0.12,0.98,0.97)

    def plot_growth_rate_by_end_date(self, **kwargs):
        dates_to_plot = [
            datetime(2021,2,1), datetime(2021,3,1), datetime(2021,4,1), datetime(2021,5,1), datetime(2021,6,1),
            datetime(2021,7,1), datetime(2021,8,1), datetime(2021,9,1), datetime(2021,10,1), datetime(2020,10,1),
            ]
        method = kwargs['method'] if 'method' in kwargs else 'harmonics'
        fig = plt.figure()
        ax = plt.gca()

        for dp in dates_to_plot:
            if method == 'harmonics':
                co2_data = self.calculate_growth_rate(dp, **kwargs)
                ax.plot(co2_data['year'], co2_data['growth_rate'], '-o', lw=1.5, ms=6, mew=1, label='Harmonic with %s data'%dp.strftime('%b %Y'))
                # self.print_table('My calculation of NOAA growth rate from NOAA data', co2_data['year'], co2_data['growth_rate'], min_year=2010)

            elif method == 'mbl trend':
                co2_file = os.path.join(dp.strftime(self.output_dir_pattern), 'zone_gl.mbl.tr.co2')
                year, co2 = np.loadtxt(co2_file, unpack=True)
                spl = interpolate.InterpolatedUnivariateSpline(year, co2, k=1)
                min_year = int(year.min())
                if min_year < year.min():
                    min_year += 1
                years = [] ; gr = []
                while min_year+1 <= year.max():
                    years.append(min_year)
                    gr.append(spl(min_year+1) - spl(min_year))
                    min_year += 1
                ax.plot(years, gr, '-', lw=2, label='Trend with %s data'%dp.strftime('%b %Y'))

            elif method == 'mbl average':
                co2_file = os.path.join(dp.strftime(self.output_dir_pattern), 'zone_gl.mbl.co2')
                year, co2 = np.loadtxt(co2_file, unpack=True)
                spl = interpolate.InterpolatedUnivariateSpline(year, co2, k=1)
                min_year = int(year.min())
                if min_year < year.min():
                    min_year += 1
                years = [] ; gr = []
                while min_year+1 <= year.max():
                    years.append(min_year)
                    gr.append(spl(min_year+1) - spl(min_year))
                    min_year += 1
                ax.plot(years, gr, '-', lw=2, label='MBL avg with %s data'%dp.strftime('%b %Y'))

        ax.set_xlim(xmin=1978.5, xmax=2020.5)
        plt.setp(ax.get_xticklabels(), size=14)
        plt.setp(ax.get_yticklabels(), size=14)
        ax.set_xlabel('Year', size=16)
        ax.set_ylabel(u'CO\u2082 MBL growth rate', size=16)
        ax.grid(True, ls='--')

        leg = plt.legend(loc='best', fontsize=14, handlelength=1.2, labelspacing=0.25)
        leg.set_draggable(True)

        plt.subplots_adjust(0.11,0.12,0.97,0.97)

    def print_table(self, title, year_arr, gr_arr, **kwargs):
        min_year = kwargs['min_year'] if 'min_year' in kwargs else 1980
        len_title = len(title)
        print(title)
        print('='*len_title)
        for year, gr, in zip(year_arr, gr_arr):
            if year >= min_year:
                print('%04i    %5.2f'%(year,gr))
        print()

    def plot_different_growth_rates(self, result_date, **kwargs):
        fig = plt.figure()
        ax = plt.gca()

        # first, the website
        year, gr = np.loadtxt('co2_gr_gl_20210421.txt', usecols=(0,1), unpack=True)
        idx = year > 1978.5
        year = year[idx] ; gr = gr[idx]
        ax.plot(year, gr, '-o', lw=1.5, ms=6, mew=1, label='NOAA website')
        self.print_table('From NOAA trends website', year, gr, min_year=2010)

        # # now, by differencing the DEI-calculated trend
        # co2_file = os.path.join(result_date.strftime(self.output_dir_pattern), 'zone_gl.mbl.tr.co2')
        # year, co2 = np.loadtxt(co2_file, unpack=True)
        # spl = interpolate.InterpolatedUnivariateSpline(year, co2, k=1)
        # min_year = int(year.min())
        # if min_year < year.min():
            # min_year += 1
        # years = [] ; gr = []
        # while min_year+1 <= year.max():
            # years.append(min_year)
            # gr.append(spl(min_year+1) - spl(min_year))
            # min_year += 1
        # ax.plot(years, gr, '-', lw=2, label='From DEI trend')

        # # now, by differencing the MBL averages
        # co2_file = os.path.join(result_date.strftime(self.output_dir_pattern), 'zone_gl.mbl.co2')
        # year, co2 = np.loadtxt(co2_file, unpack=True)
        # spl = interpolate.InterpolatedUnivariateSpline(year, co2, k=1)
        # min_year = int(year.min())
        # if min_year < year.min():
            # min_year += 1
        # years = [] ; gr = []
        # while min_year+1 <= year.max():
            # years.append(min_year)
            # gr.append(spl(min_year+1) - spl(min_year))
            # min_year += 1
        # ax.plot(years, gr, '-', lw=2, label='From DEI MBL avg')

        # now, from my moving average deseasonalization
        co2_data = self.calculate_growth_rate(result_date, **kwargs)
        ax.plot(co2_data['year'], co2_data['growth_rate'], '-o', lw=1.5, ms=6, mew=1, label='My calculation')
        self.print_table('My calculation of NOAA growth rate from NOAA data', co2_data['year'], co2_data['growth_rate'], min_year=2010)

        # now the COVID run
        co2_data = self.get_geos_run_budget()
        ax.plot(co2_data['year'], co2_data['co2'], '-s', lw=2, label=u'GEOS (exact)')
        self.print_table('Whole atmosphere budget from GEOS', co2_data['year'], co2_data['co2'], min_year=2010)
        ax.plot(co2_data['year'], co2_data['co2mbl'], '-d', lw=2, label=u'GEOS (remote MBL)')
        self.print_table('Remote MBL budget from GEOS', co2_data['year'], co2_data['co2mbl'], min_year=2010)
        # ax.plot(co2_data['year'], co2_data['co2sim'], '-s', lw=2, label=u'GEOS CO\u2082SIM (exact)')

        # now from Brad's OCO2 analysis
        co2_file = os.path.join('output/from_brad/2021-06-13/co2', 'results.2021-06-13', 'zone_gl.mbl.co2')
        year, co2 = np.loadtxt(co2_file, unpack=True)
        valid_idx = np.logical_and(co2 > 0.0, year > 2014.9)
        year = year[valid_idx] ; co2 = co2[valid_idx]
        spl = interpolate.InterpolatedUnivariateSpline(year, co2, k=1)
        min_year = int(year.min())
        if min_year < year.min():
            min_year += 1
        years = [] ; gr = []
        while min_year+1 <= year.max():
            years.append(min_year)
            gr.append(spl(min_year+1) - spl(min_year))
            min_year += 1
        ax.plot(years, gr, '-d', lw=2, label=u'GEOS (MBL samples)')
        self.print_table('GEOS run co-sampled with NOAA MBL samples', years, gr, min_year=2010)

        # Now GCP's growth rate, which in 2020 is 2.54 with an error of 0.08
        ax.errorbar(2020.0, 2.54, yerr=0.08, ecolor='k', elinewidth=1, capsize=3, capthick=1)
        ax.plot(2020.0, 2.54, 'D', ms=6, mew=0.5, mec='k', label='GCP')

        ax.set_xlim(xmin=1978.5, xmax=2020.5)
        plt.setp(ax.get_xticklabels(), size=14)
        plt.setp(ax.get_yticklabels(), size=14)
        ax.set_xlabel('Year', size=16)
        ax.set_ylabel(u'CO\u2082 MBL growth rate', size=16)
        ax.grid(True, ls='--')

        leg = plt.legend(loc='best', fontsize=14, handlelength=1.2, labelspacing=0.25)
        leg.set_draggable(True)

        plt.subplots_adjust(0.11,0.12,0.97,0.97)

    def get_geos_run_budget(self):
        with Dataset('co2_molefrac_2014102600_2021010103.nc', 'r') as fid:
            times = [datetime(*d) for d in fid.variables['time_comps'][:]]
            co2 = 1.0E6 * fid.variables['CO2'][:]
            co2_sim = 1.0E6 * fid.variables['CO2SIM'][:]
            co2_mbl = 1.0E6 * fid.variables['CO2_MBL'][:]
            # co2sim_mbl = 1.0E6 * fid.variables['CO2SIM_MBL'][:]

        decimal_date = []
        for t in times:
            seconds_elapsed = (t-datetime(t.year,1,1)).total_seconds()
            seconds_in_year = 86400.0 * (365 + int(calendar.isleap(t.year)))
            decimal_date.append(t.year + seconds_elapsed/seconds_in_year)
        decimal_date = np.array(decimal_date)

        spl_co2 = interpolate.InterpolatedUnivariateSpline(decimal_date, co2, k=1)
        spl_co2_sim = interpolate.InterpolatedUnivariateSpline(decimal_date, co2_sim, k=1)
        spl_co2_mbl = interpolate.InterpolatedUnivariateSpline(decimal_date, co2_mbl, k=1)
        # spl_co2sim_mbl = interpolate.InterpolatedUnivariateSpline(decimal_date, co2sim_mbl, k=1)

        min_year = int(decimal_date.min())
        if min_year < decimal_date.min():
            min_year += 1
        years = [] ; gr = [] ; gr_sim = [] ; gr_mbl = []
        while min_year+1 <= decimal_date.max():
            years.append(min_year)
            gr.append(spl_co2(min_year+1) - spl_co2(min_year))
            gr_sim.append(spl_co2_sim(min_year+1) - spl_co2_sim(min_year))
            gr_mbl.append(spl_co2_mbl(min_year+1) - spl_co2_mbl(min_year))
            min_year += 1

        return {
            'year'  : np.array(years),
            'co2'   : np.array(gr),
            'co2sim': np.array(gr_sim),
            'co2mbl': np.array(gr_mbl),
            }
