diff --git a/dataframely/_rule.py b/dataframely/_rule.py index 6c4e788..3616dfe 100644 --- a/dataframely/_rule.py +++ b/dataframely/_rule.py @@ -199,16 +199,21 @@ def with_evaluation_rules(lf: pl.LazyFrame, rules: dict[str, Rule]) -> pl.LazyFr name: rule for name, rule in rules.items() if isinstance(rule, GroupRule) } - # Before we can select all of the simple expressions, we need to turn the - # group rules into something to use in a `select` statement as well. - result = ( - # NOTE: A value of `null` always validates successfully as nullability should - # already be checked via dedicated rules. - lf.pipe(_with_group_rules, group_rules).with_columns( - **{name: expr.fill_null(True) for name, expr in simple_exprs.items()}, - ) + # First, evaluate row-level (simple) rules + # NOTE: A value of `null` always validates successfully as nullability should + # already be checked via dedicated rules. + result = lf.with_columns( + **{name: expr.fill_null(True) for name, expr in simple_exprs.items()}, ) + # Group rules must be evaluated AFTER row-level rules to ensure they only + # apply to rows that pass row-level validation. Otherwise, a row that fails + # a row-level rule could be part of a group evaluation, and when that row is + # filtered out, the group might become invalid. + if group_rules: + # Evaluate group rules only on rows that pass all row-level rules + result = result.pipe(_with_group_rules, group_rules, simple_exprs) + # If there is at least one rule that checks for successful dtype casting, we need # to take an extra step: rules other than the "dtype rules" might not be reliable # if casting failed, i.e. if any of the "dtype rules" evaluated to `False`. For @@ -231,7 +236,11 @@ def with_evaluation_rules(lf: pl.LazyFrame, rules: dict[str, Rule]) -> pl.LazyFr return result -def _with_group_rules(lf: pl.LazyFrame, rules: dict[str, GroupRule]) -> pl.LazyFrame: +def _with_group_rules( + lf: pl.LazyFrame, + rules: dict[str, GroupRule], + simple_exprs: dict[str, pl.Expr] | None = None, +) -> pl.LazyFrame: # First, we partition the rules by group columns. This will minimize the number # of `group_by` calls and joins to make. grouped_rules: dict[frozenset[str], dict[str, pl.Expr]] = defaultdict(dict) @@ -239,19 +248,37 @@ def _with_group_rules(lf: pl.LazyFrame, rules: dict[str, GroupRule]) -> pl.LazyF # NOTE: `null` indicates validity, see note above. grouped_rules[frozenset(rule.group_columns)][name] = rule.expr.fill_null(True) + # If we have row-level rules, we need to evaluate group rules only on rows that + # pass all row-level rules. This ensures that a row failing a row-level rule + # doesn't affect group rule evaluation. We do this by filtering the data for + # group aggregation, but still joining back to the full dataset. + if simple_exprs: + # Create a filter for rows that pass all row-level rules + all_simple_valid = pl.all_horizontal( + [pl.col(name).fill_null(True) for name in simple_exprs.keys()] + ) + # Use this filter when doing group aggregations + lf_for_groups = lf.filter(all_simple_valid) + else: + lf_for_groups = lf + # Then, for each `group_by`, we apply the relevant rules and keep all the rule # evaluations around group_evaluations: dict[frozenset[str], pl.LazyFrame] = {} - for group_columns, group_rules in grouped_rules.items(): + for group_columns, group_rules_exprs in grouped_rules.items(): # We group by the group columns and apply all expressions - group_evaluations[group_columns] = lf.group_by(group_columns).agg(**group_rules) + group_evaluations[group_columns] = lf_for_groups.group_by(group_columns).agg( + **group_rules_exprs + ) # Eventually, we apply the rule evaluations onto the input data frame. For this, we # "broadcast" the results within each group across rows in the same group. + # We join to the FULL dataset (lf), not the filtered one, so rows that failed + # row-level rules will get null values for group rules (which are treated as valid). result = lf for group_columns, frame in group_evaluations.items(): result = result.join( - frame, on=list(group_columns), nulls_equal=True, maintain_order="left" + frame, on=list(group_columns), nulls_equal=True, maintain_order="left", how="left" ) return result diff --git a/tests/schema/test_filter.py b/tests/schema/test_filter.py index 99557e7..2d01bb4 100644 --- a/tests/schema/test_filter.py +++ b/tests/schema/test_filter.py @@ -275,3 +275,81 @@ def test_filter_details(eager: bool) -> None: "primary_key": "invalid", }, ] + + +# --------------------------- GROUP RULES WITH ROW RULES ----------------------------- # + + +@pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) +@pytest.mark.parametrize("eager", [True, False]) +def test_filter_group_rule_after_row_rule_filtering( + df_type: type[pl.DataFrame] | type[pl.LazyFrame], eager: bool +) -> None: + """Test that group rules are evaluated after row-level rules filter rows. + + This addresses a bug where a row failing a row-level rule could affect group rule + evaluation, causing incorrect validation results. + """ + + class DiagnosisSchema(dy.Schema): + invoice_id = dy.String(primary_key=True) + diagnosis = dy.String(primary_key=True, regex="^[A-Z]{3}$") + is_main = dy.Bool(nullable=False) + + @dy.rule(group_by=["invoice_id"]) + def exactly_one_main_diagnosis(cls) -> pl.Expr: + return pl.col("is_main").sum() == 1 + + # Case 1: Row-level rule removes row that would make group rule pass + # Without the fix, this would incorrectly pass validation + df1 = df_type( + { + "invoice_id": ["A", "A", "A"], + "diagnosis": ["ABC", "ABD", "123"], # "123" fails regex + "is_main": [False, False, True], # Main diagnosis on invalid row + } + ) + valid1, failures1 = _filter_and_collect(DiagnosisSchema, df1, eager=eager) + # All rows should be filtered out: row with "123" fails regex, + # and remaining rows fail group rule (no main diagnosis) + assert len(valid1) == 0 + assert "exactly_one_main_diagnosis" in failures1.counts() + assert "diagnosis|regex" in failures1.counts() + + # Case 2: Valid data passes both row-level and group rules + df2 = df_type( + { + "invoice_id": ["A", "A", "A"], + "diagnosis": ["ABC", "ABD", "AEF"], + "is_main": [False, True, False], + } + ) + valid2, failures2 = _filter_and_collect(DiagnosisSchema, df2, eager=eager) + assert len(valid2) == 3 + assert valid2.select(pl.col("is_main").sum()).item() == 1 + assert len(failures2) == 0 + + # Case 3: Multiple groups, one has invalid row-level data + df3 = df_type( + { + "invoice_id": ["A", "A", "A", "B", "B"], + "diagnosis": ["ABC", "ABD", "AEF", "XYZ", "123"], + "is_main": [False, True, False, True, False], + } + ) + valid3, failures3 = _filter_and_collect(DiagnosisSchema, df3, eager=eager) + # Group A: all valid, group B: one row fails regex, remaining row is valid + assert len(valid3) == 4 + assert failures3.counts() == {"diagnosis|regex": 1} + + # Case 4: All rows pass row-level but fail group rule + df4 = df_type( + { + "invoice_id": ["A", "A", "A"], + "diagnosis": ["ABC", "ABD", "AEF"], + "is_main": [False, False, False], # No main diagnosis + } + ) + valid4, failures4 = _filter_and_collect(DiagnosisSchema, df4, eager=eager) + assert len(valid4) == 0 + assert failures4.counts() == {"exactly_one_main_diagnosis": 3}