Source code for torx.performance.chunking_m
"""
Contains functions to automatically chunk xarray datasets.
Ensures that each chunk is smaller than the limiting memory.
"""
import sys
import math
import xarray as xr
import warnings
from torx.autodoc_decorators_m import autodoc_function
from .system_information_m import get_mem_avail, get_num_cores
from .byte_tools_m import convert_to_bytes
chunk_priority = {"tau": 0, "spec": 1, "mu": 2, "vp": 3, "phi": 4,
"vector": 5, "points":-1}
chunk_limit = 2147483647
[docs]
@autodoc_function
def get_chunk_limit():
"""Return the current chunk limit in bytes."""
return chunk_limit
[docs]
@autodoc_function
def set_chunk_limit(chlim):
"""
Set the current chunk limit to a value given.
Can be either a float, integer or bytestring.
"""
global chunk_limit
if type(chlim) is int:
chunk_limit = chlim
elif type(chlim) is float:
chunk_limit = int(chlim)
elif type(chlim) is str:
chunk_limit = convert_to_bytes(chlim)
else:
raise RuntimeError("Datatype of chunk limit must be either integer, " \
f"float or string! (was {type(chlim)})")
def _dict_values_prod(dict_to_prod: dict):
"""Return the product of the values contained in a dict."""
prod = 1
for x in dict_to_prod.values():
prod = prod * x
return prod
def _get_n_chunks(n_chunks_target: int):
"""Return a valid number of chunks given the targeted number."""
# Treat large number of chunks where prime search is not viable: if odd,
# add 1 to make it even, otherwise return it.
if n_chunks_target > int(1e6):
if n_chunks_target % 2 == 0:
return n_chunks_target + 1
else:
return n_chunks_target
# Standard treatment: check if n_chunks is prime, if not return it.
# Otherwise add 1 to make it non prime and return that.
if n_chunks_target == 1:
return n_chunks_target
for i in range(2, n_chunks_target):
if n_chunks_target%i == 0:
return n_chunks_target
return n_chunks_target + 1
def _create_chunking_dict(dims, n_chunks_target):
"""Create the chunking dict which can be used to chunk an xarray dataset."""
n_chunks = n_chunks_target
priority_dict = chunk_priority.copy()
# First check all keys which are not in the chunk_priority dict and
# assign them the same lowest priority
for key in dims.keys():
if not key in chunk_priority:
priority_dict[key] = max(chunk_priority.values()) + 1
# Second, remove all keys from priority_dict which are not in dims
for key in chunk_priority.keys():
if not key in dims:
priority_dict.pop(key)
# Third, replace the lowest priority entries (<0) by large positive values
for key in priority_dict.keys():
if(priority_dict[key] < 0):
priority_dict[key] = sys.maxsize + priority_dict[key]
# Create chunking dict: xarray dataset dims dict is frozen and needs to
# be manually copied
chunking_dict = {}
for key, var in dims.items():
chunking_dict[key] = var
# Set chunk size in chunking dict, starting from highest chunking priority
for key in sorted(priority_dict, key=priority_dict.get):
if dims[key] > n_chunks:
chunk_size = math.floor(dims[key] / n_chunks)
chunking_dict[key] = chunk_size
break
else:
chunk_size = math.gcd(dims[key], n_chunks)
if(chunk_size == 1): continue
chunking_dict[key] = chunk_size
n_chunks = n_chunks // chunk_size
return chunking_dict
[docs]
@autodoc_function
def auto_chunk(ds: xr.Dataset):
"""Automatically chunk an xarray dataset and returns the result."""
dimensions = ds.sizes
total_points = _dict_values_prod(dimensions)
total_mem = total_points * 8
# TODO: Check influence of hyperthreading on available mem. Check if
# submitted jobs use hyperthreading or not. If required, adjust
# the available memory per thread.
avail_mem_per_thread = get_mem_avail() // get_num_cores()
# Either the available memory or the hard chunk size limit are used
limiting_mem = min(avail_mem_per_thread, chunk_limit)
n_chunks = _get_n_chunks(math.ceil(total_mem / limiting_mem))
chunking_dict = _create_chunking_dict(dimensions, n_chunks)
ds = ds.chunk(chunks=chunking_dict)
return ds