"""Low-level interface to the PARALLAX multigrid."""
import numpy as np
import xarray as xr
from pathlib import Path
from scipy.sparse import csr_matrix
from .mesh_m import Mesh
from .grid_2d_m import Grid2D
from .multigrid_operations_m import vector_to_matrix, matrix_to_vector
import torx
from torx.fileio import multigrid_reader
from torx.autodoc_decorators_m import autodoc_class
[docs]
@autodoc_class
class Multigrid2D(Grid2D):
"""Represents a PARALLAX multigrid type (2D)."""
[docs]
def __init__(
self,
filepath: Path,
multigrid_group: str,
load_all_levels: bool = True,
R0: torx.Quantity = torx.Quantity(1, "m"),
):
"""
Initialize the multigrid object.
Parameters
----------
- `filepath` (Path): Path to the NetCDF file containing the multigrid data.
- `multigrid_group` (str): Group inside the NetCDF file where the multigrid
is stored.
- `load_all_levels` (bool, optional): If False only load grid of level 0,
default is True.
- `R0` (torx.Quantity, optional): Grid length normalization, default is 1 meter.
"""
mesh_lvl, _, _, _, nlvls = multigrid_reader(
filepath, multigrid_group=multigrid_group, load_all_levels=False
)
self.R0 = R0
self.nlvls = nlvls
self.filepath = filepath
self.multigrid_group = multigrid_group
self._grids = np.empty(self.nlvls, dtype=Grid2D)
self._mesh_levels = np.empty(self.nlvls, dtype=Mesh)
self._restriction_matrices = np.empty(self.nlvls - 1, dtype=csr_matrix)
self._prolongation_matrices = np.empty(self.nlvls - 1, dtype=csr_matrix)
self._grids[0] = Grid2D(mesh_lvl[0].r_u, mesh_lvl[0].z_u)
self._grids[0].R0 = self.R0
self._all_levels_loaded = False
if load_all_levels:
self._load_multigrid_levels()
[docs]
def get_grid(self, lvl: int):
"""
Return a 2D grid for a given PARALLAX multigrid level.
Params
------
- `lvl`: Requested multigrid level.
Return
------
- `grid`: torx.Grid2D of the requested multigrid level.
"""
self._load_multigrid_levels()
assert (1 <= lvl <= self.nlvls), \
f"Requested level exceeds number of multigrid levels ({self.nlvls})"
return self._grids[lvl - 1]
[docs]
def get_mesh(self, lvl: int):
"""
Return a low level mesh object for a given PARALLAX multigrid level.
Params
------
- `lvl`: Requested multigrid level.
Return
------
- `grid`: torx.Mesh of the requested multigrid level.
"""
self._load_multigrid_levels()
assert (1 <= lvl <= self.nlvls), \
f"Requested level exceeds number of multigrid levels ({self.nlvls})"
return self._mesh_levels[lvl - 1]
def _matrix_vector_multiply(
self,
vector: np.ndarray,
sparse_matrix: csr_matrix,
) -> np.ndarray:
"""Matrix vector multiplication core function."""
return sparse_matrix @ vector
[docs]
def restrict(
self,
input_array: xr.DataArray,
target_lvl: int = None
) -> xr.DataArray:
"""
Restrict an unstructured array to a higher (coarser) multigrid level.
Params
------
- `input_array` (xr.DataArray): Data to be restricted.
- `target_lvl` (int, optional): Multigrid level to be restricted to,
default is the next higher (coarser) level.
Return
------
- `restricted_array` (xr.DataArray): Restricted array, where the restriction level is
stored in attribute 'multigrid_group'.
"""
self._load_multigrid_levels()
source_lvl = self._infer_level_from_shape(input_array)
if target_lvl is None:
target_lvl = source_lvl + 1
assert "points" in input_array.dims, \
"Restriction only works in unstructured data, i.e. in 'points'."
assert (1 <= source_lvl <= self.nlvls) and (1 <= target_lvl <= self.nlvls), \
f"Restriction level exceeds number of multigrid levels ({self.nlvls})"
assert target_lvl > source_lvl, \
"Restriction must be to higher (coarser) level."
restricted_array = input_array
for lvl in range(source_lvl, target_lvl):
# Dictionaries for apply_ufunc
kwargs_dict = dict(
sparse_matrix = self._restriction_matrices[lvl - 1],
)
dask_gufunc_kwargs_dict = dict(
output_sizes = {
"restricted_points":
int(self.grid_sizes.sel(multigrid_level=lvl+1))
}
)
restricted_array = xr.apply_ufunc(
self._matrix_vector_multiply,
restricted_array.chunk(chunks={"points": -1}),
input_core_dims = [["points"]],
output_core_dims = [["restricted_points"]],
kwargs = kwargs_dict,
dask_gufunc_kwargs = dask_gufunc_kwargs_dict,
output_dtypes = [np.float64],
keep_attrs = True,
vectorize = True,
dask = "parallelized",
).rename(
{"restricted_points": "points"}
)
return restricted_array.assign_attrs({"multigrid_level": target_lvl})
[docs]
def prolong(
self,
input_array: xr.DataArray,
target_lvl: int = None
) -> xr.DataArray:
"""
Prolong the data to a lower (finer) multigrid level.
Parameters
----------
- `input_array` (xr.DataArray): Data to be restricted.
- `target_lvl` (int, optional): Multigrid level to be restricted to,
default is the next lower, i.e. finer, level.
Return
------
- `prolonged_array` (xr.DataArray): Prolonged array, where the prolongation
level is stored in attribute 'multigrid_group'.
"""
self._load_multigrid_levels()
source_lvl = self._infer_level_from_shape(input_array)
if target_lvl is None:
target_lvl = source_lvl - 1
assert "points" in input_array.dims, \
"Prolongation only works in unstructured data, i.e. in 'points'."
assert (1 <= source_lvl <= self.nlvls) and (1 <= target_lvl <= self.nlvls), \
f"Prolongation level exceeds number of multigrid levels ({self.nlvls})"
assert target_lvl < source_lvl, \
"Prolongation must be to lower (finer) level."
prolonged_array = input_array
for lvl in range(source_lvl, target_lvl, -1):
# Dictionaries for apply_ufunc
kwargs_dict = dict(
sparse_matrix = self._prolongation_matrices[lvl - 2],
)
dask_gufunc_kwargs = dict(
output_sizes = {
"prolonged_points":
int(self.grid_sizes.sel(multigrid_level=lvl-1))
}
)
prolonged_array = xr.apply_ufunc(
self._matrix_vector_multiply,
prolonged_array.chunk(chunks={"points": -1}),
input_core_dims = [["points"]],
output_core_dims = [["prolonged_points"]],
kwargs = kwargs_dict,
dask_gufunc_kwargs = dask_gufunc_kwargs,
output_dtypes = [np.float64],
keep_attrs = True,
vectorize = True,
dask = "parallelized",
).rename(
{"prolonged_points": "points"}
)
return prolonged_array.assign_attrs({"multigrid_level": target_lvl})
[docs]
def vector_to_matrix(self, input_array: xr.DataArray) -> xr.DataArray:
"""
Convert logically unstructured data to structured (R,Z) data.
Non-existing points are filled with NaNs.
Parameters
----------
- `input_array` (xr.DataArray): Data to be converted to 2D.
Return
------
- `structured_array` (xr.DataArray): 2D structured representation of
the input_array.
"""
self._load_multigrid_levels()
lvl = self._infer_level_from_shape(input_array)
#return self._grids[lvl - 1].vector_to_matrix(input_array)
return vector_to_matrix(self._mesh_levels[lvl - 1], input_array)
[docs]
def matrix_to_vector(self, input_array: xr.DataArray) -> xr.DataArray:
"""
Convert structured (R,Z) data to logically unstructured data.
Parameters
----------
- `input_array` (xr.DataArray): 2D data to be converted to unstructured.
Return
------
- `unstructured_array` (xr.DataArray): Filled, unstructured array of
the input data.
"""
self._load_multigrid_levels()
lvl = self._infer_level_from_shape(input_array)
#return self._grids[lvl - 1].matrix_to_vector(input_array)
return matrix_to_vector(self._mesh_levels[lvl - 1], input_array)
def _load_multigrid_levels(self) -> None:
"""
Load all multigrid levels into memory.
By the `_all_levels_loaded` attribute the current status is checked.
"""
assert isinstance(self._all_levels_loaded, bool), \
"Something went wrong in the initialization."
if self._all_levels_loaded:
# Check grid levels
loaded_levels = np.where(
[isinstance(x, Grid2D) for x in self._grids]
)[0]
assert len(loaded_levels) == self.nlvls, \
f"Grid of level {loaded_levels+1} not loaded."
# Check restriction matrices
loaded_levels = np.where(
[isinstance(x, csr_matrix) for x in self._restriction_matrices]
)[0]
assert len(loaded_levels) == (self.nlvls - 1), \
f"Restriction matrix of level {loaded_levels+1} not loaded."
# Check restriction matrices
loaded_levels = np.where(
[isinstance(x, csr_matrix) for x in self._prolongation_matrices]
)[0]
assert len(loaded_levels) == (self.nlvls - 1), \
f"Prolongation matrix of level {loaded_levels+1} not loaded."
return None
else:
meshes, restriction, _, prolongation, nlvls = multigrid_reader(
self.filepath, multigrid_group=self.multigrid_group
)
self._mesh_levels = meshes
self._restriction_matrices = restriction
self._prolongation_matrices = prolongation
for lvl in np.arange(1, nlvls):
self._grids[lvl] = Grid2D(meshes[lvl].r_u, meshes[lvl].z_u)
self._grids[lvl].R0 = self.R0
self.grid_sizes = xr.DataArray(
np.array([self._grids[lvl].size for lvl in range(self.nlvls)]),
dims = "multigrid_level",
coords = {"multigrid_level": np.arange(1, self.nlvls + 1)}
)
self._all_levels_loaded = True
return None
def _infer_level_from_shape(self, input_array) -> int:
"""
Infer the 1-indexed multigrid level from the shape of the input array.
Either in unstructured `points` or structured `R`, `Z` dimension.
"""
self._load_multigrid_levels()
lvl = -1
if "points" in input_array.dims:
assert input_array.points.size in self.grid_sizes.values, \
"No multigrid level with matching grid size found."
lvl = np.argmin(
np.abs(self.grid_sizes.values - input_array.points.size)
)
elif "R" in input_array.dims and "Z" in input_array.dims:
for l in range(self.nlvls):
if np.logical_and(
np.allclose(self._grids[l].r_s, input_array.R),
np.allclose(self._grids[l].z_s, input_array.Z)
):
lvl = l
break
else:
continue
if lvl == -1:
"No multigrid level with matching grid size found."
else:
"Neither 'points' nor 'R' and 'Z' found in input_array."
return lvl + 1