|
19 | 19 |
|
20 | 20 | import celpy |
21 | 21 | from celpy import celtypes |
22 | | -from google.protobuf import ( # type: ignore[attr-defined] |
| 22 | +from google.protobuf import ( |
23 | 23 | any_pb2, |
24 | 24 | descriptor, |
25 | 25 | duration_pb2, |
|
33 | 33 | from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has |
34 | 34 |
|
35 | 35 | # protobuf 7+ removed FieldDescriptor.label / LABEL_REPEATED in favour of is_repeated. |
36 | | -if hasattr(descriptor.FieldDescriptor, "is_repeated"): |
| 36 | +_FieldDescriptorClass = descriptor.FieldDescriptor |
| 37 | +if hasattr(_FieldDescriptorClass, "is_repeated"): |
37 | 38 |
|
38 | 39 | def _is_repeated(field: descriptor.FieldDescriptor) -> bool: |
39 | | - return field.is_repeated # type: ignore[attr-defined] |
| 40 | + return field.is_repeated |
40 | 41 |
|
41 | 42 | else: |
42 | 43 |
|
43 | 44 | def _is_repeated(field: descriptor.FieldDescriptor) -> bool: |
44 | | - return field.label == descriptor.FieldDescriptor.LABEL_REPEATED # type: ignore[attr-defined] |
| 45 | + return field.label == descriptor.FieldDescriptor.LABEL_REPEATED |
45 | 46 |
|
46 | 47 |
|
47 | 48 | class CompilationError(Exception): |
@@ -93,14 +94,14 @@ def __init__(self, msg: message.Message): |
93 | 94 | continue |
94 | 95 | self[field.name] = field_to_cel(self.msg, field) |
95 | 96 |
|
96 | | - def __getitem__(self, name): |
97 | | - field = self.desc.fields_by_name[name] |
98 | | - if field.has_presence and not self.msg.HasField(name): |
| 97 | + def __getitem__(self, key): |
| 98 | + field = self.desc.fields_by_name[key] |
| 99 | + if field.has_presence and not self.msg.HasField(key): |
99 | 100 | if in_has(): |
100 | 101 | raise KeyError() |
101 | 102 | else: |
102 | 103 | return _zero_value(field) |
103 | | - return super().__getitem__(name) |
| 104 | + return super().__getitem__(key) |
104 | 105 |
|
105 | 106 |
|
106 | 107 | def _msg_to_cel(msg: message.Message) -> celtypes.Value: |
@@ -153,14 +154,14 @@ def _get_type_ctor(fd: typing.Any) -> typing.Callable[..., celtypes.Value] | Non |
153 | 154 |
|
154 | 155 | def _proto_message_has_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any: |
155 | 156 | if field.is_extension: |
156 | | - return msg.HasExtension(field) # type: ignore |
| 157 | + return msg.HasExtension(field) # ty: ignore[invalid-argument-type] |
157 | 158 | else: |
158 | 159 | return msg.HasField(field.name) |
159 | 160 |
|
160 | 161 |
|
161 | 162 | def _proto_message_get_field(msg: message.Message, field: descriptor.FieldDescriptor) -> typing.Any: |
162 | 163 | if field.is_extension: |
163 | | - return msg.Extensions[field] # type: ignore |
| 164 | + return msg.Extensions[field] # ty: ignore[invalid-argument-type] |
164 | 165 | else: |
165 | 166 | return getattr(msg, field.name) |
166 | 167 |
|
@@ -322,7 +323,7 @@ def sub_context(self) -> "RuleContext": |
322 | 323 | class Rules: |
323 | 324 | """The rules associated with a single 'rules' message.""" |
324 | 325 |
|
325 | | - def validate(self, ctx: RuleContext, _: message.Message): |
| 326 | + def validate(self, ctx: RuleContext, message: message.Message): # noqa: ARG002 |
326 | 327 | """Validate the message against the rules in this rule.""" |
327 | 328 | ctx.add(Violation(rule_id="unimplemented", message="Unimplemented")) |
328 | 329 |
|
@@ -440,8 +441,8 @@ def __init__(self, fields: list[descriptor.FieldDescriptor], *, required: bool): |
440 | 441 | self._fields = fields |
441 | 442 | self._required = required |
442 | 443 |
|
443 | | - def validate(self, ctx: RuleContext, msg: message.Message): |
444 | | - num_set_fields = sum(1 for field in self._fields if not _is_empty_field(msg, field)) |
| 444 | + def validate(self, ctx: RuleContext, message: message.Message): |
| 445 | + num_set_fields = sum(1 for field in self._fields if not _is_empty_field(message, field)) |
445 | 446 | if num_set_fields > 1: |
446 | 447 | ctx.add( |
447 | 448 | Violation( |
@@ -586,8 +587,8 @@ def __init__( |
586 | 587 | # For each set field in the message, look for the private rule |
587 | 588 | # extension. |
588 | 589 | for list_field, _ in rules.ListFields(): |
589 | | - if validate_pb2.predefined in list_field.GetOptions().Extensions: # type: ignore |
590 | | - for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: # type: ignore |
| 590 | + if validate_pb2.predefined in list_field.GetOptions().Extensions: # ty: ignore[unsupported-operator] |
| 591 | + for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: # ty: ignore[invalid-argument-type] |
591 | 592 | self.add_rule( |
592 | 593 | env, |
593 | 594 | funcs, |
@@ -642,11 +643,13 @@ def validate(self, ctx: RuleContext, message: message.Message): |
642 | 643 | sub_ctx.add_field_path_element(element) |
643 | 644 | ctx.add_errors(sub_ctx) |
644 | 645 |
|
645 | | - def validate_item(self, ctx: RuleContext, val: typing.Any, *, for_key: bool = False): |
646 | | - self._validate_value(ctx, val, for_key=for_key) |
647 | | - self._validate_cel(ctx, this_value=val, this_cel=_scalar_field_value_to_cel(val, self._field), for_key=for_key) |
| 646 | + def validate_item(self, ctx: RuleContext, value: typing.Any, *, for_key: bool = False): |
| 647 | + self._validate_value(ctx, value, for_key=for_key) |
| 648 | + self._validate_cel( |
| 649 | + ctx, this_value=value, this_cel=_scalar_field_value_to_cel(value, self._field), for_key=for_key |
| 650 | + ) |
648 | 651 |
|
649 | | - def _validate_value(self, ctx: RuleContext, val: typing.Any, *, for_key: bool = False): |
| 652 | + def _validate_value(self, ctx: RuleContext, value: typing.Any, *, for_key: bool = False): |
650 | 653 | pass |
651 | 654 |
|
652 | 655 |
|
@@ -1079,8 +1082,8 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]: |
1079 | 1082 | result: list[Rules] = [] |
1080 | 1083 | rule: Rules | None = None |
1081 | 1084 | all_msg_oneof_fields = set() |
1082 | | - if desc.GetOptions().HasExtension(validate_pb2.message): # type: ignore |
1083 | | - message_level = desc.GetOptions().Extensions[validate_pb2.message] # type: ignore |
| 1085 | + if desc.GetOptions().HasExtension(validate_pb2.message): # ty: ignore[invalid-argument-type] |
| 1086 | + message_level = desc.GetOptions().Extensions[validate_pb2.message] # ty: ignore[invalid-argument-type] |
1084 | 1087 | for oneof in message_level.oneof: |
1085 | 1088 | all_msg_oneof_fields.update(oneof.fields) |
1086 | 1089 | if rule := self._new_message_rule(message_level, desc): |
|
0 commit comments