Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion mailsweep/ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def chat(self, messages: list[dict], system: str = "") -> str:
"ollama": {"base_url": "http://localhost:11434/v1", "api_key": "", "model": "qwen3:8b"},
"lm-studio": {"base_url": "http://localhost:1234/v1", "api_key": "", "model": ""},
"openai": {"base_url": "https://api.openai.com/v1", "api_key": "", "model": "gpt-5.2"},
"anthropic": {"base_url": "", "api_key": "", "model": "claude-sonnet-4-6"},
"anthropic": {"base_url": "https://api.anthropic.com/v1", "api_key": "", "model": "claude-sonnet-4-6"},
"custom": {"base_url": "http://localhost:8080/v1", "api_key": "", "model": ""},
}

Expand All @@ -131,11 +131,88 @@ def chat(self, messages: list[dict], system: str = "") -> str:
}


def normalize_url(url: str) -> str:
"""Normalize a user-supplied URL or host:port to a full base URL ending in /v1.

Accepts:
- ``host:port`` → ``http://host:port/v1``
- ``http://host:port`` → ``http://host:port/v1``
- ``http://host:port/v1`` → unchanged
"""
s = url.strip().rstrip("/")
if not s:
return s
if not s.startswith(("http://", "https://")):
return f"http://{s}/v1"
# Has a scheme — ensure it ends with /v1
if not s.endswith("/v1"):
return f"{s}/v1"
return s


def detect_and_fetch(base_url: str, api_key: str = "") -> tuple[str, list[str]]:
"""Probe *base_url* to detect the API type and return (label, models).

Detection order:
1. Ollama native — GET /api/tags
2. OpenAI-compat — GET /v1/models
Returns ("unknown", []) when neither probe succeeds.
"""
full_url = normalize_url(base_url).rstrip("/")
if not full_url:
return ("unknown", [])

# Derive base host (strip trailing /v1 if present)
base_host = full_url[:-3] if full_url.endswith("/v1") else full_url

# 1. Ollama native API
try:
req = urllib.request.Request(f"{base_host}/api/tags", method="GET")
with urllib.request.urlopen(req, timeout=5) as resp:
body = json.loads(resp.read().decode("utf-8"))
if "models" in body:
models = sorted(m["name"] for m in body["models"])
return ("Ollama (native API)", models)
except Exception:
logger.debug("detect_and_fetch: Ollama probe failed for %s", base_host, exc_info=True)

# 2. OpenAI-compatible — /v1/models
try:
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
req = urllib.request.Request(f"{full_url}/models", headers=headers, method="GET")
with urllib.request.urlopen(req, timeout=5) as resp:
body = json.loads(resp.read().decode("utf-8"))
if "data" in body:
models = sorted(item["id"] for item in body["data"])
return ("OpenAI-compatible", models)
except Exception:
logger.debug("detect_and_fetch: OpenAI-compat probe failed for %s", full_url, exc_info=True)

return ("unknown", [])


def fetch_anthropic_models(api_key: str) -> list[str]:
"""Fetch available models from the Anthropic API."""
url = "https://api.anthropic.com/v1/models"
headers = {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
req = urllib.request.Request(url, headers=headers, method="GET")
try:
with urllib.request.urlopen(req, timeout=5) as resp:
body = json.loads(resp.read().decode("utf-8"))
return sorted(item["id"] for item in body.get("data", []))
except Exception:
logger.debug("fetch_anthropic_models failed", exc_info=True)
return []


def fetch_model_list(base_url: str, api_key: str = "") -> list[str]:
"""GET {base_url}/models and return sorted list of model IDs.

Returns an empty list on any error (connection refused, timeout, etc.).
"""
base_url = normalize_url(base_url)
url = f"{base_url.rstrip('/')}/models"
headers: dict[str, str] = {}
if api_key:
Expand All @@ -159,6 +236,7 @@ def create_provider(
raise LLMError("Anthropic provider requires an API key.")
return AnthropicProvider(api_key=api_key, model=model)
# Everything else uses OpenAI-compatible endpoint
base_url = normalize_url(base_url)
if not base_url:
raise LLMError("Base URL is required for non-Anthropic providers.")
return OpenAICompatProvider(base_url=base_url, api_key=api_key, model=model)
35 changes: 23 additions & 12 deletions mailsweep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@

AI_PROVIDER: str = "ollama" # ollama | openai | anthropic | custom
AI_BASE_URL: str = "http://localhost:11434/v1"
AI_API_KEY: str = ""
AI_API_KEY: str = "" # key for the currently active provider
AI_API_KEYS: dict[str, str] = {} # per-provider keys, keyed by provider name
AI_MODEL: str = "llama3.2"


Expand All @@ -74,20 +75,23 @@ def save_settings() -> None:
except Exception as exc:
logger.warning("Could not save settings: %s", exc)

# Store AI API key in system keyring
if AI_API_KEY:
try:
from mailsweep.utils.keyring_store import set_password
# Store AI API keys in system keyring (one entry per provider)
try:
from mailsweep.utils.keyring_store import set_password
for _provider, _key in AI_API_KEYS.items():
set_password("ai_api_key", f"mailsweep_ai_{_provider}", _key)
# Legacy single-key entry kept for backwards compat
if AI_API_KEY:
set_password("ai_api_key", "mailsweep_ai", AI_API_KEY)
except Exception as exc:
logger.warning("Could not save AI API key to keyring: %s", exc)
except Exception as exc:
logger.warning("Could not save AI API key to keyring: %s", exc)


def load_settings() -> None:
"""Load persisted settings from disk, falling back to defaults."""
global SCAN_BATCH_SIZE, MESSAGE_TABLE_MAX_ROWS, DEFAULT_SAVE_DIR
global UNLABELLED_MODE, SKIP_ALL_MAIL, BLOCKLIST_AUTO_MOVE, BLOCKLIST_USE_COMMUNITY, BLOCKLIST_COMMUNITY_URL
global AI_PROVIDER, AI_BASE_URL, AI_API_KEY, AI_MODEL
global AI_PROVIDER, AI_BASE_URL, AI_API_KEY, AI_API_KEYS, AI_MODEL
if not SETTINGS_PATH.exists():
return
try:
Expand All @@ -111,12 +115,19 @@ def load_settings() -> None:
except Exception as exc:
logger.warning("Could not load settings: %s", exc)

# Load AI API key from keyring
# Load AI API keys from keyring (per-provider entries)
try:
from mailsweep.utils.keyring_store import get_password
key = get_password("ai_api_key", "mailsweep_ai")
if key:
AI_API_KEY = key
for _provider in ("ollama", "lm-studio", "openai", "anthropic", "custom"):
_key = get_password("ai_api_key", f"mailsweep_ai_{_provider}")
if _key:
AI_API_KEYS[_provider] = _key
# Legacy fallback: single key entry
if not AI_API_KEYS.get(AI_PROVIDER):
_key = get_password("ai_api_key", "mailsweep_ai")
if _key:
AI_API_KEYS[AI_PROVIDER] = _key
AI_API_KEY = AI_API_KEYS.get(AI_PROVIDER, "")
except Exception as exc:
logger.debug("Could not load AI API key from keyring: %s", exc)

Expand Down
102 changes: 80 additions & 22 deletions mailsweep/ui/ai_dock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from typing import NamedTuple

from mailsweep.ai.providers import PROVIDER_MODELS, PROVIDER_PRESETS, fetch_model_list
from mailsweep.ai.providers import PROVIDER_MODELS, PROVIDER_PRESETS, detect_and_fetch, fetch_anthropic_models, fetch_model_list, normalize_url

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,7 +90,8 @@ def _build_ui(self) -> None:
url_key_row = QHBoxLayout()
url_key_row.addWidget(QLabel("URL:"))
self._url_edit = QLineEdit()
self._url_edit.setPlaceholderText("http://localhost:11434/v1")
self._url_edit.setPlaceholderText("http://host:port/v1 or host:port")
self._url_edit.editingFinished.connect(self._on_url_editing_finished)
url_key_row.addWidget(self._url_edit)

self._key_label = QLabel("Key:")
Expand All @@ -102,6 +103,10 @@ def _build_ui(self) -> None:
url_key_row.addWidget(self._key_edit)
layout.addLayout(url_key_row)

self._api_type_label = QLabel("<i>API type: OpenAI-compatible (auto-detected)</i>")
self._api_type_label.setVisible(False)
layout.addWidget(self._api_type_label)

# ── Chat history ─────────────────────────────────────────────────────
self._chat_browser = QTextBrowser()
self._chat_browser.setOpenExternalLinks(False)
Expand Down Expand Up @@ -151,76 +156,128 @@ def _build_ui(self) -> None:
def _load_from_config(self) -> None:
"""Load AI settings from config module."""
import mailsweep.config as cfg
# Block signals so setting the provider index doesn't prematurely
# trigger a fetch with the wrong (preset) URL.
self._provider_combo.blockSignals(True)
idx = self._provider_combo.findText(cfg.AI_PROVIDER)
if idx >= 0:
self._provider_combo.setCurrentIndex(idx)
self._provider_combo.blockSignals(False)

self._active_provider = cfg.AI_PROVIDER
self._provider_models: dict[str, str] = {cfg.AI_PROVIDER: cfg.AI_MODEL}
self._provider_urls: dict[str, str] = {cfg.AI_PROVIDER: cfg.AI_BASE_URL}
self._url_edit.setText(cfg.AI_BASE_URL)
self._key_edit.setText(cfg.AI_API_KEY)
self._populate_model_combo(cfg.AI_PROVIDER)
self._model_combo.setCurrentText(cfg.AI_MODEL)
if cfg.AI_API_KEY:
self._key_edit.setText(cfg.AI_API_KEY)
self._update_key_visibility()
# Fetch with the correct saved URL
self._on_refresh_models()

def _on_url_editing_finished(self) -> None:
self._on_refresh_models()

def _on_provider_changed(self, provider: str) -> None:
import mailsweep.config as cfg
# Save current key, model, and URL for the outgoing provider
cfg.AI_API_KEYS[self._active_provider] = self._key_edit.text().strip()
self._provider_models[self._active_provider] = self._model_combo.currentText()
self._provider_urls[self._active_provider] = self._url_edit.text().strip()
self._active_provider = provider
preset = PROVIDER_PRESETS.get(provider, {})
if preset.get("base_url"):
self._url_edit.setText(preset["base_url"])
saved_url = self._provider_urls.get(provider)
self._url_edit.setText(saved_url if saved_url is not None else preset.get("base_url", ""))
self._key_edit.setText(cfg.AI_API_KEYS.get(provider, ""))
self._populate_model_combo(provider)
if preset.get("model"):
self._model_combo.setCurrentText(preset["model"])
# Restore saved model for this provider, falling back to preset default
saved_model = self._provider_models.get(provider) or preset.get("model", "")
if saved_model:
self._model_combo.setCurrentText(saved_model)
self._update_key_visibility()
self._on_refresh_models()

def _populate_model_combo(self, provider: str) -> None:
self._model_combo.clear()
# Skip static list for providers where we auto-fetch real models
if provider in ("ollama", "lm-studio", "custom"):
return
models = PROVIDER_MODELS.get(provider, [])
if models:
self._model_combo.addItems(models)

def _update_key_visibility(self) -> None:
hide = self._provider_combo.currentText() in ("ollama", "lm-studio")
self._key_label.setVisible(not hide)
self._key_edit.setVisible(not hide)
provider = self._provider_combo.currentText()
hide_key = provider in ("ollama", "lm-studio")
self._key_label.setVisible(not hide_key)
self._key_edit.setVisible(not hide_key)
self._api_type_label.setVisible(provider == "custom")

def _on_refresh_models(self) -> None:
base_url = self._url_edit.text().strip()
base_url = normalize_url(self._url_edit.text().strip())
api_key = self._key_edit.text().strip()
if not base_url:
provider = self._provider_combo.currentText()
if not base_url and provider != "anthropic":
return
self._refresh_btn.setEnabled(False)
self._refresh_btn.setText("…")
self._pending_model = self._model_combo.currentText()
self._model_combo.clear()

use_detect = provider == "custom"
if use_detect:
self._api_type_label.setText("<i>API type: detecting…</i>")

class _Fetcher(QObject):
done = pyqtSignal(list)
def __init__(self, url, key):
done = pyqtSignal(str, list)
def __init__(self, url, key, prov, detect):
super().__init__()
self._url = url
self._key = key
self._prov = prov
self._detect = detect
def run(self):
self.done.emit(fetch_model_list(self._url, self._key))
if self._prov == "anthropic":
models = fetch_anthropic_models(self._key)
self.done.emit("", models)
elif self._detect:
api_type, models = detect_and_fetch(self._url, self._key)
self.done.emit(api_type, models)
else:
models = fetch_model_list(self._url, self._key)
self.done.emit("", models)

thread = QThread(self)
worker = _Fetcher(base_url, api_key)
worker = _Fetcher(base_url, api_key, provider, use_detect)
worker.moveToThread(thread)
thread.started.connect(worker.run)
worker.done.connect(lambda models: self._on_models_fetched(models))
worker.done.connect(self._on_models_fetched)
worker.done.connect(thread.quit)
worker.done.connect(worker.deleteLater)
thread.finished.connect(thread.deleteLater)
self._refresh_thread = thread
self._refresh_worker = worker
thread.start()

def _on_models_fetched(self, models: list[str]) -> None:
def _on_models_fetched(self, api_type: str, models: list[str]) -> None:
self._refresh_btn.setEnabled(True)
self._refresh_btn.setText("Refresh")
if api_type:
self._api_type_label.setText(f"<i>API type: {api_type} (auto-detected)</i>")
to_restore = getattr(self, "_pending_model", "") or ""
if not models:
# Fall back to static list so the combo isn't left blank
static = PROVIDER_MODELS.get(self._provider_combo.currentText(), [])
if static:
self._model_combo.addItems(static)
if to_restore:
self._model_combo.setCurrentText(to_restore)
return
current = self._model_combo.currentText()
existing = {self._model_combo.itemText(i) for i in range(self._model_combo.count())}
for m in models:
if m not in existing:
self._model_combo.addItem(m)
self._model_combo.setCurrentText(current)
self._model_combo.setCurrentText(to_restore or models[0])

def set_context(self, context: str) -> None:
"""Set the DB context string (called by main_window)."""
Expand All @@ -239,7 +296,7 @@ def _send_message(self, text: str) -> None:
return

provider = self._provider_combo.currentText()
base_url = self._url_edit.text().strip()
base_url = normalize_url(self._url_edit.text().strip())
api_key = self._key_edit.text().strip()
model = self._model_combo.currentText().strip()

Expand Down Expand Up @@ -290,6 +347,7 @@ def _save_to_config(self, provider: str, base_url: str, api_key: str, model: str
cfg.AI_PROVIDER = provider
cfg.AI_BASE_URL = base_url
cfg.AI_API_KEY = api_key
cfg.AI_API_KEYS[provider] = api_key
cfg.AI_MODEL = model
cfg.save_settings()

Expand Down
Loading