#!/discover/nobackup/cakelle2/conda/envs/gcpy_env/bin/python
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cf
import xarray as xr
import matplotlib.colors as colors
import numpy as np

def _add_line(xlon,ylat,proj,**kwargs):
    ylat = ylat.values
    xlon = xlon.values
    xlon[xlon>180.0] = xlon[xlon>180.0] - 360.0
    # split data if jumping over dateline:
    if xlon.max()>0.0 and xlon.min()<0.0:
        idx = xlon<0.0
        xlon1 = xlon[idx]
        ylat1 = ylat[idx]
        # deal with edge case where difference between two cells is huge
        if abs(np.diff(xlon1).max())>10. or abs(np.diff(ylat1).max())>10:
           lidx = np.where(abs(np.diff(xlon1))>10.0)
           lidx = lidx[0]
           midx = np.where(abs(np.diff(ylat1))>10.0)
           midx = midx[0]
           lidx = sorted(np.concatenate((lidx,midx)))
           for l in range(len(lidx)+1):
               l0 = 0 if l==0 else lidx[l-1]+1
               l1 = len(xlon1) if l==len(lidx) else lidx[l] 
               if (l1-l0)<2:
                   continue
               #plt.plot(xlon1[l0:l1],ylat1[l0:l1])
               _ = plt.plot(xlon1[l0:l1],ylat1[l0:l1],transform=proj,**kwargs)
#           nidx1 = np.arange(lidx[0])
#           nidx2 = np.arange(start=lidx[0]+1,stop=len(xlon1))
#           _ = plt.plot(xlon1[nidx1],ylat1[nidx1],transform=proj,**kwargs)
#           _ = plt.plot(xlon1[nidx2],ylat1[nidx2],transform=proj,**kwargs)
        else:
           _ = plt.plot(xlon1,ylat1,transform=proj,**kwargs)
        # haven't encountered this issue for positive values...
        xlon2 = xlon[~idx]
        ylat2 = ylat[~idx]
        _ = plt.plot(xlon2,ylat2,transform=proj,**kwargs)
    else:
        _ = plt.plot(xlon,ylat,transform=proj,**kwargs)
    return


# Load data 
ds = xr.open_dataset('/discover/nobackup/projects/gmao/geos_cf_dev/cakelle2/JEDI/runs/c540/geoscf_jedi_2021/c540_test.geoscf_jedi.20210810_1500z.nc4')
#ds = xr.open_dataset('/discover/nobackup/projects/gmao/geos_cf_dev/cakelle2/v11/c270_test/holding/gcc_dev_cs/c270_test.gcc_dev_cs.20200415_1500z.nc4')
#ds = xr.open_dataset('/discover/nobackup/projects/gmao/geos_cf_dev/cakelle2/v11/c540_test/holding/geoscf_jedi/c540_test.geoscf_jedi.20200803_0300z.nc4')
#ds = xr.open_dataset('geoscf_jedi_2021_nostretch/c540_test.geoscf_jedi.20210810_0300z.nc4')

#data = ds['SpeciesConc_O3'].isel(time=0, lev=71).squeeze() * 1.0e9
data = ds['NO2'].isel(time=0, lev=71).squeeze() * 1.0e9
#data = ds['O3'].isel(time=0, lev=71).squeeze() * 1.0e9


# O3:
#cmap='Purples'; vmin=2.0; vmax=80.0

# NO2:
cmap='pink_r'; vmin=0.01; vmax=10.0



# Setup axes
#face_idx = 2
if True:
    #proj=ccrs.PlateCarree()
    proj = ccrs.NearsidePerspective(central_longitude=-98.35,central_latitude=39.5)
    ##ax = plt.axes(projection=ccrs.EqualEarth())
    #ax = plt.axes(projection=ccrs.PlateCarree())
    ax = plt.axes(projection=proj)
    ax.set_global()
    ax.add_feature(cf.STATES,linewidth=0.5) #,edgecolor='darkgray',linewidth=0.5)
    ax.coastlines(zorder=10)
    ax.add_feature(cf.BORDERS,zorder=20)
    # Plot data on each face
    proj=ccrs.PlateCarree()
    #proj = ccrs.NearsidePerspective(central_longitude=-98.35,central_latitude=39.5)
    for face_idx in range(6):
    #face_idx = 3
    #if True:
        x = ds.corner_lons.isel(nf=face_idx)
        y = ds.corner_lats.isel(nf=face_idx)
        v = data.isel(nf=face_idx)
        if False:
            pcm = plt.pcolormesh(
                x, y, v,
                transform=proj,
                norm=colors.LogNorm(vmin=vmin, vmax=vmax),
                cmap=cmap,
                vmin=vmin, vmax=vmax, zorder=1
            )
        v[:] = .6 if face_idx == 5 else -.6
        # further adjustments to match TEMPO domain:
        if face_idx == 0: # south of main face
            v[0:50,:] = 0.6
        if face_idx == 1: # west of main face
            v[0:50,:] = 0.6
        if face_idx == 3: # north of main face
            v[:,-50:-1] = 0.6
        if face_idx == 4: # east of main face
            v[:,-50:-1] = 0.6
        pcm = plt.pcolormesh(
            x, y, v,
            transform=proj,
            cmap="bwr",
            vmin=-1, vmax=1, #zorder=2, alpha=0.2
        )
        # plot grid lines
        linecol2 = 'gainsboro'
        for i in np.arange(start=20,stop=x.shape[0],step=10):
        #i=220
        #if True:
            _add_line(x[i,:],y[i,:],proj,color=linecol2,linewidth=0.25,zorder=25)
            _add_line(x[:,i],y[:,i],proj,color=linecol2,linewidth=0.25,zorder=25)
        # plot four edges
        #linecol = 'lightgray'
        linecol = 'dimgray'
        _add_line(x[-1,:],y[-1,:],proj,color=linecol,linestyle='-',linewidth=2.0,zorder=30) #,color='red')
        _add_line(x[0,:],y[0,:],proj,color=linecol,linestyle='-',linewidth=2.0,zorder=30) #,color='blue')
        _add_line(x[:,0],y[:,0],proj,color=linecol,linestyle='-',linewidth=2.0,zorder=30) #,color='green')
        _add_line(x[:,-1],y[:,-1],proj,color=linecol,linestyle='-',linewidth=2.0,zorder=30) # ,color='purple')
    #ax.set_extent((-130,-65,10,50))
    #plt.colorbar(pcm, orientation='horizontal',shrink=0.8,pad=0.02)
    #plt.title('Ozone [ppbv]') # at 2020-08-10 15z')
    #plt.title('Nitrogen Dioxide [ppbv]') # at 2020-08-10 15z')
    #plt.savefig('o3_c540_face'+str(face_idx)+'.png',bbox_inches='tight',dpi=200)
    plt.savefig('c540_stretched_grid_TEMPOextension.png',bbox_inches='tight',dpi=200)
    plt.tight_layout()
    plt.close()

