Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions dataframely/_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,21 @@ def with_evaluation_rules(lf: pl.LazyFrame, rules: dict[str, Rule]) -> pl.LazyFr
name: rule for name, rule in rules.items() if isinstance(rule, GroupRule)
}

# Before we can select all of the simple expressions, we need to turn the
# group rules into something to use in a `select` statement as well.
result = (
# NOTE: A value of `null` always validates successfully as nullability should
# already be checked via dedicated rules.
lf.pipe(_with_group_rules, group_rules).with_columns(
**{name: expr.fill_null(True) for name, expr in simple_exprs.items()},
)
# First, evaluate row-level (simple) rules
# NOTE: A value of `null` always validates successfully as nullability should
# already be checked via dedicated rules.
result = lf.with_columns(
**{name: expr.fill_null(True) for name, expr in simple_exprs.items()},
)

# Group rules must be evaluated AFTER row-level rules to ensure they only
# apply to rows that pass row-level validation. Otherwise, a row that fails
# a row-level rule could be part of a group evaluation, and when that row is
# filtered out, the group might become invalid.
if group_rules:
# Evaluate group rules only on rows that pass all row-level rules
result = result.pipe(_with_group_rules, group_rules, simple_exprs)

# If there is at least one rule that checks for successful dtype casting, we need
# to take an extra step: rules other than the "dtype rules" might not be reliable
# if casting failed, i.e. if any of the "dtype rules" evaluated to `False`. For
Expand All @@ -231,27 +236,49 @@ def with_evaluation_rules(lf: pl.LazyFrame, rules: dict[str, Rule]) -> pl.LazyFr
return result


def _with_group_rules(lf: pl.LazyFrame, rules: dict[str, GroupRule]) -> pl.LazyFrame:
def _with_group_rules(
lf: pl.LazyFrame,
rules: dict[str, GroupRule],
simple_exprs: dict[str, pl.Expr] | None = None,
) -> pl.LazyFrame:
# First, we partition the rules by group columns. This will minimize the number
# of `group_by` calls and joins to make.
grouped_rules: dict[frozenset[str], dict[str, pl.Expr]] = defaultdict(dict)
for name, rule in rules.items():
# NOTE: `null` indicates validity, see note above.
grouped_rules[frozenset(rule.group_columns)][name] = rule.expr.fill_null(True)

# If we have row-level rules, we need to evaluate group rules only on rows that
# pass all row-level rules. This ensures that a row failing a row-level rule
# doesn't affect group rule evaluation. We do this by filtering the data for
# group aggregation, but still joining back to the full dataset.
if simple_exprs:
# Create a filter for rows that pass all row-level rules
all_simple_valid = pl.all_horizontal(
[pl.col(name).fill_null(True) for name in simple_exprs.keys()]
)
# Use this filter when doing group aggregations
lf_for_groups = lf.filter(all_simple_valid)
else:
lf_for_groups = lf

# Then, for each `group_by`, we apply the relevant rules and keep all the rule
# evaluations around
group_evaluations: dict[frozenset[str], pl.LazyFrame] = {}
for group_columns, group_rules in grouped_rules.items():
for group_columns, group_rules_exprs in grouped_rules.items():
# We group by the group columns and apply all expressions
group_evaluations[group_columns] = lf.group_by(group_columns).agg(**group_rules)
group_evaluations[group_columns] = lf_for_groups.group_by(group_columns).agg(
**group_rules_exprs
)

# Eventually, we apply the rule evaluations onto the input data frame. For this, we
# "broadcast" the results within each group across rows in the same group.
# We join to the FULL dataset (lf), not the filtered one, so rows that failed
# row-level rules will get null values for group rules (which are treated as valid).
result = lf
for group_columns, frame in group_evaluations.items():
result = result.join(
frame, on=list(group_columns), nulls_equal=True, maintain_order="left"
frame, on=list(group_columns), nulls_equal=True, maintain_order="left", how="left"
)
return result

Expand Down
78 changes: 78 additions & 0 deletions tests/schema/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,81 @@ def test_filter_details(eager: bool) -> None:
"primary_key": "invalid",
},
]


# --------------------------- GROUP RULES WITH ROW RULES ----------------------------- #


@pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
@pytest.mark.parametrize("eager", [True, False])
def test_filter_group_rule_after_row_rule_filtering(
df_type: type[pl.DataFrame] | type[pl.LazyFrame], eager: bool
) -> None:
"""Test that group rules are evaluated after row-level rules filter rows.

This addresses a bug where a row failing a row-level rule could affect group rule
evaluation, causing incorrect validation results.
"""

class DiagnosisSchema(dy.Schema):
invoice_id = dy.String(primary_key=True)
diagnosis = dy.String(primary_key=True, regex="^[A-Z]{3}$")
is_main = dy.Bool(nullable=False)

@dy.rule(group_by=["invoice_id"])
def exactly_one_main_diagnosis(cls) -> pl.Expr:
return pl.col("is_main").sum() == 1

# Case 1: Row-level rule removes row that would make group rule pass
# Without the fix, this would incorrectly pass validation
df1 = df_type(
{
"invoice_id": ["A", "A", "A"],
"diagnosis": ["ABC", "ABD", "123"], # "123" fails regex
"is_main": [False, False, True], # Main diagnosis on invalid row
}
)
valid1, failures1 = _filter_and_collect(DiagnosisSchema, df1, eager=eager)
# All rows should be filtered out: row with "123" fails regex,
# and remaining rows fail group rule (no main diagnosis)
assert len(valid1) == 0
assert "exactly_one_main_diagnosis" in failures1.counts()
assert "diagnosis|regex" in failures1.counts()

# Case 2: Valid data passes both row-level and group rules
df2 = df_type(
{
"invoice_id": ["A", "A", "A"],
"diagnosis": ["ABC", "ABD", "AEF"],
"is_main": [False, True, False],
}
)
valid2, failures2 = _filter_and_collect(DiagnosisSchema, df2, eager=eager)
assert len(valid2) == 3
assert valid2.select(pl.col("is_main").sum()).item() == 1
assert len(failures2) == 0

# Case 3: Multiple groups, one has invalid row-level data
df3 = df_type(
{
"invoice_id": ["A", "A", "A", "B", "B"],
"diagnosis": ["ABC", "ABD", "AEF", "XYZ", "123"],
"is_main": [False, True, False, True, False],
}
)
valid3, failures3 = _filter_and_collect(DiagnosisSchema, df3, eager=eager)
# Group A: all valid, group B: one row fails regex, remaining row is valid
assert len(valid3) == 4
assert failures3.counts() == {"diagnosis|regex": 1}

# Case 4: All rows pass row-level but fail group rule
df4 = df_type(
{
"invoice_id": ["A", "A", "A"],
"diagnosis": ["ABC", "ABD", "AEF"],
"is_main": [False, False, False], # No main diagnosis
}
)
valid4, failures4 = _filter_and_collect(DiagnosisSchema, df4, eager=eager)
assert len(valid4) == 0
assert failures4.counts() == {"exactly_one_main_diagnosis": 3}