
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
import xarray as xr
import time
from dask.distributed import Client, LocalCluster
import dask as da


def plot_data(sites=['ASP','BIR']):
      
    # Read in input from MERRA 
    dir_in =  "/gpfsm/dnb31/khbreen/ML_models/MAMnet/M2/" 
    X_lst = []
    for site in sites:
      print(site)
      
      ###########################
      ######## GET MERRA ########
      ###########################  
      merra_pth  = dir_in + "/Merra_input_asmi_72lv_" + site + "*.nc"
      #print('merra path ---',merra_pth)
      
      mass_vars = ['SO4', 'SS', 'OG', 'BC', 'DU']
      sorg_vars = ['SO4', 'OG']
      feats_all = xr.open_mfdataset(merra_pth, parallel=True)[mass_vars]
      feats_sorg = xr.open_mfdataset(merra_pth, parallel=True)[sorg_vars]
      #feats_all.name = site
      #feats_sorg.name = site
      
      
      # calculate total mass
      total_mass = sum(d for d in feats_all.data_vars.values())
      sorg_mass = sum(d for d in feats_sorg.data_vars.values())
      
      #print('TOTAL MASS', total_mass)
      print('SORG_MASS', sorg_mass)
      
      '''   
      # is this neccessary?                     
      Xall = feat_in.to_array()
      Xall = Xall.stack(ft=('variable', 'lev'))                          
      Xall = Xall.squeeze()    
      print('XALL site',site,Xall)                   
    
      Xall =  Xall.persist() 
      
      Xall.name = site
      X_lst.append(Xall)
      '''
      # make plots per site
      # total mass vs species
      # sulphates + organics vs species
      # vertical profiles for total_mass and species mass per site
      #plotting options
      plt.switch_backend('agg')
      txt =  "{:.2f}"
      SMALLER_SIZE = 6
      SMALL_SIZE = 7 #5
      MEDIUM_SIZE = 8 #6
      BIGGER_SIZE = 9 #7

      plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
      plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
      plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
      plt.rc('xtick', labelsize=SMALLER_SIZE)    # fontsize of the tick labels
      plt.rc('ytick', labelsize=SMALLER_SIZE)    # fontsize of the tick labels
      plt.rc('legend', fontsize=SMALLER_SIZE)    # legend fontsize
      plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
      
      fig = plt.figure(figsize=(9,6.5))
      pos = 0  # subplot position
      levs = list(range(1,73))
      print('levs', levs)
      
      for i, elem in enumerate(mass_vars):
          print('====',site,i,elem,'====')
          
          # plot column mean of total mass vs species mass
          #print('TOTAL MASS V SPECIES')
          pos += 1
          axn = fig.add_subplot(5, 3, pos)
          x = total_mass.mean('lev')
          y = feats_all[elem].mean('lev')
          #print('x',x)
          #print('y',y)
          total_v_species = axn.plot(x, y, 'b.')
          axn.set_ylabel(elem + '\n Mass kg kg-1')
          if pos == 13:
            axn.set_xlabel('Total Mass kg kg-1 \n column mean')
          #axn.set_aspect('equal')
          
          # plot column mean of soulphates+organsics mass vs species mass
          #print('SORG V SPECIES')
          pos += 1
          axn = fig.add_subplot(5, 3, pos)
          x = sorg_mass.mean('lev')
          y = feats_all[elem].mean('lev')
          #print('x',x)
          #print('y',y)
          sorg_v_species = axn.plot(x, y, 'r.')
          axn.set_ylabel('Mass kg kg-1')
          if pos == 14:
            axn.set_xlabel('SO4+OG Mass kg kg-1 \n column mean')
          #axn.set_aspect('equal')
          
          # plot vertical profiles for total, sorg, and species masses (mean over time)
          #print('VERT PROFILES')
          pos += 1
          axn = fig.add_subplot(5, 3, pos)
          total_m = total_mass.mean('time')
          sorg_m = sorg_mass.mean('time')
          species_m = feats_all[elem].mean('time')
          #print('total_m',total_m)
          #print('sorg_m', sorg_m)
          #print('species_m', species_m)
          total_v_lev = axn.plot(total_m, levs, 'b-', label='Total')
          sorg_v_lev = axn.plot(sorg_m, levs, 'r-', label='SO4 + OG')
          species_v_lev = axn.plot(species_m, levs, 'k-', label=elem)
          axn.set_ylabel('Model level')
          if pos == 15:
            axn.set_xlabel('Mass kg kg-1 \n temporal mean')
          #axn.set_aspect('equal')
          curves = total_v_lev + sorg_v_lev + species_v_lev
          labels = [c.get_label() for c in curves]
          axn.legend(curves, labels, loc=0)
      
      #plt.tight_layout()    
      fig.suptitle(site, fontsize=10)    
      plt.savefig('merra_mass_'+site+'.png')
      
      
      
    #Xall_all = xr.concat(X_lst, dim="time", fill_value = 0, join='override')

    #return Xall_all


#=========================================
#=========================================
#=========================================
if __name__ == '__main__':

    #######################################################
    # GET DATA > CONVERT > SHUFFLE > PARTITION
    #######################################################
    print("****GET DATA****")
    # get data
    chunk_size = 256
    # sites we don't have a merra file for: BOS, JRC, SSL, WAL
    sites = ['PDD', 'MPZ', 'ASP', 'BEO', 'BIR', 'CBW', 'CMN', 'FKL', 'HPB', 'HWL', 'JFJ', 'KPO', 'MHD', 'OBK', 'PAL', 'PLA', 'SMR', 'VHL', 'ZEP', 'ZSF']
    plot_data(sites=sites)
