"""Contains functionality for handling arrays neatly in torx."""
from dataclasses import dataclass
from typing import Optional
from numbers import Number
import operator
import scipy
import torx
import xarray as xr
import numpy as np
from torx.autodoc_decorators_m import autodoc_class, autodoc_function
[docs]
@autodoc_class
@dataclass
class TArray():
"""
Wrapper for xarray.DataArray with automatic unit handling.
Additional numpy functionality which does not exist in xarray can be added.
"""
data: xr.DataArray
norm: Optional[torx.Quantity] = None
@property
def shape(self):
"""Shape of the array."""
return self.data.shape
@property
def ndim(self):
"""Number of dimensions of the array."""
return self.data.ndim
@property
def dims(self):
"""Dimensions of the array."""
return self.data.dims
@property
def coords(self):
"""Coordinates of the array."""
return self.data.coords
@property
def values(self):
"""Data values of the array."""
return self.data.values
@property
def denormalized_values(self):
"""
Denormalized data values of the array.
This means that the array values are multiplied by the unit.
"""
return self.__array__()
@property
def unit(self):
"""Unit of the data values."""
return self.norm.u
[docs]
@autodoc_function
def to(self, unit):
"""Convert data to provided unit."""
return TArray(self.data, self.norm.to(unit))
[docs]
@autodoc_function
def tobase(self):
"""
Convert data to base unit.
Example: km to m.
"""
return TArray(self.data, self.norm.to_base_units())
[docs]
@autodoc_function
def denormalize(self):
"""Denormalize the TArray."""
return TArray(self.__array__(),
torx.Quantity(1, self.norm.u))
[docs]
@autodoc_function
def sel(self, *args, **kwargs):
"""Select value of provided argument."""
return TArray(self.data.sel(*args, **kwargs), self.norm)
[docs]
@autodoc_function
def isel(self, *args, **kwargs):
"""Select index of provided argument."""
return TArray(self.data.isel(*args, **kwargs), self.norm)
[docs]
@autodoc_function
def compute(self):
"""Compute the chunks, adding full data to memory."""
return TArray(self.data.compute(), self.norm)
@autodoc_function
def __post_init__(self):
"""
Normalize the TArray immediately after construction.
Ensures norms are compatible and strips attributes from the DataArray.
Defaults norm to 1 if not provided.
Generates functions for TArray.
"""
if not isinstance(type(self.data), xr.DataArray):
self.data = xr.DataArray(self.data)
if "norm" in self.data.attrs.keys():
if self.norm and \
torx.Quantity(self.norm) != torx.Quantity(self.data.norm):
raise ValueError("Conflicting norms found!\n"
+ "data.norm: " + str(self.data.norm) + "\n"
+ "self.norm: " + str(self.norm))
self.norm = self.data.norm
self.data = self.data.drop_attrs()
self.data = self.data.rename()
if isinstance(self.norm, str):
self.norm = torx.Quantity(self.norm)
if not self.norm:
self.norm = torx.Quantity(1)
self._generate_funcs()
@autodoc_function
def __add__(self, other):
"""
Add two TArray objects with compatible norms.
Returns another TArray with the same norm as the first array.
"""
self._assert_same_type(other)
self._assert_same_norm(other)
norm1 = self.norm
norm2 = other.norm.to(norm1.u)
data = self.data * norm1.m + other.data * norm2.m
data = data / norm1.m
return TArray(data, norm1)
@autodoc_function
def __sub__(self, other):
"""Subtract two TArray objects."""
self._assert_same_type(other)
self._assert_same_norm(other)
return self + other * -1
@autodoc_function
def __mul__(self, other):
"""Multiply TArray objects with other TArray objects, scalars and quantities."""
if type(self) == type(other):
return TArray(self.data * other.data,
self.norm * other.norm)
elif isinstance(other, Number):
return TArray(self.data * other,
self.norm)
elif isinstance(other, torx.Quantity):
return TArray(self.data,
self.norm * other)
else:
raise TypeError(
f"Unsupported operand type(s) for {type(self)} \
and {type(other)}"
)
def __rmul__(self, other):
"""Multiply from the right side."""
return self * other
@autodoc_function
def __truediv__(self, other):
"""Divide TArray objects with other TArray objects, scalars and quantities."""
if type(self) == type(other):
return TArray(self.data / other.data,
self.norm / other.norm)
elif isinstance(other, Number):
return TArray(self.data / other,
self.norm)
elif isinstance(other, torx.Quantity):
return TArray(self.data,
self.norm / other)
else:
raise TypeError(
f"Unsupported operand type(s) for {type(self)} \
and {type(other)}"
)
def __rtruediv__(self, other):
"""Divide from the right side via reciprocal."""
return other * self.reciprocal()
def __neg__(self):
"""Multiply by negative one."""
return self * -1
def __pos__(self):
"""Multiply by positive one."""
return self
def __pow__(self, exponent):
"""Raise to the power of an exponent."""
return TArray(self.data**exponent,
self.norm**exponent)
@autodoc_function
def __prod__(self):
"""
Return the product of all values in the TArray.
Scales the norm appropriately.
"""
return TArray(self.data.prod(),
self.norm**self.data.size)
@autodoc_function
def __nanprod__(self):
"""Return the product of all values and skip any NaNs in the array."""
num = np.count_nonzero(~np.isnan(self.data))
return TArray(self.data.prod(skipna=True),
self.norm**num)
@autodoc_function
def _get_axis_from_dims(self, axis=None, dims=None):
"""Return the axis from dimensions."""
if axis is not None:
assert dims is None
if dims is not None:
assert axis is None
if isinstance(dims, str):
dims = (dims,)
axis = tuple(self.data.dims.index(dim) for dim in dims)
return axis
@autodoc_function
def __gradient__(self, *args, axis=None, dims=None):
"""
Calculate the derivative along an axis.
Recalculates the norm as required.
"""
axis = self._get_axis_from_dims(axis, dims)
if args:
x = args[0]
self._assert_same_type(x)
grad = np.gradient(self.data, x.data, axis=axis)
norm = self.norm / x.norm
else:
grad = np.gradient(self.data, axis=axis)
norm = self.norm
grad = xr.DataArray(grad, dims=self.dims, coords=self.coords)
return TArray(grad, norm)
@autodoc_function
def __gradlog__(self, *args, axis=None, dims=None):
"""Compute the gradient of the log of the TArray."""
lognorm = TArray(self.data).log()
grad = lognorm.gradient(*args, axis=axis, dims=dims)
return grad
@autodoc_function
def __trapezoid__(self, *args, dx=1.0, axis=None, dims=None,
cumulative=False):
"""
Compute the integral of the TArray along an axis.
Recalculates the norm.
"""
if cumulative:
func = scipy.integrate.cumulative_trapezoid
else:
func = np.trapezoid
axis = self._get_axis_from_dims(axis, dims)
if axis is None:
axis = (-1,)
res = self.data
norm = self.norm
if args:
assert len(args) == len(axis)
for x, ax in zip(reversed(args), reversed(axis)):
self._assert_same_type(x)
res = func(res, x=x.data, axis=ax)
norm = norm * x.norm
else:
for ax in reversed(axis):
res = func(res, dx=dx, axis=ax)
if dims is not None:
residual_dims = tuple(dim for dim in self.dims if dim not in dims)
residual_coords = [self.coords[dim] for dim in residual_dims]
res = xr.DataArray(res, dims=residual_dims, coords=residual_coords)
return TArray(res, norm)
@autodoc_function
def __cumulative_trapezoid__(self, *args, dx=1.0, axis=None, dims=None):
"""Calculate the cumulative integral in the given axis."""
return self.__trapezoid__(*args, dx=dx, axis=axis, dims=dims,
cumulative=True)
def __argfunc__(self, func, axis=None, dims=None):
"""Perform function on TArray objects returning indices."""
axis = self._get_axis_from_dims(axis, dims)
if axis is not None:
assert len(axis) == 1
axis = axis[0]
return TArray(func(self.values, axis=axis))
def __argmin__(self, axis=None, dims=None):
"""Return the index where the minimum is found."""
return self.__argfunc__(np.argmin, axis=axis, dims=dims)
def __nanargmin__(self, axis=None, dims=None):
"""Return the index where the minimum is found, ignoring NaN."""
return self.__argfunc__(np.nanargmin, axis=axis, dims=dims)
def __argmax__(self, axis=None, dims=None):
"""Return the index where the maximum is found."""
return self.__argfunc__(np.argmax, axis=axis, dims=dims)
def __nanargmax__(self, axis=None, dims=None):
"""Return the index where the maximum is found, ignoring NaN."""
return self.__argfunc__(np.nanargmax, axis=axis, dims=dims)
def __argwhere__(self):
"""Return indices where a condition is matched."""
res = np.argwhere(self.values)
res = xr.DataArray(res, dims=("matches", "ndim"))
return TArray(res)
@autodoc_function
def __compare__(self, other, op, **kwargs):
"""
Perform logical comparison of TArray objects.
Checks typing and norms for compatibility.
"""
self._assert_same_type(other)
self._assert_same_norm(other)
norm1 = self.norm
norm2 = other.norm.to(norm1.u)
data = op(self.data * norm1.m, other.data * norm2.m, **kwargs)
return TArray(data)
def __lt__(self, other):
"""Return if TArray is less than other (elementwise)."""
return self.__compare__(other, operator.lt)
def __gt__(self, other):
"""Return if TArray is greater than other (elementwise)."""
return self.__compare__(other, operator.gt)
def __le__(self, other):
"""Return if TArray is less than or equal to other (elementwise)."""
return self.__compare__(other, operator.le)
def __ge__(self, other):
"""Return if TArray is greater than or equal to other (elementwise)."""
return self.__compare__(other, operator.ge)
def __eq__(self, other):
"""Return if TArray is equal to other (elementwise)."""
return self.__compare__(other, operator.eq)
def __ne__(self, other):
"""Return if TArray is not equal to other (elementwise)."""
return self.__compare__(other, operator.ne)
def __isclose__(self, other, **kwargs):
"""Return if TArray is close to other (elementwise)."""
return self.__compare__(other, np.isclose, **kwargs)
def __allclose__(self, other, **kwargs):
"""Return if TArray is close to other (all values)."""
return self.__compare__(other, np.allclose, **kwargs)
def __logical__(self, other, op):
"""Perform logical operation of TArray objects."""
self._assert_same_norm(other)
self._assert_dimensionless()
return TArray(op(self.data, other.data))
def __and__(self, other):
"""Return logical and of TArray (elementwise)."""
return self.__logical__(other, np.logical_and)
def __or__(self, other):
"""Return logical or of TArray (elementwise)."""
return self.__logical__(other, np.logical_or)
def __xor__(self, other):
"""Return logical xor of TArray (elementwise)."""
return self.__logical__(other, np.logical_xor)
def __invert__(self):
"""Return logical not of TArray (elementwise)."""
self._assert_dimensionless()
return TArray(np.logical_not(self.data))
def _convert_test(self, unit):
"""Test if conversion to unit is possible."""
try:
self.norm.to(unit)
return True
except:
return False
@autodoc_function
def _assert_same_type(self, other):
"""Check if two TArray objects are of the same type."""
if type(self) != type(other):
raise ValueError("Type of self and other must be the same!\n"
+ "self : " + str(type(self)) + "\n"
+ "other: " + str(type(other)))
@autodoc_function
def _assert_same_norm(self, other):
"""Check that two TArray objects have compatible norms."""
if not self._convert_test(other.norm.u):
raise ValueError("Units of self and other are not compatible!\n"
+ "self : " + str(self.norm.u) + "\n"
+ "other: " + str(other.norm.u))
@autodoc_function
def _assert_dimensionless(self):
"""Assert that the TArray is dimensionless."""
if not self._convert_test(""):
raise ValueError("Unit is not dimensionless!\n"
+ "unit: " + str(self.norm.u))
def __array__(self, **kwargs):
"""Return the array value times the norm."""
return self.data.values.__array__(**kwargs) * self.norm.m
@autodoc_function
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
"""
Handle user defined functions and passing of numpy functions.
Ensures unimplemented functions throw a TypeError.
Should be impossible to raise the "no TArray in args error".
"""
first_tarray = next((x for x in args if isinstance(x, TArray)), None)
if first_tarray is None:
raise TypeError(f"No TArray found in args!")
if method != '__call__':
raise TypeError(f"Function '{ufunc.__name__}' is not implemented \
for TArray")
if kwargs.get("out") is not None:
raise TypeError("out= keyword not implemented for TArray")
if not hasattr(first_tarray, ufunc.__name__):
raise TypeError(f"Function '{ufunc.__name__}' is not implemented \
for TArray")
args = self._filter_array_args(*args)
return getattr(self, ufunc.__name__)(*args, **kwargs)
def __array_function__(self, func, types, args, kwargs):
"""
Dispatch functions to corresponding methods.
Required for correct application of numpy internals to TArray.
"""
if not all(issubclass(t, self.__class__) for t in types):
return NotImplementedError
if not hasattr(self, func.__name__):
return NotImplementedError
args = self._filter_array_args(*args)
return getattr(self, func.__name__)(*args, **kwargs)
def _filter_array_args(self, *args):
"""Filter args for processing array functions."""
new_args = []
first = True
for arg in args:
if isinstance(arg, Number):
new_args.append(arg)
elif isinstance(arg, self.__class__):
if not first:
new_args.append(arg)
else:
first = False
else:
return NotImplemented
return new_args
[docs]
def __len__(self):
"""Return the size of the array."""
return self.data.size
def __iter__(self):
"""Return an iterator over array values."""
for i in self.data:
yield TArray(i, self.norm)
return
@autodoc_function
def _detect_plot(self, key):
"""Detect if there is a plot operation from pyplot."""
if isinstance(key, (int, np.integer)):
return False
if key is np.newaxis:
return True
if isinstance(key, (tuple, list)):
return any(k is np.newaxis for k in key)
return False
[docs]
def __getitem__(self, key):
"""Return an element of the array."""
if self._detect_plot(key):
return self.data.values.__getitem__(key) * self.norm.m
else:
return TArray(self.data.__getitem__(key), self.norm)
def __getslice__(self, key):
"""Return a slice of the array."""
if self._detect_plot(key):
return self.data.values.__getslice__(key) * self.norm.m
else:
return TArray(self.data.__getslice__(key), self.norm)
[docs]
@autodoc_function
def sqrt(self):
"""Define square root which handles the data and norm correctly."""
data = np.sqrt(self.data)
norm = self.norm ** 0.5
return TArray(data, norm)
[docs]
@autodoc_function
def reciprocal(self):
"""Define reciprocal which handles the data and norm correctly."""
data = np.reciprocal(self.data)
norm = self.norm ** -1
return TArray(data, norm)
@autodoc_function
def _generate_funcs(self):
"""Generate numpy and xarray functions for TArray class."""
supported_funcs = dict(
dimensionless=["sin", "cos", "tan",
"arcsin", "arccos", "arctan",
"sinh", "cosh", "tanh",
"arcsinh", "arccosh", "arctanh",
"exp", "log", "log10", "log2"],
dimension_agnostic=["real", "imag", "abs", "conjugate"],
returns_np=["atleast_1d", "atleast_2d", "atleast_3d",
"argmin", "argmax"],
returns_dimensionless=["sign", "isfinite", "isinf", "isnan",
"all", "any"]
)
for func in supported_funcs["dimension_agnostic"]:
call = getattr(np, func)
apply = lambda call=call: TArray(call(self.data),
self.norm)
setattr(self, func, apply)
for func in supported_funcs["dimensionless"]:
if self._convert_test(""):
call = getattr(np, func)
apply = lambda call=call: TArray(call(self.data),
self.norm)
else:
apply = self._assert_dimensionless
setattr(self, func, apply)
for func in supported_funcs["returns_np"]:
call = getattr(np, func)
apply = lambda call=call: call(self.data) * self.norm.m
setattr(self, func, apply)
for func in supported_funcs["returns_dimensionless"]:
call = getattr(np, func)
apply = lambda call=call: TArray(call(self.data))
setattr(self, func, apply)
xr_funcs = ["min", "max", "sum", "cumsum",
"diff", "mean", "std", "var", "median"]
for func in xr_funcs:
call = getattr(xr.DataArray, func)
apply = lambda *args, call=call: TArray(call(self.data, *args),
self.norm)
setattr(self, func, apply)
np_mappings = dict(
add="__add__",
subtract="__sub__",
positive="__pos__",
negative="__neg__",
multiply="__mul__",
divide="__truediv__",
power="__pow__",
asarray="__array__",
prod="__prod__",
nanprod="__nanprod__",
gradient="__gradient__",
gradlog="__gradlog__",
trapezoid="__trapezoid__",
cumulative_trapezoid="__cumulative_trapezoid__",
greater="__gt__",
greater_equal="__ge__",
less="__lt__",
less_equal="__le__",
equal="__eq__",
not_equal="__ne__",
isclose="__isclose__",
allclose="__allclose__",
argmin="__argmin__",
argmax="__argmax__",
nanargmin="__nanargmin__",
nanargmax="__nanargmax__",
argwhere="__argwhere__",
)
for func, mapping in np_mappings.items():
setattr(self, func, getattr(self, mapping))