Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 63 additions & 17 deletions half_json/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,44 @@
import json.decoder
from dataclasses import dataclass
from enum import Enum, auto
from json.decoder import JSONDecoder, py_scanstring # type: ignore[attr-defined]
from json.scanner import py_make_scanner # type: ignore[attr-defined]
from typing import Any
from typing import Callable, Protocol, TypeVar, Union, cast

# JSON value types - proper recursive type definition
# Use quotes for forward references in TypeAlias
JSONValue = Union[
dict[str, "JSONValue"],
list["JSONValue"],
str,
int,
float,
bool,
None,
]


# Protocol for exception with parser attribute
class _ParseError(Protocol):
parser: str
msg: str
pos: int


# Protocol for extended JSONDecoder with monkey-patched attributes
class _ExtendedDecoder(Protocol):
"""JSONDecoder with additional attributes set by _make_decoder."""

parse_string: Callable[..., object]
parse_object: Callable[..., object]
parse_array: Callable[..., object]
scan_once: Callable[[str, int], tuple[JSONValue, int]]


# Protocol for json.decoder module with scanstring attribute
class _JsonDecoderModule(Protocol):
"""json.decoder module with scanstring attribute for patching."""

scanstring: Callable[..., object]
py_scanstring: Callable[..., object] # not in typeshed stubs


class ErrorType(Enum):
Expand Down Expand Up @@ -55,14 +90,17 @@ class ParseContext:
bracket_stack: tuple[str, ...]
nextchar: str
lastchar: str
partial_result: Any = None
partial_result: JSONValue = None
consumed_end: int = 0


def _record_parser_name(parser: Any) -> Any:
_T = TypeVar("_T")


def _record_parser_name(parser: Callable[..., _T]) -> Callable[..., _T]:
"""Decorator that attaches parser name to exceptions."""

def wrapper(*args: Any, **kwargs: Any) -> Any:
def wrapper(*args: object, **kwargs: object) -> _T:
try:
return parser(*args, **kwargs)
except Exception as e:
Expand All @@ -74,23 +112,28 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def _make_decoder(*, strict: bool = True) -> JSONDecoder:
def _make_decoder(*, strict: bool = True) -> _ExtendedDecoder:
"""Create a JSONDecoder with parser-name tracking.

Note: json.decoder.scanstring must be patched at module level because
JSONObject references it from module scope — no way to inject per-decoder.
"""
decoder = JSONDecoder(strict=strict)
decoder.parse_string = _record_parser_name(py_scanstring) # type: ignore[attr-defined]
decoder.parse_object = _record_parser_name(decoder.parse_object) # type: ignore[attr-defined]
decoder.parse_array = _record_parser_name(decoder.parse_array) # type: ignore[attr-defined]
decoder.scan_once = py_make_scanner(decoder) # type: ignore[attr-defined]
# These imports are not exposed in typeshed stubs but exist at runtime
from json.decoder import py_scanstring # type: ignore[attr-defined]
from json.scanner import py_make_scanner # type: ignore[attr-defined]

decoder = cast(_ExtendedDecoder, json.JSONDecoder(strict=strict))
decoder.parse_string = _record_parser_name(py_scanstring)
decoder.parse_object = _record_parser_name(decoder.parse_object)
decoder.parse_array = _record_parser_name(decoder.parse_array)
decoder.scan_once = py_make_scanner(decoder)
return decoder


# Patch json.decoder.scanstring once so JSONObject uses our tracked version.
# This is unavoidable: JSONObject hard-references the module-level scanstring.
json.decoder.scanstring = _record_parser_name(py_scanstring) # type: ignore[attr-defined]
_decoder_module = cast(_JsonDecoderModule, json.decoder)
_decoder_module.scanstring = _record_parser_name(_decoder_module.py_scanstring)

_decoder_strict = _make_decoder(strict=True)
_decoder_unstrict = _make_decoder(strict=False)
Expand Down Expand Up @@ -120,7 +163,8 @@ def diagnose(text: str, *, strict: bool = True) -> ParseContext | None:

decoder = _decoder_strict if strict else _decoder_unstrict
try:
obj, end = decoder.scan_once(text, 0) # type: ignore[attr-defined]
obj: JSONValue
obj, end = decoder.scan_once(text, 0)
if end == len(text):
return None # valid JSON
# Partial parse — decoded something but there's leftover
Expand Down Expand Up @@ -148,15 +192,17 @@ def diagnose(text: str, *, strict: bool = True) -> ParseContext | None:
)
except ValueError as e:
parser = e.__dict__.get("parser", "")
etype = _classify_error(parser, e.msg) # type: ignore[attr-defined]
# Cast to _ParseError protocol for type checking
err = cast(_ParseError, e)
etype = _classify_error(parser, err.msg)
if etype is None:
return None # unknown error, treat as unfixable
pos = e.pos # type: ignore[attr-defined]
pos = err.pos
return ParseContext(
input=text,
error_type=etype,
pos=pos,
message=e.msg, # type: ignore[attr-defined]
message=err.msg,
bracket_stack=build_bracket_stack(text, pos),
nextchar=text[pos : pos + 1],
lastchar=text[pos - 1 : pos],
Expand Down