"""Defines 3D grids made up of collections of 2D grids."""
from abc import ABC
import numpy as np
import xarray as xr
from pathlib import Path
from typing import Optional, Sequence, Union
from torx import Quantity, combine_obj_list
from torx.units_handling_m import check_units
from torx.autodoc_decorators_m import autodoc_class
from torx.grid.grid_2d_m import Grid2D
from functools import cached_property
[docs]
@autodoc_class
class Grid3D(ABC):
"""Represents a 3D grid, composed of an array of 2D grids at each phi."""
[docs]
def __init__(self,
grid: xr.DataArray):
"""Initialize the 3D grid."""
self.grid_array = grid
self.coords = grid.coords
self.dims = grid.dims
self.attrs = grid.attrs
self.name = 'Grid3D'
self.encoding = grid.encoding
self.sizes = grid.sizes
# Get the number of points and shape for each 2D grid
self.npoints = xr.apply_ufunc(
lambda obj: obj.size,
grid,
vectorize=True,
dask="parallelized", # Optional: if your grid objects are dask-backed
output_dtypes=[int]
)
self.shape = xr.apply_ufunc(
lambda obj: obj.shape,
grid,
vectorize=True,
dask="parallelized", # Optional: if your grid objects are dask-backed
output_dtypes=[tuple]
)
[docs]
@classmethod
def from_rz(cls,
r_unstructured,
z_unstructured):
"""Create the 3D grid from xarrays of unstructured R and Z grids."""
assert_value = (type(r_unstructured) == xr.DataArray \
and type(z_unstructured) == xr.DataArray)
assert(assert_value == True), "Unstructured grids must be Xarray DataArrays!"
grid = xr.apply_ufunc(Grid2D, r_unstructured, z_unstructured,
input_core_dims=[["dim_RZ"], ["dim_RZ"]],
output_core_dims=[[]],
vectorize=True,
output_dtypes=[object])
return cls(grid)
[docs]
@classmethod
def from_multigrid_files(cls,
dirpath: Path,
planes: Optional[Union[int, Sequence[int]]] = None,
staggered: bool = False
):
"""
Create a 3D grid object from a directory of multigrid files.
Parameters
----------
dirpath : str
Path to the directory containing multigrid files.
planes : list or None
List of plane indices to load. If None, all planes are loaded.
"""
from torx.specializations.grillix import grid_2d_from_multigrid_file
# Find all files matching the pattern
path_list = sorted(list(dirpath.glob("multigrids_plane*.nc")))
if not path_list:
raise FileNotFoundError("Error: No files matching 'multigrids_plane*.nc' were found in the directory.")
all_filepaths = np.atleast_1d(path_list)
all_phi = np.linspace(0, 2*np.pi, len(all_filepaths), endpoint=False)
# Filter files based on planes
if planes is None:
phi = all_phi
filepaths = xr.DataArray(
all_filepaths,
dims="phi"
)
else:
planes = np.atleast_1d(planes)
phi = all_phi[planes]
filepaths = xr.DataArray(
all_filepaths[planes],
dims="phi"
)
grid = xr.apply_ufunc(grid_2d_from_multigrid_file, filepaths, staggered,
input_core_dims=[[], []],
output_core_dims=[[]],
vectorize=True,
output_dtypes=[object]).assign_coords(phi=phi)
return cls(grid)
@property
def grid_array(self):
"""Return the arrays of values in the grid."""
return self._grid_array
@grid_array.setter
def grid_array(self, value):
"""Set the arrays of values in the grid."""
self._grid_array = value
@cached_property
def r_u(self) -> xr.DataArray:
"""Unstructured R values."""
return combine_obj_list(self.grid_array.values, "r_u", "phi", self.grid_array.coords["phi"].values)
@cached_property
def z_u(self) -> xr.DataArray:
"""Unstructured Z values."""
return combine_obj_list(self.grid_array.values, "z_u", "phi", self.grid_array.coords["phi"].values)
@cached_property
def r_s(self) -> xr.DataArray:
"""Structured R values."""
return combine_obj_list(self.grid_array.values, "r_s", "phi", self.grid_array.coords["phi"].values)
@cached_property
def z_s(self) -> xr.DataArray:
"""Structured Z values."""
return combine_obj_list(self.grid_array.values, "z_s", "phi", self.grid_array.coords["phi"].values)
[docs]
def clear_cached_properties(self):
"""Clear all attributes decorated with @cached_property."""
# We iterate over the class members to find cached_property descriptors
cls = self.__class__
for name, value in cls.__dict__.items():
if isinstance(value, cached_property):
# If the property has been computed, it exists in the instance __dict__
if name in self.__dict__:
delattr(self, name)
def __getattr__(self, name):
"""
Allow use of isel/sel directly on the 3D grid.
One can add any attribute here for convenience.
"""
if name == 'isel':
return self._grid_array.isel
elif name == 'sel':
return self._grid_array.sel
if name in self._grid_array.dims:
return self._grid_array.isel(**{name: slice(None)})
elif name in self._grid_array.coords:
return self._grid_array.sel(**{name: slice(None)})
else:
raise AttributeError(f"'Grid3D' object has no attribute '{name}'")
[docs]
def __getitem__(self, key):
"""
Enable isel and sel to be used directly on the 3D grid.
Use with (phi = number) for example, instead of a dictionary.
"""
if isinstance(key, str) and key in self._grid_array.dims:
return self._grid_array.isel(**{key: slice(None)})
elif isinstance(key, str) and key in self._grid_array.coords:
return self._grid_array.sel(**{key: slice(None)})
else:
raise KeyError(key)
[docs]
def set_R0(self, value: Quantity):
"""Set the normalized radius for all of the 2D grids."""
check_units(value, {"[length]":1}, "R0")
for grid in self.grid_array:
grid.item().R0 = value
def _remove_nans(self, grid: Grid2D):
"""
Remove NaN values in grid from ghost and filler points.
Returns unchanged grid if no NaNs.
Updates size of grid to new size.
"""
for attribute in dir(grid):
if not attribute.startswith("_") and \
isinstance(getattr(grid, attribute), \
(np.ndarray, xr.DataArray)):
var = getattr(grid, attribute)
mask = ~np.isnan(var)
setattr(grid, attribute, var[mask])
grid.size = grid.r_u.size
return grid
[docs]
def sel_phi(self, phi_value=0) -> Grid2D:
"""Return a 2D grid for the selected phi value with NaNs removed."""
grid = self.sel(phi=phi_value).item()
grid = self._remove_nans(grid)
grid._vector_to_matrix_initialized = False
return grid
[docs]
def isel_phi(self, phi_index=0) -> Grid2D:
"""Return a 2D grid for the selected phi index with NaNs removed."""
grid = self.isel(phi=phi_index).item()
grid = self._remove_nans(grid)
grid._vector_to_matrix_initialized = False
return grid