Source code for torx.specializations.genex.snaps_access_m

"""Access data stored in GENE-X diagnostic files."""
import xarray as xr
import numpy as np
import os
import re
from pathlib import Path
from glob import glob

from torx.fileio import filepath_resolver
from .grid_helpers_m import mask_ghost_and_filler
from torx.autodoc_decorators_m import autodoc_function
from torx.performance import auto_chunk

dataset_storage = {}

keys_mom_0d = ["n", "u_par", "E_par", "E_perp", "E_es", "p_phi"]

keys_mom_2d = ["n", "u_par", "E_par", "E_perp", "Q_par", "Q_perp"]

keys_em_fields = ["es_pot", "A_par"]

keys_par_con = ["in_target", "par_con_positive", "par_con_negative"]


def has_parts(directory_path: Path):
    """Check whether the given output features the GENE-X version with parts."""
    regex = re.compile("part_*")
    for elem in directory_path.iterdir():
        if elem.is_dir():
            if regex.match(elem.name):
                return True
    return False

def find_all(directory_path: Path, filename: str):
    """Find all files containing the given filename in the directory given."""
    res = directory_path.rglob(filename)
    return [elem for elem in res]

def latest_part(paths: list):
    """
    Return the latest part that is contained in the list of paths.

    It is determined by the biggest number in part_xxx.
    """
    nums = np.array([int(re.findall(r'part_(\d+)', str(p))[0]) for p in paths])
    return nums.max()

def load_from_par_con_file(directory_path: Path, variable: str):
    """Load a given variable from the 'par_con.nc' file."""
    try:
        path = filepath_resolver(directory_path, "par_con.nc")
        dataset = xr.open_dataset(path)
        dataset = dataset.rename({"dim_phi_grid": "phi",
                                  "dim_RZ_grid": "points"})
    except:
        dataset = xr.open_dataset(
            filepath_resolver(directory_path, "mesh.nc"),
            group = "parcon"
        )
        dataset = dataset.rename({"dim_phi": "phi",
                                  "dim_RZ": "points"})

    phi = dataset["phi"]
    dataset = dataset.assign_coords({"phi": phi})

    # Update for not_in_target mask backwards compatible with in_target mask
    if "not_in_target" in dataset.data_vars and variable == "in_target":
        in_target = 1.0 - dataset["not_in_target"]
    else:
        in_target = dataset[variable]
    return in_target

def load_em_fields_file(path: Path):
    """Load data from the em_fields.nc file."""
    dataset = xr.open_dataset(path, chunks="auto")
    dataset = dataset.rename(
        {"dim_time": "tau", "dim_phi": "phi", "dim_RZ": "points", "time": "tau"}
    )

    tau = dataset["tau"]
    phi = dataset["phi"]

    dataset = dataset.assign_coords({"tau": tau})
    dataset = dataset.assign_coords({"phi": phi})
    return dataset

def load_from_em_fields_file(directory_path: Path, variable: str):
    """Load a given variable from the em_fields.nc file."""
    if not has_parts(directory_path):
        path = filepath_resolver(directory_path, "em_fields.nc")
        if not path.exists():
            raise FileNotFoundError("There is no em_fields.nc in the given directory!")
        dataset = load_em_fields_file(path)
    else:
        paths = find_all(directory_path, "em_fields.nc")
        # Sort paths naturally by the appending number
        paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))
        dataset = xr.combine_by_coords([load_em_fields_file(path) for path in paths])

    return dataset[variable]

def load_mom_2d_group(path: Path, group: str, variable: str):
    """Load data from a group in the mom_2d.nc file."""
    # NOTE: Time is only present in the file if no group is specified, thus
    #       fetch it first before data load
    dataset = xr.open_dataset(path, chunks="auto")
    time = dataset["time"].values

    dataset = xr.open_dataset(path, group=group, chunks="auto")
    dataset = dataset.rename(
        {"dim_time": "tau", "dim_phi": "phi", "dim_RZ": "points"}
    )

    dataset = dataset.assign_coords({"tau": time})
    dataset = dataset.assign_coords({"phi": dataset["phi"]})
    return dataset

def load_from_mom_2d_file(directory_path: Path, species: str, variable: str):
    """Load a given variable from a group in the mom_2d.nc file."""
    if not has_parts(directory_path):
        path = filepath_resolver(directory_path, "mom_2d.nc")
        if not path.exists():
            raise FileNotFoundError("There is no mom_2d.nc in the "\
                                    "given directory!")

        dataset = load_mom_2d_group(path, species, variable)
    else:
        paths = find_all(directory_path, "mom_2d.nc")
        # Sort paths naturally by the appending number
        paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))
        dataset = xr.combine_by_coords(
            [load_mom_2d_group(path, species, variable) for path in paths]
        )

    return dataset[variable]

def load_mom_0d_group(path: Path, group: str, variable: str):
    """Load data from a group in the mom_0d.nc file."""
    dataset = xr.open_dataset(path)
    dataset = dataset.rename({"time": "tau", "dim_time": "tau"})

    tau = dataset["tau"]

    dataset = xr.open_dataset(path, group=group)
    dataset = dataset.rename({"dim_time": "tau"})

    dataset = dataset.assign_coords({"tau": tau})

    return dataset[variable]

def load_from_mom_0d_file(directory_path: Path, species: str, variable: str):
    """Load a given variable from a group in the mom_0d.nc file."""
    if not has_parts(directory_path):
        path = filepath_resolver(directory_path, "mom_0d.nc")
        if not path.exists():
            raise FileNotFoundError("There is no mom_0d.nc in the given directory!")

        dataset = load_mom_0d_group(path, species, variable)
    else:
        paths = find_all(directory_path, "mom_0d.nc")
        # Sort paths naturally by the appending number
        paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))

        dataset = xr.concat(
            [load_mom_0d_group(path, species, variable) for path in paths], dim="tau"
        )
        dataset = dataset.assign_coords({"tau": dataset["tau"]})
        dataset = dataset.sortby("tau")
        dataset.attrs["name"] = variable
    return dataset

def load_checkpoint(path: Path, chunks="auto"):
    """Load data from the checkpoints file."""
    renaming_dict = {"dim_time": "tau",
                     "dim_phi": "phi",
                     "dim_RZ": "points",
                     "dim_sp": "spec",
                     "dim_vp": "vp",
                     "dim_mu": "mu"}

    # If chunks is not auto, map the torx dimension names in the chunks dict
    # to the ones used in the NetCDF file. This requires inverting the renaming
    # dict and renaming the keys.
    if(chunks != "auto"):
        # We need to make a hardcopy of the chunks dict since it is passed by reference
        # to this function. I.e. changes done here will affect succeeding uses outside
        # of this function.
        chunks_local = chunks.copy()

        chunk_renaming = {v: k for k, v in renaming_dict.items()}
        keys = list(chunks_local.keys())
        for k in keys:
            chunks_local[chunk_renaming[k]] = chunks_local.pop(k)
    else:
        chunks_local = chunks

    # Open dataset and rename the keys to match torx convention
    dataset = xr.open_dataset(path, chunks=chunks_local)
    dataset = dataset.rename(renaming_dict)

    lb = dataset.data_vars["lb_stripped"].values
    ub = dataset.data_vars["ub_stripped"].values + 1

    tau = dataset["time"]
    points = np.arange(lb[0], ub[0], dtype = int)
    phi    = np.arange(lb[1], ub[1], dtype = int)
    vp     = np.arange(lb[2], ub[2], dtype = int)
    mu     = np.arange(lb[3], ub[3], dtype = int)
    spec   = np.arange(lb[4], ub[4], dtype = int)

    dataset = dataset.assign_coords({"tau": tau})
    dataset = dataset.assign_coords({"points": points})
    dataset = dataset.assign_coords({"phi": phi})
    dataset = dataset.assign_coords({"spec": spec})
    dataset = dataset.assign_coords({"vp": vp})
    dataset = dataset.assign_coords({"mu": mu})

    return dataset["dist_func"]

def load_from_checkpoint_file(directory_path: Path, chunks="auto", latest=True, which=None):
    """Load the distribution function from the checkpoints file."""
    if not has_parts(directory_path):
        path = filepath_resolver(directory_path, "checkpoint.nc")
        if not path.exists():
            raise FileNotFoundError("There is no checkpoint.nc in the given directory!")

        dataset = load_checkpoint(path)
    else:
        all_paths = find_all(directory_path, "checkpoint_*.nc")
        paths = [p for p in all_paths if "neut_" not in p.stem]
        if(which == None):
            # Filter to only load latest part if specified
            if(latest and len(paths) > 1):
                latest_num = latest_part(paths)
                part_name = "part_" + str(latest_num)
                print(f"  Latest checkpoint found in {part_name} ...")
                filt = filter(lambda p: part_name in str(p), paths)
                paths = list(filt)
        else:
            part_name = which
            print(f"  Checkpoint found in {part_name} ...")
            filt = filter(lambda p: part_name in str(p), paths)
            paths = list(filt)

        # Sort paths naturally by the appending number
        paths = sorted(paths, key=lambda x: int(x.parent.stem.split("_")[-1]))

        data = []
        [data.append(load_checkpoint(path, chunks=chunks)) for path in paths]

        dataset = xr.combine_by_coords(data)
        dataset = dataset.sortby("tau")
        dataset.attrs["name"] = "dist_func"
    return dataset["dist_func"]

[docs] @autodoc_function def load_trace_genex(directory_path: Path, species: str, variable: str): """ Load a single 0d variable from the GENE-X diagnostic files. Uses the directory path of the GENE-X output directory, the species name and the variable name. """ if variable in keys_mom_0d: return load_from_mom_0d_file(directory_path, species, variable)
[docs] @autodoc_function def load_snaps_genex( directory_path: Path, species: str, variable: str, mask_ghost: bool=True, **loader_kwargs ): """ Load a single 2d variable from the GENE-X diagnostic files. Uses the directory path of the GENE-X output directory, the species name and the variable name. If species is None variables not related to species like 'es_pot' or 'A_par' can be loaded. """ if species is None: if variable in keys_par_con: ds = load_from_par_con_file(directory_path, variable) elif variable in keys_em_fields: ds = load_from_em_fields_file(directory_path, variable) else: raise ValueError(f"Variable {variable} is not a valid key!") else: if variable in keys_mom_2d: ds = load_from_mom_2d_file(directory_path, species, variable) else: raise ValueError(f"Variable {variable} is not a valid key!") ds = mask_ghost_and_filler(directory_path, ds) return auto_chunk(ds)
[docs] @autodoc_function def load_checkpoint_genex( directory_path: Path, chunks="auto", latest=True, which=None, mask_ghost: bool=True,**loader_kwargs ): """ Load the 5d distribution function from a GENE-X checkpoint file. Uses the directory path of the GENE-X output directory, the species name and the variable name. May specify the exact chunking using the torx dimension names and may specify only to load the latest checkpoint in the part structure. Or choose which checkpoint to load specifically. """ ds = load_from_checkpoint_file(directory_path, chunks=chunks, latest=latest, which=which) ds = mask_ghost_and_filler(directory_path, ds) return auto_chunk(ds)