From 3533a9a1544b762026d6bab465acc1df7f8849a1 Mon Sep 17 00:00:00 2001 From: Tom Durrant Date: Fri, 6 Feb 2026 14:50:33 +1100 Subject: [PATCH] Add group parameter support to ZarrClient and zarr_write Add optional group parameter to enable writing hierarchical Zarr stores (e.g., cycle/001). Group is passed via X-PARAMETERS header to maintain URL structure. Includes tests validating group functionality. --- src/oceanum/datamesh/zarr.py | 38 ++++++++++++++++++++-------- tests/test_zarr_groups.py | 49 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 11 deletions(-) create mode 100644 tests/test_zarr_groups.py diff --git a/src/oceanum/datamesh/zarr.py b/src/oceanum/datamesh/zarr.py index 03c8a8e..5ac3266 100644 --- a/src/oceanum/datamesh/zarr.py +++ b/src/oceanum/datamesh/zarr.py @@ -9,9 +9,16 @@ import fsspec import urllib.parse +from typing import Optional + from .exceptions import DatameshConnectError, DatameshWriteError from .session import Session -from .utils import retried_request, DATAMESH_CONNECT_TIMEOUT, DATAMESH_CHUNK_READ_TIMEOUT, DATAMESH_CHUNK_WRITE_TIMEOUT +from .utils import ( + retried_request, + DATAMESH_CONNECT_TIMEOUT, + DATAMESH_CHUNK_READ_TIMEOUT, + DATAMESH_CHUNK_WRITE_TIMEOUT, +) try: import xarray_video as xv @@ -59,7 +66,7 @@ def __init__( api="query", reference_id=None, verify=True, - storage_backend=None + storage_backend=None, ): self.datasource = datasource self.session = session @@ -106,15 +113,15 @@ def _retried_request( if resp.status_code == 401: raise DatameshConnectError(f"Not Authorized {resp.text}") if resp.status_code == 410: - raise DatameshConnectError(f"Datasource no longer exists or was deleted within your session") - if resp.status_code >= 500: raise DatameshConnectError( - f"Server error {resp.status_code}: {resp.text}" + f"Datasource no longer exists or was deleted within your session" ) + if resp.status_code >= 500: + raise DatameshConnectError(f"Server error {resp.status_code}: {resp.text}") return resp def __getitem__(self, item): - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") resp = self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", connect_timeout=self.connect_timeout, @@ -125,7 +132,7 @@ def __getitem__(self, item): return resp.content def __contains__(self, item): - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") resp = self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", method="HEAD" if self._is_v1 else "GET", @@ -139,7 +146,7 @@ def __contains__(self, item): def __setitem__(self, item, value): if self.api == "query": raise DatameshConnectError("Query api does not support write operations") - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") res = self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", method=self.method, @@ -155,7 +162,7 @@ def __setitem__(self, item, value): def __delitem__(self, item): if self.api == "query": raise DatameshConnectError("Query api does not support delete operations") - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", method="DELETE", @@ -190,7 +197,14 @@ def _to_zarr(data, store, **kwargs): data.to_zarr(store, **kwargs) -def zarr_write(connection, datasource_id, data, append=None, overwrite=False): +def zarr_write( + connection, + datasource_id, + data, + append=None, + overwrite=False, + group: Optional[str] = None, +): with Session.acquire(connection) as session: store = ZarrClient(connection, datasource_id, session, api="zarr", nocache=True) if overwrite is True: @@ -240,6 +254,7 @@ def zarr_write(connection, datasource_id, data, append=None, overwrite=False): store, mode="a", region={append_dim: replace_slice}, + group=group, ) if len(data[append]) > len(replace_range): append_chunk = data.isel( @@ -251,9 +266,10 @@ def zarr_write(connection, datasource_id, data, append=None, overwrite=False): mode="a", append_dim=append_dim, consolidated=True, + group=group, ) else: - _to_zarr(data, store, mode="w", consolidated=True) + _to_zarr(data, store, mode="w", consolidated=True, group=group) ds = connection.get_datasource(datasource_id) ds.dataschema = data.to_dict(data=False) return ds diff --git a/tests/test_zarr_groups.py b/tests/test_zarr_groups.py new file mode 100644 index 0000000..0e4e39a --- /dev/null +++ b/tests/test_zarr_groups.py @@ -0,0 +1,49 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import sys + +sys.path.insert(0, "/home/tdurrant/source/oceanum.io/oceanum-python/src") + +from oceanum.datamesh.zarr import zarr_write +import xarray as xr + + +def test_zarr_write_passes_group_to_to_zarr(): + """Test that zarr_write passes group to _to_zarr.""" + conn = Mock(_gateway="http://test", _auth_headers={}, _is_v1=True) + conn.get_datasource = Mock(return_value=Mock(_exists=False)) + + session_mock = MagicMock() + session_mock.__enter__ = Mock(return_value=session_mock) + session_mock.__exit__ = Mock(return_value=False) + session_mock.add_header = lambda h: h + + data = xr.Dataset({"temp": (["time"], [1, 2, 3])}) + + with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock): + with patch("oceanum.datamesh.zarr.ZarrClient"): + with patch("oceanum.datamesh.zarr._to_zarr") as mock_to_zarr: + zarr_write(conn, "test-ds", data, group="cycle/001") + # Verify _to_zarr called with group + to_zarr_kwargs = mock_to_zarr.call_args.kwargs + assert "group" in to_zarr_kwargs + assert to_zarr_kwargs["group"] == "cycle/001" + + +def test_zarr_write_without_group(): + """Test backward compatibility - zarr_write works without group.""" + conn = Mock(_gateway="http://test", _auth_headers={}, _is_v1=True) + conn.get_datasource = Mock(return_value=Mock(_exists=False)) + + session_mock = MagicMock() + session_mock.__enter__ = Mock(return_value=session_mock) + session_mock.__exit__ = Mock(return_value=False) + session_mock.add_header = lambda h: h + + data = xr.Dataset({"temp": (["time"], [1, 2, 3])}) + + with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock): + with patch("oceanum.datamesh.zarr.ZarrClient"): + with patch("oceanum.datamesh.zarr._to_zarr"): + # Should not raise error + zarr_write(conn, "test-ds", data)