Source code for neutralocean.lib

"""Library of simple functions for neutralocean"""

import numpy as np
import numba as nb
import xarray as xr


[docs]def find_first_nan(a, axis=-1): """The index to the first NaN along a given axis Parameters ---------- a : ndarray Input array possibly containing some NaN elements axis : int, Default -1 Axis along which to find the first NaN Returns ------- k : ndarray of int For each 1D array `a1` of `a` in the dimension `axis`, `k` gives: the index to the first NaN of `a1`, or 0 if `a1` has no NaNs, or `a.shape[axis]` if `a1` is all NaN. For example, with `a` being 3D and `axis` = -1: If all `a[i,j,:]` are NaN, then `k[i,j] = 0`. If all `a[i,j,:]` are not NaN, then `k[i,j] = a.shape[-1]`. Otherwise, `K = k[i,j]` is the smallest int such that `a[i,j,K-1]` is not NaN and `a[i,j,K]` is NaN. Notes ----- This is similar to (and faster than) `numpy.argmax(numpy.isnan(a), axis)`, but differs for indices along which `a` is all NaN: this output contains `a.shape[axis]` whereas the `argmax` approach contains 0. """ axis = axis % a.ndim # handle when axis < 0 # Shapes of a before and after axis shape1 = a.shape[0:axis] shape3 = a.shape[axis + 1 :] # Number of elements before, at, and after axis n1 = np.prod(shape1, dtype=int) n2 = a.shape[axis] n3 = np.prod(shape3, dtype=int) # Reshape a to 3D, use helper function, then reshape output back a = np.reshape(a, (n1, n2, n3)) out = _find_first_nan(a) return np.reshape(out, (*shape1, *shape3))
@nb.njit def _find_first_nan(a): """The index to the first NaN along the middle axis of a 3D array""" I, J, K = a.shape out = np.full((I, K), J, dtype=np.int64) for i in range(I): for k in range(K): a1 = a[i, :, k] for j in range(J): if np.isnan(a1[j]): out[i][k] = j break return out
[docs]@nb.njit def take_fill(a, idx, fillval=np.nan): """ Like numpy.take but fills with nan when indices are out of range Parameters ---------- a : ndarray input data idx : 1d array linear indices to elements of `a` Returns ------- b : ndarray The i'th element of b (in linear order) is the `idx[i]`'th element of `a`, (in linear order), or nan if `idx[i] < 0`. Same shape as `idx`. """ b = np.empty(idx.size, dtype=a.dtype) a_ = a.reshape(-1) for i in range(len(b)): if idx[i] >= 0: b[i] = a_[idx[i]] else: b[i] = fillval return b
[docs]@nb.njit def aggsum(a, idx, n): """ Aggregate data into groups and sum each group. Parameters ---------- a : array Input data to be aggregated into groups and summed. idx : array of int Group label for each element of `a`. To exclude element `i` of `a` from any group, let `idx[i]` be a negative int. Must be same size as `a`. n : int Number of groups, including empty groups. As this is also the length of `b`, must satisfy `n >= np.max(idx) + 1`. Returns ------- b : array The sum of each group of data from `a`. Notes ----- This is a simple implementation of `numpy_groupies.aggregate`. See https://github.com/ml31415/numpy-groupies/ """ b = np.zeros(n, dtype=a.dtype) for i in range(len(idx)): if idx[i] >= 0: b[idx[i]] += a[i] return b
[docs]def val_at(T, k): """Evaluate nD array at given indices along its last dimension Parameters ---------- T : ndarray Input array. Can be 1D or nD. k : int or ndarray Index at which to evaluate `T` along its last dimension. Can be an int or (n-1)D. Returns ------- Tk : ndarray The input `T` evaluated with its last index equal to `k`. Notes ----- If `T` is 3D and `k` is 2D, then `Tk[i,j] = T[i,j,k[i,j]]` for each valid `(i,j)`. If `T` is 1D and `k` is 2D, then `Tk[i,j] = T[k[i,j]]` for each valid `(i,j)`. If `T` is 3D and `k` is an int, then `Tk[i,j] = T[i,j,k]` for each valid `(i,j)`. Examples -------- Evaluate temperature, having data in each water column, at the bottom grid cell >>> T = np.empty((3, 2, 10)) # (longitude, latitude, depth), let us say >>> T[..., :] = np.arange(10, 0, -1) # decreasing along depth dim from 10 to 1 >>> T[0, 0, :] = np.nan # make cast (0,0) be land >>> T[0, 1, 3:] = np.nan # make cast (0,1) be only 3 ocean cells deep >>> n_good = find_first_nan(T) >>> val_at(T, n_good - 1) array([[nan, 8.], [ 1., 1.], [ 1., 1.]]) Evaluate the depth at the bottom grid cell >>> Z = np.linspace(0, 4500, 10) # grid cell centre's are at depths 0, 500, 1000, ..., 4500. >>> val_at(Z, n_good - 1) # Z doesn't have NaN structure, so use n_good from T as above array([[ nan, 1000.], [4500., 4500.], [4500., 4500.]]) """ if isinstance(k, int) or T.ndim == 1: # select the k'th element along the last dimension of T Tk = T[..., k] elif T.ndim == k.ndim + 1: # if k[i,j] == 0, this will index T[i,j,-1] which will be nan, so T_bot[i,j] == nan. Tk = np.take_along_axis(T, k[..., None], -1).squeeze() else: raise ValueError("T must be 1 dimensional or have 1 more dimension than k") # Set to NaN any place where k is negative Tk[k < 0] = np.nan return Tk
[docs]def xr_to_np(S): """Convert xarray into numpy array""" if hasattr(S, "values"): S = S.values return S
def _xr_in(S, vert_dim): # Prepare xarray container for output: like input S but without dimension # labelled `drop_dim` if isinstance(S, xr.core.dataarray.DataArray): if vert_dim is None: return xr.full_like(S, 0) elif isinstance(vert_dim, int): vert_dim = S.dims[vert_dim] # convert to str return xr.full_like(S.isel({vert_dim: 0}).drop_vars(vert_dim), 0) else: return None def _xrs_in(S, T, P, vert_dim): # Prepare xarray containers for output: like inputs S, T, P but without # the dimension labelled `vert_dim`. Doing S, T, P together allows for # pxr to be an xarray even if P is an ndarray -- it just won't have attributes. sxr, txr = (_xr_in(X, vert_dim) for X in (S, T)) if sxr is None: pxr = None else: pxr = sxr.copy() try: pxr.attrs.update(P.attrs) pxr.name = P.name except: pxr.attrs.clear() pxr.name = None return sxr, txr, pxr def _xr_out(s, sxr): # Return xarrays if inputs were xarrays if isinstance(sxr, xr.core.dataarray.DataArray): sxr.data = s return sxr else: return s def _process_vert_dim(vert_dim, S): """Convert `vert_dim` as a str naming a dimension in `S` or a (possibly negative) int into an int between 0 and S.ndim-1.""" if isinstance(vert_dim, str) and hasattr(S, "dims"): try: vert_dim = S.dims.index(vert_dim) except: raise ValueError(f"vert_dim = {vert_dim} not found in S.dims") return np.mod(vert_dim, S.ndim) def _contiguous_casts(S, vert_dim=-1): """Make individual casts contiguous in memory Parameters ---------- S : ndarray ocean data such as salinity, temperature, or pressure vert_dim : int, Default -1 Specifies which dimension of `S` is vertical. Returns ------- S : ndarray input data, possibly re-arranged to have `vert_dim` the last dimension """ if S.ndim > 1 and vert_dim not in (-1, S.ndim - 1): S = np.moveaxis(S, vert_dim, -1) return np.require(S, dtype=np.float64, requirements="C") def _process_casts(S, T, P, vert_dim): """Make individual casts contiguous in memory and extract numpy array from xarray""" vert_dim = _process_vert_dim(vert_dim, S) S, T, P = (xr_to_np(x) for x in (S, T, P)) # Broadcast a 1D vector for P into a ND array like S if P.ndim < S.ndim: # First make P a 3D array with its non-singleton dimension be `vert_dim` P = np.reshape(P, tuple(-1 if x == vert_dim else 1 for x in range(S.ndim))) P = np.broadcast_to(P, S.shape) S, T, P = (_contiguous_casts(x, vert_dim) for x in (S, T, P)) return S, T, P def _interp_casts(S, T, P, interp_fn, Sppc=None, Tppc=None): # Compute interpolants for S and T casts (unless already provided) ni, nj, nk = S.shape if Sppc is None or Sppc.shape[0:-1] != (ni, nj, nk - 1): Sppc = interp_fn(P, S) if Tppc is None or Tppc.shape[0:-1] != (ni, nj, nk - 1): Tppc = interp_fn(P, T) return Sppc, Tppc def _process_wrap(wrap, s=None, diags=False): """Convert to a tuple of `int`s specifying which horizontal dimensions are periodic""" if wrap is None: if diags: raise ValueError( "wrap must be given for omega surfaces, or when `diags` is True" ) else: return wrap if isinstance(wrap, str): wrap = (wrap,) # Convert single string to tuple if not isinstance(wrap, (tuple, list)): raise TypeError("If given, wrap must be a tuple or list or str") if all(isinstance(x, str) for x in wrap): try: # Convert dim names to tuple of bool wrap = tuple(x in wrap for x in s.dims) except: raise TypeError( "With wrap provided as strings, s must have a .dims attribute" ) # type checking on final value if not (isinstance(wrap, (tuple, list)) and len(wrap) == 2): raise TypeError( "wrap must be a two element (logical) array " "or a string (or array of strings) referring to dimensions in xarray S" ) return wrap def _process_pin_cast(pin_cast, S): """ If pinning cast is a dict: convert from a coordinate representation, suitable for `S.sel(pin_cast)` where S is an xarray, into an index representation, suitable for `S[pin_cast]` where S is an ndarray. If pinning cast is an int: wrap it into a 1-element tuple, so np.ravel_multi_index works Otherwise, just return the input `pin_cast`. """ # TODO: There must be a better way of doing this... # One issue is this always rounds one way, whereas a "nearest" neighbour # type behaviour would be preferred, as in xr.DataArray.sel if isinstance(pin_cast, dict): return tuple(int(S.get_index(k).searchsorted(v)) for (k, v) in pin_cast.items()) elif isinstance(pin_cast, int): return (pin_cast,) else: return pin_cast def local_functions(_locals, _name): """List of public functions defined in the local scope. This excludes functions beginning with an "_" as well as imported functions. At the end of a module, use `__all__ = local_functions(locals(), __name__)` so that `from mymodule import *` will only import that module's public and locally-defined functions. """ return [ k for (k, v) in _locals.items() if callable(v) and v.__module__ == _name and not k.startswith("_") ] __all__ = local_functions(locals(), __name__)