Source code for torx.specializations.grillix.operators.parallel_operators_m

"""
Parallel operators based on CSR operator defined by the parallel map.

Note
----
Requires that you've loaded all planes of data.

The parallel operators map from the canonical grid to the staggered grid and
vice-versa. This introduces a subtlety regarding which plane of data to use in
operations. Each MPI rank in GRILLIX contains both the canonical grid and
staggered grid, which are shifted by a half-toroidal spacing.
Our 'phi' dimension is actually the MPI rank, and it doesn't distinguish between
canonical and staggered variables! Therefore, we need to be careful that we're
using the correct data.

It's best to think about this pictorially

        ---- phi ---->
Rank:    -1 |  0  |  1

    ... o x | o x | o x ...

where
o = canonical grid
x = staggered grid
| = boundary between rank (or 'phi' )
"""

import xarray as xr
import numpy as np
from .csr_operators_m import csr_operator
import warnings
from torx.autodoc_decorators_m import autodoc_function

[docs] @autodoc_function def parallel_gradient(map, input_array: xr.DataArray) -> xr.DataArray: """ Calculate parallel grad on staggered grid for a var on the canonical grid. ... o x | o x | o x ... < ^ > We want to calculate the parallel gradient at the staggered point ^ We need the canonical values on the same rank < and on the forward rank > For the forward values, we 'roll' our array along the phi dimension. To see which sign to use look at the shifted coordinates array.roll(dict(phi=-1), roll_coords=True) """ assert ( input_array.sizes["phi"] == map.nplanes ), "parallel gradient requires all planes of data loaded" assert ( input_array.staggered == False ), "Parallel gradient can only be calculated for variables on the canonical grid" # qfwd * u_halo_fwd forward_values = csr_operator(map.qfwd, input_array.roll(dict(phi=-1))) # qbwd * u backward_values = csr_operator(map.qbwd, input_array) return (forward_values - backward_values).assign_attrs(staggered=True)
[docs] @autodoc_function def parallel_divergence(map, input_array: xr.DataArray) -> xr.DataArray: """ Calculate parallel div on staggered grid for a var on the canonical grid. ... o x | o x | o x ... < ^ > We want to calculate the parallel divergence at the canonical point ^ We need the staggered values on the previous rank < and on the same rank > """ assert ( input_array.sizes["phi"] == map.nplanes ), "parallel gradient requires all planes of data loaded" assert ( input_array.staggered == True ), "Parallel divergence can only be calculated for variables on the staggered grid" # pbwd * u_halo_bwd backward_values = csr_operator(map.pbwd, input_array.roll(dict(phi=+1))) # pfwd * u forward_values = csr_operator(map.pfwd, input_array) return (backward_values - forward_values).assign_attrs(staggered=False)
[docs] @autodoc_function def map_to_staggered(map, input_array: xr.DataArray) -> xr.DataArray: """ Map a variable defined on the canonical grid onto the staggered grid. ... o x | o x | o x ... < ^ > We want to calculate the value at the staggered point ^ We need the canonical values on the same rank < and on the forward rank > """ assert ( input_array.sizes["phi"] == map.nplanes ), "parallel gradient requires all planes of data loaded" assert input_array.staggered == False, "input array already on staggered grid" # mfwd * u_halo_fwd forward_values = csr_operator(map.mfwd, input_array.roll(dict(phi=-1))) # mbwd * u backward_values = csr_operator(map.mbwd, input_array) stag = ( (forward_values * map.dpar_bwd_half + backward_values * map.dpar_fwd_half) / (map.dpar_fwd_half + map.dpar_bwd_half).values ).assign_attrs(staggered=True) return stag
[docs] @autodoc_function def map_to_canonical(map, input_array: xr.DataArray) -> xr.DataArray: """ Map a variable defined on the staggered grid onto the canonical grid. ... o x | o x | o x ... < ^ > We want to calculate the value at the canonical point ^ We need the staggered values on the previous rank < and on the same rank > """ assert ( input_array.sizes["phi"] == map.nplanes ), "parallel gradient requires all planes of data loaded" assert input_array.staggered == True, "input array already on canonical grid" # mfwd * u forward_values = csr_operator(map.mfwd, input_array) # mbwd * u_halo_bwd backward_values = csr_operator(map.mbwd, input_array.roll(dict(phi=+1))) full = ( (forward_values * map.dpar_bwd_half + backward_values * map.dpar_fwd_half) / (map.dpar_fwd_half + map.dpar_bwd_half).values ).assign_attrs(staggered=False) return full
[docs] @autodoc_function def parallel_gradient_3D(map3D, field: xr.DataArray) -> xr.DataArray: """ Calculate the parallel gradient on the staggered grid for all planes. Works in a 3D (non-axisymmetric) equilibrium. """ # --- Normalize toroidal dimension ---------------------------------------- field = _fix_dim_names(field) assert not field.staggered, "Parallel gradient requires canonical grid input" assert ( (field.sizes["phi"] == map3D.nplanes) or (field.sizes["phi"] == map3D.phi.size + 1) ), "Field must contain all planes or 1 more than the Maps." if map3D.phi.size != map3D.nplanes: remove = -1 warnings.warn( "Map3D does not contain all planes of data.\ Resulting array will have one less point than the field in the 'phi' dimension.", UserWarning ) else: remove = None # --- Prepare field shifted one plane forward ----------------------------- field_p0 = field.copy() field_p1 = field_p0.roll(phi=-1, roll_coords=False) if remove is not None: field_p0 = field_p0.drop_isel(phi=remove) field_p1 = field_p1.drop_isel(phi=remove) # --- Extract CSR matrices ------------------------------------------------ fwd_matrices = map3D.qfwd bwd_matrices = map3D.qbwd # --- Compute the max row count across ALL planes ------------------------- counts = ( [m.shape[0] for m in fwd_matrices.values] + [m.shape[0] for m in bwd_matrices.values] ) max_counts = max(counts) max_counts = max(max_counts, field_p0.sizes["points"], field_p1.sizes["points"]) # --- Helper: apply CSR and pad ------------------------------------------ def _gradient_single_plane(f0, f1, qfwd_mat, qbwd_mat, max_counts): fwd_vals, bwd_vals, _, _ = _get_fwd_bwd_vals( f0, f1, qfwd_mat, qbwd_mat ) grad = fwd_vals - bwd_vals # pad at end to reach max_counts n = grad.sizes["points"] if n < max_counts: grad = grad.pad(points=(0, max_counts - n), constant_values=np.nan) return grad # --- Vectorized evaluation across phi ----------------------------------- try: grad = xr.apply_ufunc( _gradient_single_plane, field_p0, field_p1, fwd_matrices, bwd_matrices, max_counts, input_core_dims=[["points"], ["points"], [], [], []], output_core_dims=[["points"]], exclude_dims=set(("points",)), vectorize=True, dask="parallelized", output_dtypes=[np.float64], keep_attrs=True, ) except ValueError as e: raise ValueError( "Error during parallel gradient computation.\n" "Please ensure that the input field has consistent dimensions " "and that Map3D contains the necessary planes of data.\n" "If using a slice of planes, ensure the Field's extra plane\n" "has higher phi than the last plane in the Maps." ) from e return grad.assign_attrs(staggered=True)
[docs] @autodoc_function def parallel_divergence_3D(map3D, field: xr.DataArray) -> xr.DataArray: """ Calculate the parallel divergence on the canonical grid for all planes. For a variable defined on the staggered grid in a 3D (non-axisymmetric) equilibrium. """ # --- Normalize toroidal dimension ---------------------------------------- field = _fix_dim_names(field) assert field.staggered, "Parallel divergence requires staggered grid input" assert (field.sizes["phi"] == map3D.nplanes or field.sizes["phi"] == map3D.phi.size + 1), \ "Field must contain all planes or 1 more than the Maps." if map3D.phi.size != map3D.nplanes: remove = 0 warnings.warn( "Map3D does not contain all planes of data.\ Resulting array will have one less point than the field in the 'phi' dimension.", UserWarning ) else: remove = None # --- Prepare per-plane slices -------------------------------------------- field_p1 = field field_p0 = field.roll(phi=1, roll_coords=False) if remove is not None: field_p1 = field_p1.drop_isel(phi=remove) field_p0 = field_p0.drop_isel(phi=remove) # --- Extract CSR matrices per plane ------------------------------------- fwd_matrices = map3D.pfwd bwd_matrices = map3D.pbwd # --- Compute maximum number of rows across all CSR matrices ------------- counts = ( [m.shape[1] for m in fwd_matrices.values] + [m.shape[1] for m in bwd_matrices.values] ) max_counts = max(counts) max_counts = max(max_counts, field_p0.sizes["points"], field_p1.sizes["points"]) # --- Per-plane worker (gets numpy arrays for f0/f1 because vectorize=True) --- def _divergence_single_plane(f0, f1, pfwd_mat, pbwd_mat, max_counts): fwd_vals, bwd_vals, _, _ = _get_fwd_bwd_vals( f0, f1, pfwd_mat, pbwd_mat ) # compute divergence on canonical points div = (bwd_vals - fwd_vals) # pad at the end so all planes return the same length n = div.sizes["points"] if n < max_counts: div = div.pad(points=(0, max_counts - n), constant_values=np.nan) return div # --- Vectorized evaluation across phi ---------------------------------------- try: divergence = xr.apply_ufunc( _divergence_single_plane, field_p0, field_p1, fwd_matrices, bwd_matrices, max_counts, input_core_dims=[["points"], ["points"], [], [], []], output_core_dims=[["points"]], exclude_dims=set(("points",)), vectorize=True, dask="parallelized", output_dtypes=[np.float64], keep_attrs=True, ) except ValueError as e: raise ValueError( "Error during parallel divergence computation.\n" "Please ensure that the input field has consistent dimensions " "and that Map3D contains the necessary planes of data.\n" "If using a slice of planes, ensure the Field's extra plane " "has lower phi than the first plane in the Maps." ) from e return divergence.assign_attrs(staggered=False)
[docs] @autodoc_function def map_to_staggered_3D(map3D, field: xr.DataArray) -> xr.DataArray: """ Map a canonical grid field onto the staggered grid for all planes. Works in a 3D (non-axisymmetric) equilibrium. """ # --- Normalize toroidal dimension ---------------------------------------- field = _fix_dim_names(field) assert not field.staggered, "Input must be canonical grid data." assert (field.sizes["phi"] == map3D.nplanes or field.sizes["phi"] == map3D.phi.size + 1), \ "Field must contain all planes or 1 more than the Maps." if map3D.phi.size != map3D.nplanes: remove = -1 warnings.warn( "Map3D does not contain all planes of data.\ Resulting array will have one less point than the field in the 'phi' dimension.", UserWarning ) else: remove = None # --- Canonical neighbors ------------------------------------------------- # forward plane (p+1) and same plane p values field_p1 = field.roll(phi=-1, roll_coords=False) # plane p+1 field_p0 = field # plane p if remove is not None: field_p1 = field_p1.drop_isel(phi=remove) field_p0 = field_p0.drop_isel(phi=remove) # --- Extract weights ----------------------------------------------------- dfwd = map3D.dpar_fwd_half_stag.rename(n_points_stag="points") dbwd = map3D.dpar_bwd_half_stag.rename(n_points_stag="points") # --- CSR matrices -------------------------------------------------------- fwd_matrices = map3D.mfwd bwd_matrices = map3D.mbwd # --- Determine padding length ------------------------------------------- counts = ( [m.shape[0] for m in fwd_matrices.values] + [m.shape[0] for m in bwd_matrices.values] ) max_counts = max(counts) max_counts = max(max_counts, field_p0.sizes["points"], field_p1.sizes["points"], dfwd.sizes["points"], dbwd.sizes["points"]) # --- Per-plane worker ---------------------------------------------------- def _map_single_plane(f0, f1, f_fwd, f_bwd, dpar_fwd, dpar_bwd, max_counts): fwd_vals, bwd_vals, dpar_fwd, dpar_bwd = _get_fwd_bwd_vals( f0, f1, f_fwd, f_bwd, dpar_fwd, dpar_bwd ) # combine using the same logic as the single plane version mapped = (fwd_vals * dpar_bwd + bwd_vals * dpar_fwd) / (dpar_fwd + dpar_bwd) # pad to max_counts n = mapped.sizes["points"] if n < max_counts: mapped = mapped.pad(points=(0, max_counts - n), constant_values=np.nan) return mapped # --- Vectorized apply ---------------------------------------------------- try: mapped = xr.apply_ufunc( _map_single_plane, field_p0, field_p1, fwd_matrices, bwd_matrices, dfwd.pad(points=(0, max_counts - dfwd.sizes["points"]), constant_values=np.nan), dbwd.pad(points=(0, max_counts - dfwd.sizes["points"]), constant_values=np.nan), max_counts, input_core_dims=[["points"], ["points"], [], [], ["points"], ["points"], []], output_core_dims=[["points"]], exclude_dims=set(("points",)), vectorize=True, dask="parallelized", keep_attrs=True, output_dtypes=[np.float64], ) except ValueError as e: raise ValueError( "Error during mapping to staggered planes.\n" "Please ensure that the input field has consistent dimensions " "and that Map3D contains the necessary planes of data.\n" "If using a slice of planes, ensure the Field's extra plane " "has higher phi than the last plane in the Maps." ) from e return mapped.assign_attrs(staggered=True)
[docs] @autodoc_function def map_to_canonical_3D(map3D, field: xr.DataArray) -> xr.DataArray: """ Map a staggered grid field onto the canonical grid for all planes. Works in a 3D (non-axisymmetric) equilibrium. """ # --- Normalize toroidal dimension ---------------------------------------- field = _fix_dim_names(field) assert field.staggered, "Input must be staggered grid data." assert (field.sizes["phi"] == map3D.nplanes or field.sizes["phi"] == map3D.phi.size + 1), \ "Field must contain all planes or 1 more than the Maps." if map3D.phi.size != map3D.nplanes: remove = 0 warnings.warn( "Map3D does not contain all planes of data.\ Resulting array will have one less point than the field in the 'phi' dimension.", UserWarning ) else: remove = None # --- Staggered neighbors ------------------------------------------------- # For canonical at plane p, staggered needed at (p-1) and p field_p1 = field # plane p field_p0 = field.roll(phi=1, roll_coords=False) # plane p-1 if remove is not None: field_p1 = field_p1.drop_isel(phi=remove) field_p0 = field_p0.drop_isel(phi=remove) # --- Weights ------------------------------------------------------------- dfwd = map3D.dpar_fwd_half dbwd = map3D.dpar_bwd_half # --- CSR matrices -------------------------------------------------------- fwd_matrices = map3D.mfwd_stag bwd_matrices = map3D.mbwd_stag # --- max padding size ---------------------------------------------------- counts = ( [m.shape[1] for m in fwd_matrices.values] + [m.shape[1] for m in bwd_matrices.values] ) max_counts = max(counts) max_counts = max(max_counts, field_p0.sizes["points"], field_p1.sizes["points"], dfwd.sizes["points"], dbwd.sizes["points"]) # --- Per-plane mapping worker ------------------------------------------- def _map_single_plane(f0, f1, f_fwd, f_bwd, dpar_fwd, dpar_bwd, max_counts): fwd_vals, bwd_vals, dpar_fwd, dpar_bwd = _get_fwd_bwd_vals( f0, f1, f_fwd, f_bwd, dpar_fwd, dpar_bwd ) mapped = (fwd_vals * dpar_bwd + bwd_vals * dpar_fwd) / (dpar_fwd + dpar_bwd) n = mapped.sizes["points"] if n < max_counts: mapped = mapped.pad(points=(0, max_counts - n), constant_values=np.nan) return mapped # --- Vectorized ufunc ---------------------------------------------------- try: mapped = xr.apply_ufunc( _map_single_plane, field_p0, field_p1, fwd_matrices, bwd_matrices, dfwd.pad(points=(0, max_counts - dfwd.sizes["points"]), constant_values=np.nan), dbwd.pad(points=(0, max_counts - dbwd.sizes["points"]), constant_values=np.nan), max_counts, input_core_dims=[["points"], ["points"], [], [], ["points"], ["points"], []], output_core_dims=[["points"]], exclude_dims=set(("points",)), vectorize=True, dask="parallelized", keep_attrs=True, output_dtypes=[np.float64], ) except ValueError as e: raise ValueError( "Error during mapping to canonical planes.\n" "Please ensure that the input field has consistent dimensions " "and that Map3D contains the necessary planes of data.\n" "If using a slice of planes, ensure the Field's extra plane " "has lower phi than the first plane in the Maps." ) from e return mapped.assign_attrs(staggered=False)
def _fix_dim_names(field: xr.DataArray) -> xr.DataArray: if "planes" in field.dims: field = field.rename(planes="phi") elif "phi" not in field.dims: raise KeyError("Input field must contain dim 'planes' or 'phi'.") if "npoints_max" in field.dims: field = field.rename(npoints_max="points") elif "points" not in field.dims: raise KeyError("Input field must contain dim 'npoints_max' or 'points'.") return field def _get_fwd_bwd_vals(f0, f1, mat_fwd, mat_bwd, weights_fwd=None, weights_bwd=None): """ Prepare staggered/canonical data for a single plane. Does: - remove NaNs - optionally remove NaNs from weights - multiply f0/f1 by CSR matrices Returns: forward_vals, backward_vals, cleaned weights_fwd, weights_bwd """ f0 = f0[~np.isnan(f0)] f1 = f1[~np.isnan(f1)] if weights_fwd is not None: weights_fwd = weights_fwd[~np.isnan(weights_fwd)] if weights_bwd is not None: weights_bwd = weights_bwd[~np.isnan(weights_bwd)] forward_vals = csr_operator(mat_fwd, f1) backward_vals = csr_operator(mat_bwd, f0) return forward_vals, backward_vals, weights_fwd, weights_bwd