Source code for torx.analysis.lineouts.lineout_base_m

"""Defines common functionality across all lineouts."""
import numpy as np
import xarray as xr

from abc import ABC

from torx import make_xarray
from torx.normalization.normalization_m import Normalization
from torx.equilibrium.equilibrium_m import EquilibriumBaseClass
from torx.grid.grid_2d_m import Grid2D
from torx.analysis.statistics_m import statistics_ufunc
from torx.analysis.csr_linear_interpolation_m import make_matrix_interp
from torx.vector import vector_magnitude, vector_cross, vector_dot
from torx.vector import poloidal_vector, cylindrical_vector
from torx.geometry import compute_arc_length
from torx.autodoc_decorators_m import autodoc_class

@autodoc_class
class LineoutBase(ABC):
    """Lineout base class which declares shared methods."""

    @property
    def r_points(self) -> np.ndarray:
        """R coordinates for the lineout (normalized to R0)."""
        return self._r_points

    @r_points.setter
    def r_points(self, values: np.ndarray):
        """Set the R coordinates of the lineout."""
        self._r_points = values

    @property
    def z_points(self) -> np.ndarray:
        """Z coordinates for the lineout (normalized to R0)."""
        return self._z_points

    @z_points.setter
    def z_points(self, values: np.ndarray):
        """Set the Z coordinates of the lineout."""
        self._z_points = values

    def make_interpolation_matrix(
        self,
        grid: Grid2D,
        use_source_points: bool = False
    ):
        """
        Build an interpolation matrix for the lineout.

        A CSR-matrix which can be multiplied onto an array of shape (points)
        and return bilinearly interpolated values along the lineout.
        """
        if use_source_points:
            self.r_points = self.r_source
            self.z_points = self.z_source

        self.interpolation_matrix = make_matrix_interp(
            grid.r_u.values, grid.z_u.values, self.r_points, self.z_points
        )
        self.grid_size, self.curve_size = grid.size, self.r_points.size


    def _interpolate_core(self, input_array: np.ndarray):
        """Apply the interpolation matrix to an array with dimension points."""
        return self.interpolation_matrix * input_array

    def interpolate(self, input_array: xr.DataArray):
        """
        Apply the interpolation matrix to an array with parallelization.

        Use a dask-parallelized algorithm to apply the CSR-matrix
        interpolation over the input array.
        """
        interpolated = xr.apply_ufunc(
            self._interpolate_core,
            input_array.chunk({"points": -1}),
            input_core_dims=[["points"],],
            output_core_dims=[["interp_points"]],
            vectorize=True,
            dask="parallelized",
            dask_gufunc_kwargs={"output_sizes": {"interp_points": self.curve_size}},
            output_dtypes=[np.float64],
        )

        interpolated.attrs = input_array.attrs

        return interpolated


    def statistics(
        self,
        data: xr.DataArray,
        function: str,
        exclude_dims: list = [],
    ) -> xr.DataArray:
        """Apply statistics methods from torx.analysis."""
        if not "interp_points" in data.dims:
            data = self.interpolate(data)

        return statistics_ufunc(data, function, exclude_dims)


    def rho(self, equi: EquilibriumBaseClass):
        """Return the normalized flux-surface label for the lineout points."""
        return equi.normalized_flux_surface_label(
            self.r_points, self.z_points, grid=False
        )

    def poloidal_arc_length(self, norm: Normalization = None, skipna=False):
        """
        Calculate the cumulative arc length of the lineout starting with 0.

        Uses simple finite differences.
        Optionally can be run in periodic mode, where the first element is
        inserted at the end again.
        Returns the arc length normalized w.r.t. R0
        """
        x = compute_arc_length(self.r_points, self.z_points, skipna=skipna)

        return make_xarray(
            x, norm=("R0" if norm is None else norm["R0"]),
            name="poloidal arc length"
        )

    def surface_area(self, rotation_axis="Z", norm: Normalization = None):
        """
        Calculate the surface area of the lineout.

        Uses an algorithm rotating it around the Z (default) or R axis.
        Uses simple finite differences.
        Returns the surface area normalized w.r.t. R0**2
        """
        t = self.poloidal_arc_length()
        x_diff = np.diff(self.r_points, append=self.r_points[0])
        y_diff = np.diff(self.z_points, append=self.z_points[0])
        t_diff = np.diff(t, append=t[0])
        dxdt = x_diff / t_diff
        dydt = y_diff / t_diff

        s = np.sqrt(dxdt**2 + dydt**2)

        if rotation_axis == "Z":
            area = 2.0 * np.pi * np.trapezoid(s * self.r_points, x=t)

        elif rotation_axis == "R":
            area = 2.0 * np.pi * np.trapezoid(s * self.z_points, x=t)

        else:
            assert (
                rotation_axis == "Z" or rotation_axis == "R"
            ), "Error: Rotation axis must be either R or Z!"

        return make_xarray(
            area,
            norm=("R0**2" if norm is None else norm["R0"] ** 2),
            name="surface area",
        )

    def volume(self, rotation_axis="Z", norm: Normalization = None):
        """
        Calculate the volume of the lineout.

        Uses an algorithm rotating it around the Z (default) or R axis.
        Uses simple finite differences.
        Returns the volume normalized w.r.t. R0**3
        """
        t = self.poloidal_arc_length()
        t_diff = np.diff(t, append=t[0])

        if rotation_axis == "Z":
            y_diff = np.diff(self.z_points, append=self.z_points[0])
            dydt = y_diff / t_diff
            volume = np.pi * np.trapezoid(dydt * self.r_points**2, x=t)

            # The sign of the volume depends on the order of the z values
            if self.z_points[0] > self.z_points[-1]:
                volume = -volume

        elif rotation_axis == "R":
            x_diff = np.diff(self.r_points, append=self.r_points[0])
            dxdt = x_diff / t_diff
            volume = np.pi * np.trapezoid(dxdt * self.z_points**2, x=t)

            # The sign of the volume depends on the order of the r values
            if self.r_points[0] > self.r_points[-1]:
                volume = -volume

        else:
            assert (
                rotation_axis == "Z" or rotation_axis == "R"
            ), "Error: Rotation axis must be either R or Z!"

        return make_xarray(
            volume, norm=("R0**3" if norm is None else norm["R0"] ** 3), name="volume"
        )

    @property
    def tangent_unit_vector(self):
        """Return a tangent unit vector at each point of the lineout."""
        r_vec = np.zeros(len(self.r_points))
        z_vec = np.zeros(len(self.r_points))

        # Take care of borders manually
        r_vec[0] = self.r_points[1] - self.r_points[0]
        z_vec[0] = self.z_points[1] - self.z_points[0]
        r_vec[-1] = self.r_points[-1] - self.r_points[-2]
        z_vec[-1] = self.z_points[-1] - self.z_points[-2]

        # Inner points via difference of staggered points
        r_stag = self.r_points[0:-1] + np.diff(self.r_points) / 2
        z_stag = self.z_points[0:-1] + np.diff(self.z_points) / 2
        r_vec[1:-1] = np.diff(r_stag)
        z_vec[1:-1] = np.diff(z_stag)

        x = poloidal_vector(r_vec, z_vec, dims=["interp_points"])
        x = x / vector_magnitude(x)
        return x

    @property
    def normal_unit_vector(self):
        """Return a normal unit vector at each point of the lineout."""
        e_phi = cylindrical_vector(
            np.zeros(self.r_points.size),
            np.ones(self.r_points.size),
            np.zeros(self.r_points.size),
            dims=["interp_points"],
        )
        x = self.tangent_unit_vector
        n = vector_cross(e_phi, x)
        n = n / vector_magnitude(n)
        return n


    def _magnetic_field_incidence(self, equi, direction: str):
        """Find incidence angles between magfield and the normal."""
        # Find the field vector to project upon
        assert direction in ["parallel", "radial", "poloidal"],\
            "Choose either 'parallel', 'radial' or 'poloidal' as projection direction"
        if direction == "parallel":
            field_vector = equi.magfield_vector(self.r_points, self.z_points,
                                                normalize=True)
        elif direction == "radial":
            field_vector = equi.magfield_vector_radial(self.r_points, self.z_points,
                                                       normalize=True)
        elif direction == "poloidal":
            field_vector = equi.magfield_vector_poloidal(self.r_points, self.z_points,
                                                         normalize=True)

        return vector_dot(
            self.normal_unit_vector,
            field_vector.rename({"points": "interp_points"})
        )

    def parallel_incidence(self, equi):
        """Return the parallel component of magfield projected to the normal."""
        return self._magnetic_field_incidence(equi, direction="parallel")

    def radial_incidence(self, equi):
        """Return the radial component of magfield projected to the normal."""
        return self._magnetic_field_incidence(equi, direction="radial")

    def poloidal_incidence(self, equi):
        """Return the poloidal component of magfield projected to the normal."""
        return self._magnetic_field_incidence(equi, direction="poloidal")

    def __surface_integral_core(self, flux):
        """Calculate surface integrals (core routine)."""
        # If called with apply_ufunc, convert back to xarray for interpolation
        if type(flux) is not xr.DataArray:
            assert (
                len(flux.shape) == 2
            ), "flux must only contain 2 dimensions (points, vector)!"
            flux = xr.DataArray(flux, dims=("points", "vector"))
        else:
            assert (
                len(flux.dims) == 2
            ), "flux must only contain 2 dimensions (points, vector)!"

        normal = self.normal_unit_vector
        # Treat the case where flux is defined on a vector already interpolated onto the lineout
        if flux.shape != normal.shape:
            interpolation = self.interpolate(flux)
        else:
            interpolation = flux
            # Rename dimension for dot product to apply correctly
            interpolation = interpolation.rename({"points": "interp_points"})

        # Generate the arc length along the lineout
        x = self.poloidal_arc_length()
        y = vector_dot(normal, interpolation) * self.r_points
        y = xr.where(np.isnan(y), 0, y)

        # Perform the integration with the trapezoidal rule
        integral = 2.0 * np.pi * np.sum(np.trapezoid(y, x=x))
        return integral


    def surface_integral(self, flux: xr.DataArray, norm):
        """
        Calculate the surface integral of a flux through the lineout.

        Projects the flux onto the lineout normal and integrates toroidally
        over 2pi.
        """
        assert "points" in flux.dims, "flux must contain dimension points!"
        assert "vector" in flux.dims, "flux must contain dimension vector!"

        if not np.all(flux.chunksizes["points"] == flux.points.size):
            flux = flux.chunk(chunks={"points": -1})

        if len(flux.dims) == 2:
            return make_xarray(
                self.__surface_integral_core(flux), norm=(flux.norm * norm.R0**2)
            )
        else:
            integral = xr.apply_ufunc(
                self.__surface_integral_core,
                flux,
                input_core_dims=[["points", "vector"]],
                vectorize=True,
                dask="parallelized",
                output_dtypes=[np.float64],
            )
            integral.attrs["norm"] = flux.norm * norm.R0**2
            return integral

    def setup_perpendicular_gradient(self, grid: Grid2D) -> None:
        """
        Initialize the perpendicular gradient operator.

        Naive way to setup a perpendicular gradient of grid-based data:
        Since the lineout points are generally not grid points, the nearest
        neighboring grid points are used to approximate the derivatives.
        Generally, a forward stencil is used, but if a lineout point is
        exactly on a grid point a central 1st order stencil is used.
        """
        self.R0 = grid.R0
        self.grid_size = grid.size
        self.grid_spacing = grid.spacing

        self.nearest_grid_indices = np.empty(self.curve_size, dtype=int)
        self._up_indices    = np.empty(self.curve_size, dtype=int)
        self._down_indices  = np.empty(self.curve_size, dtype=int)
        self._left_indices  = np.empty(self.curve_size, dtype=int)
        self._right_indices = np.empty(self.curve_size, dtype=int)
        self._grad_fac_r = np.ones(self.curve_size, dtype=float)
        self._grad_fac_z = np.ones(self.curve_size, dtype=float)

        for i, (r, z) in enumerate(zip(self.r_points, self.z_points)):

            r_grid = grid.r_u.values
            z_grid = grid.z_u.values
            point_dist = (r_grid - r)**2 + (z_grid - z)**2
            l = np.argmin(point_dist)
            self.nearest_grid_indices[i] = l

            # Find neighboring grid points in r-direction
            if r_grid[l] < r:
                right_index = np.argmin((r_grid - (r + grid.spacing))**2 +
                                        (z_grid - z)**2)
                left_index = l
            elif r_grid[l] > r:
                left_index = np.argmin((r_grid - (r - grid.spacing))**2 +
                                       (z_grid - z)**2)
                right_index = l
            elif r_grid[l] == r:
                left_index = np.argmin((r_grid - (r - grid.spacing))**2 +
                                       (z_grid - z)**2)
                right_index = np.argmin((r_grid - (r + grid.spacing))**2 +
                                        (z_grid - z)**2)
                if not (right_index == l or left_index == l):
                    # We are exactly on a grid point and not at the domain edge:
                    # Need to account for the central stencil coefficient.
                    self._grad_fac_r[i] = 0.5
            else:
                raise RuntimeError(f"Lineout point number {i} at r,z={r, z} \
                                   seems to be off-grid. Did you run \
                                   'lineout.find_points_on_grid()' before?")

            # Find neighboring grid points in z-direction
            if z_grid[l] < z:
                up_index = np.argmin((r_grid - r)**2 +
                                     (z_grid - (z + + grid.spacing))**2)
                down_index = l
            elif z_grid[l] > z:
                down_index = np.argmin((r_grid - r)**2 +
                                       (z_grid - (z - grid.spacing))**2)
                up_index = l
            elif z_grid[l] == z:
                up_index = np.argmin((r_grid - r)**2 +
                                     (z_grid - (z + + grid.spacing))**2)
                down_index = np.argmin((r_grid - r)**2 +
                                       (z_grid - (z - grid.spacing))**2)
                if not (up_index == l or down_index == l):
                    # We are exactly on a grid point and not at the domain edge:
                    # Need to account for the central stencil coefficient.
                    self._grad_fac_z[i] = 0.5
            else:
                raise RuntimeError(f"Lineout point number {i} at r,z={r, z} \
                                   seems to be off-grid. Did you run \
                                   'lineout.find_points_on_grid()' before?")

            self._up_indices[i] = up_index
            self._down_indices[i] = down_index
            self._left_indices[i] = left_index
            self._right_indices[i] = right_index

    def perpendicular_gradient(self, array: xr.DataArray) -> xr.DataArray:
        """Return the perpendicular gradient of the given array as a vector."""
        assert (
            hasattr(self, "_up_indices") and hasattr(self, "_down_indices") and
            hasattr(self, "_left_indices") and hasattr(self, "_right_indices")
        ), \
            "Run 'lineout.setup_perpendicular_gradient(grid)' before calling."

        assert isinstance(array, xr.DataArray), "Provide data as xr.DataArray."
        assert "points" in array.dims, "Provide unstructured data."
        assert array.points.size == self.grid_size, \
            "Dimensions of provided array do not match grid size."

        # Take forward stencil gradient components
        grad_r = self._grad_fac_r / self.grid_spacing * (
            array.isel(points=self._right_indices) -
            array.isel(points=self._left_indices)
        )
        grad_z = self._grad_fac_z / self.grid_spacing * (
            array.isel(points=self._up_indices) -
            array.isel(points=self._down_indices)
        )

        # If no norm attribute given, assume normalized units
        grad_norm = array.norm / self.R0 if "norm" in array.attrs else None

        return poloidal_vector(grad_r.rename({"points": "interp_points"}),
                               grad_z.rename({"points": "interp_points"}),
                               attrs={"norm": grad_norm})

    def perpendicular_gradient_component(
        self,
        array: xr.DataArray,
        component="R",
    ) -> xr.DataArray:
        """
        Return a specific component of the perpendicular gradient.

        Component can either be a string specifying "R" or "Z", or
        alternatively an xarray DataArray representing a vector where to
        project the gradient to.
        """
        grad = self.perpendicular_gradient(array)

        if isinstance(component, str):
            if component == "R":
                return grad.sel(vector="eR")
            elif component == "Z":
                return grad.sel(vector="eZ")
            else:
                raise RuntimeError(f"Component {component} is not available " \
                                   "for perpendicular gradient.")

        elif isinstance(component, xr.DataArray):
            assert "vector" in component.dims, \
                   "Component should be a vector with dimension vector " \
                   "in the DataArray."

            grad_proj = vector_dot(grad, component)

            if ("norm" in array.attrs and "norm" in component.attrs):
                grad_proj.attrs["norm"] = component.norm * grad_proj.norm

            return grad_proj

    def radial_perpendicular_gradient(
        self,
        array: xr.DataArray,
        equi: EquilibriumBaseClass,
    ) -> xr.DataArray:
        """Return the radial component of the perpendicular gradient."""
        e_rho = equi.magfield_vector_radial(self.r_points, self.z_points,
                                            normalize=True)
        e_rho = e_rho.rename({"points": "interp_points"})

        return self.perpendicular_gradient_component(array, e_rho)