import numpy as np
import os
import glob
import pickle
import xarray as xr
import time
import xesmf as xe

 
def get_random_files(pth, nts):
  lall = glob.glob(pth) 
  #print(len(lall)) 
  f_ind = np.random.randint(0, len(lall)-1, nts)
  print(f_ind)
  fils = [lall[i] for i in f_ind] 
 # print("==========fils")
 # print(fils)
  return fils

#'/css/g5nr/Ganymed/7km/c1440_NR/DATA/0.0625_deg/inst/inst30mn_3d_W_Nv/Y2006/M05/D07/c1440_NR.inst30mn_3d_W_Nv.20060507_2230z.nc4'
#'/css/g5nr/Ganymed/7km/c1440_NR/DATA/0.5000_deg/tavg/tavg01hr_3d_T_Cv/Y2006/M05/D07/c1440_NR.tavg01hr_3d_T_Cv.20060507_0830z.nc4'

def switch_var(V, fils):
  
  old = '0.0625_deg/inst/inst30mn_3d_W_Nv'
  new = '0.5000_deg/tavg/tavg01hr_3d_' + V + '_Cv'
  #print('=============', fils)
  flsv = [sub.replace(old, new) for sub in fils]
  old  = 'inst30mn_3d_W_Nv'
  new =  'tavg01hr_3d_' + V + '_Cv'
  flsv = [sub.replace(old, new) for sub in flsv]  
  return [flsv]

def dens (ds): 
    d = ds.PL/287.0/ds.T
    ds.PL.data = d
    return ds

def QCT (ds): 
    d = ds.QL + ds.QI    
    ds.QL.data = d
    return ds
     
class get_dts():
   def __init__(self, ndts =  1, nam ="def", batch_size = 32000):  #creates a class that will handle nts files
    yr =  "Y2006/" 
    mo = "M*/"
    dy =  "D*/*30z*"#_[01]230z*"
  
    self.batch_size = batch_size
    self.lev1 = 1
    self.lev2 = 72
    self.vars_in = ['T', 'PL', 'U', 'V', 'W', 'KM', 'RI', 'QV', 'QI', 'QL'] 
    self.feats = len(self.vars_in)-1
    self.chk = { "lat": -1, "lon": -1, "lev":  -1, "time": 1} # needed so we can regrid
    self.in_dir  =  "/css/g5nr/Ganymed/7km/c1440_NR/DATA/0.5000_deg/tavg/tavg01hr_3d_"
    self.out_dir  = "/css/g5nr/Ganymed/7km/c1440_NR/DATA/0.0625_deg/inst/inst30mn_3d_W_Nv/"
    self.path_out =  self.out_dir +  yr + mo + dy
    self.create_regridder =  True       
    self.name = nam
    self.printvar = 1  
    #get nts files 
    self.fls = get_random_files(self.path_out,ndts)
    #with open(nam, 'wb') as fp:
    #      pickle.dump(self, fp) 
   
   def get_fls_batch (self, dt_batch_size):
     for i in range(0, len(self.fls), dt_batch_size):
       yield self.fls[i:i + dt_batch_size]
       print('====get fls batch',self.fls[i:i + dt_batch_size])
       if i >= len(self.fls):
           i = 0
             
   def  get_data(self, this_fls):
     
      dat_out =  xr.open_mfdataset(this_fls, chunks=self.chk, parallel=True, engine='h5netcdf').sel(lev=slice(self.lev1,self.lev2)) 
      dat_out = dat_out.rename({"W":"Wstd"})
      self.dat_out =  dat_out.coarsen(lat=8, lon=8, boundary="trim").std() #coarsen to about half degree using standard deviation as lumping function      
      self.levs =  len(self.dat_out['lev']) 
      
      print(this_fls)
    # loop through features and load the same time steps
      vars_in  = self.vars_in
      self.n_features_in_ = len(vars_in)*self.levs
      self.feats = len(vars_in)  
      dat_in = []
      m=0
      for v in vars_in:  
         flsv = switch_var(v, this_fls) 
         flsv = flsv[0]
         dat =  xr.open_mfdataset(flsv, chunks=self.chk, parallel=True, engine='h5netcdf').sel(lev=slice(self.lev1,self.lev2)) 
      
         if m ==0:
           dat_in = dat
           m=1
         else:          
           dat_in =  xr.merge([dat_in, dat])
      dat.close()        
        ###Calculate density 
      dat_in = dat_in.unify_chunks()           
      da= xr.map_blocks(dens, dat_in, template=dat_in)
      dat_in = da.rename({"PL":"AIRD"})
      
      self.dat_in =  dat_in[['T', 'AIRD', 'U', 'V', 'W', 'KM', 'RI', 'QV', 'QI', 'QL']]
      dat_in.close()
           
    
   def regrid_out(self, create_regridder = True):
    #regrid output
      if self.create_regridder:  
          self.regridder = xe.Regridder(self.dat_out, self.dat_in, 'bilinear', periodic=True) #make sure they are exactly the same grid 
          self.create_regridder =  False    #this is done bacause there is a bug in xesmf when two consecutive regridders are created on the fly 
      self.dat_out =  self.regridder(self.dat_out)
           
   def get_Xy(self, make_y_array =  True,  batch_size  =512, test = False, x_reshape =False):
   
        self.batch_size = batch_size
        if test:
          # self.fls = get_random_files(self.path_out,1) #get one time step
           self.get_data(self.fls) 

        self.regrid_out()
        #self.dat_out =  power_scaler(self.dat_out, self.nexp) #scaling
        
        
       #note: indexing and selection in xarray is full of bugs, so this is the only way to do this

        Xall = self.dat_in
        levs = Xall.coords['lev'].values
        nlev =  len(levs)
      
        yall = self.dat_out#.stack(s = ('time', 'lat', 'lon', 'lev' ))
        
        print('Xall=======', Xall)  
        print('yall=======', yall) 
        
        return Xall, yall

#=========================================
#=========================================
#=========================================
if __name__ == '__main__':

    batch_size = 512 #actual batch size
    ndts_train = 10#int(np.ceil(nepochs/epochs_per_batch)*dt_batch_size)
    dt_batch_size = ndts_train # redundant
    train_data =  get_dts(ndts=ndts_train, nam = 'train_data', batch_size =  batch_size)
    
 
    levs =  train_data.lev2-train_data.lev1 + 1

    #print('===train==', train_data.fls)
    fls_batch =  train_data.get_fls_batch (dt_batch_size)
    train_data.get_data(this_fls = next(fls_batch))                 
    X_train, y_train = train_data.get_Xy(batch_size  =  batch_size) 
    
    X_train.to_netcdf("RIPS_Wnet_data.nc", "w")
    y_train.to_netcdf("RIPS_Wnet_data.nc", "a")
    
    print('\n\n\n\n\n\n')
 
    exit()
