Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 42 additions & 54 deletions src/oceanum/datamesh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
from .session import Session
from .utils import (
retried_request,
HTTPSession,
DATAMESH_WRITE_TIMEOUT,
DATAMESH_CONNECT_TIMEOUT,
DATAMESH_DOWNLOAD_TIMEOUT,
DATAMESH_STAGE_READ_TIMEOUT,
)
from ..__init__ import __version__


DEFAULT_CONFIG = {"DATAMESH_SERVICE": "https://datamesh.oceanum.io"}

DASK_QUERY_SIZE = 1000000000 # 1GB
Expand Down Expand Up @@ -114,10 +116,13 @@ def __init__(
if not verify:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

self.http_session = HTTPSession(headers=self._auth_headers)

self._check_info()
if self._host.split(".")[-1] != self._gateway.split(".")[-1]:
warnings.warn("Gateway and service domain do not match")


def _init_auth_headers(self, token: str | None, user: str | None = None):
if token is not None:
if token.startswith("Bearer "):
Expand All @@ -134,6 +139,15 @@ def _init_auth_headers(self, token: str | None, user: str | None = None):
"A valid key must be supplied as a connection constructor argument or defined in environment variables as DATAMESH_TOKEN"
)

def _retried_request(self, *args, **kwargs):
"""Wrapper around retried_request to use connection settings"""
return retried_request(
*args,
verify=self._verify,
http_session=self.http_session,
**kwargs,
)

@property
def host(self):
"""Datamesh host
Expand All @@ -145,10 +159,8 @@ def host(self):

# Check the status of the metadata server
def _status(self):
resp = retried_request(
resp = self._retried_request(
f"{self._proto}://{self._host}",
headers=self._auth_headers,
verify=self._verify,
)
return resp.status_code == 200

Expand All @@ -158,45 +170,31 @@ def _check_info(self):
Typically will ask to update the client if the version is outdated.
Also will try to guess gateway address if not provided.
"""

_gateway = self._gateway or f"{self._proto}://{self._host}"
self._is_v1 = True
try:
resp = retried_request(
resp = self._retried_request(
f"{_gateway}/info/oceanum_python/{__version__}",
headers=self._auth_headers,
retries=1,
verify=self._verify,
retries=5,
)
if resp.status_code == 200:
r = resp.json()
if "message" in r:
print(r["message"])
print("Using datamesh API version 1")
self._gateway = _gateway
self._is_v1 = True
return
elif resp.status_code == 404:
print("Using datamesh API version 0")
self._is_v1 = False
self._gateway = self._gateway or f"{self._proto}://gateway.{self._host}"
return
raise DatameshConnectError(
f"Failed to reach datamesh: {resp.status_code}-{resp.text}"
)
except:
_gateway = self._gateway or f"{self._proto}://gateway.{self._host}"
except Exception as e:
warnings.warn(f"Failed to reach datamesh gateway at {_gateway}: {e}")
warnings.warn("Assuming datamesh API version 1")
self._gateway = _gateway
self._is_v1 = False
print("Using datamesh API version 0")
try:
resp = retried_request(
f"https://datamesh-v1.oceanum.io/info/oceanum_python/{__version__}",
headers=self._auth_headers,
retries=1,
verify=self._verify,
)
if resp.status_code == 200:
r = resp.json()
if "message" in r:
print(r["message"])
except:
pass
return

def _validate_response(self, resp):
if resp.status_code >= 400:
Expand All @@ -207,11 +205,9 @@ def _validate_response(self, resp):
raise DatameshConnectError(msg)

def _metadata_request(self, datasource_id="", params={}):
resp = retried_request(
resp = self._retried_request(
f"{self._proto}://{self._host}/datasource/{datasource_id}",
params=params,
headers=self._auth_headers,
verify=self._verify,
)
if resp.status_code == 404:
raise DatameshConnectError(f"Datasource {datasource_id} not found")
Expand All @@ -224,44 +220,39 @@ def _metadata_write(self, datasource):
data = datasource.model_dump_json(by_alias=True, warnings=False).encode(
"utf-8", "ignore"
)
headers = {**self._auth_headers, "Content-Type": "application/json"}
headers = {"Content-Type": "application/json"}
if datasource._exists:
resp = retried_request(
resp = self._retried_request(
f"{self._proto}://{self._host}/datasource/{datasource.id}/",
method="PATCH",
data=data,
headers=headers,
verify=self._verify,
)

else:
resp = retried_request(
resp = self._retried_request(
f"{self._proto}://{self._host}/datasource/",
method="POST",
data=data,
headers=headers,
verify=self._verify,
)
self._validate_response(resp)
return resp

def _delete(self, datasource_id):
resp = retried_request(
resp = self._retried_request(
f"{self._gateway}/data/{datasource_id}",
method="DELETE",
headers=self._auth_headers,
verify=self._verify,
)
self._validate_response(resp)
return True

def _data_request(self, datasource_id, data_format="application/json", cache=False):
tmpfile = os.path.join(self._cachedir.name, datasource_id)
resp = retried_request(
resp = self._retried_request(
f"{self._gateway}/data/{datasource_id}",
headers={"Accept": data_format, **self._auth_headers},
headers={"Accept": data_format},
timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT),
verify=self._verify,
)
self._validate_response(resp)
with open(tmpfile, "wb") as f:
Expand All @@ -278,26 +269,24 @@ def _data_write(
):
# Connection timeout does not act in the same way in write and read contexts
# and using a short connection timeout in write contexts leads to closed connections
headers = {"Content-Type": data_format}
if overwrite:
resp = retried_request(
resp = self._retried_request(
f"{self._gateway}/data/{datasource_id}",
method="PUT",
data=data,
headers={"Content-Type": data_format, **self._auth_headers},
headers=headers,
timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT),
verify=self._verify,
)
else:
headers = {"Content-Type": data_format, **self._auth_headers}
if append:
headers["X-Append"] = str(append)
resp = retried_request(
resp = self._retried_request(
f"{self._gateway}/data/{datasource_id}",
method="PATCH",
data=data,
headers=headers,
timeout=(DATAMESH_WRITE_TIMEOUT, DATAMESH_WRITE_TIMEOUT),
verify=self._verify,
)
self._validate_response(resp)
return Datasource(**resp.json())
Expand All @@ -307,13 +296,12 @@ def _stage_request(self, query, session, cache=False):
query.model_dump_json(warnings=False).encode()
).hexdigest()

resp = retried_request(
resp = self._retried_request(
f"{self._gateway}/oceanql/stage/",
method="POST",
headers=session.add_header(self._auth_headers),
headers=session.header,
data=query.model_dump_json(warnings=False),
timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_STAGE_READ_TIMEOUT),
verify=self._verify,
)
if resp.status_code >= 400:
try:
Expand Down Expand Up @@ -370,14 +358,14 @@ def _query(self, query, use_dask=False, cache_timeout=0, retry=0):
if stage.container == Container.Dataset
else "application/parquet"
)
headers = {"Accept": transfer_format, **self._auth_headers}
resp = retried_request(
headers = {"Accept": transfer_format,
**session.header}
resp = self._retried_request(
f"{self._gateway}/oceanql/",
method="POST",
headers=headers,
data=query.model_dump_json(warnings=False),
timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_DOWNLOAD_TIMEOUT),
verify=self._verify,
)
if resp.status_code > 500:
if cache_timeout:
Expand Down
64 changes: 63 additions & 1 deletion src/oceanum/datamesh/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from time import sleep
import requests
from requests.adapters import HTTPAdapter
import numpy as np
from .exceptions import DatameshConnectError
import os
Expand Down Expand Up @@ -52,6 +53,63 @@
)


class HTTPSession:
"""
A requests.Session wrapper that is safe to use across forked processes
Attributes
----------
pool_size : int, optional
The size of the connection pool, by default None
headers : dict, optional
Default headers to include in each request, by default None
Methods
-------
session : requests.Session
Returns a requests.Session object that is safe to use in the current process
__getstate__ : dict
Returns the state of the object for pickling
__setstate__ : None
Restores the state of the object from pickling
"""

def __init__(self, pool_size=os.environ.get("DATAMESH_CONNECTION_POOL_SIZE", 100), headers=None):
self._session = None
self._pid = None
self._pool_size = int(pool_size)
self._headers = headers

def _create_session(self):
session = requests.Session()
adapter = HTTPAdapter(
pool_connections=self._pool_size,
pool_maxsize=self._pool_size
)
session.mount('https://', adapter)
session.mount('http://', adapter)
if self._headers:
session.headers.update(self._headers)
return session

@property
def session(self):
if self._session is None or self._pid != os.getpid():
self._pid = os.getpid()
self._session = self._create_session()
return self._session

def __getstate__(self):
state = self.__dict__.copy()
state["_session"] = None
state["_pid"] = None
return state

def __setstate__(self, state):
self.__dict__.update(state)

def request(self, method, url, *args, **kwargs):
return self.session.request(method, url, *args, **kwargs)


def retried_request(
url,
method="GET",
Expand All @@ -61,6 +119,7 @@ def retried_request(
retries=8,
timeout=(DATAMESH_CONNECT_TIMEOUT, DATAMESH_READ_TIMEOUT),
verify=True,
http_session: HTTPSession = None,
):
"""
Retried request function with exponential backoff
Expand All @@ -79,6 +138,8 @@ def retried_request(
Number of retries, by default 8
timeout : tupe(float, float), optional
Request connect and read timeout in seconds, by default (3.05, 10)
http_session : HTTPSession, optional
Session object to use for request

Returns
-------
Expand All @@ -91,10 +152,11 @@ def retried_request(
If request fails

"""
requester = http_session if http_session else requests
retried = 0
while retried < retries:
try:
resp = requests.request(
resp = requester.request(
method=method,
url=url,
data=data,
Expand Down
Loading