Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions src/oceanum/datamesh/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import urllib.parse
import requests

from typing import Optional

from .exceptions import DatameshConnectError, DatameshWriteError
from .session import Session
from .utils import (
Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(
api="query",
reference_id=None,
verify=True,
storage_backend=None
storage_backend=None,
):
self.datasource = datasource
self.session = session
Expand Down Expand Up @@ -118,15 +120,15 @@ def _retried_request(
if resp.status_code == 401:
raise DatameshConnectError(f"Not Authorized {resp.text}")
if resp.status_code == 410:
raise DatameshConnectError(f"Datasource no longer exists or was deleted within your session")
if resp.status_code >= 500:
raise DatameshConnectError(
f"Server error {resp.status_code}: {resp.text}"
f"Datasource no longer exists or was deleted within your session"
)
if resp.status_code >= 500:
raise DatameshConnectError(f"Server error {resp.status_code}: {resp.text}")
return resp

def __getitem__(self, item):
encoded_item = urllib.parse.quote(item, safe='/')
encoded_item = urllib.parse.quote(item, safe="/")
resp = self._retried_request(
f"{self._proxy}/{self.datasource}/{encoded_item}",
connect_timeout=self.connect_timeout,
Expand All @@ -137,7 +139,7 @@ def __getitem__(self, item):
return resp.content

def __contains__(self, item):
encoded_item = urllib.parse.quote(item, safe='/')
encoded_item = urllib.parse.quote(item, safe="/")
resp = self._retried_request(
f"{self._proxy}/{self.datasource}/{encoded_item}",
method="HEAD" if self._is_v1 else "GET",
Expand All @@ -151,7 +153,7 @@ def __contains__(self, item):
def __setitem__(self, item, value):
if self.api == "query":
raise DatameshConnectError("Query api does not support write operations")
encoded_item = urllib.parse.quote(item, safe='/')
encoded_item = urllib.parse.quote(item, safe="/")
res = self._retried_request(
f"{self._proxy}/{self.datasource}/{encoded_item}",
method=self.method,
Expand All @@ -167,7 +169,7 @@ def __setitem__(self, item, value):
def __delitem__(self, item):
if self.api == "query":
raise DatameshConnectError("Query api does not support delete operations")
encoded_item = urllib.parse.quote(item, safe='/')
encoded_item = urllib.parse.quote(item, safe="/")
self._retried_request(
f"{self._proxy}/{self.datasource}/{encoded_item}",
method="DELETE",
Expand Down Expand Up @@ -202,7 +204,14 @@ def _to_zarr(data, store, **kwargs):
data.to_zarr(store, **kwargs)


def zarr_write(connection, datasource_id, data, append=None, overwrite=False):
def zarr_write(
connection,
datasource_id,
data,
append=None,
overwrite=False,
group: Optional[str] = None,
):
with Session.acquire(connection) as session:
store = ZarrClient(connection, datasource_id, session, api="zarr", nocache=True)
if overwrite is True:
Expand Down Expand Up @@ -252,6 +261,7 @@ def zarr_write(connection, datasource_id, data, append=None, overwrite=False):
store,
mode="a",
region={append_dim: replace_slice},
group=group,
)
if len(data[append]) > len(replace_range):
append_chunk = data.isel(
Expand All @@ -263,9 +273,10 @@ def zarr_write(connection, datasource_id, data, append=None, overwrite=False):
mode="a",
append_dim=append_dim,
consolidated=True,
group=group,
)
else:
_to_zarr(data, store, mode="w", consolidated=True)
_to_zarr(data, store, mode="w", consolidated=True, group=group)
ds = connection.get_datasource(datasource_id)
ds.dataschema = data.to_dict(data=False)
return ds
49 changes: 49 additions & 0 deletions tests/test_zarr_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from unittest.mock import Mock, patch, MagicMock
import sys

sys.path.insert(0, "/home/tdurrant/source/oceanum.io/oceanum-python/src")

from oceanum.datamesh.zarr import zarr_write
import xarray as xr


def test_zarr_write_passes_group_to_to_zarr():
"""Test that zarr_write passes group to _to_zarr."""
conn = Mock(_gateway="http://test", _auth_headers={}, _is_v1=True)
conn.get_datasource = Mock(return_value=Mock(_exists=False))

session_mock = MagicMock()
session_mock.__enter__ = Mock(return_value=session_mock)
session_mock.__exit__ = Mock(return_value=False)
session_mock.add_header = lambda h: h

data = xr.Dataset({"temp": (["time"], [1, 2, 3])})

with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock):
with patch("oceanum.datamesh.zarr.ZarrClient"):
with patch("oceanum.datamesh.zarr._to_zarr") as mock_to_zarr:
zarr_write(conn, "test-ds", data, group="cycle/001")
# Verify _to_zarr called with group
to_zarr_kwargs = mock_to_zarr.call_args.kwargs
assert "group" in to_zarr_kwargs
assert to_zarr_kwargs["group"] == "cycle/001"


def test_zarr_write_without_group():
"""Test backward compatibility - zarr_write works without group."""
conn = Mock(_gateway="http://test", _auth_headers={}, _is_v1=True)
conn.get_datasource = Mock(return_value=Mock(_exists=False))

session_mock = MagicMock()
session_mock.__enter__ = Mock(return_value=session_mock)
session_mock.__exit__ = Mock(return_value=False)
session_mock.add_header = lambda h: h

data = xr.Dataset({"temp": (["time"], [1, 2, 3])})

with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock):
with patch("oceanum.datamesh.zarr.ZarrClient"):
with patch("oceanum.datamesh.zarr._to_zarr"):
# Should not raise error
zarr_write(conn, "test-ds", data)