From 75bb535907bb32978e0c477f5a5c863e65d7d6e5 Mon Sep 17 00:00:00 2001 From: Theo Date: Mon, 9 Mar 2026 16:51:09 -0700 Subject: [PATCH 01/27] feat: add streaming write API for iterator-based array writing Add write_array_streaming() to the Rust writer, accepting a Python iterator that yields numpy chunks in C-order. This enables writing a single omfile incrementally, as opposed to write_array() which requires the source array to be fully materialized in memory. --- python/omfiles/_rust/__init__.pyi | 40 ++++++ src/writer.rs | 148 +++++++++++++++++++- tests/test_streaming_write.py | 222 ++++++++++++++++++++++++++++++ 3 files changed, 409 insertions(+), 1 deletion(-) create mode 100644 tests/test_streaming_write.py diff --git a/python/omfiles/_rust/__init__.pyi b/python/omfiles/_rust/__init__.pyi index 5ecfff6..a6b6304 100644 --- a/python/omfiles/_rust/__init__.pyi +++ b/python/omfiles/_rust/__init__.pyi @@ -573,6 +573,46 @@ class OmFileWriter: Raises: ValueError: If the data type is unsupported or if parameters are invalid """ + def write_array_streaming( + self, + dimensions: typing.Sequence[builtins.int], + chunks: typing.Sequence[builtins.int], + chunk_iterator: typing.Any, + dtype: builtins.str, + scale_factor: typing.Optional[builtins.float] = None, + add_offset: typing.Optional[builtins.float] = None, + compression: typing.Optional[builtins.str] = None, + name: typing.Optional[builtins.str] = None, + children: typing.Optional[typing.Sequence[OmVariable]] = None, + ) -> OmVariable: + r""" + Write an array to the .om file by streaming chunks from a Python iterator. + + This method is designed for writing large arrays that do not fit in memory. + Instead of providing the full array, you provide the full array dimensions + and an iterator that yields numpy array chunks. + + Chunks MUST be yielded in row-major order (C-order) of the chunk grid. + Each chunk's shape determines how many internal file chunks it covers. + + Args: + dimensions: Shape of the full array (e.g., [1000, 2000]) + chunks: Chunk sizes for each dimension (e.g., [100, 200]) + chunk_iterator: Python iterable yielding numpy arrays, one per chunk region + dtype: String name of the numpy dtype (e.g., "float32", "int64") + scale_factor: Scale factor for data compression (default: 1.0) + add_offset: Offset value for data compression (default: 0.0) + compression: Compression algorithm to use (default: "pfor_delta_2d") + name: Name of the variable (default: "data") + children: List of child variables (default: []) + + Returns: + :py:data:`omfiles.OmVariable` representing the written array in the file structure + + Raises: + ValueError: If the dtype is unsupported or parameters are invalid + RuntimeError: If there's an error during compression or I/O + """ def write_scalar( self, value: typing.Any, name: builtins.str, children: typing.Optional[typing.Sequence[OmVariable]] = None ) -> OmVariable: diff --git a/src/writer.rs b/src/writer.rs index 37ba160..454ca9a 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -13,7 +13,7 @@ use omfiles_rs::{ OmCompressionType, OmFilesError, OmOffsetSize, }; use pyo3::{ - exceptions::{PyRuntimeError, PyValueError}, + exceptions::{PyRuntimeError, PyStopIteration, PyValueError}, prelude::*, }; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; @@ -99,6 +99,42 @@ impl OmFileWriter { }) } + fn write_array_streaming_internal<'py, T>( + &mut self, + py: Python<'py>, + dimensions: Vec, + chunks: Vec, + scale_factor: f32, + add_offset: f32, + compression: OmCompressionType, + chunk_iterator: &Bound<'py, PyAny>, + ) -> PyResult + where + T: Element + OmFileArrayDataType, + { + self.with_writer(|writer| { + let mut array_writer = writer + .prepare_array::(dimensions, chunks, compression, scale_factor, add_offset) + .map_err(convert_omfilesrs_error)?; + + loop { + let next_item = chunk_iterator.call_method0("__next__"); + match next_item { + Ok(item) => { + let array: PyReadonlyArrayDyn<'_, T> = item.extract()?; + array_writer + .write_data(array.as_array(), None, None) + .map_err(convert_omfilesrs_error)?; + } + Err(err) if err.is_instance_of::(py) => break, + Err(err) => return Err(err), + } + } + + Ok(array_writer.finalize()) + }) + } + fn store_scalar( &mut self, value: T, @@ -302,6 +338,116 @@ impl OmFileWriter { }) } + /// Write an array to the .om file by streaming chunks from a Python iterator. + /// + /// This method is designed for writing large arrays that do not fit in memory. + /// Instead of providing the full array, you provide the full array dimensions + /// and an iterator that yields numpy array chunks. + /// + /// Chunks MUST be yielded in row-major order (C-order) of the chunk grid. + /// Each chunk's shape determines how many internal file chunks it covers. + /// + /// Args: + /// dimensions: Shape of the full array (e.g., [1000, 2000]) + /// chunks: Chunk sizes for each dimension (e.g., [100, 200]) + /// chunk_iterator: Python iterable yielding numpy arrays, one per chunk region + /// dtype: String name of the numpy dtype (e.g., "float32", "int64") + /// scale_factor: Scale factor for data compression (default: 1.0) + /// add_offset: Offset value for data compression (default: 0.0) + /// compression: Compression algorithm to use (default: "pfor_delta_2d") + /// name: Name of the variable (default: "data") + /// children: List of child variables (default: []) + /// + /// Returns: + /// :py:data:`omfiles.OmVariable` representing the written array in the file structure + /// + /// Raises: + /// ValueError: If the dtype is unsupported or parameters are invalid + /// RuntimeError: If there's an error during compression or I/O + #[pyo3( + text_signature = "(dimensions, chunks, chunk_iterator, dtype, scale_factor=1.0, add_offset=0.0, compression='pfor_delta_2d', name='data', children=[])", + signature = (dimensions, chunks, chunk_iterator, dtype, scale_factor=None, add_offset=None, compression=None, name=None, children=None) + )] + fn write_array_streaming( + &mut self, + py: Python<'_>, + dimensions: Vec, + chunks: Vec, + chunk_iterator: &Bound<'_, PyAny>, + dtype: &str, + scale_factor: Option, + add_offset: Option, + compression: Option<&str>, + name: Option<&str>, + children: Option>, + ) -> PyResult { + let name = name.unwrap_or("data"); + let children: Vec = children + .unwrap_or_default() + .iter() + .map(Into::into) + .collect(); + + let scale_factor = scale_factor.unwrap_or(1.0); + let add_offset = add_offset.unwrap_or(0.0); + let compression = compression + .map(|s| PyCompressionType::from_str(s)) + .transpose()? + .unwrap_or(PyCompressionType::PforDelta2d) + .to_omfilesrs(); + + let iter = chunk_iterator.call_method0("__iter__")?; + + let array_meta = match dtype { + "float32" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "float64" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "int8" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "uint8" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "int16" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "uint16" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "int32" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "uint32" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "int64" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + "uint64" => self.write_array_streaming_internal::( + py, dimensions, chunks, scale_factor, add_offset, compression, &iter, + ), + _ => Err(PyValueError::new_err(format!( + "Unsupported dtype: {}", + dtype + ))), + }?; + + self.with_writer(|writer| { + let offset_size = writer + .write_array(array_meta, name, &children) + .map_err(convert_omfilesrs_error)?; + + Ok(OmVariable { + name: name.to_string(), + offset: offset_size.offset, + size: offset_size.size, + }) + }) + } + /// Write a scalar value to the .om file. /// /// Args: diff --git a/tests/test_streaming_write.py b/tests/test_streaming_write.py new file mode 100644 index 0000000..aed97a6 --- /dev/null +++ b/tests/test_streaming_write.py @@ -0,0 +1,222 @@ +import tempfile + +import numpy as np +import pytest +from omfiles import OmFileReader, OmFileWriter + + +class TestWriteArrayStreaming: + def test_streaming_single_chunk(self): + shape = (10, 20) + chunks = [10, 20] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + + def chunk_iter(): + yield data + + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, data, decimal=4) + + def test_streaming_multiple_chunks_2d(self): + shape = (10, 20) + chunks = [5, 10] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + + def chunk_iter(): + for i in range(0, 10, 5): + for j in range(0, 20, 10): + yield data[i : i + 5, j : j + 10].copy() + + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, data, decimal=4) + + def test_streaming_all_dtypes(self): + shape = (6, 8) + chunks = [3, 4] + dtypes = [ + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ] + + for dt in dtypes: + if np.issubdtype(dt, np.floating): + data = np.random.rand(*shape).astype(dt) + elif np.issubdtype(dt, np.signedinteger): + info = np.iinfo(dt) + data = np.random.randint(max(info.min, -1000), min(info.max, 1000), size=shape, dtype=dt) + else: + info = np.iinfo(dt) + data = np.random.randint(0, min(info.max, 1000), size=shape, dtype=dt) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + + def chunk_iter(d=data): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield d[i:ie, j:je].copy() + + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype=np.dtype(dt).name, + scale_factor=10000.0, + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + assert result.dtype == dt, f"dtype mismatch for {dt}" + np.testing.assert_array_almost_equal(result, data, decimal=4) + + def test_streaming_3d_array(self): + shape = (4, 6, 8) + chunks = [2, 3, 4] + data = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + for k in range(0, shape[2], chunks[2]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + ke = min(k + chunks[2], shape[2]) + yield data[i:ie, j:je, k:ke].copy() + + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="int32", + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_equal(result, data) + + def test_streaming_boundary_chunks(self): + shape = (7, 13) + chunks = [4, 5] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield data[i:ie, j:je].copy() + + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, data, decimal=4) + + def test_streaming_matches_write_array(self): + shape = (10, 20) + chunks = [5, 10] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + with tempfile.NamedTemporaryFile(suffix=".om") as f1: + writer1 = OmFileWriter(f1.name) + var1 = writer1.write_array(data, chunks=chunks, scale_factor=10000.0) + writer1.close(var1) + reader1 = OmFileReader(f1.name) + result1 = reader1[:] + reader1.close() + + with tempfile.NamedTemporaryFile(suffix=".om") as f2: + writer2 = OmFileWriter(f2.name) + + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield data[i:ie, j:je].copy() + + var2 = writer2.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, + ) + writer2.close(var2) + reader2 = OmFileReader(f2.name) + result2 = reader2[:] + reader2.close() + + np.testing.assert_array_equal(result1, result2) + + def test_streaming_unsupported_dtype_raises(self): + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(ValueError, match="Unsupported dtype"): + writer.write_array_streaming( + dimensions=[10], + chunks=[5], + chunk_iterator=iter([]), + dtype="complex128", + ) From 3665d74a10912dc24ab67bfd1da0d9ec4e66e053 Mon Sep 17 00:00:00 2001 From: Theo Date: Mon, 9 Mar 2026 16:51:14 -0700 Subject: [PATCH 02/27] feat: add xarray Dataset write support Add write_dataset() for writing xarray Datasets to OM files with support for per-variable encoding, scalar and non-dimension coordinate roundtrip, and automatic streaming for dask-backed arrays. --- python/omfiles/__init__.py | 7 + python/omfiles/xarray.py | 284 ++++++++++++++++++++++++++++++++++++- tests/test_xarray.py | 180 +++++++++++++++++++++++ 3 files changed, 469 insertions(+), 2 deletions(-) diff --git a/python/omfiles/__init__.py b/python/omfiles/__init__.py index 42f492c..f839706 100644 --- a/python/omfiles/__init__.py +++ b/python/omfiles/__init__.py @@ -12,3 +12,10 @@ "OmVariable", "types", ] + +try: + from .xarray import write_dataset + + __all__.append("write_dataset") +except ImportError: + pass diff --git a/python/omfiles/xarray.py b/python/omfiles/xarray.py index 4bcfb43..a4831f9 100644 --- a/python/omfiles/xarray.py +++ b/python/omfiles/xarray.py @@ -3,6 +3,11 @@ from __future__ import annotations +import itertools +import os +import warnings +from typing import Any, Generator, Sequence + import numpy as np try: @@ -21,7 +26,7 @@ from xarray.core.utils import FrozenDict from xarray.core.variable import Variable -from ._rust import OmFileReader, OmVariable +from ._rust import OmFileReader, OmFileWriter, OmVariable # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -41,10 +46,16 @@ def open_dataset( with OmFileReader(filename_or_obj) as root_variable: store = OmDataStore(root_variable) store_entrypoint = StoreBackendEntrypoint() - return store_entrypoint.open_dataset( + ds = store_entrypoint.open_dataset( store, drop_variables=drop_variables, ) + coord_attr = "_COORDINATE_VARIABLES" + if coord_attr in ds.attrs: + coord_names = [c for c in ds.attrs[coord_attr].split(",") if c in ds] + ds = ds.set_coords(coord_names) + ds.attrs = {k: v for k, v in ds.attrs.items() if k != coord_attr} + return ds raise ValueError("Failed to open dataset") description = "Use .om files in Xarray" @@ -76,6 +87,11 @@ def _get_attributes_for_variable(self, reader: OmFileReader, path: str): for k, variable in direct_children.items(): child_reader = reader._init_from_variable(variable) if child_reader.is_scalar: + # Skip scalars that have _ARRAY_DIMENSIONS — they are 0-d + # coordinate variables, not plain attributes. + dim_key = path + "/" + k + "/" + DIMENSION_KEY + if dim_key in self.variables_store: + continue attrs[k] = child_reader.read_scalar() return attrs @@ -153,6 +169,31 @@ def _get_datasets(self, reader: OmFileReader): data = indexing.LazilyIndexedArray(backend_array) datasets[var_key] = Variable(dims=dim_names, data=data, attrs=attrs_for_var, encoding=None, fastpath=True) + + # Handle 0-d (scalar) variables that have _ARRAY_DIMENSIONS metadata. + # These are scalar coordinates written by write_dataset. + for var_key, var in self.variables_store.items(): + if var_key in datasets: + continue + child_reader = reader._init_from_variable(var) + if not child_reader.is_scalar: + continue + dim_path = var_key + "/" + DIMENSION_KEY + if dim_path not in self.variables_store: + continue + dim_reader = reader._init_from_variable(self.variables_store[dim_path]) + dim_names_str = dim_reader.read_scalar() + if isinstance(dim_names_str, str) and dim_names_str == "": + dim_names = () + elif isinstance(dim_names_str, str): + dim_names = tuple(dim_names_str.split(",")) + else: + dim_names = () + scalar_value = child_reader.read_scalar() + attrs = self._get_attributes_for_variable(child_reader, var_key) + attrs_for_var = {k: v for k, v in attrs.items() if k != DIMENSION_KEY} + datasets[var_key] = Variable(dims=dim_names, data=np.array(scalar_value)) + return datasets def close(self): @@ -181,3 +222,242 @@ def __getitem__(self, key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: indexing.IndexingSupport.BASIC, self.reader.__getitem__, ) + + +def _write_scalar_safe(writer: OmFileWriter, value: Any, name: str) -> OmVariable | None: + """Write a scalar, returning None and warning if the type is unsupported.""" + try: + return writer.write_scalar(value, name=name) + except (ValueError, TypeError) as e: + warnings.warn( + f"Skipping attribute '{name}' with value {value!r}: {e}", + UserWarning, + stacklevel=3, + ) + return None + + +def _chunked_block_iterator(data: Any) -> Generator[np.ndarray, None, None]: + """ + Yield numpy arrays from a chunked array in C-order block traversal. + + Works with any array that exposes ``.numblocks``, ``.blocks[idx]``, + and ``.compute()`` (e.g. dask arrays). No dask import required. + """ + block_index_ranges = [range(n) for n in data.numblocks] + for block_indices in itertools.product(*block_index_ranges): + block = data.blocks[block_indices] + if hasattr(block, "compute"): + yield block.compute() + else: + yield np.asarray(block) + + +def _validate_chunk_alignment( + data_chunks: tuple, + om_chunks: list[int], + array_shape: tuple, +) -> None: + """ + Validate dask chunks are compatible with OM chunks for block-level streaming. + + Every non-last dask chunk along each dimension must be an exact multiple + of the corresponding OM chunk size (the last chunk may be smaller). + Additionally, for the leftmost dimension where a dask block contains more + than one OM chunk, every trailing dimension must be fully covered by each + dask block. This ensures the local chunk traversal inside a block matches + the global file order. + """ + import math + + ndim = len(om_chunks) + + for d in range(ndim): + dim_chunks = data_chunks[d] + for i, c in enumerate(dim_chunks[:-1]): + if c % om_chunks[d] != 0: + raise ValueError( + f"Dask chunk size {c} along dimension {d} (block {i}) " + f"is not a multiple of the OM chunk size {om_chunks[d]}." + ) + + first_multi = None + for d in range(ndim): + local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) + if local_n > 1: + first_multi = d + break + + if first_multi is not None: + for d in range(first_multi + 1, ndim): + local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) + global_n = math.ceil(array_shape[d] / om_chunks[d]) + if local_n != global_n: + raise ValueError( + f"Dask blocks have multiple OM chunks in dimension {first_multi}, " + f"but dimension {d} is not fully covered by each dask block " + f"(dask chunk {data_chunks[d][0]} vs array size {array_shape[d]}). " + f"Rechunk so trailing dimensions are fully covered." + ) + + +def _resolve_chunks_for_variable( + var_name: str, + var: Variable, + encoding: dict[str, dict[str, Any]] | None, + global_chunks: dict[str, int] | None, + data_chunks: tuple | None = None, +) -> list[int]: + """Resolve chunk sizes for a variable using the priority chain.""" + if encoding and var_name in encoding and "chunks" in encoding[var_name]: + return list(encoding[var_name]["chunks"]) + + if global_chunks is not None: + return [global_chunks.get(dim, min(size, 512)) for dim, size in zip(var.dims, var.shape)] + + if data_chunks is not None: + return [int(c[0]) for c in data_chunks] + + return [min(size, 512) for size in var.shape] + + +def _resolve_encoding_for_variable( + var_name: str, + encoding: dict[str, dict[str, Any]] | None, + global_scale_factor: float, + global_add_offset: float, + global_compression: str, +) -> tuple[float, float, str]: + """Resolve compression parameters for a variable.""" + var_enc = (encoding or {}).get(var_name, {}) + sf = var_enc.get("scale_factor", global_scale_factor) + ao = var_enc.get("add_offset", global_add_offset) + comp = var_enc.get("compression", global_compression) + return sf, ao, comp + + +def write_dataset( + ds: Dataset, + path: str | os.PathLike, + *, + encoding: dict[str, dict[str, Any]] | None = None, + chunks: dict[str, int] | None = None, + scale_factor: float = 1.0, + add_offset: float = 0.0, + compression: str = "pfor_delta_2d", +) -> None: + """ + Write an xarray Dataset to an OM file. + + The resulting file can be read back with ``xr.open_dataset(path, engine="om")``. + + Args: + ds: The xarray Dataset to write. + path: Output file path. + encoding: Per-variable overrides. Keys per variable: ``"chunks"``, + ``"scale_factor"``, ``"add_offset"``, ``"compression"``. + chunks: Global default chunk sizes as ``{dim_name: chunk_size}``. + scale_factor: Global default scale factor for float compression. + add_offset: Global default offset for float compression. + compression: Global default compression algorithm. + """ + path = str(path) + writer = OmFileWriter(path) + all_children: list[OmVariable] = [] + + def _write_variable(name: str, var: Variable, is_dim_coord: bool) -> None: + """Write a single variable (data var or non-dimension coordinate).""" + if np.issubdtype(var.dtype, np.datetime64) or np.issubdtype(var.dtype, np.timedelta64): + raise TypeError( + f"Variable '{name}' has dtype {var.dtype}. " + "OM files do not support datetime64/timedelta64 natively. " + "Convert to a numeric type before writing." + ) + + var_children: list[OmVariable] = [] + + if not is_dim_coord: + dim_str = ",".join(var.dims) + dim_var = writer.write_scalar(dim_str, name=DIMENSION_KEY) + var_children.append(dim_var) + + for attr_name, attr_value in var.attrs.items(): + scalar = _write_scalar_safe(writer, attr_value, attr_name) + if scalar is not None: + var_children.append(scalar) + + if var.ndim == 0: + om_var = writer.write_scalar( + var.values[()], + name=name, + children=var_children if var_children else None, + ) + all_children.append(om_var) + return + + data = var.data + is_chunked = not is_dim_coord and hasattr(data, "chunks") and data.chunks is not None + + if is_dim_coord: + resolved_chunks = [var.shape[0]] + else: + resolved_chunks = _resolve_chunks_for_variable( + name, + var, + encoding, + chunks, + data_chunks=data.chunks if is_chunked else None, + ) + + sf, ao, comp = _resolve_encoding_for_variable(name, encoding, scale_factor, add_offset, compression) + + if is_chunked: + _validate_chunk_alignment(data.chunks, resolved_chunks, var.shape) + om_var = writer.write_array_streaming( + dimensions=[int(d) for d in var.shape], + chunks=[int(c) for c in resolved_chunks], + chunk_iterator=_chunked_block_iterator(data), + dtype=var.dtype.name, + scale_factor=sf, + add_offset=ao, + compression=comp, + name=name, + children=var_children if var_children else None, + ) + else: + om_var = writer.write_array( + var.values, + chunks=resolved_chunks, + scale_factor=sf, + add_offset=ao, + compression=comp, + name=name, + children=var_children if var_children else None, + ) + all_children.append(om_var) + + for var_name in ds.data_vars: + _write_variable(var_name, ds[var_name].variable, is_dim_coord=False) + + non_dim_coords: list[str] = [] + for coord_name in ds.coords: + if coord_name in ds.data_vars: + continue + coord = ds.coords[coord_name] + is_dim_coord = coord.ndim == 1 and coord.dims[0] == coord_name + if not is_dim_coord: + non_dim_coords.append(coord_name) + _write_variable(coord_name, coord.variable, is_dim_coord=is_dim_coord) + + # Write list of non-dimension coordinates so the reader can restore them + if non_dim_coords: + coord_list_var = writer.write_scalar(",".join(non_dim_coords), name="_COORDINATE_VARIABLES") + all_children.append(coord_list_var) + + for attr_name, attr_value in ds.attrs.items(): + scalar = _write_scalar_safe(writer, attr_value, attr_name) + if scalar is not None: + all_children.append(scalar) + + root_var = writer.write_group(name="", children=all_children) + writer.close(root_var) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 85f66fd..28bf6d2 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -3,6 +3,7 @@ import pytest import xarray as xr from omfiles import OmFileReader, OmFileWriter +from omfiles.xarray import write_dataset from xarray.core import indexing from .test_utils import create_test_om_file, filter_numpy_size_warning @@ -150,3 +151,182 @@ def test_xarray_hierarchical_file(empty_temp_om_file): mean_temp = ds["temperature"].mean(dim="TIME") assert mean_temp.shape == (5, 5, 5) assert mean_temp.dims == ("LATITUDE", "LONGITUDE", "ALTITUDE") + + +@filter_numpy_size_warning +def test_write_dataset_basic_roundtrip(empty_temp_om_file): + ds = xr.Dataset( + {"temperature": (["lat", "lon"], np.random.rand(5, 5).astype(np.float32))}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "Test dataset"}, + ) + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, ds["temperature"].values, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + assert ds2.attrs["description"] == "Test dataset" + + +@filter_numpy_size_warning +def test_write_dataset_hierarchical_roundtrip(empty_temp_om_file): + """Mirrors test_xarray_hierarchical_file but uses write_dataset.""" + temperature_data = np.random.rand(5, 5, 5, 10).astype(np.float32) + precipitation_data = np.random.rand(5, 5, 10).astype(np.float32) + + ds = xr.Dataset( + { + "temperature": ( + ["LATITUDE", "LONGITUDE", "ALTITUDE", "TIME"], + temperature_data, + {"units": "celsius", "description": "Surface temperature"}, + ), + "precipitation": ( + ["LATITUDE", "LONGITUDE", "TIME"], + precipitation_data, + {"units": "mm", "description": "Precipitation"}, + ), + }, + coords={ + "LATITUDE": np.arange(5, dtype=np.float32), + "LONGITUDE": np.arange(5, dtype=np.float32), + "ALTITUDE": np.arange(5, dtype=np.float32), + "TIME": np.arange(10, dtype=np.float32), + }, + attrs={"description": "This is a hierarchical OM File"}, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + assert ds2.attrs["description"] == "This is a hierarchical OM File" + assert set(ds2.data_vars) == {"temperature", "precipitation"} + + np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) + assert ds2["temperature"].dims == ("LATITUDE", "LONGITUDE", "ALTITUDE", "TIME") + assert ds2["temperature"].attrs["units"] == "celsius" + assert ds2["temperature"].attrs["description"] == "Surface temperature" + + np.testing.assert_array_almost_equal(ds2["precipitation"].values, precipitation_data, decimal=4) + assert ds2["precipitation"].dims == ("LATITUDE", "LONGITUDE", "TIME") + assert ds2["precipitation"].attrs["units"] == "mm" + + assert ds2["LATITUDE"].dims == ("LATITUDE",) + assert ds2["LONGITUDE"].dims == ("LONGITUDE",) + assert ds2["ALTITUDE"].dims == ("ALTITUDE",) + assert ds2["TIME"].dims == ("TIME",) + + +@filter_numpy_size_warning +def test_write_dataset_per_variable_encoding(empty_temp_om_file): + ds = xr.Dataset( + { + "high_res": (["x", "y"], np.random.rand(10, 10).astype(np.float32)), + "low_res": (["x", "y"], np.random.rand(10, 10).astype(np.float32)), + }, + coords={ + "x": np.arange(10, dtype=np.float32), + "y": np.arange(10, dtype=np.float32), + }, + ) + + write_dataset( + ds, + empty_temp_om_file, + scale_factor=1000.0, + encoding={ + "high_res": {"scale_factor": 100000.0, "chunks": [5, 5]}, + "low_res": {"chunks": [10, 10]}, + }, + ) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["high_res"].values, ds["high_res"].values, decimal=4) + np.testing.assert_array_almost_equal(ds2["low_res"].values, ds["low_res"].values, decimal=2) + + +@filter_numpy_size_warning +@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32, np.uint64]) +def test_write_dataset_integer_dtypes(dtype, empty_temp_om_file): + data = np.arange(25, dtype=dtype).reshape(5, 5) + ds = xr.Dataset({"values": (["x", "y"], data)}) + + write_dataset(ds, empty_temp_om_file) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_equal(ds2["values"].values, data) + assert ds2["values"].dtype == dtype + + +@filter_numpy_size_warning +def test_write_dataset_unsupported_attrs_warning(empty_temp_om_file): + ds = xr.Dataset( + {"data": (["x"], np.arange(5, dtype=np.float32))}, + attrs={"valid": "hello", "invalid": [1, 2, 3]}, + ) + + with pytest.warns(UserWarning, match="Skipping attribute"): + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + assert ds2.attrs["valid"] == "hello" + assert "invalid" not in ds2.attrs + + +def test_write_dataset_datetime_raises(empty_temp_om_file): + time_values = np.array( + ["2020-01-01", "2020-01-02", "2020-01-03", "2020-01-04", "2020-01-05"], dtype="datetime64[ns]" + ) + ds = xr.Dataset( + {"data": (["time"], np.arange(5, dtype=np.float32))}, + coords={"time": time_values}, + ) + + with pytest.raises(TypeError, match="datetime64"): + write_dataset(ds, empty_temp_om_file) + + +@filter_numpy_size_warning +def test_write_dataset_scalar_coordinate(empty_temp_om_file): + """Writing a dataset with a scalar (0-d) coordinate should not segfault.""" + temperature_data = np.random.rand(5, 5).astype(np.float32) + ds = xr.Dataset( + {"temperature": (["lat", "lon"], temperature_data)}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + "time": np.float32(42.0), + }, + ) + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + loaded = xr.open_dataset(empty_temp_om_file, engine="om") + + assert "time" in loaded.coords + assert "time" not in loaded.data_vars + assert loaded.coords["time"].ndim == 0 + np.testing.assert_almost_equal(float(loaded.coords["time"]), 42.0) + + np.testing.assert_array_almost_equal(loaded["temperature"].values, temperature_data, decimal=4) + np.testing.assert_array_equal(loaded.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(loaded.coords["lon"].values, ds.coords["lon"].values) + + +@filter_numpy_size_warning +def test_write_dataset_non_dimension_coordinate(empty_temp_om_file): + """Non-dimension coordinates should preserve their dimensions and coordinate status.""" + valid_time_data = np.arange(6, dtype=np.float32) + ds = xr.Dataset( + {"t2m": (("step", "lat"), np.zeros((6, 10), dtype=np.float32))}, + coords={"valid_time": ("step", valid_time_data)}, + ) + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + loaded = xr.open_dataset(empty_temp_om_file, engine="om") + + assert loaded["valid_time"].dims == ("step",) + assert "valid_time" in loaded.coords + assert "valid_time" not in loaded.data_vars + np.testing.assert_array_equal(loaded["valid_time"].values, valid_time_data) From 7fdaad52b7b9453695d2a43b32b7a9fea93c46ad Mon Sep 17 00:00:00 2001 From: Theo Date: Mon, 9 Mar 2026 16:51:20 -0700 Subject: [PATCH 03/27] feat: add dask array integration Add standalone dask.py module for writing dask arrays via streaming, declare dask as an optional dependency, and add dask-backed tests for both write_dask_array() and write_dataset(). --- examples/dask_larger_than_ram.py | 154 ++++++++++++++++++++++++++ pyproject.toml | 4 +- python/omfiles/dask.py | 97 +++++++++++++++++ tests/test_streaming_write.py | 179 +++++++++++++++++++++++++++++++ tests/test_xarray.py | 129 ++++++++++++++++++++++ uv.lock | 106 +++++++++++++++++- 6 files changed, 667 insertions(+), 2 deletions(-) create mode 100644 examples/dask_larger_than_ram.py create mode 100644 python/omfiles/dask.py diff --git a/examples/dask_larger_than_ram.py b/examples/dask_larger_than_ram.py new file mode 100644 index 0000000..4822c45 --- /dev/null +++ b/examples/dask_larger_than_ram.py @@ -0,0 +1,154 @@ +#!/usr/bin/env -S uv run --script +# +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "omfiles>=1.1.1", # x-release-please-version +# "dask[array]>=2023.1.0", +# ] +# /// +# +# This example demonstrates writing a dask array that is larger than the +# available process memory to an OM file using streaming writes. +# +# A process memory limit is set via resource.setrlimit to simulate a +# constrained environment. The dask array is never fully materialized — +# only one chunk is held in memory at a time thanks to write_dask_array(). +# +# NOTE: resource.setrlimit(RLIMIT_AS) is only enforced on Linux. +# On macOS the kernel ignores RSS/AS limits, so the script uses +# tracemalloc and ru_maxrss to prove that peak memory stays low. + +import os +import platform +import resource +import tempfile +import tracemalloc + +import dask.array as da +import numpy as np +from omfiles import OmFileReader, OmFileWriter +from omfiles.dask import write_dask_array + +# Configuration +MEMORY_LIMIT_MB = 128 # process memory cap (enforced on Linux) +DATASET_SIZE_MB = 512 # total size of the dask array +CHUNK_SIZE = 1024 # chunk edge length (CHUNK_SIZE x CHUNK_SIZE) +DTYPE = np.float32 # 4 bytes per element + +# Derived constants +bytes_per_element = np.dtype(DTYPE).itemsize +total_elements = (DATASET_SIZE_MB * 1024 * 1024) // bytes_per_element +side_length = int(np.sqrt(total_elements)) # square array for simplicity +actual_size_mb = (side_length * side_length * bytes_per_element) / (1024 * 1024) + + +def set_memory_limit(limit_mb: int) -> bool: + """Try to cap the process address space. Returns True if enforced.""" + limit_bytes = limit_mb * 1024 * 1024 + try: + _, hard = resource.getrlimit(resource.RLIMIT_AS) + resource.setrlimit(resource.RLIMIT_AS, (limit_bytes, hard)) + if platform.system() == "Linux": + print(f" Memory limit set to {limit_mb} MB (enforced on Linux)") + return True + else: + print( + f" Memory limit requested ({limit_mb} MB) but {platform.system()} " + "does not enforce RLIMIT_AS — relying on memory tracking instead" + ) + return False + except (ValueError, OSError, AttributeError) as e: + print(f" Could not set memory limit: {e}") + return False + + +def get_peak_rss_mb() -> float: + """Return peak RSS in MB (works on Linux and macOS).""" + usage = resource.getrusage(resource.RUSAGE_SELF) + if platform.system() == "Darwin": + return usage.ru_maxrss / (1024 * 1024) # macOS reports bytes + return usage.ru_maxrss / 1024 # Linux reports kilobytes + + +def main(): + print("=" * 60) + print("Dask larger-than-RAM write example") + print("=" * 60) + + # Set memory limit + print(f"\nSetting process memory limit to {MEMORY_LIMIT_MB} MB...") + enforced = set_memory_limit(MEMORY_LIMIT_MB) + + # Start memory tracking + tracemalloc.start() + + # Create a dask array larger than the memory limit + print( + f"\nCreating dask array: {side_length} x {side_length} {DTYPE.__name__} " + f"({actual_size_mb:.0f} MB, chunked {CHUNK_SIZE} x {CHUNK_SIZE})" + ) + + data = da.random.random( + (side_length, side_length), + chunks=(CHUNK_SIZE, CHUNK_SIZE), + ).astype(DTYPE) + + print(f" Shape: {data.shape}") + print(f" Chunks: {data.chunksize}") + print(f" Num blocks: {data.numblocks} ({np.prod(data.numblocks)} total)") + + # Write to .om file via streaming + fd, filepath = tempfile.mkstemp(suffix=".om") + os.close(fd) + + print(f"\nWriting to {filepath} ...") + writer = OmFileWriter(filepath) + root = write_dask_array(writer, data, name="temperature") + writer.close(root) + + file_size_mb = os.path.getsize(filepath) / (1024 * 1024) + print(f" File size on disk: {file_size_mb:.1f} MB (compression ratio: {actual_size_mb / file_size_mb:.1f}x)") + + # Read back a slice and verify + print("\nReading back a slice to verify...") + with OmFileReader(filepath) as reader: + print(f" Reader shape: {reader.shape}, dtype: {reader.dtype}") + sample = reader[0:10, 0:10] + print(f" Sample slice [0:10, 0:10] shape: {sample.shape}") + print(f" Sample values (first row): {sample[0, :5]}") + assert sample.shape == (10, 10), "Unexpected slice shape" + assert not np.any(np.isnan(sample)), "Found NaN values in readback" + + print(" Verification passed!") + + # Memory summary + current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + peak_traced_mb = peak / (1024 * 1024) + peak_rss_mb = get_peak_rss_mb() + + print("\n" + "=" * 60) + print("Memory summary") + print("=" * 60) + print(f" Dataset size: {actual_size_mb:.0f} MB") + if enforced: + print(f" Process memory limit: {MEMORY_LIMIT_MB} MB (enforced)") + else: + print(f" Process memory limit: {MEMORY_LIMIT_MB} MB (not enforced on {platform.system()})") + print(f" Peak traced (Python): {peak_traced_mb:.1f} MB") + print(f" Peak RSS (process): {peak_rss_mb:.1f} MB") + print(f" Ratio (dataset/peak): {actual_size_mb / peak_rss_mb:.1f}x") + print() + + if peak_rss_mb < actual_size_mb: + print("The entire dataset was written WITHOUT loading it all into memory.") + else: + print("WARNING: Peak RSS exceeded dataset size — streaming may not have worked as expected.") + + # Cleanup + os.unlink(filepath) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c43f42b..71f6b58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,13 +38,15 @@ codec = [ xarray = ["xarray>=2023.1.0"] fsspec = ["fsspec>=2023.1.0", "s3fs>=2023.1.0"] grids = ["pyproj>=3.1.0"] +dask = ["dask[array]>=2023.1.0"] all = [ "zarr>=2.18.2", "numcodecs>=0.12.1", "xarray>=2023.1.0", "fsspec>=2023.10.0", "s3fs>=2023.1.0", - "pyproj>=3.3.0" + "pyproj>=3.3.0", + "dask[array]>=2023.1.0", ] [dependency-groups] diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py new file mode 100644 index 0000000..2fa9c5b --- /dev/null +++ b/python/omfiles/dask.py @@ -0,0 +1,97 @@ +"""Dask array integration for writing to OM files.""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING, Optional, Sequence + +from omfiles._rust import OmFileWriter, OmVariable +from omfiles.xarray import _validate_chunk_alignment + +if TYPE_CHECKING: + import dask.array as da + + +def _dask_block_iterator(dask_array: da.Array): + """ + Yield computed numpy arrays from a dask array in C-order block traversal. + + The OM file format requires chunks to be written in sequential order + corresponding to a row-major (C-order) traversal of the chunk grid. + ``itertools.product`` naturally produces this ordering since the last + index varies fastest. + """ + block_index_ranges = [range(n) for n in dask_array.numblocks] + for block_indices in itertools.product(*block_index_ranges): + yield dask_array.blocks[block_indices].compute() + + +def write_dask_array( + writer: OmFileWriter, + data: da.Array, + chunks: Optional[Sequence[int]] = None, + scale_factor: float = 1.0, + add_offset: float = 0.0, + compression: str = "pfor_delta_2d", + name: str = "data", + children: Optional[Sequence[OmVariable]] = None, +) -> OmVariable: + """ + Write a dask array to an OM file using streaming/incremental writes. + + Iterates over the blocks of the dask array, computing each block + on-the-fly, and streams them to the OM file writer. Only one block + is held in memory at a time. + + The dask array's chunk structure is used to determine the OM file's + chunk dimensions by default. Dask chunks must be multiples of the OM + chunk sizes (except the last chunk along each dimension which may be + smaller). When a dask block contains more than one OM chunk in a + dimension, all trailing dimensions must be fully covered by each block. + + Performance: write speed depends on the number of dask tasks, not just + data size. For best performance, use dask chunks much larger than the + OM chunk sizes — ideally covering the full extent of trailing dimensions. + For example, with OM chunks of (124, 124) on an (8192, 8192) array, + dask chunks of (124, 8192) will write ~10x faster than (124, 124). + + Args: + writer: An open OmFileWriter instance. + data: A dask array to write. + chunks: OM file chunk sizes per dimension. If None, uses the dask + array's chunk sizes. Dask chunks must be multiples of these. + scale_factor: Scale factor for float compression (default: 1.0). + add_offset: Offset for float compression (default: 0.0). + compression: Compression algorithm (default: "pfor_delta_2d"). + name: Variable name (default: "data"). + children: Child variables (default: None). + + Returns: + OmVariable representing the written array. + + Raises: + TypeError: If data is not a dask array. + ValueError: If dask chunks are incompatible with OM chunks. + ImportError: If dask is not installed. + """ + import dask.array as da + + if not isinstance(data, da.Array): + raise TypeError(f"Expected a dask array, got {type(data)}") + + if chunks is None: + chunks = [c[0] for c in data.chunks] + + _validate_chunk_alignment(data.chunks, list(chunks), data.shape) + + return writer.write_array_streaming( + dimensions=[int(d) for d in data.shape], + chunks=[int(c) for c in chunks], + chunk_iterator=_dask_block_iterator(data), + dtype=data.dtype.name, + scale_factor=scale_factor, + add_offset=add_offset, + compression=compression, + name=name, + children=list(children) if children else [], + ) diff --git a/tests/test_streaming_write.py b/tests/test_streaming_write.py index aed97a6..7b56734 100644 --- a/tests/test_streaming_write.py +++ b/tests/test_streaming_write.py @@ -220,3 +220,182 @@ def test_streaming_unsupported_dtype_raises(self): chunk_iterator=iter([]), dtype="complex128", ) + + +class TestWriteDaskArray: + @pytest.fixture(autouse=True) + def _import_dask(self): + pytest.importorskip("dask.array") + from omfiles.dask import write_dask_array + + self.write_dask_array = write_dask_array + + @pytest.fixture + def dask_array_2d(self): + import dask.array as da + + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + return da.from_array(np_data, chunks=(5, 10)) + + @pytest.fixture + def dask_array_3d(self): + import dask.array as da + + np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) + return da.from_array(np_data, chunks=(2, 3, 4)) + + def test_dask_roundtrip_2d(self, dask_array_2d): + expected = dask_array_2d.compute() + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array( + writer, + dask_array_2d, + scale_factor=10000.0, + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, expected, decimal=4) + + def test_dask_roundtrip_3d(self, dask_array_3d): + expected = dask_array_3d.compute() + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array(writer, dask_array_3d) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_equal(result, expected) + + def test_dask_boundary_chunks(self): + import dask.array as da + + np_data = np.arange(91, dtype=np.float32).reshape(7, 13) + darr = da.from_array(np_data, chunks=(4, 5)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array(writer, darr, scale_factor=10000.0) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, np_data, decimal=4) + + def test_dask_custom_name(self, dask_array_2d): + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array( + writer, + dask_array_2d, + scale_factor=10000.0, + name="temperature", + ) + assert var.name == "temperature" + writer.close(var) + + def test_dask_non_multiple_chunks_raises(self): + """Dask chunks that aren't multiples of OM chunks should raise.""" + import dask.array as da + + np_data = np.arange(30, dtype=np.float32).reshape(6, 5) + darr = da.from_array(np_data, chunks=(3, 5)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(ValueError, match="not a multiple"): + self.write_dask_array(writer, darr, chunks=[2, 5]) + + def test_dask_larger_chunks_than_om_2d(self): + """Dask blocks spanning multiple OM chunks along dim 1 (full trailing dim).""" + import dask.array as da + + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + darr = da.from_array(np_data, chunks=(10, 20)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array( + writer, + darr, + chunks=[5, 10], + scale_factor=10000.0, + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, np_data, decimal=4) + + def test_dask_larger_chunks_than_om_3d(self): + """Dask blocks with full trailing dims, multiple OM chunks in dim 0.""" + import dask.array as da + + np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) + darr = da.from_array(np_data, chunks=(4, 6, 8)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array(writer, darr, chunks=[2, 3, 4]) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_equal(result, np_data) + + def test_dask_single_om_chunk_per_slow_dim(self): + """Dask blocks with 1 OM chunk in dim 0, partial trailing dim coverage.""" + import dask.array as da + + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + darr = da.from_array(np_data, chunks=(5, 10)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array( + writer, + darr, + chunks=[5, 5], + scale_factor=10000.0, + ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, np_data, decimal=4) + + def test_dask_misaligned_trailing_dims_raises(self): + """Dask blocks with multi-chunk dim 0 but partial trailing dim raises.""" + import dask.array as da + + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + darr = da.from_array(np_data, chunks=(10, 10)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(ValueError, match="not fully covered"): + self.write_dask_array(writer, darr, chunks=[5, 5]) + + def test_dask_not_a_dask_array_raises(self): + np_data = np.arange(20, dtype=np.float32).reshape(4, 5) + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(TypeError, match="Expected a dask array"): + self.write_dask_array(writer, np_data) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 28bf6d2..273e11c 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -330,3 +330,132 @@ def test_write_dataset_non_dimension_coordinate(empty_temp_om_file): assert "valid_time" in loaded.coords assert "valid_time" not in loaded.data_vars np.testing.assert_array_equal(loaded["valid_time"].values, valid_time_data) + + +@filter_numpy_size_warning +def test_write_dataset_dask_roundtrip(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.random.rand(10, 20).astype(np.float32) + dask_data = da.from_array(np_data, chunks=(5, 10)) + + ds = xr.Dataset( + {"temperature": (["lat", "lon"], dask_data)}, + coords={ + "lat": np.arange(10, dtype=np.float32), + "lon": np.arange(20, dtype=np.float32), + }, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, np_data, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + + +@filter_numpy_size_warning +def test_write_dataset_dask_mixed_variables(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_temp = np.random.rand(10, 20).astype(np.float32) + dask_temp = da.from_array(np_temp, chunks=(5, 10)) + np_precip = np.random.rand(10, 20).astype(np.float32) + + ds = xr.Dataset( + { + "temperature": (["lat", "lon"], dask_temp), + "precipitation": (["lat", "lon"], np_precip), + }, + coords={ + "lat": np.arange(10, dtype=np.float32), + "lon": np.arange(20, dtype=np.float32), + }, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, np_temp, decimal=4) + np.testing.assert_array_almost_equal(ds2["precipitation"].values, np_precip, decimal=4) + + +@filter_numpy_size_warning +def test_write_dataset_dask_boundary_chunks(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.arange(91, dtype=np.float32).reshape(7, 13) + dask_data = da.from_array(np_data, chunks=(4, 5)) + + ds = xr.Dataset({"data": (["x", "y"], dask_data)}) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["data"].values, np_data, decimal=4) + + +@filter_numpy_size_warning +def test_write_dataset_dask_with_attributes(empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.random.rand(5, 5).astype(np.float32) + dask_data = da.from_array(np_data, chunks=(5, 5)) + + ds = xr.Dataset( + {"temp": (["x", "y"], dask_data, {"units": "K", "long_name": "temperature"})}, + attrs={"source": "test"}, + ) + + write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temp"].values, np_data, decimal=4) + assert ds2["temp"].attrs["units"] == "K" + assert ds2["temp"].attrs["long_name"] == "temperature" + assert ds2.attrs["source"] == "test" + + +@filter_numpy_size_warning +@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32]) +def test_write_dataset_dask_integer_dtypes(dtype, empty_temp_om_file): + da = pytest.importorskip("dask.array") + + np_data = np.arange(25, dtype=dtype).reshape(5, 5) + dask_data = da.from_array(np_data, chunks=(5, 5)) + + ds = xr.Dataset({"values": (["x", "y"], dask_data)}) + + write_dataset(ds, empty_temp_om_file) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_equal(ds2["values"].values, np_data) + assert ds2["values"].dtype == dtype + + +@filter_numpy_size_warning +def test_write_dataset_dask_larger_chunks_than_om(empty_temp_om_file): + """Dask blocks larger than OM chunks with explicit smaller OM chunk sizes.""" + da = pytest.importorskip("dask.array") + + np_data = np.random.rand(10, 20).astype(np.float32) + dask_data = da.from_array(np_data, chunks=(10, 20)) + + ds = xr.Dataset( + {"temperature": (["lat", "lon"], dask_data)}, + coords={ + "lat": np.arange(10, dtype=np.float32), + "lon": np.arange(20, dtype=np.float32), + }, + ) + + write_dataset( + ds, + empty_temp_om_file, + chunks={"lat": 5, "lon": 10}, + scale_factor=100000.0, + ) + ds2 = xr.open_dataset(empty_temp_om_file, engine="om") + + np.testing.assert_array_almost_equal(ds2["temperature"].values, np_data, decimal=4) diff --git a/uv.lock b/uv.lock index e9a49bc..12a8d90 100644 --- a/uv.lock +++ b/uv.lock @@ -339,6 +339,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, ] +[[package]] +name = "click" +version = "8.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" }, +] + +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -348,6 +369,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "dask" +version = "2026.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "cloudpickle" }, + { name = "fsspec" }, + { name = "importlib-metadata", marker = "python_full_version < '3.12'" }, + { name = "packaging" }, + { name = "partd" }, + { name = "pyyaml" }, + { name = "toolz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bd/52/b0f9172b22778def907db1ff173249e4eb41f054b46a9c83b1528aaf811f/dask-2026.1.2.tar.gz", hash = "sha256:1136683de2750d98ea792670f7434e6c1cfce90cab2cc2f2495a9e60fd25a4fc", size = 10997838, upload-time = "2026-01-30T21:04:20.54Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/23/d39ccc4ed76222db31530b0a7d38876fdb7673e23f838e8d8f0ed4651a4f/dask-2026.1.2-py3-none-any.whl", hash = "sha256:46a0cf3b8d87f78a3d2e6b145aea4418a6d6d606fe6a16c79bd8ca2bb862bc91", size = 1482084, upload-time = "2026-01-30T21:04:18.363Z" }, +] + +[package.optional-dependencies] +array = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -573,6 +619,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", size = 8769, upload-time = "2022-07-01T12:21:02.467Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp", marker = "python_full_version < '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, +] + [[package]] name = "iniconfig" version = "2.3.0" @@ -603,6 +661,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" }, ] +[[package]] +name = "locket" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/83/97b29fe05cb6ae28d2dbd30b81e2e402a3eed5f460c26e9eaa5895ceacf5/locket-1.0.0.tar.gz", hash = "sha256:5c0d4c052a8bbbf750e056a8e65ccd309086f4f0f18a2eac306a8dfa4112a632", size = 4350, upload-time = "2022-04-20T22:04:44.312Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/bc/83e112abc66cd466c6b83f99118035867cecd41802f8d044638aa78a106e/locket-1.0.0-py2.py3-none-any.whl", hash = "sha256:b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3", size = 4398, upload-time = "2022-04-20T22:04:42.23Z" }, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -1126,6 +1193,7 @@ dependencies = [ [package.optional-dependencies] all = [ + { name = "dask", extra = ["array"] }, { name = "fsspec" }, { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numcodecs", version = "0.16.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -1143,6 +1211,9 @@ codec = [ { name = "zarr", version = "2.18.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "zarr", version = "3.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] +dask = [ + { name = "dask", extra = ["array"] }, +] fsspec = [ { name = "fsspec" }, { name = "s3fs" }, @@ -1183,6 +1254,8 @@ test = [ [package.metadata] requires-dist = [ + { name = "dask", extras = ["array"], marker = "extra == 'all'", specifier = ">=2023.1.0" }, + { name = "dask", extras = ["array"], marker = "extra == 'dask'", specifier = ">=2023.1.0" }, { name = "fsspec", marker = "extra == 'all'", specifier = ">=2023.10.0" }, { name = "fsspec", marker = "extra == 'fsspec'", specifier = ">=2023.1.0" }, { name = "numcodecs", marker = "extra == 'all'", specifier = ">=0.12.1" }, @@ -1197,7 +1270,7 @@ requires-dist = [ { name = "zarr", marker = "extra == 'all'", specifier = ">=2.18.2" }, { name = "zarr", marker = "extra == 'codec'", specifier = ">=2.18.2" }, ] -provides-extras = ["all", "codec", "fsspec", "grids", "xarray"] +provides-extras = ["all", "codec", "dask", "fsspec", "grids", "xarray"] [package.metadata.requires-dev] dev = [ @@ -1295,6 +1368,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/44/5191d2e4026f86a2a109053e194d3ba7a31a2d10a9c2348368c63ed4e85a/pandas-2.3.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3869faf4bd07b3b66a9f462417d0ca3a9df29a9f6abd5d0d0dbab15dac7abe87", size = 13202175, upload-time = "2025-09-29T23:31:59.173Z" }, ] +[[package]] +name = "partd" +version = "1.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "locket" }, + { name = "toolz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b2/3a/3f06f34820a31257ddcabdfafc2672c5816be79c7e353b02c1f318daa7d4/partd-1.4.2.tar.gz", hash = "sha256:d022c33afbdc8405c226621b015e8067888173d85f7f5ecebb3cafed9a20f02c", size = 21029, upload-time = "2024-05-06T19:51:41.945Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/e7/40fb618334dcdf7c5a316c0e7343c5cd82d3d866edc100d98e29bc945ecd/partd-1.4.2-py3-none-any.whl", hash = "sha256:978e4ac767ec4ba5b86c6eaa52e5a2a3bc748a2ca839e8cc798f1cc6ce6efb0f", size = 18905, upload-time = "2024-05-06T19:51:39.271Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -1910,6 +1996,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, ] +[[package]] +name = "toolz" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/d6/114b492226588d6ff54579d95847662fc69196bdeec318eb45393b24c192/toolz-1.1.0.tar.gz", hash = "sha256:27a5c770d068c110d9ed9323f24f1543e83b2f300a687b7891c1a6d56b697b5b", size = 52613, upload-time = "2025-10-17T04:03:21.661Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/12/5911ae3eeec47800503a238d971e51722ccea5feb8569b735184d5fcdbc0/toolz-1.1.0-py3-none-any.whl", hash = "sha256:15ccc861ac51c53696de0a5d6d4607f99c210739caf987b5d2054f3efed429d8", size = 58093, upload-time = "2025-10-17T04:03:20.435Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -2205,3 +2300,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/fc/76/7fa87f57c112c7b9c wheels = [ { url = "https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl", hash = "sha256:29cd905afb6235b94c09decda4258c888fcb79bb6c862ef7c0b8fe009b5c8563", size = 284067, upload-time = "2025-11-21T14:05:59.235Z" }, ] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +] From c3aafdc9d660df801d2bd439bb458da32c9bfcc8 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 10 Mar 2026 18:35:47 -0700 Subject: [PATCH 04/27] refactor: Remove `resource` import --- examples/dask_larger_than_ram.py | 59 ++++---------------------------- 1 file changed, 7 insertions(+), 52 deletions(-) diff --git a/examples/dask_larger_than_ram.py b/examples/dask_larger_than_ram.py index 4822c45..852cc11 100644 --- a/examples/dask_larger_than_ram.py +++ b/examples/dask_larger_than_ram.py @@ -11,17 +11,11 @@ # This example demonstrates writing a dask array that is larger than the # available process memory to an OM file using streaming writes. # -# A process memory limit is set via resource.setrlimit to simulate a -# constrained environment. The dask array is never fully materialized — -# only one chunk is held in memory at a time thanks to write_dask_array(). -# -# NOTE: resource.setrlimit(RLIMIT_AS) is only enforced on Linux. -# On macOS the kernel ignores RSS/AS limits, so the script uses -# tracemalloc and ru_maxrss to prove that peak memory stays low. +# The dask array is never fully materialized — only one chunk is held in +# memory at a time thanks to write_dask_array(). tracemalloc is used to +# prove that peak memory stays well below the total dataset size. import os -import platform -import resource import tempfile import tracemalloc @@ -31,7 +25,6 @@ from omfiles.dask import write_dask_array # Configuration -MEMORY_LIMIT_MB = 128 # process memory cap (enforced on Linux) DATASET_SIZE_MB = 512 # total size of the dask array CHUNK_SIZE = 1024 # chunk edge length (CHUNK_SIZE x CHUNK_SIZE) DTYPE = np.float32 # 4 bytes per element @@ -43,47 +36,15 @@ actual_size_mb = (side_length * side_length * bytes_per_element) / (1024 * 1024) -def set_memory_limit(limit_mb: int) -> bool: - """Try to cap the process address space. Returns True if enforced.""" - limit_bytes = limit_mb * 1024 * 1024 - try: - _, hard = resource.getrlimit(resource.RLIMIT_AS) - resource.setrlimit(resource.RLIMIT_AS, (limit_bytes, hard)) - if platform.system() == "Linux": - print(f" Memory limit set to {limit_mb} MB (enforced on Linux)") - return True - else: - print( - f" Memory limit requested ({limit_mb} MB) but {platform.system()} " - "does not enforce RLIMIT_AS — relying on memory tracking instead" - ) - return False - except (ValueError, OSError, AttributeError) as e: - print(f" Could not set memory limit: {e}") - return False - - -def get_peak_rss_mb() -> float: - """Return peak RSS in MB (works on Linux and macOS).""" - usage = resource.getrusage(resource.RUSAGE_SELF) - if platform.system() == "Darwin": - return usage.ru_maxrss / (1024 * 1024) # macOS reports bytes - return usage.ru_maxrss / 1024 # Linux reports kilobytes - - def main(): print("=" * 60) print("Dask larger-than-RAM write example") print("=" * 60) - # Set memory limit - print(f"\nSetting process memory limit to {MEMORY_LIMIT_MB} MB...") - enforced = set_memory_limit(MEMORY_LIMIT_MB) - # Start memory tracking tracemalloc.start() - # Create a dask array larger than the memory limit + # Create a dask array larger than available memory print( f"\nCreating dask array: {side_length} x {side_length} {DTYPE.__name__} " f"({actual_size_mb:.0f} MB, chunked {CHUNK_SIZE} x {CHUNK_SIZE})" @@ -126,25 +87,19 @@ def main(): current, peak = tracemalloc.get_traced_memory() tracemalloc.stop() peak_traced_mb = peak / (1024 * 1024) - peak_rss_mb = get_peak_rss_mb() print("\n" + "=" * 60) print("Memory summary") print("=" * 60) print(f" Dataset size: {actual_size_mb:.0f} MB") - if enforced: - print(f" Process memory limit: {MEMORY_LIMIT_MB} MB (enforced)") - else: - print(f" Process memory limit: {MEMORY_LIMIT_MB} MB (not enforced on {platform.system()})") print(f" Peak traced (Python): {peak_traced_mb:.1f} MB") - print(f" Peak RSS (process): {peak_rss_mb:.1f} MB") - print(f" Ratio (dataset/peak): {actual_size_mb / peak_rss_mb:.1f}x") + print(f" Ratio (dataset/peak): {actual_size_mb / peak_traced_mb:.1f}x") print() - if peak_rss_mb < actual_size_mb: + if peak_traced_mb < actual_size_mb: print("The entire dataset was written WITHOUT loading it all into memory.") else: - print("WARNING: Peak RSS exceeded dataset size — streaming may not have worked as expected.") + print("WARNING: Peak memory exceeded dataset size — streaming may not have worked as expected.") # Cleanup os.unlink(filepath) From ca8b232a5f4881c1bef637d8080971e80c741b02 Mon Sep 17 00:00:00 2001 From: Theo Date: Tue, 10 Mar 2026 18:36:51 -0700 Subject: [PATCH 05/27] feat: Add memory usage test --- tests/test_streaming_write.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_streaming_write.py b/tests/test_streaming_write.py index 7b56734..dee419d 100644 --- a/tests/test_streaming_write.py +++ b/tests/test_streaming_write.py @@ -1,4 +1,5 @@ import tempfile +import tracemalloc import numpy as np import pytest @@ -399,3 +400,32 @@ def test_dask_not_a_dask_array_raises(self): writer = OmFileWriter(f.name) with pytest.raises(TypeError, match="Expected a dask array"): self.write_dask_array(writer, np_data) + + def test_dask_streaming_memory_stays_bounded(self): + """Peak memory during a dask streaming write stays well below the full dataset size.""" + # ~16 MB dataset (2048 x 2048 x float32), written in 256x256 chunks (~256 KB each) + side = 2048 + chunk = 256 + dtype = np.float32 + dataset_bytes = side * side * np.dtype(dtype).itemsize + + import dask.array as da + + darr = da.random.random((side, side), chunks=(chunk, chunk)).astype(dtype) + + tracemalloc.start() + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = self.write_dask_array(writer, darr) + writer.close(var) + + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + # Peak Python memory should be a fraction of the total dataset size, + # proving that chunks are streamed rather than fully materialized. + assert peak < dataset_bytes, ( + f"Peak traced memory ({peak / 1024 / 1024:.1f} MB) should be less than " + f"the dataset size ({dataset_bytes / 1024 / 1024:.1f} MB)" + ) From b1f26645735f10469065472939105f9112ada2c2 Mon Sep 17 00:00:00 2001 From: Theo Date: Mon, 16 Mar 2026 14:32:47 -0700 Subject: [PATCH 06/27] feat: add and test fsspec xarray support --- python/omfiles/xarray.py | 11 +++- tests/test_fsspec.py | 116 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/python/omfiles/xarray.py b/python/omfiles/xarray.py index a4831f9..40e3888 100644 --- a/python/omfiles/xarray.py +++ b/python/omfiles/xarray.py @@ -340,6 +340,7 @@ def write_dataset( ds: Dataset, path: str | os.PathLike, *, + fs: Any | None = None, encoding: dict[str, dict[str, Any]] | None = None, chunks: dict[str, int] | None = None, scale_factor: float = 1.0, @@ -353,7 +354,10 @@ def write_dataset( Args: ds: The xarray Dataset to write. - path: Output file path. + path: Output file path (local path or path within the fsspec filesystem). + fs: Optional fsspec filesystem object. When provided, the file is written + via ``OmFileWriter.from_fsspec(fs, path)`` instead of the default + local-file writer. encoding: Per-variable overrides. Keys per variable: ``"chunks"``, ``"scale_factor"``, ``"add_offset"``, ``"compression"``. chunks: Global default chunk sizes as ``{dim_name: chunk_size}``. @@ -362,7 +366,10 @@ def write_dataset( compression: Global default compression algorithm. """ path = str(path) - writer = OmFileWriter(path) + if fs is not None: + writer = OmFileWriter.from_fsspec(fs, path) + else: + writer = OmFileWriter(path) all_children: list[OmVariable] = [] def _write_variable(name: str, var: Variable, is_dim_coord: bool) -> None: diff --git a/tests/test_fsspec.py b/tests/test_fsspec.py index 9e50cfa..7a8a394 100644 --- a/tests/test_fsspec.py +++ b/tests/test_fsspec.py @@ -14,6 +14,8 @@ from fsspec.implementations.memory import MemoryFileSystem from s3fs import S3FileSystem +from omfiles.xarray import write_dataset + from .test_utils import filter_numpy_size_warning, find_chunk_for_timestamp # --- Fixtures --- @@ -278,3 +280,117 @@ def read_slice(idx, start): assert len(results) == num_threads for i, arr in enumerate(results): np.testing.assert_array_equal(arr, data[i * slice_size : (i + 1) * slice_size, :]) + + +# --- write_dataset fsspec tests --- + + +@filter_numpy_size_warning +def test_write_dataset_memory_fsspec(memory_fs): + """write_dataset with fs= writes to a memory filesystem and reads back.""" + ds = xr.Dataset( + {"temperature": (["lat", "lon"], np.random.rand(5, 5).astype(np.float32))}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "Test dataset"}, + ) + write_dataset(ds, "dataset_test.om", fs=memory_fs, scale_factor=100000.0) + assert_file_exists(memory_fs, "dataset_test.om") + + reader = omfiles.OmFileReader.from_fsspec(memory_fs, "dataset_test.om") + assert reader.num_children > 0 + reader.close() + + +@filter_numpy_size_warning +def test_write_dataset_memory_fsspec_roundtrip(memory_fs): + """Full roundtrip: write_dataset via memory fs, read back with xarray.""" + temperature_data = np.random.rand(5, 5).astype(np.float32) + ds = xr.Dataset( + {"temperature": (["lat", "lon"], temperature_data)}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "fsspec roundtrip test"}, + ) + path = "roundtrip_dataset.om" + write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0) + + # Dump from memory fs to a temp file so xarray can read it back + with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: + tmp.write(memory_fs.cat(path)) + tmp_path = tmp.name + try: + ds2 = xr.open_dataset(tmp_path, engine="om") + np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + assert ds2.attrs["description"] == "fsspec roundtrip test" + finally: + os.unlink(tmp_path) + + +@filter_numpy_size_warning +def test_write_dataset_local_fsspec(local_fs): + """write_dataset with a local fsspec filesystem produces a valid file.""" + ds = xr.Dataset( + {"temperature": (["lat", "lon"], np.random.rand(8, 8).astype(np.float32))}, + coords={ + "lat": np.arange(8, dtype=np.float32), + "lon": np.arange(8, dtype=np.float32), + }, + ) + with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: + tmp_path = tmp.name + try: + write_dataset(ds, tmp_path, fs=local_fs, scale_factor=100000.0) + assert os.path.exists(tmp_path) + assert os.path.getsize(tmp_path) > 0 + + ds2 = xr.open_dataset(tmp_path, engine="om") + np.testing.assert_array_almost_equal(ds2["temperature"].values, ds["temperature"].values, decimal=4) + finally: + os.unlink(tmp_path) + + +@filter_numpy_size_warning +def test_write_dataset_fs_none_backward_compatible(): + """Passing fs=None behaves identically to the default (local path).""" + ds = xr.Dataset( + {"data": (["x"], np.arange(5, dtype=np.float32))}, + ) + with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: + tmp_path = tmp.name + try: + write_dataset(ds, tmp_path, fs=None) + ds2 = xr.open_dataset(tmp_path, engine="om") + np.testing.assert_array_equal(ds2["data"].values, ds["data"].values) + finally: + os.unlink(tmp_path) + + +@filter_numpy_size_warning +def test_write_and_read_dataset_fsspec_roundtrip(memory_fs): + """Full fsspec roundtrip: write_dataset via fs, read back via fsspec file object.""" + temperature_data = np.random.rand(5, 5).astype(np.float32) + ds = xr.Dataset( + {"temperature": (["lat", "lon"], temperature_data)}, + coords={ + "lat": np.arange(5, dtype=np.float32), + "lon": np.arange(5, dtype=np.float32), + }, + attrs={"description": "full fsspec roundtrip"}, + ) + path = "fsspec_full_roundtrip.om" + write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0) + + # Read back via fsspec.core.OpenFile, which xr.open_dataset supports + backend = fsspec.core.OpenFile(memory_fs, path, mode="rb") + ds2 = xr.open_dataset(backend, engine="om") + np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) + np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) + np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) + assert ds2.attrs["description"] == "full fsspec roundtrip" From a423c24994dcfd08f71c33345f76fd04d1cbc203 Mon Sep 17 00:00:00 2001 From: terraputix Date: Thu, 19 Mar 2026 11:34:34 +0100 Subject: [PATCH 07/27] fix linter --- tests/test_fsspec.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_fsspec.py b/tests/test_fsspec.py index 7a8a394..75602fd 100644 --- a/tests/test_fsspec.py +++ b/tests/test_fsspec.py @@ -12,9 +12,8 @@ import xarray as xr from fsspec.implementations.local import LocalFileSystem from fsspec.implementations.memory import MemoryFileSystem -from s3fs import S3FileSystem - from omfiles.xarray import write_dataset +from s3fs import S3FileSystem from .test_utils import filter_numpy_size_warning, find_chunk_for_timestamp From 9e2ec1eaa69860db2c906bc118fec1008865b59e Mon Sep 17 00:00:00 2001 From: terraputix Date: Fri, 20 Mar 2026 11:22:30 +0100 Subject: [PATCH 08/27] use similar code for streaming and full writes --- src/writer.rs | 404 +++++++++++++++++++++++++++++--------------------- 1 file changed, 237 insertions(+), 167 deletions(-) diff --git a/src/writer.rs b/src/writer.rs index 454ca9a..d0bfd14 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -9,7 +9,7 @@ use numpy::{ }; use omfiles_rs::{ traits::{OmFileArrayDataType, OmFileScalarDataType, OmFileWriterBackend}, - writer::{OmFileWriter as OmFileWriterRs, OmFileWriterArrayFinalized}, + writer::{OmFileWriter as OmFileWriterRs, OmFileWriterArray, OmFileWriterArrayFinalized}, OmCompressionType, OmFilesError, OmOffsetSize, }; use pyo3::{ @@ -22,6 +22,162 @@ use std::{ sync::{Mutex, PoisonError}, }; +// --------------------------------------------------------------------------- +// Canonical element-type enum — normalizes numpy dtype and string dtype into +// a single representation so that the 10-way type dispatch only exists once. +// --------------------------------------------------------------------------- + +/// All array element types supported by the writer. +enum OmElementType { + Float32, + Float64, + Int8, + Uint8, + Int16, + Uint16, + Int32, + Uint32, + Int64, + Uint64, +} + +impl OmElementType { + /// Resolve from a numpy `PyArrayDescr` (used by `write_array`). + fn from_numpy_dtype(py: Python<'_>, d: &Bound<'_, PyArrayDescr>) -> PyResult { + if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Float32) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Float64) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Int8) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Uint8) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Int16) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Uint16) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Int32) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Uint32) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Int64) + } else if d.is_equiv_to(&dtype::(py)) { + Ok(Self::Uint64) + } else { + Err(OmFileWriter::unsupported_array_type_error(d.clone())) + } + } + + /// Resolve from a dtype string like `"float32"` (used by `write_array_streaming`). + fn from_str(s: &str) -> PyResult { + match s { + "float32" => Ok(Self::Float32), + "float64" => Ok(Self::Float64), + "int8" => Ok(Self::Int8), + "uint8" => Ok(Self::Uint8), + "int16" => Ok(Self::Int16), + "uint16" => Ok(Self::Uint16), + "int32" => Ok(Self::Int32), + "uint32" => Ok(Self::Uint32), + "int64" => Ok(Self::Int64), + "uint64" => Ok(Self::Uint64), + _ => Err(PyValueError::new_err(format!("Unsupported dtype: {}", s))), + } + } +} + +// --------------------------------------------------------------------------- +// DataFeeder trait — abstracts how data is fed into a typed OmFileWriterArray. +// The generic method lets the single dispatcher pick T, then call feed::(). +// Used as a generic bound (F: DataFeeder), never as dyn — fully monomorphized. +// --------------------------------------------------------------------------- + +/// Abstracts over the two write strategies (full array vs streaming iterator). +trait DataFeeder<'py> { + fn feed( + self, + writer: &mut OmFileWriterArray<'_, T, WriterBackendImpl>, + ) -> PyResult<()>; +} + +/// Feeds an entire numpy array in a single `write_data` call. +struct FullArrayFeeder<'a, 'py> { + data: &'a Bound<'py, PyUntypedArray>, +} + +impl<'a, 'py> DataFeeder<'py> for FullArrayFeeder<'a, 'py> { + fn feed( + self, + w: &mut OmFileWriterArray<'_, T, WriterBackendImpl>, + ) -> PyResult<()> { + let array = self.data.cast::>()?.readonly(); + w.write_data(array.as_array(), None, None) + .map_err(convert_omfilesrs_error) + } +} + +/// Feeds data chunk-by-chunk from a Python iterator. +struct StreamingFeeder<'py> { + py: Python<'py>, + iter: Bound<'py, PyAny>, +} + +impl<'py> DataFeeder<'py> for StreamingFeeder<'py> { + fn feed( + self, + w: &mut OmFileWriterArray<'_, T, WriterBackendImpl>, + ) -> PyResult<()> { + loop { + match self.iter.call_method0("__next__") { + Ok(item) => { + let array: PyReadonlyArrayDyn<'_, T> = item.extract()?; + w.write_data(array.as_array(), None, None) + .map_err(convert_omfilesrs_error)?; + } + Err(err) if err.is_instance_of::(self.py) => break, + Err(err) => return Err(err), + } + } + Ok(()) + } +} + +/// Resolved parameters shared by both `write_array` and `write_array_streaming`. +struct WriteArrayParams<'a> { + name: &'a str, + children: Vec, + scale_factor: f32, + add_offset: f32, + compression: OmCompressionType, +} + +impl<'a> WriteArrayParams<'a> { + fn from_options( + name: Option<&'a str>, + children: Option>, + scale_factor: Option, + add_offset: Option, + compression: Option<&str>, + ) -> PyResult { + Ok(Self { + name: name.unwrap_or("data"), + children: children + .unwrap_or_default() + .iter() + .map(Into::into) + .collect(), + scale_factor: scale_factor.unwrap_or(1.0), + add_offset: add_offset.unwrap_or(0.0), + compression: compression + .map(|s| PyCompressionType::from_str(s)) + .transpose()? + .unwrap_or(PyCompressionType::PforDelta2d) + .to_omfilesrs(), + }) + } +} + /// A Python wrapper for the Rust OmFileWriter implementation. #[gen_stub_pyclass] #[pyclass] @@ -55,7 +211,7 @@ impl OmFileWriter { PyErr::new::(format!("Unsupported scalar data type: {}", type_name)) } - // Helper method for safe writer access + // Helper method for safe writer access. fn with_writer(&self, f: F) -> PyResult where F: FnOnce(&mut OmFileWriterRs) -> PyResult, @@ -68,70 +224,98 @@ impl OmFileWriter { } } - fn write_array_internal<'py, T>( + /// Prepare a typed array writer, feed data into it, and finalize. + /// + /// The `feed_data` closure is the only thing that differs between full writes + /// (one `write_data` call) and streaming writes (a loop over an iterator). + /// Because `F` is `FnOnce`, the compiler monomorphizes each call-site — + /// no dynamic dispatch, zero overhead. + fn write_array_unified( &mut self, - data: PyReadonlyArrayDyn<'py, T>, + dimensions: Vec, chunks: Vec, scale_factor: f32, add_offset: f32, compression: OmCompressionType, + feed_data: F, ) -> PyResult where T: Element + OmFileArrayDataType, + F: FnOnce(&mut OmFileWriterArray<'_, T, WriterBackendImpl>) -> PyResult<()>, { - let dimensions = data - .shape() - .into_iter() - .map(|x| *x as u64) - .collect::>(); - self.with_writer(|writer| { let mut array_writer = writer .prepare_array::(dimensions, chunks, compression, scale_factor, add_offset) .map_err(convert_omfilesrs_error)?; - array_writer - .write_data(data.as_array(), None, None) - .map_err(convert_omfilesrs_error)?; + feed_data(&mut array_writer)?; - let variable_meta = array_writer.finalize(); - Ok(variable_meta) + Ok(array_writer.finalize()) }) } - fn write_array_streaming_internal<'py, T>( + /// The **single** 10-way type dispatch. + /// + /// Resolves `element_type` to a concrete `T`, calls `write_array_unified` + /// with `feeder.feed::()`, then registers the result as a named variable. + /// Both `write_array` and `write_array_streaming` delegate here. + fn write_array_dispatched<'py, F: DataFeeder<'py>>( &mut self, - py: Python<'py>, + element_type: OmElementType, dimensions: Vec, chunks: Vec, - scale_factor: f32, - add_offset: f32, - compression: OmCompressionType, - chunk_iterator: &Bound<'py, PyAny>, - ) -> PyResult - where - T: Element + OmFileArrayDataType, - { + params: &WriteArrayParams<'_>, + feeder: F, + ) -> PyResult { + // Each match arm is identical except for the type, this macro generates the correct match arms. + macro_rules! dispatch { + ($($variant:ident => $T:ty),+ $(,)?) => { + match element_type { + $(OmElementType::$variant => self.write_array_unified::<$T, _>( + dimensions, + chunks, + params.scale_factor, + params.add_offset, + params.compression, + |w| feeder.feed::<$T>(w), + ),)+ + } + }; + } + + let array_meta = dispatch! { + Float32 => f32, + Float64 => f64, + Int8 => i8, + Uint8 => u8, + Int16 => i16, + Uint16 => u16, + Int32 => i32, + Uint32 => u32, + Int64 => i64, + Uint64 => u64, + }?; + + self.finalize_array_variable(array_meta, params.name, ¶ms.children) + } + + /// Finalize array metadata and register it in the file structure as a named variable. + fn finalize_array_variable( + &self, + array_meta: OmFileWriterArrayFinalized, + name: &str, + children: &[OmOffsetSize], + ) -> PyResult { self.with_writer(|writer| { - let mut array_writer = writer - .prepare_array::(dimensions, chunks, compression, scale_factor, add_offset) + let offset_size = writer + .write_array(array_meta, name, children) .map_err(convert_omfilesrs_error)?; - loop { - let next_item = chunk_iterator.call_method0("__next__"); - match next_item { - Ok(item) => { - let array: PyReadonlyArrayDyn<'_, T> = item.extract()?; - array_writer - .write_data(array.as_array(), None, None) - .map_err(convert_omfilesrs_error)?; - } - Err(err) if err.is_instance_of::(py) => break, - Err(err) => return Err(err), - } - } - - Ok(array_writer.finalize()) + Ok(OmVariable { + name: name.to_string(), + offset: offset_size.offset, + size: offset_size.size, + }) }) } @@ -273,69 +457,13 @@ impl OmFileWriter { name: Option<&str>, children: Option>, ) -> PyResult { - let name = name.unwrap_or("data"); - let children: Vec = children - .unwrap_or_default() - .iter() - .map(Into::into) - .collect(); - - let element_type = data.dtype(); - let py = data.py(); - - let scale_factor = scale_factor.unwrap_or(1.0); - let add_offset = add_offset.unwrap_or(0.0); - let compression = compression - .map(|s| PyCompressionType::from_str(s)) - .transpose()? - .unwrap_or(PyCompressionType::PforDelta2d) - .to_omfilesrs(); - - let array_meta = if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else if element_type.is_equiv_to(&dtype::(py)) { - let array = data.cast::>()?.readonly(); - self.write_array_internal(array, chunks, scale_factor, add_offset, compression) - } else { - Err(Self::unsupported_array_type_error(element_type)) - }?; - - self.with_writer(|writer| { - let offset_size = writer - .write_array(array_meta, name, &children) - .map_err(convert_omfilesrs_error)?; + let params = + WriteArrayParams::from_options(name, children, scale_factor, add_offset, compression)?; + let element_type = OmElementType::from_numpy_dtype(data.py(), &data.dtype())?; + let dimensions = data.shape().iter().map(|x| *x as u64).collect(); + let feeder = FullArrayFeeder { data }; - Ok(OmVariable { - name: name.to_string(), - offset: offset_size.offset, - size: offset_size.size, - }) - }) + self.write_array_dispatched(element_type, dimensions, chunks, ¶ms, feeder) } /// Write an array to the .om file by streaming chunks from a Python iterator. @@ -381,71 +509,13 @@ impl OmFileWriter { name: Option<&str>, children: Option>, ) -> PyResult { - let name = name.unwrap_or("data"); - let children: Vec = children - .unwrap_or_default() - .iter() - .map(Into::into) - .collect(); - - let scale_factor = scale_factor.unwrap_or(1.0); - let add_offset = add_offset.unwrap_or(0.0); - let compression = compression - .map(|s| PyCompressionType::from_str(s)) - .transpose()? - .unwrap_or(PyCompressionType::PforDelta2d) - .to_omfilesrs(); - + let params = + WriteArrayParams::from_options(name, children, scale_factor, add_offset, compression)?; + let element_type = OmElementType::from_str(dtype)?; let iter = chunk_iterator.call_method0("__iter__")?; + let feeder = StreamingFeeder { py, iter }; - let array_meta = match dtype { - "float32" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "float64" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "int8" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "uint8" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "int16" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "uint16" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "int32" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "uint32" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "int64" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - "uint64" => self.write_array_streaming_internal::( - py, dimensions, chunks, scale_factor, add_offset, compression, &iter, - ), - _ => Err(PyValueError::new_err(format!( - "Unsupported dtype: {}", - dtype - ))), - }?; - - self.with_writer(|writer| { - let offset_size = writer - .write_array(array_meta, name, &children) - .map_err(convert_omfilesrs_error)?; - - Ok(OmVariable { - name: name.to_string(), - offset: offset_size.offset, - size: offset_size.size, - }) - }) + self.write_array_dispatched(element_type, dimensions, chunks, ¶ms, feeder) } /// Write a scalar value to the .om file. From b5d09b58373bf7da1c68678a5abc598c6d934e9e Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 08:07:42 +0100 Subject: [PATCH 09/27] revert unnecessary changes on this branch --- python/omfiles/__init__.py | 7 - python/omfiles/xarray.py | 291 +------------------------------------ tests/test_fsspec.py | 116 --------------- tests/test_xarray.py | 179 ----------------------- 4 files changed, 2 insertions(+), 591 deletions(-) diff --git a/python/omfiles/__init__.py b/python/omfiles/__init__.py index f839706..42f492c 100644 --- a/python/omfiles/__init__.py +++ b/python/omfiles/__init__.py @@ -12,10 +12,3 @@ "OmVariable", "types", ] - -try: - from .xarray import write_dataset - - __all__.append("write_dataset") -except ImportError: - pass diff --git a/python/omfiles/xarray.py b/python/omfiles/xarray.py index 40e3888..4bcfb43 100644 --- a/python/omfiles/xarray.py +++ b/python/omfiles/xarray.py @@ -3,11 +3,6 @@ from __future__ import annotations -import itertools -import os -import warnings -from typing import Any, Generator, Sequence - import numpy as np try: @@ -26,7 +21,7 @@ from xarray.core.utils import FrozenDict from xarray.core.variable import Variable -from ._rust import OmFileReader, OmFileWriter, OmVariable +from ._rust import OmFileReader, OmVariable # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -46,16 +41,10 @@ def open_dataset( with OmFileReader(filename_or_obj) as root_variable: store = OmDataStore(root_variable) store_entrypoint = StoreBackendEntrypoint() - ds = store_entrypoint.open_dataset( + return store_entrypoint.open_dataset( store, drop_variables=drop_variables, ) - coord_attr = "_COORDINATE_VARIABLES" - if coord_attr in ds.attrs: - coord_names = [c for c in ds.attrs[coord_attr].split(",") if c in ds] - ds = ds.set_coords(coord_names) - ds.attrs = {k: v for k, v in ds.attrs.items() if k != coord_attr} - return ds raise ValueError("Failed to open dataset") description = "Use .om files in Xarray" @@ -87,11 +76,6 @@ def _get_attributes_for_variable(self, reader: OmFileReader, path: str): for k, variable in direct_children.items(): child_reader = reader._init_from_variable(variable) if child_reader.is_scalar: - # Skip scalars that have _ARRAY_DIMENSIONS — they are 0-d - # coordinate variables, not plain attributes. - dim_key = path + "/" + k + "/" + DIMENSION_KEY - if dim_key in self.variables_store: - continue attrs[k] = child_reader.read_scalar() return attrs @@ -169,31 +153,6 @@ def _get_datasets(self, reader: OmFileReader): data = indexing.LazilyIndexedArray(backend_array) datasets[var_key] = Variable(dims=dim_names, data=data, attrs=attrs_for_var, encoding=None, fastpath=True) - - # Handle 0-d (scalar) variables that have _ARRAY_DIMENSIONS metadata. - # These are scalar coordinates written by write_dataset. - for var_key, var in self.variables_store.items(): - if var_key in datasets: - continue - child_reader = reader._init_from_variable(var) - if not child_reader.is_scalar: - continue - dim_path = var_key + "/" + DIMENSION_KEY - if dim_path not in self.variables_store: - continue - dim_reader = reader._init_from_variable(self.variables_store[dim_path]) - dim_names_str = dim_reader.read_scalar() - if isinstance(dim_names_str, str) and dim_names_str == "": - dim_names = () - elif isinstance(dim_names_str, str): - dim_names = tuple(dim_names_str.split(",")) - else: - dim_names = () - scalar_value = child_reader.read_scalar() - attrs = self._get_attributes_for_variable(child_reader, var_key) - attrs_for_var = {k: v for k, v in attrs.items() if k != DIMENSION_KEY} - datasets[var_key] = Variable(dims=dim_names, data=np.array(scalar_value)) - return datasets def close(self): @@ -222,249 +181,3 @@ def __getitem__(self, key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: indexing.IndexingSupport.BASIC, self.reader.__getitem__, ) - - -def _write_scalar_safe(writer: OmFileWriter, value: Any, name: str) -> OmVariable | None: - """Write a scalar, returning None and warning if the type is unsupported.""" - try: - return writer.write_scalar(value, name=name) - except (ValueError, TypeError) as e: - warnings.warn( - f"Skipping attribute '{name}' with value {value!r}: {e}", - UserWarning, - stacklevel=3, - ) - return None - - -def _chunked_block_iterator(data: Any) -> Generator[np.ndarray, None, None]: - """ - Yield numpy arrays from a chunked array in C-order block traversal. - - Works with any array that exposes ``.numblocks``, ``.blocks[idx]``, - and ``.compute()`` (e.g. dask arrays). No dask import required. - """ - block_index_ranges = [range(n) for n in data.numblocks] - for block_indices in itertools.product(*block_index_ranges): - block = data.blocks[block_indices] - if hasattr(block, "compute"): - yield block.compute() - else: - yield np.asarray(block) - - -def _validate_chunk_alignment( - data_chunks: tuple, - om_chunks: list[int], - array_shape: tuple, -) -> None: - """ - Validate dask chunks are compatible with OM chunks for block-level streaming. - - Every non-last dask chunk along each dimension must be an exact multiple - of the corresponding OM chunk size (the last chunk may be smaller). - Additionally, for the leftmost dimension where a dask block contains more - than one OM chunk, every trailing dimension must be fully covered by each - dask block. This ensures the local chunk traversal inside a block matches - the global file order. - """ - import math - - ndim = len(om_chunks) - - for d in range(ndim): - dim_chunks = data_chunks[d] - for i, c in enumerate(dim_chunks[:-1]): - if c % om_chunks[d] != 0: - raise ValueError( - f"Dask chunk size {c} along dimension {d} (block {i}) " - f"is not a multiple of the OM chunk size {om_chunks[d]}." - ) - - first_multi = None - for d in range(ndim): - local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) - if local_n > 1: - first_multi = d - break - - if first_multi is not None: - for d in range(first_multi + 1, ndim): - local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) - global_n = math.ceil(array_shape[d] / om_chunks[d]) - if local_n != global_n: - raise ValueError( - f"Dask blocks have multiple OM chunks in dimension {first_multi}, " - f"but dimension {d} is not fully covered by each dask block " - f"(dask chunk {data_chunks[d][0]} vs array size {array_shape[d]}). " - f"Rechunk so trailing dimensions are fully covered." - ) - - -def _resolve_chunks_for_variable( - var_name: str, - var: Variable, - encoding: dict[str, dict[str, Any]] | None, - global_chunks: dict[str, int] | None, - data_chunks: tuple | None = None, -) -> list[int]: - """Resolve chunk sizes for a variable using the priority chain.""" - if encoding and var_name in encoding and "chunks" in encoding[var_name]: - return list(encoding[var_name]["chunks"]) - - if global_chunks is not None: - return [global_chunks.get(dim, min(size, 512)) for dim, size in zip(var.dims, var.shape)] - - if data_chunks is not None: - return [int(c[0]) for c in data_chunks] - - return [min(size, 512) for size in var.shape] - - -def _resolve_encoding_for_variable( - var_name: str, - encoding: dict[str, dict[str, Any]] | None, - global_scale_factor: float, - global_add_offset: float, - global_compression: str, -) -> tuple[float, float, str]: - """Resolve compression parameters for a variable.""" - var_enc = (encoding or {}).get(var_name, {}) - sf = var_enc.get("scale_factor", global_scale_factor) - ao = var_enc.get("add_offset", global_add_offset) - comp = var_enc.get("compression", global_compression) - return sf, ao, comp - - -def write_dataset( - ds: Dataset, - path: str | os.PathLike, - *, - fs: Any | None = None, - encoding: dict[str, dict[str, Any]] | None = None, - chunks: dict[str, int] | None = None, - scale_factor: float = 1.0, - add_offset: float = 0.0, - compression: str = "pfor_delta_2d", -) -> None: - """ - Write an xarray Dataset to an OM file. - - The resulting file can be read back with ``xr.open_dataset(path, engine="om")``. - - Args: - ds: The xarray Dataset to write. - path: Output file path (local path or path within the fsspec filesystem). - fs: Optional fsspec filesystem object. When provided, the file is written - via ``OmFileWriter.from_fsspec(fs, path)`` instead of the default - local-file writer. - encoding: Per-variable overrides. Keys per variable: ``"chunks"``, - ``"scale_factor"``, ``"add_offset"``, ``"compression"``. - chunks: Global default chunk sizes as ``{dim_name: chunk_size}``. - scale_factor: Global default scale factor for float compression. - add_offset: Global default offset for float compression. - compression: Global default compression algorithm. - """ - path = str(path) - if fs is not None: - writer = OmFileWriter.from_fsspec(fs, path) - else: - writer = OmFileWriter(path) - all_children: list[OmVariable] = [] - - def _write_variable(name: str, var: Variable, is_dim_coord: bool) -> None: - """Write a single variable (data var or non-dimension coordinate).""" - if np.issubdtype(var.dtype, np.datetime64) or np.issubdtype(var.dtype, np.timedelta64): - raise TypeError( - f"Variable '{name}' has dtype {var.dtype}. " - "OM files do not support datetime64/timedelta64 natively. " - "Convert to a numeric type before writing." - ) - - var_children: list[OmVariable] = [] - - if not is_dim_coord: - dim_str = ",".join(var.dims) - dim_var = writer.write_scalar(dim_str, name=DIMENSION_KEY) - var_children.append(dim_var) - - for attr_name, attr_value in var.attrs.items(): - scalar = _write_scalar_safe(writer, attr_value, attr_name) - if scalar is not None: - var_children.append(scalar) - - if var.ndim == 0: - om_var = writer.write_scalar( - var.values[()], - name=name, - children=var_children if var_children else None, - ) - all_children.append(om_var) - return - - data = var.data - is_chunked = not is_dim_coord and hasattr(data, "chunks") and data.chunks is not None - - if is_dim_coord: - resolved_chunks = [var.shape[0]] - else: - resolved_chunks = _resolve_chunks_for_variable( - name, - var, - encoding, - chunks, - data_chunks=data.chunks if is_chunked else None, - ) - - sf, ao, comp = _resolve_encoding_for_variable(name, encoding, scale_factor, add_offset, compression) - - if is_chunked: - _validate_chunk_alignment(data.chunks, resolved_chunks, var.shape) - om_var = writer.write_array_streaming( - dimensions=[int(d) for d in var.shape], - chunks=[int(c) for c in resolved_chunks], - chunk_iterator=_chunked_block_iterator(data), - dtype=var.dtype.name, - scale_factor=sf, - add_offset=ao, - compression=comp, - name=name, - children=var_children if var_children else None, - ) - else: - om_var = writer.write_array( - var.values, - chunks=resolved_chunks, - scale_factor=sf, - add_offset=ao, - compression=comp, - name=name, - children=var_children if var_children else None, - ) - all_children.append(om_var) - - for var_name in ds.data_vars: - _write_variable(var_name, ds[var_name].variable, is_dim_coord=False) - - non_dim_coords: list[str] = [] - for coord_name in ds.coords: - if coord_name in ds.data_vars: - continue - coord = ds.coords[coord_name] - is_dim_coord = coord.ndim == 1 and coord.dims[0] == coord_name - if not is_dim_coord: - non_dim_coords.append(coord_name) - _write_variable(coord_name, coord.variable, is_dim_coord=is_dim_coord) - - # Write list of non-dimension coordinates so the reader can restore them - if non_dim_coords: - coord_list_var = writer.write_scalar(",".join(non_dim_coords), name="_COORDINATE_VARIABLES") - all_children.append(coord_list_var) - - for attr_name, attr_value in ds.attrs.items(): - scalar = _write_scalar_safe(writer, attr_value, attr_name) - if scalar is not None: - all_children.append(scalar) - - root_var = writer.write_group(name="", children=all_children) - writer.close(root_var) diff --git a/tests/test_fsspec.py b/tests/test_fsspec.py index 75602fd..31edebf 100644 --- a/tests/test_fsspec.py +++ b/tests/test_fsspec.py @@ -2,7 +2,6 @@ import os import tempfile import threading -from typing import Tuple import fsspec import numpy as np @@ -12,7 +11,6 @@ import xarray as xr from fsspec.implementations.local import LocalFileSystem from fsspec.implementations.memory import MemoryFileSystem -from omfiles.xarray import write_dataset from s3fs import S3FileSystem from .test_utils import filter_numpy_size_warning, find_chunk_for_timestamp @@ -279,117 +277,3 @@ def read_slice(idx, start): assert len(results) == num_threads for i, arr in enumerate(results): np.testing.assert_array_equal(arr, data[i * slice_size : (i + 1) * slice_size, :]) - - -# --- write_dataset fsspec tests --- - - -@filter_numpy_size_warning -def test_write_dataset_memory_fsspec(memory_fs): - """write_dataset with fs= writes to a memory filesystem and reads back.""" - ds = xr.Dataset( - {"temperature": (["lat", "lon"], np.random.rand(5, 5).astype(np.float32))}, - coords={ - "lat": np.arange(5, dtype=np.float32), - "lon": np.arange(5, dtype=np.float32), - }, - attrs={"description": "Test dataset"}, - ) - write_dataset(ds, "dataset_test.om", fs=memory_fs, scale_factor=100000.0) - assert_file_exists(memory_fs, "dataset_test.om") - - reader = omfiles.OmFileReader.from_fsspec(memory_fs, "dataset_test.om") - assert reader.num_children > 0 - reader.close() - - -@filter_numpy_size_warning -def test_write_dataset_memory_fsspec_roundtrip(memory_fs): - """Full roundtrip: write_dataset via memory fs, read back with xarray.""" - temperature_data = np.random.rand(5, 5).astype(np.float32) - ds = xr.Dataset( - {"temperature": (["lat", "lon"], temperature_data)}, - coords={ - "lat": np.arange(5, dtype=np.float32), - "lon": np.arange(5, dtype=np.float32), - }, - attrs={"description": "fsspec roundtrip test"}, - ) - path = "roundtrip_dataset.om" - write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0) - - # Dump from memory fs to a temp file so xarray can read it back - with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: - tmp.write(memory_fs.cat(path)) - tmp_path = tmp.name - try: - ds2 = xr.open_dataset(tmp_path, engine="om") - np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) - np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) - np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) - assert ds2.attrs["description"] == "fsspec roundtrip test" - finally: - os.unlink(tmp_path) - - -@filter_numpy_size_warning -def test_write_dataset_local_fsspec(local_fs): - """write_dataset with a local fsspec filesystem produces a valid file.""" - ds = xr.Dataset( - {"temperature": (["lat", "lon"], np.random.rand(8, 8).astype(np.float32))}, - coords={ - "lat": np.arange(8, dtype=np.float32), - "lon": np.arange(8, dtype=np.float32), - }, - ) - with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: - tmp_path = tmp.name - try: - write_dataset(ds, tmp_path, fs=local_fs, scale_factor=100000.0) - assert os.path.exists(tmp_path) - assert os.path.getsize(tmp_path) > 0 - - ds2 = xr.open_dataset(tmp_path, engine="om") - np.testing.assert_array_almost_equal(ds2["temperature"].values, ds["temperature"].values, decimal=4) - finally: - os.unlink(tmp_path) - - -@filter_numpy_size_warning -def test_write_dataset_fs_none_backward_compatible(): - """Passing fs=None behaves identically to the default (local path).""" - ds = xr.Dataset( - {"data": (["x"], np.arange(5, dtype=np.float32))}, - ) - with tempfile.NamedTemporaryFile(suffix=".om", delete=False) as tmp: - tmp_path = tmp.name - try: - write_dataset(ds, tmp_path, fs=None) - ds2 = xr.open_dataset(tmp_path, engine="om") - np.testing.assert_array_equal(ds2["data"].values, ds["data"].values) - finally: - os.unlink(tmp_path) - - -@filter_numpy_size_warning -def test_write_and_read_dataset_fsspec_roundtrip(memory_fs): - """Full fsspec roundtrip: write_dataset via fs, read back via fsspec file object.""" - temperature_data = np.random.rand(5, 5).astype(np.float32) - ds = xr.Dataset( - {"temperature": (["lat", "lon"], temperature_data)}, - coords={ - "lat": np.arange(5, dtype=np.float32), - "lon": np.arange(5, dtype=np.float32), - }, - attrs={"description": "full fsspec roundtrip"}, - ) - path = "fsspec_full_roundtrip.om" - write_dataset(ds, path, fs=memory_fs, scale_factor=100000.0) - - # Read back via fsspec.core.OpenFile, which xr.open_dataset supports - backend = fsspec.core.OpenFile(memory_fs, path, mode="rb") - ds2 = xr.open_dataset(backend, engine="om") - np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) - np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) - np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) - assert ds2.attrs["description"] == "full fsspec roundtrip" diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 273e11c..8d12d10 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -153,185 +153,6 @@ def test_xarray_hierarchical_file(empty_temp_om_file): assert mean_temp.dims == ("LATITUDE", "LONGITUDE", "ALTITUDE") -@filter_numpy_size_warning -def test_write_dataset_basic_roundtrip(empty_temp_om_file): - ds = xr.Dataset( - {"temperature": (["lat", "lon"], np.random.rand(5, 5).astype(np.float32))}, - coords={ - "lat": np.arange(5, dtype=np.float32), - "lon": np.arange(5, dtype=np.float32), - }, - attrs={"description": "Test dataset"}, - ) - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_almost_equal(ds2["temperature"].values, ds["temperature"].values, decimal=4) - np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) - np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) - assert ds2.attrs["description"] == "Test dataset" - - -@filter_numpy_size_warning -def test_write_dataset_hierarchical_roundtrip(empty_temp_om_file): - """Mirrors test_xarray_hierarchical_file but uses write_dataset.""" - temperature_data = np.random.rand(5, 5, 5, 10).astype(np.float32) - precipitation_data = np.random.rand(5, 5, 10).astype(np.float32) - - ds = xr.Dataset( - { - "temperature": ( - ["LATITUDE", "LONGITUDE", "ALTITUDE", "TIME"], - temperature_data, - {"units": "celsius", "description": "Surface temperature"}, - ), - "precipitation": ( - ["LATITUDE", "LONGITUDE", "TIME"], - precipitation_data, - {"units": "mm", "description": "Precipitation"}, - ), - }, - coords={ - "LATITUDE": np.arange(5, dtype=np.float32), - "LONGITUDE": np.arange(5, dtype=np.float32), - "ALTITUDE": np.arange(5, dtype=np.float32), - "TIME": np.arange(10, dtype=np.float32), - }, - attrs={"description": "This is a hierarchical OM File"}, - ) - - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - assert ds2.attrs["description"] == "This is a hierarchical OM File" - assert set(ds2.data_vars) == {"temperature", "precipitation"} - - np.testing.assert_array_almost_equal(ds2["temperature"].values, temperature_data, decimal=4) - assert ds2["temperature"].dims == ("LATITUDE", "LONGITUDE", "ALTITUDE", "TIME") - assert ds2["temperature"].attrs["units"] == "celsius" - assert ds2["temperature"].attrs["description"] == "Surface temperature" - - np.testing.assert_array_almost_equal(ds2["precipitation"].values, precipitation_data, decimal=4) - assert ds2["precipitation"].dims == ("LATITUDE", "LONGITUDE", "TIME") - assert ds2["precipitation"].attrs["units"] == "mm" - - assert ds2["LATITUDE"].dims == ("LATITUDE",) - assert ds2["LONGITUDE"].dims == ("LONGITUDE",) - assert ds2["ALTITUDE"].dims == ("ALTITUDE",) - assert ds2["TIME"].dims == ("TIME",) - - -@filter_numpy_size_warning -def test_write_dataset_per_variable_encoding(empty_temp_om_file): - ds = xr.Dataset( - { - "high_res": (["x", "y"], np.random.rand(10, 10).astype(np.float32)), - "low_res": (["x", "y"], np.random.rand(10, 10).astype(np.float32)), - }, - coords={ - "x": np.arange(10, dtype=np.float32), - "y": np.arange(10, dtype=np.float32), - }, - ) - - write_dataset( - ds, - empty_temp_om_file, - scale_factor=1000.0, - encoding={ - "high_res": {"scale_factor": 100000.0, "chunks": [5, 5]}, - "low_res": {"chunks": [10, 10]}, - }, - ) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_almost_equal(ds2["high_res"].values, ds["high_res"].values, decimal=4) - np.testing.assert_array_almost_equal(ds2["low_res"].values, ds["low_res"].values, decimal=2) - - -@filter_numpy_size_warning -@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32, np.uint64]) -def test_write_dataset_integer_dtypes(dtype, empty_temp_om_file): - data = np.arange(25, dtype=dtype).reshape(5, 5) - ds = xr.Dataset({"values": (["x", "y"], data)}) - - write_dataset(ds, empty_temp_om_file) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_equal(ds2["values"].values, data) - assert ds2["values"].dtype == dtype - - -@filter_numpy_size_warning -def test_write_dataset_unsupported_attrs_warning(empty_temp_om_file): - ds = xr.Dataset( - {"data": (["x"], np.arange(5, dtype=np.float32))}, - attrs={"valid": "hello", "invalid": [1, 2, 3]}, - ) - - with pytest.warns(UserWarning, match="Skipping attribute"): - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - assert ds2.attrs["valid"] == "hello" - assert "invalid" not in ds2.attrs - - -def test_write_dataset_datetime_raises(empty_temp_om_file): - time_values = np.array( - ["2020-01-01", "2020-01-02", "2020-01-03", "2020-01-04", "2020-01-05"], dtype="datetime64[ns]" - ) - ds = xr.Dataset( - {"data": (["time"], np.arange(5, dtype=np.float32))}, - coords={"time": time_values}, - ) - - with pytest.raises(TypeError, match="datetime64"): - write_dataset(ds, empty_temp_om_file) - - -@filter_numpy_size_warning -def test_write_dataset_scalar_coordinate(empty_temp_om_file): - """Writing a dataset with a scalar (0-d) coordinate should not segfault.""" - temperature_data = np.random.rand(5, 5).astype(np.float32) - ds = xr.Dataset( - {"temperature": (["lat", "lon"], temperature_data)}, - coords={ - "lat": np.arange(5, dtype=np.float32), - "lon": np.arange(5, dtype=np.float32), - "time": np.float32(42.0), - }, - ) - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - loaded = xr.open_dataset(empty_temp_om_file, engine="om") - - assert "time" in loaded.coords - assert "time" not in loaded.data_vars - assert loaded.coords["time"].ndim == 0 - np.testing.assert_almost_equal(float(loaded.coords["time"]), 42.0) - - np.testing.assert_array_almost_equal(loaded["temperature"].values, temperature_data, decimal=4) - np.testing.assert_array_equal(loaded.coords["lat"].values, ds.coords["lat"].values) - np.testing.assert_array_equal(loaded.coords["lon"].values, ds.coords["lon"].values) - - -@filter_numpy_size_warning -def test_write_dataset_non_dimension_coordinate(empty_temp_om_file): - """Non-dimension coordinates should preserve their dimensions and coordinate status.""" - valid_time_data = np.arange(6, dtype=np.float32) - ds = xr.Dataset( - {"t2m": (("step", "lat"), np.zeros((6, 10), dtype=np.float32))}, - coords={"valid_time": ("step", valid_time_data)}, - ) - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - loaded = xr.open_dataset(empty_temp_om_file, engine="om") - - assert loaded["valid_time"].dims == ("step",) - assert "valid_time" in loaded.coords - assert "valid_time" not in loaded.data_vars - np.testing.assert_array_equal(loaded["valid_time"].values, valid_time_data) - - @filter_numpy_size_warning def test_write_dataset_dask_roundtrip(empty_temp_om_file): da = pytest.importorskip("dask.array") From 0925814cf6b9adb642f4c5c4afe0e86a06ac770d Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 08:11:11 +0100 Subject: [PATCH 10/27] remove dask related changes in xarray tests --- tests/test_xarray.py | 130 ------------------------------------------- 1 file changed, 130 deletions(-) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 8d12d10..85f66fd 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -3,7 +3,6 @@ import pytest import xarray as xr from omfiles import OmFileReader, OmFileWriter -from omfiles.xarray import write_dataset from xarray.core import indexing from .test_utils import create_test_om_file, filter_numpy_size_warning @@ -151,132 +150,3 @@ def test_xarray_hierarchical_file(empty_temp_om_file): mean_temp = ds["temperature"].mean(dim="TIME") assert mean_temp.shape == (5, 5, 5) assert mean_temp.dims == ("LATITUDE", "LONGITUDE", "ALTITUDE") - - -@filter_numpy_size_warning -def test_write_dataset_dask_roundtrip(empty_temp_om_file): - da = pytest.importorskip("dask.array") - - np_data = np.random.rand(10, 20).astype(np.float32) - dask_data = da.from_array(np_data, chunks=(5, 10)) - - ds = xr.Dataset( - {"temperature": (["lat", "lon"], dask_data)}, - coords={ - "lat": np.arange(10, dtype=np.float32), - "lon": np.arange(20, dtype=np.float32), - }, - ) - - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_almost_equal(ds2["temperature"].values, np_data, decimal=4) - np.testing.assert_array_equal(ds2.coords["lat"].values, ds.coords["lat"].values) - np.testing.assert_array_equal(ds2.coords["lon"].values, ds.coords["lon"].values) - - -@filter_numpy_size_warning -def test_write_dataset_dask_mixed_variables(empty_temp_om_file): - da = pytest.importorskip("dask.array") - - np_temp = np.random.rand(10, 20).astype(np.float32) - dask_temp = da.from_array(np_temp, chunks=(5, 10)) - np_precip = np.random.rand(10, 20).astype(np.float32) - - ds = xr.Dataset( - { - "temperature": (["lat", "lon"], dask_temp), - "precipitation": (["lat", "lon"], np_precip), - }, - coords={ - "lat": np.arange(10, dtype=np.float32), - "lon": np.arange(20, dtype=np.float32), - }, - ) - - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_almost_equal(ds2["temperature"].values, np_temp, decimal=4) - np.testing.assert_array_almost_equal(ds2["precipitation"].values, np_precip, decimal=4) - - -@filter_numpy_size_warning -def test_write_dataset_dask_boundary_chunks(empty_temp_om_file): - da = pytest.importorskip("dask.array") - - np_data = np.arange(91, dtype=np.float32).reshape(7, 13) - dask_data = da.from_array(np_data, chunks=(4, 5)) - - ds = xr.Dataset({"data": (["x", "y"], dask_data)}) - - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_almost_equal(ds2["data"].values, np_data, decimal=4) - - -@filter_numpy_size_warning -def test_write_dataset_dask_with_attributes(empty_temp_om_file): - da = pytest.importorskip("dask.array") - - np_data = np.random.rand(5, 5).astype(np.float32) - dask_data = da.from_array(np_data, chunks=(5, 5)) - - ds = xr.Dataset( - {"temp": (["x", "y"], dask_data, {"units": "K", "long_name": "temperature"})}, - attrs={"source": "test"}, - ) - - write_dataset(ds, empty_temp_om_file, scale_factor=100000.0) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_almost_equal(ds2["temp"].values, np_data, decimal=4) - assert ds2["temp"].attrs["units"] == "K" - assert ds2["temp"].attrs["long_name"] == "temperature" - assert ds2.attrs["source"] == "test" - - -@filter_numpy_size_warning -@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.uint32]) -def test_write_dataset_dask_integer_dtypes(dtype, empty_temp_om_file): - da = pytest.importorskip("dask.array") - - np_data = np.arange(25, dtype=dtype).reshape(5, 5) - dask_data = da.from_array(np_data, chunks=(5, 5)) - - ds = xr.Dataset({"values": (["x", "y"], dask_data)}) - - write_dataset(ds, empty_temp_om_file) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_equal(ds2["values"].values, np_data) - assert ds2["values"].dtype == dtype - - -@filter_numpy_size_warning -def test_write_dataset_dask_larger_chunks_than_om(empty_temp_om_file): - """Dask blocks larger than OM chunks with explicit smaller OM chunk sizes.""" - da = pytest.importorskip("dask.array") - - np_data = np.random.rand(10, 20).astype(np.float32) - dask_data = da.from_array(np_data, chunks=(10, 20)) - - ds = xr.Dataset( - {"temperature": (["lat", "lon"], dask_data)}, - coords={ - "lat": np.arange(10, dtype=np.float32), - "lon": np.arange(20, dtype=np.float32), - }, - ) - - write_dataset( - ds, - empty_temp_om_file, - chunks={"lat": 5, "lon": 10}, - scale_factor=100000.0, - ) - ds2 = xr.open_dataset(empty_temp_om_file, engine="om") - - np.testing.assert_array_almost_equal(ds2["temperature"].values, np_data, decimal=4) From 13f0ba35babd1501be7fd0668cff6a56f85f174b Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 08:17:17 +0100 Subject: [PATCH 11/27] add missing method --- python/omfiles/dask.py | 48 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py index 2fa9c5b..b054309 100644 --- a/python/omfiles/dask.py +++ b/python/omfiles/dask.py @@ -3,15 +3,61 @@ from __future__ import annotations import itertools +import math from typing import TYPE_CHECKING, Optional, Sequence from omfiles._rust import OmFileWriter, OmVariable -from omfiles.xarray import _validate_chunk_alignment if TYPE_CHECKING: import dask.array as da +def _validate_chunk_alignment( + data_chunks: tuple, + om_chunks: list[int], + array_shape: tuple, +) -> None: + """ + Validate dask chunks are compatible with OM chunks for block-level streaming. + + Every non-last dask chunk along each dimension must be an exact multiple + of the corresponding OM chunk size (the last chunk may be smaller). + Additionally, for the leftmost dimension where a dask block contains more + than one OM chunk, every trailing dimension must be fully covered by each + dask block. This ensures the local chunk traversal inside a block matches + the global file order. + """ + ndim = len(om_chunks) + + for d in range(ndim): + dim_chunks = data_chunks[d] + for i, c in enumerate(dim_chunks[:-1]): + if c % om_chunks[d] != 0: + raise ValueError( + f"Dask chunk size {c} along dimension {d} (block {i}) " + f"is not a multiple of the OM chunk size {om_chunks[d]}." + ) + + first_multi = None + for d in range(ndim): + local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) + if local_n > 1: + first_multi = d + break + + if first_multi is not None: + for d in range(first_multi + 1, ndim): + local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) + global_n = math.ceil(array_shape[d] / om_chunks[d]) + if local_n != global_n: + raise ValueError( + f"Dask blocks have multiple OM chunks in dimension {first_multi}, " + f"but dimension {d} is not fully covered by each dask block " + f"(dask chunk {data_chunks[d][0]} vs array size {array_shape[d]}). " + f"Rechunk so trailing dimensions are fully covered." + ) + + def _dask_block_iterator(dask_array: da.Array): """ Yield computed numpy arrays from a dask array in C-order block traversal. From 13b850a6ced21836dfee7f344e257c80f1d0693c Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 11:09:11 +0100 Subject: [PATCH 12/27] remove unnecessary dtype from test --- tests/test_read_write.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_read_write.py b/tests/test_read_write.py index f6daeda..e658a08 100644 --- a/tests/test_read_write.py +++ b/tests/test_read_write.py @@ -28,19 +28,19 @@ def test_round_trip_array_datatypes(): shape = (5, 5, 5, 2) chunks = [2, 2, 2, 1] test_cases = [ - (np.random.rand(*shape).astype(np.float32), "float32"), - (np.random.rand(*shape).astype(np.float64), "float64"), - (np.random.randint(-128, 127, size=shape, dtype=np.int8), "int8"), - (np.random.randint(-32768, 32767, size=shape, dtype=np.int16), "int16"), - (np.random.randint(-2147483648, 2147483647, size=shape, dtype=np.int32), "int32"), - (np.random.randint(-9223372036854775808, 9223372036854775807, size=shape, dtype=np.int64), "int64"), - (np.random.randint(0, 255, size=shape, dtype=np.uint8), "uint8"), - (np.random.randint(0, 65535, size=shape, dtype=np.uint16), "uint16"), - (np.random.randint(0, 4294967295, size=shape, dtype=np.uint32), "uint32"), - (np.random.randint(0, 18446744073709551615, size=shape, dtype=np.uint64), "uint64"), + np.random.rand(*shape).astype(np.float32), + np.random.rand(*shape).astype(np.float64), + np.random.randint(-128, 127, size=shape, dtype=np.int8), + np.random.randint(-32768, 32767, size=shape, dtype=np.int16), + np.random.randint(-2147483648, 2147483647, size=shape, dtype=np.int32), + np.random.randint(-9223372036854775808, 9223372036854775807, size=shape, dtype=np.int64), + np.random.randint(0, 255, size=shape, dtype=np.uint8), + np.random.randint(0, 65535, size=shape, dtype=np.uint16), + np.random.randint(0, 4294967295, size=shape, dtype=np.uint32), + np.random.randint(0, 18446744073709551615, size=shape, dtype=np.uint64), ] - for test_data, dtype in test_cases: + for test_data in test_cases: with tempfile.NamedTemporaryFile(suffix=".om") as temp_file: writer = omfiles.OmFileWriter(temp_file.name) variable = writer.write_array(test_data, chunks=chunks, scale_factor=10000.0, add_offset=0.0) From 10c3d6911c12cb525ad919051760094fce555a20 Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 11:31:48 +0100 Subject: [PATCH 13/27] use enum based disctinction --- src/writer.rs | 244 ++++++++++++++++++-------------------------------- 1 file changed, 88 insertions(+), 156 deletions(-) diff --git a/src/writer.rs b/src/writer.rs index d0bfd14..c5ad3c2 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -9,7 +9,7 @@ use numpy::{ }; use omfiles_rs::{ traits::{OmFileArrayDataType, OmFileScalarDataType, OmFileWriterBackend}, - writer::{OmFileWriter as OmFileWriterRs, OmFileWriterArray, OmFileWriterArrayFinalized}, + writer::{OmFileWriter as OmFileWriterRs, OmFileWriterArray}, OmCompressionType, OmFilesError, OmOffsetSize, }; use pyo3::{ @@ -22,10 +22,14 @@ use std::{ sync::{Mutex, PoisonError}, }; -// --------------------------------------------------------------------------- -// Canonical element-type enum — normalizes numpy dtype and string dtype into -// a single representation so that the 10-way type dispatch only exists once. -// --------------------------------------------------------------------------- +/// Helper to convert OmOffsetSize to OmVariable +fn to_variable(name: &str, os: OmOffsetSize) -> OmVariable { + OmVariable { + name: name.to_string(), + offset: os.offset, + size: os.size, + } +} /// All array element types supported by the writer. enum OmElementType { @@ -87,59 +91,42 @@ impl OmElementType { } } -// --------------------------------------------------------------------------- -// DataFeeder trait — abstracts how data is fed into a typed OmFileWriterArray. -// The generic method lets the single dispatcher pick T, then call feed::(). -// Used as a generic bound (F: DataFeeder), never as dyn — fully monomorphized. -// --------------------------------------------------------------------------- - /// Abstracts over the two write strategies (full array vs streaming iterator). -trait DataFeeder<'py> { - fn feed( - self, - writer: &mut OmFileWriterArray<'_, T, WriterBackendImpl>, - ) -> PyResult<()>; -} - -/// Feeds an entire numpy array in a single `write_data` call. -struct FullArrayFeeder<'a, 'py> { - data: &'a Bound<'py, PyUntypedArray>, -} - -impl<'a, 'py> DataFeeder<'py> for FullArrayFeeder<'a, 'py> { - fn feed( - self, - w: &mut OmFileWriterArray<'_, T, WriterBackendImpl>, - ) -> PyResult<()> { - let array = self.data.cast::>()?.readonly(); - w.write_data(array.as_array(), None, None) - .map_err(convert_omfilesrs_error) - } -} - -/// Feeds data chunk-by-chunk from a Python iterator. -struct StreamingFeeder<'py> { - py: Python<'py>, - iter: Bound<'py, PyAny>, +enum Feeder<'a, 'py> { + /// Feeds an entire numpy array in a single `write_data` call. + Full { + data: &'a Bound<'py, PyUntypedArray>, + }, + /// Feeds data chunk-by-chunk from a Python iterator. + Streaming { + py: Python<'py>, + iter: Bound<'py, PyAny>, + }, } -impl<'py> DataFeeder<'py> for StreamingFeeder<'py> { +impl<'a, 'py> Feeder<'a, 'py> { fn feed( self, w: &mut OmFileWriterArray<'_, T, WriterBackendImpl>, ) -> PyResult<()> { - loop { - match self.iter.call_method0("__next__") { - Ok(item) => { - let array: PyReadonlyArrayDyn<'_, T> = item.extract()?; - w.write_data(array.as_array(), None, None) - .map_err(convert_omfilesrs_error)?; - } - Err(err) if err.is_instance_of::(self.py) => break, - Err(err) => return Err(err), + match self { + Feeder::Full { data } => { + let array = data.cast::>()?.readonly(); + w.write_data(array.as_array(), None, None) + .map_err(convert_omfilesrs_error) } + Feeder::Streaming { py, iter } => loop { + match iter.call_method0("__next__") { + Ok(item) => { + let array: PyReadonlyArrayDyn<'_, T> = item.extract()?; + w.write_data(array.as_array(), None, None) + .map_err(convert_omfilesrs_error)?; + } + Err(err) if err.is_instance_of::(py) => break Ok(()), + Err(err) => break Err(err), + } + }, } - Ok(()) } } @@ -211,130 +198,80 @@ impl OmFileWriter { PyErr::new::(format!("Unsupported scalar data type: {}", type_name)) } - // Helper method for safe writer access. + /// Helper method for safe writer access. fn with_writer(&self, f: F) -> PyResult where F: FnOnce(&mut OmFileWriterRs) -> PyResult, { let mut guard = self.writer.lock().map_err(|e| Self::lock_error(e))?; - match guard.as_mut() { Some(writer) => f(writer), None => Err(Self::closed_error()), } } - /// Prepare a typed array writer, feed data into it, and finalize. + /// Unified 10-way type dispatch. /// - /// The `feed_data` closure is the only thing that differs between full writes - /// (one `write_data` call) and streaming writes (a loop over an iterator). - /// Because `F` is `FnOnce`, the compiler monomorphizes each call-site — - /// no dynamic dispatch, zero overhead. - fn write_array_unified( - &mut self, - dimensions: Vec, - chunks: Vec, - scale_factor: f32, - add_offset: f32, - compression: OmCompressionType, - feed_data: F, - ) -> PyResult - where - T: Element + OmFileArrayDataType, - F: FnOnce(&mut OmFileWriterArray<'_, T, WriterBackendImpl>) -> PyResult<()>, - { - self.with_writer(|writer| { - let mut array_writer = writer - .prepare_array::(dimensions, chunks, compression, scale_factor, add_offset) - .map_err(convert_omfilesrs_error)?; - - feed_data(&mut array_writer)?; - - Ok(array_writer.finalize()) - }) - } - - /// The **single** 10-way type dispatch. - /// - /// Resolves `element_type` to a concrete `T`, calls `write_array_unified` - /// with `feeder.feed::()`, then registers the result as a named variable. + /// Resolves `element_type` to a concrete `T`, prepares a typed array writer, + /// feeds data, finalizes, and registers the result as a named variable. /// Both `write_array` and `write_array_streaming` delegate here. - fn write_array_dispatched<'py, F: DataFeeder<'py>>( - &mut self, + fn write_array_dispatched( + &self, element_type: OmElementType, dimensions: Vec, chunks: Vec, params: &WriteArrayParams<'_>, - feeder: F, - ) -> PyResult { - // Each match arm is identical except for the type, this macro generates the correct match arms. - macro_rules! dispatch { - ($($variant:ident => $T:ty),+ $(,)?) => { - match element_type { - $(OmElementType::$variant => self.write_array_unified::<$T, _>( - dimensions, - chunks, - params.scale_factor, - params.add_offset, - params.compression, - |w| feeder.feed::<$T>(w), - ),)+ - } - }; - } - - let array_meta = dispatch! { - Float32 => f32, - Float64 => f64, - Int8 => i8, - Uint8 => u8, - Int16 => i16, - Uint16 => u16, - Int32 => i32, - Uint32 => u32, - Int64 => i64, - Uint64 => u64, - }?; - - self.finalize_array_variable(array_meta, params.name, ¶ms.children) - } - - /// Finalize array metadata and register it in the file structure as a named variable. - fn finalize_array_variable( - &self, - array_meta: OmFileWriterArrayFinalized, - name: &str, - children: &[OmOffsetSize], + feeder: Feeder<'_, '_>, ) -> PyResult { self.with_writer(|writer| { - let offset_size = writer - .write_array(array_meta, name, children) - .map_err(convert_omfilesrs_error)?; + macro_rules! dispatch { + ($($variant:ident => $T:ty),+ $(,)?) => { + match element_type { + $(OmElementType::$variant => { + let mut w = writer + .prepare_array::<$T>( + dimensions, chunks, params.compression, + params.scale_factor, params.add_offset, + ) + .map_err(convert_omfilesrs_error)?; + feeder.feed::<$T>(&mut w)?; + w.finalize() + }),+ + } + }; + } - Ok(OmVariable { - name: name.to_string(), - offset: offset_size.offset, - size: offset_size.size, - }) + let array_meta = dispatch! { + Float32 => f32, + Float64 => f64, + Int8 => i8, + Uint8 => u8, + Int16 => i16, + Uint16 => u16, + Int32 => i32, + Uint32 => u32, + Int64 => i64, + Uint64 => u64, + }; + + writer + .write_array(array_meta, params.name, ¶ms.children) + .map_err(convert_omfilesrs_error) + .map(|os| to_variable(params.name, os)) }) } fn store_scalar( - &mut self, + &self, value: T, name: &str, children: &[OmOffsetSize], ) -> PyResult { self.with_writer(|writer| { - let offset_size = writer + writer .write_scalar(value, name, children) - .map_err(convert_omfilesrs_error)?; - - Ok(OmVariable { - name: name.to_string(), - offset: offset_size.offset, - size: offset_size.size, - }) + .map_err(convert_omfilesrs_error) + .map(|os| to_variable(name, os)) }) } } @@ -400,8 +337,9 @@ impl OmFileWriter { let mut guard = self.writer.lock().map_err(|e| Self::lock_error(e))?; if let Some(writer) = guard.as_mut() { - let result = writer.write_trailer(root_variable.into()); - result.map_err(convert_omfilesrs_error)?; + writer + .write_trailer(root_variable.into()) + .map_err(convert_omfilesrs_error)?; // Take ownership and drop to ensure proper file closure guard.take(); } else { @@ -415,7 +353,6 @@ impl OmFileWriter { #[getter] fn closed(&self) -> PyResult { let guard = self.writer.lock().map_err(|e| Self::lock_error(e))?; - Ok(guard.is_none()) } @@ -461,7 +398,7 @@ impl OmFileWriter { WriteArrayParams::from_options(name, children, scale_factor, add_offset, compression)?; let element_type = OmElementType::from_numpy_dtype(data.py(), &data.dtype())?; let dimensions = data.shape().iter().map(|x| *x as u64).collect(); - let feeder = FullArrayFeeder { data }; + let feeder = Feeder::Full { data }; self.write_array_dispatched(element_type, dimensions, chunks, ¶ms, feeder) } @@ -513,7 +450,7 @@ impl OmFileWriter { WriteArrayParams::from_options(name, children, scale_factor, add_offset, compression)?; let element_type = OmElementType::from_str(dtype)?; let iter = chunk_iterator.call_method0("__iter__")?; - let feeder = StreamingFeeder { py, iter }; + let feeder = Feeder::Streaming { py, iter }; self.write_array_dispatched(element_type, dimensions, chunks, ¶ms, feeder) } @@ -622,15 +559,10 @@ impl OmFileWriter { let children: Vec = children.iter().map(Into::into).collect(); self.with_writer(|writer| { - let offset_size = writer + writer .write_none(name, &children) - .map_err(convert_omfilesrs_error)?; - - Ok(OmVariable { - name: name.to_string(), - offset: offset_size.offset, - size: offset_size.size, - }) + .map_err(convert_omfilesrs_error) + .map(|os| to_variable(name, os)) }) } } From c9a07fb4c0f75f89e0ca5573c44ec839cb2e918c Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 11:39:29 +0100 Subject: [PATCH 14/27] remove example --- examples/dask_larger_than_ram.py | 109 ------------------------------- 1 file changed, 109 deletions(-) delete mode 100644 examples/dask_larger_than_ram.py diff --git a/examples/dask_larger_than_ram.py b/examples/dask_larger_than_ram.py deleted file mode 100644 index 852cc11..0000000 --- a/examples/dask_larger_than_ram.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env -S uv run --script -# -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "omfiles>=1.1.1", # x-release-please-version -# "dask[array]>=2023.1.0", -# ] -# /// -# -# This example demonstrates writing a dask array that is larger than the -# available process memory to an OM file using streaming writes. -# -# The dask array is never fully materialized — only one chunk is held in -# memory at a time thanks to write_dask_array(). tracemalloc is used to -# prove that peak memory stays well below the total dataset size. - -import os -import tempfile -import tracemalloc - -import dask.array as da -import numpy as np -from omfiles import OmFileReader, OmFileWriter -from omfiles.dask import write_dask_array - -# Configuration -DATASET_SIZE_MB = 512 # total size of the dask array -CHUNK_SIZE = 1024 # chunk edge length (CHUNK_SIZE x CHUNK_SIZE) -DTYPE = np.float32 # 4 bytes per element - -# Derived constants -bytes_per_element = np.dtype(DTYPE).itemsize -total_elements = (DATASET_SIZE_MB * 1024 * 1024) // bytes_per_element -side_length = int(np.sqrt(total_elements)) # square array for simplicity -actual_size_mb = (side_length * side_length * bytes_per_element) / (1024 * 1024) - - -def main(): - print("=" * 60) - print("Dask larger-than-RAM write example") - print("=" * 60) - - # Start memory tracking - tracemalloc.start() - - # Create a dask array larger than available memory - print( - f"\nCreating dask array: {side_length} x {side_length} {DTYPE.__name__} " - f"({actual_size_mb:.0f} MB, chunked {CHUNK_SIZE} x {CHUNK_SIZE})" - ) - - data = da.random.random( - (side_length, side_length), - chunks=(CHUNK_SIZE, CHUNK_SIZE), - ).astype(DTYPE) - - print(f" Shape: {data.shape}") - print(f" Chunks: {data.chunksize}") - print(f" Num blocks: {data.numblocks} ({np.prod(data.numblocks)} total)") - - # Write to .om file via streaming - fd, filepath = tempfile.mkstemp(suffix=".om") - os.close(fd) - - print(f"\nWriting to {filepath} ...") - writer = OmFileWriter(filepath) - root = write_dask_array(writer, data, name="temperature") - writer.close(root) - - file_size_mb = os.path.getsize(filepath) / (1024 * 1024) - print(f" File size on disk: {file_size_mb:.1f} MB (compression ratio: {actual_size_mb / file_size_mb:.1f}x)") - - # Read back a slice and verify - print("\nReading back a slice to verify...") - with OmFileReader(filepath) as reader: - print(f" Reader shape: {reader.shape}, dtype: {reader.dtype}") - sample = reader[0:10, 0:10] - print(f" Sample slice [0:10, 0:10] shape: {sample.shape}") - print(f" Sample values (first row): {sample[0, :5]}") - assert sample.shape == (10, 10), "Unexpected slice shape" - assert not np.any(np.isnan(sample)), "Found NaN values in readback" - - print(" Verification passed!") - - # Memory summary - current, peak = tracemalloc.get_traced_memory() - tracemalloc.stop() - peak_traced_mb = peak / (1024 * 1024) - - print("\n" + "=" * 60) - print("Memory summary") - print("=" * 60) - print(f" Dataset size: {actual_size_mb:.0f} MB") - print(f" Peak traced (Python): {peak_traced_mb:.1f} MB") - print(f" Ratio (dataset/peak): {actual_size_mb / peak_traced_mb:.1f}x") - print() - - if peak_traced_mb < actual_size_mb: - print("The entire dataset was written WITHOUT loading it all into memory.") - else: - print("WARNING: Peak memory exceeded dataset size — streaming may not have worked as expected.") - - # Cleanup - os.unlink(filepath) - - -if __name__ == "__main__": - main() From 9669d3d6bcf65ab098ff77ec66bfe996815a7027 Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 11:51:28 +0100 Subject: [PATCH 15/27] use freestanding tests and separate into two files --- tests/test_dask.py | 183 ++++++++++++ tests/test_streaming_write.py | 547 +++++++++++----------------------- 2 files changed, 356 insertions(+), 374 deletions(-) create mode 100644 tests/test_dask.py diff --git a/tests/test_dask.py b/tests/test_dask.py new file mode 100644 index 0000000..82b6ca4 --- /dev/null +++ b/tests/test_dask.py @@ -0,0 +1,183 @@ +import tempfile +import tracemalloc + +import dask.array as da +import numpy as np +import pytest +from omfiles import OmFileReader, OmFileWriter +from omfiles.dask import write_dask_array + + +@pytest.fixture +def dask_array_2d(): + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + return da.from_array(np_data, chunks=(5, 10)) + + +@pytest.fixture +def dask_array_3d(): + np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) + return da.from_array(np_data, chunks=(2, 3, 4)) + + +def test_dask_roundtrip_2d(dask_array_2d): + expected = dask_array_2d.compute() + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, dask_array_2d, scale_factor=10000.0) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, expected, decimal=4) + + +def test_dask_roundtrip_3d(dask_array_3d): + expected = dask_array_3d.compute() + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, dask_array_3d) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_equal(result, expected) + + +def test_dask_boundary_chunks(): + np_data = np.arange(91, dtype=np.float32).reshape(7, 13) + darr = da.from_array(np_data, chunks=(4, 5)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, darr, scale_factor=10000.0) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, np_data, decimal=4) + + +def test_dask_custom_name(dask_array_2d): + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, dask_array_2d, scale_factor=10000.0, name="temperature") + assert var.name == "temperature" + writer.close(var) + + +def test_dask_non_multiple_chunks_raises(): + """Dask chunks that aren't multiples of OM chunks should raise.""" + np_data = np.arange(30, dtype=np.float32).reshape(6, 5) + darr = da.from_array(np_data, chunks=(3, 5)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(ValueError, match="not a multiple"): + write_dask_array(writer, darr, chunks=[2, 5]) + + +def test_dask_larger_chunks_than_om_2d(): + """Dask blocks spanning multiple OM chunks along dim 1 (full trailing dim).""" + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + darr = da.from_array(np_data, chunks=(10, 20)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, darr, chunks=[5, 10], scale_factor=10000.0) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, np_data, decimal=4) + + +def test_dask_larger_chunks_than_om_3d(): + """Dask blocks with full trailing dims, multiple OM chunks in dim 0.""" + np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) + darr = da.from_array(np_data, chunks=(4, 6, 8)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, darr, chunks=[2, 3, 4]) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_equal(result, np_data) + + +def test_dask_single_om_chunk_per_slow_dim(): + """Dask blocks with 1 OM chunk in dim 0, partial trailing dim coverage.""" + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + darr = da.from_array(np_data, chunks=(5, 10)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, darr, chunks=[5, 5], scale_factor=10000.0) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, np_data, decimal=4) + + +def test_dask_misaligned_trailing_dims_raises(): + """Dask blocks with multi-chunk dim 0 but partial trailing dim raises.""" + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + darr = da.from_array(np_data, chunks=(10, 10)) + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(ValueError, match="not fully covered"): + write_dask_array(writer, darr, chunks=[5, 5]) + + +def test_dask_not_a_dask_array_raises(): + np_data = np.arange(20, dtype=np.float32).reshape(4, 5) + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(TypeError, match="Expected a dask array"): + write_dask_array(writer, np_data) + + +def test_dask_streaming_memory_stays_bounded(): + """Peak memory during a dask streaming write stays well below the full dataset size.""" + # ~16 MB dataset (2048 x 2048 x float32), written in 256x256 chunks (~256 KB each) + side = 2048 + chunk = 256 + dtype = np.float32 + dataset_bytes = side * side * np.dtype(dtype).itemsize + + darr = da.random.random((side, side), chunks=(chunk, chunk)).astype(dtype) + + tracemalloc.start() + + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + var = write_dask_array(writer, darr) + writer.close(var) + + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + # Peak Python memory should be a fraction of the total dataset size, + # proving that chunks are streamed rather than fully materialized. + assert peak < dataset_bytes, ( + f"Peak traced memory ({peak / 1024 / 1024:.1f} MB) should be less than " + f"the dataset size ({dataset_bytes / 1024 / 1024:.1f} MB)" + ) diff --git a/tests/test_streaming_write.py b/tests/test_streaming_write.py index dee419d..03a0f81 100644 --- a/tests/test_streaming_write.py +++ b/tests/test_streaming_write.py @@ -1,169 +1,107 @@ import tempfile -import tracemalloc import numpy as np import pytest from omfiles import OmFileReader, OmFileWriter -class TestWriteArrayStreaming: - def test_streaming_single_chunk(self): - shape = (10, 20) - chunks = [10, 20] - data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) +def test_streaming_single_chunk(): + shape = (10, 20) + chunks = [10, 20] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) - def chunk_iter(): - yield data + def chunk_iter(): + yield data - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype="float32", - scale_factor=10000.0, - ) - writer.close(var) + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, + ) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(f.name) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, data, decimal=4) + np.testing.assert_array_almost_equal(result, data, decimal=4) - def test_streaming_multiple_chunks_2d(self): - shape = (10, 20) - chunks = [5, 10] - data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) +def test_streaming_multiple_chunks_2d(): + shape = (10, 20) + chunks = [5, 10] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - def chunk_iter(): - for i in range(0, 10, 5): - for j in range(0, 20, 10): - yield data[i : i + 5, j : j + 10].copy() + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype="float32", - scale_factor=10000.0, - ) - writer.close(var) + def chunk_iter(): + for i in range(0, 10, 5): + for j in range(0, 20, 10): + yield data[i : i + 5, j : j + 10].copy() - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, + ) + writer.close(var) - np.testing.assert_array_almost_equal(result, data, decimal=4) - - def test_streaming_all_dtypes(self): - shape = (6, 8) - chunks = [3, 4] - dtypes = [ - np.float32, - np.float64, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ] - - for dt in dtypes: - if np.issubdtype(dt, np.floating): - data = np.random.rand(*shape).astype(dt) - elif np.issubdtype(dt, np.signedinteger): - info = np.iinfo(dt) - data = np.random.randint(max(info.min, -1000), min(info.max, 1000), size=shape, dtype=dt) - else: - info = np.iinfo(dt) - data = np.random.randint(0, min(info.max, 1000), size=shape, dtype=dt) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - - def chunk_iter(d=data): - for i in range(0, shape[0], chunks[0]): - for j in range(0, shape[1], chunks[1]): - ie = min(i + chunks[0], shape[0]) - je = min(j + chunks[1], shape[1]) - yield d[i:ie, j:je].copy() - - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype=np.dtype(dt).name, - scale_factor=10000.0, - ) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - assert result.dtype == dt, f"dtype mismatch for {dt}" - np.testing.assert_array_almost_equal(result, data, decimal=4) + reader = OmFileReader(f.name) + result = reader[:] + reader.close() - def test_streaming_3d_array(self): - shape = (4, 6, 8) - chunks = [2, 3, 4] - data = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) + np.testing.assert_array_almost_equal(result, data, decimal=4) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - def chunk_iter(): - for i in range(0, shape[0], chunks[0]): - for j in range(0, shape[1], chunks[1]): - for k in range(0, shape[2], chunks[2]): - ie = min(i + chunks[0], shape[0]) - je = min(j + chunks[1], shape[1]) - ke = min(k + chunks[2], shape[2]) - yield data[i:ie, j:je, k:ke].copy() +def test_streaming_all_dtypes(): + shape = (6, 8) + chunks = [3, 4] + dtypes = [ + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ] - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype="int32", - ) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - np.testing.assert_array_equal(result, data) - - def test_streaming_boundary_chunks(self): - shape = (7, 13) - chunks = [4, 5] - data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + for dt in dtypes: + if np.issubdtype(dt, np.floating): + data = np.random.rand(*shape).astype(dt) + elif np.issubdtype(dt, np.signedinteger): + info = np.iinfo(dt) + data = np.random.randint(max(info.min, -1000), min(info.max, 1000), size=shape, dtype=dt) + else: + info = np.iinfo(dt) + data = np.random.randint(0, min(info.max, 1000), size=shape, dtype=dt) with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) - def chunk_iter(): + def chunk_iter(d=data): for i in range(0, shape[0], chunks[0]): for j in range(0, shape[1], chunks[1]): ie = min(i + chunks[0], shape[0]) je = min(j + chunks[1], shape[1]) - yield data[i:ie, j:je].copy() + yield d[i:ie, j:je].copy() var = writer.write_array_streaming( dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype="float32", + dtype=np.dtype(dt).name, scale_factor=10000.0, ) writer.close(var) @@ -172,260 +110,121 @@ def chunk_iter(): result = reader[:] reader.close() - np.testing.assert_array_almost_equal(result, data, decimal=4) + assert result.dtype == dt, f"dtype mismatch for {dt}" + if np.issubdtype(dt, np.floating): + np.testing.assert_array_almost_equal(result, data, decimal=4) + else: + np.testing.assert_array_equal(result, data) - def test_streaming_matches_write_array(self): - shape = (10, 20) - chunks = [5, 10] - data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f1: - writer1 = OmFileWriter(f1.name) - var1 = writer1.write_array(data, chunks=chunks, scale_factor=10000.0) - writer1.close(var1) - reader1 = OmFileReader(f1.name) - result1 = reader1[:] - reader1.close() +def test_streaming_3d_array(): + shape = (4, 6, 8) + chunks = [2, 3, 4] + data = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f2: - writer2 = OmFileWriter(f2.name) + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) - def chunk_iter(): - for i in range(0, shape[0], chunks[0]): - for j in range(0, shape[1], chunks[1]): + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + for k in range(0, shape[2], chunks[2]): ie = min(i + chunks[0], shape[0]) je = min(j + chunks[1], shape[1]) - yield data[i:ie, j:je].copy() - - var2 = writer2.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype="float32", - scale_factor=10000.0, - ) - writer2.close(var2) - reader2 = OmFileReader(f2.name) - result2 = reader2[:] - reader2.close() - - np.testing.assert_array_equal(result1, result2) - - def test_streaming_unsupported_dtype_raises(self): - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(ValueError, match="Unsupported dtype"): - writer.write_array_streaming( - dimensions=[10], - chunks=[5], - chunk_iterator=iter([]), - dtype="complex128", - ) - - -class TestWriteDaskArray: - @pytest.fixture(autouse=True) - def _import_dask(self): - pytest.importorskip("dask.array") - from omfiles.dask import write_dask_array - - self.write_dask_array = write_dask_array - - @pytest.fixture - def dask_array_2d(self): - import dask.array as da - - np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - return da.from_array(np_data, chunks=(5, 10)) - - @pytest.fixture - def dask_array_3d(self): - import dask.array as da - - np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) - return da.from_array(np_data, chunks=(2, 3, 4)) - - def test_dask_roundtrip_2d(self, dask_array_2d): - expected = dask_array_2d.compute() - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array( - writer, - dask_array_2d, - scale_factor=10000.0, - ) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - np.testing.assert_array_almost_equal(result, expected, decimal=4) - - def test_dask_roundtrip_3d(self, dask_array_3d): - expected = dask_array_3d.compute() - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array(writer, dask_array_3d) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - np.testing.assert_array_equal(result, expected) - - def test_dask_boundary_chunks(self): - import dask.array as da - - np_data = np.arange(91, dtype=np.float32).reshape(7, 13) - darr = da.from_array(np_data, chunks=(4, 5)) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array(writer, darr, scale_factor=10000.0) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - np.testing.assert_array_almost_equal(result, np_data, decimal=4) - - def test_dask_custom_name(self, dask_array_2d): - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array( - writer, - dask_array_2d, - scale_factor=10000.0, - name="temperature", - ) - assert var.name == "temperature" - writer.close(var) - - def test_dask_non_multiple_chunks_raises(self): - """Dask chunks that aren't multiples of OM chunks should raise.""" - import dask.array as da - - np_data = np.arange(30, dtype=np.float32).reshape(6, 5) - darr = da.from_array(np_data, chunks=(3, 5)) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(ValueError, match="not a multiple"): - self.write_dask_array(writer, darr, chunks=[2, 5]) - - def test_dask_larger_chunks_than_om_2d(self): - """Dask blocks spanning multiple OM chunks along dim 1 (full trailing dim).""" - import dask.array as da - - np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - darr = da.from_array(np_data, chunks=(10, 20)) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array( - writer, - darr, - chunks=[5, 10], - scale_factor=10000.0, - ) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - np.testing.assert_array_almost_equal(result, np_data, decimal=4) - - def test_dask_larger_chunks_than_om_3d(self): - """Dask blocks with full trailing dims, multiple OM chunks in dim 0.""" - import dask.array as da - - np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) - darr = da.from_array(np_data, chunks=(4, 6, 8)) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array(writer, darr, chunks=[2, 3, 4]) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - np.testing.assert_array_equal(result, np_data) - - def test_dask_single_om_chunk_per_slow_dim(self): - """Dask blocks with 1 OM chunk in dim 0, partial trailing dim coverage.""" - import dask.array as da - - np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - darr = da.from_array(np_data, chunks=(5, 10)) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array( - writer, - darr, - chunks=[5, 5], - scale_factor=10000.0, - ) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - np.testing.assert_array_almost_equal(result, np_data, decimal=4) - - def test_dask_misaligned_trailing_dims_raises(self): - """Dask blocks with multi-chunk dim 0 but partial trailing dim raises.""" - import dask.array as da - - np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - darr = da.from_array(np_data, chunks=(10, 10)) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(ValueError, match="not fully covered"): - self.write_dask_array(writer, darr, chunks=[5, 5]) - - def test_dask_not_a_dask_array_raises(self): - np_data = np.arange(20, dtype=np.float32).reshape(4, 5) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(TypeError, match="Expected a dask array"): - self.write_dask_array(writer, np_data) + ke = min(k + chunks[2], shape[2]) + yield data[i:ie, j:je, k:ke].copy() + + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="int32", + ) + writer.close(var) - def test_dask_streaming_memory_stays_bounded(self): - """Peak memory during a dask streaming write stays well below the full dataset size.""" - # ~16 MB dataset (2048 x 2048 x float32), written in 256x256 chunks (~256 KB each) - side = 2048 - chunk = 256 - dtype = np.float32 - dataset_bytes = side * side * np.dtype(dtype).itemsize + reader = OmFileReader(f.name) + result = reader[:] + reader.close() - import dask.array as da + np.testing.assert_array_equal(result, data) - darr = da.random.random((side, side), chunks=(chunk, chunk)).astype(dtype) - tracemalloc.start() +def test_streaming_boundary_chunks(): + shape = (7, 13) + chunks = [4, 5] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = self.write_dask_array(writer, darr) - writer.close(var) + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) - _, peak = tracemalloc.get_traced_memory() - tracemalloc.stop() + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield data[i:ie, j:je].copy() - # Peak Python memory should be a fraction of the total dataset size, - # proving that chunks are streamed rather than fully materialized. - assert peak < dataset_bytes, ( - f"Peak traced memory ({peak / 1024 / 1024:.1f} MB) should be less than " - f"the dataset size ({dataset_bytes / 1024 / 1024:.1f} MB)" + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, ) + writer.close(var) + + reader = OmFileReader(f.name) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, data, decimal=4) + + +def test_streaming_matches_write_array(): + shape = (10, 20) + chunks = [5, 10] + data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + with tempfile.NamedTemporaryFile(suffix=".om") as f1: + writer1 = OmFileWriter(f1.name) + var1 = writer1.write_array(data, chunks=chunks, scale_factor=10000.0) + writer1.close(var1) + reader1 = OmFileReader(f1.name) + result1 = reader1[:] + reader1.close() + + with tempfile.NamedTemporaryFile(suffix=".om") as f2: + writer2 = OmFileWriter(f2.name) + + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield data[i:ie, j:je].copy() + + var2 = writer2.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype="float32", + scale_factor=10000.0, + ) + writer2.close(var2) + reader2 = OmFileReader(f2.name) + result2 = reader2[:] + reader2.close() + + np.testing.assert_array_equal(result1, result2) + + +def test_streaming_unsupported_dtype_raises(): + with tempfile.NamedTemporaryFile(suffix=".om") as f: + writer = OmFileWriter(f.name) + with pytest.raises(ValueError, match="Unsupported dtype"): + writer.write_array_streaming( + dimensions=[10], + chunks=[5], + chunk_iterator=iter([]), + dtype="complex128", + ) From c65f173d68fc7cce5d64c46abdfa13942933f815 Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 12:08:16 +0100 Subject: [PATCH 16/27] fix pyright errors --- tests/test_chunk_reader.py | 2 +- tests/test_dask.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_chunk_reader.py b/tests/test_chunk_reader.py index 093558e..ad779f0 100644 --- a/tests/test_chunk_reader.py +++ b/tests/test_chunk_reader.py @@ -92,7 +92,7 @@ def test_load_data_success(chunk_reader: OmChunkFileReader, icond2_om_chunks_met for chunk_idx in chunk_reader.chunk_indices: chunk_times = icond2_om_chunks_meta.get_chunk_time_range(chunk_idx) time_mask = (chunk_times >= chunk_reader.start_date) & (chunk_times <= chunk_reader.end_date) - num_points = np.sum(time_mask) + num_points = int(np.sum(time_mask)) # Create mock data with the correct length expected_data.append(np.arange(num_points, dtype=np.float32)) diff --git a/tests/test_dask.py b/tests/test_dask.py index 82b6ca4..8578ddc 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -11,13 +11,13 @@ @pytest.fixture def dask_array_2d(): np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - return da.from_array(np_data, chunks=(5, 10)) + return da.from_array(np_data, chunks=(5, 10)) # type: ignore[arg-type] @pytest.fixture def dask_array_3d(): np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) - return da.from_array(np_data, chunks=(2, 3, 4)) + return da.from_array(np_data, chunks=(2, 3, 4)) # type: ignore[arg-type] def test_dask_roundtrip_2d(dask_array_2d): @@ -52,7 +52,7 @@ def test_dask_roundtrip_3d(dask_array_3d): def test_dask_boundary_chunks(): np_data = np.arange(91, dtype=np.float32).reshape(7, 13) - darr = da.from_array(np_data, chunks=(4, 5)) + darr = da.from_array(np_data, chunks=(4, 5)) # type: ignore[arg-type] with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) @@ -77,7 +77,7 @@ def test_dask_custom_name(dask_array_2d): def test_dask_non_multiple_chunks_raises(): """Dask chunks that aren't multiples of OM chunks should raise.""" np_data = np.arange(30, dtype=np.float32).reshape(6, 5) - darr = da.from_array(np_data, chunks=(3, 5)) + darr = da.from_array(np_data, chunks=(3, 5)) # type: ignore[arg-type] with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) @@ -88,7 +88,7 @@ def test_dask_non_multiple_chunks_raises(): def test_dask_larger_chunks_than_om_2d(): """Dask blocks spanning multiple OM chunks along dim 1 (full trailing dim).""" np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - darr = da.from_array(np_data, chunks=(10, 20)) + darr = da.from_array(np_data, chunks=(10, 20)) # type: ignore[arg-type] with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) @@ -105,7 +105,7 @@ def test_dask_larger_chunks_than_om_2d(): def test_dask_larger_chunks_than_om_3d(): """Dask blocks with full trailing dims, multiple OM chunks in dim 0.""" np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) - darr = da.from_array(np_data, chunks=(4, 6, 8)) + darr = da.from_array(np_data, chunks=(4, 6, 8)) # type: ignore[arg-type] with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) @@ -122,7 +122,7 @@ def test_dask_larger_chunks_than_om_3d(): def test_dask_single_om_chunk_per_slow_dim(): """Dask blocks with 1 OM chunk in dim 0, partial trailing dim coverage.""" np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - darr = da.from_array(np_data, chunks=(5, 10)) + darr = da.from_array(np_data, chunks=(5, 10)) # type: ignore[arg-type] with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) @@ -139,7 +139,7 @@ def test_dask_single_om_chunk_per_slow_dim(): def test_dask_misaligned_trailing_dims_raises(): """Dask blocks with multi-chunk dim 0 but partial trailing dim raises.""" np_data = np.arange(200, dtype=np.float32).reshape(10, 20) - darr = da.from_array(np_data, chunks=(10, 10)) + darr = da.from_array(np_data, chunks=(10, 10)) # type: ignore[arg-type] with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) @@ -152,7 +152,7 @@ def test_dask_not_a_dask_array_raises(): with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) with pytest.raises(TypeError, match="Expected a dask array"): - write_dask_array(writer, np_data) + write_dask_array(writer, np_data) # type: ignore[arg-type] def test_dask_streaming_memory_stays_bounded(): From 9124d6d0eab27a454c6953ab1c981eebf4ec53dd Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 12:38:36 +0100 Subject: [PATCH 17/27] use numpy dtype instead of string --- python/omfiles/_rust/__init__.pyi | 4 ++-- python/omfiles/dask.py | 2 +- src/writer.rs | 30 +++++++----------------------- tests/test_streaming_write.py | 16 ++++++++-------- 4 files changed, 18 insertions(+), 34 deletions(-) diff --git a/python/omfiles/_rust/__init__.pyi b/python/omfiles/_rust/__init__.pyi index a6b6304..03279d9 100644 --- a/python/omfiles/_rust/__init__.pyi +++ b/python/omfiles/_rust/__init__.pyi @@ -578,7 +578,7 @@ class OmFileWriter: dimensions: typing.Sequence[builtins.int], chunks: typing.Sequence[builtins.int], chunk_iterator: typing.Any, - dtype: builtins.str, + dtype: numpy.dtype, scale_factor: typing.Optional[builtins.float] = None, add_offset: typing.Optional[builtins.float] = None, compression: typing.Optional[builtins.str] = None, @@ -599,7 +599,7 @@ class OmFileWriter: dimensions: Shape of the full array (e.g., [1000, 2000]) chunks: Chunk sizes for each dimension (e.g., [100, 200]) chunk_iterator: Python iterable yielding numpy arrays, one per chunk region - dtype: String name of the numpy dtype (e.g., "float32", "int64") + dtype: Numpy dtype of the array. scale_factor: Scale factor for data compression (default: 1.0) add_offset: Offset value for data compression (default: 0.0) compression: Compression algorithm to use (default: "pfor_delta_2d") diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py index b054309..06ab286 100644 --- a/python/omfiles/dask.py +++ b/python/omfiles/dask.py @@ -134,7 +134,7 @@ def write_dask_array( dimensions=[int(d) for d in data.shape], chunks=[int(c) for c in chunks], chunk_iterator=_dask_block_iterator(data), - dtype=data.dtype.name, + dtype=data.dtype, scale_factor=scale_factor, add_offset=add_offset, compression=compression, diff --git a/src/writer.rs b/src/writer.rs index c5ad3c2..4504748 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -47,7 +47,8 @@ enum OmElementType { impl OmElementType { /// Resolve from a numpy `PyArrayDescr` (used by `write_array`). - fn from_numpy_dtype(py: Python<'_>, d: &Bound<'_, PyArrayDescr>) -> PyResult { + fn from_numpy_dtype(d: &Bound<'_, PyArrayDescr>) -> PyResult { + let py = d.py(); if d.is_equiv_to(&dtype::(py)) { Ok(Self::Float32) } else if d.is_equiv_to(&dtype::(py)) { @@ -72,23 +73,6 @@ impl OmElementType { Err(OmFileWriter::unsupported_array_type_error(d.clone())) } } - - /// Resolve from a dtype string like `"float32"` (used by `write_array_streaming`). - fn from_str(s: &str) -> PyResult { - match s { - "float32" => Ok(Self::Float32), - "float64" => Ok(Self::Float64), - "int8" => Ok(Self::Int8), - "uint8" => Ok(Self::Uint8), - "int16" => Ok(Self::Int16), - "uint16" => Ok(Self::Uint16), - "int32" => Ok(Self::Int32), - "uint32" => Ok(Self::Uint32), - "int64" => Ok(Self::Int64), - "uint64" => Ok(Self::Uint64), - _ => Err(PyValueError::new_err(format!("Unsupported dtype: {}", s))), - } - } } /// Abstracts over the two write strategies (full array vs streaming iterator). @@ -396,7 +380,7 @@ impl OmFileWriter { ) -> PyResult { let params = WriteArrayParams::from_options(name, children, scale_factor, add_offset, compression)?; - let element_type = OmElementType::from_numpy_dtype(data.py(), &data.dtype())?; + let element_type = OmElementType::from_numpy_dtype(&data.dtype())?; let dimensions = data.shape().iter().map(|x| *x as u64).collect(); let feeder = Feeder::Full { data }; @@ -416,7 +400,7 @@ impl OmFileWriter { /// dimensions: Shape of the full array (e.g., [1000, 2000]) /// chunks: Chunk sizes for each dimension (e.g., [100, 200]) /// chunk_iterator: Python iterable yielding numpy arrays, one per chunk region - /// dtype: String name of the numpy dtype (e.g., "float32", "int64") + /// dtype: Numpy dtype of the array. /// scale_factor: Scale factor for data compression (default: 1.0) /// add_offset: Offset value for data compression (default: 0.0) /// compression: Compression algorithm to use (default: "pfor_delta_2d") @@ -433,13 +417,13 @@ impl OmFileWriter { text_signature = "(dimensions, chunks, chunk_iterator, dtype, scale_factor=1.0, add_offset=0.0, compression='pfor_delta_2d', name='data', children=[])", signature = (dimensions, chunks, chunk_iterator, dtype, scale_factor=None, add_offset=None, compression=None, name=None, children=None) )] - fn write_array_streaming( + fn write_array_streaming<'py>( &mut self, py: Python<'_>, dimensions: Vec, chunks: Vec, chunk_iterator: &Bound<'_, PyAny>, - dtype: &str, + dtype: &Bound<'py, PyArrayDescr>, scale_factor: Option, add_offset: Option, compression: Option<&str>, @@ -448,7 +432,7 @@ impl OmFileWriter { ) -> PyResult { let params = WriteArrayParams::from_options(name, children, scale_factor, add_offset, compression)?; - let element_type = OmElementType::from_str(dtype)?; + let element_type = OmElementType::from_numpy_dtype(dtype)?; let iter = chunk_iterator.call_method0("__iter__")?; let feeder = Feeder::Streaming { py, iter }; diff --git a/tests/test_streaming_write.py b/tests/test_streaming_write.py index 03a0f81..9185e29 100644 --- a/tests/test_streaming_write.py +++ b/tests/test_streaming_write.py @@ -20,7 +20,7 @@ def chunk_iter(): dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype="float32", + dtype=np.dtype(np.float32), scale_factor=10000.0, ) writer.close(var) @@ -49,7 +49,7 @@ def chunk_iter(): dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype="float32", + dtype=np.dtype(np.float32), scale_factor=10000.0, ) writer.close(var) @@ -101,7 +101,7 @@ def chunk_iter(d=data): dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype=np.dtype(dt).name, + dtype=np.dtype(dt), scale_factor=10000.0, ) writer.close(var) @@ -138,7 +138,7 @@ def chunk_iter(): dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype="int32", + dtype=np.dtype(np.int32), ) writer.close(var) @@ -168,7 +168,7 @@ def chunk_iter(): dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype="float32", + dtype=np.dtype(np.float32), scale_factor=10000.0, ) writer.close(var) @@ -207,7 +207,7 @@ def chunk_iter(): dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype="float32", + dtype=np.dtype(np.float32), scale_factor=10000.0, ) writer2.close(var2) @@ -221,10 +221,10 @@ def chunk_iter(): def test_streaming_unsupported_dtype_raises(): with tempfile.NamedTemporaryFile(suffix=".om") as f: writer = OmFileWriter(f.name) - with pytest.raises(ValueError, match="Unsupported dtype"): + with pytest.raises(ValueError, match="Unsupported array data type"): writer.write_array_streaming( dimensions=[10], chunks=[5], chunk_iterator=iter([]), - dtype="complex128", + dtype=np.dtype(np.complex128), ) From ae99fb599e05936efbc5c7ac83c9df3f939c187c Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 12:52:40 +0100 Subject: [PATCH 18/27] use same pattern for temporary file as in other tests --- tests/test_dask.py | 150 ++++++++++----------- tests/test_streaming_write.py | 237 +++++++++++++++++----------------- 2 files changed, 184 insertions(+), 203 deletions(-) diff --git a/tests/test_dask.py b/tests/test_dask.py index 8578ddc..aeb5b75 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,4 +1,3 @@ -import tempfile import tracemalloc import dask.array as da @@ -20,142 +19,132 @@ def dask_array_3d(): return da.from_array(np_data, chunks=(2, 3, 4)) # type: ignore[arg-type] -def test_dask_roundtrip_2d(dask_array_2d): +def test_dask_roundtrip_2d(empty_temp_om_file, dask_array_2d): expected = dask_array_2d.compute() - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, dask_array_2d, scale_factor=10000.0) - writer.close(var) + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, dask_array_2d, scale_factor=10000.0) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, expected, decimal=4) + np.testing.assert_array_almost_equal(result, expected, decimal=4) -def test_dask_roundtrip_3d(dask_array_3d): +def test_dask_roundtrip_3d(empty_temp_om_file, dask_array_3d): expected = dask_array_3d.compute() - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, dask_array_3d) - writer.close(var) + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, dask_array_3d) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_equal(result, expected) + np.testing.assert_array_equal(result, expected) -def test_dask_boundary_chunks(): +def test_dask_boundary_chunks(empty_temp_om_file): np_data = np.arange(91, dtype=np.float32).reshape(7, 13) darr = da.from_array(np_data, chunks=(4, 5)) # type: ignore[arg-type] - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, darr, scale_factor=10000.0) - writer.close(var) + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, darr, scale_factor=10000.0) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, np_data, decimal=4) + np.testing.assert_array_almost_equal(result, np_data, decimal=4) -def test_dask_custom_name(dask_array_2d): - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, dask_array_2d, scale_factor=10000.0, name="temperature") - assert var.name == "temperature" - writer.close(var) +def test_dask_custom_name(empty_temp_om_file, dask_array_2d): + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, dask_array_2d, scale_factor=10000.0, name="temperature") + assert var.name == "temperature" + writer.close(var) -def test_dask_non_multiple_chunks_raises(): +def test_dask_non_multiple_chunks_raises(empty_temp_om_file): """Dask chunks that aren't multiples of OM chunks should raise.""" np_data = np.arange(30, dtype=np.float32).reshape(6, 5) darr = da.from_array(np_data, chunks=(3, 5)) # type: ignore[arg-type] - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(ValueError, match="not a multiple"): - write_dask_array(writer, darr, chunks=[2, 5]) + writer = OmFileWriter(empty_temp_om_file) + with pytest.raises(ValueError, match="not a multiple"): + write_dask_array(writer, darr, chunks=[2, 5]) -def test_dask_larger_chunks_than_om_2d(): +def test_dask_larger_chunks_than_om_2d(empty_temp_om_file): """Dask blocks spanning multiple OM chunks along dim 1 (full trailing dim).""" np_data = np.arange(200, dtype=np.float32).reshape(10, 20) darr = da.from_array(np_data, chunks=(10, 20)) # type: ignore[arg-type] - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, darr, chunks=[5, 10], scale_factor=10000.0) - writer.close(var) + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, darr, chunks=[5, 10], scale_factor=10000.0) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, np_data, decimal=4) + np.testing.assert_array_almost_equal(result, np_data, decimal=4) -def test_dask_larger_chunks_than_om_3d(): +def test_dask_larger_chunks_than_om_3d(empty_temp_om_file): """Dask blocks with full trailing dims, multiple OM chunks in dim 0.""" np_data = np.arange(192, dtype=np.int32).reshape(4, 6, 8) darr = da.from_array(np_data, chunks=(4, 6, 8)) # type: ignore[arg-type] - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, darr, chunks=[2, 3, 4]) - writer.close(var) + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, darr, chunks=[2, 3, 4]) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_equal(result, np_data) + np.testing.assert_array_equal(result, np_data) -def test_dask_single_om_chunk_per_slow_dim(): +def test_dask_single_om_chunk_per_slow_dim(empty_temp_om_file): """Dask blocks with 1 OM chunk in dim 0, partial trailing dim coverage.""" np_data = np.arange(200, dtype=np.float32).reshape(10, 20) darr = da.from_array(np_data, chunks=(5, 10)) # type: ignore[arg-type] - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, darr, chunks=[5, 5], scale_factor=10000.0) - writer.close(var) + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, darr, chunks=[5, 5], scale_factor=10000.0) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, np_data, decimal=4) + np.testing.assert_array_almost_equal(result, np_data, decimal=4) -def test_dask_misaligned_trailing_dims_raises(): +def test_dask_misaligned_trailing_dims_raises(empty_temp_om_file): """Dask blocks with multi-chunk dim 0 but partial trailing dim raises.""" np_data = np.arange(200, dtype=np.float32).reshape(10, 20) darr = da.from_array(np_data, chunks=(10, 10)) # type: ignore[arg-type] - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(ValueError, match="not fully covered"): - write_dask_array(writer, darr, chunks=[5, 5]) + writer = OmFileWriter(empty_temp_om_file) + with pytest.raises(ValueError, match="not fully covered"): + write_dask_array(writer, darr, chunks=[5, 5]) -def test_dask_not_a_dask_array_raises(): +def test_dask_not_a_dask_array_raises(empty_temp_om_file): np_data = np.arange(20, dtype=np.float32).reshape(4, 5) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(TypeError, match="Expected a dask array"): - write_dask_array(writer, np_data) # type: ignore[arg-type] + writer = OmFileWriter(empty_temp_om_file) + with pytest.raises(TypeError, match="Expected a dask array"): + write_dask_array(writer, np_data) # type: ignore[arg-type] -def test_dask_streaming_memory_stays_bounded(): +def test_dask_streaming_memory_stays_bounded(empty_temp_om_file): """Peak memory during a dask streaming write stays well below the full dataset size.""" # ~16 MB dataset (2048 x 2048 x float32), written in 256x256 chunks (~256 KB each) side = 2048 @@ -167,10 +156,9 @@ def test_dask_streaming_memory_stays_bounded(): tracemalloc.start() - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - var = write_dask_array(writer, darr) - writer.close(var) + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, darr) + writer.close(var) _, peak = tracemalloc.get_traced_memory() tracemalloc.stop() diff --git a/tests/test_streaming_write.py b/tests/test_streaming_write.py index 9185e29..48e5cd5 100644 --- a/tests/test_streaming_write.py +++ b/tests/test_streaming_write.py @@ -5,63 +5,61 @@ from omfiles import OmFileReader, OmFileWriter -def test_streaming_single_chunk(): +def test_streaming_single_chunk(empty_temp_om_file): shape = (10, 20) chunks = [10, 20] data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) + writer = OmFileWriter(empty_temp_om_file) - def chunk_iter(): - yield data + def chunk_iter(): + yield data - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype=np.dtype(np.float32), - scale_factor=10000.0, - ) - writer.close(var) + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype=np.dtype(np.float32), + scale_factor=10000.0, + ) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, data, decimal=4) + np.testing.assert_array_almost_equal(result, data, decimal=4) -def test_streaming_multiple_chunks_2d(): +def test_streaming_multiple_chunks_2d(empty_temp_om_file): shape = (10, 20) chunks = [5, 10] data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) + writer = OmFileWriter(empty_temp_om_file) - def chunk_iter(): - for i in range(0, 10, 5): - for j in range(0, 20, 10): - yield data[i : i + 5, j : j + 10].copy() + def chunk_iter(): + for i in range(0, 10, 5): + for j in range(0, 20, 10): + yield data[i : i + 5, j : j + 10].copy() - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype=np.dtype(np.float32), - scale_factor=10000.0, - ) - writer.close(var) + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype=np.dtype(np.float32), + scale_factor=10000.0, + ) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, data, decimal=4) + np.testing.assert_array_almost_equal(result, data, decimal=4) -def test_streaming_all_dtypes(): +def test_streaming_all_dtypes(empty_temp_om_file): shape = (6, 8) chunks = [3, 4] dtypes = [ @@ -87,111 +85,107 @@ def test_streaming_all_dtypes(): info = np.iinfo(dt) data = np.random.randint(0, min(info.max, 1000), size=shape, dtype=dt) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - - def chunk_iter(d=data): - for i in range(0, shape[0], chunks[0]): - for j in range(0, shape[1], chunks[1]): - ie = min(i + chunks[0], shape[0]) - je = min(j + chunks[1], shape[1]) - yield d[i:ie, j:je].copy() - - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype=np.dtype(dt), - scale_factor=10000.0, - ) - writer.close(var) - - reader = OmFileReader(f.name) - result = reader[:] - reader.close() - - assert result.dtype == dt, f"dtype mismatch for {dt}" - if np.issubdtype(dt, np.floating): - np.testing.assert_array_almost_equal(result, data, decimal=4) - else: - np.testing.assert_array_equal(result, data) - - -def test_streaming_3d_array(): - shape = (4, 6, 8) - chunks = [2, 3, 4] - data = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) - - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) + writer = OmFileWriter(empty_temp_om_file) - def chunk_iter(): + def chunk_iter(d=data): for i in range(0, shape[0], chunks[0]): for j in range(0, shape[1], chunks[1]): - for k in range(0, shape[2], chunks[2]): - ie = min(i + chunks[0], shape[0]) - je = min(j + chunks[1], shape[1]) - ke = min(k + chunks[2], shape[2]) - yield data[i:ie, j:je, k:ke].copy() + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield d[i:ie, j:je].copy() var = writer.write_array_streaming( dimensions=list(shape), chunks=chunks, chunk_iterator=chunk_iter(), - dtype=np.dtype(np.int32), + dtype=np.dtype(dt), + scale_factor=10000.0, ) writer.close(var) - reader = OmFileReader(f.name) + reader = OmFileReader(empty_temp_om_file) result = reader[:] reader.close() - np.testing.assert_array_equal(result, data) + assert result.dtype == dt, f"dtype mismatch for {dt}" + if np.issubdtype(dt, np.floating): + np.testing.assert_array_almost_equal(result, data, decimal=4) + else: + np.testing.assert_array_equal(result, data) + + +def test_streaming_3d_array(empty_temp_om_file): + shape = (4, 6, 8) + chunks = [2, 3, 4] + data = np.arange(np.prod(shape), dtype=np.int32).reshape(shape) + writer = OmFileWriter(empty_temp_om_file) -def test_streaming_boundary_chunks(): + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + for k in range(0, shape[2], chunks[2]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + ke = min(k + chunks[2], shape[2]) + yield data[i:ie, j:je, k:ke].copy() + + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype=np.dtype(np.int32), + ) + writer.close(var) + + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() + + np.testing.assert_array_equal(result, data) + + +def test_streaming_boundary_chunks(empty_temp_om_file): shape = (7, 13) chunks = [4, 5] data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) + writer = OmFileWriter(empty_temp_om_file) - def chunk_iter(): - for i in range(0, shape[0], chunks[0]): - for j in range(0, shape[1], chunks[1]): - ie = min(i + chunks[0], shape[0]) - je = min(j + chunks[1], shape[1]) - yield data[i:ie, j:je].copy() + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield data[i:ie, j:je].copy() - var = writer.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype=np.dtype(np.float32), - scale_factor=10000.0, - ) - writer.close(var) + var = writer.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype=np.dtype(np.float32), + scale_factor=10000.0, + ) + writer.close(var) - reader = OmFileReader(f.name) - result = reader[:] - reader.close() + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() - np.testing.assert_array_almost_equal(result, data, decimal=4) + np.testing.assert_array_almost_equal(result, data, decimal=4) -def test_streaming_matches_write_array(): +def test_streaming_matches_write_array(empty_temp_om_file): shape = (10, 20) chunks = [5, 10] data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with tempfile.NamedTemporaryFile(suffix=".om") as f1: - writer1 = OmFileWriter(f1.name) - var1 = writer1.write_array(data, chunks=chunks, scale_factor=10000.0) - writer1.close(var1) - reader1 = OmFileReader(f1.name) - result1 = reader1[:] - reader1.close() + writer1 = OmFileWriter(empty_temp_om_file) + var1 = writer1.write_array(data, chunks=chunks, scale_factor=10000.0) + writer1.close(var1) + reader1 = OmFileReader(empty_temp_om_file) + result1 = reader1[:] + reader1.close() with tempfile.NamedTemporaryFile(suffix=".om") as f2: writer2 = OmFileWriter(f2.name) @@ -218,13 +212,12 @@ def chunk_iter(): np.testing.assert_array_equal(result1, result2) -def test_streaming_unsupported_dtype_raises(): - with tempfile.NamedTemporaryFile(suffix=".om") as f: - writer = OmFileWriter(f.name) - with pytest.raises(ValueError, match="Unsupported array data type"): - writer.write_array_streaming( - dimensions=[10], - chunks=[5], - chunk_iterator=iter([]), - dtype=np.dtype(np.complex128), - ) +def test_streaming_unsupported_dtype_raises(empty_temp_om_file): + writer = OmFileWriter(empty_temp_om_file) + with pytest.raises(ValueError, match="Unsupported array data type"): + writer.write_array_streaming( + dimensions=[10], + chunks=[5], + chunk_iterator=iter([]), + dtype=np.dtype(np.complex128), + ) From 5a5c1c8aed016041923bce5de8e0739ea9a033a8 Mon Sep 17 00:00:00 2001 From: terraputix Date: Tue, 24 Mar 2026 13:04:40 +0100 Subject: [PATCH 19/27] raise top level import error if dask not available --- python/omfiles/dask.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py index 06ab286..c8f404b 100644 --- a/python/omfiles/dask.py +++ b/python/omfiles/dask.py @@ -4,12 +4,14 @@ import itertools import math -from typing import TYPE_CHECKING, Optional, Sequence +from typing import Optional, Sequence from omfiles._rust import OmFileWriter, OmVariable -if TYPE_CHECKING: +try: import dask.array as da +except ImportError: + raise ImportError("omfiles[dask] is required for dask functionality") def _validate_chunk_alignment( @@ -120,8 +122,6 @@ def write_dask_array( ValueError: If dask chunks are incompatible with OM chunks. ImportError: If dask is not installed. """ - import dask.array as da - if not isinstance(data, da.Array): raise TypeError(f"Expected a dask array, got {type(data)}") From 0a37954e290c79ff1c89156134b10166049bb87f Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 07:46:20 +0100 Subject: [PATCH 20/27] use fixture for second empty file as well --- tests/conftest.py | 3 +++ tests/test_streaming_write.py | 41 ++++++++++++++++------------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fc984eb..6a614c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,6 +97,9 @@ def empty_temp_om_file(): warnings.warn(f"Failed to remove temporary file {filename}: {e}") +empty_temp_om_file_2 = empty_temp_om_file + + @pytest.fixture def icon_d2_meta_json() -> str: # return meta_str diff --git a/tests/test_streaming_write.py b/tests/test_streaming_write.py index 48e5cd5..3cfbd96 100644 --- a/tests/test_streaming_write.py +++ b/tests/test_streaming_write.py @@ -1,5 +1,3 @@ -import tempfile - import numpy as np import pytest from omfiles import OmFileReader, OmFileWriter @@ -175,7 +173,7 @@ def chunk_iter(): np.testing.assert_array_almost_equal(result, data, decimal=4) -def test_streaming_matches_write_array(empty_temp_om_file): +def test_streaming_matches_write_array(empty_temp_om_file, empty_temp_om_file_2): shape = (10, 20) chunks = [5, 10] data = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) @@ -187,27 +185,26 @@ def test_streaming_matches_write_array(empty_temp_om_file): result1 = reader1[:] reader1.close() - with tempfile.NamedTemporaryFile(suffix=".om") as f2: - writer2 = OmFileWriter(f2.name) + writer2 = OmFileWriter(empty_temp_om_file_2) - def chunk_iter(): - for i in range(0, shape[0], chunks[0]): - for j in range(0, shape[1], chunks[1]): - ie = min(i + chunks[0], shape[0]) - je = min(j + chunks[1], shape[1]) - yield data[i:ie, j:je].copy() + def chunk_iter(): + for i in range(0, shape[0], chunks[0]): + for j in range(0, shape[1], chunks[1]): + ie = min(i + chunks[0], shape[0]) + je = min(j + chunks[1], shape[1]) + yield data[i:ie, j:je].copy() - var2 = writer2.write_array_streaming( - dimensions=list(shape), - chunks=chunks, - chunk_iterator=chunk_iter(), - dtype=np.dtype(np.float32), - scale_factor=10000.0, - ) - writer2.close(var2) - reader2 = OmFileReader(f2.name) - result2 = reader2[:] - reader2.close() + var2 = writer2.write_array_streaming( + dimensions=list(shape), + chunks=chunks, + chunk_iterator=chunk_iter(), + dtype=np.dtype(np.float32), + scale_factor=10000.0, + ) + writer2.close(var2) + reader2 = OmFileReader(empty_temp_om_file_2) + result2 = reader2[:] + reader2.close() np.testing.assert_array_equal(result1, result2) From a857360c60bb001fff127b7cef06d52ca347dd15 Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 07:49:05 +0100 Subject: [PATCH 21/27] delete memory usage test --- tests/test_dask.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/tests/test_dask.py b/tests/test_dask.py index aeb5b75..42b0d1f 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -142,30 +142,3 @@ def test_dask_not_a_dask_array_raises(empty_temp_om_file): writer = OmFileWriter(empty_temp_om_file) with pytest.raises(TypeError, match="Expected a dask array"): write_dask_array(writer, np_data) # type: ignore[arg-type] - - -def test_dask_streaming_memory_stays_bounded(empty_temp_om_file): - """Peak memory during a dask streaming write stays well below the full dataset size.""" - # ~16 MB dataset (2048 x 2048 x float32), written in 256x256 chunks (~256 KB each) - side = 2048 - chunk = 256 - dtype = np.float32 - dataset_bytes = side * side * np.dtype(dtype).itemsize - - darr = da.random.random((side, side), chunks=(chunk, chunk)).astype(dtype) - - tracemalloc.start() - - writer = OmFileWriter(empty_temp_om_file) - var = write_dask_array(writer, darr) - writer.close(var) - - _, peak = tracemalloc.get_traced_memory() - tracemalloc.stop() - - # Peak Python memory should be a fraction of the total dataset size, - # proving that chunks are streamed rather than fully materialized. - assert peak < dataset_bytes, ( - f"Peak traced memory ({peak / 1024 / 1024:.1f} MB) should be less than " - f"the dataset size ({dataset_bytes / 1024 / 1024:.1f} MB)" - ) From f9ccd5e7434050f1acb5863aa43f5fd1d847a185 Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 07:54:33 +0100 Subject: [PATCH 22/27] consistency in doc comments --- python/omfiles/_rust/__init__.pyi | 2 +- src/writer.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/omfiles/_rust/__init__.pyi b/python/omfiles/_rust/__init__.pyi index 03279d9..32c5ef8 100644 --- a/python/omfiles/_rust/__init__.pyi +++ b/python/omfiles/_rust/__init__.pyi @@ -599,7 +599,7 @@ class OmFileWriter: dimensions: Shape of the full array (e.g., [1000, 2000]) chunks: Chunk sizes for each dimension (e.g., [100, 200]) chunk_iterator: Python iterable yielding numpy arrays, one per chunk region - dtype: Numpy dtype of the array. + dtype: Numpy dtype of the array (e.g., np.dtype(np.float32)) scale_factor: Scale factor for data compression (default: 1.0) add_offset: Offset value for data compression (default: 0.0) compression: Compression algorithm to use (default: "pfor_delta_2d") diff --git a/src/writer.rs b/src/writer.rs index 4504748..64e8de7 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -400,7 +400,7 @@ impl OmFileWriter { /// dimensions: Shape of the full array (e.g., [1000, 2000]) /// chunks: Chunk sizes for each dimension (e.g., [100, 200]) /// chunk_iterator: Python iterable yielding numpy arrays, one per chunk region - /// dtype: Numpy dtype of the array. + /// dtype: Numpy dtype of the array (e.g., np.dtype(np.float32)) /// scale_factor: Scale factor for data compression (default: 1.0) /// add_offset: Offset value for data compression (default: 0.0) /// compression: Compression algorithm to use (default: "pfor_delta_2d") From 7b0937732f7276ee472be6a3e5e5b4dcd7990ee6 Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 08:04:44 +0100 Subject: [PATCH 23/27] cleanup --- python/omfiles/dask.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py index c8f404b..0319f95 100644 --- a/python/omfiles/dask.py +++ b/python/omfiles/dask.py @@ -1,11 +1,12 @@ """Dask array integration for writing to OM files.""" -from __future__ import annotations - import itertools import math +from collections.abc import Generator from typing import Optional, Sequence +import numpy as np + from omfiles._rust import OmFileWriter, OmVariable try: @@ -60,7 +61,7 @@ def _validate_chunk_alignment( ) -def _dask_block_iterator(dask_array: da.Array): +def _dask_block_iterator(dask_array: da.Array) -> Generator[np.ndarray, None, None]: """ Yield computed numpy arrays from a dask array in C-order block traversal. @@ -120,7 +121,6 @@ def write_dask_array( Raises: TypeError: If data is not a dask array. ValueError: If dask chunks are incompatible with OM chunks. - ImportError: If dask is not installed. """ if not isinstance(data, da.Array): raise TypeError(f"Expected a dask array, got {type(data)}") @@ -139,5 +139,5 @@ def write_dask_array( add_offset=add_offset, compression=compression, name=name, - children=list(children) if children else [], + children=list(children) if children is not None else [], ) From d76ccb661154e5d81e1056993038df52838c44c7 Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 08:18:24 +0100 Subject: [PATCH 24/27] use ndindex over itertools --- python/omfiles/dask.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py index 0319f95..231b613 100644 --- a/python/omfiles/dask.py +++ b/python/omfiles/dask.py @@ -1,9 +1,7 @@ """Dask array integration for writing to OM files.""" -import itertools import math -from collections.abc import Generator -from typing import Optional, Sequence +from typing import Iterator, Optional, Sequence import numpy as np @@ -61,17 +59,15 @@ def _validate_chunk_alignment( ) -def _dask_block_iterator(dask_array: da.Array) -> Generator[np.ndarray, None, None]: +def _dask_block_iterator(dask_array: da.Array) -> Iterator[np.ndarray]: """ Yield computed numpy arrays from a dask array in C-order block traversal. The OM file format requires chunks to be written in sequential order corresponding to a row-major (C-order) traversal of the chunk grid. - ``itertools.product`` naturally produces this ordering since the last - index varies fastest. + ndindex does this: the last dimension is iterated over first. """ - block_index_ranges = [range(n) for n in dask_array.numblocks] - for block_indices in itertools.product(*block_index_ranges): + for block_indices in np.ndindex(*dask_array.numblocks): yield dask_array.blocks[block_indices].compute() @@ -125,14 +121,12 @@ def write_dask_array( if not isinstance(data, da.Array): raise TypeError(f"Expected a dask array, got {type(data)}") - if chunks is None: - chunks = [c[0] for c in data.chunks] - - _validate_chunk_alignment(data.chunks, list(chunks), data.shape) + om_chunks = list(chunks) if chunks is not None else [int(c[0]) for c in data.chunks] + _validate_chunk_alignment(data.chunks, om_chunks, data.shape) return writer.write_array_streaming( dimensions=[int(d) for d in data.shape], - chunks=[int(c) for c in chunks], + chunks=om_chunks, chunk_iterator=_dask_block_iterator(data), dtype=data.dtype, scale_factor=scale_factor, From eee9f03c82cffda3000001b06d98f20ca1e9e74a Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 08:31:58 +0100 Subject: [PATCH 25/27] guard against mismatching chunk shapes --- python/omfiles/dask.py | 7 +++++-- tests/test_dask.py | 19 ++++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py index 231b613..b217c70 100644 --- a/python/omfiles/dask.py +++ b/python/omfiles/dask.py @@ -65,7 +65,7 @@ def _dask_block_iterator(dask_array: da.Array) -> Iterator[np.ndarray]: The OM file format requires chunks to be written in sequential order corresponding to a row-major (C-order) traversal of the chunk grid. - ndindex does this: the last dimension is iterated over first. + np.ndindex yields indices in C-order: the last axis index varies fastest. """ for block_indices in np.ndindex(*dask_array.numblocks): yield dask_array.blocks[block_indices].compute() @@ -121,7 +121,10 @@ def write_dask_array( if not isinstance(data, da.Array): raise TypeError(f"Expected a dask array, got {type(data)}") - om_chunks = list(chunks) if chunks is not None else [int(c[0]) for c in data.chunks] + if chunks is not None and len(chunks) != data.ndim: + raise ValueError(f"chunks has {len(chunks)} element(s) but data has {data.ndim} dimension(s).") + + om_chunks: list[int] = list(chunks) if chunks is not None else [c[0] for c in data.chunks] _validate_chunk_alignment(data.chunks, om_chunks, data.shape) return writer.write_array_streaming( diff --git a/tests/test_dask.py b/tests/test_dask.py index 42b0d1f..9632536 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,5 +1,3 @@ -import tracemalloc - import dask.array as da import numpy as np import pytest @@ -128,7 +126,6 @@ def test_dask_single_om_chunk_per_slow_dim(empty_temp_om_file): def test_dask_misaligned_trailing_dims_raises(empty_temp_om_file): - """Dask blocks with multi-chunk dim 0 but partial trailing dim raises.""" np_data = np.arange(200, dtype=np.float32).reshape(10, 20) darr = da.from_array(np_data, chunks=(10, 10)) # type: ignore[arg-type] @@ -142,3 +139,19 @@ def test_dask_not_a_dask_array_raises(empty_temp_om_file): writer = OmFileWriter(empty_temp_om_file) with pytest.raises(TypeError, match="Expected a dask array"): write_dask_array(writer, np_data) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + "bad_chunks", + [ + pytest.param([5], id="too_few"), + pytest.param([5, 10, 4], id="too_many"), + ], +) +def test_dask_chunk_ndim_mismatch_raises(empty_temp_om_file, bad_chunks): + np_data = np.arange(200, dtype=np.float32).reshape(10, 20) + darr = da.from_array(np_data, chunks=(5, 10)) # type: ignore[arg-type] + + writer = OmFileWriter(empty_temp_om_file) + with pytest.raises(ValueError, match=r"chunks has \d+ element"): + write_dask_array(writer, darr, chunks=bad_chunks) From 12f861b273a87eb11d46f89c3034295b6fd865cb Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 08:59:36 +0100 Subject: [PATCH 26/27] improve iter type hint --- python/omfiles/_rust/__init__.pyi | 2 +- src/writer.rs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/omfiles/_rust/__init__.pyi b/python/omfiles/_rust/__init__.pyi index 32c5ef8..17e60f1 100644 --- a/python/omfiles/_rust/__init__.pyi +++ b/python/omfiles/_rust/__init__.pyi @@ -577,7 +577,7 @@ class OmFileWriter: self, dimensions: typing.Sequence[builtins.int], chunks: typing.Sequence[builtins.int], - chunk_iterator: typing.Any, + chunk_iterator: typing.Iterator, dtype: numpy.dtype, scale_factor: typing.Optional[builtins.float] = None, add_offset: typing.Optional[builtins.float] = None, diff --git a/src/writer.rs b/src/writer.rs index 64e8de7..6893379 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -15,6 +15,7 @@ use omfiles_rs::{ use pyo3::{ exceptions::{PyRuntimeError, PyStopIteration, PyValueError}, prelude::*, + types::PyIterator, }; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use std::{ @@ -422,7 +423,8 @@ impl OmFileWriter { py: Python<'_>, dimensions: Vec, chunks: Vec, - chunk_iterator: &Bound<'_, PyAny>, + #[gen_stub(override_type(type_repr="typing.Iterator", imports=("typing")))] + chunk_iterator: &Bound<'_, PyIterator>, dtype: &Bound<'py, PyArrayDescr>, scale_factor: Option, add_offset: Option, From 05d58590ec03cfd71bd5cfea6ac49e432bea8324 Mon Sep 17 00:00:00 2001 From: terraputix Date: Wed, 25 Mar 2026 09:30:12 +0100 Subject: [PATCH 27/27] validate all dimensions --- python/omfiles/dask.py | 9 ++++----- tests/test_dask.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/python/omfiles/dask.py b/python/omfiles/dask.py index b217c70..ffcc303 100644 --- a/python/omfiles/dask.py +++ b/python/omfiles/dask.py @@ -41,20 +41,19 @@ def _validate_chunk_alignment( first_multi = None for d in range(ndim): - local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) + local_n = max(math.ceil(c / om_chunks[d]) for c in data_chunks[d]) if local_n > 1: first_multi = d break if first_multi is not None: for d in range(first_multi + 1, ndim): - local_n = math.ceil(data_chunks[d][0] / om_chunks[d]) - global_n = math.ceil(array_shape[d] / om_chunks[d]) - if local_n != global_n: + dim_chunks = data_chunks[d] + if not (len(dim_chunks) == 1 and dim_chunks[0] == array_shape[d]): raise ValueError( f"Dask blocks have multiple OM chunks in dimension {first_multi}, " f"but dimension {d} is not fully covered by each dask block " - f"(dask chunk {data_chunks[d][0]} vs array size {array_shape[d]}). " + f"(dask chunks {dim_chunks} vs array size {array_shape[d]}). " f"Rechunk so trailing dimensions are fully covered." ) diff --git a/tests/test_dask.py b/tests/test_dask.py index 9632536..ea882f3 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -155,3 +155,41 @@ def test_dask_chunk_ndim_mismatch_raises(empty_temp_om_file, bad_chunks): writer = OmFileWriter(empty_temp_om_file) with pytest.raises(ValueError, match=r"chunks has \d+ element"): write_dask_array(writer, darr, chunks=bad_chunks) + + +def test_dask_irregular_chunks_misaligned_raises(empty_temp_om_file): + """ + Non-first dask block spans multiple OM chunks while trailing dim is not + fully covered. + + Array (12, 16), dask chunks ((4, 8), (8, 8)), OM chunks [4, 8]: + block (1,0) shape (8, 8) → 2 OM rows but only 1 of 2 OM columns. + """ + np_data = np.arange(192, dtype=np.float32).reshape(12, 16) + darr = da.from_array(np_data, chunks=((4, 8), (8, 8))) # type: ignore[arg-type] + + writer = OmFileWriter(empty_temp_om_file) + with pytest.raises(ValueError, match="not fully covered"): + write_dask_array(writer, darr, chunks=[4, 8]) + + +def test_dask_irregular_chunks_valid_roundtrip(empty_temp_om_file): + """ + Non-first dask block spans multiple OM chunks but trailing dim IS fully + covered — this configuration is valid and must produce correct output. + + Array (12, 16), dask chunks ((4, 8), (16,)), OM chunks [4, 8]: + block (1,0) shape (8, 16) → 2 OM rows and all OM columns — safe. + """ + np_data = np.arange(192, dtype=np.float32).reshape(12, 16) + darr = da.from_array(np_data, chunks=((4, 8), (16,))) # type: ignore[arg-type] + + writer = OmFileWriter(empty_temp_om_file) + var = write_dask_array(writer, darr, chunks=[4, 8], scale_factor=10000.0) + writer.close(var) + + reader = OmFileReader(empty_temp_om_file) + result = reader[:] + reader.close() + + np.testing.assert_array_almost_equal(result, np_data, decimal=4)