Source code for torx.performance.chunking_m

"""
Contains functions to automatically chunk xarray datasets.

Ensures that each chunk is smaller than the limiting memory.
"""
import sys
import math
import xarray as xr
import warnings

from torx.autodoc_decorators_m import autodoc_function

from .system_information_m import get_mem_avail, get_num_cores
from .byte_tools_m import convert_to_bytes

chunk_priority = {"tau": 0, "spec": 1, "mu": 2, "vp": 3, "phi": 4,
                  "vector": 5, "points":-1}
chunk_limit = 2147483647

[docs] @autodoc_function def get_chunk_limit(): """Return the current chunk limit in bytes.""" return chunk_limit
[docs] @autodoc_function def set_chunk_limit(chlim): """ Set the current chunk limit to a value given. Can be either a float, integer or bytestring. """ global chunk_limit if type(chlim) is int: chunk_limit = chlim elif type(chlim) is float: chunk_limit = int(chlim) elif type(chlim) is str: chunk_limit = convert_to_bytes(chlim) else: raise RuntimeError("Datatype of chunk limit must be either integer, " \ f"float or string! (was {type(chlim)})")
def _dict_values_prod(dict_to_prod: dict): """Return the product of the values contained in a dict.""" prod = 1 for x in dict_to_prod.values(): prod = prod * x return prod def _get_n_chunks(n_chunks_target: int): """Return a valid number of chunks given the targeted number.""" # Treat large number of chunks where prime search is not viable: if odd, # add 1 to make it even, otherwise return it. if n_chunks_target > int(1e6): if n_chunks_target % 2 == 0: return n_chunks_target + 1 else: return n_chunks_target # Standard treatment: check if n_chunks is prime, if not return it. # Otherwise add 1 to make it non prime and return that. if n_chunks_target == 1: return n_chunks_target for i in range(2, n_chunks_target): if n_chunks_target%i == 0: return n_chunks_target return n_chunks_target + 1 def _create_chunking_dict(dims, n_chunks_target): """Create the chunking dict which can be used to chunk an xarray dataset.""" n_chunks = n_chunks_target priority_dict = chunk_priority.copy() # First check all keys which are not in the chunk_priority dict and # assign them the same lowest priority for key in dims.keys(): if not key in chunk_priority: priority_dict[key] = max(chunk_priority.values()) + 1 # Second, remove all keys from priority_dict which are not in dims for key in chunk_priority.keys(): if not key in dims: priority_dict.pop(key) # Third, replace the lowest priority entries (<0) by large positive values for key in priority_dict.keys(): if(priority_dict[key] < 0): priority_dict[key] = sys.maxsize + priority_dict[key] # Create chunking dict: xarray dataset dims dict is frozen and needs to # be manually copied chunking_dict = {} for key, var in dims.items(): chunking_dict[key] = var # Set chunk size in chunking dict, starting from highest chunking priority for key in sorted(priority_dict, key=priority_dict.get): if dims[key] > n_chunks: chunk_size = math.floor(dims[key] / n_chunks) chunking_dict[key] = chunk_size break else: chunk_size = math.gcd(dims[key], n_chunks) if(chunk_size == 1): continue chunking_dict[key] = chunk_size n_chunks = n_chunks // chunk_size return chunking_dict
[docs] @autodoc_function def auto_chunk(ds: xr.Dataset): """Automatically chunk an xarray dataset and returns the result.""" dimensions = ds.sizes total_points = _dict_values_prod(dimensions) total_mem = total_points * 8 # TODO: Check influence of hyperthreading on available mem. Check if # submitted jobs use hyperthreading or not. If required, adjust # the available memory per thread. avail_mem_per_thread = get_mem_avail() // get_num_cores() # Either the available memory or the hard chunk size limit are used limiting_mem = min(avail_mem_per_thread, chunk_limit) n_chunks = _get_n_chunks(math.ceil(total_mem / limiting_mem)) chunking_dict = _create_chunking_dict(dimensions, n_chunks) ds = ds.chunk(chunks=chunking_dict) return ds