From a43d4b62244b5332bf42001c73b0dc9665ca7985 Mon Sep 17 00:00:00 2001 From: snus-kin Date: Mon, 23 Feb 2026 15:03:12 +0000 Subject: [PATCH 1/2] feat: enable reportAny --- pyproject.toml | 1 - .../client/athena_client.py | 7 +- src/resolver_athena_client/client/channel.py | 20 +- .../client/transformers/core.py | 6 +- tests/client/test_athena_client.py | 37 ++- tests/client/test_channel.py | 263 +++++++++++++----- tests/client/test_deployment_selector.py | 31 ++- tests/client/test_timeout_behavior.py | 12 +- tests/functional/conftest.py | 3 +- tests/functional/e2e/test_classify_single.py | 11 +- tests/functional/e2e/testcases/parser.py | 19 +- tests/test_classify_single.py | 60 ++-- 12 files changed, 341 insertions(+), 129 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a5e62f8..ad5b0cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ venvPath = "." venv = ".venv" stubPath = "stubs" reportImplicitStringConcatenation = false -reportAny = false [tool.pytest.ini_options] markers = [ diff --git a/src/resolver_athena_client/client/athena_client.py b/src/resolver_athena_client/client/athena_client.py index 2001128..e00c5f3 100644 --- a/src/resolver_athena_client/client/athena_client.py +++ b/src/resolver_athena_client/client/athena_client.py @@ -410,8 +410,8 @@ async def shutdown_worker( ) -> None: """Safely shutdown a single worker, handling mocks/errors.""" try: - shutdown_method = getattr(worker_batcher, "shutdown", None) - if shutdown_method and callable(shutdown_method): + if hasattr(worker_batcher, "shutdown"): + shutdown_method = worker_batcher.shutdown shutdown_coro = shutdown_method() # Only await if it's actually a coroutine (not a mock) if asyncio.iscoroutine(shutdown_coro): @@ -421,8 +421,7 @@ async def shutdown_worker( self.logger.debug( "Skipping non-coroutine shutdown method" ) - else: - self.logger.debug("Worker has no shutdown method") + except (AttributeError, TypeError): # Worker doesn't have shutdown method or it's not callable self.logger.debug("Worker shutdown failed, skipping") diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index 497341a..da9b927 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -5,7 +5,7 @@ import threading import time from dataclasses import dataclass -from typing import override +from typing import cast, override import grpc import httpx @@ -230,10 +230,18 @@ def _refresh_token(self) -> None: ) _ = response.raise_for_status() - raw = response.json() - access_token: str = raw["access_token"] - expires_in: int = raw.get("expires_in", 3600) # Default 1 hour - token_type = raw.get("token_type", "Bearer") + raw = cast("dict[str, object]", response.json()) + access_token = str(raw["access_token"]) + expires_in_raw = raw.get("expires_in") + expires_in: int = ( + int(cast("int", expires_in_raw)) + if expires_in_raw is not None + else 3600 + ) + token_type_raw = raw.get("token_type") + token_type: str = ( + str(token_type_raw) if token_type_raw is not None else "Bearer" + ) scheme: str = token_type.strip() if token_type else "Bearer" current_time = time.time() self._token_data = TokenData( @@ -247,7 +255,7 @@ def _refresh_token(self) -> None: except httpx.HTTPStatusError as e: error_detail = "" try: - error_data = e.response.json() + error_data = cast("dict[str, str]", e.response.json()) error_desc = error_data.get( "error_description", error_data.get("error", "") ) diff --git a/src/resolver_athena_client/client/transformers/core.py b/src/resolver_athena_client/client/transformers/core.py index 8041a28..d9803a1 100644 --- a/src/resolver_athena_client/client/transformers/core.py +++ b/src/resolver_athena_client/client/transformers/core.py @@ -7,6 +7,7 @@ import asyncio import enum +from typing import cast import brotli import cv2 as cv @@ -72,7 +73,10 @@ def process_image() -> tuple[bytes, bool]: err = "Failed to decode image data for resizing" raise ValueError(err) - if img.shape[0] == EXPECTED_HEIGHT and img.shape[1] == EXPECTED_WIDTH: + shape = cast("tuple[int, int, int]", img.shape) + height: int = shape[0] + width: int = shape[1] + if height == EXPECTED_HEIGHT and width == EXPECTED_WIDTH: resized_img = img else: resized_img = cv.resize( diff --git a/tests/client/test_athena_client.py b/tests/client/test_athena_client.py index 33f83f6..f5aaae9 100644 --- a/tests/client/test_athena_client.py +++ b/tests/client/test_athena_client.py @@ -2,6 +2,7 @@ import asyncio import contextlib +from typing import cast from unittest import mock import pytest @@ -64,9 +65,10 @@ async def test_classify_images_success( # Setup mock classifier client with mock.patch( - "resolver_athena_client.client.athena_client.ClassifierServiceClient" + "resolver_athena_client.client.athena_client.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) # Create mock stream that returns our responses mock_classify = MockAsyncIterator(test_responses) @@ -121,9 +123,10 @@ async def test_client_context_manager_success( ) # Success response will have default empty global_error with mock.patch( - "resolver_athena_client.client.athena_client.ClassifierServiceClient" + "resolver_athena_client.client.athena_client.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) # Create mock stream that returns our response mock_classify = MockAsyncIterator([init_response]) @@ -157,7 +160,8 @@ async def get_one_response() -> None: await classify_task # Verify channel was closed - mock_channel.close.assert_called_once() + close_mock = cast("mock.MagicMock", mock_channel.close) + close_mock.assert_called_once() @pytest.mark.asyncio @@ -176,9 +180,10 @@ async def test_client_context_manager_error( ) with mock.patch( - "resolver_athena_client.client.athena_client.ClassifierServiceClient" + "resolver_athena_client.client.athena_client.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) # Create mock stream that returns our error response mock_classify = MockAsyncIterator([error_response]) @@ -225,9 +230,10 @@ async def test_client_transformers_disabled( ) with mock.patch( - "resolver_athena_client.client.athena_client.ClassifierServiceClient" + "resolver_athena_client.client.athena_client.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) mock_classify = MockAsyncIterator([test_response]) mock_client.classify = mock_classify @@ -277,9 +283,10 @@ async def test_client_transformers_enabled( ) with mock.patch( - "resolver_athena_client.client.athena_client.ClassifierServiceClient" + "resolver_athena_client.client.athena_client.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) mock_classify = MockAsyncIterator([test_response]) mock_client.classify = mock_classify @@ -337,13 +344,14 @@ async def test_client_num_workers_configuration( with ( mock.patch( - "resolver_athena_client.client.athena_client.ClassifierServiceClient" + "resolver_athena_client.client.athena_client.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls, mock.patch( "resolver_athena_client.client.athena_client.WorkerBatcher" ) as mock_worker_batcher_cls, ): - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) mock_classify = MockAsyncIterator([test_response]) mock_client.classify = mock_classify @@ -391,4 +399,5 @@ async def test_client_close( await client.close() - mock_channel.close.assert_called_once() + close_mock = cast("mock.MagicMock", mock_channel.close) + close_mock.assert_called_once() diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 0a538fa..0cc4687 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -3,6 +3,7 @@ # Ideally we don't use private attributes in the tests but hard to test without import time +from typing import cast from unittest import mock import httpx @@ -48,7 +49,8 @@ async def test_create_channel_does_not_eagerly_fetch_token() -> None: _ = await create_channel_with_credentials(test_host, mock_helper) # Token should NOT be fetched at channel creation time - mock_helper.get_token.assert_not_called() + get_token_mock = cast("mock.MagicMock", mock_helper.get_token) + get_token_mock.assert_not_called() class TestCredentialHelper: @@ -179,17 +181,30 @@ def test_get_token_success(self) -> None: client_secret="test_client_secret", ) - mock_response = mock.Mock() - mock_response.json.return_value = { + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { "access_token": "new_access_token", "expires_in": 3600, "token_type": "Bearer", } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.return_value = mock_response + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response token_data = helper.get_token() @@ -205,17 +220,30 @@ def test_get_token_respects_token_type(self) -> None: client_secret="test_client_secret", ) - mock_response = mock.Mock() - mock_response.json.return_value = { + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { "access_token": "some_token", "expires_in": 3600, "token_type": "DPoP", } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.return_value = mock_response + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response token_data = helper.get_token() @@ -228,16 +256,29 @@ def test_get_token_defaults_to_bearer(self) -> None: client_secret="test_client_secret", ) - mock_response = mock.Mock() - mock_response.json.return_value = { + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { "access_token": "some_token", "expires_in": 3600, } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.return_value = mock_response + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response token_data = helper.get_token() @@ -260,19 +301,30 @@ def test_get_token_preserves_server_casing(self) -> None: client_secret="test_client_secret", ) - mock_response = mock.Mock() - mock_response.json.return_value = { + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { "access_token": "test_token", "expires_in": 3600, "token_type": server_type, } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None with mock.patch("httpx.Client") as mock_client: - mock_response_obj = ( - mock_client.return_value.__enter__.return_value + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ ) - mock_response_obj.post.return_value = mock_response + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response token_data = helper.get_token() @@ -307,9 +359,10 @@ def test_refresh_token_http_error(self) -> None: client_secret="test_client_secret", ) - mock_response = mock.Mock() + mock_response = mock.Mock(spec=httpx.Response) mock_response.status_code = 401 - mock_response.json.return_value = { + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { "error": "invalid_client", "error_description": "Invalid client credentials", } @@ -321,8 +374,17 @@ def test_refresh_token_http_error(self) -> None: ) with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.side_effect = http_error + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.side_effect = http_error with pytest.raises( OAuthError, match="OAuth request failed with status 401" @@ -339,8 +401,17 @@ def test_refresh_token_request_error(self) -> None: request_error = httpx.RequestError("Connection failed") with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.side_effect = request_error + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.side_effect = request_error with pytest.raises( OAuthError, match="Failed to connect to OAuth server" @@ -354,15 +425,28 @@ def test_refresh_token_invalid_response(self) -> None: client_secret="test_client_secret", ) - mock_response = mock.Mock() - mock_response.json.return_value = { + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { "invalid_field": "missing_access_token", } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.return_value = mock_response + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response with pytest.raises( OAuthError, match="Invalid OAuth response format" @@ -404,21 +488,37 @@ def test_get_token_refreshes_after_invalidation(self) -> None: ) helper.invalidate_token() - mock_response = mock.Mock() - mock_response.json.return_value = { - "access_token": "refreshed_token", + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { + "access_token": "fresh_token", "expires_in": 3600, - "token_type": "bearer", + "token_type": "Bearer", } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.return_value = mock_response + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response - token_data = helper.get_token() + # First call should fetch a new token + token = helper.get_token() - assert token_data.access_token == "refreshed_token" + # Verify network call was made + post_mock.assert_called_once() + assert token.access_token == "fresh_token" class TestAutoRefreshTokenAuthMetadataPlugin: @@ -427,7 +527,8 @@ class TestAutoRefreshTokenAuthMetadataPlugin: def test_plugin_passes_bearer_token_to_callback(self) -> None: """Plugin fetches token and passes Bearer metadata.""" mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.return_value = TokenData( + get_token_mock = cast("mock.MagicMock", mock_helper.get_token) + get_token_mock.return_value = TokenData( access_token="test-bearer-token", expires_at=time.time() + 3600, scheme="Bearer", @@ -440,14 +541,16 @@ def test_plugin_passes_bearer_token_to_callback(self) -> None: plugin(mock_context, mock_callback) - mock_helper.get_token.assert_called_once() + get_token_mock = cast("mock.MagicMock", mock_helper.get_token) + get_token_mock.assert_called_once() expected_metadata = (("authorization", "Bearer test-bearer-token"),) mock_callback.assert_called_once_with(expected_metadata, None) def test_plugin_respects_token_scheme(self) -> None: """Plugin uses the scheme from TokenData, not hardcoded Bearer.""" mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.return_value = TokenData( + get_token_mock = cast("mock.MagicMock", mock_helper.get_token) + get_token_mock.return_value = TokenData( access_token="dpop-token", expires_at=time.time() + 3600, scheme="Dpop", @@ -467,7 +570,8 @@ def test_plugin_passes_oauth_error_to_callback(self) -> None: """Test that OAuthError is forwarded to the callback as an error.""" mock_helper = mock.Mock(spec=CredentialHelper) oauth_error = OAuthError("token acquisition failed") - mock_helper.get_token.side_effect = oauth_error + get_token_mock = cast("mock.MagicMock", mock_helper.get_token) + get_token_mock.side_effect = oauth_error plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) mock_callback = mock.Mock() @@ -481,7 +585,8 @@ def test_plugin_catches_unexpected_exceptions(self) -> None: """Non-OAuthError exceptions are forwarded to callback.""" mock_helper = mock.Mock(spec=CredentialHelper) runtime_error = RuntimeError("unexpected failure") - mock_helper.get_token.side_effect = runtime_error + get_token_mock = cast("mock.MagicMock", mock_helper.get_token) + get_token_mock.side_effect = runtime_error plugin = _AutoRefreshTokenAuthMetadataPlugin(mock_helper) mock_callback = mock.Mock() @@ -582,7 +687,8 @@ def test_background_refresh_does_not_start_if_already_running(self) -> None: # Mock a running refresh thread mock_thread = mock.Mock() - mock_thread.is_alive.return_value = True + is_alive_mock = cast("mock.MagicMock", mock_thread.is_alive) + is_alive_mock.return_value = True helper._refresh_thread = mock_thread with mock.patch("threading.Thread") as mock_thread_class: @@ -603,7 +709,8 @@ def test_background_refresh_starts_new_thread_if_none_exists(self) -> None: helper._start_background_refresh() # Should have started the thread - mock_thread.start.assert_called_once() + start_mock = cast("mock.MagicMock", mock_thread.start) + start_mock.assert_called_once() def test_background_refresh_silently_handles_errors(self) -> None: """Test that background refresh silently ignores errors.""" @@ -657,24 +764,37 @@ def test_get_token_blocks_for_expired_token(self) -> None: issued_at=time.time() - 3700, ) - mock_response = mock.Mock() - mock_response.json.return_value = { - "access_token": "new_token", + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { + "access_token": "new_access_token", "expires_in": 3600, - "token_type": "bearer", + "token_type": "Bearer", } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.return_value = mock_response + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response token_data = helper.get_token() # Should have refreshed and returned new token - assert token_data.access_token == "new_token" + assert token_data.access_token == "new_access_token" # Should have called the OAuth endpoint - mock_response_obj.post.assert_called_once() + post_mock.assert_called_once() def test_refresh_token_sets_issued_at(self) -> None: """Test that _refresh_token sets the issued_at timestamp.""" @@ -683,18 +803,31 @@ def test_refresh_token_sets_issued_at(self) -> None: client_secret="test_client_secret", ) - mock_response = mock.Mock() - mock_response.json.return_value = { - "access_token": "new_token", + mock_response = mock.Mock(spec=httpx.Response) + json_mock = cast("mock.MagicMock", mock_response.json) + json_mock.return_value = { + "access_token": "test_token", "expires_in": 3600, - "token_type": "bearer", + "token_type": "Bearer", } - mock_response.raise_for_status.return_value = None + raise_for_status_mock = cast( + "mock.MagicMock", mock_response.raise_for_status + ) + raise_for_status_mock.return_value = None before_time = time.time() with mock.patch("httpx.Client") as mock_client: - mock_response_obj = mock_client.return_value.__enter__.return_value - mock_response_obj.post.return_value = mock_response + mock_client_instance = cast( + "mock.MagicMock", mock_client.return_value + ) + mock_context = cast( + "mock.MagicMock", mock_client_instance.__enter__ + ) + mock_response_obj = cast( + "mock.MagicMock", mock_context.return_value + ) + post_mock = cast("mock.MagicMock", mock_response_obj.post) + post_mock.return_value = mock_response _ = helper.get_token() diff --git a/tests/client/test_deployment_selector.py b/tests/client/test_deployment_selector.py index 84d3dfa..7e5c6a2 100644 --- a/tests/client/test_deployment_selector.py +++ b/tests/client/test_deployment_selector.py @@ -1,5 +1,6 @@ """Tests for deployment selector.""" +from typing import cast from unittest import mock import pytest @@ -39,9 +40,10 @@ async def test_list_deployments_success(mock_channel: mock.Mock) -> None: # Setup mock with mock.patch( - "resolver_athena_client.client.deployment_selector.ClassifierServiceClient" + "resolver_athena_client.client.deployment_selector.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) mock_client.list_deployments = mock.AsyncMock( return_value=expected_response ) @@ -61,7 +63,10 @@ async def test_list_deployments_success(mock_channel: mock.Mock) -> None: # Verify client interaction mock_client_cls.assert_called_once_with(mock_channel) - mock_client.list_deployments.assert_called_once() + list_deployments_mock = cast( + "mock.AsyncMock", mock_client.list_deployments + ) + list_deployments_mock.assert_called_once() @pytest.mark.asyncio @@ -72,9 +77,10 @@ async def test_list_deployments_empty(mock_channel: mock.Mock) -> None: # Setup mock with mock.patch( - "resolver_athena_client.client.deployment_selector.ClassifierServiceClient" + "resolver_athena_client.client.deployment_selector.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) mock_client.list_deployments = mock.AsyncMock( return_value=empty_response ) @@ -89,7 +95,10 @@ async def test_list_deployments_empty(mock_channel: mock.Mock) -> None: # Verify client interaction mock_client_cls.assert_called_once_with(mock_channel) - mock_client.list_deployments.assert_called_once() + list_deployments_mock = cast( + "mock.AsyncMock", mock_client.list_deployments + ) + list_deployments_mock.assert_called_once() @pytest.mark.asyncio @@ -97,9 +106,10 @@ async def test_list_deployments_client_error(mock_channel: mock.Mock) -> None: """Test deployment listing when client raises an error.""" # Setup mock to raise error with mock.patch( - "resolver_athena_client.client.deployment_selector.ClassifierServiceClient" + "resolver_athena_client.client.deployment_selector.ClassifierServiceClient", + spec=ClassifierServiceClient, ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) mock_client.list_deployments = mock.AsyncMock( side_effect=RuntimeError("Test error") ) @@ -113,4 +123,7 @@ async def test_list_deployments_client_error(mock_channel: mock.Mock) -> None: # Verify client interaction mock_client_cls.assert_called_once_with(mock_channel) - mock_client.list_deployments.assert_called_once() + list_deployments_mock = cast( + "mock.AsyncMock", mock_client.list_deployments + ) + list_deployments_mock.assert_called_once() diff --git a/tests/client/test_timeout_behavior.py b/tests/client/test_timeout_behavior.py index 34fa837..8feeddd 100644 --- a/tests/client/test_timeout_behavior.py +++ b/tests/client/test_timeout_behavior.py @@ -4,7 +4,7 @@ import contextlib import time from collections.abc import AsyncIterator -from typing import Self, TypeVar, override +from typing import Self, TypeVar, cast, override from unittest import mock import grpc @@ -82,7 +82,7 @@ async def test_timeout_behavior() -> None: with mock.patch( "resolver_athena_client.client.athena_client.ClassifierServiceClient" ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) # Create an iterator that will wait longer than the timeout mock_classify = SlowMockAsyncIterator(test_responses, delay=0.02) mock_client.classify = mock.AsyncMock(return_value=mock_classify) @@ -138,7 +138,7 @@ async def test_infinite_timeout() -> None: with mock.patch( "resolver_athena_client.client.athena_client.ClassifierServiceClient" ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) # Create an iterator with significant delays mock_classify = SlowMockAsyncIterator(test_responses, delay=0.02) mock_client.classify = mock.AsyncMock(return_value=mock_classify) @@ -202,7 +202,7 @@ async def test_custom_timeout() -> None: with mock.patch( "resolver_athena_client.client.athena_client.ClassifierServiceClient" ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) # Create an iterator with delays between responses mock_classify = SlowMockAsyncIterator(test_responses, delay=0.015) mock_client.classify = mock.AsyncMock(return_value=mock_classify) @@ -258,7 +258,7 @@ async def test_timeout_with_errors() -> None: with mock.patch( "resolver_athena_client.client.athena_client.ClassifierServiceClient" ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) # Use our custom MockGrpcError error = MockGrpcError( code=StatusCode.INTERNAL, @@ -315,7 +315,7 @@ async def test_timeout_with_cancellation() -> None: with mock.patch( "resolver_athena_client.client.athena_client.ClassifierServiceClient" ) as mock_client_cls: - mock_client = mock_client_cls.return_value + mock_client = cast("mock.MagicMock", mock_client_cls.return_value) mock_classify = SlowMockAsyncIterator(test_responses, delay=0.01) mock_client.classify = mock.AsyncMock(return_value=mock_classify) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 4d58a8c..fc9d782 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -1,5 +1,6 @@ import os import uuid +from typing import cast import cv2 as cv import numpy as np @@ -111,7 +112,7 @@ def valid_formatted_image( Images are cached to disk to avoid regenerating on every test run. """ - image_format = request.param + image_format = cast("str", request.param) image_dir = tmp_path_factory.mktemp("images") base_image = _create_base_test_image_opencv(EXPECTED_WIDTH, EXPECTED_HEIGHT) diff --git a/tests/functional/e2e/test_classify_single.py b/tests/functional/e2e/test_classify_single.py index 06d26a4..446239f 100644 --- a/tests/functional/e2e/test_classify_single.py +++ b/tests/functional/e2e/test_classify_single.py @@ -19,10 +19,19 @@ FP_ERROR_TOLERANCE = 1e-4 +def _get_test_case_id(tc: AthenaTestCase) -> str: + """Get the test case ID for pytest parametrize.""" + return tc.id + + @pytest.mark.asyncio @pytest.mark.functional @pytest.mark.e2e -@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda tc: tc.id) +@pytest.mark.parametrize( + "test_case", + TEST_CASES, + ids=_get_test_case_id, +) async def test_classify_single( athena_options: AthenaOptions, credential_helper: CredentialHelper, diff --git a/tests/functional/e2e/testcases/parser.py b/tests/functional/e2e/testcases/parser.py index 80d7901..8df93b3 100644 --- a/tests/functional/e2e/testcases/parser.py +++ b/tests/functional/e2e/testcases/parser.py @@ -1,5 +1,6 @@ import json from pathlib import Path +from typing import cast # Path to the shared testcases directory in athena-protobufs _REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent @@ -27,12 +28,20 @@ def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]: with Path.open( Path(TESTCASES_DIR / dirname / "expected_outputs.json"), ) as f: - test_cases = json.load(f) + test_cases = cast( + "dict[str, list[str] | list[list[str | list[float]]]]", json.load(f) + ) + classification_labels = cast( + "list[str]", test_cases["classification_labels"] + ) + images = cast("list[list[str | list[float]]]", test_cases["images"]) return [ AthenaTestCase( - str(Path(TESTCASES_DIR / dirname / "images" / item[0])), - item[1], - test_cases["classification_labels"], + str( + Path(TESTCASES_DIR / dirname / "images" / cast("str", item[0])) + ), + cast("list[float]", item[1]), + classification_labels, ) - for item in test_cases["images"] + for item in images ] diff --git a/tests/test_classify_single.py b/tests/test_classify_single.py index d517ca3..3d0eff0 100644 --- a/tests/test_classify_single.py +++ b/tests/test_classify_single.py @@ -1,6 +1,7 @@ """Tests for the classify_single method in AthenaClient.""" import uuid +from typing import cast from unittest.mock import AsyncMock, Mock import cv2 as cv @@ -93,8 +94,11 @@ async def test_classify_single_success( assert not result.HasField("error") # Verify the call was made with correct parameters - athena_client.classifier.classify_single.assert_called_once() - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + classify_single_mock.assert_called_once() + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) assert isinstance(call_args, ClassificationInput) assert call_args.affiliate == "test-affiliate" @@ -130,7 +134,10 @@ async def test_classify_single_with_correlation_id( _ = await athena_client.classify_single(copied_image_data) # Verify correlation ID was used - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) assert call_args.correlation_id == custom_correlation_id @@ -152,13 +159,15 @@ async def test_classify_single_auto_correlation_id( _ = await athena_client.classify_single(sample_image_data) # Verify a correlation ID was generated - call_args = athena_client.classifier.classify_single.call_args[0][0] - assert call_args.correlation_id is not None - assert len(call_args.correlation_id) > 0 + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) + correlation_id = call_args.correlation_id + assert correlation_id is not None + assert len(correlation_id) > 0 # Should be a valid UUID format - _ = uuid.UUID( - call_args.correlation_id - ) # This will raise if not a valid UUID + _ = uuid.UUID(correlation_id) # This will raise if not a valid UUID @pytest.mark.asyncio @@ -182,7 +191,10 @@ async def test_classify_single_with_compression( _ = await athena_client.classify_single(sample_image_data) # Verify compression settings were applied - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) assert call_args.encoding == RequestEncoding.REQUEST_ENCODING_BROTLI # Data should be compressed - check it's the same as the modified image data assert call_args.data == sample_image_data.data @@ -218,7 +230,10 @@ async def test_classify_single_error_handling( _ = await athena_client.classify_single(valid_image_data) # Verify resizing was processed (encoding should be uncompressed) - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) assert call_args.encoding == RequestEncoding.REQUEST_ENCODING_UNCOMPRESSED @@ -284,7 +299,8 @@ async def test_classify_single_timeout_parameter( _ = await athena_client.classify_single(sample_image_data) # Verify timeout was passed - call_kwargs = athena_client.classifier.classify_single.call_args[1] + classify_single_mock = athena_client.classifier.classify_single + call_kwargs = cast("dict[str, float]", classify_single_mock.call_args[1]) expected_timeout = 30.0 # From the fixture options assert call_kwargs["timeout"] == expected_timeout @@ -314,7 +330,10 @@ async def test_classify_single_multiple_hashes( _ = await athena_client.classify_single(image_data) # Verify all hashes were included - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) expected_hash_count = 3 # Original + 2 transformations assert len(call_args.hashes) == expected_hash_count for hash_obj in call_args.hashes: @@ -405,7 +424,10 @@ async def test_classify_single_with_png_format( _ = await athena_client.classify_single(image_data) # Verify PNG format was detected and sent - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) assert call_args.format == ImageFormat.IMAGE_FORMAT_PNG @@ -431,7 +453,10 @@ async def test_classify_single_with_jpeg_format( _ = await athena_client.classify_single(image_data) # Verify JPEG format was detected and sent - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) assert call_args.format == ImageFormat.IMAGE_FORMAT_JPEG @@ -456,6 +481,9 @@ async def test_classify_single_never_sends_unspecified( _ = await athena_client.classify_single(unknown_data) # Verify UNSPECIFIED was converted to RAW_UINT8 - call_args = athena_client.classifier.classify_single.call_args[0][0] + classify_single_mock = athena_client.classifier.classify_single + call_args = cast( + "ClassificationInput", classify_single_mock.call_args[0][0] + ) assert call_args.format != ImageFormat.IMAGE_FORMAT_UNSPECIFIED assert call_args.format == ImageFormat.IMAGE_FORMAT_RAW_UINT8_BGR From eac65cbb9328d94119419adb06a4a15bceeb0b5a Mon Sep 17 00:00:00 2001 From: Thomas Carroll Date: Wed, 25 Feb 2026 17:30:17 +0000 Subject: [PATCH 2/2] refactor: use TypedDict --- tests/functional/e2e/testcases/parser.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/functional/e2e/testcases/parser.py b/tests/functional/e2e/testcases/parser.py index 8df93b3..5e72fd8 100644 --- a/tests/functional/e2e/testcases/parser.py +++ b/tests/functional/e2e/testcases/parser.py @@ -1,12 +1,17 @@ import json from pathlib import Path -from typing import cast +from typing import TypedDict, cast # Path to the shared testcases directory in athena-protobufs _REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent TESTCASES_DIR = _REPO_ROOT / "athena-protobufs" / "testcases" +class TestCases(TypedDict): + classification_labels: list[str] + images: list[list[str | list[float]]] + + class AthenaTestCase: def __init__( self, @@ -28,13 +33,9 @@ def load_test_cases(dirname: str = "benign_model") -> list[AthenaTestCase]: with Path.open( Path(TESTCASES_DIR / dirname / "expected_outputs.json"), ) as f: - test_cases = cast( - "dict[str, list[str] | list[list[str | list[float]]]]", json.load(f) - ) - classification_labels = cast( - "list[str]", test_cases["classification_labels"] - ) - images = cast("list[list[str | list[float]]]", test_cases["images"]) + test_cases: TestCases = cast("TestCases", json.load(f)) + classification_labels = test_cases["classification_labels"] + images = test_cases["images"] return [ AthenaTestCase( str(