import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import netCDF4 as nc
from matplotlib.tri import Triangulation, LinearTriInterpolator
import os
from datetime import datetime
import matplotlib.ticker as mticker
from matplotlib.colors import BoundaryNorm, ListedColormap


# Open the NetCDF file
# Replace 'your_file.nc' with your actual file path
cn51_file = "/discover/nobackup/jkolassa/CN51_global/GEOSldas_CN51_280_global_c2/output/SMAP_EASEv2_M36/cat/ens0000/Y1980/M07/GEOSldas_CN51_280_global_c2.tavg24_1d_lnd_Nt.monthly.198007.nc4"
dataset = nc.Dataset(cn51_file, 'r')

# Print basic information about the dataset (optional)
print("Variables in the dataset:", list(dataset.variables.keys()))

# Read latitude, longitude, and data variable
# Replace 'lat', 'lon', and 'temperature' with your actual variable names
lats = dataset.variables['lat'][:]
lons = dataset.variables['lon'][:]
data = dataset.variables['CNTLAI'][:]  # Replace with your variable of interest

# If data is multi-dimensional, you might need to select a specific slice
# For example, if it's time-dependent, you might want the first time step:
# data = dataset.variables['temperature'][0, :, :]

lons = np.asarray(lons).flatten()
lats = np.asarray(lats).flatten()
data = np.asarray(data).flatten()

# create regular grid

unique_lats, indices = np.unique(lats, axis=0, return_index=True)
unique_lons, indices = np.unique(lons, axis=0, return_index=True)
lats_grid, lons_grid = np.meshgrid(unique_lats, unique_lons)


# DEBUGGING: Print array shapes
print(f"Shape of lons array: {unique_lons.shape}")
print(f"Shape of lats array: {unique_lats.shape}")
print(f"Shape of lons array: {lats_grid.shape}")
print(f"Shape of lats array: {lons_grid.shape}")


# map data points to regular grid

for i in range(5):
    lat_index = np.where(unique_lats == lats[i])
    lon_index = np.where(unique_lons == lons[i])
    print(f"  lats {i}: {lats:.2f}")
    print(f"  lons {i}: {lons:.2f}")

cbar_min = np.min(data) if np.isfinite(np.min(data)) else 0  # Set your minimum value here
cbar_max = np.max(data) if np.isfinite(np.max(data)) else 20  # Set your maximum value here


# OPTION A: Create discrete colorbar with evenly spaced intervals
# Define the number of discrete colors/levels
num_colors = 10

# Create discrete bin edges
levels = np.linspace(cbar_min, cbar_max, num_colors + 1)

# PRINT THE LEVELS TO INSPECT VALUES
print("\nLevel boundaries:")
for i, level in enumerate(levels):
    print(f"  Level {i}: {level:.2f}")


# Create a BoundaryNorm to map data values to discrete colors
norm = BoundaryNorm(levels, num_colors)

# Create a figure with a map projection
plt.figure(figsize=(10, 6))
ax = plt.axes(projection=ccrs.PlateCarree())

# Add map features
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAND, alpha=0.5)
ax.gridlines(draw_labels=True)

# Option 1: Direct scatter plot of irregular points
# Good for visualizing actual measurement locations
'''
#continuous colorbar with defined min/max
scatter = ax.scatter(lons, lats, c=data, cmap='YlGn', 
                     transform=ccrs.PlateCarree(), s=5, alpha=0.8,
                     vmin=cbar_min, vmax=cbar_max)
'''
'''
#discrete colorbar
scatter = ax.scatter(lons, lats, c=data, cmap='YlGn',   
                     transform=ccrs.PlateCarree(), s=5, alpha=0.8,
                     norm=norm)
'''


# Option 2: Create a triangulation for irregular grid and plot as a filled contour
# Uncomment the following section if you want interpolated visualization

# TRIANGULATION APPROACH WITH ERROR HANDLING
try:
    # Filter out any invalid values before triangulation
    valid_mask = np.isfinite(lons) & np.isfinite(lats) & np.isfinite(data)
    if not np.all(valid_mask):
        print(f"Filtering out {np.sum(~valid_mask)} invalid points")
        lons = lons[valid_mask]
        lats = lats[valid_mask]
        data = data[valid_mask]

    # Create a triangulation of the irregular points
    print("Creating triangulation...")
    triang = Triangulation(lons, lats)

    # Define the plotting grid (regular)
    grid_size = 500
    xi = np.linspace(np.min(lons), np.max(lons), grid_size)
    yi = np.linspace(np.min(lats), np.max(lats), grid_size)
    xi_grid, yi_grid = np.meshgrid(xi, yi)

    # Interpolate data onto the regular grid
    print("Interpolating data...")
    interp = LinearTriInterpolator(triang, data)
    zi_grid = interp(xi_grid, yi_grid)

    # Plot as a filled contour
    print("Creating contour plot...")
    contour = ax.contourf(xi, yi, zi_grid, 
                      transform=ccrs.PlateCarree(),
                      cmap='viridis', levels=15)
    plot_obj = contour  # For colorbar reference

    print("Triangulation and interpolation successful!")
    
except Exception as e:
    print(f"Error in triangulation: {e}")
    print("Falling back to scatter plot method...")

    # Fall back to scatter plot if triangulation fails
    scatter = ax.scatter(lons, lats, c=data, cmap='viridis', 
                        transform=ccrs.PlateCarree(), s=5, alpha=0.8,
                        norm=norm)
    plot_obj = scatter  # For colorbar reference

# Plot the data
'''
# This assumes your data is on a regular lat-lon grid
mesh = plt.pcolormesh(lons, lats, data, transform=ccrs.PlateCarree(), cmap='viridis')
'''
# Add colorbar and labels

#continuous
cbar = plt.colorbar(plot_obj, shrink=0.6,
       ticks=mticker.MultipleLocator(5))

'''
# Add discrete colorbar with edges at the specified levels
cbar = plt.colorbar(scatter, shrink=0.6, 
                   ticks=levels,  # Show ticks at bin edges
                   spacing='proportional',  # Equal spacing for each bin
                   boundaries=levels)  # Set bin boundaries
'''
cbar.set_label('LAI [-]')  # Change to match your variable
plt.title('LAI 1980-07')  # Change to match your variable

# Close the dataset
dataset.close()

# Show the plot
plt.tight_layout()

# Create output directory if it doesn't exist
output_dir = '/discover/nobackup/jkolassa/CN51_plots/python_test_plots/'
os.makedirs(output_dir, exist_ok=True)

# Generate a filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
variable_name = 'CNTLAI'  # Change to match your variable
save_filename = f"{output_dir}/{variable_name}_198007.png"

# Save the figure
plt.savefig(save_filename, dpi=300, bbox_inches='tight')
print(f"Figure saved as: {save_filename}")

plt.show()
