import numpy as np

CtoK = 273.16  # 273.15
""" Conversion factor for (^\circ\,C) to (^\\circ\\,K) """

kappa = 0.4  # NOTE: 0.41
""" von Karman's constant """
# ---------------------------------------------------------------------


def charnock_C35(wind, u10n, usr, seastate, waveage, wcp, sigH, lat):
    """ Calculates Charnock number following Edson et al. 2013 based on 
    C35 matlab code (coare35vn.m)
    
    Parameters
    ----------
    wind : float
        wind speed (m/s)
    u10n : float
        neutral 10m wind speed (m/s)
    usr  : float
        friction velocity (m/s)
    seastate : bool
        0 or 1
    waveage  : bool
        0 or 1
    wcp      : float
        phase speed of dominant waves (m/s)
    sigH     : float
        significant wave height (m)
    lat      : float
        latitude (deg)
    
    Returns
    -------
    ac : float
        Charnock number
    """
    g = gc(lat, None)
    a1, a2 = 0.0017, -0.0050
    charnC = np.where(u10n > 19, a1*19+a2, a1*u10n+a2)
    A, B = 0.114, 0.622  # wave-age dependent coefficients
    Ad, Bd = 0.091, 2.0  # Sea-state/wave-age dependent coefficients
    charnW = A*(usr/wcp)**B
    zoS = sigH*Ad*(usr/wcp)**Bd
    charnS = (zoS*g)/usr**2
    charn = np.where(wind > 10, 0.011+(wind-10)/(18-10)*(0.018-0.011),
                     np.where(wind > 18, 0.018, 0.011*np.ones(np.shape(wind))))
    if waveage:
        if seastate:
            charn = charnS
        else:
            charn = charnW
    else:
        charn = charnC
    ac = np.zeros((len(wind), 3))
    ac[:, 0] = charn
    ac[:, 1] = charnS
    ac[:, 2] = charnW
    return ac
# ---------------------------------------------------------------------


def cd_C35(u10n, wind, usr, charn, monob, Ta, hh_in, lat):
    """ Calculates exchange coefficients following Edson et al. 2013 based on 
    C35 matlab code (coare35vn.m)
    
    Parameters
    ----------
    u10n : float
        neutral 10m wind speed (m/s)
    wind : float
        wind speed (m/s)    
    charn : float
        Charnock number
    monob : float
        Monin-Obukhov stability length
    Ta    : float
        air temperature (K)
    hh_in : float
        input sensor's height (m)
    lat      : float
        latitude (deg)
    
    Returns
    -------
    zo : float
        surface roughness (m)
    cdhf : float
        drag coefficient
    cthf : float
        heat exchange coefficient
    cqhf : float
        moisture exchange coefficient
    """
    g = gc(lat, None)
    zo = charn*usr**2/g+0.11*visc_air(Ta)/usr  # surface roughness
    rr = zo*usr/visc_air(Ta)
    # These thermal roughness lengths give Stanton and
    zoq = np.where(5.8e-5/rr**0.72 > 1.6e-4, 1.6e-4, 5.8e-5/rr**0.72)
    zot = zoq  # Dalton numbers that closely approximate COARE 3.0
    cdhf = kappa/(np.log(hh_in[0]/zo)-psiu_26(hh_in[0]/monob))
    cthf = kappa/(np.log(hh_in[1]/zot)-psit_26(hh_in[1]/monob))
    cqhf = kappa/(np.log(hh_in[2]/zoq)-psit_26(hh_in[2]/monob))
    return zo, cdhf, cthf, cqhf
# ---------------------------------------------------------------------


def cdn_calc(u10n, Ta, Tp, method="S80"):
    """ Calculates neutral drag coefficient
    
    Parameters
    ----------
    u10n : float
        neutral 10m wind speed (m/s)
    Ta   : float
        air temperature (K)
    Tp   : float
        wave period
    method : str
    
    Returns
    -------
    cdn : float
    """
    if (method == "S80"):
        cdn = np.where(u10n <= 3, (0.61+0.567/u10n)*0.001,
                       (0.61+0.063*u10n)*0.001)
    elif (method == "LP82"):
        cdn = np.where((u10n < 11) & (u10n >= 4), 1.2*0.001,
                       np.where((u10n <= 25) & (u10n >= 11),
                       (0.49+0.065*u10n)*0.001, 1.14*0.001))
    elif (method == "S88" or method == "UA" or method == "ERA5"):
        cdn = cdn_from_roughness(u10n, Ta, None, method)
    elif (method == "YT96"):
        # for u<3 same as S80
        cdn = np.where((u10n < 6) & (u10n >= 3),
                       (0.29+3.1/u10n+7.7/u10n**2)*0.001,
                       np.where((u10n <= 26) & (u10n >= 6),
                       (0.60 + 0.070*u10n)*0.001, (0.61+0.567/u10n)*0.001))
    elif (method == "LY04"):
        cdn = np.where(u10n >= 0.5,
                       (0.142+(2.7/u10n)+(u10n/13.09))*0.001, np.nan)
    else:
        print("unknown method cdn: "+method)
    return cdn
# ---------------------------------------------------------------------


def cdn_from_roughness(u10n, Ta, Tp, method="S88"):
    """ Calculates neutral drag coefficient from roughness length
    
    Parameters
    ----------
    u10n : float
        neutral 10m wind speed (m/s)
    Ta   : float
        air temperature (K)
    Tp   : float
        wave period
    method : str
    
    Returns
    -------
    cdn : float
    """
    g, tol = 9.812, 0.000001
    cdn, usr = np.zeros(Ta.shape), np.zeros(Ta.shape)
    cdnn = (0.61+0.063*u10n)*0.001
    zo, zc, zs = np.zeros(Ta.shape), np.zeros(Ta.shape), np.zeros(Ta.shape)
    for it in range(5):
        cdn = np.copy(cdnn)
        usr = np.sqrt(cdn*u10n**2)
        if (method == "S88"):
            # .....Charnock roughness length (equn 4 in Smith 88)
            zc = 0.011*np.power(usr, 2)/g
            # .....smooth surface roughness length (equn 6 in Smith 88)
            zs = 0.11*visc_air(Ta)/usr
            zo = zc + zs  # .....equns 7 & 8 in Smith 88 to calculate new CDN
        elif (method == "UA"):
            # valid for 0<u<18m/s # Zeng et al. 1998 (24)
            zo = 0.013*np.power(usr, 2)/g+0.11*visc_air(Ta)/usr
        elif (method == "ERA5"):
            # eq. (3.26) p.40 over sea IFS Documentation cy46r1
            zo = 0.11*visc_air(Ta)/usr+0.018*np.power(usr, 2)/g
        else:
            print("unknown method for cdn_from_roughness "+method)
        cdnn = (kappa/np.log(10/zo))**2
    cdn = np.where(np.abs(cdnn-cdn) < tol, cdnn, np.nan)
    return cdnn
# ---------------------------------------------------------------------


def ctcqn_calc(zol, cdn, u10n, zo, Ta, method="S80"):
    """ Calculates neutral heat and moisture exchange coefficients
    
    Parameters
    ----------
    zol  : float
        height over MO length
    cdn  : float
        neatral drag coefficient
    u10n : float
        neutral 10m wind speed (m/s)
    zo   : float
        surface roughness (m)
    Ta   : float
        air temperature (K)
    method : str
    
    Returns
    -------
    ctn : float
        neutral heat exchange coefficient
    cqn : float
        neutral moisture exchange coefficient
    """
    if (method == "S80" or method == "S88" or method == "YT96"):
        cqn = np.ones(Ta.shape)*1.20*0.001  # from S88
        ctn = np.ones(Ta.shape)*1.00*0.001
    elif (method == "LP82"):
        cqn = np.where((zol <= 0) & (u10n > 4) & (u10n < 14), 1.15*0.001,
                       np.nan)
        ctn = np.where((zol <= 0) & (u10n > 4) & (u10n < 25), 1.13*0.001,
                       0.66*0.001)
    elif (method == "LY04"):
        cqn = 34.6*0.001*cdn**0.5
        ctn = np.where(zol <= 0, 32.7*0.001*cdn**0.5, 18*0.001*cdn**0.5)
    elif (method == "UA"):
        usr = (cdn * u10n**2)**0.5
        # Zeng et al. 1998 (25)
        zoq = zo*np.exp(-(2.67*np.power(usr*zo/visc_air(Ta), 1/4)-2.57))
        zot = zoq
        cqn = np.where((u10n > 0.5) & (u10n < 18), np.power(kappa, 2) /
                       (np.log(10/zo)*np.log(10/zoq)), np.nan)
        ctn = np.where((u10n > 0.5) & (u10n < 18), np.power(kappa, 2) /
                       (np.log(10/zo)*np.log(10/zoq)), np.nan)
    elif (method == "ERA5"):
        # eq. (3.26) p.40 over sea IFS Documentation cy46r1
        usr = np.sqrt(cdn*np.power(u10n, 2))
        zot = 0.40*visc_air(Ta)/usr
        zoq = 0.62*visc_air(Ta)/usr
        cqn = kappa**2/np.log(10/zo)/np.log(10/zoq)
        ctn = kappa**2/np.log(10/zo)/np.log(10/zot)
    else:
        print("unknown method ctcqn: "+method)
    return ctn, cqn
# ---------------------------------------------------------------------


def cd_calc(cdn, height, ref_ht, psim):
    """ Calculates drag coefficient at reference height
    
    Parameters
    ----------
    cdn : float
        neutral drag coefficient
    height : float
        original sensor height (m)
    ref_ht : float
        reference height (m)
    psim : float
        momentum stability function
    
    Returns
    -------
    cd : float
    """
    cd = (cdn*np.power(1+np.sqrt(cdn)*np.power(kappa, -1) *
          (np.log(height/ref_ht)-psim), -2))
    return cd
# ---------------------------------------------------------------------


def ctcq_calc(cdn, cd, ctn, cqn, h_t, h_q, ref_ht, psit, psiq):
    """ Calculates heat and moisture exchange coefficients at reference height
    
    Parameters
    ----------
    cdn : float
        neutral drag coefficient
    cd  : float
        drag coefficient at reference height
    ctn : float
        neutral heat exchange coefficient
    cqn : float
        neutral moisture exchange coefficient
    h_t : float
        original temperature sensor height (m)
    h_q : float
        original moisture sensor height (m)
    ref_ht : float
        reference height (m)
    psit : float
        heat stability function
    psiq : float
        moisture stability function
    
    Returns
    -------
    ct : float
    cq : float
    """
    ct = ctn*(cd/cdn)**0.5/(1+ctn*((np.log(h_t/ref_ht)-psit)/(kappa*cdn**0.5)))
    cq = cqn*(cd/cdn)**0.5/(1+cqn*((np.log(h_q/ref_ht)-psiq)/(kappa*cdn**0.5)))
    return ct, cq
# ---------------------------------------------------------------------


def psim_calc(zol, method="S80"):
    """ Calculates momentum stability function
    
    Parameters
    ----------
    zol : float
        height over MO length
    method : str
    
    Returns
    -------
    psim : float
    """
    coeffs = get_stabco(method)
    alpha, beta, gamma = coeffs[0], coeffs[1], coeffs[2]
    if (method == "ERA5"):
        psim = np.where(zol < 0, psim_conv(zol, alpha, beta, gamma),
                        psim_stab_era5(zol, alpha, beta, gamma))
    else:
        psim = np.where(zol < 0, psim_conv(zol, alpha, beta, gamma),
                        psim_stab(zol, alpha, beta, gamma))
    return psim
# ---------------------------------------------------------------------


def psit_calc(zol, method="S80"):
    """ Calculates heat stability function
    
    Parameters
    ----------
    zol : float
        height over MO length
    method : str
    
    Returns
    -------
    psit : float
    """
    coeffs = get_stabco(method)
    alpha, beta, gamma = coeffs[0], coeffs[1], coeffs[2]
    if (method == "ERA5"):
        psit = np.where(zol < 0, psi_conv(zol, alpha, beta, gamma),
                        psi_stab_era5(zol, alpha, beta, gamma))
    else:
        psit = np.where(zol < 0, psi_conv(zol, alpha, beta, gamma),
                        psi_stab(zol, alpha, beta, gamma))
    return psit
# ---------------------------------------------------------------------


def get_stabco(method="S80"):
    """ Gives the coefficients \\alpha, \\beta, \\gamma for stability functions
    
    Parameters
    ----------
    method : str
    
    Returns
    -------
    coeffs : float
    """
    if (method == "S80" or method == "S88" or method == "LY04" or
            method == "UA" or method == "ERA5"):
        alpha, beta, gamma = 16, 0.25, 5  # Smith 1980, from Dyer (1974)
    elif (method == "LP82"):
        alpha, beta, gamma = 16, 0.25, 7
    elif (method == "YT96"):
        alpha, beta, gamma = 20, 0.25, 5
    else:
        print("unknown method stabco: "+method)
    coeffs = np.zeros(3)
    coeffs[0] = alpha
    coeffs[1] = beta
    coeffs[2] = gamma
    return coeffs
# ---------------------------------------------------------------------


def psi_stab_era5(zol, alpha, beta, gamma):
    """ Calculates heat stability function for stable conditions 
        for method ERA5
    
    Parameters
    ----------
    zol : float
        height over MO length
    alpha, beta, gamma : float
        constants given by get_stabco
    
    Returns
    -------
    psit : float
    """
    # eq (3.22) p. 39 IFS Documentation cy46r1
    a, b, c, d = 1, 2/3, 5, 0.35
    psit = -b*(zol-c/d)*np.exp(-d*zol)-np.power(1+(2/3)*a*zol, 1.5)-(b*c)/d+1
    return psit
# ---------------------------------------------------------------------

def psi_conv(zol, alpha, beta, gamma):
    """ Calculates heat stability function for unstable conditions
    
    Parameters
    ----------
    zol : float
        height over MO length
    alpha, beta, gamma : float
        constants given by get_stabco
    
    Returns
    -------
    psit : float
    """
    xtmp = (1-alpha*zol)**beta
    psit = 2*np.log((1+xtmp**2)*0.5)
    return psit
# ---------------------------------------------------------------------


def psi_stab(zol, alpha, beta, gamma):
    """ Calculates heat stability function for stable conditions
    
    Parameters
    ----------
    zol : float
        height over MO length
    alpha, beta, gamma : float
        constants given by get_stabco
    
    Returns
    -------
    psit : float
    """
    psit = -gamma*zol
    return psit
# ---------------------------------------------------------------------


def psit_26(zol):
    """ Computes temperature structure function as in C35
    
    Parameters
    ----------
    zol : float
        height over MO length
        
    Returns
    -------
    psi : float
    """
    dzol = np.where(0.35*zol > 50, 50, 0.35*zol)  # stable
    psi = -((1+0.6667*zol)**1.5+0.6667*(zol-14.28)*np.exp(-dzol)+8.525)
    k = np.where(zol < 0)  # unstable
    x = (1-15*zol[k])**0.5
    psik = 2*np.log((1+x)/2)
    x = (1-34.15*zol[k])**0.3333
    psic = (1.5*np.log((1+x+x**2)/3)-np.sqrt(3)*np.arctan((1+2*x) /
            np.sqrt(3))+4*np.arctan(1)/np.sqrt(3))
    f = zol[k]**2/(1+zol[k]**2)
    psi[k] = (1-f)*psik+f*psic
    return psi
# ---------------------------------------------------------------------


def psim_stab_era5(zol, alpha, beta, gamma):
    """ Calculates momentum stability function for stable conditions 
        for method ERA5
    
    Parameters
    ----------
    zol : float
        height over MO length
    alpha, beta, gamma : float
        constants given by get_stabco
    
    Returns
    -------
    psim : float
    """
    # eq (3.22) p. 39 IFS Documentation cy46r1
    a, b, c, d = 1, 2/3, 5, 0.35
    psim = -b*(zol-c/d)*np.exp(-d*zol)-a*zol-(b*c)/d
    return psim
# ---------------------------------------------------------------------


def psim_conv(zol, alpha, beta, gamma):
    """ Calculates momentum stability function for unstable conditions
    
    Parameters
    ----------
    zol : float
        height over MO length
    alpha, beta, gamma : float
        constants given by get_stabco
    
    Returns
    -------
    psim : float
    """
    xtmp = (1-alpha*zol)**beta
    psim = (2*np.log((1+xtmp)*0.5)+np.log((1+xtmp**2)*0.5) -
            2*np.arctan(xtmp)+np.pi/2)
    return psim
# ---------------------------------------------------------------------


def psim_stab(zol, alpha, beta, gamma):
    """ Calculates momentum stability function for stable conditions
    
    Parameters
    ----------
    zol : float
        height over MO length
    alpha, beta, gamma : float
        constants given by get_stabco
    
    Returns
    -------
    psim : float
    """
    psim = -gamma*zol
    return psim
# ---------------------------------------------------------------------


def psiu_26(zol):
    """ Computes velocity structure function C35
    
    Parameters
    ----------
    zol : float
        height over MO length
   
    Returns
    -------
    psi : float
    """
    dzol = np.where(0.35*zol > 50, 50, 0.35*zol)  # stable
    a, b, c, d = 0.7, 3/4, 5, 0.35
    psi = -(a*zol+b*(zol-c/d)*np.exp(-dzol)+b*c/d)
    k = np.where(zol < 0)  # unstable
    x = (1-15*zol[k])**0.25
    psik = 2*np.log((1+x)/2)+np.log((1+x**2)/2)-2*np.arctan(x)+2*np.arctan(1)
    x = (1-10.15*zol[k])**0.3333
    psic = (1.5*np.log((1+x+x**2)/3)-np.sqrt(3)*np.arctan((1+2*x)/np.sqrt(3)) +
            4*np.arctan(1)/np.sqrt(3))
    f = zol[k]**2/(1+zol[k]**2)
    psi[k] = (1-f)*psik+f*psic
    return psi
# ------------------------------------------------------------------------------


def psiu_40(zol):
    """ Computes velocity structure function C35
    
    Parameters
    ----------
    zol : float
        height over MO length
   
    Returns
    -------
    psi : float
    """
    dzol = np.where(0.35*zol > 50, 50, 0.35*zol)  # stable
    a, b, c, d = 1, 3/4, 5, 0.35
    psi = -(a*zol+b*(zol-c/d)*np.exp(-dzol)+b*c/d)
    k = np.where(zol < 0)  # unstable
    x = (1-18*zol[k])**0.25
    psik = 2*np.log((1+x)/2)+np.log((1+x**2)/2)-2*np.arctan(x)+2*np.arctan(1)
    x = (1-10*zol[k])**0.3333
    psic = (1.5*np.log((1+x+x**2)/3)-np.sqrt(3)*np.arctan((1+2*x)/np.sqrt(3)) +
            4*np.arctan(1)/np.sqrt(3))
    f = zol[k]**2/(1+zol[k]**2)
    psi[k] = (1-f)*psik+f*psic
    return psi
# ---------------------------------------------------------------------


def get_skin(sst, qsea, rho, Rl, Rs, Rnl, cp, lv, usr, tsr, qsr, lat):
    """ Computes cool skin
    
    Parameters
    ----------
    sst : float
        sea surface temperature ($^\circ$\,C)
    qsea : float
        specific humidity over sea (g/kg)
    rho : float
        density of air (kg/m^3)
    Rl : float
        downward longwave radiation (W/m^2)
    Rs : float
        downward shortwave radiation (W/m^2)
    cp : float
       specific heat of air at constant pressure
    lv : float
       latent heat of vaporization
    usr : float
       friction velocity
    tsr : float
       star temperature
    qsr : float
       star humidity
    lat : float
       latitude
    
    Returns
    -------
    dter : float
    dqer : float      
    
    """
    # coded following Saunders (1967) with lambda = 6
    g = gc(lat, None)
    if (np.nanmin(sst) > 200):  # if Ta in Kelvin convert to Celsius
        sst = sst-273.16
    # ************  cool skin constants  *******
    # density of water, specific heat capacity of water, water viscosity,
    # thermal conductivity of water
    rhow, cpw, visw, tcw = 1022, 4000, 1e-6, 0.6
    Al = 2.1e-5*(sst+3.2)**0.79
    be = 0.026
    bigc = 16*g*cpw*(rhow*visw)**3/(tcw*tcw*rho*rho)
    wetc = 0.622*lv*qsea/(287.1*(sst+273.16)**2)
    Rns = 0.945*Rs  # albedo correction
    hsb = -rho*cp*usr*tsr
    hlb = -rho*lv*usr*qsr
    qout = Rnl+hsb+hlb
    tkt = 0.001*np.ones(np.shape(sst))
    dels = Rns*(0.065+11*tkt-6.6e-5/tkt*(1-np.exp(-tkt/8.0e-4)))
    qcol = qout-dels
    alq = Al*qcol+be*hlb*cpw/lv
    xlamx = np.where(alq > 0, 6/(1+(bigc*alq/usr**4)**0.75)**0.333, 6)
    tkt = xlamx*visw/(np.sqrt(rho/rhow)*usr)
    tkt = np.where(alq > 0, np.where(tkt > 0.01, 0.01, tkt), tkt)
    dter = qcol*tkt/tcw
    dqer = wetc*dter
    return dter, dqer
# ---------------------------------------------------------------------


def get_gust(beta, Ta, usr, tsrv, zi, lat):
    """ Computes gustiness
    
    Parameters
    ----------
    beta : float
        constant
    Ta : float
        air temperature (K)
    usr : float
        friction velocity (m/s)
    tsrv : float
        star virtual temperature of air (K)
    zi : int
        scale height of the boundary layer depth (m)
    lat : float
        latitude
    
    Returns
    -------
    ug : float
    """
    if (np.max(Ta) < 200):  # convert to K if in Celsius
        Ta = Ta+273.16
    if np.isnan(zi):
        zi = 600
    g = gc(lat, None)
    Bf = -g/Ta*usr*tsrv
    ug = np.ones(np.shape(Ta))*0.2
    ug = np.where(Bf > 0, beta*np.power(Bf*zi, 1/3), 0.2)
    return ug
# ---------------------------------------------------------------------


def get_heights(h):
    """ Reads input heights for velocity, temperature and humidity
    
    Parameters
    ----------
    h : float
        input heights (m)
        
    Returns
    -------
    hh : array
    """
    hh = np.zeros(3)
    if (type(h) == float or type(h) == int):
        hh[0], hh[1], hh[2] = h, h, h
    elif len(h) == 2:
        hh[0], hh[1], hh[2] = h[0], h[1], h[1]
    else:
        hh[0], hh[1], hh[2] = h[0], h[1], h[2]
    return hh
# ---------------------------------------------------------------------


def svp_calc(T):
    """ Calculates saturation vapour pressure
    
    Parameters
    ----------
    T : float
        temperature (K)
    
    Returns
    -------
    svp : float
        in mb, pure water
    """
    if (np.nanmin(T) < 200):  # if T in Celsius convert to Kelvin
        T = T+273.16
    svp = np.where(np.isnan(T), np.nan, 2.1718e08*np.exp(-4157/(T-33.91-0.16)))
    return svp
# ---------------------------------------------------------------------


def qsea_calc(sst, pres):
    """ Computes specific humidity of the  sea surface air
    
    Parameters
    ----------
    sst : float
        sea surface temperature (K)
    pres : float
        pressure (mb)
    
    Returns
    -------
    qsea : float 
        (kg/kg)
    """
    if (np.nanmin(sst) < 200):  # if sst in Celsius convert to Kelvin
        sst = sst+273.16
    ed = svp_calc(sst)
    e = 0.98*ed
    qsea = (0.622*e)/(pres-0.378*e)
    qsea = np.where(~np.isnan(sst+pres), qsea, np.nan)
    return qsea
# ---------------------------------------------------------------------


def q_calc(Ta, rh, pres):
    """ Computes specific humidity following Haltiner and Martin p.24
    
    Parameters
    ----------
    Ta : float
        air temperature (K)
    rh : float
        relative humidity (%)
    pres : float
        air pressure (mb)
        
    Returns
    -------
    qair : float, (kg/kg)
    """
    if (np.nanmin(Ta) < 200):  # if sst in Celsius convert to Kelvin
        Ta = Ta+273.15
    e = np.where(np.isnan(Ta+rh+pres), np.nan, svp_calc(Ta)*rh*0.01)
    qair = np.where(np.isnan(e), np.nan, ((0.62197*e)/(pres-0.378*e)))
    return qair
# ------------------------------------------------------------------------------


def bucksat(T, P):
    """ Computes saturation vapor pressure (mb) as in C35
    
    Parameters
    ----------
    T : float
        temperature ($^\\circ$\\,C)
    P : float
        pressure (mb)
    
    Returns
    -------
    exx : float
    """
    T = np.asarray(T)
    if (np.nanmin(T) > 200):  # if Ta in Kelvin convert to Celsius
        T = T-CtoK
    exx = 6.1121*np.exp(17.502*T/(T+240.97))*(1.0007+3.46e-6*P)
    return exx
# ------------------------------------------------------------------------------


def qsat26sea(T, P):
    """ Computes surface saturation specific humidity (g/kg) as in C35
    
    Parameters
    ----------
    T : float
        temperature ($^\\circ$\\,C)
    P : float
        pressure (mb)
        
    Returns
    -------
    qs : float
    """
    T = np.asarray(T)
    if (np.nanmin(T) > 200):  # if Ta in Kelvin convert to Celsius
        T = T-CtoK
    ex = bucksat(T, P)
    es = 0.98*ex  # reduction at sea surface
    qs = 622*es/(P-0.378*es)
    return qs
# ------------------------------------------------------------------------------


def qsat26air(T, P, rh):
    """ Computes saturation specific humidity (g/kg) as in C35
    
    Parameters
    ----------
    T : float
        temperature ($^\circ$\,C)
    P : float
        pressure (mb)
        
    Returns
    -------
    q : float
    em : float
    """
    T = np.asarray(T)
    if (np.nanmin(T) > 200):  # if Ta in Kelvin convert to Celsius
        T = T-CtoK
    es = bucksat(T, P)
    em = 0.01*rh*es
    q = 622*em/(P-0.378*em)
    return q, em
# ---------------------------------------------------------------------


def gc(lat, lon=None):
    """ Computes gravity relative to latitude
    
    Parameters
    ----------
    lat : float
        latitude ($^\circ$)
    lon : float
        longitude ($^\circ$, optional)
    
    Returns
    -------
    gc : float
        gravity constant (m/s^2)
    """
    gamma = 9.7803267715
    c1 = 0.0052790414
    c2 = 0.0000232718
    c3 = 0.0000001262
    c4 = 0.0000000007
    if lon is not None:
        lon_m, lat_m = np.meshgrid(lon, lat)
    else:
        lat_m = lat
    phi = lat_m*np.pi/180.
    xx = np.sin(phi)
    gc = (gamma*(1+c1*np.power(xx, 2)+c2*np.power(xx, 4)+c3*np.power(xx, 6) +
          c4*np.power(xx, 8)))
    return gc
# ---------------------------------------------------------------------


def visc_air(Ta):
    """ Computes the kinematic viscosity of dry air as a function of air temp.
    following Andreas (1989), CRREL Report 89-11.
    
    Parameters
    ----------
    Ta : float
        air temperature ($^\circ$\,C)
    
    Returns
    -------
    visa : float
        kinematic viscosity (m^2/s)
    """
    Ta = np.asarray(Ta)
    if (np.nanmin(Ta) > 200):  # if Ta in Kelvin convert to Celsius
        Ta = Ta-273.16
    visa = 1.326e-5 * (1 + 6.542e-3*Ta + 8.301e-6*Ta**2 - 4.84e-9*Ta**3)
    return visa