Source code for torx.specializations.grillix.operators.csr_operators_m
"""
Operators defined as CSR matrices which can be multiplied onto a state vector.
These are stored in NetCDF groups with
i: row indices
j: column indices
val: value of matrix element
ndim: dimension of input vector
ncol: dimension of output vector
IMPORTANT NOTE: The CSR matrices are stored with Fortran-indexing (starting at
1) but we need them with Python-indexing (starting at 0)
"""
import numpy as np
from scipy.sparse import csr_matrix
import xarray as xr
from typing import Union
from torx.autodoc_decorators_m import autodoc_function
[docs]
def convert_ncgroup_to_matrix(matrix_group: xr.Dataset) -> csr_matrix:
"""Return a CSR matrix corresponding to a NetCDF group."""
return csr_matrix(
(matrix_group.val, matrix_group.j - 1, matrix_group.i - 1),
shape=(matrix_group.ndim, matrix_group.ncol),
)
def matrix_vector_multiply(vector, matrix):
"""Apply a matrix operator onto a vector."""
return matrix * vector
[docs]
@autodoc_function
def csr_operator(
matrix: Union[csr_matrix, xr.Dataset], vector: xr.DataArray
) -> xr.DataArray:
"""Apply a CSR matrix to a vector defined as an xarray."""
if isinstance(matrix, xr.Dataset):
matrix = convert_ncgroup_to_matrix(matrix)
result = xr.apply_ufunc(
matrix_vector_multiply,
vector,
kwargs=dict(matrix=matrix),
keep_attrs=True,
input_core_dims=[
["points"],
],
output_core_dims=[
["output_points"],
],
vectorize=True,
dask="parallelized",
dask_gufunc_kwargs={"output_sizes": {"output_points": matrix.shape[0]}},
output_dtypes=[np.float64],
)
if not isinstance(result, xr.DataArray):
return xr.DataArray(result, dims=["points"])
return result.rename(output_points="points")