Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion helpers/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
import json
import threading
from urllib.parse import urlsplit, unquote
from functools import wraps
from pathlib import Path
from typing import Union, Dict, Any
Expand Down Expand Up @@ -102,6 +103,44 @@ def use_context(self, ctxid: str, create_if_not_exists: bool = True):
from helpers.network import is_loopback_address


def is_safe_next_url(value: str | None) -> bool:
"""Return True when value is a safe same-origin redirect target."""
if not value:
return False
if "\r" in value or "\n" in value:
return False
# Reject raw backslashes (browsers normalize `/\host` to `//host` -> external).
if "\\" in value:
return False

# Decode percent-escapes so encoded backslashes (e.g. `%5C`) are caught too.
decoded = unquote(value)
if "\\" in decoded:
return False

parsed = urlsplit(decoded)
if parsed.scheme or parsed.netloc:
return False

# Require an absolute path within this origin, but reject protocol-relative URLs.
return parsed.path.startswith("/") and not parsed.path.startswith("//")


def get_safe_next_url(value: str | None, fallback: str | None = None) -> str | None:
"""Return value if it is a safe next URL, otherwise return a safe fallback."""
if is_safe_next_url(value):
return value
if is_safe_next_url(fallback):
return fallback
return None


def get_current_request_next_url() -> str:
"""Return the current request path/query as a safe relative redirect target."""
next_url = request.full_path if request.query_string else request.path
return get_safe_next_url(next_url, url_for("serve_index")) or url_for("serve_index")


def requires_api_key(f):
@wraps(f)
async def decorated(*args, **kwargs):
Expand Down Expand Up @@ -142,7 +181,7 @@ async def decorated(*args, **kwargs):
if not user_pass_hash:
return await f(*args, **kwargs)
if session.get("authentication") != user_pass_hash:
return redirect(url_for("login_handler"))
return redirect(url_for("login_handler", next=get_current_request_next_url()))
return await f(*args, **kwargs)

return decorated
Expand Down
12 changes: 9 additions & 3 deletions helpers/ui_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import socketio # type: ignore[import-untyped]

from helpers import dotenv, fasta2a_server, files, git, login, mcp_server, runtime
from helpers.api import register_api_route, requires_auth
from helpers.api import get_safe_next_url, register_api_route, requires_auth
from helpers.extension import extensible
from helpers.files import get_abs_path
from helpers.print_style import PrintStyle
Expand Down Expand Up @@ -200,19 +200,25 @@ def __init__(self, runtime_state: UiServerRuntime) -> None:
@extensible
async def login_handler(self):
error = None
fallback_url = url_for("serve_index")
next_url = get_safe_next_url(
request.form.get("next") if request.method == "POST" else request.args.get("next"),
fallback_url,
)

if request.method == "POST":
user = dotenv.get_dotenv_value("AUTH_LOGIN")
password = dotenv.get_dotenv_value("AUTH_PASSWORD")

if request.form["username"] == user and request.form["password"] == password:
session["authentication"] = login.get_credentials_hash()
return redirect(url_for("serve_index"))
return redirect(next_url or fallback_url)
else:
await asyncio.sleep(1)
error = "Invalid Credentials. Please try again."

login_page_content = files.read_file("webui/login.html")
return render_template_string(login_page_content, error=error)
return render_template_string(login_page_content, error=error, next=next_url)

@extensible
async def logout_handler(self):
Expand Down
69 changes: 69 additions & 0 deletions tests/test_http_auth_csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,72 @@ async def secure():
_set_csrf_cookie(client, "csrf-4")
response = client.get("/secure")
assert response.status_code == 200


def test_safe_next_url_accepts_plugin_page_path() -> None:
from helpers.api import get_safe_next_url, is_safe_next_url

target = "/plugins/a0_voqualizer/webui/voqualizer.html"
assert is_safe_next_url(target)
assert get_safe_next_url(target, "/") == target


def test_safe_next_url_preserves_query_string() -> None:
from helpers.api import get_safe_next_url

target = "/plugins/a0_voqualizer/webui/voqualizer.html?context=rlO1iMV7"
assert get_safe_next_url(target, "/") == target


def test_safe_next_url_rejects_external_and_protocol_relative_urls() -> None:
from helpers.api import get_safe_next_url, is_safe_next_url

fallback = "/"
for value in [
"https://evil.example/plugins/a0_voqualizer/webui/voqualizer.html",
"//evil.example/plugins/a0_voqualizer/webui/voqualizer.html",
"javascript:alert(1)",
"/safe\nLocation: https://evil.example",
]:
assert not is_safe_next_url(value)
assert get_safe_next_url(value, fallback) == fallback


def test_auth_redirect_includes_original_path_and_query(monkeypatch) -> None:
from run_ui import requires_auth

monkeypatch.setattr("helpers.login.get_credentials_hash", lambda: "hash")

app = _make_app()

@app.get("/plugins/a0_voqualizer/webui/voqualizer.html")
@requires_auth
async def voqualizer_page():
return Response("ok", status=200)

client = app.test_client()
response = client.get("/plugins/a0_voqualizer/webui/voqualizer.html?context=rlO1iMV7")
assert response.status_code == 302
location = response.headers["Location"]
assert location.startswith("/login?next=")
assert "%2Fplugins%2Fa0_voqualizer%2Fwebui%2Fvoqualizer.html%3Fcontext%3DrlO1iMV7" in location


def test_is_safe_next_url_rejects_backslash_open_redirects() -> None:
from helpers.api import is_safe_next_url

# Raw backslash forms
assert is_safe_next_url("/\\evil.example") is False
assert is_safe_next_url("\\/evil.example") is False
assert is_safe_next_url("/path\\evil") is False

# Percent-encoded backslash forms
assert is_safe_next_url("/%5Cevil.example") is False
assert is_safe_next_url("%5C/evil.example") is False
assert is_safe_next_url("/%5cevil.example") is False # lowercase hex

# Mixed / double-encoded edge
assert is_safe_next_url("/path/%5Cevil") is False

# Sanity: a legitimate relative path still passes
assert is_safe_next_url("/plugins/a0_voqualizer/webui/voqualizer.html") is True
20 changes: 15 additions & 5 deletions webui/js/api.js
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,20 @@ function _normalizeApiUrl(url) {
}

function redirect(response) {
if (!(response.redirected && response.url.endsWith("/login"))) return false;
if (!response.redirected) return false;

const _redirectUrl = new URL(response.url);
if (_redirectUrl.origin === window.location.origin) {
window.location.href = response.url;
if (
_redirectUrl.origin === window.location.origin &&
_redirectUrl.pathname === "/login"
) {
const currentUrl = `${window.location.pathname}${window.location.search}${window.location.hash}`;
if (currentUrl && currentUrl !== "/login") {
_redirectUrl.searchParams.set("next", currentUrl);
}
window.location.href = _redirectUrl.toString();
return true;
}
return true;
}

return false;
}
5 changes: 4 additions & 1 deletion webui/login.html
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
</head>
<body>
<div class="login-container">
<form class="login-form" method="POST" action="/login">
<form class="login-form" method="POST" action="/login{% if next %}?next={{ next|urlencode }}{% endif %}">
{% if next %}
<input type="hidden" name="next" value="{{ next }}">
{% endif %}
<img src="/public/splash.jpg" alt="Agent Zero Logo" class="logo">
<h2>Agent Zero</h2>
<div class="input-group">
Expand Down