# conda activate /discover/nobackup/cakelle2/conda/envs/myviz
import argparse
import sys
import numpy as np
import datetime as dt
import os
import pandas as pd
import glob
import xarray as xr
import matplotlib.pyplot as plt
import geopandas as gpd
import contextily as ctx
import cartopy
from cartopy import crs as ccrs
from cartopy.io import shapereader
from shapely.geometry import Polygon
import pyproj

sys.path.insert(1,'/discover/nobackup/cakelle2/GEOS_CF/Apps/cftools/external/cmcrameri')
from cmcrameri import cm


lon1,lon2 = -125.0,-70.0
lat1,lat2 = 25.0,49.0
#lon1,lon2 = -85.0,-70.0
#lat1,lat2 = 25.0,45.0
lon1,lon2 = -85.0,-70.0
lat1,lat2 = 36.0,45.0


def main(args):
    #var = 'SpeciesConc_NO2'
    scal  = 1.0e9
    minval = 1.
    maxval = 10.
    proj = ccrs.epsg('3857')
    #proj = ccrs.PlateCarree()
    pdf1=None; pdf2=None; pdf3=None; pdf4=None
    # figures
    #cmap = plt.get_cmap('Reds')
    cmap = cm.lajolla
    fig, axs = plt.subplots(2,2,subplot_kw={'projection':proj},figsize=(29,20))
    _,pdf1,im1 = _make_fig('c90',cmap,axs[0,0],pdf=pdf1,legend=True)
    _,pdf2,im2 = _make_fig('c360',cmap,axs[0,1],pdf=pdf2,legend=True)
    _,pdf3,im3 = _make_fig('c720',cmap,axs[1,0],pdf=pdf3,legend=True)
    _,pdf4,im4 = _make_fig('c1440',cmap,axs[1,1],pdf=pdf4,legend=True)
    #fig, axs = plt.subplots(1,4,subplot_kw={'projection':proj},figsize=(25,7.5))
    #_,pdf1,im1 = _make_fig('c90',cmap,axs[0],pdf=pdf1,legend=True)
    #_,pdf2,im2 = _make_fig('c360',cmap,axs[1],pdf=pdf2,legend=True)
    #_,pdf3,im3 = _make_fig('c720',cmap,axs[2],pdf=pdf3,legend=True)
    #_,pdf4,im4 = _make_fig('c1440',cmap,axs[3],pdf=pdf4,legend=True)
#    fig, axs = plt.subplots(1,1,subplot_kw={'projection':proj},figsize=(25,11))
#    _,pdf4,im4 = _make_fig('c1440',cmap,axs,pdf=pdf4,legend=True)
    plt.tight_layout()
    plt.savefig('NOx_USne.pdf')
    plt.close()
    


def _make_fig(res,cmap,ax,title=None,pdf=None,alpha=0.6,legend=False):
    if pdf is None:
        ifile = 'nc4/GCv13_{res}r4.geosgcm_chm.20180703.nc4'.replace('{res}',res)
        #ifile = '/discover/nobackup/projects/gmao/geos_cf_dev/SCU16/runs/GCv13_{res}r4/holding/geosgcm_chm/GCv13_{res}r4.geosgcm_chm.20180101_0030z.nc4'.replace('{res}',res)
        ds = xr.open_dataset(ifile)
        # get lat/lon distances
        dx = (ds.lon.values[-1]-ds.lon.values[-2])/2.
        dy = (ds.lat.values[-1]-ds.lat.values[-2])/2.
        # select data
        ids = ds.sel(lev=72,lon=slice(lon1,lon2),lat=slice(lat1,lat2))
        df = ids.to_dataframe().reset_index()
        geom = [Polygon([[x-dx,y-dy],[x-dx,y+dy],[x+dx,y+dy],[x+dx,y-dy]]) for x, y in zip(df['lon'], df['lat'])]
        gdf = gpd.GeoDataFrame(df, geometry=geom).set_crs(epsg=4326)
        pdf = gdf.to_crs(epsg=3857)
        # convert mol mol-1 to ppbv
        pdf['NOx'] = (pdf['SpeciesConc_NO2']+pdf['SpeciesConc_NO'])*scal
    # region of interest 
    transformer = pyproj.Transformer.from_crs("epsg:4326", "epsg:3857")
    bounds = transformer.transform([lat1,lat1,lat2,lat2],[lon1,lon2,lon2,lon1])
    exts = [np.min(bounds[0]),np.max(bounds[0]),np.min(bounds[1]),np.max(bounds[1])]
    # limit to grid cells above minimum value
    ipdf = pdf.loc[pdf['NOx'] >= minval]
    # plot
    dat = ipdf
    #proj = ccrs.epsg('3857')
    ax.add_feature(cartopy.feature.COASTLINE, edgecolor='grey', facecolor='none')
    ax.add_feature(cartopy.feature.BORDERS, edgecolor='grey', facecolor='none')
    ax.add_feature(cartopy.feature.NaturalEarthFeature('cultural', 'admin_1_states_provinces_lines', '50m', edgecolor='lightgray', facecolor='none'))
    ax.set_extent([lon1,lon2,lat1,lat2],crs=ccrs.PlateCarree())
    ax.set_title(title) if title is not None else ax.set_title(res)
    im = dat.plot(column='NOx',ax=ax,alpha=alpha,vmin=minval,vmax=maxval,cmap=cmap,legend=legend,legend_kwds={'extend':'both','label':'NO$_{x}$ [ppbv]'})
    #ctx.add_basemap(ax,source=ctx.providers.Esri.WorldTerrain,crs=dat.crs.to_string())
    #ctx.add_basemap(ax,source=ctx.providers.Stamen.TonerLite,crs=dat.crs.to_string())
    #ctx.add_basemap(ax,source=ctx.providers.HikeBike.HillShading,crs=dat.crs.to_string())
    #ctx.add_basemap(ax,source=ctx.providers.CartoDB.Positron,crs=dat.crs.to_string())
    ctx.add_basemap(ax,source=ctx.providers.OpenStreetMap.Mapnik,crs=dat.crs.to_string())
    return ax,pdf,im


    # other way of plotting    
    fig, axs = plt.subplots(1,1,figsize=(15,10))
    ax = axs
    ax.axis(exts)
    dat.plot(column='SpeciesConc_NO2',ax=ax,alpha=0.5,vmin=minval,vmax=maxval,cmap=cmap,legend=True)
    #ctx.add_basemap(ax,source=ctx.providers.Esri.WorldTerrain,crs=dat.crs.to_string())
    ctx.add_basemap(ax,source=ctx.providers.HikeBike.HillShading,crs=dat.crs.to_string())
    # to crop outside
    #ax.add_geometries( msk, dat.crs, zorder=12, facecolor='white', edgecolor='none', alpha=0.65)
    plt.savefig('test.png')
    plt.close()


def parse_args():
    p = argparse.ArgumentParser(description='Undef certain variables')
    p.add_argument('-v','--var',type=str,help='variable',default='O3dry')
    p.add_argument('-v2','--var2',type=str,help='variable 2',default=None)
    p.add_argument('-n','--name',type=str,help='name',default=None)
    p.add_argument('-s','--scal',type=float,help='dimension',default=1.0e9)
    return p.parse_args()

if __name__ == '__main__':
    main(parse_args()) 
