import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import netCDF4 as nc
from matplotlib.tri import Triangulation, LinearTriInterpolator
import os
from datetime import datetime
import matplotlib.ticker as mticker
from matplotlib.colors import BoundaryNorm, ListedColormap


# Open the NetCDF file
# Replace 'your_file.nc' with your actual file path
cn51_file = "/discover/nobackup/jkolassa/CN51_global/GEOSldas_CN51_280_global_c2/output/SMAP_EASEv2_M36/cat/ens0000/Y1980/M07/GEOSldas_CN51_280_global_c2.tavg24_1d_lnd_Nt.monthly.198007.nc4"
dataset = nc.Dataset(cn51_file, 'r')

# Print basic information about the dataset (optional)
print("Variables in the dataset:", list(dataset.variables.keys()))

# Read latitude, longitude, and data variable
# Replace 'lat', 'lon', and 'temperature' with your actual variable names
lats = dataset.variables['lat'][:]
lons = dataset.variables['lon'][:]
data = dataset.variables['CNTLAI'][:]  # Replace with your variable of interest

# If data is multi-dimensional, you might need to select a specific slice
# For example, if it's time-dependent, you might want the first time step:
# data = dataset.variables['temperature'][0, :, :]

lons = np.asarray(lons).flatten()
lats = np.asarray(lats).flatten()
data = np.asarray(data).flatten()


# DEBUGGING: Print array shapes
print(f"Shape of lons array: {lons.shape}")
print(f"Shape of lats array: {lats.shape}")
print(f"Shape of data array: {data.shape}")

# Check and ensure all arrays have the same length
if len(lons) != len(lats) or len(lons) != len(data):
    print("ERROR: Input arrays have different lengths!")
    
    # If data has additional dimensions, try to flatten or select appropriate slice
    if hasattr(data, 'ndim') and data.ndim > 1:
        print(f"Data has {data.ndim} dimensions with shape {data.shape}")
        print("Attempting to extract a compatible 1D slice...")
        
        # Option 1: If data is 2D, take the first row/column
        if data.ndim == 2:
            if data.shape[0] == len(lons):
                data = data[:, 0]  # Take first column
            elif data.shape[1] == len(lons):
                data = data[0, :]  # Take first row
        
        # Option 2: If data has more dimensions, flatten it if appropriate
        # CAUTION: Only do this if you're sure the flattened data corresponds
        # to the lat/lon points in the correct order
        elif np.prod(data.shape) == len(lons):
            data = data.flatten()
            
    print(f"After adjustment, data shape: {data.shape}")

# Ensure all arrays are 1D with the same length
lons = np.asarray(lons).flatten()
lats = np.asarray(lats).flatten()
data = np.asarray(data).flatten()

# Verify all arrays now have the same length
if len(lons) != len(lats) or len(lons) != len(data):
    print(f"ERROR: Arrays still have different lengths after adjustment.")
    print(f"  Lons: {len(lons)}, Lats: {len(lats)}, Data: {len(data)}")
    print("Please check your data structure and ensure proper alignment.")
    # You might need to exit here if arrays still don't match
else:
    print(f"All arrays now have the same length: {len(lons)}")

cbar_min = np.min(data) if np.isfinite(np.min(data)) else 0  # Set your minimum value here
cbar_max = np.max(data) if np.isfinite(np.max(data)) else 20  # Set your maximum value here


# OPTION A: Create discrete colorbar with evenly spaced intervals
# Define the number of discrete colors/levels
num_colors = 10

# Create discrete bin edges
levels = np.linspace(cbar_min, cbar_max, num_colors + 1)

# PRINT THE LEVELS TO INSPECT VALUES
print("\nLevel boundaries:")
for i, level in enumerate(levels):
    print(f"  Level {i}: {level:.2f}")


# Create a BoundaryNorm to map data values to discrete colors
norm = BoundaryNorm(levels, num_colors)

# Create a figure with a map projection
plt.figure(figsize=(10, 6))
ax = plt.axes(projection=ccrs.PlateCarree())

# Add map features
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAND, alpha=0.5)
ax.gridlines(draw_labels=True)

# Option 1: Direct scatter plot of irregular points
# Good for visualizing actual measurement locations
'''
#continuous colorbar with defined min/max
scatter = ax.scatter(lons, lats, c=data, cmap='YlGn', 
                     transform=ccrs.PlateCarree(), s=5, alpha=0.8,
                     vmin=cbar_min, vmax=cbar_max)
'''
'''
#discrete colorbar
scatter = ax.scatter(lons, lats, c=data, cmap='YlGn',   
                     transform=ccrs.PlateCarree(), s=5, alpha=0.8,
                     norm=norm)
'''


# Option 2: Create a triangulation for irregular grid and plot as a filled contour
# Uncomment the following section if you want interpolated visualization

# TRIANGULATION APPROACH WITH ERROR HANDLING
try:
    # Filter out any invalid values before triangulation
    valid_mask = np.isfinite(lons) & np.isfinite(lats) & np.isfinite(data)
    if not np.all(valid_mask):
        print(f"Filtering out {np.sum(~valid_mask)} invalid points")
        lons = lons[valid_mask]
        lats = lats[valid_mask]
        data = data[valid_mask]

    # Create a triangulation of the irregular points
    print("Creating triangulation...")
    triang = Triangulation(lons, lats)

    # Define the plotting grid (regular)
    grid_size = 500
    xi = np.linspace(np.min(lons), np.max(lons), grid_size)
    yi = np.linspace(np.min(lats), np.max(lats), grid_size)
    xi_grid, yi_grid = np.meshgrid(xi, yi)

    # Interpolate data onto the regular grid
    print("Interpolating data...")
    interp = LinearTriInterpolator(triang, data)
    zi_grid = interp(xi_grid, yi_grid)

    # Plot as a filled contour
    print("Creating contour plot...")
    contour = ax.contourf(xi, yi, zi_grid, 
                      transform=ccrs.PlateCarree(),
                      cmap='viridis', levels=15)
    plot_obj = contour  # For colorbar reference

    print("Triangulation and interpolation successful!")
    
except Exception as e:
    print(f"Error in triangulation: {e}")
    print("Falling back to scatter plot method...")

    # Fall back to scatter plot if triangulation fails
    scatter = ax.scatter(lons, lats, c=data, cmap='viridis', 
                        transform=ccrs.PlateCarree(), s=5, alpha=0.8,
                        norm=norm)
    plot_obj = scatter  # For colorbar reference

# Plot the data
'''
# This assumes your data is on a regular lat-lon grid
mesh = plt.pcolormesh(lons, lats, data, transform=ccrs.PlateCarree(), cmap='viridis')
'''
# Add colorbar and labels

#continuous
cbar = plt.colorbar(plot_obj, shrink=0.6,
       ticks=mticker.MultipleLocator(5))

'''
# Add discrete colorbar with edges at the specified levels
cbar = plt.colorbar(scatter, shrink=0.6, 
                   ticks=levels,  # Show ticks at bin edges
                   spacing='proportional',  # Equal spacing for each bin
                   boundaries=levels)  # Set bin boundaries
'''
cbar.set_label('LAI [-]')  # Change to match your variable
plt.title('LAI 1980-07')  # Change to match your variable

# Close the dataset
dataset.close()

# Show the plot
plt.tight_layout()

# Create output directory if it doesn't exist
output_dir = '/discover/nobackup/jkolassa/CN51_plots/python_test_plots/'
os.makedirs(output_dir, exist_ok=True)

# Generate a filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
variable_name = 'CNTLAI'  # Change to match your variable
save_filename = f"{output_dir}/{variable_name}_198007.png"

# Save the figure
plt.savefig(save_filename, dpi=300, bbox_inches='tight')
print(f"Figure saved as: {save_filename}")

plt.show()
