From c3fefdf3ac99eb22570fe19ce74afd2aeaf5a4aa Mon Sep 17 00:00:00 2001 From: Sebastien Delaux Date: Fri, 23 Jan 2026 08:58:56 +1300 Subject: [PATCH 1/4] Implementing fork safe http connection pool for datamesh connector and zarr client --- src/oceanum/datamesh/connection.py | 17 ++++++++- src/oceanum/datamesh/utils.py | 59 +++++++++++++++++++++++++++++- src/oceanum/datamesh/zarr.py | 33 ++++++++++++----- 3 files changed, 97 insertions(+), 12 deletions(-) diff --git a/src/oceanum/datamesh/connection.py b/src/oceanum/datamesh/connection.py index dfad41c..e8c6e35 100644 --- a/src/oceanum/datamesh/connection.py +++ b/src/oceanum/datamesh/connection.py @@ -35,6 +35,7 @@ from .session import Session from .utils import ( retried_request, + HTTPSession, DATAMESH_WRITE_TIMEOUT, DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT, @@ -42,6 +43,7 @@ ) from ..__init__ import __version__ + DEFAULT_CONFIG = {"DATAMESH_SERVICE": "https://datamesh.oceanum.io"} DASK_QUERY_SIZE = 1000000000 # 1GB @@ -114,10 +116,13 @@ def __init__( if not verify: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + self.http_session = HTTPSession() + self._check_info() if self._host.split(".")[-1] != self._gateway.split(".")[-1]: warnings.warn("Gateway and service domain do not match") + def _init_auth_headers(self, token: str | None, user: str | None = None): if token is not None: if token.startswith("Bearer "): @@ -149,6 +154,7 @@ def _status(self): f"{self._proto}://{self._host}", headers=self._auth_headers, verify=self._verify, + http_session=self.http_session, ) return resp.status_code == 200 @@ -166,6 +172,7 @@ def _check_info(self): headers=self._auth_headers, retries=1, verify=self._verify, + http_session=self.http_session, ) if resp.status_code == 200: r = resp.json() @@ -178,7 +185,7 @@ def _check_info(self): raise DatameshConnectError( f"Failed to reach datamesh: {resp.status_code}-{resp.text}" ) - except: + except Exception as e: _gateway = self._gateway or f"{self._proto}://gateway.{self._host}" self._gateway = _gateway self._is_v1 = False @@ -212,6 +219,7 @@ def _metadata_request(self, datasource_id="", params={}): params=params, headers=self._auth_headers, verify=self._verify, + http_session=self.http_session, ) if resp.status_code == 404: raise DatameshConnectError(f"Datasource {datasource_id} not found") @@ -232,6 +240,7 @@ def _metadata_write(self, datasource): data=data, headers=headers, verify=self._verify, + http_session=self.http_session, ) else: @@ -241,6 +250,7 @@ def _metadata_write(self, datasource): data=data, headers=headers, verify=self._verify, + http_session=self.http_session, ) self._validate_response(resp) return resp @@ -251,6 +261,7 @@ def _delete(self, datasource_id): method="DELETE", headers=self._auth_headers, verify=self._verify, + http_session=self.http_session, ) self._validate_response(resp) return True @@ -262,6 +273,7 @@ def _data_request(self, datasource_id, data_format="application/json", cache=Fal headers={"Accept": data_format, **self._auth_headers}, timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT), verify=self._verify, + http_session=self.http_session, ) self._validate_response(resp) with open(tmpfile, "wb") as f: @@ -286,6 +298,7 @@ def _data_write( headers={"Content-Type": data_format, **self._auth_headers}, timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT), verify=self._verify, + http_session=self.http_session, ) else: headers = {"Content-Type": data_format, **self._auth_headers} @@ -298,6 +311,7 @@ def _data_write( headers=headers, timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT), verify=self._verify, + http_session=self.http_session, ) self._validate_response(resp) return Datasource(**resp.json()) @@ -314,6 +328,7 @@ def _stage_request(self, query, session, cache=False): data=query.model_dump_json(warnings=False), timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_STAGE_READ_TIMEOUT), verify=self._verify, + http_session=self.http_session, ) if resp.status_code >= 400: try: diff --git a/src/oceanum/datamesh/utils.py b/src/oceanum/datamesh/utils.py index b80748d..e5d4ff8 100644 --- a/src/oceanum/datamesh/utils.py +++ b/src/oceanum/datamesh/utils.py @@ -1,5 +1,6 @@ from time import sleep import requests +from requests.adapters import HTTPAdapter import numpy as np from .exceptions import DatameshConnectError import os @@ -52,6 +53,58 @@ ) +class HTTPSession: + """ + A requests.Session wrapper that is safe to use across forked processes + Attributes + ---------- + pool_size : int, optional + The size of the connection pool, by default None + Methods + ------- + session : requests.Session + Returns a requests.Session object that is safe to use in the current process + __getstate__ : dict + Returns the state of the object for pickling + __setstate__ : None + Restores the state of the object from pickling + """ + + def __init__(self, pool_size=os.environ.get("DATAMESH_CONNECTION_POOL_SIZE", 100)): + self._session = None + self._pid = None + self._pool_size = int(pool_size) + + def _create_session(self): + session = requests.Session() + adapter = HTTPAdapter( + pool_connections=self._pool_size, + pool_maxsize=self._pool_size + ) + session.mount('https://', adapter) + session.mount('http://', adapter) + return session + + @property + def session(self): + if self._session is None or self._pid != os.getpid(): + self._pid = os.getpid() + self._session = self._create_session() + return self._session + + def __getstate__(self): + state = self.__dict__.copy() + state["_session"] = None + state["_pid"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def request(self, method, url, *args, **kwargs): + return self.session.request(method, url, *args, **kwargs) + + def retried_request( url, method="GET", @@ -61,6 +114,7 @@ def retried_request( retries=8, timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_READ_TIMEOUT), verify=True, + http_session: HTTPSession = None, ): """ Retried request function with exponential backoff @@ -79,6 +133,8 @@ def retried_request( Number of retries, by default 8 timeout : tupe(float, float), optional Request connect and read timeout in seconds, by default (3.05, 10) + http_session : HTTPSession, optional + Session object to use for request Returns ------- @@ -91,10 +147,11 @@ def retried_request( If request fails """ + requester = http_session if http_session else requests retried = 0 while retried < retries: try: - resp = requests.request( + resp = requester.request( method=method, url=url, data=data, diff --git a/src/oceanum/datamesh/zarr.py b/src/oceanum/datamesh/zarr.py index 03c8a8e..4d69dc2 100644 --- a/src/oceanum/datamesh/zarr.py +++ b/src/oceanum/datamesh/zarr.py @@ -8,10 +8,17 @@ import xarray import fsspec import urllib.parse +import requests from .exceptions import DatameshConnectError, DatameshWriteError from .session import Session -from .utils import retried_request, DATAMESH_CONNECT_TIMEOUT, DATAMESH_CHUNK_READ_TIMEOUT, DATAMESH_CHUNK_WRITE_TIMEOUT +from .utils import ( + retried_request, + HTTPSession, + DATAMESH_CONNECT_TIMEOUT, + DATAMESH_CHUNK_READ_TIMEOUT, + DATAMESH_CHUNK_WRITE_TIMEOUT, +) try: import xarray_video as xv @@ -85,6 +92,7 @@ def __init__( self.verify = verify if storage_backend is not None: self.headers["X-DATAMESH-STORAGE-BACKEND"] = storage_backend + self.http_session = HTTPSession() def _retried_request( self, @@ -94,15 +102,20 @@ def _retried_request( connect_timeout=DATAMESH_CONNECT_TIMEOUT, read_timeout=DATAMESH_CHUNK_READ_TIMEOUT, ): - resp = retried_request( - url=path, - method=method, - data=data, - headers=self.headers, - retries=self.retries, - timeout=(connect_timeout, read_timeout), - verify=self.verify, - ) + try: + resp = retried_request( + url=path, + method=method, + data=data, + headers=self.headers, + retries=self.retries, + timeout=(connect_timeout, read_timeout), + verify=self.verify, + http_session=self.http_session, + ) + except requests.RequestException as e: + raise DatameshConnectError(str(e)) + if resp.status_code == 401: raise DatameshConnectError(f"Not Authorized {resp.text}") if resp.status_code == 410: From 6acf9f44197daa8463edfd91f00d2c967f35b1d8 Mon Sep 17 00:00:00 2001 From: Sebastien Delaux Date: Fri, 23 Jan 2026 09:41:16 +1300 Subject: [PATCH 2/4] Trying to fix test --- tests/test_verify_parameter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_verify_parameter.py b/tests/test_verify_parameter.py index 8e4aa46..967d7e2 100644 --- a/tests/test_verify_parameter.py +++ b/tests/test_verify_parameter.py @@ -12,8 +12,8 @@ @pytest.fixture def mock_request(): - """Mock the requests.request function to check verify parameter""" - with patch("requests.request") as mock_req: + """Mock the requests.Session.request function to check verify parameter""" + with patch("requests.Session.request") as mock_req: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"id": "test-id", "name": "test"} From 6ac0223410a5712c5d7e4342e55014ee42b188fe Mon Sep 17 00:00:00 2001 From: Sebastien Delaux Date: Wed, 28 Jan 2026 15:43:00 +1300 Subject: [PATCH 3/4] Setting default headers + retried_request wrapper to simplify code --- src/oceanum/datamesh/connection.py | 75 ++++++++++++------------------ src/oceanum/datamesh/utils.py | 7 ++- src/oceanum/datamesh/zarr.py | 3 +- 3 files changed, 36 insertions(+), 49 deletions(-) diff --git a/src/oceanum/datamesh/connection.py b/src/oceanum/datamesh/connection.py index e8c6e35..a396bb8 100644 --- a/src/oceanum/datamesh/connection.py +++ b/src/oceanum/datamesh/connection.py @@ -116,7 +116,7 @@ def __init__( if not verify: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - self.http_session = HTTPSession() + self.http_session = HTTPSession(headers=self._auth_headers) self._check_info() if self._host.split(".")[-1] != self._gateway.split(".")[-1]: @@ -139,6 +139,15 @@ def _init_auth_headers(self, token: str | None, user: str | None = None): "A valid key must be supplied as a connection constructor argument or defined in environment variables as DATAMESH_TOKEN" ) + def _retried_request(self, *args, **kwargs): + """Wrapper around retried_request to use connection settings""" + return retried_request( + *args, + verify=self._verify, + http_session=self.http_session, + **kwargs, + ) + @property def host(self): """Datamesh host @@ -150,11 +159,8 @@ def host(self): # Check the status of the metadata server def _status(self): - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}", - headers=self._auth_headers, - verify=self._verify, - http_session=self.http_session, ) return resp.status_code == 200 @@ -167,12 +173,9 @@ def _check_info(self): _gateway = self._gateway or f"{self._proto}://{self._host}" try: - resp = retried_request( + resp = self._retried_request( f"{_gateway}/info/oceanum_python/{__version__}", - headers=self._auth_headers, retries=1, - verify=self._verify, - http_session=self.http_session, ) if resp.status_code == 200: r = resp.json() @@ -191,11 +194,9 @@ def _check_info(self): self._is_v1 = False print("Using datamesh API version 0") try: - resp = retried_request( + resp = self._retried_request( f"https://datamesh-v1.oceanum.io/info/oceanum_python/{__version__}", - headers=self._auth_headers, retries=1, - verify=self._verify, ) if resp.status_code == 200: r = resp.json() @@ -214,12 +215,9 @@ def _validate_response(self, resp): raise DatameshConnectError(msg) def _metadata_request(self, datasource_id="", params={}): - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}/datasource/{datasource_id}", params=params, - headers=self._auth_headers, - verify=self._verify, - http_session=self.http_session, ) if resp.status_code == 404: raise DatameshConnectError(f"Datasource {datasource_id} not found") @@ -232,48 +230,39 @@ def _metadata_write(self, datasource): data = datasource.model_dump_json(by_alias=True, warnings=False).encode( "utf-8", "ignore" ) - headers = {**self._auth_headers, "Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} if datasource._exists: - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}/datasource/{datasource.id}/", method="PATCH", data=data, headers=headers, - verify=self._verify, - http_session=self.http_session, ) else: - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}/datasource/", method="POST", data=data, headers=headers, - verify=self._verify, - http_session=self.http_session, ) self._validate_response(resp) return resp def _delete(self, datasource_id): - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", method="DELETE", - headers=self._auth_headers, - verify=self._verify, - http_session=self.http_session, ) self._validate_response(resp) return True def _data_request(self, datasource_id, data_format="application/json", cache=False): tmpfile = os.path.join(self._cachedir.name, datasource_id) - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", - headers={"Accept": data_format, **self._auth_headers}, + headers={"Accept": data_format}, timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT), - verify=self._verify, - http_session=self.http_session, ) self._validate_response(resp) with open(tmpfile, "wb") as f: @@ -290,28 +279,24 @@ def _data_write( ): # Connection timeout does not act in the same way in write and read contexts # and using a short connection timeout in write contexts leads to closed connections + headers = {"Content-Type": data_format} if overwrite: - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", method="PUT", data=data, - headers={"Content-Type": data_format, **self._auth_headers}, + headers=headers, timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT), - verify=self._verify, - http_session=self.http_session, ) else: - headers = {"Content-Type": data_format, **self._auth_headers} if append: headers["X-Append"] = str(append) - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", method="PATCH", data=data, headers=headers, timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT), - verify=self._verify, - http_session=self.http_session, ) self._validate_response(resp) return Datasource(**resp.json()) @@ -321,14 +306,12 @@ def _stage_request(self, query, session, cache=False): query.model_dump_json(warnings=False).encode() ).hexdigest() - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/oceanql/stage/", method="POST", - headers=session.add_header(self._auth_headers), + headers=session.header, data=query.model_dump_json(warnings=False), timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_STAGE_READ_TIMEOUT), - verify=self._verify, - http_session=self.http_session, ) if resp.status_code >= 400: try: @@ -385,14 +368,14 @@ def _query(self, query, use_dask=False, cache_timeout=0, retry=0): if stage.container == Container.Dataset else "application/parquet" ) - headers = {"Accept": transfer_format, **self._auth_headers} - resp = retried_request( + headers = {"Accept": transfer_format, + **session.header} + resp = self._retried_request( f"{self._gateway}/oceanql/", method="POST", headers=headers, data=query.model_dump_json(warnings=False), timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT), - verify=self._verify, ) if resp.status_code > 500: if cache_timeout: diff --git a/src/oceanum/datamesh/utils.py b/src/oceanum/datamesh/utils.py index e5d4ff8..efe31ee 100644 --- a/src/oceanum/datamesh/utils.py +++ b/src/oceanum/datamesh/utils.py @@ -60,6 +60,8 @@ class HTTPSession: ---------- pool_size : int, optional The size of the connection pool, by default None + headers : dict, optional + Default headers to include in each request, by default None Methods ------- session : requests.Session @@ -70,10 +72,11 @@ class HTTPSession: Restores the state of the object from pickling """ - def __init__(self, pool_size=os.environ.get("DATAMESH_CONNECTION_POOL_SIZE", 100)): + def __init__(self, pool_size=os.environ.get("DATAMESH_CONNECTION_POOL_SIZE", 100), headers=None): self._session = None self._pid = None self._pool_size = int(pool_size) + self._headers = headers def _create_session(self): session = requests.Session() @@ -83,6 +86,8 @@ def _create_session(self): ) session.mount('https://', adapter) session.mount('http://', adapter) + if self._headers: + session.headers.update(self._headers) return session @property diff --git a/src/oceanum/datamesh/zarr.py b/src/oceanum/datamesh/zarr.py index 4d69dc2..4674a8a 100644 --- a/src/oceanum/datamesh/zarr.py +++ b/src/oceanum/datamesh/zarr.py @@ -92,7 +92,7 @@ def __init__( self.verify = verify if storage_backend is not None: self.headers["X-DATAMESH-STORAGE-BACKEND"] = storage_backend - self.http_session = HTTPSession() + self.http_session = HTTPSession(headers=self.headers) def _retried_request( self, @@ -107,7 +107,6 @@ def _retried_request( url=path, method=method, data=data, - headers=self.headers, retries=self.retries, timeout=(connect_timeout, read_timeout), verify=self.verify, From d31b9964edbd0d13b1e4c838e5f0687fc529125e Mon Sep 17 00:00:00 2001 From: Sebastien Delaux Date: Wed, 28 Jan 2026 17:05:59 +1300 Subject: [PATCH 4/4] Making V1 the default --- src/oceanum/datamesh/connection.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/oceanum/datamesh/connection.py b/src/oceanum/datamesh/connection.py index a396bb8..b837e5a 100644 --- a/src/oceanum/datamesh/connection.py +++ b/src/oceanum/datamesh/connection.py @@ -170,41 +170,31 @@ def _check_info(self): Typically will ask to update the client if the version is outdated. Also will try to guess gateway address if not provided. """ - _gateway = self._gateway or f"{self._proto}://{self._host}" + self._is_v1 = True try: resp = self._retried_request( f"{_gateway}/info/oceanum_python/{__version__}", - retries=1, + retries=5, ) if resp.status_code == 200: r = resp.json() if "message" in r: print(r["message"]) - print("Using datamesh API version 1") self._gateway = _gateway - self._is_v1 = True + return + elif resp.status_code == 404: + print("Using datamesh API version 0") + self._is_v1 = False + self._gateway = self._gateway or f"{self._proto}://gateway.{self._host}" return raise DatameshConnectError( f"Failed to reach datamesh: {resp.status_code}-{resp.text}" ) except Exception as e: - _gateway = self._gateway or f"{self._proto}://gateway.{self._host}" + warnings.warn(f"Failed to reach datamesh gateway at {_gateway}: {e}") + warnings.warn("Assuming datamesh API version 1") self._gateway = _gateway - self._is_v1 = False - print("Using datamesh API version 0") - try: - resp = self._retried_request( - f"https://datamesh-v1.oceanum.io/info/oceanum_python/{__version__}", - retries=1, - ) - if resp.status_code == 200: - r = resp.json() - if "message" in r: - print(r["message"]) - except: - pass - return def _validate_response(self, resp): if resp.status_code >= 400: