From 6609632bf6c94d53768a7d8f82bf9b8a8810c851 Mon Sep 17 00:00:00 2001 From: neurocis Date: Tue, 2 Jun 2026 21:54:48 -0700 Subject: [PATCH 1/2] feat(login): Implement secure post-login redirection Return users to their original requested page after login, with robust same-origin validation to prevent open redirects. Co-Author Agent Hero --- helpers/api.py | 33 +++++++++++++++++++++++- helpers/ui_server.py | 12 ++++++--- tests/test_http_auth_csrf.py | 49 ++++++++++++++++++++++++++++++++++++ webui/js/api.js | 20 +++++++++++---- webui/login.html | 5 +++- 5 files changed, 109 insertions(+), 10 deletions(-) diff --git a/helpers/api.py b/helpers/api.py index 9616528b7c..9ac198c163 100644 --- a/helpers/api.py +++ b/helpers/api.py @@ -1,6 +1,7 @@ from abc import abstractmethod import json import threading +from urllib.parse import urlsplit from functools import wraps from pathlib import Path from typing import Union, Dict, Any @@ -102,6 +103,36 @@ def use_context(self, ctxid: str, create_if_not_exists: bool = True): from helpers.network import is_loopback_address +def is_safe_next_url(value: str | None) -> bool: + """Return True when value is a safe same-origin redirect target.""" + if not value: + return False + if "\r" in value or "\n" in value: + return False + + parsed = urlsplit(value) + if parsed.scheme or parsed.netloc: + return False + + # Require an absolute path within this origin, but reject protocol-relative URLs. + return parsed.path.startswith("/") and not parsed.path.startswith("//") + + +def get_safe_next_url(value: str | None, fallback: str | None = None) -> str | None: + """Return value if it is a safe next URL, otherwise return a safe fallback.""" + if is_safe_next_url(value): + return value + if is_safe_next_url(fallback): + return fallback + return None + + +def get_current_request_next_url() -> str: + """Return the current request path/query as a safe relative redirect target.""" + next_url = request.full_path if request.query_string else request.path + return get_safe_next_url(next_url, url_for("serve_index")) or url_for("serve_index") + + def requires_api_key(f): @wraps(f) async def decorated(*args, **kwargs): @@ -142,7 +173,7 @@ async def decorated(*args, **kwargs): if not user_pass_hash: return await f(*args, **kwargs) if session.get("authentication") != user_pass_hash: - return redirect(url_for("login_handler")) + return redirect(url_for("login_handler", next=get_current_request_next_url())) return await f(*args, **kwargs) return decorated diff --git a/helpers/ui_server.py b/helpers/ui_server.py index 0756c6e0d6..6cf9b7bb98 100644 --- a/helpers/ui_server.py +++ b/helpers/ui_server.py @@ -26,7 +26,7 @@ import socketio # type: ignore[import-untyped] from helpers import dotenv, fasta2a_server, files, git, login, mcp_server, runtime -from helpers.api import register_api_route, requires_auth +from helpers.api import get_safe_next_url, register_api_route, requires_auth from helpers.extension import extensible from helpers.files import get_abs_path from helpers.print_style import PrintStyle @@ -200,19 +200,25 @@ def __init__(self, runtime_state: UiServerRuntime) -> None: @extensible async def login_handler(self): error = None + fallback_url = url_for("serve_index") + next_url = get_safe_next_url( + request.form.get("next") if request.method == "POST" else request.args.get("next"), + fallback_url, + ) + if request.method == "POST": user = dotenv.get_dotenv_value("AUTH_LOGIN") password = dotenv.get_dotenv_value("AUTH_PASSWORD") if request.form["username"] == user and request.form["password"] == password: session["authentication"] = login.get_credentials_hash() - return redirect(url_for("serve_index")) + return redirect(next_url or fallback_url) else: await asyncio.sleep(1) error = "Invalid Credentials. Please try again." login_page_content = files.read_file("webui/login.html") - return render_template_string(login_page_content, error=error) + return render_template_string(login_page_content, error=error, next=next_url) @extensible async def logout_handler(self): diff --git a/tests/test_http_auth_csrf.py b/tests/test_http_auth_csrf.py index 73d9750556..ddb01a4b37 100644 --- a/tests/test_http_auth_csrf.py +++ b/tests/test_http_auth_csrf.py @@ -122,3 +122,52 @@ async def secure(): _set_csrf_cookie(client, "csrf-4") response = client.get("/secure") assert response.status_code == 200 + + +def test_safe_next_url_accepts_plugin_page_path() -> None: + from helpers.api import get_safe_next_url, is_safe_next_url + + target = "/plugins/a0_voqualizer/webui/voqualizer.html" + assert is_safe_next_url(target) + assert get_safe_next_url(target, "/") == target + + +def test_safe_next_url_preserves_query_string() -> None: + from helpers.api import get_safe_next_url + + target = "/plugins/a0_voqualizer/webui/voqualizer.html?context=rlO1iMV7" + assert get_safe_next_url(target, "/") == target + + +def test_safe_next_url_rejects_external_and_protocol_relative_urls() -> None: + from helpers.api import get_safe_next_url, is_safe_next_url + + fallback = "/" + for value in [ + "https://evil.example/plugins/a0_voqualizer/webui/voqualizer.html", + "//evil.example/plugins/a0_voqualizer/webui/voqualizer.html", + "javascript:alert(1)", + "/safe\nLocation: https://evil.example", + ]: + assert not is_safe_next_url(value) + assert get_safe_next_url(value, fallback) == fallback + + +def test_auth_redirect_includes_original_path_and_query(monkeypatch) -> None: + from run_ui import requires_auth + + monkeypatch.setattr("helpers.login.get_credentials_hash", lambda: "hash") + + app = _make_app() + + @app.get("/plugins/a0_voqualizer/webui/voqualizer.html") + @requires_auth + async def voqualizer_page(): + return Response("ok", status=200) + + client = app.test_client() + response = client.get("/plugins/a0_voqualizer/webui/voqualizer.html?context=rlO1iMV7") + assert response.status_code == 302 + location = response.headers["Location"] + assert location.startswith("/login?next=") + assert "%2Fplugins%2Fa0_voqualizer%2Fwebui%2Fvoqualizer.html%3Fcontext%3DrlO1iMV7" in location diff --git a/webui/js/api.js b/webui/js/api.js index df41a4845d..26637686b2 100644 --- a/webui/js/api.js +++ b/webui/js/api.js @@ -266,10 +266,20 @@ function _normalizeApiUrl(url) { } function redirect(response) { - if (!(response.redirected && response.url.endsWith("/login"))) return false; + if (!response.redirected) return false; + const _redirectUrl = new URL(response.url); - if (_redirectUrl.origin === window.location.origin) { - window.location.href = response.url; + if ( + _redirectUrl.origin === window.location.origin && + _redirectUrl.pathname === "/login" + ) { + const currentUrl = `${window.location.pathname}${window.location.search}${window.location.hash}`; + if (currentUrl && currentUrl !== "/login") { + _redirectUrl.searchParams.set("next", currentUrl); + } + window.location.href = _redirectUrl.toString(); + return true; } - return true; -} \ No newline at end of file + + return false; +} diff --git a/webui/login.html b/webui/login.html index 74c7143885..679302192d 100644 --- a/webui/login.html +++ b/webui/login.html @@ -9,7 +9,10 @@