Source code for torx.specializations.grillix.trunk.map3d_m

"""Contains functionality to interface map 3D."""
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.sparse import csr_matrix

from torx import combine_obj_list
from ..trunk.map_m import Map
from ..operators.parallel_operators_m import (
    parallel_gradient_3D,
    parallel_divergence_3D,
    map_to_staggered_3D,
    map_to_canonical_3D
)
from torx.autodoc_decorators_m import autodoc_class
from functools import cached_property

[docs] @autodoc_class class Map3D: """ High-level interface to multiple grillix/trunk/map_plane*.nc files. Parameters ---------- - `dirpath` (Path): Directory containing map files - `planes` (int, slice, or list of int): Planes to load (0-indexed). If None, load all planes. """
[docs] def __init__(self, dirpath: Path, planes=None): """Initialize the map 3D.""" self.dirpath = dirpath if isinstance(planes, slice): self.planes = planes else: self.planes = np.atleast_1d(planes) # Find all files matching the pattern path_list = sorted(list(dirpath.glob("map_plane*.nc"))) if not path_list: raise FileNotFoundError("Error: No files matching 'map_plane*.nc' were found in the directory.") all_filepaths = np.atleast_1d(path_list) self.nplanes = len(all_filepaths) # Total number of planes all_phi = np.linspace(0, 2*np.pi, self.nplanes, endpoint=False) # Filter files based on planes if planes is None: self.phi = all_phi self.filepaths = xr.DataArray( all_filepaths, dims="phi", coords={"phi": self.phi} ) else: self.phi = all_phi[self.planes] self.filepaths = xr.DataArray( all_filepaths[self.planes], dims="phi", coords={"phi": self.phi} ) #Load Maps self.maps = xr.apply_ufunc( self._create_map, self.filepaths, input_core_dims=[[]], output_core_dims=[[]], vectorize=True, output_dtypes=[object] ).assign_coords(phi=self.phi) self.dphi = self.maps.isel(phi=0).item().dphi self.intorder = self.maps.isel(phi=0).item().intorder self.xorder = self.maps.isel(phi=0).item().xorder
def _create_map(self, filepath: str) -> Map: """Create Map object from filepath.""" return Map(Path(filepath))
[docs] def get_plane(self, plane: int) -> Map: """Return map at plane.""" return self.maps.isel(phi=plane).item()
[docs] def __getitem__(self, plane: int) -> Map: """Allow indexing: map3d[0].""" return self.get_plane(plane)
[docs] def __len__(self) -> int: """Return number of planes.""" return len(self.maps)
def _build_parallel_csr(self, attribute_name: str) -> xr.DataArray: """Apply the CSR building function in a vectorized manner.""" return xr.apply_ufunc( build_csr_from_map, self.maps, attribute_name, input_core_dims=[[], []], output_core_dims=[[]], vectorize=True, output_dtypes=[object], ) @cached_property def dpar_fwd_half(self) -> xr.DataArray: """Parallel distance to the half forward toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_fwd_half", "phi", self.phi) @cached_property def dpar_bwd_half(self) -> xr.DataArray: """Parallel distance to the half reverse toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_bwd_half", "phi", self.phi) @cached_property def dpar_fwd_full(self) -> xr.DataArray: """Parallel distance to the forward toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_fwd_full", "phi", self.phi) @cached_property def dpar_bwd_full(self) -> xr.DataArray: """Parallel distance to the reverse toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_bwd_full", "phi", self.phi) @cached_property def fluxbox_vol(self) -> xr.DataArray: """Volume of the flux box defined around the point (+/- 1/2 phi).""" return combine_obj_list(self.maps.values, "fluxbox_vol", "phi", self.phi) @cached_property def dpar_fwd_half_stag(self) -> xr.DataArray: """Parallel distance to the half forward toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_fwd_half_stag", "phi", self.phi) @cached_property def dpar_bwd_half_stag(self) -> xr.DataArray: """Parallel distance to the half reverse toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_bwd_half_stag", "phi", self.phi) @cached_property def dpar_fwd_full_stag(self) -> xr.DataArray: """Parallel distance to the forward toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_fwd_full_stag", "phi", self.phi) @cached_property def dpar_bwd_full_stag(self) -> xr.DataArray: """Parallel distance to the reverse toroidal plane.""" return combine_obj_list(self.maps.values, "dpar_bwd_full_stag", "phi", self.phi) @cached_property def fluxbox_vol_stag(self) -> xr.DataArray: """Volume of the flux box defined around the point (+/- 1/2 phi).""" return combine_obj_list(self.maps.values, "fluxbox_vol_stag", "phi", self.phi) @cached_property def mfwd(self) -> xr.DataArray: """Map matrix in backward(phi) direction (trace of -dphi/2).""" return self._build_parallel_csr("mfwd") @cached_property def mbwd(self) -> xr.DataArray: """Map matrix in backward(phi) direction (trace of -dphi/2).""" return self._build_parallel_csr("mbwd") @cached_property def mfwd_stag(self) -> xr.DataArray: """Map matrix in forward direction (trace of +dphi/2) for a plane.""" return self._build_parallel_csr("mfwd_stag") @cached_property def mbwd_stag(self) -> xr.Dataset: """Map matrix in backward direction (trace of -dphi/2).""" return self._build_parallel_csr("mbwd_stag") @cached_property def qfwd(self) -> xr.DataArray: """Parallel gradient matrix in forward direction (trace +dphi/2).""" return self._build_parallel_csr("qfwd") @cached_property def qbwd(self) -> xr.DataArray: """Parallel gradient matrix in backward direction (trace -dphi/2).""" return self._build_parallel_csr("qbwd") @cached_property def pfwd(self) -> xr.Dataset: """Parallel divergence matrix in forward(phi/2) direction.""" return self._build_parallel_csr("pfwd") @cached_property def pbwd(self) -> xr.Dataset: """Parallel divergence matrix in backward(-phi/2) direction.""" return self._build_parallel_csr("pbwd") @cached_property def mfwd_full(self) -> xr.Dataset: """Map matrix in forward(phi) direction (trace of dphi).""" return self._build_parallel_csr("mfwd_full") @cached_property def mbwd_full(self) -> xr.Dataset: """Map matrix in backward(phi) direction (trace of -dphi).""" return self._build_parallel_csr("mbwd_full") @cached_property def mfwd_full_stag(self) -> xr.Dataset: """Map matrix in forward(phi) direction (trace of dphi).""" return self._build_parallel_csr("mfwd_full_stag") @cached_property def mbwd_full_stag(self) -> xr.Dataset: """Map matrix in backward(phi) direction (trace of -dphi).""" return self._build_parallel_csr("mbwd_full_stag")
[docs] def clear_cached_properties(self): """Clear all attributes decorated with @cached_property.""" # We iterate over the class members to find cached_property descriptors cls = self.__class__ for name, value in cls.__dict__.items(): if isinstance(value, cached_property): # If the property has been computed, it exists in the instance __dict__ if name in self.__dict__: delattr(self, name)
[docs] def parallel_gradient(self, input_array: xr.DataArray) -> xr.DataArray: """Parallel gradient.""" return parallel_gradient_3D(self, input_array)
[docs] def parallel_divergence(self, input_array: xr.DataArray) -> xr.DataArray: """Parallel divergence.""" return parallel_divergence_3D(self, input_array)
[docs] def map_to_staggered(self, input_array: xr.DataArray) -> xr.DataArray: """Map to staggered plane.""" return map_to_staggered_3D(self, input_array)
[docs] def map_to_canonical(self, input_array: xr.DataArray) -> xr.DataArray: """Map to canonical plane.""" return map_to_canonical_3D(self, input_array)
def build_csr_from_map(map_object, attr_name: str) -> csr_matrix: """Take an xarray DataArray block and return a CSR matrix.""" # Access the required matrix dynamically try: matrix = getattr(map_object, attr_name) except AttributeError: raise AttributeError(f"Map instance does not have attribute: {attr_name}") # Build the CSR Matrix values = matrix.val indices_0based = matrix.j - 1 indptr_0based = matrix.i - 1 rows = matrix.ndim cols = matrix.ncol return csr_matrix((values, indices_0based, indptr_0based), shape=(rows, cols))