Skip to content
3 changes: 2 additions & 1 deletion app/templates/classifier_page.html
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,8 @@ <h2>{{ heading }}</h2>
aria-expanded="false" aria-haspopup="listbox"
class="border border-gray-400 rounded text-left px-1 py-1 mb-1">
{% for v in versions %}
<option value="{{ v }}" {% if v==url_params.version %}selected{% endif %}>{{ v }}</option>
<option value="{{ v }}" {% if v==(url_params.version if url_params.version else first_version)
%}selected{% endif %}>{{ v }}</option>
{% endfor %}
</select>
</div>
Expand Down
105 changes: 75 additions & 30 deletions app/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,50 @@ def get_sample_cache_profile(sample_path: str):
return STATIC_MEDIA


def build_classification_results_context(
request: Request,
classifier_type: str,
query: str,
version: str,
top_k: int,
) -> dict[str, object]:
"""Build the template context used to render classification results."""
normalized_query = re.sub(r"\s+", " ", query).strip()
upper_type = classifier_type.strip().upper()

if not normalized_query:
return {
"query": normalized_query,
"results_for_query": [],
"base_url": "",
"tooltip": "",
"total_request_time": 0,
}

start_total_time = time.perf_counter()
quantization_cache = getattr(request.app.state, "collection_quantization_cache", {})
zclient = getattr(request.app.state, "zclient", None)
result = perform_classification(
embed_client=request.app.state.embed_client,
qdrant_client=request.app.state.qdrant_client,
query=normalized_query,
classifier_type=upper_type,
version=version,
top_k=top_k,
quantization_cache=quantization_cache,
zclient=zclient,
)
total_request_time = time.perf_counter() - start_total_time

return {
"query": normalized_query,
"results_for_query": result["results"],
"base_url": result["version_config"].get("base_url", ""),
"tooltip": result["version_config"].get("tooltip", ""),
"total_request_time": total_request_time,
}


def get_related_mapping_products(product: MappingProduct) -> list[MappingProduct]:
related_products: list[MappingProduct] = []
for slug in product.related_slugs:
Expand Down Expand Up @@ -390,27 +434,14 @@ async def get_classification_fragment(
add_quota_headers(response, usage_status)
return response

start_total_time = time.perf_counter()

try:
quantization_cache = getattr(
request.app.state, "collection_quantization_cache", {}
)
# Use shared classification service with ZeroEntropy reranking
zclient = getattr(request.app.state, "zclient", None)
result = perform_classification(
embed_client=request.app.state.embed_client,
qdrant_client=request.app.state.qdrant_client,
query=normalized_description,
results_context = build_classification_results_context(
request=request,
classifier_type=upper_type,
query=normalized_description,
version=version,
top_k=top_k,
quantization_cache=quantization_cache,
zclient=zclient,
)

classification_results = result["results"]

except HTTPException:
# Let HTTP exceptions propagate
raise
Expand All @@ -420,9 +451,6 @@ async def get_classification_fragment(
status_code=500, detail=f"Error processing request: {str(e)}"
)

end_total_time = time.perf_counter()
total_request_time = end_total_time - start_total_time

# Calculate dynamic page title for OOB swap
page_title = None
if push_url:
Expand All @@ -433,11 +461,7 @@ async def get_classification_fragment(
request,
"results.html",
{
"query": normalized_description,
"results_for_query": classification_results,
"base_url": result["version_config"].get("base_url", ""),
"tooltip": result["version_config"].get("tooltip", ""),
"total_request_time": total_request_time,
**results_context,
"page_title": page_title,
},
)
Expand Down Expand Up @@ -536,6 +560,7 @@ async def show_classifier_page_with_query(
# Get first version for default handling
versions_list = list(config["versions"].keys())
first_version = versions_list[0] if versions_list else ""
validated_version = version if version in config["versions"] else first_version

# Initialize results data structure
results_data = {
Expand All @@ -549,20 +574,36 @@ async def show_classifier_page_with_query(
raw_example = config["example"].strip()
display_example = raw_example if raw_example else ""

# Determine if we should trigger a search on load
# This is true if we have a URL search query OR if we're falling back to the example
trigger_search_on_load = False
# Track whether the base page is seeded from the configured example text.
# This must always be initialized before the template context is built.
default_example_prefill = False

trigger_search_on_load = False

if decoded_search_query:
trigger_search_on_load = True
else:
# If no search query (base URL), use example query
example_query = raw_example
if example_query:
results_data["query"] = example_query
trigger_search_on_load = True
default_example_prefill = True
results_data["query"] = example_query
try:
results_data = build_classification_results_context(
request=request,
classifier_type=effective_classifier_type,
query=example_query,
version=validated_version,
top_k=top_k,
)
except Exception as e:
logger.warning(
"SSR fallback for '%s' page classification due to %s: %s",
effective_classifier_type,
type(e).__name__,
e,
)
trigger_search_on_load = True

today = datetime.now()
current_year = today.year
Expand All @@ -580,7 +621,11 @@ async def show_classifier_page_with_query(
"example": display_example,
"url_params": {
"search": decoded_search_query,
"version": version if version and version != first_version else "",
"version": (
validated_version
if validated_version and validated_version != first_version
else ""
),
"top_k": top_k,
},
"default_example_prefill": default_example_prefill,
Comment thread
DmitryMatv marked this conversation as resolved.
Expand Down
66 changes: 27 additions & 39 deletions tests/test_request_validation_and_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from html import escape
from pathlib import Path
from unittest.mock import patch

import httpx
from fastapi import FastAPI, Request
Expand Down Expand Up @@ -51,6 +52,11 @@ def _build_web_test_app() -> FastAPI:
"/static", StaticFiles(directory=BASE_DIR / "app" / "static"), name="static"
)
app.include_router(router)
app.state.embed_client = object()
app.state.qdrant_client = object()
app.state.collection_quantization_cache = {}
app.state.zclient = None
app.state.redis_client = object()
return app


Expand Down Expand Up @@ -236,58 +242,40 @@ async def test_search_page_gates_initial_loader_on_auth_ready(self):
response = await client.get(f"/{self.classifier_type}/industrial_pump")

self.assertEqual(response.status_code, 200)
self.assertIn('id="classifier-form"', response.text)
self.assertIn('data-autoload-enabled="true"', response.text)
self.assertIn('data-initial-query-present="true"', response.text)
self.assertIn('data-initial-track-usage="true"', response.text)
self.assertIn('data-default-example-prefill="false"', response.text)
self.assertIn(
f'data-default-top-k="{get_default_top_k(self.classifier_type)}"',
response.text,
)
self.assertIn(
f'data-default-version="{next(iter(self.config["versions"]))}"',
response.text,
)
self.assertIn(">industrial pump</textarea>", response.text)
self.assertNotIn('id="initial-results-loader"', response.text)
self.assertNotIn('name="push_url"', response.text)
self.assertNotIn('name="track_usage"', response.text)
self.assertNotIn("data-auth-gated=", response.text)

async def test_base_page_keeps_immediate_unmetered_initial_loader(self):
async def test_search_page_normalizes_invalid_version_in_rendered_form(self):
transport = httpx.ASGITransport(app=self.app)
default_version = next(iter(self.config["versions"]))

async with httpx.AsyncClient(
transport=transport,
base_url="http://testserver",
) as client:
response = await client.get(f"/{self.classifier_type}/")
response = await client.get(
f"/{self.classifier_type}/industrial_pump",
params={"version": "missing-version"},
)

self.assertEqual(response.status_code, 200)
self.assertIn('id="classifier-form"', response.text)
self.assertIn('data-autoload-enabled="true"', response.text)
self.assertIn('data-initial-query-present="false"', response.text)
self.assertIn('data-initial-track-usage="false"', response.text)
self.assertIn('data-default-example-prefill="true"', response.text)
self.assertIn(
f'data-default-top-k="{get_default_top_k(self.classifier_type)}"',
f'<option value="{escape(default_version)}" selected>',
response.text,
)
self.assertIn(
f'data-default-version="{next(iter(self.config["versions"]))}"',
response.text,
)
example_query = self.config["example"]
self.assertIn(f">{escape(example_query)}</textarea>", response.text)
self.assertIn(
f'placeholder="{escape(example_query)}"',
self.assertNotIn(
'<option value="missing-version" selected>',
response.text,
)
self.assertNotIn('id="initial-results-loader"', response.text)
self.assertNotIn('name="push_url"', response.text)
self.assertNotIn('name="track_usage"', response.text)

async def test_base_page_uses_textarea_value_as_initial_autoload_source(self):
@patch("app.web.perform_classification", side_effect=RuntimeError("no backend"))
async def test_base_page_falls_back_to_initial_loader_when_ssr_cannot_run(
self,
perform_classification_mock,
):
transport = httpx.ASGITransport(app=self.app)

async with httpx.AsyncClient(
Expand All @@ -297,12 +285,12 @@ async def test_base_page_uses_textarea_value_as_initial_autoload_source(self):
response = await client.get(f"/{self.classifier_type}/")

self.assertEqual(response.status_code, 200)
example_query = self.config["example"]
self.assertIn(
f'id="product_description_area" placeholder="{escape(example_query)}"',
response.text,
)
self.assertIn(f">{escape(example_query)}</textarea>", response.text)
self.assertIn('data-autoload-enabled="true"', response.text)
self.assertIn('data-default-example-prefill="true"', response.text)
self.assertIn('data-initial-track-usage="false"', response.text)
self.assertNotIn('id="initial-results-loader"', response.text)
self.assertNotIn("data-auth-gated=", response.text)
perform_classification_mock.assert_called_once()

def _expected_generic_name(self) -> str:
if self.classifier_type == "UNSPSC":
Expand Down
Loading