diff --git a/changelog.d/fix-eternity-input-cache-preservation.fixed.md b/changelog.d/fix-eternity-input-cache-preservation.fixed.md new file mode 100644 index 00000000..d9cae03d --- /dev/null +++ b/changelog.d/fix-eternity-input-cache-preservation.fixed.md @@ -0,0 +1 @@ +Preserve user-provided ETERNITY inputs across cache invalidation when they were set for an ordinary period. diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index ba8e173a..e221c904 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1,5 +1,6 @@ import hashlib import tempfile +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Type, Union import numpy as np @@ -57,6 +58,14 @@ def _stable_hash_to_seed(value: str) -> int: ) +@dataclass(frozen=True) +class PreservedUserInput: + variable_name: str + branch_name: str + period: Period + value: object + + class Simulation: """ Represents a simulation, and handles the calculation logic @@ -273,16 +282,23 @@ def _invalidate_all_caches(self) -> None: self._fast_cache = {} self.invalidated_caches = set() # Snapshot user-provided inputs before wiping so they can be - # replayed into the fresh storage. Storage keys each entry as - # f"{branch_name}:{period}"; preserve exactly those keys. - preserved: dict[str, dict[str, object]] = {} + # replayed into the fresh storage. Use the storage API instead of + # hand-building keys, since ETERNITY variables canonicalize every + # period to the single ETERNITY storage key. + preserved: list[PreservedUserInput] = [] user_input_keys = getattr(self, "_user_input_keys", None) or set() for variable_name, branch_name, period in user_input_keys: holder = self.get_holder(variable_name) - storage_key = f"{branch_name}:{period}" - stored_value = holder._memory_storage._arrays.get(storage_key) + stored_value = holder._memory_storage.get(period, branch_name) if stored_value is not None: - preserved.setdefault(variable_name, {})[storage_key] = stored_value + preserved.append( + PreservedUserInput( + variable_name=variable_name, + branch_name=branch_name, + period=period, + value=stored_value, + ) + ) # Iterate only over holders that already exist on each population — # lazy-creating a holder for every variable in the tax-benefit # system (thousands in policyengine-us) inflated the cost of @@ -295,9 +311,13 @@ def _invalidate_all_caches(self) -> None: if holder._disk_storage is not None: holder._disk_storage._files = {} # Replay preserved user inputs so ``calculate`` still sees them. - for variable_name, key_to_array in preserved.items(): - holder = self.get_holder(variable_name) - holder._memory_storage._arrays.update(key_to_array) + for user_input in preserved: + holder = self.get_holder(user_input.variable_name) + holder._memory_storage.put( + user_input.value, + user_input.period, + user_input.branch_name, + ) for branch in self.branches.values(): branch._invalidate_all_caches() diff --git a/tests/core/test_apply_reform_preserves_user_inputs.py b/tests/core/test_apply_reform_preserves_user_inputs.py index 6cea6016..f11c5c57 100644 --- a/tests/core/test_apply_reform_preserves_user_inputs.py +++ b/tests/core/test_apply_reform_preserves_user_inputs.py @@ -89,6 +89,60 @@ def apply(self): assert sim.calculate("age", period=period)[0] == 27 +def test_apply_reform_preserves_eternity_inputs_set_for_a_period( + tax_benefit_system, +): + """ETERNITY inputs set for ordinary periods must survive reform apply.""" + sim = SimulationBuilder().build_from_entities( + tax_benefit_system, situation_examples.single + ) + period = "2017" + expected_person_id = np.array([123], dtype=np.int32) + + sim.set_input("person_id", period, expected_person_id) + + class NoOpReform(Reform): + def apply(self): + pass + + sim.apply_reform(NoOpReform) + + result = sim.calculate("person_id", period=period) + assert np.array_equal(result, expected_person_id), ( + "apply_reform lost an ETERNITY input set through Simulation.set_input " + f"for {period}; got {result} instead of {expected_person_id}." + ) + + +def test_apply_reform_preserves_eternity_inputs_set_through_holder( + tax_benefit_system, +): + """ETERNITY preservation must also cover direct ``holder.set_input``.""" + sim = SimulationBuilder().build_from_entities( + tax_benefit_system, situation_examples.single + ) + period = "2017" + expected_person_id = np.array([456], dtype=np.int32) + + sim.get_holder("person_id").set_input( + period, + expected_person_id, + sim.branch_name, + ) + + class NoOpReform(Reform): + def apply(self): + pass + + sim.apply_reform(NoOpReform) + + result = sim.calculate("person_id", period=period) + assert np.array_equal(result, expected_person_id), ( + "apply_reform lost an ETERNITY input set through Holder.set_input " + f"for {period}; got {result} instead of {expected_person_id}." + ) + + def test_apply_reform_preserves_situation_dict_inputs(tax_benefit_system): """Situation-dict inputs must survive ``apply_reform`` too.