Source code for torx.fileio.multigrid_reader_m

"""Reader for PARALLAX multigrids in NetCDF files."""
import warnings
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.sparse import csr_matrix

from torx.grid.mesh_m import Mesh
from torx.autodoc_decorators_m import autodoc_function

[docs] @autodoc_function def multigrid_reader( filepath: Path, multigrid_group: str = "", load_all_levels: bool = True, ): """ Reader for PARALLAX multigrids, stored as a group in a NetCDF file. Note ---- The reader does not directly return the multigrid data stored in CSR-format, but builds the meshes and interpolation matrices. """ nlvls = xr.load_dataset(filepath).nlvls mesh_lvls = np.empty(nlvls, dtype=xr.Dataset) restriction = np.empty(nlvls - 1, dtype=xr.Dataset) restriction_inner = np.empty(nlvls - 1, dtype=xr.Dataset) prolongation = np.empty(nlvls - 1, dtype=xr.Dataset) mesh_dataset = _load_multigrid_dataset(filepath, group=f"{multigrid_group}mesh_lvl_001") mesh_lvls[0] = Mesh(mesh_dataset) if load_all_levels: for lvl in range(nlvls - 1): # Build the meshes of the multigrid layers mesh_dataset = _load_multigrid_dataset(filepath, group=f"{multigrid_group}mesh_lvl_{lvl+2:03d}") mesh_lvls[lvl + 1] = Mesh(mesh_dataset) # Construct dense interpolation matrices from CSR storage restriction[lvl] = build_csr_matrix( filepath, group = f"{multigrid_group}restriction_lvl_"\ f"{lvl+1:03d}_to_{lvl+2:03d}", ) restriction_inner[lvl] = build_csr_matrix( filepath, group = f"{multigrid_group}restriction_inner_lvl_"\ f"{lvl+1:03d}_to_{lvl+2:03d}", ) prolongation[lvl] = build_csr_matrix( filepath, group = f"{multigrid_group}prolongation_lvl_"\ f"{lvl+2:03d}_to_{lvl+1:03d}", ) return mesh_lvls, restriction, restriction_inner, prolongation, nlvls
def _load_multigrid_dataset(filepath, group) -> xr.Dataset: """Load the multigrid data and performs preprocessing.""" # Load dataset and catch user warning to filter out the size neighbor # warning that is processed below with warnings.catch_warnings(): warnings.filterwarnings("error") try: ds = xr.load_dataset(filepath, group=group) except UserWarning as w: warnings.filterwarnings("ignore") ds = xr.load_dataset(filepath, group=group) warnings.filterwarnings("default") # We filter the warning that is treated below by using some # appropriate keywords filter_keys = ("duplicate", "dimension", "size_neighbor") if not all([f in str(w) for f in filter_keys]): warnings.warn(w) if not "index_neighbor" in ds.keys(): return ds # NOTE: We need to have distinct named dimensions to use the # xr.DataArray.isel method. Therefore, rename the existing # "size_neighbor" to dx and dy displacements. Then, assign # coordinates so that we can access the elements intuitively # using xr.DataArray.sel(dx=dx, dy=dy) new_dims = ("n_points", "dy", "dx") index_neighbor = xr.DataArray(ds.index_neighbor.values, dims=new_dims) size_neighbor = int(ds.size_neighbor.size - 1) // 2 displacements = np.arange(-size_neighbor, size_neighbor + 1, 1) index_neighbor = index_neighbor.assign_coords(dx=displacements, dy=displacements) ds = ds.drop_dims("size_neighbor") ds["index_neighbor"] = index_neighbor return ds @autodoc_function def build_csr_matrix(filepath: Path, group: str = "") -> csr_matrix: """Build a scipy sparse matrix from PARALLAX CSR data.""" csr_dataset = _load_multigrid_dataset(filepath, group) # Extracting values for CSR matrix construction values = csr_dataset.val.values indices = csr_dataset.j.values - 1 indptr = csr_dataset.i.values - 1 shape = (csr_dataset.ndim, csr_dataset.ncol) return csr_matrix((values, indices, indptr), shape=shape)