"""Classes for tracing vector fields."""
import numpy as np
import xarray as xr
from abc import ABC
from typing import Union, List, Callable
from scipy.integrate import solve_ivp
from scipy.integrate._ivp.ivp import OdeResult
from .spline_interpolator_2d_m import SplineInterpolator2D
from torx.geometry.polygon_2d_m import Polygon2D
from torx.autodoc_decorators_m import autodoc_class
[docs]
@autodoc_class
class VectorFieldTracer(ABC):
"""Allows to trace a given vector field."""
[docs]
def __init__(
self,
vector_field: xr.DataArray,
):
"""
Initialize the vectorfield tracer.
Initialize separate SplineInterpolator2D instances for all
components of the provided vector field.
"""
assert "vector" in vector_field.dims, \
"Dimension 'vector' is missing in DataArray."
assert ["eR", "eZ", "ePhi"] in vector_field.vector, \
"Provided vector field is not in toroidal coordinates."
self.vector_field = vector_field
self.r_basis = self.vector_field.R.values
self.z_basis = self.vector_field.Z.values
self.r_min, self.r_max = self.r_basis.min(), self.r_basis.max()
self.z_min, self.z_max = self.z_basis.min(), self.z_basis.max()
self.spline_domain_polygon = Polygon2D(
[self.r_min, self.r_max, self.r_max, self.r_min],
[self.z_min, self.z_min, self.z_max, self.z_max]
)
self.vector_field_r = vector_field.sel(vector="eR")
self.vector_field_z = vector_field.sel(vector="eZ")
self.vector_field_phi = vector_field.sel(vector="ePhi")
self.interpolator_r = SplineInterpolator2D(self.vector_field_r)
self.interpolator_z = SplineInterpolator2D(self.vector_field_z)
self.interpolator_phi = SplineInterpolator2D(self.vector_field_phi)
@property
def direction(self) -> str:
"""Direction to trace the vector field, fwd or bwd."""
return self._direction
@direction.setter
def direction(self, direction: str):
"""Set the trace direction."""
if direction == "fwd":
self._dir_fac = 1
elif direction == "bwd":
self._dir_fac = -1
else:
raise ValueError("Choose 'fwd' or 'bwd' as direction!")
self._direction = direction
[docs]
def toroidal_integration(
self,
r_initial: float,
z_initial: float,
phi_initial: float,
phi_to_max: float,
tolerance: float=1e-4,
method: str="DOP853",
direction: str="fwd",
check_out_of_domain: bool=True,
events: Union[List[Callable], Callable]=None,
**integrator_kwargs
) -> OdeResult:
"""
Perform an integration with phi as the independent variable.
Returns the (x,y) positions of the trace, as well as the fieldline
length.
"""
self.direction = direction
self._check_initial_state(r_initial, z_initial, phi_initial)
# Handle out_of_bounds event and check format of the provided events
if check_out_of_domain:
events = self._add_event(events, self._out_of_domain)
solution = solve_ivp(
fun=self.toroidal_integration_equation,
t_span=[phi_initial, phi_initial + self._dir_fac * phi_to_max],
y0=np.array([r_initial, z_initial, 0.0]),
events=events,
dense_output=True,
rtol=tolerance,
atol=tolerance,
method=method,
**integrator_kwargs,
)
return solution
[docs]
def toroidal_integration_equation(self, phi, state: np.ndarray) -> np.ndarray:
"""
Equation that is required to integrate toroidally around the torus.
The fieldline length is returned as the third element of the state
vector.
"""
r, z = state[0], state[1]
vec_r = self.interpolator_r(r, z, grid=False)
vec_z = self.interpolator_z(r, z, grid=False)
vec_phi = self.interpolator_phi(r, z, grid=False)
if np.isclose(vec_phi, 0.0):
raise ZeroDivisionError(
f"Vanishing toroidal component of the vector field at \
{r, z, phi}, tracing not possible."
)
d_state = np.zeros_like(state)
d_state[0] = r * vec_r / vec_phi
d_state[1] = r * vec_z / vec_phi
d_state[2] = r / vec_phi * np.sqrt(vec_r**2 + vec_z**2 + vec_phi**2)
return d_state
[docs]
def rz_integration(
self,
r_initial: float,
z_initial: float,
t_max: float,
tolerance: float=1e-4,
method: str="DOP853",
direction: str="fwd",
check_out_of_domain: bool=True,
events: Union[List[Callable], Callable]=None,
**integrator_kwargs,
) -> OdeResult:
"""
Perform a parametric integration in the RZ plane.
Returns the (x,y) positions of the trace, as well as the fieldline
length.
"""
self.direction = direction
self._check_initial_state(r_initial, z_initial, 0.0)
assert (t_max > 0.0), \
f"Should always use a positive value for timeout_trace"
# Returns -1 if the magnitude of the vector field at the current
# location if sufficiently close to zero
def vanishing_vector_magnitude(_, state):
eps = np.finfo(type(state[2])).eps
return -1 if np.isclose(state[2], 0.0, atol=eps) else 1
vanishing_vector_magnitude.terminal = True
vanishing_vector_magnitude.direction = -1.0
# Handle events to be passed and check format
if check_out_of_domain:
events = self._add_event(
events, [self._out_of_domain, vanishing_vector_magnitude]
)
else:
events = self._add_event(events, vanishing_vector_magnitude )
solution = solve_ivp(
fun=self.rz_integration_equation,
t_span=[0.0, t_max],
y0=np.array([r_initial, z_initial, 0.0]),
events=events,
dense_output=True,
rtol=tolerance,
atol=tolerance,
method=method,
**integrator_kwargs,
)
return solution
[docs]
def rz_integration_equation(self, _, state: np.ndarray) -> np.ndarray:
"""
Equation that is required to integrate parametrically in the RZ plane.
Uses parameter t for integration.
"""
r, z = state[0], state[1]
vec_r = self.interpolator_r(r, z, grid=False)
vec_z = self.interpolator_z(r, z, grid=False)
vec_magnitude = np.sqrt(vec_r**2 + vec_z**2)
d_state = np.zeros_like(state)
d_state[0] = self._dir_fac * vec_r
d_state[1] = self._dir_fac * vec_z
d_state[2] = vec_magnitude
return d_state
def _check_initial_state(
self, r_initial: float, z_initial: float, phi_initial: float
) -> None:
"""Make sure that the initial state is a single point."""
if (np.array(r_initial).size == 1) and \
(np.array(z_initial).size == 1) and \
(np.array(phi_initial).size == 1):
return None
else:
raise RuntimeError("Tracing can only be called for a single point!")
def _add_event(self, events, new_events):
"""Add a new event for 'solve_ivp' in the correct format."""
for new_event in np.atleast_1d(new_events):
if events is None:
events = new_event
elif isinstance(events, Callable):
events = [events, new_event]
elif isinstance(events, list):
events.append(new_event)
else:
raise TypeError(
"'events' must be either a callable or a list of callables"
)
return events
def _out_of_domain(self, _, state):
"""
Check if integrator state is out of bounds.
Return 1 if the state of the integrator is inside the boundary
polygon, otherwise -1.
"""
r, z = state[0], state[1]
return 1 if self.spline_domain_polygon.point_inside(r, z) else -1
_out_of_domain.terminal = True
_out_of_domain.direction = -1