"""Basic statistics routines."""
import numpy as np
from scipy import stats
from torx import Dimensionless
from typing import Union
import xarray as xr
xr.set_options(keep_attrs=True)
from torx.autodoc_decorators_m import autodoc_function
def _statistics_core(input_array: np.ndarray, function: str):
"""
Apply statistics function given by name to the input array.
This allows for convenient parallelization using ufunc.
"""
if function == "avg":
return np.nanmean(input_array)
elif function == "std":
return np.sqrt(np.nanvar(input_array))
# Ignore NaNs
input_array = input_array[~np.isnan(input_array)]
if function == "rms":
return np.sqrt(np.vdot(input_array, input_array)/input_array.size)
elif function == "skew":
return stats.skew(input_array)
elif function == "kurtosis":
return stats.kurtosis(input_array, fisher=False)
elif function == "excess_kurtosis":
return stats.kurtosis(input_array, fisher=True)
else:
raise ValueError("Choose from 'avg', 'std', 'skew', 'kurtosis' \
and 'excess_kurtosis'")
def statistics_ufunc(
data: Union[xr.Dataset, xr.DataArray],
function: str,
exclude_dims: Union[np.ndarray, str] = None
) -> Union[xr.Dataset, xr.DataArray]:
"""
Apply xarray ufunc for a statistics function given by name.
Wrapper performing statistics function over xarrays along all dimensions
not in exclude_dims.
The functions can be chosen from:
- 'rms'
- 'avg'
- 'std'
- 'skew'
- 'kurtosis'
- 'excess_kurtosis'
"""
# Perform np.mean() over all dims not in exclude_dims
exclude_dims = np.atleast_1d(exclude_dims)
input_core_dims = [d for d in list(data.dims) if not d in exclude_dims]
# Chunk the data according to the dims not parallelized over
data = data.chunk(chunks={dim: -1 for dim in input_core_dims})
# Apply statistics core function over xarray data
output_data = xr.apply_ufunc(
_statistics_core,
data.chunk(chunks={dim: -1 for dim in np.atleast_1d(exclude_dims)}),
input_core_dims=[input_core_dims],
output_core_dims=[[]],
kwargs=dict(function=function),
vectorize=True,
dask="parallelized",
output_dtypes = [np.float64],
dask_gufunc_kwargs = dict(allow_rechunk=False),
)
if function in ["skew", "kurtosis", "excess_kurtosis"]:
output_data.attrs["norm"] = Dimensionless
output_data.attrs["units"] = ""
return output_data
[docs]
@autodoc_function
def rms(data: Union[xr.Dataset, xr.DataArray], exclude_dims: list=[]):
"""
Calculate root-mean-square (rms) with parallelized algorithm.
Performs operation over xarrays along all dimensions not in exclude_dims.
"""
return statistics_ufunc(data, function="rms", exclude_dims=exclude_dims)
[docs]
@autodoc_function
def avg(data: Union[xr.Dataset, xr.DataArray], exclude_dims: list=[]):
"""
Calculate average/mean with parallelized algorithm.
Performs operation over xarrays along all dimensions not in exclude_dims.
"""
return statistics_ufunc(data, function="avg", exclude_dims=exclude_dims)
[docs]
@autodoc_function
def std(data: Union[xr.Dataset, xr.DataArray], exclude_dims: list=[]):
"""
Calculate standard deviation with parallelized algorithm.
Performs operation over xarrays along all dimensions not in exclude_dims.
"""
return statistics_ufunc(data, function="std", exclude_dims=exclude_dims)
[docs]
@autodoc_function
def skew(data: Union[xr.Dataset, xr.DataArray], exclude_dims: list=[]):
"""
Calculate skew with parallelized algorithm.
Performs operation over xarrays along all dimensions not in exclude_dims.
"""
return statistics_ufunc(data, function="skew", exclude_dims=exclude_dims)
[docs]
@autodoc_function
def kurtosis(data: Union[xr.Dataset, xr.DataArray], exclude_dims: list=[]):
"""
Calculate kurtosis with parallelized algorithm.
Performs operation over xarrays along all dimensions not in exclude_dims.
"""
return statistics_ufunc(data, function="kurtosis",
exclude_dims=exclude_dims)
[docs]
@autodoc_function
def excess_kurtosis(data: Union[xr.Dataset, xr.DataArray],
exclude_dims: list=[]):
"""
Calculate excess/Fisher kurtosis with parallelized algorithm.
Performs operation over xarrays along all dimensions not in exclude_dims.
"""
return statistics_ufunc(data, function="excess_kurtosis",
exclude_dims=exclude_dims)