import numpy as np
import matplotlib.pyplot as plt
import math

EPS = 0.622  # Mw/Md

def es_goff_gratch_hpa(T):
    """Saturation vapor pressure over liquid water (hPa), Goff–Gratch."""
    Tst = 373.15
    est = 1013.25

    a = Tst / T
    b = 11.344 * (1.0 - T / Tst)
    c = -3.49149 * (a - 1.0)

    f = (-7.90298 * (a - 1.0)
         + 5.02808 * math.log10(a)
         - 1.3816e-7 * (10.0 ** b - 1.0)
         + 8.1328e-3 * (10.0 ** c - 1.0)
         + math.log10(est))
    return 10.0 ** f

def desdT_goff_gratch_hpa_perK(T):
    """d(es)/dT in (hPa/K), analytically differentiated."""
    Tst = 373.15
    est = 1013.25

    a = Tst / T
    da_dT = -a / T

    b = 11.344 * (1.0 - T / Tst)
    db_dT = -11.344 / Tst

    c = -3.49149 * (a - 1.0)
    dc_dT = -3.49149 * da_dT

    ln10 = math.log(10.0)

    # f(T) = log10(es)
    df_dT = (
        -7.90298 * da_dT
        + 5.02808 * (da_dT / (a * ln10))
        - 1.3816e-7 * (ln10 * (10.0 ** b) * db_dT)
        + 8.1328e-3 * (ln10 * (10.0 ** c) * dc_dT)
    )

    es = 10.0 ** (
        (-7.90298 * (a - 1.0)
         + 5.02808 * math.log10(a)
         - 1.3816e-7 * (10.0 ** b - 1.0)
         + 8.1328e-3 * (10.0 ** c - 1.0)
         + math.log10(est))
    )

    # es = 10**f  =>  des/dT = ln(10) * es * df/dT
    return ln10 * es * df_dT

def dqs_dT_goff_gratch(T, p_pa):
    """
    d(qs)/dT in (1/K), with T [K], p_pa [Pa].
    Uses qs = eps*es / (p - (1-eps)*es).
    """
    es_pa = 100.0 * es_goff_gratch_hpa(T)              # hPa -> Pa
    desdT_pa_perK = 100.0 * desdT_goff_gratch_hpa_perK(T)

    denom = (p_pa - (1.0 - EPS) * es_pa)

    # dq/de = eps*p / (p - (1-eps)*e)^2
    dqde = EPS * p_pa / (denom * denom)

    return dqde * desdT_pa_perK


# -----------------------------
# Plotting routine for alpha(T)
# -----------------------------
def plot_alpha_goff_gratch(
    T_min=250.0, T_max=320.0, n=500,
    p_pa=101325.0,
    L=2.4665e6,
    Cp=1004.0, 
    outfile="alpha_goff_gratch.png",
    show=True
):
    """
    Plot alpha(T) = 1 / (1 + (L/Cp) * dqs/dT)
    where dqs/dT comes from Goff–Gratch saturation vapor pressure.
    """
    Ts = np.linspace(T_min, T_max, n)

    # vectorize scalar derivative function safely
    dqs_vec = np.vectorize(lambda TT: dqs_dT_goff_gratch(float(TT), float(p_pa)))
    dqs = dqs_vec(Ts)

    alpha = 1.0 / (1.0 + (L / Cp) * dqs)

    plt.figure()
    plt.plot(Ts, alpha)
    plt.xlabel("Temperature (K)")
    plt.ylabel("alpha")
    plt.title(f"alpha(T) using Goff–Gratch, p={p_pa:.0f} Pa")
    plt.grid(True)
    # ---- save ----
    plt.savefig(outfile, dpi=120, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close()


    return Ts, alpha


plot_alpha_goff_gratch()

