diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index 7ae64e3..5ebf9d0 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -21,12 +21,12 @@ else: from typing_extensions import Self - _COLUMN_ATTR = "__dataframely_columns__" _RULE_ATTR = "__dataframely_rules__" ORIGINAL_COLUMN_PREFIX = "__DATAFRAMELY_ORIGINAL__" + # --------------------------------------- UTILS -------------------------------------- # @@ -206,11 +206,17 @@ def _get_metadata(source: dict[str, Any]) -> Metadata: result.columns[value.alias or attr] = value if isinstance(value, RuleFactory): # We must ensure that custom rules do not clash with internal rules. - if attr == "primary_key": + name = value.name or attr + if name == "primary_key": raise ImplementationError( "Custom validation rule must not be named `primary_key`." ) - result.rules[attr] = value + if name in result.rules: + raise ImplementationError( + f"Duplicate validation rule name '{name}' found. " + f"Custom validation rules must have unique names." + ) + result.rules[name] = value return result def __repr__(cls) -> str: diff --git a/dataframely/_rule.py b/dataframely/_rule.py index 6c4e788..5952da3 100644 --- a/dataframely/_rule.py +++ b/dataframely/_rule.py @@ -99,10 +99,14 @@ class RuleFactory: """Factory class for rules created within schemas.""" def __init__( - self, validation_fn: Callable[[Any], pl.Expr], group_columns: list[str] | None + self, + validation_fn: Callable[[Any], pl.Expr], + group_columns: list[str] | None, + name: str | None = None, ) -> None: self.validation_fn = validation_fn self.group_columns = group_columns + self.name = name @classmethod def from_rule(cls, rule: Rule) -> Self: @@ -125,7 +129,7 @@ def make(self, schema: Any) -> Rule: def rule( - *, group_by: list[str] | None = None + *, group_by: list[str] | None = None, name: str | None = None ) -> Callable[[ValidationFunction], RuleFactory]: """Mark a function as a rule to evaluate during validation. @@ -148,6 +152,8 @@ def rule( of rows. If this list is provided, the returned expression must return a single boolean value, i.e. some kind of aggregation function must be used (e.g. `sum`, `any`, ...). + name: A custom name for the rule for user-friendly display. + By default, the name of the decorated function will be used. Note: You'll need to explicitly handle `null` values in your columns when defining @@ -163,7 +169,9 @@ def rule( """ def decorator(validation_fn: ValidationFunction) -> RuleFactory: - return RuleFactory(validation_fn=validation_fn, group_columns=group_by) + return RuleFactory( + validation_fn=validation_fn, group_columns=group_by, name=name + ) return decorator diff --git a/docs/guides/faq.md b/docs/guides/faq.md index d5e24d8..c4eefc3 100644 --- a/docs/guides/faq.md +++ b/docs/guides/faq.md @@ -30,6 +30,23 @@ class UserSchema(dy.Schema): return pl.col("email").is_null() | pl.col("email").is_unique() ``` +## How do I give custom names to the rules in a {class}`~dataframely.Schema`? + +By default, the name of a rule is the name of the function that implements it. +However, you can also provide a custom name for a rule by using the `name` parameter of the `@dy.rule` decorator: + +```python +class UserSchema(dy.Schema): + user_id = dy.UInt64(primary_key=True, nullable=False) + + @dy.rule(name="my-custom-name") + def irrelevant_function_name(cls) -> pl.Expr: + return cls.user_id.col != 42 +``` + +Whenever `dataframely` needs to refer to this rule, it will use the custom name `my-custom-name` instead of the function +name `irrelevant_function_name`. + ## How do I fix the ruff error `First argument of a method should be named self`? See our documentation on [group rules](./quickstart.md#group-rules). diff --git a/docs/guides/quickstart.md b/docs/guides/quickstart.md index 4eb341a..cee303d 100644 --- a/docs/guides/quickstart.md +++ b/docs/guides/quickstart.md @@ -64,6 +64,12 @@ The decorator `@dy.rule()` "registers" the function as a rule using its name (i. The returned expression provides a boolean value for each row of the data which evaluates to `True` whenever the data are valid with respect to this rule. +```{note} +New in `dataframely` v2.8.0: You can now set custom names on rules by specifying the `name` kwarg. +For example, `@dy.rule(name="my-custom-name")` would register the rule under the name `my-custom-name` instead of the function name. +This will affect how validation errors are reported and how the rule is referred to in the `FailureInfo` object. +``` + ## Group rules For defining even more complex rules, the `@dy.rule` decorator allows for a `group_by` diff --git a/tests/schema/test_filter.py b/tests/schema/test_filter.py index 99557e7..6db744a 100644 --- a/tests/schema/test_filter.py +++ b/tests/schema/test_filter.py @@ -275,3 +275,37 @@ def test_filter_details(eager: bool) -> None: "primary_key": "invalid", }, ] + + +@pytest.mark.parametrize("eager", [True, False]) +def test_filter_custom_rule_name(eager: bool) -> None: + """Verify that we can set a custom rule name on a non-group rule.""" + + class MySchema(dy.Schema): + a = dy.Int64() + + @dy.rule(name="custom_rule_name") + def my_rule(cls) -> pl.Expr: + return cls.a.col != 42 + + df = pl.DataFrame({"a": [1, 42, 3]}) + _, fails = _filter_and_collect(MySchema, df, cast=True, eager=eager) + assert fails.counts() == {"custom_rule_name": 1} + + +@pytest.mark.parametrize("eager", [True, False]) +def test_filter_custom_group_rule_name(eager: bool) -> None: + """Verify that we can set a custom rule name on a group rule.""" + + class MySchema(dy.Schema): + a = dy.String() + b = dy.Int64() + + @dy.rule(name="custom_rule_name", group_by=["a"]) + def my_rule(cls) -> pl.Expr: + return cls.b.col.sum() < 10 + + df = pl.DataFrame({"a": ["x", "x", "y"], "b": [75, 75, 75]}) + _, fails = _filter_and_collect(MySchema, df, cast=True, eager=eager) + + assert fails.counts() == {"custom_rule_name": 3} diff --git a/tests/schema/test_rule_implementation.py b/tests/schema/test_rule_implementation.py index 7305101..7633d0b 100644 --- a/tests/schema/test_rule_implementation.py +++ b/tests/schema/test_rule_implementation.py @@ -39,3 +39,35 @@ def test_rule_column_overlap_error() -> None: columns={"test": dy.Integer(alias="a")}, rules={"a": Rule(pl.col("a") > 0)}, ) + + +def test_rule_custom_illegal_name() -> None: + with pytest.raises( + dy.exc.ImplementationError, + match="Custom validation rule must not be named `primary_key`.", + ): + + class MySchema(dy.Schema): + a = dy.String() + + @dy.rule(name="primary_key") + def my_rule(cls) -> pl.Expr: + return cls.a.col < 10 + + +def test_rule_custom_duplicate_name() -> None: + with pytest.raises( + dy.exc.ImplementationError, + match="Custom validation rules must have unique names", + ): + + class MySchema(dy.Schema): + a = dy.String() + + @dy.rule(name="custom") + def my_rule1(cls) -> pl.Expr: + return cls.a.col < 10 + + @dy.rule(name="custom") + def my_rule2(cls) -> pl.Expr: + return cls.a.col > 5 diff --git a/tests/schema/test_serialization.py b/tests/schema/test_serialization.py index c8a1cd3..6cf4417 100644 --- a/tests/schema/test_serialization.py +++ b/tests/schema/test_serialization.py @@ -30,6 +30,16 @@ def test_simple_serialization() -> None: assert set(decoded["rules"].keys()) == set() +class CustomRuleNameSchema(dy.Schema): + """A schema with a custom rule name.""" + + x = dy.Int64() + + @dy.rule(name="custom_rule") + def irrelevant_name_here(cls) -> pl.Expr: + return cls.x.col > 5 + + @pytest.mark.parametrize( "schema", [ @@ -59,6 +69,7 @@ def test_simple_serialization() -> None: "test", {"a": dy.Struct({"x": dy.Int64(min=5, check=lambda expr: expr < 10)})}, ), + CustomRuleNameSchema, ], ) def test_roundtrip_matches(schema: type[dy.Schema]) -> None: