Source code for torx.fileio.multigrid_reader_m
"""Reader for PARALLAX multigrids in NetCDF files."""
import warnings
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.sparse import csr_matrix
from torx.grid.mesh_m import Mesh
from torx.autodoc_decorators_m import autodoc_function
[docs]
@autodoc_function
def multigrid_reader(
filepath: Path,
multigrid_group: str = "",
load_all_levels: bool = True,
):
"""
Reader for PARALLAX multigrids, stored as a group in a NetCDF file.
Note
----
The reader does not directly return the multigrid data stored in CSR-format,
but builds the meshes and interpolation matrices.
"""
nlvls = xr.load_dataset(filepath).nlvls
mesh_lvls = np.empty(nlvls, dtype=xr.Dataset)
restriction = np.empty(nlvls - 1, dtype=xr.Dataset)
restriction_inner = np.empty(nlvls - 1, dtype=xr.Dataset)
prolongation = np.empty(nlvls - 1, dtype=xr.Dataset)
mesh_dataset = _load_multigrid_dataset(filepath,
group=f"{multigrid_group}mesh_lvl_001")
mesh_lvls[0] = Mesh(mesh_dataset)
if load_all_levels:
for lvl in range(nlvls - 1):
# Build the meshes of the multigrid layers
mesh_dataset = _load_multigrid_dataset(filepath,
group=f"{multigrid_group}mesh_lvl_{lvl+2:03d}")
mesh_lvls[lvl + 1] = Mesh(mesh_dataset)
# Construct dense interpolation matrices from CSR storage
restriction[lvl] = build_csr_matrix(
filepath,
group = f"{multigrid_group}restriction_lvl_"\
f"{lvl+1:03d}_to_{lvl+2:03d}",
)
restriction_inner[lvl] = build_csr_matrix(
filepath,
group = f"{multigrid_group}restriction_inner_lvl_"\
f"{lvl+1:03d}_to_{lvl+2:03d}",
)
prolongation[lvl] = build_csr_matrix(
filepath,
group = f"{multigrid_group}prolongation_lvl_"\
f"{lvl+2:03d}_to_{lvl+1:03d}",
)
return mesh_lvls, restriction, restriction_inner, prolongation, nlvls
def _load_multigrid_dataset(filepath, group) -> xr.Dataset:
"""Load the multigrid data and performs preprocessing."""
# Load dataset and catch user warning to filter out the size neighbor
# warning that is processed below
with warnings.catch_warnings():
warnings.filterwarnings("error")
try:
ds = xr.load_dataset(filepath, group=group)
except UserWarning as w:
warnings.filterwarnings("ignore")
ds = xr.load_dataset(filepath, group=group)
warnings.filterwarnings("default")
# We filter the warning that is treated below by using some
# appropriate keywords
filter_keys = ("duplicate", "dimension", "size_neighbor")
if not all([f in str(w) for f in filter_keys]):
warnings.warn(w)
if not "index_neighbor" in ds.keys():
return ds
# NOTE: We need to have distinct named dimensions to use the
# xr.DataArray.isel method. Therefore, rename the existing
# "size_neighbor" to dx and dy displacements. Then, assign
# coordinates so that we can access the elements intuitively
# using xr.DataArray.sel(dx=dx, dy=dy)
new_dims = ("n_points", "dy", "dx")
index_neighbor = xr.DataArray(ds.index_neighbor.values, dims=new_dims)
size_neighbor = int(ds.size_neighbor.size - 1) // 2
displacements = np.arange(-size_neighbor, size_neighbor + 1, 1)
index_neighbor = index_neighbor.assign_coords(dx=displacements,
dy=displacements)
ds = ds.drop_dims("size_neighbor")
ds["index_neighbor"] = index_neighbor
return ds
@autodoc_function
def build_csr_matrix(filepath: Path, group: str = "") -> csr_matrix:
"""Build a scipy sparse matrix from PARALLAX CSR data."""
csr_dataset = _load_multigrid_dataset(filepath, group)
# Extracting values for CSR matrix construction
values = csr_dataset.val.values
indices = csr_dataset.j.values - 1
indptr = csr_dataset.i.values - 1
shape = (csr_dataset.ndim, csr_dataset.ncol)
return csr_matrix((values, indices, indptr), shape=shape)