Source code for neutralocean.lib

"""Library of simple functions for neutral_surfaces"""

import numpy as np
import numba as nb
import xarray as xr

from .eos import make_eos, make_eos_s_t


[docs]@nb.njit def find_first_nan(a): """The index to the first NaN along the last axis Parameters ---------- a : ndarray Input array possibly containing some NaN elements Returns ------- k : ndarray of int The index to the first NaN along each 1D array making up `a`, as in the following example with `a` being 3D. If all `a[i,j,:]` are NaN, then `k[i,j] = 0`. If all `a[i,j,:]` are not NaN, then `k[i,j] = a.shape[-1]`. Otherwise, `K = k[i,j]` is the smallest int such that `a[i,j,K-1]` is not NaN, but `a[i,j,K]` is NaN. """ nk = a.shape[-1] k = np.full(a.shape[:-1], nk, dtype=np.int_) for n in np.ndindex(a.shape[0:-1]): for i in range(nk): if np.isnan(a[n][i]): k[n] = i break return k
@nb.njit def take_fill(a, idx, fillval=np.nan): """ Like numpy.take but fills with nan when indices are out of range Parameters ---------- a : ndarray input data idx : 1d array linear indices to elements of `a` Returns ------- b : ndarray The i'th element of b (in linear order) is the `idx[i]`'th element of `a`, (in linear order), or nan if `idx[i] < 0`. Same shape as `idx`. """ b = np.empty(idx.size, dtype=a.dtype) a_ = a.reshape(-1) for i in range(len(b)): if idx[i] >= 0: b[i] = a_[idx[i]] else: b[i] = fillval return b @nb.njit def aggsum(a, idx, n): """ Aggregate data into groups and then sum each group. Parameters ---------- a : array Input data to be aggregated into groups and summed. idx : array of int Group label for each element of `a`. To exclude element `i` of `a` from any group, let `idx[i]` be a negative int. Must be same size as `a`. n : int Number of groups, including empty groups. As this is also the length of `b`, must satisfy `n >= np.max(idx) + 1`. Returns ------- b : array The sum of each group of data from `a`. Notes ----- This is a simple implementation of `numpy_groupies.aggregate`. See https://github.com/ml31415/numpy-groupies/ """ b = np.zeros(n, dtype=a.dtype) for i in range(len(idx)): if idx[i] >= 0: b[idx[i]] += a[i] return b
[docs]def val_at(T, k): """Evaluate nD array at given indices along its last dimension Parameters ---------- T : ndarray Input array. Can be 1D or nD. k : int or ndarray Index at which to evaluate `T` along its last dimension. Can be an int or (n-1)D. Returns ------- Tk : ndarray The input `T` evaluated with its last index equal to `k`. Notes ----- If `T` is 3D and `k` is 2D, then `Tk[i,j] = T[i,j,k[i,j]]` for each valid `(i,j)`. If `T` is 1D and `k` is 2D, then `Tk[i,j] = T[k[i,j]]` for each valid `(i,j)`. If `T` is 3D and `k` is an int, then `Tk[i,j] = T[i,j,k]` for each valid `(i,j)`. Examples -------- Evaluate temperature, having data in each water column, at the bottom grid cell >>> T = np.empty((3, 2, 10)) # (longitude, latitude, depth), let us say >>> T[..., :] = np.arange(10, 0, -1) # decreasing along depth dim from 10 to 1 >>> T[0, 0, :] = np.nan # make cast (0,0) be land >>> T[0, 1, 3:] = np.nan # make cast (0,1) be only 3 ocean cells deep >>> n_good = find_first_nan(T) >>> val_at(T, n_good - 1) array([[nan, 8.], [ 1., 1.], [ 1., 1.]]) Evaluate the depth at the bottom grid cell >>> Z = np.linspace(0, 4500, 10) # grid cell centre's are at depths 0, 500, 1000, ..., 4500. >>> val_at(Z, n_good - 1) # Z doesn't have NaN structure, so use n_good from T as above array([[ nan, 1000.], [4500., 4500.], [4500., 4500.]]) """ if isinstance(k, int) or T.ndim == 1: # select the k'th element along the last dimension of T Tk = T[..., k] elif T.ndim == k.ndim + 1: # if k[i,j] == 0, this will index T[i,j,-1] which will be nan, so T_bot[i,j] == nan. Tk = np.take_along_axis(T, k[..., None], -1).squeeze() else: raise ValueError( "T must be 1 dimensional or have 1 more dimension than k" ) # Set to NaN any place where k is negative Tk[k < 0] = np.nan return Tk
def xr_to_np(S): """Convert xarray into numpy array""" if hasattr(S, "values"): S = S.values return S def _xr_in(S, vert_dim): # Prepare xarray container for output: like input S but without dimension # labelled `drop_dim` if isinstance(S, xr.core.dataarray.DataArray): if vert_dim is None: return xr.full_like(S, 0) elif isinstance(vert_dim, int): vert_dim = S.dims[vert_dim] # convert to str return xr.full_like(S.isel({vert_dim: 0}).drop_vars(vert_dim), 0) else: return None def _xrs_in(S, T, P, vert_dim): # Prepare xarray containers for output: like inputs S, T, P but without # the dimension labelled `vert_dim`. Doing S, T, P together allows for # pxr to be an xarray even if P is an ndarray -- it just won't have attributes. sxr, txr = (_xr_in(X, vert_dim) for X in (S, T)) if sxr is None: pxr = None else: pxr = sxr.copy() try: pxr.attrs.update(P.attrs) pxr.name = P.name except: pxr.attrs.clear() pxr.name = None return sxr, txr, pxr def _xr_out(s, sxr): # Return xarrays if inputs were xarrays if isinstance(sxr, xr.core.dataarray.DataArray): sxr.data = s return sxr else: return s def _process_vert_dim(vert_dim, S): """Convert `vert_dim` as a str naming a dimension in `S` or a (possibly negative) int into an int between 0 and S.ndim-1.""" if isinstance(vert_dim, str) and hasattr(S, "dims"): try: vert_dim = S.dims.index(vert_dim) except: raise ValueError(f"vert_dim = {vert_dim} not found in S.dims") return np.mod(vert_dim, S.ndim) def _contiguous_casts(S, vert_dim=-1): """Make individual casts contiguous in memory Parameters ---------- S : ndarray ocean data such as salinity, temperature, or pressure vert_dim : int, Default -1 Specifies which dimension of `S` is vertical. Returns ------- S : ndarray input data, possibly re-arranged to have `vert_dim` the last dimension """ if S.ndim > 1 and vert_dim not in (-1, S.ndim - 1): S = np.moveaxis(S, vert_dim, -1) return np.require(S, dtype=np.float64, requirements="C") def _process_casts(S, T, P, vert_dim): """Make individual casts contiguous in memory and extract numpy array from xarray""" vert_dim = _process_vert_dim(vert_dim, S) S, T, P = (xr_to_np(x) for x in (S, T, P)) # Broadcast a 1D vector for P into a ND array like S if P.ndim < S.ndim: # First make P a 3D array with its non-singleton dimension be `vert_dim` P = np.reshape( P, tuple(-1 if x == vert_dim else 1 for x in range(S.ndim)) ) P = np.broadcast_to(P, S.shape) S, T, P = (_contiguous_casts(x, vert_dim) for x in (S, T, P)) return S, T, P def _interp_casts(S, T, P, interp_fn, Sppc=None, Tppc=None): # Compute interpolants for S and T casts (unless already provided) ni, nj, nk = S.shape if Sppc is None or Sppc.shape[0:-1] != (ni, nj, nk - 1): Sppc = interp_fn(P, S) if Tppc is None or Tppc.shape[0:-1] != (ni, nj, nk - 1): Tppc = interp_fn(P, T) return Sppc, Tppc def _process_wrap(wrap, s=None, diags=False): """Convert to a tuple of `int`s specifying which horizontal dimensions are periodic""" if wrap is None: if diags: raise ValueError( "wrap must be given for omega surfaces, or when `diags` is True" ) else: return wrap if isinstance(wrap, str): wrap = (wrap,) # Convert single string to tuple if not isinstance(wrap, (tuple, list)): raise TypeError("If given, wrap must be a tuple or list or str") if all(isinstance(x, str) for x in wrap): try: # Convert dim names to tuple of bool wrap = tuple(x in wrap for x in s.dims) except: raise TypeError( "With wrap provided as strings, s must have a .dims attribute" ) # type checking on final value if not (isinstance(wrap, (tuple, list)) and len(wrap) == 2): raise TypeError( "wrap must be a two element (logical) array " "or a string (or array of strings) referring to dimensions in xarray S" ) return wrap def _process_pin_cast(pin_cast, S): """ If pinning cast is a dict: convert from a coordinate representation, suitable for `S.sel(pin_cast)` where S is an xarray, into an index representation, suitable for `S[pin_cast]` where S is an ndarray. If pinning cast is an int: wrap it into a 1-element tuple, so np.ravel_multi_index works Otherwise, just return the input `pin_cast`. """ # TODO: There must be a better way of doing this... # One issue is this always rounds one way, whereas a "nearest" neighbour # type behaviour would be preferred, as in xr.DataArray.sel if isinstance(pin_cast, dict): return tuple( int(S.get_index(k).searchsorted(v)) for (k, v) in pin_cast.items() ) elif isinstance(pin_cast, int): return (pin_cast,) else: return pin_cast def _process_eos(eos, grav=None, rho_c=None, need_s_t=False): # Process equation of state argument and make cache functions eos_s_t = None if isinstance(eos, str): if need_s_t: eos_s_t = make_eos_s_t(eos, grav, rho_c) eos = make_eos(eos, grav, rho_c) else: if need_s_t: if isinstance(eos, (tuple, list)) and len(eos) == 2: eos_s_t = eos[1] eos = eos[0] if not callable(eos) or not callable(eos_s_t): raise ValueError( "If `eos` is not a str, expected a tuple of length two" " containing an eos function and an eos_s_t function." ) else: if isinstance(eos, (tuple, list)) and len(eos) >= 1: eos = eos[0] if not callable(eos): raise ValueError("If `eos` is not a str, expected a function.") return eos, eos_s_t