Source code for torx.specializations.grillix.operators.csr_operators_m

"""
Operators defined as CSR matrices which can be multiplied onto a state vector.

These are stored in NetCDF groups with
i: row indices
j: column indices
val: value of matrix element
ndim: dimension of input vector
ncol: dimension of output vector

IMPORTANT NOTE: The CSR matrices are stored with Fortran-indexing (starting at
1) but we need them with Python-indexing (starting at 0)
"""
import numpy as np
from scipy.sparse import csr_matrix
import xarray as xr
from typing import Union
from torx.autodoc_decorators_m import autodoc_function

[docs] def convert_ncgroup_to_matrix(matrix_group: xr.Dataset) -> csr_matrix: """Return a CSR matrix corresponding to a NetCDF group.""" return csr_matrix( (matrix_group.val, matrix_group.j - 1, matrix_group.i - 1), shape=(matrix_group.ndim, matrix_group.ncol), )
def matrix_vector_multiply(vector, matrix): """Apply a matrix operator onto a vector.""" return matrix * vector
[docs] @autodoc_function def csr_operator( matrix: Union[csr_matrix, xr.Dataset], vector: xr.DataArray ) -> xr.DataArray: """Apply a CSR matrix to a vector defined as an xarray.""" if isinstance(matrix, xr.Dataset): matrix = convert_ncgroup_to_matrix(matrix) result = xr.apply_ufunc( matrix_vector_multiply, vector, kwargs=dict(matrix=matrix), keep_attrs=True, input_core_dims=[ ["points"], ], output_core_dims=[ ["output_points"], ], vectorize=True, dask="parallelized", dask_gufunc_kwargs={"output_sizes": {"output_points": matrix.shape[0]}}, output_dtypes=[np.float64], ) if not isinstance(result, xr.DataArray): return xr.DataArray(result, dims=["points"]) return result.rename(output_points="points")