Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ data/transformed/uniprot_functional_microbes/uniprot_kgx.zip

data/transformed/ontologies/.Rapp.history
mediadive_cache.sqlite
mediadive_bulk_cache.sqlite
data/raw/mediadive/
data/raw/.keep
kg_microbe/transform_utils/uniprot_human/tmp/relevant_files.tsv
kg_microbe/transform_utils/uniprot/tmp/relevant_files.tsv
Expand Down
162 changes: 137 additions & 25 deletions kg_microbe/utils/mediadive_bulk_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@

import json
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

import requests
import requests_cache
from tqdm import tqdm

# Default to 5 workers — still a large speedup over sequential but polite to
# MediaDive, which is a small academic REST API at DSMZ.
DEFAULT_MAX_WORKERS = 5

# Descriptive User-Agent so the API operator can identify traffic source.
USER_AGENT = "kg-microbe (Knowledge-Graph-Hub; https://github.com/Knowledge-Graph-Hub/kg-microbe)"

# Set up logging for API warnings (written to file, not stdout)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,34 +51,61 @@ def setup_cache(cache_name: str = "mediadive_bulk_cache"):
print(f"HTTP cache enabled: {cache_name}.sqlite")


def _make_session() -> requests.Session:
"""Create a requests Session with the kg-microbe User-Agent."""
session = requests.Session()
session.headers.update({"User-Agent": USER_AGENT})
return session


def get_json_from_api(
url: str, retry_count: int = 3, retry_delay: float = 2.0, verbose: bool = False
url: str,
retry_count: int = 3,
retry_delay: float = 2.0,
verbose: bool = False,
session: Optional[requests.Session] = None,
) -> Dict:
"""
Get JSON data from MediaDive API with retry logic.

Respects Retry-After headers on 429 responses.

Args:
----
url: Full API URL to fetch
retry_count: Number of retries on failure
retry_delay: Delay in seconds between retries
retry_delay: Delay in seconds between retries (overridden by Retry-After on 429)
verbose: If True, log empty responses (useful for debugging)
session: Optional requests Session to reuse (uses module-level session if None)

Returns:
-------
Dictionary with API response data (empty dict on failure or empty response)

"""
requester = session or _make_session()
for attempt in range(retry_count):
try:
r = requests.get(url, timeout=30)
r = requester.get(url, timeout=30)
r.raise_for_status()
data_json = r.json()
result = data_json.get(DATA_KEY, {})
# Distinguish empty API response from failure (for debugging)
if not result and verbose:
print(f" Empty response from API: {url}")
return result
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code == 429:
wait = float(e.response.headers.get("Retry-After", retry_delay))
logger.debug(f"429 Too Many Requests — waiting {wait}s (URL: {url})")
time.sleep(wait)
continue
if attempt < retry_count - 1:
logger.debug(f"Retry {attempt + 1}/{retry_count} after error: {e} (URL: {url})")
time.sleep(retry_delay)
else:
logger.warning(f"Request failed after {retry_count} attempts: {e} (URL: {url})")
return {}
except requests.exceptions.RequestException as e:
if attempt < retry_count - 1:
logger.debug(f"Retry {attempt + 1}/{retry_count} after error: {e} (URL: {url})")
Expand Down Expand Up @@ -101,55 +137,117 @@ def load_basic_media_list(basic_file: str) -> List[Dict]:
return media_list


def download_detailed_media(media_list: List[Dict]) -> Dict[str, Dict]:
def _fetch_medium_detail(
medium: Dict,
session: requests.Session,
rate_limiter: threading.Semaphore,
retry_count: int,
retry_delay: float,
) -> tuple[str, dict]:
"""Fetch detailed recipe for a single medium. Returns (medium_id, data)."""
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_ENDPOINT + medium_id
with rate_limiter:
return medium_id, get_json_from_api(url, retry_count=retry_count, retry_delay=retry_delay, session=session)


def _fetch_medium_strains(
medium: Dict,
session: requests.Session,
rate_limiter: threading.Semaphore,
retry_count: int,
retry_delay: float,
) -> tuple[str, dict]:
"""Fetch strain associations for a single medium. Returns (medium_id, data)."""
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_STRAINS_ENDPOINT + medium_id
with rate_limiter:
return medium_id, get_json_from_api(url, retry_count=retry_count, retry_delay=retry_delay, session=session)


def download_detailed_media(
media_list: List[Dict],
max_workers: int = DEFAULT_MAX_WORKERS,
retry_count: int = 3,
retry_delay: float = 2.0,
requests_per_second: float = 10.0,
) -> Dict[str, Dict]:
"""
Download detailed recipe information for all media.

Args:
----
media_list: List of basic media records
max_workers: Number of parallel download threads
retry_count: Number of retries on request failure
retry_delay: Seconds between retries (overridden by Retry-After on 429)
requests_per_second: Maximum sustained request rate (smooths bursts)

Returns:
-------
Dictionary mapping medium_id -> detailed_recipe_data

"""
print(f"\nDownloading detailed recipes for {len(media_list)} media...")
detailed_data = {}

for medium in tqdm(media_list, desc="Downloading medium details"):
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_ENDPOINT + medium_id
data = get_json_from_api(url)
if data:
detailed_data[medium_id] = data
detailed_data: Dict[str, Dict] = {}
session = _make_session()
rate_limiter = threading.Semaphore(max_workers)

def fetch(medium: Dict) -> tuple[str, dict]:
return _fetch_medium_detail(medium, session, rate_limiter, retry_count, retry_delay)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
for medium_id, data in tqdm(
executor.map(fetch, media_list),
total=len(media_list),
desc="Downloading medium details",
):
if data:
detailed_data[medium_id] = data

print(f"Downloaded {len(detailed_data)} detailed medium recipes")
return detailed_data


def download_medium_strains(media_list: List[Dict]) -> Dict[str, List]:
def download_medium_strains(
media_list: List[Dict],
max_workers: int = DEFAULT_MAX_WORKERS,
retry_count: int = 3,
retry_delay: float = 2.0,
requests_per_second: float = 10.0,
) -> Dict[str, List]:
"""
Download strain associations for all media.

Args:
----
media_list: List of basic media records
max_workers: Number of parallel download threads
retry_count: Number of retries on request failure
retry_delay: Seconds between retries (overridden by Retry-After on 429)
requests_per_second: Maximum sustained request rate (smooths bursts)

Returns:
-------
Dictionary mapping medium_id -> list_of_strain_data

"""
print(f"\nDownloading strain associations for {len(media_list)} media...")
strain_data = {}

for medium in tqdm(media_list, desc="Downloading medium-strain associations"):
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_STRAINS_ENDPOINT + medium_id
data = get_json_from_api(url)
if data:
strain_data[medium_id] = data
strain_data: Dict[str, List] = {}
session = _make_session()
rate_limiter = threading.Semaphore(max_workers)

def fetch(medium: Dict) -> tuple[str, dict]:
return _fetch_medium_strains(medium, session, rate_limiter, retry_count, retry_delay)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
for medium_id, data in tqdm(
executor.map(fetch, media_list),
total=len(media_list),
desc="Downloading medium-strain associations",
):
if data:
strain_data[medium_id] = data

# Count total strain associations, handling different data types
total_strains = 0
Expand Down Expand Up @@ -241,7 +339,13 @@ def save_json_file(data: Dict, filepath: Path, description: str):
print(f"Saved {description} to {filepath} ({file_size_mb:.2f} MB)")


def download_mediadive_bulk(basic_file: str, output_dir: str):
def download_mediadive_bulk(
basic_file: str,
output_dir: str,
max_workers: int = DEFAULT_MAX_WORKERS,
retry_count: int = 3,
retry_delay: float = 2.0,
):
"""
Download all MediaDive data in bulk.

Expand All @@ -251,6 +355,9 @@ def download_mediadive_bulk(basic_file: str, output_dir: str):
----
basic_file: Path to mediadive.json (basic media list)
output_dir: Directory to save bulk data files
max_workers: Number of parallel download threads (default: 5, polite for small APIs)
retry_count: Number of retries on request failure
retry_delay: Seconds between retries (overridden by Retry-After on 429)

"""
output_path = Path(output_dir)
Expand All @@ -267,6 +374,7 @@ def download_mediadive_bulk(basic_file: str, output_dir: str):
logger.setLevel(logging.DEBUG)
logger.propagate = False # Prevent propagation to root logger and stdout
print(f"API warnings will be logged to: {log_file}")
print(f"Using {max_workers} parallel workers")

# Set up HTTP caching
setup_cache()
Expand All @@ -277,12 +385,16 @@ def download_mediadive_bulk(basic_file: str, output_dir: str):

# Step 2: Download detailed medium recipes
print("\n[2/5] Downloading detailed medium recipes...")
detailed_media = download_detailed_media(media_list)
detailed_media = download_detailed_media(
media_list, max_workers=max_workers, retry_count=retry_count, retry_delay=retry_delay
)
save_json_file(detailed_media, output_path / "media_detailed.json", "detailed media recipes")

# Step 3: Download medium-strain associations
print("\n[3/5] Downloading medium-strain associations...")
media_strains = download_medium_strains(media_list)
media_strains = download_medium_strains(
media_list, max_workers=max_workers, retry_count=retry_count, retry_delay=retry_delay
)
save_json_file(media_strains, output_path / "media_strains.json", "medium-strain associations")

# Step 4: Extract solutions from embedded structure
Expand Down
Loading
Loading