#!/usr/local/other/python/GEOSpyD/2019.03_py3.7/2019-04-22/bin/python

import sys
import numpy as np
import matplotlib as mpl
#mpl.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LightSource
from matplotlib.cm import get_cmap
import cartopy.crs as ccrs
import cartopy.feature as cf
import xarray as xr
import imageio
sys.path.insert(1,'/discover/nobackup/cakelle2/GEOS_CF/Apps/cftools/external/cmcrameri')
from cmcrameri import cm

# point to relevant paths
path_relax_tol = '/discover/nobackup/projects/gmao/geos_cf_dev/psturm/CFv2_c90/holding_relax_tol/gcc_kpp'
path_control = '/discover/nobackup/projects/gmao/geos_cf_dev/psturm/CFv2_c90/holding_control/gcc_kpp'

# load in desired netcdf files in those paths
nc_relax_tol = xr.open_dataset(path_relax_tol + '/CFv2_ctrl_c90.gcc_kpp.20181005_2100z.nc4')
nc_control = xr.open_dataset(path_control + '/CFv2_ctrl_c90.gcc_kpp.20181005_2100z.nc4')

def _make_subplot(fig, igs, longitude, latitude, concentration, species,
                  title=None, ylabel = None, cmap=cm.batlow, levels=None,
                  cbar_label=None, cbar_ticks=None):
    if cbar_label is None:
        cbar_label = f'{species} [ppb]'
    proj = ccrs.PlateCarree()
    ax = fig.add_subplot(igs, projection=proj)
    ax.coastlines()
    cp = ax.contourf(longitude, latitude, concentration, cmap=cmap, levels=levels)
    plt.colorbar(cp, ax=ax, orientation='vertical', label=cbar_label, ticks=cbar_ticks,fraction=0.031, pad=0.04)
    ax.set_title(title, fontsize=18)
    if ylabel is not None:
        ax.set_ylabel(ylabel, fontsize=17)
        ax.set_yticks([])
    return ax

def plot_concentration(fig, igs, column, nc_control, nc_relax_tol, species, level, title,
                       max_abs_bias=None, axes_range=None,
                       units=None, time_index=0):
    # print(f'Plotting {species} concentration at {title} for control, relax_tol, and difference')
    lons = nc_control.variables['lon'][:]
    lats = nc_control.variables['lat'][:]
    scale = 1e9 # assume we're plotting ppb from mixing ratio
    if units is not None:
        if 'ppb' in units:
            scale = scale
        else:
            scale = 1
    # if levels and cbar_ticks are not specified, set them to default values:
    # find the global min and max over both datasets
    if axes_range is None:
        axes_min = np.min([np.min(nc_control[species][:,level,:,:]), np.min(nc_relax_tol[species][:,level,:,:])])
        axes_max = np.max([np.max(nc_control[species][:,level,:,:]), np.max(nc_relax_tol[species][:,level,:,:])])
    else:
        axes_min, axes_max = axes_range
    axes_min = axes_min*scale
    axes_max = axes_max*scale
    # set the levels to be 100 evenly spaced values between the min and max
    levels = np.linspace(axes_min, axes_max, 100)
    # set the cbar_ticks to be 5 evenly spaced values between the min and max
    cbar_ticks = np.linspace(axes_min, axes_max, 5)
    # make the first two axes
    ylabel = None
    if column == 0: ylabel = "Default Tolerances"
    _make_subplot(fig, igs[0, column], lons, lats, nc_control[species][time_index,level,:,:]*scale, species,
            title=title, ylabel = ylabel, cmap=cm.batlow, levels=levels, cbar_ticks=cbar_ticks, cbar_label=units)
    if column == 0: ylabel = "Relaxed Tolerances"
    _make_subplot(fig, igs[1, column], lons, lats, nc_relax_tol[species][time_index,level,:,:]*scale, species,
            title=None, ylabel = ylabel, cmap=cm.batlow, levels=levels, cbar_ticks=cbar_ticks, cbar_label=units)
    # find the largest absolute bias to set as the symmetric colorbar range
    if max_abs_bias is None:
        max_abs_bias = np.max(np.abs(nc_relax_tol[species][:,level,:,:] - nc_control[species][:,level,:,:]))
    max_abs_bias = max_abs_bias*scale
    bias_levels = np.arange(-max_abs_bias, max_abs_bias+0.01*max_abs_bias, max_abs_bias/100)
    bias_cbar_ticks = np.arange(-max_abs_bias,max_abs_bias+0.5*max_abs_bias,max_abs_bias/2)
    # make the third axis
    if column == 0: ylabel = "Difference"
    _make_subplot(fig, igs[2, column], lons, lats, (nc_relax_tol[species][time_index,level,:,:] - nc_control[species][time_index,level,:,:])*scale, species,
            title=None, ylabel = ylabel, cmap=cm.vik, levels=bias_levels, cbar_ticks=bias_cbar_ticks,
            cbar_label= units)
    return fig  # return the figure

def plot_multiple_concentrations(nc_control, nc_relax_tol, species_level_list,
                                 max_abs_bias=None, axes_range=None,
                                 time_index =0,name='concentration_plots.png'):
    fig = plt.figure(figsize=(18,9))  # create a new figure
    igs = GridSpec(3, 3)  # create a 3x3 grid
    for i, (species, level, title, units) in enumerate(species_level_list):
        column_max_abs_bias = max_abs_bias[i] if max_abs_bias is not None else None
        column_axes_range = axes_range[i] if axes_range is not None else None
        plot_concentration(fig, igs, i, nc_control, nc_relax_tol, species, level, title,
                           units=units, time_index=time_index,
                           max_abs_bias=column_max_abs_bias, axes_range=column_axes_range)
    fig.tight_layout()  # make sure the axes don't overlap
    fig.suptitle(f'{nc_control.time[time_index].values}', fontsize=12)
    fig.subplots_adjust(top=0.93)  # adjust the top spacing
    fig.savefig(name,dpi=80)  # save the figure
    plt.close(fig)  # close the figure

plot_multiple_concentrations(nc_control, nc_relax_tol,
                             [('SpeciesConc_O3', -1, 'Surface',"Concentration [ppb]"),
                              ('SpeciesConc_O3', -48, 'L48 (10 hPa)',"Concentration [ppb]"),
                              ('SpeciesConc_O3', -56, 'L56 (1.9 hPa)',"Concentration [ppb]")],
                              name='figs/ozone_levels_day5.png')

# Do the same for NO2
plot_multiple_concentrations(nc_control, nc_relax_tol,
                             [('SpeciesConc_NO2', -1, 'Surface',"NO2 Concentration [ppb]"),
                              ('SpeciesConc_NO2', -48, 'L48 (10 hPa)'," NO2 Concentration [ppb]"),
                              ('SpeciesConc_NO2', -56, 'L56 (1.9 hPa)',"NO2 Concentration [ppb]")],
                              name='figs/no2_levels_day5.png')


plot_multiple_concentrations(nc_control, nc_relax_tol,
                                [('SpeciesConc_O3', -48, 'O3 at 10 hPa',"Concentration [ppb]"),
                                ('SpeciesConc_NO2', -48, 'NO2 at 10 hPa',"Concentration [ppb]"),
                                ('KppCPUSteps', -1, 'Total CPU Steps',"Total Internal Steps")],
                                name='figs/ozone_no2_cpu.png')

# Stand-in for TotCol: sum O3 and NO2 vertically
# nc_control["TotCol_O3"] = nc_control["SpeciesConc_O3"]
# nc_control["TotCol_NO2"] = nc_control["SpeciesConc_NO2"]
# nc_relax_tol["TotCol_O3"] = nc_relax_tol["SpeciesConc_O3"]
# nc_relax_tol["TotCol_NO2"] = nc_relax_tol["SpeciesConc_NO2"]
# # sum across levels
# nc_control["TotCol_O3"][:,-1,:,:] = nc_control["TotCol_O3"].sum(dim='lev')
# nc_control["TotCol_NO2"][:,-1,:,:] = nc_control["TotCol_NO2"].sum(dim='lev')
# nc_relax_tol["TotCol_O3"][:,-1,:,:] = nc_relax_tol["TotCol_O3"].sum(dim='lev')
# nc_relax_tol["TotCol_NO2"][:,-1,:,:] = nc_relax_tol["TotCol_NO2"].sum(dim='lev')
# plot_multiple_concentrations(nc_control, nc_relax_tol,
#                                 [('TotCol_O3', -1, 'Total O3 Column Stand-in',"Fake weird concentration units"),
#                                 ('TotCol_NO2', -1, 'Total NO2 Column Stand-in',"Fake weird concentration units"),
#                                 ('KppCPUSteps', -1, 'Total CPU Steps',"Total Internal Steps")],
#                                 name='figs/ozone_no2_cpu_totcol.png')


# Now make a function to plot multiple netcdf files in path_relax_tol and path_control
# Save these images to a folder called gif_frames
# And them make a gif
# load the merged files
def calculate_axes_bounds(nc_control, nc_relax_tol, args):
    max_abs_biases = []
    axes_ranges = []
    for (species, level, _, _) in args:
        # print a message on which species and level we're on
        print(f'Calculating axes bounds for {species} at level {level}')
        # Extract the data for this species and level
        control_data = nc_control[species][:,level,:,:]
        relax_tol_data = nc_relax_tol[species][:,level,:,:]
        # Calculate the max absolute bias for this species and level
        max_abs_bias = np.max(np.abs(relax_tol_data - control_data))
        max_abs_biases.append(max_abs_bias)
        # Calculate the axes range for this species and level
        axes_min = np.min([np.min(control_data), np.min(relax_tol_data)])
        axes_max = np.max([np.max(control_data), np.max(relax_tol_data)])
        axes_ranges.append((axes_min, axes_max))
    return max_abs_biases, axes_ranges

nc_control = xr.open_dataset(path_control + '/mergetime.nc4')
nc_relax_tol = xr.open_dataset(path_relax_tol + '/mergetime.nc4')

# args list
figure_column_specs = [('SpeciesConc_O3', -48, 'O3 at 10 hPa',"Concentration [ppb]"),
        ('SpeciesConc_NO2', -48, 'NO2 at 10 hPa',"Concentration [ppb]"),
        ('KppCPUSteps', -1, 'Total CPU Steps',"Total Internal Steps")]


max_abs_bias, axes_range = calculate_axes_bounds(nc_control, nc_relax_tol, figure_column_specs)

# loop over the time dimension and call multiple_concentrations
images = []
total_timesteps_start = 0
total_timesteps_end = 480 # range(0,nc_control.dims['time'])
for time_index in range(total_timesteps_start,total_timesteps_end):
    print(f'Plotting time index {time_index}')
    plot_multiple_concentrations(nc_control, nc_relax_tol,
                                    figure_column_specs,
                                    time_index=time_index,
                                    max_abs_bias=max_abs_bias, axes_range=axes_range,
                                    name=f'gif_frames/ozone_no2_cpu_{time_index}.png')
    images.append(imageio.imread(f'gif_frames/ozone_no2_cpu_{time_index}.png'))

imageio.mimsave('ozone_no2_cpu_L48_5days.gif', images, duration=.08)
