import xarray as xr
import pandas as pd
import numpy as np
import datetime

def process_grace_for_grads(input_file, output_file):
    print(f"Opening input file: {input_file}")
    
    # 1. Load the dataset WITHOUT decoding first
    # This allows us to fix the non-standard "Units" (capitalized) attribute
    # which prevents xarray from automatically understanding the time.
    ds = xr.open_dataset(input_file, decode_times=False)
    
    # Fix capitalization of 'units' in time attribute
    # The input metadata shows 'Units', but CF conventions require lowercase 'units'
    if 'Units' in ds.time.attrs and 'units' not in ds.time.attrs:
        print("Fixing metadata: Renaming capitalized 'Units' to 'units' for time decoding.")
        ds.time.attrs['units'] = ds.time.attrs['Units']
        
    # Now that attributes are fixed, explicitly decode the CF conventions (dates/times)
    try:
        ds = xr.decode_cf(ds)
    except Exception as e:
        print(f"Error decoding time: {e}")
        print("Check if the 'units' string in the NetCDF is valid (e.g., 'days since ...').")
        return

    # Drop time_bounds as they become invalid after interpolation and confuse GrADS 
    if 'time_bounds' in ds:
        ds = ds.drop_vars('time_bounds')
    if 'timebound' in ds.dims:
        ds = ds.drop_dims('timebound')
        
    # Remove the 'bounds' attribute from time if it exists, since we dropped the bounds var
    if 'bounds' in ds.time.attrs:
        del ds.time.attrs['bounds']

    # 2. Construct the Target Time Axis (15th of every month)
    # Get the range of the original data
    start_date = pd.Timestamp(ds.time.values.min())
    end_date = pd.Timestamp(ds.time.values.max())
    
    print(f"Original Time Range: {start_date} to {end_date}")

    # Generate a monthly range starting on the 1st of the month
    # 'MS' = Month Start
    target_dates_monthly = pd.date_range(
        start=start_date.replace(day=1), 
        end=end_date.replace(day=1) + pd.DateOffset(months=1), 
        freq='MS'
    )
    
    # Shift to the 15th
    target_times = target_dates_monthly + pd.Timedelta(days=14)
    
    # Filter targets to ensure we don't exceed the original bounds excessively
    target_times = target_times[
        (target_times >= start_date - pd.Timedelta(days=30)) & 
        (target_times <= end_date + pd.Timedelta(days=30))
    ]

    print(f"Target Time Steps generated: {len(target_times)}")

    # 3. Perform Interpolation with Gap Detection
    # Logic: We interpret to the 15th, BUT if the 15th is too far from a real 
    # observation (indicating a missing month), we mask it as NaN.
    
    # Step A: Linear Interpolation (Calculates values at the 15th based on neighbors)
    ds_interp = ds.interp(time=target_times, method='linear')
    
    # Step B: Nearest Neighbor check for Gaps
    # We reindex using 'nearest' with a tolerance. If no original data exists 
    # within 20 days of the 15th, this returns NaN.
    # We use this as a mask.
    tolerance = pd.Timedelta('20 days') 
    ds_mask = ds.reindex(time=target_times, method='nearest', tolerance=tolerance)
    
    # Apply the mask: Keep interpolated values only where the mask is valid (not NaN)
    # This ensures we don't "bridge" large gaps like the 2017/2018 missing months.
    ds_final = ds_interp.where(ds_mask.notnull())

    # 4. Update Metadata for GrADS/CF Compliance
    
    # Update time attributes to ensure GrADS reads it as linear
    ds_final.time.attrs['standard_name'] = "time"
    ds_final.time.attrs['long_name'] = "Time"
    ds_final.time.attrs['axis'] = "T"
    
    # Preserve original global attributes but add history
    ds_final.attrs = ds.attrs
    current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    history_msg = (f"\n{current_time}: Time axis regularized to the 15th of each month "
                   "using Python xarray. Input 'Units' fixed to 'units'. "
                   "Interpolated linearly with a 20-day tolerance mask.")
    
    if 'history' in ds_final.attrs:
        ds_final.attrs['history'] += history_msg
    else:
        ds_final.attrs['history'] = history_msg

    # 5. Write to NetCDF
    print("Writing output file (this may take a moment)...")
    
    # Encoding specifications
    # zlib=True compresses the file
    # _FillValue ensures GrADS sees 'undef' correctly
    encoding = {
        'lwe_thickness': {'zlib': True, '_FillValue': -99999.0, 'dtype': 'float32'},
        'time': {'units': 'days since 2002-01-01 00:00:00', 'calendar': 'gregorian'}
    }
    
    ds_final.to_netcdf(output_file, format='NETCDF4', encoding=encoding)
    print(f"Success! File saved to: {output_file}")

if __name__ == "__main__":
    # Define your filenames here
    INPUT_NC = 'CSR_GRACE_GRACE-FO_RL0603_Mascons_all-corrections.nc'
    OUTPUT_NC = 'CSR_GRACE_Regridded_15th_For_GrADS.nc'
    
    try:
        process_grace_for_grads(INPUT_NC, OUTPUT_NC)
    except FileNotFoundError:
        print(f"Error: Could not find input file: {INPUT_NC}")
    except Exception as e:
        print(f"An error occurred: {e}")


