diff --git a/src/oceanum/datamesh/zarr.py b/src/oceanum/datamesh/zarr.py index 4674a8a..ecb58ec 100644 --- a/src/oceanum/datamesh/zarr.py +++ b/src/oceanum/datamesh/zarr.py @@ -10,6 +10,8 @@ import urllib.parse import requests +from typing import Optional + from .exceptions import DatameshConnectError, DatameshWriteError from .session import Session from .utils import ( @@ -66,7 +68,7 @@ def __init__( api="query", reference_id=None, verify=True, - storage_backend=None + storage_backend=None, ): self.datasource = datasource self.session = session @@ -118,15 +120,15 @@ def _retried_request( if resp.status_code == 401: raise DatameshConnectError(f"Not Authorized {resp.text}") if resp.status_code == 410: - raise DatameshConnectError(f"Datasource no longer exists or was deleted within your session") - if resp.status_code >= 500: raise DatameshConnectError( - f"Server error {resp.status_code}: {resp.text}" + f"Datasource no longer exists or was deleted within your session" ) + if resp.status_code >= 500: + raise DatameshConnectError(f"Server error {resp.status_code}: {resp.text}") return resp def __getitem__(self, item): - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") resp = self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", connect_timeout=self.connect_timeout, @@ -137,7 +139,7 @@ def __getitem__(self, item): return resp.content def __contains__(self, item): - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") resp = self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", method="HEAD" if self._is_v1 else "GET", @@ -151,7 +153,7 @@ def __contains__(self, item): def __setitem__(self, item, value): if self.api == "query": raise DatameshConnectError("Query api does not support write operations") - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") res = self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", method=self.method, @@ -167,7 +169,7 @@ def __setitem__(self, item, value): def __delitem__(self, item): if self.api == "query": raise DatameshConnectError("Query api does not support delete operations") - encoded_item = urllib.parse.quote(item, safe='/') + encoded_item = urllib.parse.quote(item, safe="/") self._retried_request( f"{self._proxy}/{self.datasource}/{encoded_item}", method="DELETE", @@ -202,7 +204,14 @@ def _to_zarr(data, store, **kwargs): data.to_zarr(store, **kwargs) -def zarr_write(connection, datasource_id, data, append=None, overwrite=False): +def zarr_write( + connection, + datasource_id, + data, + append=None, + overwrite=False, + group: Optional[str] = None, +): with Session.acquire(connection) as session: store = ZarrClient(connection, datasource_id, session, api="zarr", nocache=True) if overwrite is True: @@ -252,6 +261,7 @@ def zarr_write(connection, datasource_id, data, append=None, overwrite=False): store, mode="a", region={append_dim: replace_slice}, + group=group, ) if len(data[append]) > len(replace_range): append_chunk = data.isel( @@ -263,9 +273,10 @@ def zarr_write(connection, datasource_id, data, append=None, overwrite=False): mode="a", append_dim=append_dim, consolidated=True, + group=group, ) else: - _to_zarr(data, store, mode="w", consolidated=True) + _to_zarr(data, store, mode="w", consolidated=True, group=group) ds = connection.get_datasource(datasource_id) ds.dataschema = data.to_dict(data=False) return ds diff --git a/tests/test_zarr_groups.py b/tests/test_zarr_groups.py new file mode 100644 index 0000000..0e4e39a --- /dev/null +++ b/tests/test_zarr_groups.py @@ -0,0 +1,49 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +import sys + +sys.path.insert(0, "/home/tdurrant/source/oceanum.io/oceanum-python/src") + +from oceanum.datamesh.zarr import zarr_write +import xarray as xr + + +def test_zarr_write_passes_group_to_to_zarr(): + """Test that zarr_write passes group to _to_zarr.""" + conn = Mock(_gateway="http://test", _auth_headers={}, _is_v1=True) + conn.get_datasource = Mock(return_value=Mock(_exists=False)) + + session_mock = MagicMock() + session_mock.__enter__ = Mock(return_value=session_mock) + session_mock.__exit__ = Mock(return_value=False) + session_mock.add_header = lambda h: h + + data = xr.Dataset({"temp": (["time"], [1, 2, 3])}) + + with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock): + with patch("oceanum.datamesh.zarr.ZarrClient"): + with patch("oceanum.datamesh.zarr._to_zarr") as mock_to_zarr: + zarr_write(conn, "test-ds", data, group="cycle/001") + # Verify _to_zarr called with group + to_zarr_kwargs = mock_to_zarr.call_args.kwargs + assert "group" in to_zarr_kwargs + assert to_zarr_kwargs["group"] == "cycle/001" + + +def test_zarr_write_without_group(): + """Test backward compatibility - zarr_write works without group.""" + conn = Mock(_gateway="http://test", _auth_headers={}, _is_v1=True) + conn.get_datasource = Mock(return_value=Mock(_exists=False)) + + session_mock = MagicMock() + session_mock.__enter__ = Mock(return_value=session_mock) + session_mock.__exit__ = Mock(return_value=False) + session_mock.add_header = lambda h: h + + data = xr.Dataset({"temp": (["time"], [1, 2, 3])}) + + with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock): + with patch("oceanum.datamesh.zarr.ZarrClient"): + with patch("oceanum.datamesh.zarr._to_zarr"): + # Should not raise error + zarr_write(conn, "test-ds", data)