diff --git a/half_json/diagnosis.py b/half_json/diagnosis.py index bb7d7bf..683ca21 100644 --- a/half_json/diagnosis.py +++ b/half_json/diagnosis.py @@ -4,9 +4,44 @@ import json.decoder from dataclasses import dataclass from enum import Enum, auto -from json.decoder import JSONDecoder, py_scanstring # type: ignore[attr-defined] -from json.scanner import py_make_scanner # type: ignore[attr-defined] -from typing import Any +from typing import Callable, Protocol, TypeVar, Union, cast + +# JSON value types - proper recursive type definition +# Use quotes for forward references in TypeAlias +JSONValue = Union[ + dict[str, "JSONValue"], + list["JSONValue"], + str, + int, + float, + bool, + None, +] + + +# Protocol for exception with parser attribute +class _ParseError(Protocol): + parser: str + msg: str + pos: int + + +# Protocol for extended JSONDecoder with monkey-patched attributes +class _ExtendedDecoder(Protocol): + """JSONDecoder with additional attributes set by _make_decoder.""" + + parse_string: Callable[..., object] + parse_object: Callable[..., object] + parse_array: Callable[..., object] + scan_once: Callable[[str, int], tuple[JSONValue, int]] + + +# Protocol for json.decoder module with scanstring attribute +class _JsonDecoderModule(Protocol): + """json.decoder module with scanstring attribute for patching.""" + + scanstring: Callable[..., object] + py_scanstring: Callable[..., object] # not in typeshed stubs class ErrorType(Enum): @@ -55,14 +90,17 @@ class ParseContext: bracket_stack: tuple[str, ...] nextchar: str lastchar: str - partial_result: Any = None + partial_result: JSONValue = None consumed_end: int = 0 -def _record_parser_name(parser: Any) -> Any: +_T = TypeVar("_T") + + +def _record_parser_name(parser: Callable[..., _T]) -> Callable[..., _T]: """Decorator that attaches parser name to exceptions.""" - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: object, **kwargs: object) -> _T: try: return parser(*args, **kwargs) except Exception as e: @@ -74,23 +112,28 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def _make_decoder(*, strict: bool = True) -> JSONDecoder: +def _make_decoder(*, strict: bool = True) -> _ExtendedDecoder: """Create a JSONDecoder with parser-name tracking. Note: json.decoder.scanstring must be patched at module level because JSONObject references it from module scope — no way to inject per-decoder. """ - decoder = JSONDecoder(strict=strict) - decoder.parse_string = _record_parser_name(py_scanstring) # type: ignore[attr-defined] - decoder.parse_object = _record_parser_name(decoder.parse_object) # type: ignore[attr-defined] - decoder.parse_array = _record_parser_name(decoder.parse_array) # type: ignore[attr-defined] - decoder.scan_once = py_make_scanner(decoder) # type: ignore[attr-defined] + # These imports are not exposed in typeshed stubs but exist at runtime + from json.decoder import py_scanstring # type: ignore[attr-defined] + from json.scanner import py_make_scanner # type: ignore[attr-defined] + + decoder = cast(_ExtendedDecoder, json.JSONDecoder(strict=strict)) + decoder.parse_string = _record_parser_name(py_scanstring) + decoder.parse_object = _record_parser_name(decoder.parse_object) + decoder.parse_array = _record_parser_name(decoder.parse_array) + decoder.scan_once = py_make_scanner(decoder) return decoder # Patch json.decoder.scanstring once so JSONObject uses our tracked version. # This is unavoidable: JSONObject hard-references the module-level scanstring. -json.decoder.scanstring = _record_parser_name(py_scanstring) # type: ignore[attr-defined] +_decoder_module = cast(_JsonDecoderModule, json.decoder) +_decoder_module.scanstring = _record_parser_name(_decoder_module.py_scanstring) _decoder_strict = _make_decoder(strict=True) _decoder_unstrict = _make_decoder(strict=False) @@ -120,7 +163,8 @@ def diagnose(text: str, *, strict: bool = True) -> ParseContext | None: decoder = _decoder_strict if strict else _decoder_unstrict try: - obj, end = decoder.scan_once(text, 0) # type: ignore[attr-defined] + obj: JSONValue + obj, end = decoder.scan_once(text, 0) if end == len(text): return None # valid JSON # Partial parse — decoded something but there's leftover @@ -148,15 +192,17 @@ def diagnose(text: str, *, strict: bool = True) -> ParseContext | None: ) except ValueError as e: parser = e.__dict__.get("parser", "") - etype = _classify_error(parser, e.msg) # type: ignore[attr-defined] + # Cast to _ParseError protocol for type checking + err = cast(_ParseError, e) + etype = _classify_error(parser, err.msg) if etype is None: return None # unknown error, treat as unfixable - pos = e.pos # type: ignore[attr-defined] + pos = err.pos return ParseContext( input=text, error_type=etype, pos=pos, - message=e.msg, # type: ignore[attr-defined] + message=err.msg, bracket_stack=build_bracket_stack(text, pos), nextchar=text[pos : pos + 1], lastchar=text[pos - 1 : pos],