Skip to content

Commit d2214e0

Browse files
Switch from mypy to ty (#465)
Redux of #418 - `ty` is still iterating, but now used over in connect-python (connectrpc/connect-python#242); seems reasonable to be consistent across our projects.
1 parent 70cde78 commit d2214e0

6 files changed

Lines changed: 55 additions & 192 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ conformance: $(BIN)/protovalidate-conformance generate install ## Run conformanc
8181
lint: install $(BIN)/buf ## Lint code
8282
buf format -d --exit-code
8383
uv run -- ruff format --check --diff protovalidate test
84-
uv run -- mypy protovalidate
84+
uv run -- ty check protovalidate
8585
uv run -- ruff check protovalidate test
8686
uv run -- tombi format --check
8787
uv run -- tombi lint

protovalidate/internal/cel_field_presence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def in_has() -> bool:
3333

3434

3535
class InterpretedRunner(celpy.InterpretedRunner):
36-
def evaluate(self, context) -> celpy.celtypes.Value:
36+
def evaluate(self, context: celpy.Context) -> celpy.celtypes.Value:
3737
class Evaluator(celpy.Evaluator):
3838
def macro_has_eval(self, exprlist) -> celpy.celtypes.BoolType:
3939
_has_state.in_has = True

protovalidate/internal/rules.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import celpy
2121
from celpy import celtypes
22-
from google.protobuf import ( # type: ignore[attr-defined]
22+
from google.protobuf import (
2323
any_pb2,
2424
descriptor,
2525
duration_pb2,
@@ -33,15 +33,16 @@
3333
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has
3434

3535
# protobuf 7+ removed FieldDescriptor.label / LABEL_REPEATED in favour of is_repeated.
36-
if hasattr(descriptor.FieldDescriptor, "is_repeated"):
36+
_FieldDescriptorClass = descriptor.FieldDescriptor
37+
if hasattr(_FieldDescriptorClass, "is_repeated"):
3738

3839
def _is_repeated(field: descriptor.FieldDescriptor) -> bool:
39-
return field.is_repeated # type: ignore[attr-defined]
40+
return field.is_repeated
4041

4142
else:
4243

4344
def _is_repeated(field: descriptor.FieldDescriptor) -> bool:
44-
return field.label == descriptor.FieldDescriptor.LABEL_REPEATED # type: ignore[attr-defined]
45+
return field.label == descriptor.FieldDescriptor.LABEL_REPEATED
4546

4647

4748
class CompilationError(Exception):
@@ -93,14 +94,14 @@ def __init__(self, msg: message.Message):
9394
continue
9495
self[field.name] = field_to_cel(self.msg, field)
9596

96-
def __getitem__(self, name):
97-
field = self.desc.fields_by_name[name]
98-
if field.has_presence and not self.msg.HasField(name):
97+
def __getitem__(self, key):
98+
field = self.desc.fields_by_name[key]
99+
if field.has_presence and not self.msg.HasField(key):
99100
if in_has():
100101
raise KeyError()
101102
else:
102103
return _zero_value(field)
103-
return super().__getitem__(name)
104+
return super().__getitem__(key)
104105

105106

106107
def _msg_to_cel(msg: message.Message) -> celtypes.Value:
@@ -153,14 +154,14 @@ def _get_type_ctor(fd: typing.Any) -> typing.Callable[..., celtypes.Value] | Non
153154

154155
def _proto_message_has_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any:
155156
if field.is_extension:
156-
return msg.HasExtension(field) # type: ignore
157+
return msg.HasExtension(field) # ty: ignore[invalid-argument-type]
157158
else:
158159
return msg.HasField(field.name)
159160

160161

161162
def _proto_message_get_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any:
162163
if field.is_extension:
163-
return msg.Extensions[field] # type: ignore
164+
return msg.Extensions[field] # ty: ignore[invalid-argument-type]
164165
else:
165166
return getattr(msg, field.name)
166167

@@ -322,7 +323,7 @@ def sub_context(self) -> "RuleContext":
322323
class Rules:
323324
"""The rules associated with a single 'rules' message."""
324325

325-
def validate(self, ctx: RuleContext, _: message.Message):
326+
def validate(self, ctx: RuleContext, message: message.Message): # noqa: ARG002
326327
"""Validate the message against the rules in this rule."""
327328
ctx.add(Violation(rule_id="unimplemented", message="Unimplemented"))
328329

@@ -440,8 +441,8 @@ def __init__(self, fields: list[descriptor.FieldDescriptor], *, required: bool):
440441
self._fields = fields
441442
self._required = required
442443

443-
def validate(self, ctx: RuleContext, msg: message.Message):
444-
num_set_fields = sum(1 for field in self._fields if not _is_empty_field(msg, field))
444+
def validate(self, ctx: RuleContext, message: message.Message):
445+
num_set_fields = sum(1 for field in self._fields if not _is_empty_field(message, field))
445446
if num_set_fields > 1:
446447
ctx.add(
447448
Violation(
@@ -586,8 +587,8 @@ def __init__(
586587
# For each set field in the message, look for the private rule
587588
# extension.
588589
for list_field, _ in rules.ListFields():
589-
if validate_pb2.predefined in list_field.GetOptions().Extensions: # type: ignore
590-
for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: # type: ignore
590+
if validate_pb2.predefined in list_field.GetOptions().Extensions: # ty: ignore[unsupported-operator]
591+
for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: # ty: ignore[invalid-argument-type]
591592
self.add_rule(
592593
env,
593594
funcs,
@@ -642,11 +643,13 @@ def validate(self, ctx: RuleContext, message: message.Message):
642643
sub_ctx.add_field_path_element(element)
643644
ctx.add_errors(sub_ctx)
644645

645-
def validate_item(self, ctx: RuleContext, val: typing.Any, *, for_key: bool = False):
646-
self._validate_value(ctx, val, for_key=for_key)
647-
self._validate_cel(ctx, this_value=val, this_cel=_scalar_field_value_to_cel(val, self._field), for_key=for_key)
646+
def validate_item(self, ctx: RuleContext, value: typing.Any, *, for_key: bool = False):
647+
self._validate_value(ctx, value, for_key=for_key)
648+
self._validate_cel(
649+
ctx, this_value=value, this_cel=_scalar_field_value_to_cel(value, self._field), for_key=for_key
650+
)
648651

649-
def _validate_value(self, ctx: RuleContext, val: typing.Any, *, for_key: bool = False):
652+
def _validate_value(self, ctx: RuleContext, value: typing.Any, *, for_key: bool = False):
650653
pass
651654

652655

@@ -1079,8 +1082,8 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
10791082
result: list[Rules] = []
10801083
rule: Rules | None = None
10811084
all_msg_oneof_fields = set()
1082-
if desc.GetOptions().HasExtension(validate_pb2.message): # type: ignore
1083-
message_level = desc.GetOptions().Extensions[validate_pb2.message] # type: ignore
1085+
if desc.GetOptions().HasExtension(validate_pb2.message): # ty: ignore[invalid-argument-type]
1086+
message_level = desc.GetOptions().Extensions[validate_pb2.message] # ty: ignore[invalid-argument-type]
10841087
for oneof in message_level.oneof:
10851088
all_msg_oneof_fields.update(oneof.fields)
10861089
if rule := self._new_message_rule(message_level, desc):

protovalidate/validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def collect_violations(
8484
break
8585
for violation in ctx.violations:
8686
if violation.proto.HasField("field"):
87-
violation.proto.field.elements.reverse() # type: ignore
87+
violation.proto.field.elements.reverse()
8888
if violation.proto.HasField("rule"):
89-
violation.proto.rule.elements.reverse() # type: ignore
89+
violation.proto.rule.elements.reverse()
9090
return ctx.violations
9191

9292

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ Source = "https://github.com/bufbuild/protovalidate-python"
4444
[dependency-groups]
4545
dev = [
4646
"google-re2-stubs==0.1.1",
47-
"mypy==1.20.2",
4847
"pytest==9.0.3",
4948
"ruff==0.15.12",
5049
"tombi==0.10.5",
50+
"ty==0.0.34",
5151
"types-protobuf==6.32.1.20260221",
5252
]
5353

@@ -59,9 +59,6 @@ build-backend = "hatchling.build"
5959
source = "vcs"
6060
raw-options = { fallback_version = "0.0.0" }
6161

62-
[tool.mypy]
63-
mypy_path = "gen"
64-
6562
[tool.pytest]
6663
# Turn all warnings into errors,
6764
# except DeprecationWarnings (which we knowingly tolerate due to using old `protobuf` APIs).

0 commit comments

Comments
 (0)