diff --git a/src/oceanum/datamesh/connection.py b/src/oceanum/datamesh/connection.py index dfad41c..b837e5a 100644 --- a/src/oceanum/datamesh/connection.py +++ b/src/oceanum/datamesh/connection.py @@ -35,6 +35,7 @@ from .session import Session from .utils import ( retried_request, + HTTPSession, DATAMESH_WRITE_TIMEOUT, DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT, @@ -42,6 +43,7 @@ ) from ..__init__ import __version__ + DEFAULT_CONFIG = {"DATAMESH_SERVICE": "https://datamesh.oceanum.io"} DASK_QUERY_SIZE = 1000000000 # 1GB @@ -114,10 +116,13 @@ def __init__( if not verify: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + self.http_session = HTTPSession(headers=self._auth_headers) + self._check_info() if self._host.split(".")[-1] != self._gateway.split(".")[-1]: warnings.warn("Gateway and service domain do not match") + def _init_auth_headers(self, token: str | None, user: str | None = None): if token is not None: if token.startswith("Bearer "): @@ -134,6 +139,15 @@ def _init_auth_headers(self, token: str | None, user: str | None = None): "A valid key must be supplied as a connection constructor argument or defined in environment variables as DATAMESH_TOKEN" ) + def _retried_request(self, *args, **kwargs): + """Wrapper around retried_request to use connection settings""" + return retried_request( + *args, + verify=self._verify, + http_session=self.http_session, + **kwargs, + ) + @property def host(self): """Datamesh host @@ -145,10 +159,8 @@ def host(self): # Check the status of the metadata server def _status(self): - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}", - headers=self._auth_headers, - verify=self._verify, ) return resp.status_code == 200 @@ -158,45 +170,31 @@ def _check_info(self): Typically will ask to update the client if the version is outdated. Also will try to guess gateway address if not provided. """ - _gateway = self._gateway or f"{self._proto}://{self._host}" + self._is_v1 = True try: - resp = retried_request( + resp = self._retried_request( f"{_gateway}/info/oceanum_python/{__version__}", - headers=self._auth_headers, - retries=1, - verify=self._verify, + retries=5, ) if resp.status_code == 200: r = resp.json() if "message" in r: print(r["message"]) - print("Using datamesh API version 1") self._gateway = _gateway - self._is_v1 = True + return + elif resp.status_code == 404: + print("Using datamesh API version 0") + self._is_v1 = False + self._gateway = self._gateway or f"{self._proto}://gateway.{self._host}" return raise DatameshConnectError( f"Failed to reach datamesh: {resp.status_code}-{resp.text}" ) - except: - _gateway = self._gateway or f"{self._proto}://gateway.{self._host}" + except Exception as e: + warnings.warn(f"Failed to reach datamesh gateway at {_gateway}: {e}") + warnings.warn("Assuming datamesh API version 1") self._gateway = _gateway - self._is_v1 = False - print("Using datamesh API version 0") - try: - resp = retried_request( - f"https://datamesh-v1.oceanum.io/info/oceanum_python/{__version__}", - headers=self._auth_headers, - retries=1, - verify=self._verify, - ) - if resp.status_code == 200: - r = resp.json() - if "message" in r: - print(r["message"]) - except: - pass - return def _validate_response(self, resp): if resp.status_code >= 400: @@ -207,11 +205,9 @@ def _validate_response(self, resp): raise DatameshConnectError(msg) def _metadata_request(self, datasource_id="", params={}): - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}/datasource/{datasource_id}", params=params, - headers=self._auth_headers, - verify=self._verify, ) if resp.status_code == 404: raise DatameshConnectError(f"Datasource {datasource_id} not found") @@ -224,44 +220,39 @@ def _metadata_write(self, datasource): data = datasource.model_dump_json(by_alias=True, warnings=False).encode( "utf-8", "ignore" ) - headers = {**self._auth_headers, "Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} if datasource._exists: - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}/datasource/{datasource.id}/", method="PATCH", data=data, headers=headers, - verify=self._verify, ) else: - resp = retried_request( + resp = self._retried_request( f"{self._proto}://{self._host}/datasource/", method="POST", data=data, headers=headers, - verify=self._verify, ) self._validate_response(resp) return resp def _delete(self, datasource_id): - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", method="DELETE", - headers=self._auth_headers, - verify=self._verify, ) self._validate_response(resp) return True def _data_request(self, datasource_id, data_format="application/json", cache=False): tmpfile = os.path.join(self._cachedir.name, datasource_id) - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", - headers={"Accept": data_format, **self._auth_headers}, + headers={"Accept": data_format}, timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT), - verify=self._verify, ) self._validate_response(resp) with open(tmpfile, "wb") as f: @@ -278,26 +269,24 @@ def _data_write( ): # Connection timeout does not act in the same way in write and read contexts # and using a short connection timeout in write contexts leads to closed connections + headers = {"Content-Type": data_format} if overwrite: - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", method="PUT", data=data, - headers={"Content-Type": data_format, **self._auth_headers}, + headers=headers, timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT), - verify=self._verify, ) else: - headers = {"Content-Type": data_format, **self._auth_headers} if append: headers["X-Append"] = str(append) - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/data/{datasource_id}", method="PATCH", data=data, headers=headers, timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT), - verify=self._verify, ) self._validate_response(resp) return Datasource(**resp.json()) @@ -307,13 +296,12 @@ def _stage_request(self, query, session, cache=False): query.model_dump_json(warnings=False).encode() ).hexdigest() - resp = retried_request( + resp = self._retried_request( f"{self._gateway}/oceanql/stage/", method="POST", - headers=session.add_header(self._auth_headers), + headers=session.header, data=query.model_dump_json(warnings=False), timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_STAGE_READ_TIMEOUT), - verify=self._verify, ) if resp.status_code >= 400: try: @@ -370,14 +358,14 @@ def _query(self, query, use_dask=False, cache_timeout=0, retry=0): if stage.container == Container.Dataset else "application/parquet" ) - headers = {"Accept": transfer_format, **self._auth_headers} - resp = retried_request( + headers = {"Accept": transfer_format, + **session.header} + resp = self._retried_request( f"{self._gateway}/oceanql/", method="POST", headers=headers, data=query.model_dump_json(warnings=False), timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT), - verify=self._verify, ) if resp.status_code > 500: if cache_timeout: diff --git a/src/oceanum/datamesh/utils.py b/src/oceanum/datamesh/utils.py index b80748d..efe31ee 100644 --- a/src/oceanum/datamesh/utils.py +++ b/src/oceanum/datamesh/utils.py @@ -1,5 +1,6 @@ from time import sleep import requests +from requests.adapters import HTTPAdapter import numpy as np from .exceptions import DatameshConnectError import os @@ -52,6 +53,63 @@ ) +class HTTPSession: + """ + A requests.Session wrapper that is safe to use across forked processes + Attributes + ---------- + pool_size : int, optional + The size of the connection pool, by default None + headers : dict, optional + Default headers to include in each request, by default None + Methods + ------- + session : requests.Session + Returns a requests.Session object that is safe to use in the current process + __getstate__ : dict + Returns the state of the object for pickling + __setstate__ : None + Restores the state of the object from pickling + """ + + def __init__(self, pool_size=os.environ.get("DATAMESH_CONNECTION_POOL_SIZE", 100), headers=None): + self._session = None + self._pid = None + self._pool_size = int(pool_size) + self._headers = headers + + def _create_session(self): + session = requests.Session() + adapter = HTTPAdapter( + pool_connections=self._pool_size, + pool_maxsize=self._pool_size + ) + session.mount('https://', adapter) + session.mount('http://', adapter) + if self._headers: + session.headers.update(self._headers) + return session + + @property + def session(self): + if self._session is None or self._pid != os.getpid(): + self._pid = os.getpid() + self._session = self._create_session() + return self._session + + def __getstate__(self): + state = self.__dict__.copy() + state["_session"] = None + state["_pid"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def request(self, method, url, *args, **kwargs): + return self.session.request(method, url, *args, **kwargs) + + def retried_request( url, method="GET", @@ -61,6 +119,7 @@ def retried_request( retries=8, timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_READ_TIMEOUT), verify=True, + http_session: HTTPSession = None, ): """ Retried request function with exponential backoff @@ -79,6 +138,8 @@ def retried_request( Number of retries, by default 8 timeout : tupe(float, float), optional Request connect and read timeout in seconds, by default (3.05, 10) + http_session : HTTPSession, optional + Session object to use for request Returns ------- @@ -91,10 +152,11 @@ def retried_request( If request fails """ + requester = http_session if http_session else requests retried = 0 while retried < retries: try: - resp = requests.request( + resp = requester.request( method=method, url=url, data=data, diff --git a/src/oceanum/datamesh/zarr.py b/src/oceanum/datamesh/zarr.py index 03c8a8e..4674a8a 100644 --- a/src/oceanum/datamesh/zarr.py +++ b/src/oceanum/datamesh/zarr.py @@ -8,10 +8,17 @@ import xarray import fsspec import urllib.parse +import requests from .exceptions import DatameshConnectError, DatameshWriteError from .session import Session -from .utils import retried_request, DATAMESH_CONNECT_TIMEOUT, DATAMESH_CHUNK_READ_TIMEOUT, DATAMESH_CHUNK_WRITE_TIMEOUT +from .utils import ( + retried_request, + HTTPSession, + DATAMESH_CONNECT_TIMEOUT, + DATAMESH_CHUNK_READ_TIMEOUT, + DATAMESH_CHUNK_WRITE_TIMEOUT, +) try: import xarray_video as xv @@ -85,6 +92,7 @@ def __init__( self.verify = verify if storage_backend is not None: self.headers["X-DATAMESH-STORAGE-BACKEND"] = storage_backend + self.http_session = HTTPSession(headers=self.headers) def _retried_request( self, @@ -94,15 +102,19 @@ def _retried_request( connect_timeout=DATAMESH_CONNECT_TIMEOUT, read_timeout=DATAMESH_CHUNK_READ_TIMEOUT, ): - resp = retried_request( - url=path, - method=method, - data=data, - headers=self.headers, - retries=self.retries, - timeout=(connect_timeout, read_timeout), - verify=self.verify, - ) + try: + resp = retried_request( + url=path, + method=method, + data=data, + retries=self.retries, + timeout=(connect_timeout, read_timeout), + verify=self.verify, + http_session=self.http_session, + ) + except requests.RequestException as e: + raise DatameshConnectError(str(e)) + if resp.status_code == 401: raise DatameshConnectError(f"Not Authorized {resp.text}") if resp.status_code == 410: diff --git a/tests/test_verify_parameter.py b/tests/test_verify_parameter.py index 8e4aa46..967d7e2 100644 --- a/tests/test_verify_parameter.py +++ b/tests/test_verify_parameter.py @@ -12,8 +12,8 @@ @pytest.fixture def mock_request(): - """Mock the requests.request function to check verify parameter""" - with patch("requests.request") as mock_req: + """Mock the requests.Session.request function to check verify parameter""" + with patch("requests.Session.request") as mock_req: mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"id": "test-id", "name": "test"}