diff --git a/dpsynth/domain.py b/dpsynth/domain.py index 67458df..de66d04 100644 --- a/dpsynth/domain.py +++ b/dpsynth/domain.py @@ -50,6 +50,7 @@ from typing import Any, Literal, TypeAlias import attr +import numpy as np import pandas as pd import yaml @@ -141,6 +142,13 @@ class NumericalAttribute: max_value] if they are a numeric type, and to min_value otherwise. If False, out-of-domain values will be grouped together and treated as a single special out-of-domain value. + sentinel: The value to assign to out-of-domain entries during reverse + discretization when ``clip_to_range`` is False. Defaults to ``None``, + which resolves to ``np.nan`` for numeric modes or ``''`` for + ``interval_handling='interval'``. Set to an integer (e.g. ``-1``) to keep + the output array as an integer dtype instead of silently promoting to + float. When explicitly set, must be a string for interval mode and numeric + for other modes. Use ``resolved_sentinel`` to get the effective value. dtype: The dtype of the data (either 'int' or 'float'). interval_handling: Controls how discretized intervals are converted back to numerical values. 'midpoint' returns the interval midpoint (or the finite @@ -153,6 +161,7 @@ class NumericalAttribute: min_value: float = attr.field(converter=float) max_value: float = attr.field(converter=float) clip_to_range: bool = attr.field(default=True) + sentinel: float | int | str | None = attr.field(default=None) dtype: str = attr.field(default='float') interval_handling: str = attr.field(default='midpoint') description: str | None = attr.field(default=None) @@ -174,11 +183,34 @@ def _validate_dtype(self, *_): @interval_handling.validator # pytype: disable=attribute-error def _validate_interval_handling(self, *_): + """Validates interval_handling mode and sentinel type compatibility.""" if self.interval_handling not in ['midpoint', 'sample', 'interval']: raise ValueError( 'interval_handling must be "midpoint", "sample", or "interval",' f' got {self.interval_handling}.' ) + if self.sentinel is not None: + if self.interval_handling == 'interval': + if not isinstance(self.sentinel, str): + raise ValueError( + "interval_handling='interval' requires a string sentinel, got" + f' sentinel={self.sentinel!r}.' + ) + elif not isinstance(self.sentinel, (int, float, np.integer, np.floating)): + raise ValueError( + 'sentinel must be numeric when' + f' interval_handling={self.interval_handling!r}, got' + f' sentinel={self.sentinel!r}.' + ) + + @property + def resolved_sentinel(self) -> float | int | str: + """Returns the effective sentinel, with mode-appropriate defaults.""" + if self.sentinel is not None: + return self.sentinel + if self.interval_handling == 'interval': + return '' + return np.nan @property def exclusive_min_value(self) -> float: diff --git a/dpsynth/local_mode/vectorized_transformations.py b/dpsynth/local_mode/vectorized_transformations.py index 792ce84..30e10dc 100644 --- a/dpsynth/local_mode/vectorized_transformations.py +++ b/dpsynth/local_mode/vectorized_transformations.py @@ -162,7 +162,8 @@ def undiscretize( A 1-D array. For ``'midpoint'`` and ``'sample'`` the dtype is float (or int when ``dtype == 'int'`` and all values are in-domain). For ``'interval'`` the dtype is ``object`` (strings). Out-of-domain bins - (index 0 when ``clip_to_range`` is ``False``) map to ``NaN`` or ``""``. + (index 0 when ``clip_to_range`` is ``False``) map to + ``attribute_domain.resolved_sentinel``. """ rng = np.random.default_rng(rng) min_, max_ = attribute_domain.exclusive_min_value, attribute_domain.max_value @@ -170,16 +171,15 @@ def undiscretize( full_edges = np.r_[min_, bin_edges, max_] lefts, rights = full_edges[:-1], full_edges[1:] handling = attribute_domain.interval_handling + sentinel = attribute_domain.resolved_sentinel if handling == 'interval': values = np.array([f'({l}, {r}]' for l, r in zip(lefts, rights)], dtype=str) if not attribute_domain.clip_to_range: - sentinel = np.array('', dtype=str) - values = np.r_[sentinel, values] + values = np.r_[np.array(sentinel, dtype=str), values] return values[bin_indices] elif handling == 'sample': if not attribute_domain.clip_to_range: - sentinel = np.nan ood = bin_indices == 0 idx = bin_indices - 1 result = np.where(ood, sentinel, rng.uniform(lefts[idx], rights[idx])) @@ -188,14 +188,12 @@ def undiscretize( elif handling == 'midpoint': midpoints = (lefts + rights) / 2.0 if not attribute_domain.clip_to_range: - sentinel = np.nan midpoints = np.r_[sentinel, midpoints] result = midpoints[bin_indices] else: raise ValueError(f'Unsupported interval_handling: {handling}') if attribute_domain.dtype == 'int' and attribute_domain.clip_to_range: - # If clip_to_range=False, then NaNs are possible so we don't cast to int. result = np.ceil(result).astype(int) return result diff --git a/dpsynth/pydantic_api.py b/dpsynth/pydantic_api.py index fbef5bf..ebd5505 100644 --- a/dpsynth/pydantic_api.py +++ b/dpsynth/pydantic_api.py @@ -172,4 +172,10 @@ def dp_synthetic_data_generation( discrete_config=mechanism_config, ) - return [cls(**user) for _, user in synthetic.iterrows()] + return [ + cls(**{ + k: None if isinstance(v, float) and math.isnan(v) else v + for k, v in user.items() + }) + for _, user in synthetic.iterrows() + ] diff --git a/dpsynth/transformations.py b/dpsynth/transformations.py index 3f53cab..c87b612 100644 --- a/dpsynth/transformations.py +++ b/dpsynth/transformations.py @@ -169,6 +169,7 @@ def create_discretize_transformation( ] intervals = pd.IntervalIndex.from_breaks(bin_edges) maybe_none = [] if attribute_domain.clip_to_range else [None] + sentinel = attribute_domain.resolved_sentinel possible_values = maybe_none + list(intervals) def transform(value: Any) -> pd.Interval | None: @@ -190,7 +191,7 @@ def _resolve_finite(interval: pd.Interval) -> float: def reverse(value: pd.Interval | None) -> float | pd.Interval | None: if value is None: - return None + return sentinel if attribute_domain.interval_handling == 'interval': return value if attribute_domain.interval_handling == 'sample': diff --git a/tests/domain_test.py b/tests/domain_test.py index 5cae77a..0469419 100644 --- a/tests/domain_test.py +++ b/tests/domain_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from dpsynth import domain +import numpy as np class TestDomain(absltest.TestCase): @@ -100,6 +101,55 @@ def test_standardize_numerical(self): for value in ood_values: self.assertIsNone(attribute.standardize(value)) + def test_numerical_attribute_default_sentinel(self): + attribute = domain.NumericalAttribute(0, 10) + self.assertIsNone(attribute.sentinel) + self.assertTrue(np.isnan(attribute.resolved_sentinel)) + + def test_numerical_attribute_custom_sentinel(self): + attribute = domain.NumericalAttribute(0, 10, sentinel=-1) + self.assertEqual(attribute.sentinel, -1) + self.assertEqual(attribute.resolved_sentinel, -1) + + def test_sentinel_yaml_roundtrip(self): + original = { + 'num': domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=False, sentinel=-1 + ), + } + temp_file = self.create_tempfile('temp.yaml', mode='w+') + domain.to_yaml_file(original, temp_file.full_path) + loaded = domain.from_yaml_file(temp_file.full_path) + self.assertEqual(loaded['num'].sentinel, -1) + + def test_string_sentinel_allowed_with_interval_handling(self): + attr = domain.NumericalAttribute( + 0, 10, sentinel='MISSING', interval_handling='interval' + ) + self.assertEqual(attr.sentinel, 'MISSING') + self.assertEqual(attr.resolved_sentinel, 'MISSING') + + def test_interval_handling_default_sentinel(self): + attr = domain.NumericalAttribute(0, 10, interval_handling='interval') + self.assertIsNone(attr.sentinel) + self.assertEqual(attr.resolved_sentinel, '') + + def test_string_sentinel_rejected_with_midpoint_handling(self): + with self.assertRaises(ValueError): + domain.NumericalAttribute(0, 10, sentinel='MISSING') + + def test_numeric_sentinel_rejected_with_interval_handling(self): + with self.assertRaises(ValueError): + domain.NumericalAttribute( + 0, 10, sentinel=42, interval_handling='interval' + ) + + def test_numpy_numeric_sentinel_accepted(self): + attr = domain.NumericalAttribute(0, 10, sentinel=np.int32(-1)) + self.assertEqual(attr.sentinel, -1) + attr = domain.NumericalAttribute(0, 10, sentinel=np.float32(0.5)) + self.assertAlmostEqual(attr.sentinel, 0.5, places=5) + def test_freeform_text_defaults(self): attribute = domain.FreeFormTextAttribute() self.assertEqual(attribute.max_tokens, 256) diff --git a/tests/local_mode/vectorized_transformations_test.py b/tests/local_mode/vectorized_transformations_test.py index a830d3f..5c9ce8c 100644 --- a/tests/local_mode/vectorized_transformations_test.py +++ b/tests/local_mode/vectorized_transformations_test.py @@ -238,17 +238,18 @@ def test_interval_mode(self): self.assertIn('(', result[0]) self.assertIn(']', result[0]) - def test_interval_mode_ood_empty_string(self): + def test_string_sentinel_interval(self): attr = domain.NumericalAttribute( min_value=0, max_value=10, clip_to_range=False, + sentinel='MISSING', interval_handling='interval', ) result = vectorized_transformations.undiscretize( np.array([0, 1]), np.array([5.0]), attr ) - self.assertEqual(result[0], '') + self.assertEqual(result[0], 'MISSING') self.assertIn('(', result[1]) def test_invalid_bin_edges_raises(self): @@ -266,6 +267,31 @@ def test_invalid_bin_edges_raises(self): np.array([1]), np.array([5.0, 3.0]), attr ) + def test_custom_sentinel_midpoint(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=False, sentinel=-1 + ) + result = vectorized_transformations.undiscretize( + np.array([0, 1, 2]), np.array([5.0]), attr + ) + self.assertEqual(result[0], -1) + self.assertBetween(result[1], 0, 5) + self.assertBetween(result[2], 5, 10) + + def test_custom_sentinel_sample(self): + rng = np.random.default_rng(0) + attr = domain.NumericalAttribute( + min_value=0, + max_value=10, + clip_to_range=False, + sentinel=-1, + interval_handling='sample', + ) + result = vectorized_transformations.undiscretize( + np.array([0, 1]), np.array([5.0]), attr, rng=rng + ) + self.assertEqual(result[0], -1) + class MergeRareValuesTest(absltest.TestCase): diff --git a/tests/transformations_test.py b/tests/transformations_test.py index 81cc434..6a68b1a 100644 --- a/tests/transformations_test.py +++ b/tests/transformations_test.py @@ -149,7 +149,25 @@ def test_valid_discretization_no_clip_to_range_inverse(self): self.assertBetween(transform_fn.inverse(interval1), 0, 5) self.assertBetween(transform_fn.inverse(interval2), 5, 10) - self.assertIsNone(transform_fn.inverse(None)) + self.assertTrue(np.isnan(transform_fn.inverse(None))) + + def test_discretize_inverse_sentinel_default_nan(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=False + ) + _, transform_fn = transformations.create_discretize_transformation( + attr, [5] + ) + self.assertTrue(np.isnan(transform_fn.inverse(None))) + + def test_discretize_inverse_custom_sentinel(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=False, sentinel=-1 + ) + _, transform_fn = transformations.create_discretize_transformation( + attr, [5] + ) + self.assertEqual(transform_fn.inverse(None), -1) def test_valid_discretization_for_int_attribute(self): attr = domain.NumericalAttribute(min_value=0, max_value=10, dtype='int') @@ -180,7 +198,7 @@ def test_discretize_interval_handling_sample(self): values.add(value) # Sample mode should produce non-constant output (unlike midpoint). self.assertGreater(len(values), 1) - self.assertIsNone(transform_fn.inverse(None)) + self.assertTrue(np.isnan(transform_fn.inverse(None))) def test_discretize_interval_handling_interval(self): attr = domain.NumericalAttribute( @@ -191,7 +209,7 @@ def test_discretize_interval_handling_interval(self): ) interval = pd.Interval(5, 10) self.assertEqual(transform_fn.inverse(interval), interval) - self.assertIsNone(transform_fn.inverse(None)) + self.assertEqual(transform_fn.inverse(None), '') def test_discretize_reverse_semi_infinite_intervals(self): # Midpoint mode: semi-infinite intervals should return the finite endpoint.