"""Routines for performing phase shift analysis on data."""
import xarray as xr
import numpy as np
from torx.autodoc_decorators_m import autodoc_function
def _phase_histogram(
ampl: np.ndarray,
ref_ampl: np.ndarray,
mode_number: np.ndarray,
n_bins: int,
weighted: bool,
normed: bool,
):
"""
Return the data from a histogram analysis.
Wrapper for histogram function to strip of the return variables that
are not needed. Is needed to be used in xarray ufunc that only works
with one return variable.
"""
if weighted:
weights = mode_number * np.abs(ampl) * np.abs(ref_ampl)
else:
weights = np.ones(len(mode_number))
phase_shift = np.imag(np.log(np.conj(ampl) * ref_ampl))
idx_not_nan = ~np.isnan(phase_shift)
phase_range = [-np.pi,np.pi]
mode_range = [np.min(mode_number), np.max(mode_number)]
H, _, _ = np.histogram2d(
phase_shift[idx_not_nan],
mode_number[idx_not_nan],
bins=[n_bins, len(mode_number)],
range=[phase_range,mode_range],
weights=weights,
density=normed
)
return H
[docs]
@autodoc_function
def phase_shift_analysis(
fourier_ds: xr.Dataset,
variables: str,
reference: str,
n_bins: int,
weighted: bool=False,
normed: bool=False,
allow_rechunk: bool=True,
):
"""
Perform a phase shift analysis between variables and a reference.
The dataset is assumed to contain already Fourier analyzed variables.
Additionally a 2D histogram count in the mode_number and phase_shift
variables (the size of the latter one is given by n_bins) is performed
and the histogram data is saved into a variable with an _hist.
The count is done over the time and toroidal angle dimensions.
If "weighted" is True, then the phase shifts are weighted by the
product of mode number and amplitudes of the variables and reference,
i.e. the amplitude of the cross-field transport.
"""
# Compute the phase shift for each phi & tau
phase_shift = np.imag(np.log(np.conj(
fourier_ds[variables]) * fourier_ds[reference])
)
# Dictionaries for apply_ufunc
kwargs_dict = dict(
mode_number=fourier_ds.mode_number.values,
n_bins=n_bins,
weighted=weighted,
normed=normed
)
dask_gufunc_kwargs_dict = dict(
allow_rechunk=allow_rechunk,
output_sizes=dict(
phase_shift=n_bins,
mode_number=fourier_ds.mode_number.size
)
)
coords_dict = dict(
phase_shift=np.linspace(-np.pi, np.pi, n_bins),
)
attrs_dict = dict(
description=(f"Fourier phase shift of variables w.r.t. {reference}" +
f", on flux surface rho={fourier_ds.rho}." +
(f" Histogram weighted by k*|var_amplitude|*|ref_amplitude|."
if weighted else "")),
weighted=str(weighted),
rho=fourier_ds.rho
)
dtype = fourier_ds[reference].dtype
# Compute a 2D-histogram for each phi & tau
H = xr.apply_ufunc(
_phase_histogram,
fourier_ds[variables].chunk({"mode_number": -1}),
fourier_ds[reference].chunk({"mode_number": -1}),
input_core_dims=[["mode_number"],["mode_number"],],
output_core_dims=[["phase_shift", "mode_number"]],
kwargs=kwargs_dict,
vectorize=True,
dask="parallelized",
dask_gufunc_kwargs=dask_gufunc_kwargs_dict,
output_dtypes=(dtype, dtype)
).assign_coords(coords_dict).assign_attrs(attrs_dict)
H = np.real(H)
# Fill the 2D histograms by summing over all phi & tau
H = H.where(H != 0)
H = H.sum(dim=(["phi", "tau"] if "phi" in H.dims else "tau"), keep_attrs=True, skipna=True)
name_dict = {}
for var in H.data_vars:
name_dict[var] = var + "_hist"
H = H.rename_vars(name_dict=name_dict)
return H.merge(phase_shift)