Source code for torx.units_handling_m
"""
Set up the pint unit library for use in torx.
See unit_registry.txt for defined units (for reference only).
You should ALWAYS import Quantity and unit_registry from this file,
and never from pint directly (need to use a consistent unit registry,
otherwise unit conversions can get weird).
"""
import numpy as np
import xarray as xr
import pint
import os
import warnings
from typing import Union, TypeAlias
from torx.autodoc_decorators_m import autodoc_function
# Disable Pint's old fallback behavior (must come before importing Pint)
os.environ["PINT_ARRAY_PROTOCOL_FALLBACK"] = "0"
# NOTE: The following allows to overwrite the pint Quantity class by a custom
# version which is exactly the same but allows for type checks in the
# library that make sure that only this version is used.
class TorxQuantity(pint.Quantity):
"""Quantity object for use in torx."""
pass
class TorxUnitRegistry(pint.registry.UnitRegistry):
"""Unit registry object for use in torx."""
Quantity: TypeAlias = TorxQuantity
Unit: TypeAlias = pint.Unit
unit_registry = TorxUnitRegistry()
Quantity = unit_registry.Quantity
Dimensionless = Quantity(1, "")
# Silence NEP18 warning
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Quantity([])
unit_registry.setup_matplotlib()
unit_registry.formatter.default_format = "~P"
# Hide a warning that units are stripped when downcasting to numpy arrays
warnings.filterwarnings("ignore", category=pint.errors.UnitStrippedWarning)
serialization_target = "default"
[docs]
@autodoc_function
def set_serialization_target(target: str):
"""
Set the target backend of the serialization routines.
Should be "default" or "netcdf".
"""
global serialization_target
assert (target == "default" or target == "netcdf")
serialization_target = target
[docs]
@autodoc_function
def serialize_Quantity(quant: Quantity):
"""
Convert a quantity object to a serializable object.
Useful for storing the quantity to file.
"""
if serialization_target == "netcdf":
return [float(quant.m), str(quant.u)]
else:
return quant.to_tuple()
[docs]
@autodoc_function
def deserialize_Quantity(quant) -> Quantity:
"""Convert a serialized quantity object back."""
if serialization_target == "netcdf":
assert len(quant) == 2, \
"Serialized Quantity must contain two entries: magnitude and unit."
return Quantity(float(quant[0]), quant[1])
else:
return Quantity.from_tuple(quant)
[docs]
@autodoc_function
def convert_xarray_to_quantity(input_array: xr.DataArray) -> Quantity:
"""
Convert an xarray with units attribute into a quantity with units.
Useful for unit checking.
If the base array is already a Quantity, its units will be silently
overwritten.
Note that all attributes except units will be lost.
Not recommended except as a sanity check/in tests.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore",
category=pint.errors.UnitStrippedWarning)
return input_array.norm * np.asarray(input_array)
@autodoc_function
def check_units(value: Quantity, unit: dict, name: str = "Value"):
"""Check value is torx quantity and units are as expected."""
if not isinstance(value, Quantity):
raise TypeError(
f"{name} must be a Quantity, got {type(value).__name__}."
)
if not value.check(unit):
raise ValueError(
f"{name} must have dimensions {unit}, got {value.units}."
)