Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ module = [
"gliner",
"huggingface_hub",
"llama_cpp",
"llama_cpp.llama_chat_format",
"peft",
"torch",
"transformers",
Expand Down
4 changes: 2 additions & 2 deletions src/aibackends/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel

from aibackends.core.config import get_runtime, parse_model_text, parse_runtime_text
from aibackends.core.model_manager import ModelManager
from aibackends.core.model_manager import get_model_manager
from aibackends.core.types import RuntimeConfig
from aibackends.tasks.registry import create_task, get_task

Expand Down Expand Up @@ -61,7 +61,7 @@ def pull_model(
None, "--cache-dir", help="Optional Hugging Face cache directory override."
),
) -> None:
manager = ModelManager(cache_dir=cache_dir)
manager = get_model_manager(cache_dir)
location = manager.pull_model(
RuntimeConfig.model_validate(
{
Expand Down
58 changes: 49 additions & 9 deletions src/aibackends/core/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import platform
import threading
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path, PurePosixPath
Expand All @@ -10,6 +11,9 @@
from aibackends.core.types import RuntimeConfig
from aibackends.model_support import get_model_support

_MANAGER_INSTANCES: dict[str | None, ModelManager] = {}
_MANAGER_LOCK = threading.Lock()


@dataclass(slots=True)
class HardwareProfile:
Expand All @@ -23,11 +27,27 @@ class ModelLocation:
local_path: str | None = None


def get_model_manager(cache_dir: str | None = None) -> ModelManager:
"""Return a shared ``ModelManager`` for *cache_dir*, creating one if needed."""
normalized = str(Path(cache_dir).expanduser()) if cache_dir else None
cached = _MANAGER_INSTANCES.get(normalized)
if cached is not None:
return cached
with _MANAGER_LOCK:
cached = _MANAGER_INSTANCES.get(normalized)
if cached is not None:
return cached
manager = ModelManager(cache_dir)
_MANAGER_INSTANCES[normalized] = manager
return manager


class ModelManager:
def __init__(self, cache_dir: str | None = None) -> None:
self.cache_dir = Path(cache_dir).expanduser() if cache_dir else None
if self.cache_dir is not None:
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._gguf_selection_cache: dict[str, PurePosixPath] = {}

def detect_hardware(self) -> HardwareProfile:
machine = platform.machine().lower()
Expand Down Expand Up @@ -86,23 +106,38 @@ def pull_model(self, config: RuntimeConfig) -> ModelLocation:
local_dir = snapshot_download(repo_id=resolved, cache_dir=self._hf_cache_dir())
return ModelLocation(source=resolved, local_path=local_dir)

def _download_gguf_repo(self, repo_id: str) -> Path:
def _download_gguf_repo(
self,
repo_id: str,
preferred_quantization: str | None = None,
) -> Path:
try:
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError as exc:
raise RuntimeImportError(
"Install 'aibackends[llamacpp]' to enable GGUF model download and caching."
) from exc

candidates = self._list_gguf_files(repo_id, list_repo_files)
if not candidates:
raise ModelResolutionError(
f"No GGUF files found in repository: {repo_id}. "
"For llama.cpp, use a GGUF repo ID or a local GGUF file."
cache_key = f"{repo_id}::{preferred_quantization or ''}"
selected = self._gguf_selection_cache.get(cache_key)
if selected is None:
candidates = self._list_gguf_files(repo_id, list_repo_files)
if not candidates:
raise ModelResolutionError(
f"No GGUF files found in repository: {repo_id}. "
"For llama.cpp, use a GGUF repo ID or a local GGUF file."
)
selected = self._select_gguf_file(
candidates,
preferred_quantization=preferred_quantization,
)
self._gguf_selection_cache[cache_key] = selected

selected = self._select_gguf_file(candidates)
subfolder = None if selected.parent == PurePosixPath(".") else selected.parent.as_posix()
subfolder = (
None
if selected.parent == PurePosixPath(".")
else selected.parent.as_posix()
)
local_path = hf_hub_download(
repo_id=repo_id,
filename=selected.name,
Expand All @@ -128,8 +163,13 @@ def _list_gguf_files(
]
return sorted(candidates, key=lambda item: item.as_posix().lower())

def _select_gguf_file(self, candidates: list[PurePosixPath]) -> PurePosixPath:
def _select_gguf_file(
self,
candidates: list[PurePosixPath],
preferred_quantization: str | None = None,
) -> PurePosixPath:
preferred_order = [
*([] if preferred_quantization is None else [preferred_quantization]),
self.default_quantization(),
"Q4_K_M",
"Q5_K_M",
Expand Down
1 change: 1 addition & 0 deletions src/aibackends/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TransformerModelProfile:
chat_template: str | None = None
chat_template_path: str | None = None
generation_defaults: dict[str, Any] = field(default_factory=dict)
preferred_quantization: str | None = None

@property
def names(self) -> tuple[str, ...]:
Expand Down
4 changes: 2 additions & 2 deletions src/aibackends/core/runtimes/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel

from aibackends.core.exceptions import RuntimeImportError, RuntimeRequestError
from aibackends.core.model_manager import ModelLocation, ModelManager
from aibackends.core.model_manager import ModelLocation, get_model_manager
from aibackends.core.prompting import (
build_prompt_messages,
normalise_message_content,
Expand Down Expand Up @@ -185,7 +185,7 @@ def image_path_to_data_uri(path: str | Path) -> str:
class LlamaCppRuntime(BaseRuntime):
def __init__(self, config: RuntimeConfig) -> None:
super().__init__(config)
self.model_manager = ModelManager(cache_dir=config.cache_dir)
self.model_manager = get_model_manager(config.cache_dir)
self._client: Any | None = None
self._multimodal_client: Any | None = None

Expand Down
10 changes: 7 additions & 3 deletions src/aibackends/core/runtimes/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel

from aibackends.core.exceptions import RuntimeImportError
from aibackends.core.model_manager import ModelManager
from aibackends.core.model_manager import get_model_manager
from aibackends.core.model_registry import apply_transformer_model_profile
from aibackends.core.prompting import PromptRenderer
from aibackends.core.registry import RuntimeSpec
Expand All @@ -17,17 +17,21 @@ class TransformersRuntime(BaseRuntime):
def __init__(self, config: RuntimeConfig) -> None:
effective_config = apply_transformer_model_profile(config)
super().__init__(effective_config)
self.model_manager = ModelManager(cache_dir=effective_config.cache_dir)
self.model_manager = get_model_manager(effective_config.cache_dir)
self.prompt_renderer = PromptRenderer(effective_config)
self._tokenizer: Any | None = None
self._generator: Any | None = None
self._embed_model: Any | None = None
self._embed_tokenizer: Any | None = None
self._torch: Any | None = None
self._resolved_model_id: str | None = None

def _model_id(self) -> str:
if self._resolved_model_id is not None:
return self._resolved_model_id
location = self.model_manager.ensure_model(self.config)
return location.local_path or location.source
self._resolved_model_id = location.local_path or location.source
return self._resolved_model_id

def _load_generator(self):
if self._generator is not None and self._tokenizer is not None:
Expand Down
15 changes: 13 additions & 2 deletions src/aibackends/model_support/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from aibackends.core.model_manager import ModelLocation, ModelManager
from aibackends.core.model_registry import resolve_model_profile
from aibackends.core.registry import ModelSupportSpec
from aibackends.core.types import RuntimeConfig

Expand All @@ -10,8 +11,18 @@ def ensure_model(
config: RuntimeConfig,
resolved: str,
) -> ModelLocation:
del config
local_path = manager._download_gguf_repo(resolved)
profile = resolve_model_profile(config.model, runtime=config.runtime)
preferred_quantization = (
profile.preferred_quantization if profile is not None else None
)
if preferred_quantization is None:
preferred_quantization = config.extra_options.get(
"preferred_quantization"
)
local_path = manager._download_gguf_repo(
resolved,
preferred_quantization=preferred_quantization,
)
return ModelLocation(source=resolved, local_path=str(local_path))


Expand Down