"""Access data stored in GENE-X diagnostic files."""
import xarray as xr
import numpy as np
import os
import re
from pathlib import Path
from glob import glob
from torx.fileio import filepath_resolver
from .grid_helpers_m import mask_ghost_and_filler
from torx.autodoc_decorators_m import autodoc_function
from torx.performance import auto_chunk
dataset_storage = {}
keys_mom_0d = ["n", "u_par", "E_par", "E_perp", "E_es", "p_phi"]
keys_mom_2d = ["n", "u_par", "E_par", "E_perp", "Q_par", "Q_perp"]
keys_em_fields = ["es_pot", "A_par"]
keys_par_con = ["in_target", "par_con_positive", "par_con_negative"]
def has_parts(directory_path: Path):
"""Check whether the given output features the GENE-X version with parts."""
regex = re.compile("part_*")
for elem in directory_path.iterdir():
if elem.is_dir():
if regex.match(elem.name):
return True
return False
def find_all(directory_path: Path, filename: str):
"""Find all files containing the given filename in the directory given."""
res = directory_path.rglob(filename)
return [elem for elem in res]
def latest_part(paths: list):
"""
Return the latest part that is contained in the list of paths.
It is determined by the biggest number in part_xxx.
"""
nums = np.array([int(re.findall(r'part_(\d+)', str(p))[0]) for p in paths])
return nums.max()
def load_from_par_con_file(directory_path: Path, variable: str):
"""Load a given variable from the 'par_con.nc' file."""
try:
path = filepath_resolver(directory_path, "par_con.nc")
dataset = xr.open_dataset(path)
dataset = dataset.rename({"dim_phi_grid": "phi",
"dim_RZ_grid": "points"})
except:
dataset = xr.open_dataset(
filepath_resolver(directory_path, "mesh.nc"),
group = "parcon"
)
dataset = dataset.rename({"dim_phi": "phi",
"dim_RZ": "points"})
phi = dataset["phi"]
dataset = dataset.assign_coords({"phi": phi})
# Update for not_in_target mask backwards compatible with in_target mask
if "not_in_target" in dataset.data_vars and variable == "in_target":
in_target = 1.0 - dataset["not_in_target"]
else:
in_target = dataset[variable]
return in_target
def load_em_fields_file(path: Path):
"""Load data from the em_fields.nc file."""
dataset = xr.open_dataset(path, chunks="auto")
dataset = dataset.rename(
{"dim_time": "tau", "dim_phi": "phi", "dim_RZ": "points", "time": "tau"}
)
tau = dataset["tau"]
phi = dataset["phi"]
dataset = dataset.assign_coords({"tau": tau})
dataset = dataset.assign_coords({"phi": phi})
return dataset
def load_from_em_fields_file(directory_path: Path, variable: str):
"""Load a given variable from the em_fields.nc file."""
if not has_parts(directory_path):
path = filepath_resolver(directory_path, "em_fields.nc")
if not path.exists():
raise FileNotFoundError("There is no em_fields.nc in the given directory!")
dataset = load_em_fields_file(path)
else:
paths = find_all(directory_path, "em_fields.nc")
# Sort paths naturally by the appending number
paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))
dataset = xr.combine_by_coords([load_em_fields_file(path) for path in paths])
return dataset[variable]
def load_mom_2d_group(path: Path, group: str, variable: str):
"""Load data from a group in the mom_2d.nc file."""
# NOTE: Time is only present in the file if no group is specified, thus
# fetch it first before data load
dataset = xr.open_dataset(path, chunks="auto")
time = dataset["time"].values
dataset = xr.open_dataset(path, group=group, chunks="auto")
dataset = dataset.rename(
{"dim_time": "tau", "dim_phi": "phi", "dim_RZ": "points"}
)
dataset = dataset.assign_coords({"tau": time})
dataset = dataset.assign_coords({"phi": dataset["phi"]})
return dataset
def load_from_mom_2d_file(directory_path: Path, species: str, variable: str):
"""Load a given variable from a group in the mom_2d.nc file."""
if not has_parts(directory_path):
path = filepath_resolver(directory_path, "mom_2d.nc")
if not path.exists():
raise FileNotFoundError("There is no mom_2d.nc in the "\
"given directory!")
dataset = load_mom_2d_group(path, species, variable)
else:
paths = find_all(directory_path, "mom_2d.nc")
# Sort paths naturally by the appending number
paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))
dataset = xr.combine_by_coords(
[load_mom_2d_group(path, species, variable) for path in paths]
)
return dataset[variable]
def load_mom_0d_group(path: Path, group: str, variable: str):
"""Load data from a group in the mom_0d.nc file."""
dataset = xr.open_dataset(path)
dataset = dataset.rename({"time": "tau", "dim_time": "tau"})
tau = dataset["tau"]
dataset = xr.open_dataset(path, group=group)
dataset = dataset.rename({"dim_time": "tau"})
dataset = dataset.assign_coords({"tau": tau})
return dataset[variable]
def load_from_mom_0d_file(directory_path: Path, species: str, variable: str):
"""Load a given variable from a group in the mom_0d.nc file."""
if not has_parts(directory_path):
path = filepath_resolver(directory_path, "mom_0d.nc")
if not path.exists():
raise FileNotFoundError("There is no mom_0d.nc in the given directory!")
dataset = load_mom_0d_group(path, species, variable)
else:
paths = find_all(directory_path, "mom_0d.nc")
# Sort paths naturally by the appending number
paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))
dataset = xr.concat(
[load_mom_0d_group(path, species, variable) for path in paths], dim="tau"
)
dataset = dataset.assign_coords({"tau": dataset["tau"]})
dataset = dataset.sortby("tau")
dataset.attrs["name"] = variable
return dataset
def load_checkpoint(path: Path, chunks="auto"):
"""Load data from the checkpoints file."""
renaming_dict = {"dim_time": "tau",
"dim_phi": "phi",
"dim_RZ": "points",
"dim_sp": "spec",
"dim_vp": "vp",
"dim_mu": "mu"}
# If chunks is not auto, map the torx dimension names in the chunks dict
# to the ones used in the NetCDF file. This requires inverting the renaming
# dict and renaming the keys.
if(chunks != "auto"):
# We need to make a hardcopy of the chunks dict since it is passed by reference
# to this function. I.e. changes done here will affect succeeding uses outside
# of this function.
chunks_local = chunks.copy()
chunk_renaming = {v: k for k, v in renaming_dict.items()}
keys = list(chunks_local.keys())
for k in keys:
chunks_local[chunk_renaming[k]] = chunks_local.pop(k)
else:
chunks_local = chunks
# Open dataset and rename the keys to match torx convention
dataset = xr.open_dataset(path, chunks=chunks_local)
dataset = dataset.rename(renaming_dict)
lb = dataset.data_vars["lb_stripped"].values
ub = dataset.data_vars["ub_stripped"].values + 1
tau = dataset["time"]
points = np.arange(lb[0], ub[0], dtype = int)
phi = np.arange(lb[1], ub[1], dtype = int)
vp = np.arange(lb[2], ub[2], dtype = int)
mu = np.arange(lb[3], ub[3], dtype = int)
spec = np.arange(lb[4], ub[4], dtype = int)
dataset = dataset.assign_coords({"tau": tau})
dataset = dataset.assign_coords({"points": points})
dataset = dataset.assign_coords({"phi": phi})
dataset = dataset.assign_coords({"spec": spec})
dataset = dataset.assign_coords({"vp": vp})
dataset = dataset.assign_coords({"mu": mu})
return dataset["dist_func"]
def load_from_checkpoint_file(directory_path: Path, chunks="auto", latest=True, which=None):
"""Load the distribution function from the checkpoints file."""
if not has_parts(directory_path):
path = filepath_resolver(directory_path, "checkpoint.nc")
if not path.exists():
raise FileNotFoundError("There is no checkpoint.nc in the given directory!")
dataset = load_checkpoint(path)
else:
all_paths = find_all(directory_path, "checkpoint_*.nc")
paths = [p for p in all_paths if "neut_" not in p.stem]
if(which == None):
# Filter to only load latest part if specified
if(latest and len(paths) > 1):
latest_num = latest_part(paths)
part_name = "part_" + str(latest_num)
print(f" Latest checkpoint found in {part_name} ...")
filt = filter(lambda p: part_name in str(p), paths)
paths = list(filt)
else:
part_name = which
print(f" Checkpoint found in {part_name} ...")
filt = filter(lambda p: part_name in str(p), paths)
paths = list(filt)
# Sort paths naturally by the appending number
paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))
data = []
[data.append(load_checkpoint(path, chunks=chunks)) for path in paths]
dataset = xr.combine_by_coords(data)
dataset = dataset.sortby("tau")
dataset.attrs["name"] = "dist_func"
return dataset["dist_func"]
[docs]
@autodoc_function
def load_trace_genex(directory_path: Path, species: str, variable: str):
"""
Load a single 0d variable from the GENE-X diagnostic files.
Uses the directory path of the GENE-X output directory, the species name
and the variable name.
"""
if variable in keys_mom_0d:
return load_from_mom_0d_file(directory_path, species, variable)
[docs]
@autodoc_function
def load_snaps_genex(
directory_path: Path, species: str, variable: str, mask_ghost: bool=True,
**loader_kwargs
):
"""
Load a single 2d variable from the GENE-X diagnostic files.
Uses the directory path of the GENE-X output directory, the species name
and the variable name.
If species is None variables not related to species like 'es_pot' or
'A_par' can be loaded.
"""
if species is None:
if variable in keys_par_con:
ds = load_from_par_con_file(directory_path, variable)
elif variable in keys_em_fields:
ds = load_from_em_fields_file(directory_path, variable)
else:
raise ValueError(f"Variable {variable} is not a valid key!")
else:
if variable in keys_mom_2d:
ds = load_from_mom_2d_file(directory_path, species, variable)
else:
raise ValueError(f"Variable {variable} is not a valid key!")
ds = mask_ghost_and_filler(directory_path, ds)
return auto_chunk(ds)
[docs]
@autodoc_function
def load_checkpoint_genex(
directory_path: Path, chunks="auto", latest=True, which=None,
mask_ghost: bool=True,**loader_kwargs
):
"""
Load the 5d distribution function from a GENE-X checkpoint file.
Uses the directory path of the GENE-X output directory, the species
name and the variable name.
May specify the exact chunking using the torx dimension names and may
specify only to load the latest checkpoint in the part structure. Or
choose which checkpoint to load specifically.
"""
ds = load_from_checkpoint_file(directory_path, chunks=chunks,
latest=latest, which=which)
ds = mask_ghost_and_filler(directory_path, ds)
return auto_chunk(ds)