Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
else:
from typing_extensions import Self


_COLUMN_ATTR = "__dataframely_columns__"
_RULE_ATTR = "__dataframely_rules__"

ORIGINAL_COLUMN_PREFIX = "__DATAFRAMELY_ORIGINAL__"


# --------------------------------------- UTILS -------------------------------------- #


Expand Down Expand Up @@ -206,11 +206,17 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
result.columns[value.alias or attr] = value
if isinstance(value, RuleFactory):
# We must ensure that custom rules do not clash with internal rules.
if attr == "primary_key":
name = value.name or attr
if name == "primary_key":
raise ImplementationError(
"Custom validation rule must not be named `primary_key`."
)
result.rules[attr] = value
if name in result.rules:
raise ImplementationError(
f"Duplicate validation rule name '{name}' found. "
f"Custom validation rules must have unique names."
)
result.rules[name] = value
return result

def __repr__(cls) -> str:
Expand Down
14 changes: 11 additions & 3 deletions dataframely/_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,14 @@ class RuleFactory:
"""Factory class for rules created within schemas."""

def __init__(
self, validation_fn: Callable[[Any], pl.Expr], group_columns: list[str] | None
self,
validation_fn: Callable[[Any], pl.Expr],
group_columns: list[str] | None,
name: str | None = None,
) -> None:
self.validation_fn = validation_fn
self.group_columns = group_columns
self.name = name

@classmethod
def from_rule(cls, rule: Rule) -> Self:
Expand All @@ -125,7 +129,7 @@ def make(self, schema: Any) -> Rule:


def rule(
*, group_by: list[str] | None = None
*, group_by: list[str] | None = None, name: str | None = None
) -> Callable[[ValidationFunction], RuleFactory]:
"""Mark a function as a rule to evaluate during validation.

Expand All @@ -148,6 +152,8 @@ def rule(
of rows. If this list is provided, the returned expression must return a
single boolean value, i.e. some kind of aggregation function must be used
(e.g. `sum`, `any`, ...).
name: A custom name for the rule for user-friendly display.
By default, the name of the decorated function will be used.

Note:
You'll need to explicitly handle `null` values in your columns when defining
Expand All @@ -163,7 +169,9 @@ def rule(
"""

def decorator(validation_fn: ValidationFunction) -> RuleFactory:
return RuleFactory(validation_fn=validation_fn, group_columns=group_by)
return RuleFactory(
validation_fn=validation_fn, group_columns=group_by, name=name
)

return decorator

Expand Down
17 changes: 17 additions & 0 deletions docs/guides/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ class UserSchema(dy.Schema):
return pl.col("email").is_null() | pl.col("email").is_unique()
```

## How do I give custom names to the rules in a {class}`~dataframely.Schema`?

By default, the name of a rule is the name of the function that implements it.
However, you can also provide a custom name for a rule by using the `name` parameter of the `@dy.rule` decorator:

```python
class UserSchema(dy.Schema):
user_id = dy.UInt64(primary_key=True, nullable=False)

@dy.rule(name="my-custom-name")
def irrelevant_function_name(cls) -> pl.Expr:
return cls.user_id.col != 42
```

Whenever `dataframely` needs to refer to this rule, it will use the custom name `my-custom-name` instead of the function
name `irrelevant_function_name`.

## How do I fix the ruff error `First argument of a method should be named self`?

See our documentation on [group rules](./quickstart.md#group-rules).
Expand Down
6 changes: 6 additions & 0 deletions docs/guides/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ The decorator `@dy.rule()` "registers" the function as a rule using its name (i.
The returned expression provides a boolean value for each row of the data which evaluates to `True` whenever the data
are valid with respect to this rule.

```{note}
New in `dataframely` v2.8.0: You can now set custom names on rules by specifying the `name` kwarg.
For example, `@dy.rule(name="my-custom-name")` would register the rule under the name `my-custom-name` instead of the function name.
This will affect how validation errors are reported and how the rule is referred to in the `FailureInfo` object.
```

## Group rules

For defining even more complex rules, the `@dy.rule` decorator allows for a `group_by`
Expand Down
34 changes: 34 additions & 0 deletions tests/schema/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,37 @@ def test_filter_details(eager: bool) -> None:
"primary_key": "invalid",
},
]


@pytest.mark.parametrize("eager", [True, False])
def test_filter_custom_rule_name(eager: bool) -> None:
"""Verify that we can set a custom rule name on a non-group rule."""

class MySchema(dy.Schema):
a = dy.Int64()

@dy.rule(name="custom_rule_name")
def my_rule(cls) -> pl.Expr:
return cls.a.col != 42

df = pl.DataFrame({"a": [1, 42, 3]})
_, fails = _filter_and_collect(MySchema, df, cast=True, eager=eager)
assert fails.counts() == {"custom_rule_name": 1}


@pytest.mark.parametrize("eager", [True, False])
def test_filter_custom_group_rule_name(eager: bool) -> None:
"""Verify that we can set a custom rule name on a group rule."""

class MySchema(dy.Schema):
a = dy.String()
b = dy.Int64()

@dy.rule(name="custom_rule_name", group_by=["a"])
def my_rule(cls) -> pl.Expr:
return cls.b.col.sum() < 10

df = pl.DataFrame({"a": ["x", "x", "y"], "b": [75, 75, 75]})
_, fails = _filter_and_collect(MySchema, df, cast=True, eager=eager)

assert fails.counts() == {"custom_rule_name": 3}
32 changes: 32 additions & 0 deletions tests/schema/test_rule_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,35 @@ def test_rule_column_overlap_error() -> None:
columns={"test": dy.Integer(alias="a")},
rules={"a": Rule(pl.col("a") > 0)},
)


def test_rule_custom_illegal_name() -> None:
with pytest.raises(
dy.exc.ImplementationError,
match="Custom validation rule must not be named `primary_key`.",
):

class MySchema(dy.Schema):
a = dy.String()

@dy.rule(name="primary_key")
def my_rule(cls) -> pl.Expr:
return cls.a.col < 10


def test_rule_custom_duplicate_name() -> None:
with pytest.raises(
dy.exc.ImplementationError,
match="Custom validation rules must have unique names",
):

class MySchema(dy.Schema):
a = dy.String()

@dy.rule(name="custom")
def my_rule1(cls) -> pl.Expr:
return cls.a.col < 10

@dy.rule(name="custom")
def my_rule2(cls) -> pl.Expr:
return cls.a.col > 5
11 changes: 11 additions & 0 deletions tests/schema/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def test_simple_serialization() -> None:
assert set(decoded["rules"].keys()) == set()


class CustomRuleNameSchema(dy.Schema):
"""A schema with a custom rule name."""

x = dy.Int64()

@dy.rule(name="custom_rule")
def irrelevant_name_here(cls) -> pl.Expr:
return cls.x.col > 5


@pytest.mark.parametrize(
"schema",
[
Expand Down Expand Up @@ -59,6 +69,7 @@ def test_simple_serialization() -> None:
"test",
{"a": dy.Struct({"x": dy.Int64(min=5, check=lambda expr: expr < 10)})},
),
CustomRuleNameSchema,
],
)
def test_roundtrip_matches(schema: type[dy.Schema]) -> None:
Expand Down
Loading