#!/usr/local/other/python/GEOSpyD/2019.03_py3.7/2019-04-22/bin/python
import argparse
import sys
import numpy as np
import datetime as dt
from calendar import monthrange
import os
from netCDF4 import Dataset,num2date
import pandas as pd
import glob
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.dates as mdates
from matplotlib.cm import get_cmap
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LightSource
from matplotlib.cm import get_cmap
import cartopy.crs as ccrs
import cartopy.feature as cf
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
#from cartopy.io.img_tiles import Stamen
from pandas.plotting import register_matplotlib_converters
import xarray as xr
register_matplotlib_converters()

sys.path.insert(1,'/discover/nobackup/cakelle2/GEOS_CF/Apps/cftools/external/cmcrameri')
from cmcrameri import cm

idate = dt.datetime(2018,5,1)
zonal = True 

def main(args):
    # read data
    fid_v1r = Dataset(idate.strftime('../CFv2_c90rc1/monthly/CFv2_c90rc1.geosgcm_chm.%Y%m.nc4'),'r')
    lab_v1r = 'ref'
    fid_v20 = Dataset(idate.strftime('/discover/nobackup/projects/gmao/geos_cf/priv/GEOS-CF_v1-1/ana/Y2017/M%m/GEOS-CF_v1-1Spinup.chm_inst_1hr_g1440x721_v72.monthly.2017%m.nc4'),'r')
    lab_v20 = 'spinup'

    # get coordinates
    #lev1 = 37
    lev1 = 20
    lev2 = 71
    levs = fid_v1r.variables['lev'][lev1:lev2]
    lons = fid_v1r.variables['lon'][:]
    lats = fid_v1r.variables['lat'][:]
    
    # select variables, calculate totals
#    spec  = 'NOy'
#    var1  = 'NOy'
#    var2  = 'NOy'
#    scal  = 1.0e9
    spec  = 'O3'
    var1  = 'SpeciesConc_O3'
    var2  = 'O3'
    scal  = 1.0e9
#    spec  = 'CO'
#    var1  = 'COdry'
#    var2  = 'SpeciesConc_CO'
#    scal  = 1.0e9
#    spec  = 'PM25'
#    var1  = 'PM25_RH35_GCC'
#    var2  = 'PM25'
#    scal  = 1.0
    if zonal:
        v1r = fid_v1r.variables[var1][0,lev1:lev2,:,:].mean(axis=2)*scal
        v20 = fid_v20.variables[var2][0,lev1:lev2,:,:].mean(axis=2)*scal
    else:
        v1r = fid_v1r.variables[var1][0,71,:,:]*scal
        v20 = fid_v20.variables[var2][0,71,:,:]*scal
    
    # ratios to reference 
    v20_rat = ( v20 - v1r ) / v1r
    title_rat = 'Relative difference to ref'
    
    # make figure 
    #cmap1    = cm.batlow  
    #cmap1    = cm.lajolla
    #cmap1    = cm.lisbon 
    #cmap1    = cm.davos_r 
    cmap1    = cm.tokyo_r
    #cmap1    = cm.lapaz
    #cmap1    = cm.oslo_r
    #cmap1    = cm.acton_r
    #cmap1    = get_cmap("rainbow")
    cmap2    = get_cmap("seismic")
#    a        = np.arange(0.,0.01,0.001)
#    b        = np.arange(0.01,0.1,0.01)
#    c        = np.arange(0.1,1.0,0.1)
#    d        = np.arange(1.0,20.0,1.0)
#    e        = np.arange(20.0,200.0,20.0)
#    flev1    = np.concatenate((a,b,c,d))
    #flev1    = np.arange(0.0,80.0,1.0)
    flev1    = np.arange(1.0,70.0,1.0)
    #flev1    = np.arange(0.0,20.05,0.5)
#    flev1    = np.arange(0.0,100.0,2.0)
#    flev1    = np.arange(50.0,300.0,5.0)
    flev2    = np.arange(-0.5,0.5,0.02) 
    shrnk    = 1.0 if zonal else 1.0
 
    #fig      = plt.figure(figsize=(18,5))
    #fig      = plt.figure(figsize=(18,5)) if zonal else plt.figure(figsize=(16,10))
    #gs       = GridSpec(2,6) if zonal else GridSpec(4,3)
    fig      = plt.figure(figsize=(11,5))
    gs       = GridSpec(2,2)

    # variables to plot:
    d1 = lats if zonal else lons
    d2 = levs[::-1] if zonal else lats
    
    # absolute plots
    ax1 = _make_plot(fig,gs[0,0],d1,d2,v1r,lab_v1r,cmap1,flev1,shrnk,zonal,extend='both')
    ax2 = _make_plot(fig,gs[0,1],d1,d2,v20,lab_v20,cmap1,flev1,shrnk,zonal,extend='both')
    ax3 = _make_plot(fig,gs[1,1],d1,d2,v20_rat,'(spinup - ref) / ref',cmap2,flev2,shrnk,zonal,extend='both')

    # save out
    #fig.suptitle(idate.strftime('%b %Y'))
    fig.suptitle(idate.strftime(spec+' (%Y-%m)'))
    fig.tight_layout()
    ofile = idate.strftime('png/sfc_spinup_'+spec+'_%Y%m.png') if not zonal else idate.strftime('png/zonal_spinup_'+spec+'_%Y%m.png')
    plt.savefig(ofile)
    plt.close()
    print('Figure written to '+ofile) 
    
    # close files
    fid_v1r.close()
    fid_v20.close()
    return


def _make_plot(fig,igs,dim1,dim2,arr,title,cmap,flev,shrnk,zonal,**kwargs):
    proj = ccrs.PlateCarree()
    ax = fig.add_subplot(igs,projection=proj) if not zonal else fig.add_subplot(igs)
    if not zonal:
        ax.coastlines()
        cp = ax.contourf(dim1,dim2,arr,transform=proj,cmap=cmap,levels=flev,**kwargs)
    else:
        cp = ax.contourf(dim1,dim2,arr,cmap=cmap,levels=flev,**kwargs)
        yticks = [68,58,48,38]
        ax.set_yticks(yticks)
        ax.set_yticklabels([42,52,62,72])
    plt.colorbar(cp,ax=ax,shrink=shrnk)
    ax.set_title(title)
    return ax


def parse_args():
    p = argparse.ArgumentParser(description='Undef certain variables')
    p.add_argument('-v','--var',type=str,help='variable',default='O3dry')
    p.add_argument('-v2','--var2',type=str,help='variable 2',default=None)
    p.add_argument('-n','--name',type=str,help='name',default=None)
    p.add_argument('-s','--scal',type=float,help='dimension',default=1.0e9)
    return p.parse_args()

if __name__ == '__main__':
    main(parse_args()) 
