"""Parametric bi-spline interpolator class."""
import numpy as np
import xarray as xr
from typing import Union
from scipy.interpolate import bisplrep, bisplev, RectBivariateSpline
from torx.autodoc_decorators_m import autodoc_class
[docs]
@autodoc_class
class SplineInterpolator2D:
"""
Spline interpolation in 2D for structured or unstructured grids.
A wrapper class around either RectBivariateSpline (for structured grid)
or a combination of bisplrep and splev (for unstructured grid)
from 'scipy.interpolate'.
"""
[docs]
def __init__(
self,
data: Union[xr.DataArray, np.ndarray],
r_points: np.ndarray=None,
z_points: np.ndarray=None,
kr: int=3,
kz: int=3,
smoothing: float=None
):
"""
Initialize the spline interpolator.
There are three foreseen ways to initialize the class:
1) Provide data as an xarray with 'R' and 'Z' as dimensions.
2) Provide the data as 2D xarray or numpy array plus separate
'r_points' and 'z_points' defining the spline basis.
3) Provide the data as 1D unstructured xarray or numpy array plus
'r_points' and 'z_points' giving the R and Z coordinates at each
point of the data array.
"""
# Handle different datatypes allowed for data input
if isinstance(data, xr.DataArray):
self.data = data.values
if np.logical_and.reduce(
( len(data.shape) == 2, r_points is None, z_points is None )
):
assert "R" and "Z" in data.dims, \
"Assign R and Z coordinates to the 2D DataArray."
r_points = data.R.values
z_points = data.Z.values
self.rectangular_grid = True
else:
assert ( r_points is not None ) and ( z_points is not None ), \
"Provide r and z points for the data (see doc-string)"
elif isinstance(data, np.ndarray):
assert ( r_points is not None ) and ( z_points is not None ), \
"Provide r and z points for the data (see doc-string)"
self.data = data
else:
raise TypeError("Provide 'data' as xr.DataArray or as np.ndarray")
# Make sure dimensions of coordinate arrays are consistent
assert len(r_points.shape) == len(z_points.shape), \
"Inputs 'r_points' and 'z_points' must be of equal rank"
assert len(r_points.shape) <= 2, \
"Inputs 'r_points' and 'z_points' must be arrays of rank 1 or 2."
self.data_shape = self.data.shape
self.data_size = np.prod(self.data_shape)
self.is_structured = (len(r_points.shape) == 1 and len(z_points.shape) == 1)
# Distinguish between 2D rectangular grid and unstructured 2D surface
if self.is_structured:
self._setup_unstructured(r_points, z_points)
else:
self._setup_structured(r_points, z_points)
# Set up the interpolator depending on the grid type
if self.rectangular_grid:
if not (self.z_basis.size, self.r_basis.size) == self.data.shape:
self.data = self.data.T
self.interpolator = RectBivariateSpline(
self.z_basis, self.r_basis, self.data, kx=kr, ky=kz, s=smoothing
)
else:
if self.is_structured:
self.r_u = r_points
self.z_u = z_points
self.data_u = self.data
else:
self.r_u = np.reshape(r_points, self.data_size, order="C")
self.z_u = np.reshape(z_points, self.data_size, order="C")
self.data_u = np.reshape(self.data, self.data_size, order="C")
self.tck = bisplrep(
self.r_u, self.z_u, self.data, kx=kr, ky=kz, s=smoothing
)
[docs]
def __call__(
self,
r_eval: np.ndarray,
z_eval: np.ndarray,
dr: int=0,
dz: int=0,
grid: bool=True
):
"""Return the spline evaluations at points given."""
if self.rectangular_grid:
return self.interpolator(z_eval, r_eval, dx=dr, dy=dz, grid=grid)
else:
return bisplev(r_eval, z_eval, self.tck, dx=dr, dy=dz)
def _setup_structured(self, r_points, z_points):
"""
Initialize the structured interpolator.
Internal setup routine to distinguish between regular rectangular grids
and unstructured grids for 1D coordinate arrays
"""
assert ( r_points.shape == z_points.shape == self.data.shape ), \
"Shape mismatch between 'r_points', 'z_points' and 'data'."
r_unique = np.unique(r_points)
z_unique = np.unique(z_points)
if ( r_unique.size * z_unique.size == self.data_size ):
RR, ZZ = np.meshgrid(r_unique, z_unique)
if np.all(
np.logical_and(
np.isclose(r_points, RR),
np.isclose(z_points, ZZ)
)
):
self.rectangular_grid = True
self.r_basis = r_unique
self.z_basis = z_unique
else:
self.rectangular_grid = False
def _setup_unstructured(self, r_points, z_points):
"""
Initialize the unstructured interpolator.
Internal setup routine to distinguish between regular rectangular grids
and unstructured grids for 2D coordinate arrays
"""
# Unstructured grid if coords and data are equal sized 1D arrays
if ( r_points.shape == z_points.shape == self.data_shape ):
self.rectangular_grid = False
# For 2D data check if the basis vectors are strictly ascending
elif ( r_points.size * z_points.size == self.data_size ):
assert np.all(np.diff(r_points) > 0), \
"Basis vector in r-direction is not strictly ascending."
assert np.all(np.diff(z_points) > 0), \
"Basis vector in z-direction is not strictly ascending."
self.r_basis = r_points
self.z_basis = z_points
self.rectangular_grid = True
else:
raise RuntimeError(
"Dimension mismatch between 'r_points', 'z_points' and 'data'."
)