"""Contains functionality to interface map 3D."""
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.sparse import csr_matrix
from torx import combine_obj_list
from ..trunk.map_m import Map
from ..operators.parallel_operators_m import (
parallel_gradient_3D,
parallel_divergence_3D,
map_to_staggered_3D,
map_to_canonical_3D
)
from torx.autodoc_decorators_m import autodoc_class
from functools import cached_property
[docs]
@autodoc_class
class Map3D:
"""
High-level interface to multiple grillix/trunk/map_plane*.nc files.
Parameters
----------
- `dirpath` (Path): Directory containing map files
- `planes` (int, slice, or list of int): Planes to load (0-indexed).
If None, load all planes.
"""
[docs]
def __init__(self, dirpath: Path, planes=None):
"""Initialize the map 3D."""
self.dirpath = dirpath
if isinstance(planes, slice):
self.planes = planes
else:
self.planes = np.atleast_1d(planes)
# Find all files matching the pattern
path_list = sorted(list(dirpath.glob("map_plane*.nc")))
if not path_list:
raise FileNotFoundError("Error: No files matching 'map_plane*.nc' were found in the directory.")
all_filepaths = np.atleast_1d(path_list)
self.nplanes = len(all_filepaths) # Total number of planes
all_phi = np.linspace(0, 2*np.pi, self.nplanes, endpoint=False)
# Filter files based on planes
if planes is None:
self.phi = all_phi
self.filepaths = xr.DataArray(
all_filepaths,
dims="phi",
coords={"phi": self.phi}
)
else:
self.phi = all_phi[self.planes]
self.filepaths = xr.DataArray(
all_filepaths[self.planes],
dims="phi",
coords={"phi": self.phi}
)
#Load Maps
self.maps = xr.apply_ufunc(
self._create_map,
self.filepaths,
input_core_dims=[[]],
output_core_dims=[[]],
vectorize=True,
output_dtypes=[object]
).assign_coords(phi=self.phi)
self.dphi = self.maps.isel(phi=0).item().dphi
self.intorder = self.maps.isel(phi=0).item().intorder
self.xorder = self.maps.isel(phi=0).item().xorder
def _create_map(self, filepath: str) -> Map:
"""Create Map object from filepath."""
return Map(Path(filepath))
[docs]
def get_plane(self, plane: int) -> Map:
"""Return map at plane."""
return self.maps.isel(phi=plane).item()
[docs]
def __getitem__(self, plane: int) -> Map:
"""Allow indexing: map3d[0]."""
return self.get_plane(plane)
[docs]
def __len__(self) -> int:
"""Return number of planes."""
return len(self.maps)
def _build_parallel_csr(self, attribute_name: str) -> xr.DataArray:
"""Apply the CSR building function in a vectorized manner."""
return xr.apply_ufunc(
build_csr_from_map,
self.maps,
attribute_name,
input_core_dims=[[], []],
output_core_dims=[[]],
vectorize=True,
output_dtypes=[object],
)
@cached_property
def dpar_fwd_half(self) -> xr.DataArray:
"""Parallel distance to the half forward toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_fwd_half", "phi", self.phi)
@cached_property
def dpar_bwd_half(self) -> xr.DataArray:
"""Parallel distance to the half reverse toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_bwd_half", "phi", self.phi)
@cached_property
def dpar_fwd_full(self) -> xr.DataArray:
"""Parallel distance to the forward toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_fwd_full", "phi", self.phi)
@cached_property
def dpar_bwd_full(self) -> xr.DataArray:
"""Parallel distance to the reverse toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_bwd_full", "phi", self.phi)
@cached_property
def fluxbox_vol(self) -> xr.DataArray:
"""Volume of the flux box defined around the point (+/- 1/2 phi)."""
return combine_obj_list(self.maps.values, "fluxbox_vol", "phi", self.phi)
@cached_property
def dpar_fwd_half_stag(self) -> xr.DataArray:
"""Parallel distance to the half forward toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_fwd_half_stag", "phi", self.phi)
@cached_property
def dpar_bwd_half_stag(self) -> xr.DataArray:
"""Parallel distance to the half reverse toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_bwd_half_stag", "phi", self.phi)
@cached_property
def dpar_fwd_full_stag(self) -> xr.DataArray:
"""Parallel distance to the forward toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_fwd_full_stag", "phi", self.phi)
@cached_property
def dpar_bwd_full_stag(self) -> xr.DataArray:
"""Parallel distance to the reverse toroidal plane."""
return combine_obj_list(self.maps.values, "dpar_bwd_full_stag", "phi", self.phi)
@cached_property
def fluxbox_vol_stag(self) -> xr.DataArray:
"""Volume of the flux box defined around the point (+/- 1/2 phi)."""
return combine_obj_list(self.maps.values, "fluxbox_vol_stag", "phi", self.phi)
@cached_property
def mfwd(self) -> xr.DataArray:
"""Map matrix in backward(phi) direction (trace of -dphi/2)."""
return self._build_parallel_csr("mfwd")
@cached_property
def mbwd(self) -> xr.DataArray:
"""Map matrix in backward(phi) direction (trace of -dphi/2)."""
return self._build_parallel_csr("mbwd")
@cached_property
def mfwd_stag(self) -> xr.DataArray:
"""Map matrix in forward direction (trace of +dphi/2) for a plane."""
return self._build_parallel_csr("mfwd_stag")
@cached_property
def mbwd_stag(self) -> xr.Dataset:
"""Map matrix in backward direction (trace of -dphi/2)."""
return self._build_parallel_csr("mbwd_stag")
@cached_property
def qfwd(self) -> xr.DataArray:
"""Parallel gradient matrix in forward direction (trace +dphi/2)."""
return self._build_parallel_csr("qfwd")
@cached_property
def qbwd(self) -> xr.DataArray:
"""Parallel gradient matrix in backward direction (trace -dphi/2)."""
return self._build_parallel_csr("qbwd")
@cached_property
def pfwd(self) -> xr.Dataset:
"""Parallel divergence matrix in forward(phi/2) direction."""
return self._build_parallel_csr("pfwd")
@cached_property
def pbwd(self) -> xr.Dataset:
"""Parallel divergence matrix in backward(-phi/2) direction."""
return self._build_parallel_csr("pbwd")
@cached_property
def mfwd_full(self) -> xr.Dataset:
"""Map matrix in forward(phi) direction (trace of dphi)."""
return self._build_parallel_csr("mfwd_full")
@cached_property
def mbwd_full(self) -> xr.Dataset:
"""Map matrix in backward(phi) direction (trace of -dphi)."""
return self._build_parallel_csr("mbwd_full")
@cached_property
def mfwd_full_stag(self) -> xr.Dataset:
"""Map matrix in forward(phi) direction (trace of dphi)."""
return self._build_parallel_csr("mfwd_full_stag")
@cached_property
def mbwd_full_stag(self) -> xr.Dataset:
"""Map matrix in backward(phi) direction (trace of -dphi)."""
return self._build_parallel_csr("mbwd_full_stag")
[docs]
def clear_cached_properties(self):
"""Clear all attributes decorated with @cached_property."""
# We iterate over the class members to find cached_property descriptors
cls = self.__class__
for name, value in cls.__dict__.items():
if isinstance(value, cached_property):
# If the property has been computed, it exists in the instance __dict__
if name in self.__dict__:
delattr(self, name)
[docs]
def parallel_gradient(self, input_array: xr.DataArray) -> xr.DataArray:
"""Parallel gradient."""
return parallel_gradient_3D(self, input_array)
[docs]
def parallel_divergence(self, input_array: xr.DataArray) -> xr.DataArray:
"""Parallel divergence."""
return parallel_divergence_3D(self, input_array)
[docs]
def map_to_staggered(self, input_array: xr.DataArray) -> xr.DataArray:
"""Map to staggered plane."""
return map_to_staggered_3D(self, input_array)
[docs]
def map_to_canonical(self, input_array: xr.DataArray) -> xr.DataArray:
"""Map to canonical plane."""
return map_to_canonical_3D(self, input_array)
def build_csr_from_map(map_object, attr_name: str) -> csr_matrix:
"""Take an xarray DataArray block and return a CSR matrix."""
# Access the required matrix dynamically
try:
matrix = getattr(map_object, attr_name)
except AttributeError:
raise AttributeError(f"Map instance does not have attribute: {attr_name}")
# Build the CSR Matrix
values = matrix.val
indices_0based = matrix.j - 1
indptr_0based = matrix.i - 1
rows = matrix.ndim
cols = matrix.ncol
return csr_matrix((values, indices_0based, indptr_0based),
shape=(rows, cols))