Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ venvPath = "."
venv = ".venv"
stubPath = "stubs"
reportImplicitStringConcatenation = false
reportAny = false

[tool.pytest.ini_options]
markers = [
Expand Down
7 changes: 3 additions & 4 deletions src/resolver_athena_client/client/athena_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,8 @@ async def shutdown_worker(
) -> None:
"""Safely shutdown a single worker, handling mocks/errors."""
try:
shutdown_method = getattr(worker_batcher, "shutdown", None)
if shutdown_method and callable(shutdown_method):
if hasattr(worker_batcher, "shutdown"):
shutdown_method = worker_batcher.shutdown
shutdown_coro = shutdown_method()
# Only await if it's actually a coroutine (not a mock)
if asyncio.iscoroutine(shutdown_coro):
Expand All @@ -421,8 +421,7 @@ async def shutdown_worker(
self.logger.debug(
"Skipping non-coroutine shutdown method"
)
else:
self.logger.debug("Worker has no shutdown method")

except (AttributeError, TypeError):
# Worker doesn't have shutdown method or it's not callable
self.logger.debug("Worker shutdown failed, skipping")
Expand Down
20 changes: 14 additions & 6 deletions src/resolver_athena_client/client/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import time
from dataclasses import dataclass
from typing import override
from typing import cast, override

import grpc
import httpx
Expand Down Expand Up @@ -230,10 +230,18 @@ def _refresh_token(self) -> None:
)
_ = response.raise_for_status()

raw = response.json()
access_token: str = raw["access_token"]
expires_in: int = raw.get("expires_in", 3600) # Default 1 hour
token_type = raw.get("token_type", "Bearer")
raw = cast("dict[str, object]", response.json())
access_token = str(raw["access_token"])
expires_in_raw = raw.get("expires_in")
expires_in: int = (
int(cast("int", expires_in_raw))
if expires_in_raw is not None
else 3600
)
token_type_raw = raw.get("token_type")
token_type: str = (
str(token_type_raw) if token_type_raw is not None else "Bearer"
)
scheme: str = token_type.strip() if token_type else "Bearer"
current_time = time.time()
self._token_data = TokenData(
Expand All @@ -247,7 +255,7 @@ def _refresh_token(self) -> None:
except httpx.HTTPStatusError as e:
error_detail = ""
try:
error_data = e.response.json()
error_data = cast("dict[str, str]", e.response.json())
error_desc = error_data.get(
"error_description", error_data.get("error", "")
)
Expand Down
6 changes: 5 additions & 1 deletion src/resolver_athena_client/client/transformers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import enum
from typing import cast

import brotli
import cv2 as cv
Expand Down Expand Up @@ -72,7 +73,10 @@ def process_image() -> tuple[bytes, bool]:
err = "Failed to decode image data for resizing"
raise ValueError(err)

if img.shape[0] == EXPECTED_HEIGHT and img.shape[1] == EXPECTED_WIDTH:
shape = cast("tuple[int, int, int]", img.shape)
height: int = shape[0]
width: int = shape[1]
if height == EXPECTED_HEIGHT and width == EXPECTED_WIDTH:
resized_img = img
else:
resized_img = cv.resize(
Expand Down
37 changes: 23 additions & 14 deletions tests/client/test_athena_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import contextlib
from typing import cast
from unittest import mock

import pytest
Expand Down Expand Up @@ -64,9 +65,10 @@ async def test_classify_images_success(

# Setup mock classifier client
with mock.patch(
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
spec=ClassifierServiceClient,
) as mock_client_cls:
mock_client = mock_client_cls.return_value
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)

# Create mock stream that returns our responses
mock_classify = MockAsyncIterator(test_responses)
Expand Down Expand Up @@ -121,9 +123,10 @@ async def test_client_context_manager_success(
) # Success response will have default empty global_error

with mock.patch(
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
spec=ClassifierServiceClient,
) as mock_client_cls:
mock_client = mock_client_cls.return_value
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)

# Create mock stream that returns our response
mock_classify = MockAsyncIterator([init_response])
Expand Down Expand Up @@ -157,7 +160,8 @@ async def get_one_response() -> None:
await classify_task

# Verify channel was closed
mock_channel.close.assert_called_once()
close_mock = cast("mock.MagicMock", mock_channel.close)
close_mock.assert_called_once()


@pytest.mark.asyncio
Expand All @@ -176,9 +180,10 @@ async def test_client_context_manager_error(
)

with mock.patch(
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
spec=ClassifierServiceClient,
) as mock_client_cls:
mock_client = mock_client_cls.return_value
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)

# Create mock stream that returns our error response
mock_classify = MockAsyncIterator([error_response])
Expand Down Expand Up @@ -225,9 +230,10 @@ async def test_client_transformers_disabled(
)

with mock.patch(
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
spec=ClassifierServiceClient,
) as mock_client_cls:
mock_client = mock_client_cls.return_value
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
mock_classify = MockAsyncIterator([test_response])
mock_client.classify = mock_classify

Expand Down Expand Up @@ -277,9 +283,10 @@ async def test_client_transformers_enabled(
)

with mock.patch(
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
spec=ClassifierServiceClient,
) as mock_client_cls:
mock_client = mock_client_cls.return_value
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
mock_classify = MockAsyncIterator([test_response])
mock_client.classify = mock_classify

Expand Down Expand Up @@ -337,13 +344,14 @@ async def test_client_num_workers_configuration(

with (
mock.patch(
"resolver_athena_client.client.athena_client.ClassifierServiceClient"
"resolver_athena_client.client.athena_client.ClassifierServiceClient",
spec=ClassifierServiceClient,
) as mock_client_cls,
mock.patch(
"resolver_athena_client.client.athena_client.WorkerBatcher"
) as mock_worker_batcher_cls,
):
mock_client = mock_client_cls.return_value
mock_client = cast("mock.MagicMock", mock_client_cls.return_value)
mock_classify = MockAsyncIterator([test_response])
mock_client.classify = mock_classify

Expand Down Expand Up @@ -391,4 +399,5 @@ async def test_client_close(

await client.close()

mock_channel.close.assert_called_once()
close_mock = cast("mock.MagicMock", mock_channel.close)
close_mock.assert_called_once()
Loading