diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4836fc1..fe9d65e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -74,4 +74,4 @@ jobs: - name: Run tests run: | - python -m pytest --pyargs confection + python -m pytest tests/ diff --git a/confection/__init__.py b/confection/__init__.py index ad2964d..6db6ea6 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -1,1205 +1,21 @@ -import copy -import inspect -import io -import json as _json -import re -import warnings -from configparser import ( - MAX_INTERPOLATION_DEPTH, - ConfigParser, - ExtendedInterpolation, - InterpolationDepthError, - InterpolationMissingOptionError, - InterpolationSyntaxError, - NoOptionError, - NoSectionError, - ParsingError, +# FIXME some symbols are not in __all__; can we remove them? +from ._config import Config # noqa: F401 +from ._constants import ( # noqa: F401 + ARGS_FIELD, + ARGS_FIELD_ALIAS, + RESERVED_FIELDS, + SECTION_PREFIX, + VARIABLE_RE, ) -from dataclasses import dataclass -from pathlib import Path -from types import GeneratorType -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, - get_type_hints, +from ._errors import ConfigValidationError +from ._registry import Promise, registry +from .util import ( # noqa: F401 + SimpleFrozenDict, + SimpleFrozenList, + try_dump_json, + try_load_json, ) - -from .util import SimpleFrozenDict, SimpleFrozenList # noqa: F401 -from .validation import ( - Field, - FieldInfo, - Schema, - ValidationError, - create_schema, - ensure_schema, - validate_type, # noqa: F401 — public API -) - -# Field used for positional arguments, e.g. [section.*.xyz]. The alias is -# required for the schema (shouldn't clash with user-defined arg names) -ARGS_FIELD = "*" -ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS" -# Aliases for fields that would otherwise shadow pydantic attributes. Can be any -# string, so we're using name + space so it looks the same in error messages etc. -RESERVED_FIELDS = {"validate": "validate\u0020"} -# Internal prefix used to mark section references for custom interpolation -SECTION_PREFIX = "__SECTION__:" -# Regex to detect whether a value contains a variable -VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") - - -class CustomInterpolation(ExtendedInterpolation): - _KEYCRE: re.Pattern - - def before_read(self, parser, section, option, value): - # Warn about single-quoted strings (common mistake) - if value and value[0] == value[-1] == "'": - warnings.warn( - f"The value [{value}] seems to be single-quoted, but values " - "use JSON formatting, which requires double quotes." - ) - return super().before_read(parser, section, option, value) - - def _coerce_for_string_context(self, v): - """Coerce a raw config value for use in a compound string expression.""" - # Don't coerce section references - they need to stay quoted for JSON - if SECTION_PREFIX in v: - return v - try: - parsed = _json.loads(v) - except (ValueError, TypeError): - return v # Not valid JSON, already a plain string - if isinstance(parsed, str): - return parsed # Unwrap JSON string - # Use json.dumps() for non-strings, escaping inner quotes so they don't - # conflict with the outer JSON string quotes - return _json.dumps(parsed).replace('"', '\\"') - - def before_get(self, parser, section, option, value, defaults): - # Mostly copy-pasted from the built-in configparser implementation. - # The interpolate() method resolves ${...} references and appends pieces - # to L. For a bare reference like ${x}, L has one element. For compound - # expressions like "hello ${x}", L has multiple pieces that we join. - # Compound results stay as strings (coerced via _coerce_for_string_context), - # while bare references keep their JSON type for _interpret_value to parse. - L = [] - self.interpolate(parser, option, L, value, section, defaults, 1) - if len(L) == 1: - return L[0] - return "".join(self._coerce_for_string_context(piece) for piece in L) - - def interpolate(self, parser, option, accum, rest, section, map, depth): - # Mostly copy-pasted from the built-in configparser implementation. - # We need to overwrite this method so we can add special handling for - # block references :( All values produced here should be strings – - # we need to wait until the whole config is interpreted anyways so - # filling in incomplete values here is pointless. All we need is the - # section reference so we can fetch it later. - rawval = parser.get(section, option, raw=True, fallback=rest) - if depth > MAX_INTERPOLATION_DEPTH: - raise InterpolationDepthError(option, section, rawval) - while rest: - p = rest.find("$") - if p < 0: - accum.append(rest) - return - if p > 0: - accum.append(rest[:p]) - rest = rest[p:] - # p is no longer used - c = rest[1:2] - if c == "$": - accum.append("$") - rest = rest[2:] - elif c == "{": - # We want to treat both ${a:b} and ${a.b} the same - m = self._KEYCRE.match(rest) - if m is None: - err = f"bad interpolation variable reference {rest}" - raise InterpolationSyntaxError(option, section, err) - orig_var = m.group(1) - path = orig_var.replace(":", ".").rsplit(".", 1) - rest = rest[m.end() :] - sect = section - opt = option - try: - if len(path) == 1: - opt = parser.optionxform(path[0]) - # Check if the variable references a section rather - # than a key in the current section. If the key - # exists in the current map but its raw value is the - # same interpolation variable (self-reference), or if - # the key doesn't exist in the map, treat it as a - # section reference. - is_section_ref = opt not in map - if not is_section_ref: - raw = map[opt] - is_section_ref = ( - isinstance(raw, str) and f"${{{orig_var}}}" in raw - ) - if not is_section_ref: - v = map[opt] - else: - # Block reference — store as a special key - section_name = parser[parser.optionxform(path[0])]._name - v = self._get_section_name(section_name) - elif len(path) == 2: - sect = path[0] - opt = parser.optionxform(path[1]) - fallback = "__FALLBACK__" - v = parser.get(sect, opt, raw=True, fallback=fallback) - # If a variable doesn't exist, try again and treat the - # reference as a section - if v == fallback: - v = self._get_section_name(parser[f"{sect}.{opt}"]._name) - else: - err = f"More than one ':' found: {rest}" - raise InterpolationSyntaxError(option, section, err) - except (KeyError, NoSectionError, NoOptionError): - raise InterpolationMissingOptionError( - option, section, rawval, orig_var - ) from None - if "$" in v: - new_map = dict(parser.items(sect, raw=True)) - self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1) - else: - accum.append(v) - else: - err = "'$' must be followed by '$' or '{', found: %r" % (rest,) - raise InterpolationSyntaxError(option, section, err) - - def _get_section_name(self, name: str) -> str: - """Generate the name of a section. Note that we use a quoted string here - so we can use section references within lists and load the list as - JSON. Since section references can't be used within strings, we don't - need the quoted vs. unquoted distinction like we do for variables. - - Examples (assuming section = {"foo": 1}): - - value: ${section.foo} -> value: 1 - - value: "hello ${section.foo}" -> value: "hello 1" - - value: ${section} -> value: {"foo": 1} - - value: "${section}" -> value: {"foo": 1} - - value: "hello ${section}" -> invalid - """ - return f'"{SECTION_PREFIX}{name}"' - - -def get_configparser(interpolate: bool = True) -> ConfigParser: - config = ConfigParser(interpolation=CustomInterpolation() if interpolate else None) - # Preserve case of keys: https://stackoverflow.com/a/1611877/6400719 - config.optionxform = str # type: ignore - return config - - -_ConfigSelf = TypeVar("_ConfigSelf", bound="Config") - - -class Config(dict): - """This class holds the model and training configuration and can load and - save the TOML-style configuration format from/to a string, file or bytes. - The Config class is a subclass of dict and uses Python's ConfigParser - under the hood. - """ - - is_interpolated: bool - section_order: Sequence[str] - _sections: dict - - def __init__( - self, - data: Optional[Union[Mapping[str, Any], "ConfigParser", "Config"]] = None, - *, - is_interpolated: Optional[bool] = None, - section_order: Optional[Sequence[str]] = None, - ) -> None: - """Initialize a new Config object with optional data.""" - dict.__init__(self) - if data is None: - data = {} - if not isinstance(data, (dict, Config, ConfigParser)): - raise ValueError( - f"Can't initialize Config with data. Expected dict, Config or " - f"ConfigParser but got: {type(data)}" - ) - # Whether the config has been interpolated. We can use this to check - # whether we need to interpolate again when it's resolved. We assume - # that a config is interpolated by default. - if is_interpolated is not None: - self.is_interpolated = is_interpolated - elif isinstance(data, Config): - self.is_interpolated = data.is_interpolated - else: - self.is_interpolated = True - if section_order is not None: - self.section_order = section_order - elif isinstance(data, Config): - self.section_order = data.section_order - else: - self.section_order = [] - # Update with data - self.update(self._sort(data)) - - def interpolate(self: _ConfigSelf) -> _ConfigSelf: - """Interpolate a config. Returns a copy of the object.""" - # This is currently the most effective way because we need our custom - # to_str logic to run in order to re-serialize the values so we can - # interpolate them again. ConfigParser.read_dict will just call str() - # on all values, which isn't enough. - return type(self)().from_str(self.to_str()) - - def interpret_config(self, config: "ConfigParser") -> None: - """Interpret a config, parse nested sections and parse the values - as JSON. Mostly used internally and modifies the config in place. - """ - self._validate_sections(config) - # Sort sections by depth, so that we can iterate breadth-first. This - # allows us to check that we're not expanding an undefined block. - get_depth = lambda item: len(item[0].split(".")) - part = "" - for section, values in sorted(config.items(), key=get_depth): - if section == "DEFAULT": - # Skip [DEFAULT] section so it doesn't cause validation error - continue - parts = section.split(".") - node = self - for part in parts[:-1]: - if part == "*": - node = node.setdefault(part, {}) - elif part not in node: - err_title = ( - "Error parsing config section. Perhaps a section name is wrong?" - ) - err = [{"loc": parts, "msg": f"Section '{part}' is not defined"}] - raise ConfigValidationError( - config=self, errors=err, title=err_title - ) - else: - node = node[part] - if not isinstance(node, dict): - # Happens if both value *and* subsection were defined for a key - err = [{"loc": parts, "msg": "found conflicting values"}] - err_cfg = f"{self}\n{ ({part: dict(values)}) }" - raise ConfigValidationError(config=err_cfg, errors=err) - # Set the default section - node = node.setdefault(parts[-1], {}) - if not isinstance(node, dict): - # Happens if both value *and* subsection were defined for a key - err = [{"loc": parts, "msg": "found conflicting values"}] - err_cfg = f"{self}\n{ ({part: dict(values)}) }" - raise ConfigValidationError(config=err_cfg, errors=err) - try: - keys_values = list(values.items()) - except InterpolationMissingOptionError as e: - raise ConfigValidationError(desc=f"{e}") from None - for key, value in keys_values: - config_v = config.get(section, key) - node[key] = self._interpret_value(config_v) - self.replace_section_refs(self) - - def replace_section_refs( - self, config: Union[Dict[str, Any], "Config"], parent: str = "" - ) -> None: - """Replace references to section blocks in the final config.""" - for key, value in config.items(): - key_parent = f"{parent}.{key}".strip(".") - if isinstance(value, dict): - self.replace_section_refs(value, parent=key_parent) - elif isinstance(value, list): - config[key] = [ - self._get_section_ref(v, parent=[parent, key]) for v in value - ] - else: - config[key] = self._get_section_ref(value, parent=[parent, key]) - - def _interpret_value(self, value: Any) -> Any: - """Interpret a single config value.""" - result = try_load_json(value) - # If value is a string and it contains a variable, use original value - # (not interpreted string, which could lead to double quotes: - # ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string, - # so we're not keeping lists as strings. - # NOTE: This currently can't handle uninterpolated values like [${x.y}]! - if isinstance(result, str) and VARIABLE_RE.search(value): - result = value - return result - - def _get_section_ref(self, value: Any, *, parent: List[str] = []) -> Any: - """Get a single section reference.""" - if isinstance(value, str) and value.startswith(f'"{SECTION_PREFIX}'): - value = try_load_json(value) - if ( - isinstance(value, str) - and value.startswith(SECTION_PREFIX) - and value != SECTION_PREFIX - ): - parts = value.replace(SECTION_PREFIX, "", 1).split(".") - result = self - for item in parts: - try: - result = result[item] - except (KeyError, TypeError): - err_title = "Error parsing reference to config section" - err_msg = f"Section '{'.'.join(parts)}' is not defined" - err = [{"loc": parts, "msg": err_msg}] - raise ConfigValidationError( - config=self, errors=err, title=err_title - ) from None - return result - elif ( - isinstance(value, str) - and SECTION_PREFIX in value - and value != SECTION_PREFIX - ): - # String value references a section (either a dict or return - # value of promise). We can't allow this, since variables are - # always interpolated *before* configs are resolved. - err_desc = ( - "Can't reference whole sections or return values of function " - "blocks inside a string or list\n\nYou can change your variable to " - "reference a value instead. Keep in mind that it's not " - "possible to interpolate the return value of a registered " - "function, since variables are interpolated when the config " - "is loaded, and registered functions are resolved afterwards." - ) - err = [{"loc": parent, "msg": "uses section variable in string or list"}] - raise ConfigValidationError(errors=err, desc=err_desc) - return value - - def copy(self: _ConfigSelf) -> _ConfigSelf: - """Deepcopy the config.""" - try: - config = copy.deepcopy(self) - except Exception as e: - raise ValueError(f"Couldn't deep-copy config: {e}") from e - return type(self)( - config, - is_interpolated=self.is_interpolated, - section_order=self.section_order, - ) - - def merge( - self: _ConfigSelf, - updates: Union[Mapping[str, Any], "Config"], - remove_extra: bool = False, - ) -> _ConfigSelf: - """Deep merge the config with updates, using current as defaults.""" - defaults = self.copy() - updates_config = Config(updates).copy() - merged = deep_merge_configs(updates_config, defaults, remove_extra=remove_extra) - return type(self)( - merged, - is_interpolated=defaults.is_interpolated and updates_config.is_interpolated, - section_order=defaults.section_order, - ) - - def _sort( - self, data: Union["Config", "ConfigParser", Dict[str, Any]] - ) -> Dict[str, Any]: - """Sort sections using the currently defined sort order. Sort - sections by index on section order, if available, then alphabetic, and - account for subsections, which should always follow their parent. - """ - sort_map = {section: i for i, section in enumerate(self.section_order)} - sort_key = lambda x: ( - sort_map.get(x[0].split(".")[0], len(sort_map)), - _mask_positional_args(x[0]), - ) - return dict(sorted(data.items(), key=sort_key)) - - def _set_overrides( - self, config: "ConfigParser", overrides: Mapping[str, Any] - ) -> None: - """Set overrides in the ConfigParser before config is interpreted.""" - err_title = "Error parsing config overrides" - for key, value in overrides.items(): - err_msg = "not a section value that can be overridden" - err = [{"loc": key.split("."), "msg": err_msg}] - if "." not in key: - raise ConfigValidationError(errors=err, title=err_title) - section, option = key.rsplit(".", 1) - # Check for section and accept if option not in config[section] - if section not in config: - raise ConfigValidationError(errors=err, title=err_title) - config.set(section, option, try_dump_json(value, overrides)) - - def _validate_sections(self, config: "ConfigParser") -> None: - # If the config defines top-level properties that are not sections (e.g. - # if config was constructed from dict), those values would be added as - # [DEFAULTS] and included in *every other section*. This is usually not - # what we want and it can lead to very confusing results. - default_section = config.defaults() - if default_section: - err_title = "Found config values without a top-level section" - err_msg = "not part of a section" - err = [{"loc": [k], "msg": err_msg} for k in default_section] - raise ConfigValidationError(errors=err, title=err_title) - - def from_str( - self: _ConfigSelf, - text: str, - *, - interpolate: bool = True, - overrides: Mapping[str, Any] = {}, - ) -> _ConfigSelf: - """Load the config from a string.""" - config = get_configparser(interpolate=interpolate) - if overrides: - config = get_configparser(interpolate=False) - try: - config.read_string(text) - except ParsingError as e: - desc = f"Make sure the sections and values are formatted correctly.\n\n{e}" - raise ConfigValidationError(desc=desc) from None - config._sections = self._sort(config._sections) # type: ignore - self._set_overrides(config, overrides) - self.clear() - self.interpret_config(config) - if overrides and interpolate: - # do the interpolation. Avoids recursion because the new call from_str call will have overrides as empty - self = self.interpolate() - self.is_interpolated = interpolate - return self - - def to_str(self, *, interpolate: bool = True) -> str: - """Write the config to a string.""" - flattened = get_configparser(interpolate=interpolate) - queue: List[Tuple[tuple, "Config"]] = [(tuple(), self)] - for path, node in queue: - section_name = ".".join(path) - is_kwarg = path and path[-1] != "*" - has_leaves = any(not hasattr(v, "items") for v in node.values()) - if path and has_leaves and not flattened.has_section(section_name): - # Create sections that have leaf values (including * sections). - # Skip empty sections whose children are all sub-sections. - flattened.add_section(section_name) - elif is_kwarg and not flattened.has_section(section_name): - # Always create non-* sections even if empty, so we don't - # expand blocks that are undefined. - flattened.add_section(section_name) - for key, value in node.items(): - if hasattr(value, "items"): - # Reference to a function with no arguments, serialize - # inline as a dict and don't create new section - if registry.is_promise(value) and len(value) == 1 and is_kwarg: - flattened.set(section_name, key, try_dump_json(value, node)) - else: - queue.append((path + (key,), value)) - else: - flattened.set(section_name, key, try_dump_json(value, node)) - # Order so subsection follow parent (not all sections, then all subs etc.) - flattened._sections = self._sort(flattened._sections) # type: ignore - self._validate_sections(flattened) - string_io = io.StringIO() - flattened.write(string_io) - return string_io.getvalue().strip() - - def to_bytes(self, *, interpolate: bool = True) -> bytes: - """Serialize the config to a byte string.""" - return self.to_str(interpolate=interpolate).encode("utf8") - - def from_bytes( - self: _ConfigSelf, - bytes_data: bytes, - *, - interpolate: bool = True, - overrides: Mapping[str, Any] = {}, - ) -> _ConfigSelf: - """Load the config from a byte string.""" - return self.from_str( - bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides - ) - - def to_disk(self, path: Union[str, Path], *, interpolate: bool = True): - """Serialize the config to a file.""" - path = Path(path) if isinstance(path, str) else path - with path.open("w", encoding="utf8") as file_: - file_.write(self.to_str(interpolate=interpolate)) - - def from_disk( - self: _ConfigSelf, - path: Union[str, Path], - *, - interpolate: bool = True, - overrides: Mapping[str, Any] = {}, - ) -> _ConfigSelf: - """Load config from a file.""" - path = Path(path) if isinstance(path, str) else path - with path.open("r", encoding="utf8") as file_: - text = file_.read() - return self.from_str(text, interpolate=interpolate, overrides=overrides) - - -def _mask_positional_args(name: str) -> List[Optional[str]]: - """Create a section name representation that masks names - of positional arguments to retain their order in sorts.""" - - stable_name = cast(List[Optional[str]], name.split(".")) - - # Remove names of sections that are a positional argument. - for i in range(1, len(stable_name)): - if stable_name[i - 1] == "*": - stable_name[i] = None - - return stable_name - - -def try_load_json(value: str) -> Any: - """Load a JSON string if possible, otherwise default to original value.""" - # Guard against ujson quirk where "-" parses as 0 - if value == "-": - return value - try: - return _json.loads(value) - except (ValueError, TypeError): - return value - - -def try_dump_json(value: Any, data: Union[Mapping[str, Any], Config, str] = "") -> str: - """Dump a config value as JSON and output user-friendly error if it fails.""" - # Special case if we have a variable: it's already a string so don't dump - # to preserve ${x:y} vs. "${x:y}" - if isinstance(value, str) and VARIABLE_RE.search(value): - return value - try: - value = _json.dumps(value, separators=(",", ":")) - # Escape all $ to $$ for configparser, then restore valid ${...} - # variable references. This ensures incomplete sequences like "${" - # are escaped, while "${foo.bar}" is preserved. - value = value.replace("$", "$$") - value = re.sub(r"\$(\$\{[\w\.:]+\})", r"\1", value) - return value - except Exception as e: - err_msg = ( - f"Couldn't serialize config value of type {type(value)}: {e}. Make " - f"sure all values in your config are JSON-serializable. If you want " - f"to include Python objects, use a registered function that returns " - f"the object instead." - ) - raise ConfigValidationError(config=data, desc=err_msg) from e - - -def deep_merge_configs( - config: Union[Dict[str, Any], Config], - defaults: Union[Dict[str, Any], Config], - *, - remove_extra: bool = False, -) -> Union[Dict[str, Any], Config]: - """Deep merge two configs.""" - if remove_extra: - # Filter out values in the original config that are not in defaults - keys = list(config.keys()) - for key in keys: - if key not in defaults: - del config[key] - for key, value in defaults.items(): - if isinstance(value, dict): - node = config.setdefault(key, {}) - if not isinstance(node, dict): - continue - value_promises = [k for k in value if k.startswith("@")] - value_promise = value_promises[0] if value_promises else None - node_promises = [k for k in node if k.startswith("@")] if node else [] - node_promise = node_promises[0] if node_promises else None - # We only update the block from defaults if it refers to the same - # registered function - if ( - value_promise - and node_promise - and ( - value_promise in node - and node[value_promise] != value[value_promise] - ) - ): - continue - if node_promise and ( - node_promise not in value or node[node_promise] != value[node_promise] - ): - continue - defaults = deep_merge_configs(node, value, remove_extra=remove_extra) - elif key not in config: - config[key] = value - return config - - -_CVESelf = TypeVar("_CVESelf", bound="ConfigValidationError") - - -class ConfigValidationError(ValueError): - def __init__( - self, - *, - config: Optional[Union[Config, Mapping[str, Any], str]] = None, - errors: Iterable[Mapping[str, Any]] = tuple(), - title: Optional[str] = "Config validation error", - desc: Optional[str] = None, - parent: Optional[str] = None, - show_config: bool = True, - ) -> None: - """Custom error for validating configs. - - config (Union[Config, Dict[str, Dict[str, Any]], str]): The - config the validation error refers to. - errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]): - A list of errors as dicts with keys "loc" (list of strings - describing the path of the value), "msg" (validation message - to show) and optional "type" (mostly internals). - Same format as produced by pydantic's validation error (e.errors()). - title (str): The error title. - desc (str): Optional error description, displayed below the title. - parent (str): Optional parent to use as prefix for all error locations. - For example, parent "element" will result in "element -> a -> b". - show_config (bool): Whether to print the whole config with the error. - - ATTRIBUTES: - config (Union[Config, Dict[str, Dict[str, Any]], str]): The config. - errors (Iterable[Dict[str, Any]]): The errors. - error_types (Set[str]): All "type" values defined in the errors, if - available. This is most relevant for the pydantic errors that define - types like "type_error.integer". This attribute makes it easy to - check if a config validation error includes errors of a certain - type, e.g. to log additional information or custom help messages. - title (str): The title. - desc (str): The description. - parent (str): The parent. - show_config (bool): Whether to show the config. - text (str): The formatted error text. - """ - self.config = config - self.errors = errors - self.title = title - self.desc = desc - self.parent = parent - self.show_config = show_config - self.error_types = set() - for error in self.errors: - err_type = error.get("type") - if err_type: - self.error_types.add(err_type) - self.text = self._format() - ValueError.__init__(self, self.text) - - @classmethod - def from_error( - cls: Type[_CVESelf], - err: "ConfigValidationError", - title: Optional[str] = None, - desc: Optional[str] = None, - parent: Optional[str] = None, - show_config: Optional[bool] = None, - ) -> _CVESelf: - """Create a new ConfigValidationError based on an existing error, e.g. - to re-raise it with different settings. If no overrides are provided, - the values from the original error are used. - - err (ConfigValidationError): The original error. - title (str): Overwrite error title. - desc (str): Overwrite error description. - parent (str): Overwrite error parent. - show_config (bool): Overwrite whether to show config. - RETURNS (ConfigValidationError): The new error. - """ - return cls( - config=err.config, - errors=err.errors, - title=title if title is not None else err.title, - desc=desc if desc is not None else err.desc, - parent=parent if parent is not None else err.parent, - show_config=show_config if show_config is not None else err.show_config, - ) - - def _format(self) -> str: - """Format the error message.""" - loc_divider = "->" - data = [] - for error in self.errors: - err_loc = f" {loc_divider} ".join([str(p) for p in error.get("loc", [])]) - if self.parent: - err_loc = f"{self.parent} {loc_divider} {err_loc}" - data.append((err_loc, error.get("msg"))) - result = [] - if self.title: - result.append(self.title) - if self.desc: - result.append(self.desc) - if data: - result.append("\n".join([f"{entry[0]}\t{entry[1]}" for entry in data])) - if self.config and self.show_config: - result.append(f"{self.config}") - return "\n\n" + "\n".join(result) - - -def alias_generator(name: str) -> str: - """Generate field aliases in promise schema.""" - # Underscore fields are not allowed in model, so use alias - if name == ARGS_FIELD_ALIAS: - return ARGS_FIELD - # Auto-alias fields that shadow base model attributes - if name in RESERVED_FIELDS: - return RESERVED_FIELDS[name] - return name - - -def _override_field_to_any(schema, field_name): - """Return a copy of the schema with one field's type set to Any. - - Creates a new schema class so the original (possibly cached) schema - is not mutated. The copy always uses our own model_validate (not - pydantic delegation) since the field types no longer match. - """ - if field_name not in schema.model_fields: - return schema - # Build a new schema with the overridden field - new_fields = {} - for name, field in schema.model_fields.items(): - if name == field_name: - new_field = FieldInfo(default=field.default, alias=field.alias) - new_field.annotation = Any - new_fields[name] = new_field - else: - new_fields[name] = field - new_schema = type(schema.__name__ + "_any", (Schema,), {}) - new_schema.model_fields = new_fields - new_schema.model_config = schema.model_config - return new_schema - - -def _contains_promise(obj): - """Check if a config dict contains any promise references (nested).""" - if isinstance(obj, dict): - if any(k.startswith("@") for k in obj if isinstance(k, str)): - return True - return any(_contains_promise(v) for v in obj.values()) - return False - - -def _coerce_basemodel_args(func, kwargs): - """Coerce dict kwargs to BaseModel instances where the function signature - expects a BaseModel subclass. This lets registered functions receive - constructed model instances instead of raw dicts (issue #58). - """ - try: - hints = get_type_hints(func) - except Exception: - return kwargs - result = dict(kwargs) - for name, value in result.items(): - if not isinstance(value, dict) or name not in hints: - continue - annotation = hints[name] - # Check if annotation is a class with __fields__ (pydantic v1) - # or model_fields (pydantic v2 / our Schema) - if isinstance(annotation, type) and ( - hasattr(annotation, "__fields__") or hasattr(annotation, "model_fields") - ): - try: - result[name] = annotation(**value) - except Exception: - pass # if construction fails, leave as dict - return result - - -class EmptySchema(Schema): - model_config = {"extra": "allow", "arbitrary_types_allowed": True} - - -@dataclass -class Promise: - registry: str - name: str - args: List[str] - kwargs: Dict[str, Any] - - -class registry: - @classmethod - def has(cls, registry_name: str, func_name: str) -> bool: - """Check whether a function is available in a registry.""" - if not hasattr(cls, registry_name): - return False - reg = getattr(cls, registry_name) - return func_name in reg - - @classmethod - def get(cls, registry_name: str, func_name: str) -> Callable: - """Get a registered function from a given registry.""" - if not hasattr(cls, registry_name): - raise ValueError(f"Unknown registry: '{registry_name}'") - reg = getattr(cls, registry_name) - func = reg.get(func_name) - if func is None: - raise ValueError(f"Could not find '{func_name}' in '{registry_name}'") - return func - - @classmethod - def resolve( - cls, - config: Union[Config, Mapping[str, Mapping[str, Any]]], - *, - schema: Type[Schema] = EmptySchema, - overrides: Mapping[str, Any] = {}, - validate: bool = True, - ) -> Dict[str, Any]: - schema = ensure_schema(schema) - resolved, _ = cls._make( - config, schema=schema, overrides=overrides, validate=validate, resolve=True - ) - return resolved - - @classmethod - def fill( - cls, - config: Union[Config, Mapping[str, Mapping[str, Any]]], - *, - schema: Type[Schema] = EmptySchema, - overrides: Mapping[str, Any] = {}, - validate: bool = True, - ) -> Config: - schema = ensure_schema(schema) - _, filled = cls._make( - config, schema=schema, overrides=overrides, validate=validate, resolve=False - ) - return filled - - @classmethod - def _make( - cls, - config: Union[Config, Mapping[str, Mapping[str, Any]]], - *, - schema: Type[Schema] = EmptySchema, - overrides: Mapping[str, Any] = {}, - resolve: bool = True, - validate: bool = True, - ) -> Tuple[Dict[str, Any], Config]: - """Unpack a config dictionary and create two versions of the config: - a resolved version with objects from the registry created recursively, - and a filled version with all references to registry functions left - intact, but filled with all values and defaults based on the type - annotations. If validate=True, the config will be validated against the - type annotations of the registered functions referenced in the config - (if available) and/or the schema (if available). - """ - # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} - # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} - if cls.is_promise(config): - err_msg = "The top-level config object can't be a reference to a registered function." - raise ConfigValidationError(config=config, errors=[{"msg": err_msg}]) - # If a Config was loaded with interpolate=False, we assume it needs to - # be interpolated first, otherwise we take it at face value - is_interpolated = not isinstance(config, Config) or config.is_interpolated - section_order = config.section_order if isinstance(config, Config) else None - orig_config = config - if not is_interpolated: - config = Config(orig_config).interpolate() - filled, _, resolved = cls._fill( - config, # type: ignore[arg-type] - schema, - validate=validate, - overrides=overrides, - resolve=resolve, - ) - filled = Config(filled, section_order=section_order) - # Check that overrides didn't include invalid properties not in config - if validate: - cls._validate_overrides(filled, overrides) - # Merge the original config back to preserve variables if we started - # with a config that wasn't interpolated. Here, we prefer variables to - # allow auto-filling a non-interpolated config without destroying - # variable references. - if not is_interpolated: - filled = filled.merge( - Config(orig_config, is_interpolated=False), remove_extra=True - ) - return dict(resolved), filled - - @classmethod - def _fill( - cls, - config: Union[Config, Dict[str, Dict[str, Any]]], - schema: Type[Schema] = EmptySchema, - *, - validate: bool = True, - resolve: bool = True, - parent: str = "", - overrides: Mapping[str, Any] = {}, - ) -> Tuple[ - Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] - ]: - """Build three representations of the config: - 1. All promises are preserved (just like config user would provide). - 2. Promises are replaced by their return values. This is the validation - copy and will be parsed by pydantic. It lets us include hacks to - work around problems (e.g. handling of generators). - 3. Final copy with promises replaced by their return values. - """ - schema = ensure_schema(schema) - filled: Dict[str, Any] = {} - validation: Dict[str, Any] = {} - final: Dict[str, Any] = {} - for key, value in config.items(): - # If the field name is reserved, we use its alias for validation - v_key = RESERVED_FIELDS.get(key, key) - key_parent = f"{parent}.{key}".strip(".") - if key_parent in overrides: - value = overrides[key_parent] - config[key] = value - if cls.is_promise(value): - if key in schema.model_fields and not resolve: - # If we're not resolving the config, make sure that the field - # expecting the promise is typed Any so it doesn't fail - # validation if it doesn't receive the function return value - schema = _override_field_to_any(schema, key) - promise_schema = cls.make_promise_schema(value, resolve=resolve) - filled[key], validation[v_key], final[key] = cls._fill( - value, - promise_schema, - validate=validate, - resolve=resolve, - parent=key_parent, - overrides=overrides, - ) - reg_name, func_name = cls.get_constructor(final[key]) - args, kwargs = cls.parse_args(final[key]) - if resolve: - # Call the function and populate the field value. We can't - # just create an instance of the type here, since this - # wouldn't work for generics / more complex custom types - getter = cls.get(reg_name, func_name) - # Coerce dict values to BaseModel instances where the - # function annotation expects one (fixes #58). - kwargs = _coerce_basemodel_args(getter, kwargs) - # We don't want to try/except this and raise our own error - # here, because we want the traceback if the function fails. - getter_result = getter(*args, **kwargs) - else: - # We're not resolving and calling the function, so replace - # the getter_result with a Promise class - getter_result = Promise( - registry=reg_name, name=func_name, args=args, kwargs=kwargs - ) - validation[v_key] = getter_result - final[key] = getter_result - if isinstance(validation[v_key], GeneratorType): - # If value is a generator we can't validate type without - # consuming it (which doesn't work if it's infinite – see - # schedule for examples). So we skip it. - validation[v_key] = [] - elif hasattr(value, "items"): - field_type = EmptySchema - if key in schema.model_fields: - field = schema.model_fields[key] - field_type = field.annotation - if not hasattr(field_type, "model_fields"): - # If we don't have a schema and just a type - field_type = EmptySchema - if not resolve and _contains_promise(value): - # If we're not resolving and the value contains nested - # promises, override the field to Any so validation - # doesn't reject Promise objects in typed containers - schema = _override_field_to_any(schema, key) - filled[key], validation[v_key], final[key] = cls._fill( - value, - field_type, - validate=validate, - resolve=resolve, - parent=key_parent, - overrides=overrides, - ) - if key == ARGS_FIELD and isinstance(validation[v_key], dict): - # If the value of variable positional args is a dict (e.g. - # created via config blocks), only use its values - validation[v_key] = list(validation[v_key].values()) - final[key] = list(final[key].values()) - - if ARGS_FIELD_ALIAS in schema.model_fields and not resolve: - # If we're not resolving the config, make sure that the field - # expecting the promise is typed Any so it doesn't fail - # validation if it doesn't receive the function return value - schema = _override_field_to_any(schema, ARGS_FIELD_ALIAS) - else: - filled[key] = value - # Prevent pydantic from consuming generator if part of a union - validation[v_key] = ( - value if not isinstance(value, GeneratorType) else [] - ) - final[key] = value - # Now that we've filled in all of the promises, update with defaults - # from schema, and validate if validation is enabled - exclude = [] - if validate: - try: - schema.model_validate(validation) - except ValidationError as e: - raise ConfigValidationError( - config=config, errors=e.errors(), parent=parent - ) from None - else: - # If our schema doesn't allow extra values, we need to filter them - extra_setting = schema.model_config.get("extra", "allow") - if extra_setting in ("forbid", "ignore"): - schema_fields = set(schema.model_fields.keys()) - exclude = [k for k in validation if k not in schema_fields] - # Update validation dict with defaults from schema - exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) - for name, field in schema.model_fields.items(): - if name in exclude_validation: - continue - if name not in validation and not field.is_required(): - default = field.default - # Unfreeze frozen containers so _update_from_parsed can write - if isinstance(default, dict): - default = dict(default) - elif isinstance(default, list): - default = list(default) - validation[name] = default - filled, final = cls._update_from_parsed(validation, filled, final) - if exclude: - filled = {k: v for k, v in filled.items() if k not in exclude} - validation = {k: v for k, v in validation.items() if k not in exclude} - final = {k: v for k, v in final.items() if k not in exclude} - return filled, validation, final - - @classmethod - def _update_from_parsed( - cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any] - ): - """Update the final result with the parsed config like converted - values recursively. - """ - for key, value in validation.items(): - if key in RESERVED_FIELDS.values(): - continue # skip aliases for reserved fields - if key not in filled: - filled[key] = value - if key not in final: - final[key] = value - if isinstance(value, dict): - filled[key], final[key] = cls._update_from_parsed( - value, filled[key], final[key] - ) - # Update final config with parsed value if they're not equal (in - # value and in type) but not if it's a generator because we had to - # replace that to validate it correctly - elif key == ARGS_FIELD: - continue # don't substitute if list of positional args - # Check numpy first, just in case. Use stringified type so that numpy dependency can be ditched. - elif str(type(value)) == "": - final[key] = value - elif ( - value != final[key] or not isinstance(type(value), type(final[key])) - ) and not isinstance(final[key], GeneratorType): - final[key] = value - return filled, final - - @classmethod - def _validate_overrides(cls, filled: Config, overrides: Mapping[str, Any]): - """Validate overrides against a filled config to make sure there are - no references to properties that don't exist and weren't used.""" - error_msg = "Invalid override: config value doesn't exist" - errors = [] - for override_key in overrides.keys(): - if not cls._is_in_config(override_key, filled): - errors.append({"msg": error_msg, "loc": [override_key]}) - if errors: - raise ConfigValidationError(config=filled, errors=errors) - - @classmethod - def _is_in_config(cls, prop: str, config: Union[Mapping[str, Any], Config]): - """Check whether a nested config property like "section.subsection.key" - is in a given config.""" - tree = prop.split(".") - obj = dict(config) - while tree: - key = tree.pop(0) - if isinstance(obj, dict) and key in obj: - obj = obj[key] - else: - return False - return True - - @classmethod - def is_promise(cls, obj: Any) -> bool: - """Check whether an object is a "promise", i.e. contains a reference - to a registered function (via a key starting with `"@"`. - """ - if not hasattr(obj, "keys"): - return False - id_keys = [k for k in obj.keys() if k.startswith("@")] - if len(id_keys): - return True - return False - - @classmethod - def get_constructor(cls, obj: Mapping[str, Any]) -> Tuple[str, str]: - id_keys = [k for k in obj.keys() if k.startswith("@")] - if len(id_keys) != 1: - err_msg = f"A block can only contain one function registry reference. Got: {id_keys}" - raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}]) - else: - key = id_keys[0] - value = obj[key] - return (key[1:], value) - - @classmethod - def parse_args(cls, obj: Mapping[str, Any]) -> Tuple[List[Any], Dict[str, Any]]: - args = [] - kwargs = {} - for key, value in obj.items(): - if not key.startswith("@"): - if key == ARGS_FIELD: - args = value - elif key in RESERVED_FIELDS.values(): - continue - else: - kwargs[key] = value - return args, kwargs - - @classmethod - def make_promise_schema( - cls, obj: Mapping[str, Any], *, resolve: bool = True - ) -> Type[Schema]: - """Create a schema for a promise dict (referencing a registry function) - by inspecting the function signature. - """ - reg_name, func_name = cls.get_constructor(obj) - if not resolve and not cls.has(reg_name, func_name): - return EmptySchema - func = cls.get(reg_name, func_name) - # Read the argument annotations and defaults from the function signature - id_keys = [k for k in obj.keys() if k.startswith("@")] - sig_args: Dict[str, Any] = {id_keys[0]: (str, Field(...))} - for param in inspect.signature(func).parameters.values(): - # If no annotation is specified assume it's anything - annotation = param.annotation if param.annotation != param.empty else Any - # If no default value is specified assume that it's required - default = param.default if param.default != param.empty else ... - # Handle spread arguments and use their annotation as Sequence[whatever] - if param.kind == param.VAR_POSITIONAL: - spread_annot = Sequence[annotation] # type: ignore - sig_args[ARGS_FIELD_ALIAS] = (spread_annot, Field(default)) - else: - name = RESERVED_FIELDS.get(param.name, param.name) - sig_args[name] = (annotation, Field(default)) - return create_schema( - "ArgModel", - __config__={ - "extra": "forbid", - "arbitrary_types_allowed": True, - "alias_generator": alias_generator, - }, - **sig_args, - ) - +from .validation import Schema # noqa: F401 __all__ = [ "Config", diff --git a/confection/_config.py b/confection/_config.py new file mode 100644 index 0000000..31100f0 --- /dev/null +++ b/confection/_config.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import copy +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from ._errors import ConfectionError, ConfigValidationError +from ._parser import parse_config, serialize_config +from .validation import ValidationError, ensure_schema + + +class Config(dict): + # TODO: Improve doc string + """Dict subclass to save TOML-style configuration format from/to string, file + or bytes. + """ + + is_interpolated: bool + + def __init__( + self, + data: Optional[Union[Dict[str, Any], "Config"]] = None, + *, + is_interpolated: Optional[bool] = None, + section_order: Optional[List[str]] = None, + ) -> None: + """Initialize a new Config object with optional data.""" + dict.__init__(self) + if data is None: + data = {} + if not isinstance(data, (dict, Config)): + raise ConfectionError( + f"Can't initialize Config with data. Expected dict or " + f"Config but got: {type(data)}" + ) + # Whether the config has been interpolated. We can use this to check + # whether we need to interpolate again when it's resolved. We assume + # that a config is interpolated by default. + if is_interpolated is not None: + self.is_interpolated = is_interpolated + elif isinstance(data, Config): + self.is_interpolated = data.is_interpolated + else: + self.is_interpolated = True + if section_order is not None: + self.section_order = section_order + elif isinstance(data, Config): + self.section_order = data.section_order + else: + self.section_order = [] + # Update with data + self.update(data) + + def interpolate(self) -> Self: + """Interpolate (resolve var references) a config. + + Returns a copy of the object. + """ + # This is currently the most effective way because we need our custom + # to_str logic to run in order to re-serialize the values so we can + # interpolate them again. ConfigParser.read_dict will just call str() + # on all values, which isn't enough. + return type(self)().from_str(self.to_str()) + + def copy(self) -> Self: + """Deepcopy the config.""" + config = copy.deepcopy(self) + return type(self)( + config, + is_interpolated=self.is_interpolated, + section_order=self.section_order, + ) + + def merge( + self, updates: Union[Dict[str, Any], "Config"], remove_extra: bool = False + ) -> Self: + """Deep merge the config with updates, using current as defaults.""" + defaults = self.copy() + updates = Config(updates).copy() + merged = deep_merge_configs(updates, defaults, remove_extra=remove_extra) + return type(self)( + merged, + is_interpolated=defaults.is_interpolated and updates.is_interpolated, + section_order=defaults.section_order, + ) + + def validate(self, schema) -> Self: + """Validate the config against a schema. Raises ConfigValidationError + if validation fails. + """ + schema = ensure_schema(schema) + _validate_recursive(dict(self), schema, self) + return self + + def fill_defaults(self, schema) -> Self: + """Fill in missing values from schema defaults and remove extra + fields if the schema forbids them. Modifies in place and returns self. + """ + schema = ensure_schema(schema) + extra = schema.model_config.get("extra", "allow") + # Fill defaults + for name, field in schema.model_fields.items(): + if name not in self and not field.is_required(): + self[name] = field.default + elif name in self and isinstance(self[name], dict): + field_schema = field.annotation + if isinstance(field_schema, type) and hasattr( + field_schema, "model_fields" + ): + sub_schema = ensure_schema(field_schema) + _fill_defaults_recursive(self[name], sub_schema) + # Strip extras + if extra == "forbid": + known = set(schema.model_fields.keys()) + for key in list(self.keys()): + if key not in known: + del self[key] + return self + + def from_str( + self, + text: str, + *, + interpolate: bool = True, + overrides: Dict[str, Any] = {}, + schema=None, + ) -> Self: + """Load the config from a string.""" + self.clear() + self.update(parse_config(text, interpolate=interpolate, overrides=overrides)) + if overrides and interpolate: + # Re-interpolate now that overrides are applied. The recursive + # from_str call will have no overrides, so this doesn't loop. + self = self.interpolate() + self.is_interpolated = interpolate + if schema is not None: + self.fill_defaults(schema) + self.validate(schema) + return self + + def to_str(self, *, interpolate: bool = True) -> str: + """Write the config to a string.""" + return serialize_config(self, interpolate=interpolate) + + def to_bytes(self, *, interpolate: bool = True) -> bytes: + """Serialize the config to a byte string.""" + return self.to_str(interpolate=interpolate).encode("utf8") + + def from_bytes( + self, + bytes_data: bytes, + *, + interpolate: bool = True, + overrides: Dict[str, Any] = {}, + ) -> Self: + """Load the config from a byte string.""" + return self.from_str( + bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides + ) + + def to_disk(self, path: Union[str, Path], *, interpolate: bool = True) -> None: + """Serialize the config to a file.""" + path = Path(path) if isinstance(path, str) else path + with path.open("w", encoding="utf8") as file_: + file_.write(self.to_str(interpolate=interpolate)) + + def from_disk( + self, + path: Union[str, Path], + *, + interpolate: bool = True, + overrides: Dict[str, Any] = {}, + ) -> Self: + """Load config from a file.""" + path = Path(path) if isinstance(path, str) else path + with path.open("r", encoding="utf8") as file_: + text = file_.read() + return self.from_str(text, interpolate=interpolate, overrides=overrides) + + +def _fill_defaults_recursive(data, schema): + """Fill defaults and strip extras recursively for nested schemas.""" + extra = schema.model_config.get("extra", "allow") + for name, field in schema.model_fields.items(): + if name not in data and not field.is_required(): + data[name] = field.default + elif name in data and isinstance(data[name], dict): + field_schema = field.annotation + if isinstance(field_schema, type) and hasattr(field_schema, "model_fields"): + _fill_defaults_recursive(data[name], ensure_schema(field_schema)) + if extra == "forbid": + known = set(schema.model_fields.keys()) + for key in list(data.keys()): + if key not in known: + del data[key] + + +def _validate_recursive(data, schema, config, parent=""): + """Validate data against a schema, recursing into nested schemas.""" + try: + schema.model_validate(data) + except ValidationError as e: + section = f" in [{parent}]" if parent else "" + raise ConfigValidationError( + config=config, + errors=e.errors(), + title=f"Config validation error{section}", + ) from None + # Recurse into fields that are themselves schemas + for name, field in schema.model_fields.items(): + annotation = field.annotation + if ( + isinstance(annotation, type) + and hasattr(annotation, "model_validate") + and name in data + and isinstance(data[name], dict) + ): + child_parent = f"{parent}.{name}" if parent else name + _validate_recursive(data[name], annotation, config, parent=child_parent) + + +def deep_merge_configs( + config: Union[Dict[str, Any], Config], + defaults: Union[Dict[str, Any], Config], + *, + remove_extra: bool = False, +) -> Union[Dict[str, Any], Config]: + """Deep merge two configs.""" + if remove_extra: + # Filter out values in the original config that are not in defaults + keys = list(config.keys()) + for key in keys: + if key not in defaults: + del config[key] + for key, value in defaults.items(): + if isinstance(value, dict): + node = config.setdefault(key, {}) + if not isinstance(node, dict): + continue + value_promises = [k for k in value if k.startswith("@")] + value_promise = value_promises[0] if value_promises else None + node_promises = [k for k in node if k.startswith("@")] if node else [] + node_promise = node_promises[0] if node_promises else None + # We only update the block from defaults if it refers to the same + # registered function + if ( + value_promise + and node_promise + and ( + value_promise in node + and node[value_promise] != value[value_promise] + ) + ): + continue + if node_promise and ( + node_promise not in value or node[node_promise] != value[node_promise] + ): + continue + defaults = deep_merge_configs(node, value, remove_extra=remove_extra) + elif key not in config: + config[key] = value + return config diff --git a/confection/_constants.py b/confection/_constants.py new file mode 100644 index 0000000..a3bc963 --- /dev/null +++ b/confection/_constants.py @@ -0,0 +1,22 @@ +import re + +# Regex to detect whether a value contains a variable +VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") + +# Internal prefix used to mark section references for custom interpolation +SECTION_PREFIX = "__SECTION__:" + +# Field used for positional arguments, e.g. [section.*.xyz]. The alias is +# required for the schema (shouldn't clash with user-defined arg names) +ARGS_FIELD = "*" +ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS" + +# Aliases for fields that would otherwise shadow pydantic attributes. Can be any +# string, so we're using name + space so it looks the same in error messages etc. +RESERVED_FIELDS = { + "validate": "validate\u0020", + "model_config": "model_config\u0020", + "model_validate": "model_validate\u2020", + "model_fields": "model_fields\u2020", +} +RESERVED_FIELDS_REVERSE = {v: k for k, v in RESERVED_FIELDS.items()} diff --git a/confection/_errors.py b/confection/_errors.py new file mode 100644 index 0000000..a650a8c --- /dev/null +++ b/confection/_errors.py @@ -0,0 +1,112 @@ +from typing import Optional + + +class ConfectionError(ValueError): ... + + +class ConfigValidationError(ConfectionError): + def __init__( + self, + *, + config=None, + errors=None, + title: Optional[str] = "Config validation error", + desc: Optional[str] = None, + parent: Optional[str] = None, + show_config: bool = True, + ) -> None: + """Custom error for validating configs. + + config (Union[Config, Dict[str, Dict[str, Any]], str]): The + config the validation error refers to. + errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]): + A list of errors as dicts with keys "loc" (list of strings + describing the path of the value), "msg" (validation message + to show) and optional "type" (mostly internals). + Same format as produced by pydantic's validation error (e.errors()). + title (str): The error title. + desc (str): Optional error description, displayed below the title. + parent (str): Optional parent to use as prefix for all error locations. + For example, parent "element" will result in "element -> a -> b". + show_config (bool): Whether to print the whole config with the error. + + ATTRIBUTES: + config (Union[Config, Dict[str, Dict[str, Any]], str]): The config. + errors (Iterable[Dict[str, Any]]): The errors. + error_types (Set[str]): All "type" values defined in the errors, if + available. This is most relevant for the pydantic errors that define + types like "type_error.integer". This attribute makes it easy to + check if a config validation error includes errors of a certain + type, e.g. to log additional information or custom help messages. + title (str): The title. + desc (str): The description. + parent (str): The parent. + show_config (bool): Whether to show the config. + text (str): The formatted error text. + """ + self.config = config + self.errors = errors + self.title = title + self.desc = desc + self.parent = parent + self.show_config = show_config + self.error_types = set() + if self.errors: + for error in self.errors: + err_type = error.get("type") + if err_type: + self.error_types.add(err_type) + self.text = self._format() + ValueError.__init__(self, self.text) + + @classmethod + def from_error( + cls, + err: "ConfigValidationError", + title: Optional[str] = None, + desc: Optional[str] = None, + parent: Optional[str] = None, + show_config: Optional[bool] = None, + ) -> "ConfigValidationError": + """Create a new ConfigValidationError based on an existing error, e.g. + to re-raise it with different settings. If no overrides are provided, + the values from the original error are used. + + err (ConfigValidationError): The original error. + title (str): Overwrite error title. + desc (str): Overwrite error description. + parent (str): Overwrite error parent. + show_config (bool): Overwrite whether to show config. + RETURNS (ConfigValidationError): The new error. + """ + return cls( + config=err.config, + errors=err.errors, + title=title if title is not None else err.title, + desc=desc if desc is not None else err.desc, + parent=parent if parent is not None else err.parent, + show_config=show_config if show_config is not None else err.show_config, + ) + + def _format(self) -> str: + """Format the error message.""" + loc_divider = "->" + data = [] + if self.errors: + for error in self.errors: + err_loc = f" {loc_divider} ".join( + [str(p) for p in error.get("loc", [])] + ) + if self.parent: + err_loc = f"{self.parent} {loc_divider} {err_loc}" + data.append((err_loc, error.get("msg"))) + result = [] + if self.title: + result.append(self.title) + if self.desc: + result.append(self.desc) + if data: + result.append("\n".join([f"{entry[0]}\t{entry[1]}" for entry in data])) + if self.config and self.show_config: + result.append(f"{self.config}") + return "\n\n" + "\n".join(result) diff --git a/confection/_parser.py b/confection/_parser.py new file mode 100644 index 0000000..f05f131 --- /dev/null +++ b/confection/_parser.py @@ -0,0 +1,404 @@ +import io +import json +import warnings +from configparser import ( + MAX_INTERPOLATION_DEPTH, + ConfigParser, + ExtendedInterpolation, + InterpolationDepthError, + InterpolationMissingOptionError, + InterpolationSyntaxError, + NoOptionError, + NoSectionError, + ParsingError, +) +from typing import Any, Dict, List + +from ._constants import SECTION_PREFIX +from ._errors import ConfigValidationError +from .util import VARIABLE_RE, try_dump_json, try_load_json + + +def parse_config( + text: str, + *, + interpolate: bool = True, + overrides: Dict[str, Any] = {}, +) -> dict[str, Any]: + """Parse a config string into a nested dict. + + Handles the full pipeline: parse with ConfigParser, validate structure, + apply overrides, interpret values, and resolve section references. + + Returns the nested dict and whether a second interpolation pass is needed + (when overrides were applied with interpolation enabled). + """ + config_parser = _get_configparser(interpolate=interpolate and not overrides) + try: + config_parser.read_string(text) + except ParsingError as e: + desc = f"Make sure the sections and values are formatted correctly.\n\n{e}" + raise ConfigValidationError(desc=desc) from None + errors = _validate_configparser(config_parser) + if errors: + raise errors[0] + errors = _validate_overrides(config_parser, overrides) + if errors: + raise errors[0] + result: dict[str, Any] = {} + section_parts = [section.split(".") for section in config_parser.sections()] + # Phase 1: Build the skeleton of nested dicts from section names. + for parts in section_parts: + node = result + for part in parts[:-1]: + node = node.setdefault(part, {}) if part == "*" else node[part] + node.setdefault(parts[-1], {}) + # Phase 2: Fill in values, processing breadth-first by section depth. + for section, values in sorted( + config_parser.items(), key=lambda x: len(x[0].split(".")) + ): + if section == "DEFAULT": + continue + parts = section.split(".") + node = result + for part in parts: + node = node[part] + for key in values: + node[key] = _interpret_value(config_parser.get(section, key)) + # Phase 3: Apply overrides on the nested dict. + _apply_overrides(result, overrides) + # Phase 4: Replace section reference placeholders with actual dicts. + _replace_section_refs(result, result) + return result + + +def _apply_overrides(result: dict[str, Any], overrides: Dict[str, Any]) -> None: + """Apply dot-notation overrides to a nested dict. + + Override paths have already been validated by _validate_overrides. + """ + for key, value in overrides.items(): + path = key.split(".") + node = result + for part in path[:-1]: + node = node[part] + node[path[-1]] = value + + +def serialize_config( + data: dict[str, Any], + *, + interpolate: bool = True, + inline_paths: frozenset[str] = frozenset(), +) -> str: + """Serialize a nested config dict to a config string. + + inline_paths: dotted paths whose values should be serialized as inline + JSON rather than expanded into subsections. For example, if "a.b" is + in inline_paths, data["a"]["b"] (even if it's a dict) will be + serialized as ``b = {"key": "value"}`` under [a]. + """ + flattened = _get_configparser(interpolate=interpolate) + queue: list[tuple[tuple, dict[str, Any]]] = [(tuple(), data)] + for path, node in queue: + section_name = ".".join(path) + is_kwarg = path and path[-1] != "*" + has_leaves = any(not hasattr(v, "items") for v in node.values()) + if path and has_leaves and not flattened.has_section(section_name): + # Create sections that have leaf values (including * sections). + flattened.add_section(section_name) + elif is_kwarg and not flattened.has_section(section_name): + flattened.add_section(section_name) + for key, value in node.items(): + child_path = f"{section_name}.{key}" if section_name else key + if hasattr(value, "items") and child_path not in inline_paths: + queue.append((path + (key,), value)) + else: + flattened.set(section_name, key, try_dump_json(value, node)) + string_io = io.StringIO() + flattened.write(string_io) + return string_io.getvalue().strip() + + +def _get_configparser(interpolate: bool = True) -> ConfigParser: + config = ConfigParser(interpolation=_CustomInterpolation() if interpolate else None) + # Preserve case of keys: https://stackoverflow.com/a/1611877/6400719 + config.optionxform = str # type: ignore + return config + + +def _validate_configparser(config_parser: ConfigParser) -> list[ConfigValidationError]: + """Validate a configparser's structure before interpreting it into a Config. + + Checks that: + - No values leak into the DEFAULT section (top-level values without a section) + - All parent sections exist for dotted section names (e.g. "a.b" requires "a") + - No key in a section conflicts with a child section name + - No interpolation errors in values + """ + errors = [] + default_section = config_parser.defaults() + if default_section: # pragma: no cover -- configparser raises ParsingError first + err_title = "Found config values without a top-level section" + err_msg = "not part of a section" + err = [{"loc": [k], "msg": err_msg} for k in default_section] + errors.append(ConfigValidationError(errors=err, title=err_title)) + section_names = set(config_parser.sections()) + for section in config_parser.sections(): + path = section.split(".") + for i in range(1, len(path)): + # "*" is an implicit list section — it doesn't need a parent + # section header, and paths through it are always valid. + if path[i - 1] == "*": + continue + parent = ".".join(path[:i]) + # A parent is valid if it's a declared section OR if the path + # goes through a "*" component (which is implicitly created). + if parent not in section_names and "*" not in parent.split("."): + err_title = ( + "Error parsing config section. Perhaps a section name is wrong?" + ) + err = [{"loc": path, "msg": f"Section '{path[i - 1]}' is not defined"}] + errors.append(ConfigValidationError(errors=err, title=err_title)) + break + keys = set(config_parser.options(section)) + for other in section_names: + if other.startswith(section + "."): + child = other[len(section) + 1 :].split(".")[0] + if child in keys: + err = [{"loc": other.split("."), "msg": "found conflicting values"}] + errors.append(ConfigValidationError(errors=err)) + return errors + + +def _validate_overrides( + config_parser: ConfigParser, overrides: dict[str, Any] +) -> list[ConfigValidationError]: + errors = [] + err_title = "Error parsing config overrides" + for key in overrides: + err_msg = "not a section value that can be overridden" + err = [{"loc": key.split("."), "msg": err_msg}] + if "." not in key: + errors.append(ConfigValidationError(errors=err, title=err_title)) + continue + section, _ = key.rsplit(".", 1) + # Check for section and accept if option not in config[section] + if section not in config_parser: + errors.append(ConfigValidationError(errors=err, title=err_title)) + # TODO: Are we supposed to chek for the *option*? + return errors + + +def _interpret_value(value: Any) -> Any: + """Interpret a single config value.""" + result = try_load_json(value) + # If value is a string and it contains a variable, use original value + # (not interpreted string, which could lead to double quotes: + # ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string, + # so we're not keeping lists as strings. + # NOTE: This currently can't handle uninterpolated values like [${x.y}]! + if isinstance(result, str) and VARIABLE_RE.search(value): + result = value + return result + + +def _replace_section_refs( + root: dict[str, Any], node: dict[str, Any], parent: str = "" +) -> None: + """Replace section reference placeholders with actual dicts.""" + for key, value in node.items(): + key_parent = f"{parent}.{key}".strip(".") + if isinstance(value, dict): + _replace_section_refs(root, value, parent=key_parent) + elif isinstance(value, list): + node[key] = [_get_section_ref(root, v, parent=[parent, key]) for v in value] + else: + node[key] = _get_section_ref(root, value, parent=[parent, key]) + + +def _get_section_ref( + root: Dict[str, Any], value: Any, *, parent: List[str] = [] +) -> Any: + """Resolve a single section reference placeholder, or return value as-is.""" + if isinstance(value, str) and value.startswith( + f'"{SECTION_PREFIX}' + ): # pragma: no cover + value = try_load_json(value) # pragma: no cover + if ( + isinstance(value, str) + and value.startswith(SECTION_PREFIX) + and value != SECTION_PREFIX + ): + parts = value.replace(SECTION_PREFIX, "", 1).split(".") + result = root + for item in parts: + result = result[item] + return result + elif isinstance(value, str) and SECTION_PREFIX in value and value != SECTION_PREFIX: + err_desc = ( + "Can't reference whole sections or return values of function " + "blocks inside a string or list\n\nYou can change your variable to " + "reference a value instead. Keep in mind that it's not " + "possible to interpolate the return value of a registered " + "function, since variables are interpolated when the config " + "is loaded, and registered functions are resolved afterwards." + ) + err = [{"loc": parent, "msg": "uses section variable in string or list"}] + raise ConfigValidationError(errors=err, desc=err_desc) + return value + + +class _CustomInterpolation(ExtendedInterpolation): + def before_read( # pyright: ignore[reportIncompatibleMethodOverride] + self, parser: ConfigParser, section: str, option: str, value: str + ) -> str: + # Warn about single-quoted strings (common mistake) + if value and value[0] == value[-1] == "'": + warnings.warn( + f"The value [{value}] seems to be single-quoted, but values " + "use JSON formatting, which requires double quotes." + ) + return super().before_read(parser, section, option, value) + + def _coerce_for_string_context(self, v: str) -> str: + """Coerce a raw config value for use in a compound string expression.""" + # Don't coerce section references - they need to stay quoted for JSON + if SECTION_PREFIX in v: + return v + try: + parsed = json.loads(v) + except json.JSONDecodeError: + return v # Not valid JSON, already a plain string + if isinstance(parsed, str): + return parsed # Unwrap JSON string + # Use json.dumps() for non-strings, escaping inner quotes so they don't + # conflict with the outer JSON string quotes + return json.dumps(parsed).replace('"', '\\"') + + def before_get( # pyright: ignore[reportIncompatibleMethodOverride] + self, + parser: ConfigParser, + section: str, + option: str, + value: str, + defaults: Dict[str, str], + ) -> str: + # Mostly copy-pasted from the built-in configparser implementation. + # The interpolate() method resolves ${...} references and appends pieces + # to L. For a bare reference like ${x}, L has one element. For compound + # expressions like "hello ${x}", L has multiple pieces that we join. + # Compound results stay as strings (coerced via _coerce_for_string_context), + # while bare references keep their JSON type for _interpret_value to parse. + L: List[str] = [] + self.interpolate(parser, option, L, value, section, defaults, 1) + if len(L) == 1: + return L[0] + return "".join(self._coerce_for_string_context(piece) for piece in L) + + def interpolate( + self, + parser: ConfigParser, + option: str, + accum: List[str], + rest: str, + section: str, + map: Dict[str, str], + depth: int, + ) -> None: + """Resolve variable references like ${foo.bar}""" + # Mostly copy-pasted from the built-in configparser implementation. + # We need to overwrite this method so we can add special handling for + # block references :( All values produced here should be strings – + # we need to wait until the whole config is interpreted anyways so + # filling in incomplete values here is pointless. All we need is the + # section reference so we can fetch it later. + rawval = parser.get(section, option, raw=True, fallback=rest) + if depth > MAX_INTERPOLATION_DEPTH: + raise InterpolationDepthError(option, section, rawval) + while rest: + p = rest.find("$") + if p < 0: + accum.append(rest) + return + if p > 0: + accum.append(rest[:p]) + rest = rest[p:] + # p is no longer used + c = rest[1:2] + if c == "$": + accum.append("$") + rest = rest[2:] + elif c == "{": + # We want to treat both ${a:b} and ${a.b} the same + m = self._KEYCRE.match(rest) # type: ignore[attr-defined] + if m is None: + err = f"bad interpolation variable reference {rest}" + raise InterpolationSyntaxError(option, section, err) + orig_var = m.group(1) + path = orig_var.replace(":", ".").rsplit(".", 1) + rest = rest[m.end() :] + sect = section + opt = option + try: + if len(path) == 1: + opt = parser.optionxform(path[0]) + # Check if the variable references a section rather + # than a key in the current section. If the key + # exists in the current map but its raw value is the + # same interpolation variable (self-reference), or if + # the key doesn't exist in the map, treat it as a + # section reference. + is_section_ref = opt not in map + if not is_section_ref: + raw = map[opt] + is_section_ref = raw.strip() == rawval.strip() + if not is_section_ref: + v = map[opt] + else: + # We have block reference, store it as a special key + section_name = parser[parser.optionxform(path[0])].name + v = self._get_section_name(section_name) + elif len(path) == 2: + sect = path[0] + opt = parser.optionxform(path[1]) + fallback = "__FALLBACK__" + v = parser.get(sect, opt, raw=True, fallback=fallback) + # If a variable doesn't exist, try again and treat the + # reference as a section + if v == fallback: + v = self._get_section_name(parser[f"{sect}.{opt}"].name) + else: # pragma: no cover + # Dead code: rsplit(".", 1) produces at most 2 elements + err = f"More than one ':' found: {rest}" + raise InterpolationSyntaxError(option, section, err) + except (KeyError, NoSectionError, NoOptionError): + raise InterpolationMissingOptionError( + option, section, rawval, orig_var + ) from None + if "$" in v: + new_map = dict(parser.items(sect, raw=True)) + self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1) + else: + accum.append(v) + else: + err = "'$' must be followed by '$' or '{', found: %r" % (rest,) + raise InterpolationSyntaxError(option, section, err) + + def _get_section_name(self, name: str) -> str: + """Generate the name of a section. Note that we use a quoted string here + so we can use section references within lists and load the list as + JSON. Since section references can't be used within strings, we don't + need the quoted vs. unquoted distinction like we do for variables. + + Examples (assuming section = {"foo": 1}): + - value: ${section.foo} -> value: 1 + - value: "hello ${section.foo}" -> value: "hello 1" + - value: ${section} -> value: {"foo": 1} + - value: "${section}" -> value: {"foo": 1} + - value: "hello ${section}" -> invalid + """ + return f'"{SECTION_PREFIX}{name}"' + + +__all__ = ["parse_config", "serialize_config"] diff --git a/confection/_registry.py b/confection/_registry.py new file mode 100644 index 0000000..ecda479 --- /dev/null +++ b/confection/_registry.py @@ -0,0 +1,447 @@ +import inspect +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, +) + +from ._config import Config +from ._constants import ( + ARGS_FIELD, + ARGS_FIELD_ALIAS, + RESERVED_FIELDS_REVERSE, +) +from ._errors import ConfigValidationError +from .util import is_promise +from .validation import Schema + +_PromisedType = TypeVar("_PromisedType") + + +@dataclass +class Promise(Generic[_PromisedType]): + registry: str + name: str + var_args: List[Any] + kwargs: Dict[str, Any] + getter: Union[Callable[..., _PromisedType], Exception] + + @property + def return_type(self) -> _PromisedType: + if isinstance(self.getter, Exception): # pragma: no cover + raise self.getter # pragma: no cover + signature = inspect.signature(self.getter) + return signature.return_annotation + + def resolve(self) -> Any: + if isinstance(self.getter, Exception): # pragma: no cover + raise self.getter # pragma: no cover + kwargs = _recursive_resolve(self.kwargs) + args = _recursive_resolve(self.var_args) + args = list(args.values()) if isinstance(args, dict) else args + kwargs = _coerce_basemodel_args(self.getter, kwargs) + return self.getter(*args, **kwargs) # type: ignore + + @classmethod + def from_dict(cls, registry, values) -> "Promise": + reg_name, func_name = registry.get_constructor(values) + var_args, kwargs = registry.parse_args(values) + try: + getter = registry.get(reg_name, func_name) + except Exception as e: # pragma: no cover + getter = e # pragma: no cover + output = cls( + registry=reg_name, + name=func_name, + var_args=var_args, + kwargs=kwargs, + getter=getter, + ) + return output + + +def _coerce_basemodel_args(func, kwargs): + """Coerce dict kwargs to BaseModel instances where the function signature + expects a BaseModel subclass. This lets registered functions receive + constructed model instances instead of raw dicts (issue #58). + """ + try: + from typing import get_type_hints + + hints = get_type_hints(func) + except Exception: + return kwargs + result = dict(kwargs) + for name, value in result.items(): + if not isinstance(value, dict): + continue + hint = hints.get(name) + if hint is None or not isinstance(hint, type): + continue + # Check for pydantic BaseModel (v1 or v2) + if hasattr(hint, "model_validate"): + result[name] = hint.model_validate(value) + elif hasattr(hint, "parse_obj"): + result[name] = hint.parse_obj(value) + return result + + +def _recursive_resolve(obj): + if isinstance(obj, list): + return [_recursive_resolve(v) for v in obj] + elif isinstance(obj, dict): + return {k: _recursive_resolve(v) for k, v in obj.items()} + elif isinstance(obj, Promise): + return obj.resolve() + else: + return obj + + +class registry: + @classmethod + def has(cls, registry_name: str, func_name: str) -> bool: + """Check whether a function is available in a registry.""" + if not hasattr(cls, registry_name): + return False + reg = getattr(cls, registry_name) + return func_name in reg + + @classmethod + def get(cls, registry_name: str, func_name: str) -> Callable: + """Get a registered function from a given registry.""" + if not hasattr(cls, registry_name): + raise ValueError(f"Unknown registry: '{registry_name}'") + reg = getattr(cls, registry_name) + func = reg.get(func_name) + if func is None: # pragma: no cover -- catalogue raises RegistryError first + raise ValueError(f"Could not find '{func_name}' in '{registry_name}'") + return func + + @classmethod + def resolve( + cls, + config: Union[Config, Dict[str, Dict[str, Any]]], + *, + schema=None, + overrides: Dict[str, Any] = {}, + validate: bool = True, + ) -> Dict[str, Any]: + config = cls.fill( + config, + schema=schema, + overrides=overrides, + interpolate=True, + ) + promised = insert_promises(cls, config, resolve=True) + resolved = resolve_promises(promised) + fixed = fix_positionals(resolved) + assert isinstance(fixed, dict) + if schema is not None and validate: + Config(fixed).validate(schema) + return fixed + + @classmethod + def fill( + cls, + config: Union[Config, Dict[str, Dict[str, Any]]], + *, + schema=None, + overrides: Dict[str, Any] = {}, + interpolate: bool = False, + validate: bool = True, + ) -> Config: + if cls.is_promise(config): + err_msg = "The top-level config object can't be a reference to a registered function." + raise ConfigValidationError(config=config, errors=[{"msg": err_msg}]) + # If a Config was loaded with interpolate=False, we assume it needs to + # be interpolated first, otherwise we take it at face value + is_interpolated = not isinstance(config, Config) or config.is_interpolated + section_order = config.section_order if isinstance(config, Config) else None + orig_config = config + if not is_interpolated: + config = Config(orig_config).interpolate() + filled = fill_config(cls, config, overrides=overrides, validate=validate) + filled = Config(filled, section_order=section_order) + # Merge the original config back to preserve variables if we started + # with a config that wasn't interpolated. Here, we prefer variables to + # allow auto-filling a non-interpolated config without destroying + # variable references. + if not interpolate and not is_interpolated: + filled = filled.merge( + Config(orig_config, is_interpolated=False), remove_extra=True + ) + if schema is not None: + filled.fill_defaults(schema) + return filled + + @classmethod + def is_promise(cls, obj: Any) -> bool: + """Check whether an object is a "promise", i.e. contains a reference + to a registered function (via a key starting with `"@"`. + """ + return is_promise(obj) + + @classmethod + def get_constructor(cls, obj: Dict[str, Any]) -> Tuple[str, str]: + id_keys = [k for k in obj.keys() if k.startswith("@")] + if len(id_keys) != 1: + err_msg = f"A block can only contain one function registry reference. Got: {id_keys}" + raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}]) + else: + key = id_keys[0] + value = obj[key] + return (key[1:], value) + + @classmethod + def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]: + args = [] + kwargs = {} + for key, value in obj.items(): + if not key.startswith("@"): + if key == ARGS_FIELD: + args = value + else: + kwargs[key] = value + return args, kwargs + + +def _is_config_section(obj) -> bool: + """Check if a dict is a config section (all string keys) vs a data value.""" + if not isinstance(obj, dict): + return False + return all(isinstance(k, str) for k in obj.keys()) + + +def alias_generator(name: str) -> str: + """Generate field aliases in promise schema.""" + # Underscore fields are not allowed in model, so use alias + if name == ARGS_FIELD_ALIAS: + return ARGS_FIELD + # Auto-alias fields that shadow base model attributes + return RESERVED_FIELDS_REVERSE.get(name, name) + + +def fill_config( + registry, + config: Dict[str, Any], + *, + overrides: Dict[str, Dict[str, Any]] = {}, + validate: bool = True, +) -> Dict[str, Any]: + overrided = apply_overrides(dict(config), overrides) + return _fill_defaults(registry, overrided, validate=validate) + + +def _fill_defaults( + registry, config: Dict[str, Any], *, validate: bool = True +) -> Dict[str, Any]: + """Recursively fill default values from registered function signatures.""" + output = dict(config) + for key, value in output.items(): + if is_promise(value): + output[key] = _fill_promise_defaults( + registry, value, validate=validate, parent=key + ) + elif isinstance(value, dict): + output[key] = _fill_defaults(registry, value, validate=validate) + return output + + +def _fill_promise_defaults( + registry, + promise_dict: Dict[str, Any], + *, + validate: bool = True, + parent: str = "", +) -> Dict[str, Any]: + """Fill default argument values for a promise block from the function signature.""" + reg_name, func_name = registry.get_constructor(promise_dict) + if not registry.has(reg_name, func_name): + # Can't fill defaults for unknown functions — they'll error at resolve time + return dict(promise_dict) + func = registry.get(reg_name, func_name) + schema = Schema.from_function(func) + filled = dict(promise_dict) + # Fill in defaults from the schema + for param_name, field in schema.model_fields.items(): + if param_name not in filled and not field.is_required(): + filled[param_name] = field.default + # Validate: check for missing required args and type errors + if validate: + _validate_promise_args(filled, schema, func_name, parent) + # Recurse into nested values (which may themselves be promises) + for key, value in filled.items(): + if key.startswith("@"): + continue + if is_promise(value): + filled[key] = _fill_promise_defaults( + registry, value, validate=validate, parent=f"{parent}.{key}" + ) + elif isinstance(value, dict): + filled[key] = _fill_defaults(registry, value, validate=validate) + return filled + + +def _validate_promise_args( + filled: Dict[str, Any], + schema, + func_name: str, + parent: str, +) -> None: + """Validate promise arguments against the function schema.""" + from .validation import validate_type + + errors = [] + for param_name, field in schema.model_fields.items(): + if param_name not in filled: + if field.is_required(): + errors.append( + { + "loc": [parent, param_name] if parent else [param_name], + "msg": f"missing required argument: '{param_name}'", + } + ) + elif not is_promise(filled[param_name]): + # Only validate non-promise values — promises will be validated + # when they're resolved + err = validate_type(filled[param_name], field.annotation) + if err: + errors.append( + { + "loc": [parent, param_name] if parent else [param_name], + "msg": err, + } + ) + # Check for unexpected arguments + known = set(schema.model_fields.keys()) | {k for k in filled if k.startswith("@")} + for key in filled: + if key not in known: + errors.append( + { + "loc": [parent, key] if parent else [key], + "msg": f"unexpected argument: '{key}'", + } + ) + if errors: + raise ConfigValidationError( + config=filled, + errors=errors, + title=f"Config error for '{func_name}'", + ) + + +def insert_promises( + registry, config: Dict[str, Dict[str, Any]], resolve: bool +) -> Dict[str, Dict[str, Any]]: + """Create a version of a config dict where promises are recognised and replaced by + Promise dataclasses + """ + output = {} + for key, value in config.items(): + if registry.is_promise(value): + value = insert_promises(registry, value, resolve=resolve) + output[key] = Promise.from_dict( + registry, + value, + ) + elif isinstance(value, dict): + output[key] = insert_promises(registry, value, resolve=resolve) + else: + output[key] = value + return output + + +def resolve_promises(config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + output = {} + for key, value in config.items(): + if isinstance(value, dict): + output[key] = resolve_promises(value) + elif isinstance(value, Promise): + output[key] = value.resolve() + else: + output[key] = value + return output + + +def fix_positionals(config): + """Ensure positionals are provided as a tuple, rather than a dict.""" + if isinstance(config, dict): + output = {} + for key, value in config.items(): + if key == ARGS_FIELD and isinstance(value, dict): + value = tuple(value.values()) + if isinstance(value, dict): + value = fix_positionals(value) + elif isinstance(value, list) or isinstance(value, tuple): + value = fix_positionals(value) + output[key] = value + return output + elif isinstance(config, list): + return [fix_positionals(v) for v in config] + elif isinstance(config, tuple): + return tuple([fix_positionals(v) for v in config]) + else: + return config + + +def _deep_copy_with_uncopyable(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any: + """Deep copy that passes through objects that can't be copied (like generators).""" + if memo is None: + memo = {} + + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + + if isinstance(obj, dict): + result = {} + memo[obj_id] = result + for k, v in obj.items(): + result[_deep_copy_with_uncopyable(k, memo)] = _deep_copy_with_uncopyable( + v, memo + ) + return result + elif isinstance(obj, list): + result = [] + memo[obj_id] = result + for item in obj: + result.append(_deep_copy_with_uncopyable(item, memo)) + return result + elif isinstance(obj, tuple): + # Tuples are immutable, but we still need to copy their contents + return tuple(_deep_copy_with_uncopyable(item, memo) for item in obj) + else: + return obj + + +def apply_overrides( + config: Dict[str, Dict[str, Any]], + overrides: Dict[str, Dict[str, Any]], +) -> Dict[str, Dict[str, Any]]: + """Build first representation of the config:""" + output = dict(config) + for key, value in overrides.items(): + path = key.split(".") + err_title = "Error parsing config overrides" + err_msg = "not a section value that can be overridden" + err = [{"loc": path, "msg": err_msg}] + node = output + for subkey in path[:-1]: + if ( + not isinstance(node, dict) or subkey not in node + ): # pragma: no cover -- overrides validated in _parser + raise ConfigValidationError( + errors=err, title=err_title + ) # pragma: no cover + node = node[subkey] + if path[-1] not in node: # pragma: no cover + raise ConfigValidationError(errors=err, title=err_title) # pragma: no cover + node[path[-1]] = value # pragma: no cover + return output diff --git a/confection/tests/conftest.py b/confection/tests/conftest.py deleted file mode 100644 index 43e9d23..0000000 --- a/confection/tests/conftest.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - - -def pytest_addoption(parser): - parser.addoption("--slow", action="store_true", help="include slow tests") - - -@pytest.fixture() -def pathy_fixture(): - pytest.importorskip("pathy") - import shutil - import tempfile - - from pathy import Pathy, use_fs - - temp_folder = tempfile.mkdtemp(prefix="thinc-pathy") - use_fs(temp_folder) - - root = Pathy("gs://test-bucket") - root.mkdir(exist_ok=True) - - yield root - use_fs(False) - shutil.rmtree(temp_folder) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py deleted file mode 100644 index ce21335..0000000 --- a/confection/tests/test_config.py +++ /dev/null @@ -1,1489 +0,0 @@ -import inspect -import pickle -import platform -import sys -from types import GeneratorType -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import catalogue -import pytest - -if sys.version_info >= (3, 14): - pytest.skip( - "pydantic v1 is not compatible with Python 3.14+", allow_module_level=True - ) - -try: - from pydantic.v1 import BaseModel, PositiveInt, StrictFloat, constr - from pydantic.v1.types import StrictBool -except ImportError: - from pydantic import BaseModel, PositiveInt, StrictFloat, constr # type: ignore - from pydantic.types import StrictBool # type: ignore - -from confection import Config, ConfigValidationError -from confection.tests.util import Cat, make_tempdir, my_registry -from confection.util import Generator, partial - -EXAMPLE_CONFIG = """ -[optimizer] -@optimizers = "Adam.v1" -beta1 = 0.9 -beta2 = 0.999 -use_averages = true - -[optimizer.learn_rate] -@schedules = "warmup_linear.v1" -initial_rate = 0.1 -warmup_steps = 10000 -total_steps = 100000 - -[pipeline] - -[pipeline.classifier] -name = "classifier" -factory = "classifier" - -[pipeline.classifier.model] -@layers = "ClassifierModel.v1" -hidden_depth = 1 -hidden_width = 64 -token_vector_width = 128 - -[pipeline.classifier.model.embedding] -@layers = "Embedding.v1" -width = ${pipeline.classifier.model:token_vector_width} - -""" - -OPTIMIZER_CFG = """ -[optimizer] -@optimizers = "Adam.v1" -beta1 = 0.9 -beta2 = 0.999 -use_averages = true - -[optimizer.learn_rate] -@schedules = "warmup_linear.v1" -initial_rate = 0.1 -warmup_steps = 10000 -total_steps = 100000 -""" - - -class HelloIntsSchema(BaseModel): - hello: int - world: int - - class Config: - extra = "forbid" - - -class DefaultsSchema(BaseModel): - required: int - optional: str = "default value" - - class Config: - extra = "forbid" - - -class ComplexSchema(BaseModel): - outer_req: int - outer_opt: str = "default value" - - level2_req: HelloIntsSchema - level2_opt: DefaultsSchema = DefaultsSchema(required=1) - - -good_catsie = {"@cats": "catsie.v1", "evil": False, "cute": True} -ok_catsie = {"@cats": "catsie.v1", "evil": False, "cute": False} -bad_catsie = {"@cats": "catsie.v1", "evil": True, "cute": True} -worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False} - - -def test_validate_simple_config(): - simple_config = {"hello": 1, "world": 2} - f, _, v = my_registry._fill(simple_config, HelloIntsSchema) - assert f == simple_config - assert v == simple_config - - -def test_invalidate_simple_config(): - invalid_config = {"hello": 1, "world": "hi!"} - with pytest.raises(ConfigValidationError) as exc_info: - my_registry._fill(invalid_config, HelloIntsSchema) - error = exc_info.value - assert len(error.errors) == 1 - assert "type_error.integer" in error.error_types - - -def test_invalidate_extra_args(): - invalid_config = {"hello": 1, "world": 2, "extra": 3} - with pytest.raises(ConfigValidationError): - my_registry._fill(invalid_config, HelloIntsSchema) - - -def test_fill_defaults_simple_config(): - valid_config = {"required": 1} - filled, _, v = my_registry._fill(valid_config, DefaultsSchema) - assert filled["required"] == 1 - assert filled["optional"] == "default value" - invalid_config = {"optional": "some value"} - with pytest.raises(ConfigValidationError): - my_registry._fill(invalid_config, DefaultsSchema) - - -def test_fill_recursive_config(): - valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}} - filled, _, validation = my_registry._fill(valid_config, ComplexSchema) - assert filled["outer_req"] == 1 - assert filled["outer_opt"] == "default value" - assert filled["level2_req"]["hello"] == 4 - assert filled["level2_req"]["world"] == 7 - assert filled["level2_opt"]["required"] == 1 - assert filled["level2_opt"]["optional"] == "default value" - - -def test_is_promise(): - assert my_registry.is_promise(good_catsie) - assert not my_registry.is_promise({"hello": "world"}) - assert not my_registry.is_promise(1) - invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} - assert my_registry.is_promise(invalid) - - -def test_get_constructor(): - assert my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1") - - -def test_parse_args(): - args, kwargs = my_registry.parse_args(bad_catsie) - assert args == [] - assert kwargs == {"evil": True, "cute": True} - - -def test_make_promise_schema(): - schema = my_registry.make_promise_schema(good_catsie) - assert "evil" in schema.model_fields - assert "cute" in schema.model_fields - - -def test_validate_promise(): - config = {"required": 1, "optional": good_catsie} - filled, _, validated = my_registry._fill(config, DefaultsSchema) - assert filled == config - assert validated == {"required": 1, "optional": "meow"} - - -def test_fill_validate_promise(): - config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} - filled, _, validated = my_registry._fill(config, DefaultsSchema) - assert filled["optional"]["cute"] is True - - -def test_fill_invalidate_promise(): - config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} - with pytest.raises(ConfigValidationError): - my_registry._fill(config, HelloIntsSchema) - config["optional"]["whiskers"] = True - with pytest.raises(ConfigValidationError): - my_registry._fill(config, DefaultsSchema) - - -def test_create_registry(): - my_registry.dogs = catalogue.create( - my_registry.namespace, "dogs", entry_points=False - ) - assert hasattr(my_registry, "dogs") - assert len(my_registry.dogs.get_all()) == 0 - my_registry.dogs.register("good_boy.v1", func=lambda x: x) - assert len(my_registry.dogs.get_all()) == 1 - - -def test_registry_methods(): - with pytest.raises(ValueError): - my_registry.get("dfkoofkds", "catsie.v1") - my_registry.cats.register("catsie.v123")(None) - with pytest.raises(ValueError): - my_registry.get("cats", "catsie.v123") - - -def test_resolve_no_schema(): - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - result = my_registry.resolve({"cfg": config})["cfg"] - assert result["one"] == 1 - assert result["two"] == {"three": "scratch!"} - with pytest.raises(ConfigValidationError): - config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}} - my_registry.resolve(config) - - -def test_resolve_schema(): - class TestBaseSubSchema(BaseModel): - three: str - - class TestBaseSchema(BaseModel): - one: PositiveInt - two: TestBaseSubSchema - - class Config: - extra = "forbid" - - class TestSchema(BaseModel): - cfg: TestBaseSchema - - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - my_registry.resolve({"cfg": config}, schema=TestSchema) - config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} - with pytest.raises(ConfigValidationError): - # "one" is not a positive int - my_registry.resolve({"cfg": config}, schema=TestSchema) - config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}} - with pytest.raises(ConfigValidationError): - # "three" is required in subschema - my_registry.resolve({"cfg": config}, schema=TestSchema) - - -def test_resolve_schema_coerced(): - class TestBaseSchema(BaseModel): - test1: str - test2: bool - test3: float - - class TestSchema(BaseModel): - cfg: TestBaseSchema - - config = {"test1": 123, "test2": 1, "test3": 5} - filled = my_registry.fill({"cfg": config}, schema=TestSchema) - result = my_registry.resolve({"cfg": config}, schema=TestSchema) - # Without pydantic, values are not coerced (no int→str, int→float) - assert result["cfg"] == config - assert filled["cfg"] == config - - -def test_read_config(): - byte_string = EXAMPLE_CONFIG.encode("utf8") - cfg = Config().from_bytes(byte_string) - - assert cfg["optimizer"]["beta1"] == 0.9 - assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1 - assert cfg["pipeline"]["classifier"]["factory"] == "classifier" - assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128 - - -def test_optimizer_config(): - cfg = Config().from_str(OPTIMIZER_CFG) - optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] - assert optimizer.beta1 == 0.9 - - -def test_config_to_str(): - cfg = Config().from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_to_str_creates_intermediate_blocks(): - cfg = Config({"optimizer": {"foo": {"bar": 1}}}) - assert ( - cfg.to_str().strip() - == """ -[optimizer] - -[optimizer.foo] -bar = 1 - """.strip() - ) - - -def test_config_to_str_escapes(): - section_str = """ - [section] - node1 = "^a$$" - node2 = "$$b$$c" - """ - section_dict = {"section": {"node1": "^a$", "node2": "$b$c"}} - - # parse from escaped string - cfg = Config().from_str(section_str) - assert cfg == section_dict - - # parse from non-escaped dict - cfg = Config(section_dict) - assert cfg == section_dict - - # roundtrip through str - cfg_str = cfg.to_str() - assert "^a$$" in cfg_str - Config().from_str(cfg_str) - assert cfg == section_dict - - -def test_config_incomplete_dollar_brace(): - """Strings containing ${ without a closing } should roundtrip correctly.""" - cfg = Config({"section": {"val": "${"}}) - cfg_str = cfg.to_str() - cfg2 = Config().from_str(cfg_str) - assert cfg2["section"]["val"] == "${" - - cfg = Config({"section": {"val": "hello ${ world"}}) - cfg_str = cfg.to_str() - cfg2 = Config().from_str(cfg_str) - assert cfg2["section"]["val"] == "hello ${ world" - - -def test_config_roundtrip_bytes(): - cfg = Config().from_str(OPTIMIZER_CFG) - cfg_bytes = cfg.to_bytes() - new_cfg = Config().from_bytes(cfg_bytes) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_roundtrip_disk(): - cfg = Config().from_str(OPTIMIZER_CFG) - with make_tempdir() as path: - cfg_path = path / "config.cfg" - cfg.to_disk(cfg_path) - new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture): - cfg = Config().from_str(OPTIMIZER_CFG) - cfg_path = pathy_fixture / "config.cfg" - cfg.to_disk(cfg_path) - new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - - -def test_config_to_str_invalid_defaults(): - """Test that an error is raised if a config contains top-level keys without - a section that would otherwise be interpreted as [DEFAULT] (which causes - the values to be included in *all* other sections). - """ - cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}} - with pytest.raises(ConfigValidationError): - Config(cfg).to_str() - config_str = "[DEFAULT]\none = 1" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_validation_custom_types(): - def complex_args( - rate: StrictFloat, - steps: PositiveInt = 10, # type: ignore - log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR", # noqa: F821 - ): - return None - - my_registry.complex = catalogue.create( - my_registry.namespace, "complex", entry_points=False - ) - my_registry.complex("complex.v1")(complex_args) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"} - with pytest.raises(ConfigValidationError): - # steps is not a positive int - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"} - with pytest.raises(ConfigValidationError): - # log_level is not a string matching the regex - my_registry.resolve({"config": cfg}) - cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} - with pytest.raises(ConfigValidationError): - # top-level object is promise - my_registry.resolve(cfg) - with pytest.raises(ConfigValidationError): - # top-level object is promise - my_registry.fill(cfg) - cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} - with pytest.raises(ConfigValidationError): - # two constructors - my_registry.resolve({"config": cfg}) - - -def test_validation_no_validate(): - config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}} - result = my_registry.resolve({"cfg": config}, validate=False) - filled = my_registry.fill({"cfg": config}, validate=False) - assert result["cfg"]["one"] == 1 - assert result["cfg"]["two"] == {"three": "scratch!"} - assert filled["cfg"]["two"]["three"]["evil"] == "false" - assert filled["cfg"]["two"]["three"]["cute"] is True - - -def test_validation_fill_defaults(): - config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}} - result = my_registry.fill(config, validate=False) - assert len(result["cfg"]["two"]) == 3 - with pytest.raises(ConfigValidationError): - # Required arg "evil" is not defined - my_registry.fill(config) - config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}} - # Fill in with new defaults - result = my_registry.fill(config) - assert len(result["cfg"]["two"]) == 4 - assert result["cfg"]["two"]["evil"] is False - assert result["cfg"]["two"]["cute"] is True - assert result["cfg"]["two"]["cute_level"] == 1 - - -def test_make_config_positional_args(): - @my_registry.cats("catsie.v567") - def catsie_567(*args: Optional[str], foo: str = "bar"): - assert args[0] == "^_^" - assert args[1] == "^(*.*)^" - assert foo == "baz" - return args[0] - - args = ["^_^", "^(*.*)^"] - cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}} - assert my_registry.resolve(cfg)["config"] == "^_^" - - -def test_fill_config_positional_args_w_promise(): - @my_registry.cats("catsie.v568") - def catsie_568(*args: str, foo: str = "bar"): - assert args[0] == "^(*.*)^" - assert foo == "baz" - return args[0] - - @my_registry.cats("cat_promise.v568") - def cat_promise() -> str: - return "^(*.*)^" - - cfg = { - "config": { - "@cats": "catsie.v568", - "*": {"promise": {"@cats": "cat_promise.v568"}}, - } - } - filled = my_registry.fill(cfg, validate=True) - assert filled["config"]["foo"] == "bar" - assert filled["config"]["*"] == {"promise": {"@cats": "cat_promise.v568"}} - - -def test_make_config_positional_args_complex(): - @my_registry.cats("catsie.v890") - def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]): - assert args[0] == 123 - return args[0] - - cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}} - assert my_registry.resolve(cfg)["config"] == 123 - cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}} - with pytest.raises(ConfigValidationError): - # "True" is not a valid boolean or positive int - my_registry.resolve(cfg) - - -def test_positional_args_to_from_string(): - cfg = """[a]\nb = 1\n* = ["foo","bar"]""" - assert Config().from_str(cfg).to_str() == cfg - cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1""" - assert Config().from_str(cfg).to_str() == cfg - - @my_registry.cats("catsie.v666") - def catsie_666(*args, meow=False): - return args - - cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false""" - resolved = my_registry.resolve(Config().from_str(cfg)) - assert resolved == {"a": ("foo", "bar")} - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1""" - resolved = my_registry.resolve(Config().from_str(cfg)) - assert resolved == {"a": ({"x": 1},)} - - @my_registry.cats("catsie.v777") - def catsie_777(y: int = 1): - return "meow" * y - - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\"""" - filled = my_registry.fill(Config().from_str(cfg)).to_str() - expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1""" - assert filled == expected - cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3""" - result = my_registry.resolve(Config().from_str(cfg)) - assert result == {"a": ("meowmeowmeow",)} - - -def test_validation_generators_iterable(): - @my_registry.optimizers("test_optimizer.v1") - def test_optimizer_v1(rate: float) -> None: - return None - - @my_registry.schedules("test_schedule.v1") - def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]: - while True: - yield some_value - - config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}} - my_registry.resolve(config) - - -def test_validation_unset_type_hints(): - """Test that unset type hints are handled correctly (and treated as Any).""" - - @my_registry.optimizers("test_optimizer.v2") - def test_optimizer_v2(rate, steps: int = 10) -> None: - return None - - config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}} - my_registry.resolve(config) - - -def test_validation_bad_function(): - @my_registry.optimizers("bad.v1") - def bad() -> None: - raise ValueError("This is an error in the function") - return None - - @my_registry.optimizers("good.v1") - def good() -> None: - return None - - # Bad function - config = {"test": {"@optimizers": "bad.v1"}} - with pytest.raises(ValueError): - my_registry.resolve(config) - # Bad function call - config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(config) - - -def test_objects_from_config(): - config = { - "optimizer": { - "@optimizers": "my_cool_optimizer.v1", - "beta1": 0.2, - "learn_rate": { - "@schedules": "my_cool_repetitive_schedule.v1", - "base_rate": 0.001, - "repeat": 4, - }, - } - } - - optimizer = my_registry.resolve(config)["optimizer"] - assert optimizer.beta1 == 0.2 - assert optimizer.learn_rate == [0.001] * 4 - - -def test_partials_from_config(): - """Test that functions registered with partial applications are handled - correctly (e.g. initializers).""" - numpy = pytest.importorskip("numpy") - - def uniform_init( - shape: Tuple[int, ...], *, lo: float = -0.1, hi: float = 0.1 - ) -> List[float]: - return numpy.random.uniform(lo, hi, shape).tolist() - - @my_registry.initializers("uniform_init.v1") - def configure_uniform_init( - *, lo: float = -0.1, hi: float = 0.1 - ) -> Callable[[List[float]], List[float]]: - return partial(uniform_init, lo=lo, hi=hi) - - name = "uniform_init.v1" - cfg = {"test": {"@initializers": name, "lo": -0.2}} - func = my_registry.resolve(cfg)["test"] - assert hasattr(func, "__call__") - # The partial will still have lo as an arg, just with default - assert len(inspect.signature(func).parameters) == 3 - # Make sure returned partial function has correct value set - assert inspect.signature(func).parameters["lo"].default == -0.2 - # Actually call the function and verify - assert numpy.asarray(func((2, 3))).shape == (2, 3) - # Make sure validation still works - bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(bad_cfg) - bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}} - with pytest.raises(ConfigValidationError): - my_registry.resolve(bad_cfg) - - -def test_partials_from_config_nested(): - """Test that partial functions are passed correctly to other registered - functions that consume them (e.g. initializers -> layers).""" - - def test_initializer(a: int, b: int = 1) -> int: - return a * b - - @my_registry.initializers("test_initializer.v1") - def configure_test_initializer(b: int = 1) -> Callable[[int], int]: - return partial(test_initializer, b=b) - - @my_registry.layers("test_layer.v1") - def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]: - return lambda x: x + init(c) - - cfg = { - "@layers": "test_layer.v1", - "c": 5, - "init": {"@initializers": "test_initializer.v1", "b": 10}, - } - func = my_registry.resolve({"test": cfg})["test"] - assert func(1) == 51 - assert func(100) == 150 - - -def test_validate_generator(): - """Test that generator replacement for validation in config doesn't - actually replace the returned value.""" - - @my_registry.schedules("test_schedule.v2") - def test_schedule(): - while True: - yield 10 - - cfg = {"@schedules": "test_schedule.v2"} - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v2") - def test_optimizer2(rate: Generator) -> Generator: - return rate - - cfg = { - "@optimizers": "test_optimizer.v2", - "rate": {"@schedules": "test_schedule.v2"}, - } - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v3") - def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: - return schedules["rate"] - - cfg = { - "@optimizers": "test_optimizer.v3", - "schedules": {"rate": {"@schedules": "test_schedule.v2"}}, - } - result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) - - @my_registry.optimizers("test_optimizer.v4") - def test_optimizer4(*schedules: Generator) -> Generator: - return schedules[0] - - -def test_handle_generic_type(): - """Test that validation can handle checks against arbitrary generic - types in function argument annotations.""" - - cfg = {"@cats": "generic_cat.v1", "cat": {"@cats": "int_cat.v1", "value_in": 3}} - cat = my_registry.resolve({"test": cfg})["test"] - assert isinstance(cat, Cat) - assert cat.value_in == 3 - assert cat.value_out is None - assert cat.name == "generic_cat" - - -@pytest.mark.parametrize( - "cfg", - [ - "[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3", - "[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3", - ], -) -def test_handle_error_duplicate_keys(cfg): - """This would cause very cryptic error when interpreting config. - (TypeError: 'X' object does not support item assignment) - """ - with pytest.raises(ConfigValidationError): - Config().from_str(cfg) - - -@pytest.mark.parametrize( - "cfg,is_valid", - [("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)], -) -def test_cant_expand_undefined_block(cfg, is_valid): - """Test that you can't expand a block that hasn't been created yet. This - comes up when you typo a name, and if we allow expansion of undefined blocks, - it's very hard to create good errors for those typos. - """ - if is_valid: - Config().from_str(cfg) - else: - with pytest.raises(ConfigValidationError): - Config().from_str(cfg) - - -def test_fill_config_overrides(): - config = { - "cfg": { - "one": 1, - "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, - } - } - overrides = {"cfg.two.three.evil": False} - result = my_registry.fill(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"]["evil"] is False - # Test that promises can be overwritten as well - overrides = {"cfg.two.three": 3} - result = my_registry.fill(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == 3 - # Test that value can be overwritten with promises and that the result is - # interpreted and filled correctly - overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} - result = my_registry.fill(config, overrides=overrides) - assert result["cfg"]["two"] is None - assert result["cfg"]["one"]["@cats"] == "catsie.v1" - assert result["cfg"]["one"]["evil"] is False - assert result["cfg"]["one"]["cute"] is True - # Overwriting with wrong types should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": 20} - my_registry.fill(config, overrides=overrides, validate=True) - # Overwriting with incomplete promises should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} - my_registry.fill(config, overrides=overrides) - # Overrides that don't match config should raise error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": False, "two.four": True} - my_registry.fill(config, overrides=overrides, validate=True) - with pytest.raises(ConfigValidationError): - overrides = {"cfg.five": False} - my_registry.fill(config, overrides=overrides, validate=True) - - -def test_resolve_overrides(): - config = { - "cfg": { - "one": 1, - "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, - } - } - overrides = {"cfg.two.three.evil": False} - result = my_registry.resolve(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == "meow" - # Test that promises can be overwritten as well - overrides = {"cfg.two.three": 3} - result = my_registry.resolve(config, overrides=overrides, validate=True) - assert result["cfg"]["two"]["three"] == 3 - # Test that value can be overwritten with promises - overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} - result = my_registry.resolve(config, overrides=overrides) - assert result["cfg"]["one"] == "meow" - assert result["cfg"]["two"] is None - # Overwriting with wrong types should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": 20} - my_registry.resolve(config, overrides=overrides, validate=True) - # Overwriting with incomplete promises should cause validation error - with pytest.raises(ConfigValidationError): - overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} - my_registry.resolve(config, overrides=overrides) - # Overrides that don't match config should raise error - with pytest.raises(ConfigValidationError): - overrides = {"cfg.two.three.evil": False, "cfg.two.four": True} - my_registry.resolve(config, overrides=overrides, validate=True) - with pytest.raises(ConfigValidationError): - overrides = {"cfg.five": False} - my_registry.resolve(config, overrides=overrides, validate=True) - - -@pytest.mark.parametrize( - "prop,expected", - [("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)], -) -def test_is_in_config(prop, expected): - config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}} - assert my_registry._is_in_config(prop, config) is expected - - -def test_resolve_prefilled_values(): - class Language(object): - def __init__(self): ... - - @my_registry.optimizers("prefilled.v1") - def prefilled(nlp: Language, value: int = 10): - return (nlp, value) - - # Passing an instance of Language here via the config is bad, since it - # won't serialize to a string, but we still test for it - config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}} - resolved = my_registry.resolve(config, validate=True) - result = resolved["test"] - assert isinstance(result[0], Language) - assert result[1] == 50 - - -def test_fill_config_dict_return_type(): - """Test that a registered function returning a dict is handled correctly.""" - - @my_registry.cats.register("catsie_with_dict.v1") - def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]: - return {"not_evil": not evil} - - config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10} - result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"] - assert result["evil"] is False - assert "not_evil" not in result - result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"] - assert result["not_evil"] is True - - -def test_deepcopy_config(): - config = Config({"a": 1, "b": {"c": 2, "d": 3}}) - copied = config.copy() - # Same values but not same object - assert config == copied - assert config is not copied - - -@pytest.mark.skipif( - platform.python_implementation() == "PyPy", reason="copy does not fail for pypy" -) -def test_deepcopy_config_pickle(): - numpy = pytest.importorskip("numpy") - # Check for error if value can't be pickled/deepcopied - config = Config({"a": 1, "b": numpy}) - with pytest.raises(ValueError): - config.copy() - - -def test_config_to_str_simple_promises(): - """Test that references to function registries without arguments are - serialized inline as dict.""" - config_str = """[section]\nsubsection = {"@registry":"value"}""" - config = Config().from_str(config_str) - assert config["section"]["subsection"]["@registry"] == "value" - assert config.to_str() == config_str - - -def test_config_from_str_invalid_section(): - config_str = """[a]\nb = null\n\n[a.b]\nc = 1""" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1""" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_config_to_str_order(): - """Test that Config.to_str orders the sections.""" - config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}} - expected = ( - "[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5" - ) - config = Config(config) - assert config.to_str() == expected - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_interpolation(d): - """Test that config values are interpolated correctly. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}""" - assert Config().from_str(c_str)["b"]["bar"] == "hello" - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!""" - assert Config().from_str(c_str)["b"]["bar"] == "hello!" - c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\"""" - assert Config().from_str(c_str)["b"]["bar"] == "hello!" - c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!""" - assert Config().from_str(c_str)["b"]["bar"] == "15!" - c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}""" - assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"] - # Interpolation within the same section - c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\"""" - assert Config().from_str(c_str)["a"]["bar"] == "x" - assert Config().from_str(c_str)["a"]["baz"] == "xy" - - -def test_config_interpolation_lists(): - # Test that lists are preserved correctly - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]""" - config = Config().from_str(c_str, interpolate=False) - assert config["c"]["d"] == ["hello ${a.b}", "world"] - config = config.interpolate() - assert config["c"]["d"] == ["hello 1", "world"] - c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]""" - config = Config().from_str(c_str) - assert config["c"]["d"] == [1, "hello 1", "world"] - config = Config().from_str(c_str, interpolate=False) - # NOTE: This currently doesn't work, because we can't know how to JSON-load - # the uninterpolated list [${a.b}]. - # assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"] - # config = config.interpolate() - # assert config["c"]["d"] == [1, "hello 1", "world"] - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]""" - config = Config().from_str(c_str) - assert config["c"]["d"] == ["hello", {"b": 1}] - c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]""" - config = Config().from_str(config_str) - assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}] - config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_interpolation_sections(d): - """Test that config sections are interpolated correctly. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - # Simple block references - c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}""" - config = Config().from_str(c_str) - assert config["b"]["c"] == config["a"] - # References with non-string values - c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]""" - config = Config().from_str(c_str) - assert config["a"]["x"]["y"] == config["a"]["b"] - # Multiple references in the same string - c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\"""" - config = Config().from_str(c_str) - assert config["b"]["z"] == "string/10" - # Non-string references in string (converted to string) - c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\"""" - config = Config().from_str(c_str) - assert config["b"]["y"] == 'result: ["hello", "world"]' - # References to sections referencing sections - c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}""" - config = Config().from_str(c_str) - assert config["b"]["bar"] == config["a"] - assert config["c"]["baz"] == config["b"] - # References to section values referencing other sections - c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}""" - config = Config().from_str(c_str) - assert config["c"]["baz"] == config["b"]["bar"] - # References to sections with subsections - c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}""" - config = Config().from_str(c_str) - assert config["c"]["baz"] == config["a"] - # Infinite recursion - c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}""" - config = Config().from_str(c_str) - assert config["a"]["b"]["bar"] == config["a"] - c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}""" - # We can't reference not-yet interpolated subsections - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - # Generally invalid references - c_str = f"""[a]\nfoo = ${{b{d}bar}}""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - # We can't reference sections or promises within strings - c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2""" - with pytest.raises(ConfigValidationError): - Config().from_str(c_str) - - -def test_config_from_str_overrides(): - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}""" - # Basic value substitution - overrides = {"a.b": 10, "a.c.d": 20} - config = Config().from_str(config_str, overrides=overrides) - assert config["a"]["b"] == 10 - assert config["a"]["c"]["d"] == 20 - assert config["a"]["c"]["e"] == 3 - # Valid values that previously weren't in config - config = Config().from_str(config_str, overrides={"a.c.f": 100}) - assert config["a"]["c"]["d"] == 2 - assert config["a"]["c"]["e"] == 3 - assert config["a"]["c"]["f"] == 100 - # Invalid keys and sections - with pytest.raises(ConfigValidationError): - Config().from_str(config_str, overrides={"f": 10}) - # This currently isn't expected to work, because the dict in f.g is not - # interpreted as a section while the config is still just the configparser - with pytest.raises(ConfigValidationError): - Config().from_str(config_str, overrides={"f.g.x": "z"}) - # With variables (values) - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}""" - config = Config().from_str(config_str, overrides={"a.b": 10}) - assert config["a"]["b"] == 10 - assert config["a"]["c"]["e"] == 10 - # With variables (sections) - config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}""" - config = Config().from_str(config_str, overrides={"a.c.d": 20}) - assert config["a"]["c"]["d"] == 20 - assert config["e"]["f"] == {"d": 20} - - -def test_config_reserved_aliases(): - """Test that the auto-generated pydantic schemas auto-alias reserved - attributes like "validate" that would otherwise cause NameError.""" - - @my_registry.cats("catsie.with_alias") - def catsie_with_alias(validate: StrictBool = False): - return validate - - cfg = {"@cats": "catsie.with_alias", "validate": True} - resolved = my_registry.resolve({"test": cfg}) - filled = my_registry.fill({"test": cfg}) - assert resolved["test"] is True - assert filled["test"] == cfg - cfg = {"@cats": "catsie.with_alias", "validate": 20} - with pytest.raises(ConfigValidationError): - my_registry.resolve({"test": cfg}) - - -@pytest.mark.parametrize("d", [".", ":"]) -def test_config_no_interpolation(d): - """Test that interpolation is correctly preserved. The parametrized - value is the final divider (${a.b} vs. ${a:b}). Both should now work and be - valid. The double {{ }} in the config strings are required to prevent the - references from being interpreted as an actual f-string variable. - """ - numpy = pytest.importorskip("numpy") - c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}""" - config = Config().from_str(c_str, interpolate=False) - assert not config.is_interpolated - assert config["c"]["d"] == f"${{a{d}b}}" - assert config["c"]["e"] == f'"hello${{a{d}b}}"' - assert config["c"]["f"] == "${a}" - config2 = Config().from_str(config.to_str(), interpolate=True) - assert config2.is_interpolated - assert config2["c"]["d"] == 1 - assert config2["c"]["e"] == "hello1" - assert config2["c"]["f"] == {"b": 1} - config3 = config.interpolate() - assert config3.is_interpolated - assert config3["c"]["d"] == 1 - assert config3["c"]["e"] == "hello1" - assert config3["c"]["f"] == {"b": 1} - # Bad non-serializable value - cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}} - with pytest.raises(ConfigValidationError): - Config(cfg).interpolate() - - -def test_config_no_interpolation_registry(): - config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}""" - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - assert config["b"]["evil"] == "${a:bad}" - assert config["c"]["d"] == "${b}" - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["b"] == "scratch!" - assert resolved["c"]["d"] == "scratch!" - assert filled["b"]["evil"] == "${a:bad}" - assert filled["b"]["cute"] is True - assert filled["c"]["d"] == "${b}" - interpolated = filled.interpolate() - assert interpolated.is_interpolated - assert interpolated["b"]["evil"] is True - assert interpolated["c"]["d"] == interpolated["b"] - config = Config().from_str(config_str, interpolate=True) - assert config.is_interpolated - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["b"] == "scratch!" - assert resolved["c"]["d"] == "scratch!" - assert filled["b"]["evil"] is True - assert filled["c"]["d"] == filled["b"] - # Resolving a non-interpolated filled config - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - filled = my_registry.fill(config) - assert not filled.is_interpolated - assert filled["c"]["d"] == "${b}" - resolved = my_registry.resolve(filled) - assert resolved["c"]["d"] == "scratch!" - - -def test_config_deep_merge(): - config = {"a": "hello", "b": {"c": "d"}} - defaults = {"a": "world", "b": {"c": "e", "f": "g"}} - merged = Config(defaults).merge(config) - assert len(merged) == 2 - assert merged["a"] == "hello" - assert merged["b"] == {"c": "d", "f": "g"} - config = {"a": "hello", "b": {"@test": "x", "foo": 1}} - defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2} - assert merged["c"] == 100 - config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "x", "foo": 1} - assert merged["c"] == 100 - # Test that leaving out the factory just adds to existing - config = {"a": "hello", "b": {"foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2} - assert merged["c"] == 100 - # Test that switching to a different factory prevents the default from being added - config = {"a": "hello", "b": {"@foo": 1}, "c": 100} - defaults = {"a": "world", "b": {"@bar": "y"}} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@foo": 1} - assert merged["c"] == 100 - config = {"a": "hello", "b": {"@foo": 1}, "c": 100} - defaults = {"a": "world", "b": "y"} - merged = Config(defaults).merge(config) - assert len(merged) == 3 - assert merged["a"] == "hello" - assert merged["b"] == {"@foo": 1} - assert merged["c"] == 100 - - -def test_config_deep_merge_variables(): - config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}""" - defaults_str = """[a]\nx = 100\n\n[d]\ny = 500""" - config = Config().from_str(config_str, interpolate=False) - defaults = Config().from_str(defaults_str) - merged = defaults.merge(config) - assert merged["a"] == {"b": 1, "c": 2, "x": 100} - assert merged["d"] == {"e": "${a:b}", "y": 500} - assert merged.interpolate()["d"] == {"e": 1, "y": 500} - # With variable in defaults: overwritten by new value - config = Config().from_str("""[a]\nb= 1\nc = 2""") - defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False) - merged = defaults.merge(config) - assert merged["a"]["c"] == 2 - - -def test_config_to_str_roundtrip(): - numpy = pytest.importorskip("numpy") - cfg = {"cfg": {"foo": False}} - config_str = Config(cfg).to_str() - assert config_str == "[cfg]\nfoo = false" - config = Config().from_str(config_str) - assert dict(config) == cfg - cfg = {"cfg": {"foo": "false"}} - config_str = Config(cfg).to_str() - assert config_str == '[cfg]\nfoo = "false"' - config = Config().from_str(config_str) - assert dict(config) == cfg - # Bad non-serializable value - cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}} - config = Config(cfg) - with pytest.raises(ConfigValidationError): - config.to_str() - # Roundtrip with variables: preserve variables correctly (quoted/unquoted) - config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\"""" - config = Config().from_str(config_str, interpolate=False) - assert config.to_str() == config_str - - -def test_config_is_interpolated(): - """Test that a config object correctly reports whether it's interpolated.""" - config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}""" - config = Config().from_str(config_str, interpolate=False) - assert not config.is_interpolated - config = config.merge(Config({"x": {"y": "z"}})) - assert not config.is_interpolated - config = Config(config) - assert not config.is_interpolated - config = config.interpolate() - assert config.is_interpolated - config = config.merge(Config().from_str(config_str, interpolate=False)) - assert not config.is_interpolated - - -@pytest.mark.parametrize( - "section_order,expected_str,expected_keys", - [ - # fmt: off - ( - [], - "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", - ["a", "h", "j"], - ), - ( - ["j", "h", "a"], - "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", - ["j", "h", "a"], - ), - ( - ["h"], - "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", - ["h", "a", "j"], - ), - # fmt: on - ], -) -def test_config_serialize_custom_sort(section_order, expected_str, expected_keys): - cfg = { - "j": {"k": 6}, - "a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}}, - "h": {"i": 5}, - } - cfg_str = Config(cfg).to_str() - assert Config(cfg, section_order=section_order).to_str() == expected_str - keys = list(Config(section_order=section_order).from_str(cfg_str).keys()) - assert keys == expected_keys - keys = list(Config(cfg, section_order=section_order).keys()) - assert keys == expected_keys - - -def test_config_custom_sort_preserve(): - """Test that sort order is preserved when merging and copying configs, - or when configs are filled and resolved.""" - cfg = {"x": {}, "y": {}, "z": {}} - section_order = ["y", "z", "x"] - expected = "[y]\n\n[z]\n\n[x]" - config = Config(cfg, section_order=section_order) - assert config.to_str() == expected - config2 = config.copy() - assert config2.to_str() == expected - config3 = config.merge({"a": {}}) - assert config3.to_str() == f"{expected}\n\n[a]" - config4 = Config(config) - assert config4.to_str() == expected - config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2""" - section_order = ["c", "a", "t"] - config5 = Config(section_order=section_order).from_str(config_str) - assert list(config5.keys()) == section_order - filled = my_registry.fill(config5) - assert filled.section_order == section_order - - -def test_config_pickle(): - config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"]) - data = pickle.dumps(config) - config_new = pickle.loads(data) - assert config_new == {"foo": "bar"} - assert config_new.section_order == ["foo", "bar", "baz"] - - -def test_config_fill_extra_fields(): - """Test that filling a config from a schema removes extra fields.""" - - class TestSchemaContent(BaseModel): - a: str - b: int - - class Config: - extra = "forbid" - - class TestSchema(BaseModel): - cfg: TestSchemaContent - - config = Config({"cfg": {"a": "1", "b": 2, "c": True}}) - with pytest.raises(ConfigValidationError): - my_registry.fill(config, schema=TestSchema) - filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - config2 = config.interpolate() - filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False) - filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2} - - class TestSchemaContent2(BaseModel): - a: str - b: int - - class Config: - extra = "allow" - - class TestSchema2(BaseModel): - cfg: TestSchemaContent2 - - filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"] - assert filled == {"a": "1", "b": 2, "c": True} - - -def test_config_validation_error_custom(): - class Schema(BaseModel): - hello: int - world: int - - config = {"hello": 1, "world": "hi!"} - with pytest.raises(ConfigValidationError) as exc_info: - my_registry._fill(config, Schema) - e1 = exc_info.value - assert e1.title == "Config validation error" - assert e1.desc is None - assert not e1.parent - assert e1.show_config is True - assert len(e1.errors) == 1 - assert e1.errors[0]["loc"] == ("world",) - assert e1.errors[0]["msg"] == "value is not a valid integer" - assert e1.errors[0]["type"] == "type_error.integer" - assert e1.error_types == set(["type_error.integer"]) - # Create a new error with overrides - title = "Custom error" - desc = "Some error description here" - e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False) - assert e2.errors == e1.errors - assert e2.error_types == e1.error_types - assert e2.title == title - assert e2.desc == desc - assert e2.show_config is False - assert e1.text != e2.text - - -def test_config_parsing_error(): - config_str = "[a]\nb c" - with pytest.raises(ConfigValidationError): - Config().from_str(config_str) - - -def test_config_fill_without_resolve(): - class BaseSchema(BaseModel): - catsie: int - - config = {"catsie": {"@cats": "catsie.v1", "evil": False}} - filled = my_registry.fill(config) - resolved = my_registry.resolve(config) - assert resolved["catsie"] == "meow" - assert filled["catsie"]["cute"] is True - with pytest.raises(ConfigValidationError): - my_registry.resolve(config, schema=BaseSchema) - filled2 = my_registry.fill(config, schema=BaseSchema) - assert filled2["catsie"]["cute"] is True - resolved = my_registry.resolve(filled2) - assert resolved["catsie"] == "meow" - - # With unavailable function - class BaseSchema2(BaseModel): - catsie: Any - other: int = 12 - - config = {"catsie": {"@cats": "dog", "evil": False}} - filled3 = my_registry.fill(config, schema=BaseSchema2) - assert filled3["catsie"] == config["catsie"] - assert filled3["other"] == 12 - - -def test_config_dataclasses(): - cat = Cat("testcat", value_in=1, value_out=2) - config = {"cfg": {"@cats": "catsie.v3", "arg": cat}} - result = my_registry.resolve(config)["cfg"] - assert isinstance(result, Cat) - assert result.name == cat.name - assert result.value_in == cat.value_in - assert result.value_out == cat.value_out - - -@pytest.mark.parametrize( - "greeting,value,expected", - [ - # simple substitution should go fine - [342, "${vars.a}", int], - ["342", "${vars.a}", str], - ["everyone", "${vars.a}", str], - ], -) -def test_config_interpolates(greeting, value, expected): - str_cfg = f""" - [project] - my_par = {value} - - [vars] - a = "something" - """ - overrides = {"vars.a": greeting} - cfg = Config().from_str(str_cfg, overrides=overrides) - assert type(cfg["project"]["my_par"]) is expected - - -@pytest.mark.parametrize( - "greeting,value,expected", - [ - # fmt: off - # simple substitution should go fine - ["hello 342", "${vars.a}", "hello 342"], - ["hello everyone", "${vars.a}", "hello everyone"], - ["hello tout le monde", "${vars.a}", "hello tout le monde"], - ["hello 42", "${vars.a}", "hello 42"], - # substituting an element in a list - ["hello 342", "[1, ${vars.a}, 3]", "hello 342"], - ["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"], - ["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"], - ["hello 42", "[1, ${vars.a}, 3]", "hello 42"], - # substituting part of a string - [342, "hello ${vars.a}", "hello 342"], - ["everyone", "hello ${vars.a}", "hello everyone"], - ["tout le monde", "hello ${vars.a}", "hello tout le monde"], - ["42", "hello ${vars.a}", "hello 42"], - # substituting part of a implicit string inside a list - [342, "[1, hello ${vars.a}, 3]", "hello 342"], - ["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"], - ["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"], - ["42", "[1, hello ${vars.a}, 3]", "hello 42"], - # substituting part of a explicit string inside a list - [342, "[1, 'hello ${vars.a}', '3']", "hello 342"], - ["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"], - ["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"], - ["42", "[1, 'hello ${vars.a}', '3']", "hello 42"], - # more complicated example - [342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"], - ["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"], - [ - "tout le monde", - "[{'name':'x','script':['hello ${vars.a}']}]", - "hello tout le monde", - ], - ["42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42"], - # fmt: on - ], -) -def test_config_overrides(greeting, value, expected): - str_cfg = f""" - [project] - commands = {value} - - [vars] - a = "world" - """ - overrides = {"vars.a": greeting} - assert "${vars.a}" in str_cfg - cfg = Config().from_str(str_cfg, overrides=overrides) - assert expected in str(cfg) - - -def test_warn_single_quotes(): - str_cfg = """ - [project] - commands = 'do stuff' - """ - - with pytest.warns(UserWarning, match="single-quoted"): - Config().from_str(str_cfg) - - # should not warn if single quotes are in the middle - str_cfg = """ - [project] - commands = some'thing - """ - Config().from_str(str_cfg) - - -def test_parse_strings_interpretable_as_ints(): - """Test whether strings interpretable as integers are parsed correctly (i. e. as strings).""" - cfg = Config().from_str( - f"""[a]\nfoo = [${{b.bar}}, "00${{b.bar}}", "y"]\n\n[b]\nbar = 3""" # noqa: F541 - ) - assert cfg["a"]["foo"] == [3, "003", "y"] - assert cfg["b"]["bar"] == 3 diff --git a/confection/tests/test_config_values.py b/confection/tests/test_config_values.py deleted file mode 100644 index 3836068..0000000 --- a/confection/tests/test_config_values.py +++ /dev/null @@ -1,987 +0,0 @@ -"""Systematic tests for config value types using Hypothesis. - -Uses property-based testing to explore the space of possible config values. -""" - -import json -from configparser import ConfigParser, ExtendedInterpolation - -import pytest -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st -from numpy.testing import assert_allclose, assert_equal - -from confection import Config, try_load_json - - -def _json_dumps(value): - return json.dumps(value, separators=(",", ":")) - - -from confection.tests.util import my_registry -from confection.validation import Schema, ValidationError - -# ============================================================================= -# Unit tests for try_load_json - the core parsing function -# ============================================================================= -# Contract: parse as JSON if valid, otherwise return original string - - -class TestTryLoadJson: - """Test the core JSON parsing function.""" - - # Valid JSON literals -> parsed Python values - @pytest.mark.parametrize( - "inp,expected", - [ - ("42", 42), - ("-42", -42), - ("3.14", 3.14), - ("-3.14", -3.14), - ("0", 0), - ("true", True), - ("false", False), - ("null", None), - ("[1, 2, 3]", [1, 2, 3]), - ('{"a": 1}', {"a": 1}), - ("[]", []), - ("{}", {}), - ], - ) - def test_json_literals(self, inp, expected): - """Valid JSON literals are parsed to Python values.""" - assert try_load_json(inp) == expected - - # Quoted strings -> unquoted Python strings - @pytest.mark.parametrize( - "inp,expected", - [ - ('"hello"', "hello"), - ('"with spaces"', "with spaces"), - ('""', ""), - ('"0"', "0"), # Quoted "0" should be string, not int - ('"-42"', "-42"), # Quoted "-42" should be string, not int - ('"true"', "true"), # Quoted "true" should be string, not bool - ('"false"', "false"), - ('"null"', "null"), - ('"3.14"', "3.14"), # Quoted "3.14" should be string, not float - ], - ) - def test_quoted_strings(self, inp, expected): - """Quoted strings are unquoted to Python strings.""" - assert try_load_json(inp) == expected - - # Invalid JSON -> returned as-is - @pytest.mark.parametrize( - "inp", - [ - "hello", # unquoted string - "hello world", # unquoted with space - "not json", - "${var.ref}", # variable reference - "hello ${var}", # string with variable - ], - ) - def test_invalid_json_returned_as_is(self, inp): - """Invalid JSON strings are returned unchanged.""" - assert try_load_json(inp) == inp - - -# ============================================================================= -# Unit tests for CustomInterpolation.before_read -# ============================================================================= -# This is where the bug lives. before_read receives raw INI values and -# preprocesses them before interpolation. - -# ============================================================================= -# Tests for parsing with plain ExtendedInterpolation (no CustomInterpolation) -# ============================================================================= -# These tests define what SHOULD work if we remove the buggy CustomInterpolation - - -class TestPlainExtendedInterpolation: - """Test parsing with plain ExtendedInterpolation (no custom before_read).""" - - def _parse(self, config_str): - """Parse a config string using plain ExtendedInterpolation.""" - parser = ConfigParser(interpolation=ExtendedInterpolation()) - parser.read_string(config_str) - return parser - - def _parse_value(self, raw_value): - """Parse a single raw value through try_load_json.""" - return try_load_json(raw_value) - - # Basic value types - @pytest.mark.parametrize( - "ini_value,expected_type,expected_value", - [ - ("42", int, 42), - ("-42", int, -42), - ("3.14", float, 3.14), - ("true", bool, True), - ("false", bool, False), - ("null", type(None), None), - ("[1, 2, 3]", list, [1, 2, 3]), - ('{"a": 1}', dict, {"a": 1}), - ], - ) - def test_unquoted_json_literals(self, ini_value, expected_type, expected_value): - """Unquoted JSON literals parse to their Python types.""" - parser = self._parse(f"[s]\nv = {ini_value}") - raw = parser.get("s", "v") - parsed = self._parse_value(raw) - assert type(parsed) is expected_type - assert parsed == expected_value - - # Quoted strings - the key test cases - @pytest.mark.parametrize( - "ini_value,expected", - [ - ('"hello"', "hello"), - ('"with spaces"', "with spaces"), - ('""', ""), - ('"0"', "0"), # Must stay string, not become int - ('"-42"', "-42"), # Must stay string, not become int - ('"3.14"', "3.14"), # Must stay string, not become float - ('"true"', "true"), # Must stay string, not become bool - ('"false"', "false"), - ('"null"', "null"), # Must stay string, not become None - ], - ) - def test_quoted_strings_stay_strings(self, ini_value, expected): - """Quoted strings must parse to Python strings, not other types.""" - parser = self._parse(f"[s]\nv = {ini_value}") - raw = parser.get("s", "v") - parsed = self._parse_value(raw) - assert isinstance(parsed, str), f"Expected str, got {type(parsed).__name__}" - assert parsed == expected - - # Value interpolation - def test_value_interpolation(self): - """Value interpolation ${section:key} should work.""" - config_str = """ -[vars] -x = 10 -name = "hello" - -[section] -a = ${vars:x} -b = ${vars:name} -""" - parser = self._parse(config_str) - assert self._parse_value(parser.get("section", "a")) == 10 - assert self._parse_value(parser.get("section", "b")) == "hello" - - -def assert_values_equal(actual, expected): - """Assert values are equal, using approximate comparison for floats. - - Note: Very small floats may round to 0 due to srsly serialization limits. - """ - if isinstance(expected, float): - # Use both relative and absolute tolerance - # atol handles small numbers that round to 0 - assert_allclose(actual, expected, rtol=1e-5, atol=1e-10) - elif isinstance(expected, dict): - assert set(actual.keys()) == set(expected.keys()) - for k in expected: - assert_values_equal(actual[k], expected[k]) - elif isinstance(expected, list): - assert len(actual) == len(expected) - for a, e in zip(actual, expected): - assert_values_equal(a, e) - else: - assert_equal(actual, expected) - - -# ============================================================================= -# Strategies for config values -# ============================================================================= - -# Field names: valid Python identifiers, not starting with @ -field_names = st.from_regex(r"[a-z][a-z0-9_]{0,10}", fullmatch=True) - -# Scalar values -scalar_values = st.one_of( - st.text( - min_size=0, max_size=20, alphabet=st.characters(blacklist_categories=["Cs"]) - ), - st.integers(min_value=-1000000, max_value=1000000), - st.floats(allow_nan=False, allow_infinity=False, min_value=-1e6, max_value=1e6), - st.booleans(), -) - -# Alphabet for dictionary keys - alphanumeric plus underscore, safe for JSON -DICT_KEY_ALPHABET = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" - - -# Recursive strategy for nested values (lists, dicts) -def config_values(): - """Strategy for any value that can appear in a config.""" - return st.recursive( - scalar_values, - lambda children: st.one_of( - st.lists(children, max_size=3), - st.dictionaries( - st.text(min_size=1, max_size=8, alphabet=DICT_KEY_ALPHABET), - children, - max_size=3, - ), - ), - max_leaves=10, - ) - - -# A config section: dict with string field names and config values -config_section = st.dictionaries( - field_names, - config_values(), - min_size=1, - max_size=5, -) - -# Section names for nested sections (no dots allowed in individual names) -section_names = st.from_regex(r"[a-z][a-z0-9_]{0,8}", fullmatch=True) - - -@st.composite -def nested_config(draw): - """Generate a config with nested sections like [outer.inner.deep].""" - # Generate 1-3 levels of nesting - depth = draw(st.integers(min_value=1, max_value=3)) - path = [draw(section_names) for _ in range(depth)] - - # Generate content for the deepest section - content = draw(config_section) - - # Build nested structure from inside out - result = content - for name in reversed(path): - result = {name: result} - - return result, path, content - - -@st.composite -def config_with_positional_args(draw): - """Generate a config with positional args using [section.*.name] syntax.""" - parent_name = draw(section_names) - - parent_fields = draw( - st.dictionaries( - field_names, - st.one_of( - st.integers(min_value=-100, max_value=100), - st.text(min_size=1, max_size=10, alphabet=DICT_KEY_ALPHABET), - ), - min_size=1, - max_size=3, - ) - ) - - positional_name_pool = ["pos1", "pos2", "pos3", "item1", "item2", "item3"] - num_positional = draw(st.integers(min_value=1, max_value=3)) - positional_names = draw( - st.permutations(positional_name_pool).map(lambda x: list(x)[:num_positional]) - ) - - positional_contents = [] - for _ in positional_names: - content = draw( - st.dictionaries( - field_names, - st.one_of( - st.integers(min_value=-100, max_value=100), - st.booleans(), - ), - min_size=1, - max_size=3, - ) - ) - positional_contents.append(content) - - expected = dict(parent_fields) - expected["*"] = { - name: content for name, content in zip(positional_names, positional_contents) - } - - return ( - parent_name, - positional_names, - parent_fields, - positional_contents, - {parent_name: expected}, - ) - - -@st.composite -def config_with_interpolation(draw): - """Generate a config with variable interpolation.""" - source_fields = draw( - st.dictionaries( - field_names, - st.one_of( - st.integers(min_value=-1000, max_value=1000), - st.text(min_size=1, max_size=10, alphabet=DICT_KEY_ALPHABET), - ), - min_size=1, - max_size=5, - ) - ) - - source_keys = list(source_fields.keys()) - num_refs = draw(st.integers(min_value=1, max_value=len(source_keys))) - ref_keys = draw(st.permutations(source_keys))[:num_refs] - - target_fields = {} - expected_target = {} - for source_key in ref_keys: - target_field = draw(field_names) - target_fields[target_field] = f"${{source.{source_key}}}" - expected_target[target_field] = source_fields[source_key] - - config = { - "source": source_fields, - "target": target_fields, - } - - return config, expected_target - - -# ============================================================================= -# Config String Strategy - generates INI-format config strings directly -# ============================================================================= - -ini_string_values = st.text( - min_size=1, - max_size=20, - alphabet="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_- ", -) -ini_int_values = st.integers(min_value=-10000, max_value=10000).map(str) -ini_float_values = st.floats( - allow_nan=False, allow_infinity=False, min_value=-1000, max_value=1000 -).map(lambda x: f"{x:.6g}") -ini_bool_values = st.sampled_from(["true", "false"]) - -ini_scalar_value = st.one_of( - ini_string_values.map(lambda s: _json_dumps(s)), - ini_int_values, - ini_float_values, - ini_bool_values, -) - -ini_list_value = st.lists(ini_scalar_value, min_size=0, max_size=5).map( - lambda items: "[" + ", ".join(items) + "]" -) - -ini_value = st.one_of(ini_scalar_value, ini_list_value) - - -@st.composite -def config_string(draw): - """Generate a config string in INI format.""" - num_sections = draw(st.integers(min_value=1, max_value=3)) - sections = [] - expected = {} - - for _ in range(num_sections): - section_name = draw(section_names) - while section_name in expected: - section_name = draw(section_names) - - num_fields = draw(st.integers(min_value=1, max_value=5)) - fields = [] - section_expected = {} - - for _ in range(num_fields): - field_name = draw(field_names) - while field_name in section_expected: - field_name = draw(field_names) - - value_type = draw( - st.sampled_from(["string", "int", "float", "bool", "list"]) - ) - - if value_type == "string": - py_value = draw(ini_string_values) - ini_str = _json_dumps(py_value) - elif value_type == "int": - py_value = draw(st.integers(min_value=-10000, max_value=10000)) - ini_str = str(py_value) - elif value_type == "float": - py_value = draw( - st.floats( - allow_nan=False, - allow_infinity=False, - min_value=-1000, - max_value=1000, - ) - ) - ini_str = f"{py_value:.6g}" - elif value_type == "bool": - py_value = draw(st.booleans()) - ini_str = "true" if py_value else "false" - else: # list - list_len = draw(st.integers(min_value=0, max_value=3)) - py_value = [ - draw(st.integers(min_value=-100, max_value=100)) - for _ in range(list_len) - ] - ini_str = "[" + ", ".join(str(x) for x in py_value) + "]" - - fields.append(f"{field_name} = {ini_str}") - section_expected[field_name] = py_value - - section_str = f"[{section_name}]\n" + "\n".join(fields) - sections.append(section_str) - expected[section_name] = section_expected - - config_str = "\n\n".join(sections) - return config_str, expected - - -# ============================================================================= -# Tests -# ============================================================================= - - -@given(section=config_section) -@settings(max_examples=200, suppress_health_check=[HealthCheck.too_slow]) -def test_config_section_roundtrip(section): - """Test that a config section survives being converted to string and back.""" - cfg = Config({"section": section}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert_values_equal(parsed["section"], section) - - -# Scalar values that are safe for roundtrip (exclude strings with problematic patterns) -safe_scalar_values = st.one_of( - # Exclude strings that look like JSON primitives with whitespace - st.text(min_size=0, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz_"), - st.integers(min_value=-1000000, max_value=1000000), - st.floats(allow_nan=False, allow_infinity=False, min_value=-1e6, max_value=1e6), - st.booleans(), -) - - -@given(value=safe_scalar_values) -@settings(max_examples=100) -def test_scalar_value_roundtrip(value): - """Test that scalar values roundtrip correctly.""" - cfg = Config({"section": {"field": value}}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert_values_equal(parsed["section"]["field"], value) - - -@given(items=st.lists(scalar_values, max_size=10)) -@settings(max_examples=100) -def test_list_value_roundtrip(items): - """Test that list values roundtrip correctly.""" - cfg = Config({"section": {"field": items}}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert_values_equal(parsed["section"]["field"], items) - - -@given( - mapping=st.dictionaries( - st.text(min_size=1, max_size=10, alphabet=DICT_KEY_ALPHABET), - scalar_values, - max_size=5, - ) -) -@settings(max_examples=100) -def test_dict_value_roundtrip(mapping): - """Test that dict values (data, not sections) roundtrip correctly.""" - cfg = Config({"section": {"field": mapping}}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert_values_equal(parsed["section"]["field"], mapping) - - -@given( - section1=config_section, - section2=config_section, -) -@settings(max_examples=100) -def test_multiple_sections_roundtrip(section1, section2): - """Test that multiple sections roundtrip correctly.""" - cfg = Config({"section1": section1, "section2": section2}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert_values_equal(parsed["section1"], section1) - assert_values_equal(parsed["section2"], section2) - - -# ============================================================================= -# Nested Sections -# ============================================================================= - - -@given(data=nested_config()) -@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow]) -def test_nested_sections_roundtrip(data): - """Test that nested sections like [outer.inner] roundtrip correctly.""" - config, path, content = data - cfg = Config(config) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - - # Navigate to the nested content - node = parsed - for name in path: - node = node[name] - - assert_values_equal(node, content) - - -# ============================================================================= -# Positional Args ([section.*.name] syntax) -# ============================================================================= - - -@given(data=config_with_positional_args()) -@settings( - max_examples=100, - suppress_health_check=[HealthCheck.too_slow, HealthCheck.large_base_example], -) -def test_positional_args_roundtrip(data): - """Test that [section.*.name] positional args syntax roundtrips correctly.""" - parent_name, positional_names, parent_fields, positional_contents, expected = data - - lines = [f"[{parent_name}]"] - for key, value in parent_fields.items(): - lines.append(f"{key} = {_json_dumps(value)}") - - for name, content in zip(positional_names, positional_contents): - lines.append(f"\n[{parent_name}.*.{name}]") - for key, value in content.items(): - lines.append(f"{key} = {_json_dumps(value)}") - - config_str = "\n".join(lines) - - # Parse and verify - parsed = Config().from_str(config_str) - assert_values_equal(dict(parsed), expected) - - # Verify roundtrip - regenerated = parsed.to_str() - parsed2 = Config().from_str(regenerated) - assert_values_equal(dict(parsed2), expected) - - -# ============================================================================= -# Variable Interpolation -# ============================================================================= - - -@given(data=config_with_interpolation()) -@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow]) -def test_variable_interpolation(data): - """Test that variable interpolation ${section.key} works correctly.""" - config, expected_target = data - - # Create config from dict (uninterpolated) - cfg = Config(config) - - # Convert to string and back (still uninterpolated) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - - # Interpolate - interpolated = parsed.interpolate() - - # Check target section has interpolated values - assert_values_equal(interpolated["target"], expected_target) - - -# ============================================================================= -# Config String Parsing (from INI format) -# ============================================================================= - - -@given(data=config_string()) -@settings(max_examples=200, suppress_health_check=[HealthCheck.too_slow]) -def test_config_string_parsing(data): - """Test that generated config strings parse correctly.""" - config_str, expected = data - parsed = Config().from_str(config_str) - assert_values_equal(dict(parsed), expected) - - -@given(data=config_string()) -@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow]) -def test_config_string_roundtrip(data): - """Test that config strings survive parse -> to_str -> parse roundtrip.""" - config_str, expected = data - - # Parse the generated string - parsed1 = Config().from_str(config_str) - assert_values_equal(dict(parsed1), expected) - - # Convert back to string - regenerated = parsed1.to_str() - - # Parse again - parsed2 = Config().from_str(regenerated) - assert_values_equal(dict(parsed2), expected) - - -# ============================================================================= -# String Parsing Edge Cases -# ============================================================================= - - -@pytest.mark.parametrize( - "value", - [ - "0\n", - "1\t", - " 42", - "42 ", - "true\n", - "false ", - "null\n", - ], -) -def test_string_with_whitespace_stays_string(value): - """Strings that look like JSON primitives with whitespace stay strings.""" - cfg = Config({"section": {"field": value}}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert parsed["section"]["field"] == value - - -@pytest.mark.parametrize( - "value", - [ - "123", - "3.14", - "0", - "0.5", - ], -) -def test_numeric_string_stays_string(value): - """Strings that look like positive numbers stay strings.""" - cfg = Config({"section": {"field": value}}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert parsed["section"]["field"] == value - - -@pytest.mark.parametrize( - "value", - [ - "-42", - "-3.14", - ], -) -def test_negative_numeric_string_stays_string(value): - """Negative numeric strings stay strings.""" - cfg = Config({"section": {"field": value}}) - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - assert parsed["section"]["field"] == value - - -# ============================================================================= -# Registry Resolution Tests -# ============================================================================= - - -@st.composite -def config_with_catsie(draw): - """Generate a config with a catsie.v1 registered function.""" - evil = draw(st.booleans()) - cute = draw(st.booleans()) - - config = { - "cat": { - "@cats": "catsie.v1", - "evil": evil, - "cute": cute, - } - } - - expected = "scratch!" if evil else "meow" - return config, expected - - -@st.composite -def config_with_catsie_v2(draw): - """Generate a config with a catsie.v2 registered function.""" - evil = draw(st.booleans()) - cute = draw(st.booleans()) - cute_level = draw(st.integers(min_value=0, max_value=5)) - - config = { - "cat": { - "@cats": "catsie.v2", - "evil": evil, - "cute": cute, - "cute_level": cute_level, - } - } - - if evil: - expected = "scratch!" - elif cute_level > 2: - expected = "meow <3" - else: - expected = "meow" - - return config, expected - - -@given(data=config_with_catsie()) -@settings(max_examples=50) -def test_registry_resolve_catsie_v1(data): - """Test that catsie.v1 configs resolve correctly.""" - config, expected = data - cfg = Config(config) - resolved = my_registry.resolve(cfg) - assert resolved["cat"] == expected - - -@given(data=config_with_catsie_v2()) -@settings(max_examples=50) -def test_registry_resolve_catsie_v2(data): - """Test that catsie.v2 configs resolve correctly.""" - config, expected = data - cfg = Config(config) - resolved = my_registry.resolve(cfg) - assert resolved["cat"] == expected - - -@given(data=config_with_catsie()) -@settings(max_examples=50) -def test_registry_roundtrip_with_resolve(data): - """Test that configs with registry refs survive roundtrip and resolve.""" - config, expected = data - cfg = Config(config) - - # Roundtrip through string - config_str = cfg.to_str() - parsed = Config().from_str(config_str) - - # Resolve should produce same result - resolved = my_registry.resolve(parsed) - assert resolved["cat"] == expected - - -@st.composite -def multiple_registry_sections(draw): - """Generate a config with multiple sections containing registered functions.""" - cat1_evil = draw(st.booleans()) - cat2_evil = draw(st.booleans()) - cat2_cute_level = draw(st.integers(min_value=0, max_value=5)) - - config = { - "cat1": { - "@cats": "catsie.v1", - "evil": cat1_evil, - }, - "cat2": { - "@cats": "catsie.v2", - "evil": cat2_evil, - "cute_level": cat2_cute_level, - }, - } - - cat1_expected = "scratch!" if cat1_evil else "meow" - if cat2_evil: - cat2_expected = "scratch!" - elif cat2_cute_level > 2: - cat2_expected = "meow <3" - else: - cat2_expected = "meow" - - return config, cat1_expected, cat2_expected - - -@given(data=multiple_registry_sections()) -@settings(max_examples=50) -def test_multiple_registry_sections_resolve(data): - """Test that multiple sections with registered functions resolve correctly.""" - config, cat1_expected, cat2_expected = data - cfg = Config(config) - resolved = my_registry.resolve(cfg) - assert resolved["cat1"] == cat1_expected - assert resolved["cat2"] == cat2_expected - - -@given(evil=st.booleans(), cute=st.booleans()) -@settings(max_examples=20) -def test_registry_fill_adds_defaults(evil, cute): - """Test that registry.fill() adds default values.""" - config = {"cat": {"@cats": "catsie.v1", "evil": evil}} - cfg = Config(config) - filled = my_registry.fill(cfg) - - assert filled["cat"]["cute"] is True - assert filled["cat"]["evil"] == evil - - -# ============================================================================= -# Schema Inference Tests (Schema.from_function / get_func_fields) -# ============================================================================= - - -def make_test_func_int(x: int) -> int: - return x - - -def make_test_func_str(x: str) -> str: - return x - - -def make_test_func_bool(x: bool) -> bool: - return x - - -def make_test_func_float(x: float) -> float: - return x - - -def make_test_func_optional(x: int, y: str = "default") -> str: - return f"{x} {y}" - - -def make_test_func_all_optional(x: int = 10, y: str = "default") -> str: - return f"{x} {y}" - - -def make_test_func_list(x: list) -> list: - return x - - -def make_test_func_typed_list(x: list[int]) -> list: - return x - - -class TestMakeFuncSchema: - """Tests for Schema.from_function inferring schemas from functions.""" - - def test_schema_accepts_correct_int(self): - """Schema accepts correct int type.""" - schema = Schema.from_function(make_test_func_int) - result = schema.model_validate({"x": 42}) - assert result.x == 42 - - def test_schema_rejects_wrong_type_for_int(self): - """Schema rejects string when int expected.""" - schema = Schema.from_function(make_test_func_int) - with pytest.raises(ValidationError): - schema.model_validate({"x": "not an int"}) - - def test_schema_accepts_correct_str(self): - """Schema accepts correct str type.""" - schema = Schema.from_function(make_test_func_str) - result = schema.model_validate({"x": "hello"}) - assert result.x == "hello" - - def test_schema_accepts_correct_bool(self): - """Schema accepts correct bool type.""" - schema = Schema.from_function(make_test_func_bool) - result = schema.model_validate({"x": True}) - assert result.x is True - - def test_schema_accepts_correct_float(self): - """Schema accepts correct float type.""" - schema = Schema.from_function(make_test_func_float) - result = schema.model_validate({"x": 3.14}) - assert result.x == 3.14 - - def test_schema_requires_required_param(self): - """Schema requires parameters without defaults.""" - schema = Schema.from_function(make_test_func_optional) - with pytest.raises(ValidationError): - schema.model_validate({"y": "provided"}) # missing x - - def test_schema_uses_default_for_optional(self): - """Schema uses default value for optional parameters.""" - schema = Schema.from_function(make_test_func_optional) - result = schema.model_validate({"x": 5}) - assert result.x == 5 - assert result.y == "default" - - def test_schema_all_optional_uses_defaults(self): - """Schema uses defaults when all params are optional.""" - schema = Schema.from_function(make_test_func_all_optional) - result = schema.model_validate({}) - assert result.x == 10 - assert result.y == "default" - - def test_schema_rejects_extra_fields(self): - """Schema rejects extra fields not in function signature.""" - schema = Schema.from_function(make_test_func_int) - with pytest.raises(ValidationError): - schema.model_validate({"x": 1, "extra": "not allowed"}) - - def test_schema_accepts_list(self): - """Schema accepts list type.""" - schema = Schema.from_function(make_test_func_list) - result = schema.model_validate({"x": [1, 2, 3]}) - assert result.x == [1, 2, 3] - - -@given(value=st.integers()) -@settings(max_examples=50) -def test_schema_int_property(value): - """Property test: schema validates any integer.""" - schema = Schema.from_function(make_test_func_int) - result = schema.model_validate({"x": value}) - assert result.x == value - - -@given(value=st.text(max_size=100)) -@settings(max_examples=50) -def test_schema_str_property(value): - """Property test: schema validates any string.""" - schema = Schema.from_function(make_test_func_str) - result = schema.model_validate({"x": value}) - assert result.x == value - - -@given(value=st.booleans()) -@settings(max_examples=10) -def test_schema_bool_property(value): - """Property test: schema validates any boolean.""" - schema = Schema.from_function(make_test_func_bool) - result = schema.model_validate({"x": value}) - assert result.x == value - - -@given(value=st.floats(allow_nan=False, allow_infinity=False)) -@settings(max_examples=50) -def test_schema_float_property(value): - """Property test: schema validates any float.""" - schema = Schema.from_function(make_test_func_float) - result = schema.model_validate({"x": value}) - assert result.x == value - - -@given(required=st.integers(), optional=st.text(max_size=20) | st.none()) -@settings(max_examples=50) -def test_schema_optional_property(required, optional): - """Property test: schema handles required and optional params.""" - schema = Schema.from_function(make_test_func_optional) - if optional is None: - result = schema.model_validate({"x": required}) - assert result.y == "default" - else: - result = schema.model_validate({"x": required, "y": optional}) - assert result.y == optional - assert result.x == required - - -@given(items=st.lists(st.integers(), max_size=10)) -@settings(max_examples=50) -def test_schema_list_property(items): - """Property test: schema validates lists.""" - schema = Schema.from_function(make_test_func_list) - result = schema.model_validate({"x": items}) - assert result.x == items diff --git a/confection/tests/test_frozen_structures.py b/confection/tests/test_frozen_structures.py deleted file mode 100644 index 91b9767..0000000 --- a/confection/tests/test_frozen_structures.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Dict, List - -import catalogue -import pytest -from pytest import raises - -from confection import Config, SimpleFrozenDict, SimpleFrozenList, registry - - -def test_frozen_list(): - frozen = SimpleFrozenList(range(10)) - - for k in range(10): - assert frozen[k] == k - - with raises(NotImplementedError, match="frozen list"): - frozen.append(5) - - with raises(NotImplementedError, match="frozen list"): - frozen.reverse() - - with raises(NotImplementedError, match="frozen list"): - frozen.pop(0) - - -def test_frozen_dict(): - frozen = SimpleFrozenDict({k: k for k in range(10)}) - - for k in range(10): - assert frozen[k] == k - - with raises(NotImplementedError, match="frozen dictionary"): - frozen[0] = 1 - - with raises(NotImplementedError, match="frozen dictionary"): - frozen[10] = 1 - - -@pytest.mark.parametrize("frozen_type", ("dict", "list")) -def test_frozen_struct_deepcopy(frozen_type): - """Test whether setting default values for a FrozenDict/FrozenList works within a config, which utilizes - deepcopy.""" - registry.bar = catalogue.create("confection", "bar", entry_points=False) - - @registry.bar.register("foo_dict.v1") - def make_dict(values: Dict[str, int] = SimpleFrozenDict(x=3)): - return values - - @registry.bar.register("foo_list.v1") - def make_list(values: List[int] = SimpleFrozenList([1, 2, 3])): - return values - - cfg = Config() - resolved = registry.resolve( - cfg.from_str( - f""" - [something] - @bar = "foo_{frozen_type}.v1" - """ - ) - ) - - assert isinstance(resolved["something"], Dict if frozen_type == "dict" else List) diff --git a/confection/tests/util.py b/confection/tests/util.py deleted file mode 100644 index 5ff4f54..0000000 --- a/confection/tests/util.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Registered functions used for config tests. -""" - -import contextlib -import dataclasses -import shutil -import tempfile -from pathlib import Path -from typing import Generator, Generic, Iterable, List, Optional, TypeVar, Union - -import catalogue - -try: - from pydantic.v1.types import StrictBool -except (ImportError, TypeError): - try: - from pydantic.types import StrictBool # type: ignore - except (ImportError, TypeError): - from confection.validation import StrictBool # type: ignore - -import confection - -FloatOrSeq = Union[float, List[float], Generator] -InT = TypeVar("InT") -OutT = TypeVar("OutT") - - -@dataclasses.dataclass -class Cat(Generic[InT, OutT]): - name: str - value_in: InT - value_out: OutT - - -my_registry_namespace = "config_tests" - - -class my_registry(confection.registry): - namespace = "config_tests" - cats = catalogue.create(namespace, "cats", entry_points=False) - optimizers = catalogue.create(namespace, "optimizers", entry_points=False) - schedules = catalogue.create(namespace, "schedules", entry_points=False) - initializers = catalogue.create(namespace, "initializers", entry_points=False) - layers = catalogue.create(namespace, "layers", entry_points=False) - - -@my_registry.cats.register("catsie.v1") -def catsie_v1(evil: StrictBool, cute: bool = True) -> str: - if evil: - return "scratch!" - else: - return "meow" - - -@my_registry.cats.register("catsie.v2") -def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str: - if evil: - return "scratch!" - else: - if cute_level > 2: - return "meow <3" - return "meow" - - -@my_registry.cats("catsie.v3") -def catsie(arg: Cat) -> Cat: - return arg - - -@my_registry.optimizers("Adam.v1") -def Adam( - learn_rate: FloatOrSeq = 0.001, - *, - beta1: FloatOrSeq = 0.001, - beta2: FloatOrSeq = 0.001, - use_averages: bool = True, -): - """ - Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used - to illustrate how to use the function registry, e.g. with thinc. - """ - - @dataclasses.dataclass - class Optimizer: - learn_rate: FloatOrSeq - beta1: FloatOrSeq - beta2: FloatOrSeq - use_averages: bool - - return Optimizer( - learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages - ) - - -@my_registry.schedules("warmup_linear.v1") -def warmup_linear( - initial_rate: float, warmup_steps: int, total_steps: int -) -> Iterable[float]: - """Generate a series, starting from an initial rate, and then with a warmup - period, and then a linear decline. Used for learning rates. - """ - step = 0 - while True: - if step < warmup_steps: - factor = step / max(1, warmup_steps) - else: - factor = max( - 0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps) - ) - yield factor * initial_rate - step += 1 - - -@my_registry.cats("generic_cat.v1") -def generic_cat(cat: Cat[int, int]) -> Cat[int, int]: - cat.name = "generic_cat" - return cat - - -@my_registry.cats("int_cat.v1") -def int_cat( - value_in: Optional[int] = None, value_out: Optional[int] = None -) -> Cat[Optional[int], Optional[int]]: - """Instantiates cat with integer values.""" - return Cat(name="int_cat", value_in=value_in, value_out=value_out) - - -@my_registry.optimizers.register("my_cool_optimizer.v1") -def make_my_optimizer(learn_rate: List[float], beta1: float): - return Adam(learn_rate, beta1=beta1) - - -@my_registry.schedules("my_cool_repetitive_schedule.v1") -def decaying(base_rate: float, repeat: int) -> List[float]: - return repeat * [base_rate] - - -@contextlib.contextmanager -def make_tempdir(): - d = Path(tempfile.mkdtemp()) - yield d - shutil.rmtree(str(d)) diff --git a/confection/typechecker.py b/confection/typechecker.py new file mode 100644 index 0000000..95ad22c --- /dev/null +++ b/confection/typechecker.py @@ -0,0 +1,565 @@ +""" +A structural type checker with clean separation of concerns. + +The checker validates values against type annotations without requiring +the types to be instantiated. It supports standard library types, generics, +Union/Optional, Literal, Annotated, TypeVar, NewType, dataclasses, and +pydantic-compatible custom types. + +Architecture: + check_type - entry point, dispatches to custom handlers or standard path + get_annot_branches - peels Union/Optional into flat alternatives + check_branch - outer_match then decompose+recurse + outer_match - does the value match this annotation at the top level? + decompose - yield (child_value, child_annotation, child_ctx) triples +""" + +from __future__ import annotations + +import collections.abc +import inspect +import types +from dataclasses import dataclass, field, is_dataclass +from dataclasses import fields as dataclass_fields +from enum import Enum +from pathlib import PurePath +from types import GeneratorType +from typing import ( + Annotated, + Any, + Callable, + Dict, + ForwardRef, + Iterator, + Literal, + Optional, + Tuple, + TypeVar, + Union, + get_args, + get_origin, +) + +# --------------------------------------------------------------------------- +# Error accumulation +# --------------------------------------------------------------------------- + + +@dataclass +class TypeCheckError: + """A single type-check failure, recording the path, value, and expected type.""" + + path: Tuple[Any, ...] + value: Any + annotation: Any + + def __str__(self) -> str: + path_str = " → ".join(str(p) for p in self.path) if self.path else "root" + return f"at {path_str}: {self.value!r} is not {self.annotation}" + + +@dataclass +class Ctx: + """Accumulates errors during a type-check traversal. + + All recursive calls share the same ``errors`` list via ``child()``, + so errors from any depth are collected in one place. + """ + + path: Tuple[Any, ...] = () + errors: list[TypeCheckError] = field(default_factory=list) + + def child(self, segment: Any) -> Ctx: + """Create a child context with *segment* appended to the path.""" + return Ctx(self.path + (segment,), self.errors) # shared errors list + + def fail(self, value: Any, annotation: Any) -> None: + """Record a type-check failure at the current path.""" + self.errors.append(TypeCheckError(self.path, value, annotation)) + + +# --------------------------------------------------------------------------- +# Main entry points +# --------------------------------------------------------------------------- + + +def check_type( + value: Any, + annotation: Any, + custom_handlers: Optional[Dict[type, Callable[..., bool]]] = None, + ctx: Optional[Ctx] = None, +) -> bool: + """Check whether *value* is compatible with *annotation*. + + Returns ``True`` if the value matches, ``False`` otherwise. Errors are + accumulated in *ctx* (created automatically if not provided). + + *custom_handlers* maps ``type(value)`` to a callable + ``(value, annotation, handlers, ctx) -> bool`` that overrides the + default checking logic for that runtime type. + """ + if custom_handlers is None: + custom_handlers = {} + if ctx is None: + ctx = Ctx() + + if type(value) in custom_handlers: + return custom_handlers[type(value)](value, annotation, custom_handlers, ctx) + + return any( + check_branch(value, branch, custom_handlers, ctx) + for branch in get_annot_branches(annotation) + ) + + +def check_branch( + value: Any, + annotation: Any, + custom_handlers: Dict[type, Callable[..., bool]], + ctx: Ctx, +) -> bool: + """Check *value* against a single (non-Union) annotation branch. + + First checks the top-level match via ``outer_match``, then recursively + checks children yielded by ``decompose``. + """ + if not outer_match(value, annotation): + ctx.fail(value, annotation) + return False + return all( + check_type(v, a, custom_handlers, child_ctx) + for v, a, child_ctx in decompose(value, annotation, ctx) + ) + + +# --------------------------------------------------------------------------- +# get_annot_branches: peel Union/Optional into flat alternatives +# --------------------------------------------------------------------------- + + +def get_annot_branches(annotation: Any) -> Tuple[Any, ...]: + """Split a (possibly Union) annotation into individual branches. + + ``Union[int, str]`` becomes ``(int, str)``. ``Optional[X]`` becomes + ``(X, NoneType)``. ``X | Y`` (Python 3.10+) is handled via + ``types.UnionType``. Non-union annotations are returned as a + single-element tuple. + """ + origin = get_origin(annotation) + + # Union[X, Y] and Optional[X] (which is Union[X, None]) + if origin is Union: + return get_args(annotation) + + # types.UnionType handles X | Y syntax (Python 3.10+) + if origin is types.UnionType or isinstance(annotation, types.UnionType): + return get_args(annotation) + + return (annotation,) + + +# --------------------------------------------------------------------------- +# outer_match: does the value match at this level, ignoring children? +# --------------------------------------------------------------------------- + +#: Map from typing generic origins to the runtime types used for isinstance. +ORIGIN_TO_BUILTIN: Dict[Any, Any] = { + list: list, + dict: dict, + tuple: tuple, + set: set, + frozenset: frozenset, + collections.abc.Sequence: collections.abc.Sequence, + collections.abc.MutableSequence: collections.abc.MutableSequence, + collections.abc.Set: collections.abc.Set, + collections.abc.MutableSet: collections.abc.MutableSet, + collections.abc.Mapping: collections.abc.Mapping, + collections.abc.MutableMapping: collections.abc.MutableMapping, + collections.abc.Callable: collections.abc.Callable, + collections.abc.Iterable: collections.abc.Iterable, + collections.abc.Iterator: collections.abc.Iterator, +} + + +def outer_match(value: Any, annotation: Any) -> bool: + """Check whether *value* matches *annotation* at the top level. + + This does **not** recurse into container elements — that is the job of + ``decompose`` + ``check_branch``. Coercion rules: + + * ``bool`` requires an exact bool (``0``/``1`` are rejected). + * ``int`` accepts ints and parseable strings, but rejects bools. + * ``float`` accepts ints, floats, and parseable strings, but rejects bools. + * ``Path``/``PurePath`` accept strings. + * ``str`` enums accept valid member value strings. + * Generators / iterators always pass (to avoid consuming them). + * Unresolved ``ForwardRef`` and string annotations always pass. + """ + # Any / Parameter.empty matches everything + if annotation is Any or annotation is inspect.Parameter.empty: + return True + + # None/NoneType + if annotation is None or annotation is type(None): + return value is None + + # Generators pass through without consumption + if isinstance(value, (GeneratorType, collections.abc.Iterator)) and not isinstance( + value, (str, bytes) + ): + return True + + # Literal[v1, v2, ...] + if get_origin(annotation) is Literal: + return value in get_args(annotation) + + # Annotated[T, ...] — unwrap to inner type + if get_origin(annotation) is Annotated: + inner = get_args(annotation)[0] + metadata = get_args(annotation)[1:] + if _has_strict_metadata(metadata): + return _strict_match(value, inner) + return outer_match(value, inner) + + # Type[X] — value should be a class that is a subclass of X + if get_origin(annotation) is type: + args = get_args(annotation) + if not isinstance(value, type): + return False + if args and args[0] is not Any: + try: + return issubclass(value, args[0]) + except ( + TypeError + ): # pragma: no cover -- modern Python handles Union in issubclass + return True # pragma: no cover + return True + + # Callable — just check callability here, signature checking is hard + # and arguably belongs in a custom handler + if get_origin(annotation) is collections.abc.Callable: + return callable(value) + + # Generic types: List[int], Dict[str, int], etc. + origin = get_origin(annotation) + if origin is not None: + check_against = ORIGIN_TO_BUILTIN.get(origin, origin) + try: + if not isinstance(value, check_against): + return False + except TypeError: # pragma: no cover -- custom generics with non-type origins + return True # pragma: no cover + # Fixed-length tuple: check length here + if origin is tuple: + args = get_args(annotation) + if args and not (len(args) == 2 and args[1] is Ellipsis): + if len(value) != len(args): + return False + return True + + # NewType — unwrap to supertype + if callable(annotation) and hasattr(annotation, "__supertype__"): + return outer_match(value, annotation.__supertype__) # pyright: ignore[reportFunctionMemberAccess] + + # TypeVar + if isinstance(annotation, TypeVar): + bound = annotation.__bound__ + constraints = annotation.__constraints__ + if bound: + try: + return isinstance(value, bound) + except TypeError: + # bound contains unresolved ForwardRefs or complex generics + return outer_match(value, bound) + if constraints: + return any(outer_match(value, c) for c in constraints) + return True + + # Forward references — can't resolve, accept + if isinstance(annotation, (str, ForwardRef)): + return True + + # --- Plain types with coercion --- + if isinstance(annotation, type): + # bool: exact type (don't accept int 0/1) + if annotation is bool: + return isinstance(value, bool) + # int: accept ints (not bools), reject strings that don't parse + if annotation is int: + if isinstance(value, bool): + return False + if isinstance(value, int): + return True + if isinstance(value, str): + try: + int(value) + return True + except (ValueError, TypeError): + return False + return False + # float: accept int/float (not bools), strings that parse + if annotation is float: + if isinstance(value, bool): + return False + if isinstance(value, (int, float)): + return True + if isinstance(value, str): + try: + float(value) + return True + except (ValueError, TypeError): + return False + return False + # str: straightforward + if annotation is str: + return isinstance(value, str) + # Path: accept strings + if issubclass(annotation, PurePath): + return isinstance(value, (str, PurePath)) + # str enums: accept plain strings that are valid member values + if issubclass(annotation, str) and issubclass(annotation, Enum): + try: + annotation(value) + return True + except (ValueError, KeyError): + return False + # Dataclass / Schema with dict value: accept for decompose + if is_dataclass(annotation) and isinstance(value, dict): + return True + if hasattr(annotation, "model_fields") and isinstance(value, dict): + return True + # Pydantic v2 validator hook + if hasattr(annotation, "__get_pydantic_core_schema__"): + return _pydantic_v2_match(value, annotation) + # Pydantic v1 validator hook + if hasattr(annotation, "__get_validators__"): + return _pydantic_v1_match(value, annotation) + # Default isinstance + return isinstance(value, annotation) + + return False + + +# --------------------------------------------------------------------------- +# Helpers for outer_match +# --------------------------------------------------------------------------- + + +def _resolve_dataclass_hints(cls: type) -> Dict[str, Any]: + """Resolve forward references in a dataclass's type annotations. + + Uses ``get_type_hints`` with the class's module globals so that + forward references like ``ForwardRef('Floats3d')`` are resolved to + actual types. Returns an empty dict on failure. + """ + import sys + from typing import get_type_hints + + mod = sys.modules.get(cls.__module__) + globalns = vars(mod) if mod else None + try: + return get_type_hints(cls, globalns=globalns) + except (NameError, AttributeError, TypeError, RecursionError): # pragma: no cover + return {} # pragma: no cover + + +def _has_strict_metadata(metadata: Tuple[Any, ...]) -> bool: + """Check if ``Annotated`` metadata contains a ``Strict()`` marker.""" + return any(getattr(m, "strict", False) for m in metadata if hasattr(m, "strict")) + + +def _strict_match(value: Any, inner_type: type) -> bool: + """Exact type match for ``Annotated[X, Strict()]``. + + Unlike the normal coercion rules, strict matching requires + ``type(value)`` to be exactly the annotated type. + """ + if inner_type is int: + return type(value) is int and not isinstance(value, bool) + if inner_type is float: + return type(value) is float + if inner_type is str: + return type(value) is str + if inner_type is bool: + return type(value) is bool + return isinstance(value, inner_type) + + +class _AnySchemaHandler: + """Minimal stand-in for pydantic's ``GetCoreSchemaHandler``.""" + + def __call__( + self, _source_type: Any + ) -> Dict[str, Any]: # pragma: no cover -- called internally by pydantic hooks + return {"type": "any"} # pragma: no cover + + +def _pydantic_v2_match(value: Any, annotation: type) -> bool: + """Check *value* against a type with ``__get_pydantic_core_schema__``. + + Extracts the validator function from the pydantic core schema and calls + it. Falls back to ``isinstance`` if the value is already an instance. + """ + if isinstance(value, annotation): + return True + try: + schema = annotation.__get_pydantic_core_schema__( + annotation, _AnySchemaHandler() + ) + fn_entry = schema.get("function", {}) + validator = fn_entry.get("function") if isinstance(fn_entry, dict) else None + if callable(validator): + validator(value) + return True + except (ValueError, TypeError, AssertionError): + return False + return False + + +class _PydanticV1FieldShim: + """Minimal shim providing ``field.type_`` for pydantic v1 validators.""" + + def __init__(self, typ: type) -> None: + self.type_ = typ + + +def _pydantic_v1_match(value: Any, annotation: type) -> bool: + """Check *value* against a type with ``__get_validators__``. + + Iterates through the validators yielded by the type's + ``__get_validators__`` classmethod. Validators with more than + 2 parameters are skipped (they require a ``config`` argument we + don't have). + """ + if isinstance(value, annotation): + return True + shim = _PydanticV1FieldShim(annotation) + for validator in annotation.__get_validators__(): + try: + nparams = len(inspect.signature(validator).parameters) + except (ValueError, TypeError): + nparams = 1 + if nparams > 2: + continue + try: + value = validator(value) if nparams == 1 else validator(value, shim) + except (ValueError, TypeError, AssertionError): + return False + return True + + +# --------------------------------------------------------------------------- +# decompose: yield (child_value, child_annotation, child_ctx) triples +# --------------------------------------------------------------------------- + +#: Origins that are sequence-like: one type arg, fan across elements. +SEQUENCE_ORIGINS: set[Any] = { + list, + set, + frozenset, + collections.abc.Sequence, + collections.abc.MutableSequence, + collections.abc.Set, + collections.abc.MutableSet, + collections.abc.Iterable, + collections.abc.Iterator, +} + +#: Origins that are mapping-like: two type args (key, value). +MAPPING_ORIGINS: set[Any] = { + dict, + collections.abc.Mapping, + collections.abc.MutableMapping, +} + + +def decompose(value: Any, annotation: Any, ctx: Ctx) -> Iterator[Tuple[Any, Any, Ctx]]: + """Yield ``(child_value, child_annotation, child_ctx)`` triples. + + This is the recursive engine of the type checker. For container types + it fans out over elements; for dataclasses and schemas it fans out + over fields. ``outer_match`` has already confirmed the top-level + match, so we only need to check the children here. + """ + # Annotated[T, ...] — unwrap + if get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + + origin = get_origin(annotation) + args = get_args(annotation) + + # Schema / model_fields annotation with dict value — fan out over fields + if ( + isinstance(annotation, type) + and hasattr(annotation, "model_fields") + and isinstance(value, dict) + ): + for name, field_info in annotation.model_fields.items(): + data_key = name + if hasattr(field_info, "alias") and field_info.alias is not None: + data_key = field_info.alias + if data_key in value: + yield (value[data_key], field_info.annotation, ctx.child(data_key)) + return + + # Dataclass annotation with dict value — fan out over fields + if ( + isinstance(annotation, type) + and is_dataclass(annotation) + and isinstance(value, dict) + ): + resolved_hints = _resolve_dataclass_hints(annotation) + for f in dataclass_fields(annotation): + if f.name in value: + hint = resolved_hints.get(f.name, f.type) + yield (value[f.name], hint, ctx.child(f.name)) + return + + # Dataclass annotation with dataclass value — match fields + if ( + isinstance(annotation, type) + and is_dataclass(annotation) + and is_dataclass(value) + ): + resolved_hints = _resolve_dataclass_hints(annotation) + for f in dataclass_fields(annotation): + if hasattr(value, f.name): + hint = resolved_hints.get(f.name, f.type) + yield (getattr(value, f.name), hint, ctx.child(f.name)) + return + + # No type args means nothing to recurse into for generics + if not args: + return + + # Tuple: fixed-length or variable-length + if origin is tuple: + if len(args) == 2 and args[1] is Ellipsis: + # Tuple[int, ...] — variable length, one type + for i, elem in enumerate(value): + yield (elem, args[0], ctx.child(i)) + else: + # Tuple[int, str, float] — fixed length + # Length already verified in outer_match + for i, (elem, arg) in enumerate(zip(value, args)): + yield (elem, arg, ctx.child(i)) + return + + # Sequence-like + if origin in SEQUENCE_ORIGINS: + arg = args[0] + for i, elem in enumerate(value): + yield (elem, arg, ctx.child(i)) + return + + # Mapping-like + if origin in MAPPING_ORIGINS: + k_ann, v_ann = args[0], args[1] + for k in value: + yield (k, k_ann, ctx.child(f"key({k!r})")) + yield (value[k], v_ann, ctx.child(k)) + return + + # Type[X], Callable, TypeVar, Literal — no children to decompose + # (fully resolved in outer_match) + return diff --git a/confection/util.py b/confection/util.py index 2f9407e..cecab8b 100644 --- a/confection/util.py +++ b/confection/util.py @@ -1,6 +1,11 @@ import functools +import json +import re from copy import deepcopy -from typing import Any, Callable, Iterator, Protocol, TypeVar +from typing import Any, Callable, Protocol, TypeVar + +from ._constants import VARIABLE_RE +from ._errors import ConfigValidationError _DIn = TypeVar("_DIn") @@ -26,23 +31,6 @@ def partial( return partial_func -class Generator(Iterator): - """Custom generator type. Used to annotate function arguments that accept - generators so they can be validated by pydantic (which doesn't support - iterators/iterables otherwise). - """ - - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - if not hasattr(v, "__iter__") and not hasattr(v, "__next__"): - raise TypeError("not a valid iterator") - return v - - DEFAULT_FROZEN_DICT_ERROR = ( "Can't write to frozen dictionary. This is likely an internal " "error. Are you writing to a default function argument?" @@ -80,7 +68,7 @@ def __setitem__(self, key, value): def pop(self, key, default=None): raise NotImplementedError(self.error) - def update(self, other): # type: ignore[override] + def update(self, other=(), /, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride] raise NotImplementedError(self.error) def __deepcopy__(self, memo): @@ -133,3 +121,40 @@ def sort(self, *args, **kwargs): def __deepcopy__(self, memo): return self.__class__(deepcopy(v) for v in self) + + +def is_promise(obj) -> bool: + if not hasattr(obj, "keys"): + return False + id_keys = [k for k in obj.keys() if isinstance(k, str) and k.startswith("@")] + if len(id_keys): + return True + return False + + +def try_load_json(value: str) -> Any: + """Load a JSON string if possible, otherwise default to original value.""" + try: + return json.loads(value) + except Exception: + return value + + +def try_dump_json(value: Any, data: dict[str, dict] | str = "") -> str: + """Dump a config value as JSON and output user-friendly error if it fails.""" + # Special case if we have a variable: it's already a string so don't dump + # to preserve ${x:y} vs. "${x:y}" + if isinstance(value, str) and VARIABLE_RE.search(value): + return value + try: + value = json.dumps(value) + except Exception as e: + err_msg = ( + f"Couldn't serialize config value of type {type(value)}: {e}. Make " + f"sure all values in your config are JSON-serializable. If you want " + f"to include Python objects, use a registered function that returns " + f"the object instead." + ) + raise ConfigValidationError(config=data, desc=err_msg) from e + # Escape $ to $$ for configparser, but preserve ${...} variable references + return re.sub(r"\$(?!\{)", "$$", value) diff --git a/confection/validation.py b/confection/validation.py index 741e727..df99447 100644 --- a/confection/validation.py +++ b/confection/validation.py @@ -4,45 +4,34 @@ for config values against function signatures. """ -import collections.abc import inspect import sys -import types -from pathlib import PurePath -from types import GeneratorType -from typing import ( - Annotated, - Any, - Literal, - Optional, - TypeVar, - Union, - get_args, - get_origin, - get_type_hints, -) +from typing import Any, Optional, get_type_hints + +from .typechecker import Ctx +from .typechecker import check_type as _tc2_check_type # Optional pydantic imports — confection doesn't depend on pydantic, # but if it's installed we can detect and convert BaseModel schemas. try: - from pydantic.v1 import ( # pyright: ignore[reportMissingImports] - BaseModel as _PydanticV1BaseModel, + from pydantic.v1 import ( + BaseModel as _PydanticV1BaseModel, # pyright: ignore[reportMissingImports] ) - from pydantic.v1 import ( # pyright: ignore[reportMissingImports] - ValidationError as _PydanticV1ValidationError, + from pydantic.v1 import ( + ValidationError as _PydanticV1ValidationError, # pyright: ignore[reportMissingImports] ) -except (ImportError, ModuleNotFoundError): +except (ImportError, ModuleNotFoundError): # pragma: no cover _PydanticV1BaseModel = None # type: ignore[assignment,misc] _PydanticV1ValidationError = None # type: ignore[assignment,misc] try: - from pydantic import ( # pyright: ignore[reportMissingImports] - BaseModel as _PydanticV2BaseModel, + from pydantic import ( + BaseModel as _PydanticV2BaseModel, # pyright: ignore[reportMissingImports] ) - from pydantic import ( # pyright: ignore[reportMissingImports] - ValidationError as _PydanticV2ValidationError, + from pydantic import ( + ValidationError as _PydanticV2ValidationError, # pyright: ignore[reportMissingImports] ) -except (ImportError, ModuleNotFoundError): +except (ImportError, ModuleNotFoundError): # pragma: no cover _PydanticV2BaseModel = None # type: ignore[assignment,misc] _PydanticV2ValidationError = None # type: ignore[assignment,misc] @@ -132,18 +121,20 @@ def __init_subclass__(cls, **kwargs): if name in ("model_config", "model_fields") or name.startswith("_"): continue default = ... + alias = None for klass in cls.__mro__: if name in klass.__dict__: val = klass.__dict__[name] if isinstance(val, FieldInfo): default = val.default + alias = val.alias elif not isinstance( val, (type, classmethod, staticmethod, property) ): if not callable(val): default = val break - field = FieldInfo(default=default) + field = FieldInfo(default=default, alias=alias) field.annotation = annotation fields[name] = field @@ -304,13 +295,6 @@ def resolve_type_hints(func): # === Type Validation === -def _is_generator_value(value): - """Check if value is a generator/iterator that shouldn't be consumed.""" - return isinstance( - value, (GeneratorType, collections.abc.Iterator) - ) and not isinstance(value, (str, bytes)) - - def _error_type_for(annotation): """Get an error type string for an annotation.""" if annotation is int or annotation is PositiveInt: @@ -328,409 +312,13 @@ def validate_type(value, annotation): """Validate value against a type annotation. Returns None if valid, or an error message string if invalid. - Generators/iterators always pass through without validation. """ - if annotation is Any or annotation is inspect.Parameter.empty: - return None - - if annotation is type(None): - return None if value is None else "Input should be None" - - # Generators always pass through (they can't be validated without consumption) - if _is_generator_value(value): - return None - - origin = get_origin(annotation) - args = get_args(annotation) - - # Annotated[X, ...] -> validate against X, respecting Strict metadata - # get_origin returns None for Annotated in Python 3.10, so also check __metadata__ - if origin is Annotated or hasattr(annotation, "__metadata__"): - inner_type = get_args(annotation)[0] - metadata = getattr(annotation, "__metadata__", ()) - strict = any( - getattr(m, "strict", False) for m in metadata if hasattr(m, "strict") - ) - if strict: - if inner_type is int: - if type(value) is not int or isinstance(value, bool): - return "Input should be a valid integer (strict)" - return None - elif inner_type is float: - if type(value) is not float: - return "Input should be a valid float (strict)" - return None - elif inner_type is str: - if type(value) is not str: - return "Input should be a valid string (strict)" - return None - elif inner_type is bool: - if type(value) is not bool: - return "Input should be a valid boolean (strict)" - return None - return validate_type(value, inner_type) - - # Union / Optional (typing.Union and Python 3.10+ X | Y syntax) - if origin is Union or origin is types.UnionType: - for arg in args: - if validate_type(value, arg) is None: - return None - type_names = ", ".join(_type_display(a) for a in args) - return f"Input should be valid for union type: {type_names}" - - # Literal - if origin is Literal: - if value in args: - return None - return f"Input should be {' or '.join(repr(a) for a in args)}" - - # NewType: unwrap to the supertype - if callable(annotation) and hasattr(annotation, "__supertype__"): - return validate_type(value, annotation.__supertype__) # pyright: ignore[reportFunctionMemberAccess] - - # TypeVar: validate against bound or constraints - if isinstance(annotation, TypeVar): - if annotation.__bound__ is not None: - return validate_type(value, annotation.__bound__) - if annotation.__constraints__: - for constraint in annotation.__constraints__: - if validate_type(value, constraint) is None: - return None - names = ", ".join(_type_display(c) for c in annotation.__constraints__) - return f"Input should be valid for: {names}" - return None # unconstrained TypeVar accepts anything - - # Constrained types - if annotation is StrictBool: - if type(value) is not bool: - return "Input should be a valid boolean" - return None - - if annotation is PositiveInt: - if isinstance(value, bool) or not isinstance(value, int) or value <= 0: - return "Input should be a positive integer, greater than 0" + ctx = Ctx() + if _tc2_check_type(value, annotation, ctx=ctx): return None - - if annotation is StrictFloat: - if type(value) is not float: - return "Input should be a valid float" - return None - - # Generic types - if origin is not None: - return _validate_generic(value, origin, args) - - # Plain types - if isinstance(annotation, type): - # If annotation is a Schema subclass and value is a dict, validate against it - if issubclass(annotation, Schema) and isinstance(value, dict): - errors = _validate_schema( - value, - annotation.model_fields, - annotation.model_config, - annotation.model_config.get("alias_generator"), - ) - if errors: - return f"Input should be an instance of {annotation.__name__}" - return None - return _validate_plain_type(value, annotation) - - # Unknown annotation - accept - return None - - -def _validate_plain_type(value, typ): - """Validate value against a plain (non-generic) type.""" - if typ is bool: - if not isinstance(value, bool): - return "Input should be a valid boolean" - return None - - if typ is int: - if isinstance(value, bool): - return ( - "Input should be a valid integer, unable to parse string as an integer" - ) - if isinstance(value, int): - return None - if isinstance(value, str): - try: - int(value) - return None - except (ValueError, TypeError): - pass - return "Input should be a valid integer, unable to parse string as an integer" - - if typ is float: - if isinstance(value, bool): - return "Input should be a valid number" - if isinstance(value, (int, float)): - return None - if isinstance(value, str): - try: - float(value) - return None - except (ValueError, TypeError): - pass - return "Input should be a valid number" - - if typ is str: - if isinstance(value, str): - return None - return "Input should be a valid string" - - # Path: accept strings (pydantic coerces str → Path) - if issubclass(typ, PurePath): - if isinstance(value, (str, PurePath)): - return None - return "Input should be a valid path" - - # Custom class - isinstance check - try: - if isinstance(value, typ): - return None - # Types that declare custom validation via pydantic's schema protocol - # (e.g. thinc's Floats2d). Extract and call the validator directly - # — the hook returns a plain dict, no pydantic import needed. - if hasattr(typ, "__get_pydantic_core_schema__"): - return _call_pydantic_schema_validator(typ, value) - # pydantic v1 protocol: __get_validators__ yields single-arg - # validator functions. - if hasattr(typ, "__get_validators__"): - return _call_pydantic_v1_validators(typ, value) - # For constrained subtypes without validator hooks: if the - # annotation inherits from the value's type, the value is - # structurally compatible. - if issubclass(typ, type(value)): - return None - except TypeError: - return None - - return f"Input should be an instance of {getattr(typ, '__name__', str(typ))}" - - -class _AnySchemaHandler: - """Minimal stand-in for pydantic's GetCoreSchemaHandler. - - Passed to __get_pydantic_core_schema__ so we can extract the - validator function without importing pydantic. The handler is - only called as ``handler(source_type)`` and must return a core - schema dict — ``{"type": "any"}`` tells pydantic "accept anything" - which is the right inner schema for a plain validator. - """ - - def __call__(self, _source_type): - return {"type": "any"} - - -def _call_pydantic_schema_validator(typ, value): - """Call the validator from a type's __get_pydantic_core_schema__ hook. - - The hook returns a plain dict describing the schema. We extract the - validator function and call it directly — no pydantic import needed. - Returns None on success, or an error message string on failure. - """ - schema = typ.__get_pydantic_core_schema__(typ, _AnySchemaHandler()) - # Navigate the schema dict to find the validator function. - # Typical shapes: - # {"type": "function-after", "function": {"type": "no-info", "function": }, ...} - # {"type": "function-plain", "function": {"type": "no-info", "function": }} - fn_entry = schema.get("function", {}) - if isinstance(fn_entry, dict): - validator = fn_entry.get("function") - else: - return None # unrecognised shape — can't extract validator - if not callable(validator): - return None - try: - validator(value) - except (ValueError, TypeError, AssertionError) as e: - return str(e) - return None - - -def _call_pydantic_v1_validators(typ, value): - """Call validators from a type's __get_validators__ hook (pydantic v1). - - Validators may accept 1 arg (value), 2 args (value, field), or - 3 args (value, field, config). For multi-arg validators that need - field metadata (like number constraints), we build a minimal shim. - """ - for validator in typ.__get_validators__(): - try: - nparams = len(inspect.signature(validator).parameters) - except (ValueError, TypeError): - nparams = 1 - if nparams > 2: - continue # skip validators requiring pydantic config objects - try: - if nparams == 1: - value = validator(value) - else: - value = validator(value, _PydanticV1FieldShim(typ)) - except (ValueError, TypeError, AssertionError) as e: - return str(e) - return None - - -class _PydanticV1FieldShim: - """Minimal shim for pydantic v1 ModelField, providing just enough - for constraint validators (number_size_validator etc.).""" - - def __init__(self, typ): - self.type_ = typ - - -def _validate_generic(value, origin, args): - """Validate value against a generic type (List[X], Dict[K,V], etc.).""" - # Callable - if origin is collections.abc.Callable: - if callable(value): - return None - return "Input should be callable" - - # list / List[X] - if origin is list: - if not isinstance(value, list): - return "Input should be a valid list" - if args: - for i, item in enumerate(value): - err = validate_type(item, args[0]) - if err: - return f"Item {i}: {err}" - return None - - # dict / Dict[K, V] - if origin is dict: - if not isinstance(value, dict): - return "Input should be a valid dictionary" - if args and len(args) == 2: - for k, v in value.items(): - err = validate_type(k, args[0]) - if err: - return f"Key {k!r}: {err}" - err = validate_type(v, args[1]) - if err: - return f"Value for {k!r}: {err}" - return None - - # tuple / Tuple — Tuple[int, str] (fixed) vs Tuple[int, ...] (variable) - if origin is tuple: - if not isinstance(value, (tuple, list)): - return "Input should be a valid tuple" - if args: - if len(args) == 2 and args[1] is Ellipsis: - # Tuple[X, ...] — variable-length, all elements same type - for i, item in enumerate(value): - err = validate_type(item, args[0]) - if err: - return f"Item {i}: {err}" - elif args != ((),): - # Tuple[X, Y, Z] — fixed-length positional - if len(value) != len(args): - return f"Expected {len(args)} items in tuple, got {len(value)}" - for i, (item, expected) in enumerate(zip(value, args)): - err = validate_type(item, expected) - if err: - return f"Item {i}: {err}" - return None - - # set / Set[X] - if origin is set: - if not isinstance(value, set): - return "Input should be a valid set" - if args: - for item in value: - err = validate_type(item, args[0]) - if err: - return f"Set item: {err}" - return None - - # frozenset / FrozenSet[X] - if origin is frozenset: - if not isinstance(value, frozenset): - return "Input should be a valid frozenset" - if args: - for item in value: - err = validate_type(item, args[0]) - if err: - return f"Frozenset item: {err}" - return None - - # Sequence - if origin is collections.abc.Sequence: - if isinstance(value, (list, tuple)): - if args: - for i, item in enumerate(value): - err = validate_type(item, args[0]) - if err: - return f"Item {i}: {err}" - return None - if isinstance(value, str): - return None - return "Input should be a valid sequence" - - # Iterable - if origin is collections.abc.Iterable: - if hasattr(value, "__iter__"): - return None - return "Input should be iterable" - - # Mapping (covers Mapping, MutableMapping, OrderedDict, etc.) - if isinstance(origin, type) and issubclass(origin, collections.abc.Mapping): - if isinstance(value, collections.abc.Mapping): - return None - return "Input should be a valid mapping" - - # AbstractSet (covers Set, FrozenSet, MutableSet ABCs) - if isinstance(origin, type) and issubclass(origin, collections.abc.Set): - if isinstance(value, collections.abc.Set): - return None - return "Input should be a valid set" - - # Iterator / Generator - if isinstance(origin, type) and issubclass( - origin, (collections.abc.Iterator, collections.abc.Generator) - ): - if hasattr(value, "__next__") or hasattr(value, "__iter__"): - return None - return "Input should be an iterator" - - # Type[X] — check value is a class and optionally a subclass of X - if origin is type: - if not isinstance(value, type): - return "Input should be a type" - if args and args[0] is not Any: - expected = args[0] - try: - if not issubclass(value, expected): - return ( - f"Input should be a subclass of" - f" {getattr(expected, '__name__', expected)}" - ) - except TypeError: - pass # expected is not a class (e.g. Union) — skip - return None - - # For any other generic - try isinstance against origin - if isinstance(origin, type): - try: - if isinstance(value, origin): - return None - except TypeError: - return None - return f"Input should be an instance of {origin.__name__}" - - return None - - -def _type_display(annotation): - """Human-readable name for a type.""" - if annotation is type(None): - return "None" - if hasattr(annotation, "__name__"): - return annotation.__name__ - return str(annotation) + if ctx.errors: + return str(ctx.errors[0]) + return f"{value!r} does not match {annotation}" # pragma: no cover -- defensive fallback # === Schema Validation === @@ -848,8 +436,8 @@ def _extract_pydantic_fields(pydantic_cls): fields = {} if hasattr(pydantic_cls, "model_fields"): - # pydantic v2 interface — check before __fields__ because pydantic v2 - # exposes __fields__ as a deprecated property that triggers warnings + # pydantic v2 interface (check first — v2 also exposes __fields__ + # as a deprecated shim, so we must not fall into the v1 branch) for name, pyd_field in pydantic_cls.model_fields.items(): annotation = pyd_field.annotation if pyd_field.is_required(): @@ -899,13 +487,7 @@ def _extract_pydantic_config(pydantic_cls): """Extract model config from a pydantic BaseModel class (v1 or v2).""" config = {"extra": "allow"} - if hasattr(pydantic_cls, "model_config") and isinstance( - pydantic_cls.model_config, dict - ): - # pydantic v2: dict — check before __config__ because pydantic v2 - # exposes __config__ as a deprecated property that triggers warnings - config = dict(pydantic_cls.model_config) - elif hasattr(pydantic_cls, "__config__"): + if hasattr(pydantic_cls, "__config__"): # pydantic v1: inner class Config cfg = pydantic_cls.__config__ extra = getattr(cfg, "extra", "allow") @@ -915,6 +497,11 @@ def _extract_pydantic_config(pydantic_cls): config["extra"] = extra if isinstance(extra, str) else str(extra) if hasattr(cfg, "arbitrary_types_allowed"): config["arbitrary_types_allowed"] = cfg.arbitrary_types_allowed + elif hasattr(pydantic_cls, "model_config") and isinstance( + pydantic_cls.model_config, dict + ): + # pydantic v2: dict + config = dict(pydantic_cls.model_config) return config @@ -964,10 +551,12 @@ def _pydantic_model_validate(cls, data): pyd_cls.model_validate(data) elif hasattr(pyd_cls, "parse_obj"): pyd_cls.parse_obj(data) - else: - pyd_cls(**data) + else: # pragma: no cover -- all pydantic versions have model_validate or parse_obj + pyd_cls(**data) # pragma: no cover except pyd_validation_err as e: - raise ValidationError(e.errors()) from None # pyright: ignore[reportAttributeAccessIssue] + raise ValidationError( + e.errors() # pyright: ignore[reportAttributeAccessIssue] + ) from None # Return attribute-accessible result with defaults filled in result_data = dict(data) for name, field in cls.model_fields.items(): diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..6344201 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,4 @@ +{ + "exclude": ["confection/tests"], + "reportMissingImports": "warning" +} diff --git a/setup.cfg b/setup.cfg index 8614aa7..7b16a5f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,6 +29,7 @@ zip_safe = true include_package_data = true python_requires = >=3.10 install_requires = + typing_extensions>=4.0,<5.0; python_version < "3.11" [sdist] formats = gztar diff --git a/confection/tests/__init__.py b/tests/conftest.py similarity index 100% rename from confection/tests/__init__.py rename to tests/conftest.py diff --git a/tests/strategies.py b/tests/strategies.py new file mode 100644 index 0000000..6b9a5dc --- /dev/null +++ b/tests/strategies.py @@ -0,0 +1,256 @@ +import json + +from hypothesis import strategies as st + +# Valid config keys: simple identifiers, no dots or special configparser chars +config_keys = st.from_regex(r"[a-z][a-z0-9_]{0,15}", fullmatch=True) + +# Strings safe for config values (no $ to avoid interpolation, no " or \ to +# avoid JSON escaping issues) +config_strings = st.text( + st.characters(whitelist_categories=("L", "N", "Z"), blacklist_characters='$"\\'), + min_size=0, + max_size=20, +) + +# Scalar leaf values: str, int, float, bool, None +scalar_leaves = st.one_of( + config_strings, + st.integers(min_value=-(2**31), max_value=2**31), + st.floats(allow_nan=False, allow_infinity=False), + st.booleans(), + st.none(), +) + +# JSON leaf values: scalars, or lists/dicts of scalars (arbitrarily nested). +# These are values that get serialized as a single JSON-encoded string in the +# config, rather than expanded into subsections. +_json_values = st.recursive( + scalar_leaves, + lambda children: st.one_of( + st.lists(children, max_size=5), + st.dictionaries(config_keys, children, max_size=5), + ), + max_leaves=10, +) +# Only the non-scalar cases are interesting as "json leaves" — scalars are +# already covered by scalar_leaves. +json_leaves = st.one_of( + st.lists(scalar_leaves, min_size=0, max_size=5), + st.dictionaries(config_keys, _json_values, min_size=0, max_size=5), +) + +# Basic config leaves (no JSON-encoded collections) +config_leaves = scalar_leaves + +# A config node is either a leaf or a dict of config nodes. +config_nodes = st.recursive( + config_leaves, + lambda children: st.dictionaries(config_keys, children, min_size=0, max_size=4), + max_leaves=15, +) + +# A valid config must have sections at the top level (all values must be dicts). +config_dicts = st.dictionaries( + config_keys, + st.dictionaries(config_keys, config_nodes, min_size=0, max_size=5), + min_size=1, + max_size=5, +) + + +def _leaf_to_str(value): + """Serialize a leaf value the way confection's config format expects.""" + if value is None: + return "null" + elif isinstance(value, bool): + return "true" if value else "false" + elif isinstance(value, str): + return json.dumps(value) + elif isinstance(value, float): + return json.dumps(value) + elif isinstance(value, int): + return str(value) + elif isinstance(value, (list, dict)): + return json.dumps(value) + raise TypeError(f"Unexpected leaf type: {type(value)}") + + +def _flatten_sections(data, prefix="", inline_paths=frozenset()): + """Convert a nested dict into a list of (section_path, {key: leaf_str}) pairs. + + Dicts at paths in inline_paths are serialized as inline JSON values rather + than expanded into subsections. + """ + sections = [] + leaves = {} + for key, value in data.items(): + key_path = f"{prefix}.{key}" if prefix else key + if isinstance(value, dict) and key_path not in inline_paths: + sections.extend(_flatten_sections(value, key_path, inline_paths)) + else: + leaves[key] = _leaf_to_str(value) + if prefix: + sections.insert(0, (prefix, leaves)) + return sections + + +def serialize_with_inline(data, inline_paths=frozenset()): + """Serialize a nested dict to a config string, inlining dicts at the given paths.""" + sections = _flatten_sections(data, inline_paths=inline_paths) + parts = [] + for section_name, leaves in sections: + parts.append(f"[{section_name}]") + for key, value_str in leaves.items(): + parts.append(f"{key} = {value_str}") + parts.append("") + return "\n".join(parts).strip() + + +@st.composite +def json_config_dicts(draw): + """Strategy that produces (data, inline_paths) pairs. + + The data is a nested dict suitable for Config. inline_paths is a set of + dotted paths where the value is a JSON-encoded leaf (list or dict) rather + than a config subsection. This distinction matters for serialization: the + library always expands dicts into subsections, but a valid config string + could also have them as inline JSON values. + """ + # Start with a basic config dict (sections with scalar leaves) + base = draw(config_dicts) + inline_paths = set() + + # Sprinkle some JSON leaf values into the config + for section_key, section in list(base.items()): + for key in list(section.keys()): + if draw(st.booleans()): + # Replace some leaves with JSON-encoded values + value = draw(json_leaves) + section[key] = value + inline_paths.add(f"{section_key}.{key}") + + return base, inline_paths + + +def _collect_scalar_paths(data, prefix=""): + """Collect all (dotted_path, value) pairs for scalar leaves in a config dict.""" + paths = [] + for key, value in data.items(): + path = f"{prefix}.{key}" if prefix else key + if isinstance(value, dict): + paths.extend(_collect_scalar_paths(value, path)) + elif not isinstance(value, (list, dict)): + paths.append((path, value)) + return paths + + +def _set_at_path(data, path, value): + """Set a value at a dotted path in a nested dict.""" + parts = path.split(".") + node = data + for part in parts[:-1]: + node = node[part] + node[parts[-1]] = value + + +def _deep_copy(obj): + """Simple deep copy for nested dicts/lists of scalars.""" + if isinstance(obj, dict): + return {k: _deep_copy(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_deep_copy(v) for v in obj] + return obj + + +def _make_interpolated_config_str(sections, replacements): + """Build a config string from sections, applying variable replacements.""" + modified = {name: dict(leaves) for name, leaves in sections} + for replace_path, target_path in replacements.items(): + section, key = replace_path.rsplit(".", 1) + if section in modified: + modified[section][key] = "${" + target_path + "}" + parts = [] + for section_name, _ in sections: + parts.append(f"[{section_name}]") + for key, val in modified.get(section_name, {}).items(): + parts.append(f"{key} = {val}") + parts.append("") + return "\n".join(parts).strip() + + +@st.composite +def interpolated_config(draw): + """Strategy producing (config_str, expected_dict) with variable interpolation. + + Replaces some scalar leaves with ${section.key} references. Avoids cycles + by only referencing values that are never themselves replaced. + """ + base = draw(config_dicts) + scalar_paths = _collect_scalar_paths(base) + # Filter out bools — they serialize as true/false which can't be + # referenced via ${} (configparser treats them specially) + scalar_paths = [(p, v) for p, v in scalar_paths if not isinstance(v, bool)] + if len(scalar_paths) < 2: + return serialize_with_inline(base), base + + sections = _flatten_sections(base) + expected = _deep_copy(base) + + # Split paths into targets (stable values to reference) and candidates + # (values that may be replaced with refs). A path can't be both. + # Use random subset selection instead of permutations (which is O(n!)). + target_flags = draw( + st.lists( + st.booleans(), + min_size=len(scalar_paths), + max_size=len(scalar_paths), + ) + ) + targets = [sp for sp, flag in zip(scalar_paths, target_flags) if flag] + candidates = [sp for sp, flag in zip(scalar_paths, target_flags) if not flag] + + if not targets or not candidates: + return serialize_with_inline(base), base + + replacements = {} + for cand_path, _ in candidates: + if draw(st.booleans()): + target_path, target_value = draw(st.sampled_from(targets)) + replacements[cand_path] = target_path + _set_at_path(expected, cand_path, target_value) + + config_str = _make_interpolated_config_str(sections, replacements) + return config_str, expected + + +@st.composite +def circular_interpolated_config(draw): + """Strategy producing config strings with circular variable references.""" + base = draw(config_dicts) + scalar_paths = _collect_scalar_paths(base) + scalar_paths = [(p, v) for p, v in scalar_paths if not isinstance(v, bool)] + if len(scalar_paths) < 2: + from hypothesis import assume + + assume(False) + + sections = _flatten_sections(base) + + # Pick 2+ paths and create a cycle: a -> b -> ... -> a + cycle_len = draw(st.integers(min_value=2, max_value=min(4, len(scalar_paths)))) + cycle_indices = draw( + st.lists( + st.sampled_from(range(len(scalar_paths))), + min_size=cycle_len, + max_size=cycle_len, + unique=True, + ) + ) + cycle_paths = [scalar_paths[i][0] for i in cycle_indices] + + replacements = {} + for i in range(cycle_len): + replacements[cycle_paths[i]] = cycle_paths[(i + 1) % cycle_len] + + return _make_interpolated_config_str(sections, replacements) diff --git a/tests/test_config_basic.py b/tests/test_config_basic.py new file mode 100644 index 0000000..332cb47 --- /dev/null +++ b/tests/test_config_basic.py @@ -0,0 +1,145 @@ +"""Test basic config sections. + +* No variable interpolation +* No promises + +Just basic structure and JSON-encoded values. +""" + +from configparser import InterpolationDepthError + +import pytest +from hypothesis import HealthCheck, given, settings + +from confection import Config +from tests.strategies import ( + circular_interpolated_config, + config_dicts, + interpolated_config, + json_config_dicts, + serialize_with_inline, +) + + +@given(config_dicts) +def test_roundtrip(data): + """Config.from_str(config.to_str()) should reproduce the original dict.""" + config = Config(data) + serialized = config.to_str(interpolate=False) + restored = Config().from_str(serialized, interpolate=False) + assert dict_equal(restored, data) + + +@given(json_config_dicts()) +def test_json_leaves_parse(pair): + """Config strings with inline JSON values (lists, dicts) should parse + to the same nested dict structure.""" + data, inline_paths = pair + config_str = serialize_with_inline(data, inline_paths) + restored = Config().from_str(config_str, interpolate=False) + assert dict_equal(restored, data) + + +@given(interpolated_config()) +def test_variable_interpolation(pair): + """Config strings with ${section.key} variable references should resolve + to the referenced values after interpolation.""" + config_str, expected = pair + restored = Config().from_str(config_str, interpolate=True) + assert dict_equal(restored, expected) + + +@settings(suppress_health_check=[HealthCheck.too_slow]) +@given(circular_interpolated_config()) +def test_circular_interpolation_raises(config_str): + """Circular variable references should raise an error.""" + with pytest.raises((InterpolationDepthError, Exception)): + Config().from_str(config_str, interpolate=True) + + +def test_star_sections_parse(): + """[section.*.name] creates a dict under the "*" key.""" + result = Config().from_str( + """ +[section] + +[section.*.first] +x = 1 + +[section.*.second] +x = 2 +""", + interpolate=False, + ) + assert result["section"]["*"] == {"first": {"x": 1}, "second": {"x": 2}} + + +def test_star_sections_roundtrip(): + """Configs with * sections roundtrip through to_str/from_str.""" + original = Config().from_str( + """ +[section] + +[section.*.a] +x = 1 +y = "hello" + +[section.*.b] +x = 2 +y = "world" +""", + interpolate=False, + ) + serialized = original.to_str(interpolate=False) + restored = Config().from_str(serialized, interpolate=False) + assert dict_equal(restored, original) + + +def test_star_sections_nested(): + """* sections can appear at different levels of nesting.""" + result = Config().from_str( + """ +[top] + +[top.*.item] +val = 1 + +[top.*.item.sub] +val = 2 +""", + interpolate=False, + ) + assert result["top"]["*"]["item"]["val"] == 1 + assert result["top"]["*"]["item"]["sub"] == {"val": 2} + + +def test_star_with_interpolation(): + """Variable interpolation works across * sections.""" + result = Config().from_str( + """ +[settings] +lr = 0.001 + +[models] + +[models.*.first] +learning_rate = ${settings.lr} +""", + interpolate=True, + ) + assert result["models"]["*"]["first"]["learning_rate"] == 0.001 + + +def dict_equal(a, b) -> bool: + """Recursively compare two nested dicts, treating empty dicts as equal.""" + if type(a) is not type(b) and not (isinstance(a, dict) and isinstance(b, dict)): + return a == b + if isinstance(a, dict): + all_keys = set(a) | set(b) + for key in all_keys: + av = a.get(key, {}) + bv = b.get(key, {}) + if not dict_equal(av, bv): + return False + return True + return a == b diff --git a/tests/test_config_errors.py b/tests/test_config_errors.py new file mode 100644 index 0000000..18b417f --- /dev/null +++ b/tests/test_config_errors.py @@ -0,0 +1,159 @@ +"""Test error cases in config parsing and validation.""" + +import pytest + +from confection import Config +from confection._errors import ConfigValidationError + + +def test_malformed_config_string(): + """Malformed config syntax should raise ConfigValidationError.""" + with pytest.raises(ConfigValidationError, match="formatted correctly"): + Config().from_str("[[invalid", interpolate=False) + + +def test_malformed_missing_section_header(): + """Values without any section header are malformed.""" + with pytest.raises(ConfigValidationError, match="formatted correctly"): + Config().from_str("key = value", interpolate=False) + + +def test_top_level_values_without_section(): + """Top-level values that leak into DEFAULT should be caught.""" + # configparser treats values before any section header as defaults. + # This is caught by read_string as a MissingSectionHeaderError -> ParsingError. + with pytest.raises(ConfigValidationError): + Config().from_str("x = 1\n[section]\ny = 2", interpolate=False) + + +def test_missing_parent_section(): + """Dotted section names require all parent sections to exist.""" + with pytest.raises(ConfigValidationError, match="not defined"): + Config().from_str("[a.b.c]\nx = 1", interpolate=False) + + +def test_missing_intermediate_section(): + """[a] and [a.b.c] without [a.b] should fail.""" + with pytest.raises(ConfigValidationError, match="not defined"): + Config().from_str("[a]\nx = 1\n\n[a.b.c]\ny = 2", interpolate=False) + + +def test_key_subsection_conflict(): + """A key that conflicts with a subsection name should be caught.""" + with pytest.raises(ConfigValidationError, match="conflicting"): + Config().from_str("[a]\nb = 1\n\n[a.b]\nc = 2", interpolate=False) + + +def test_override_without_dot(): + """Overrides must have a dotted path (section.key).""" + with pytest.raises(ConfigValidationError, match="overrid"): + Config().from_str("[a]\nx = 1", interpolate=False, overrides={"x": 2}) + + +def test_override_nonexistent_section(): + """Overrides for nonexistent sections should fail.""" + with pytest.raises(ConfigValidationError, match="overrid"): + Config().from_str("[a]\nx = 1", interpolate=False, overrides={"b.x": 2}) + + +def test_override_applies(): + """Valid overrides should replace values.""" + result = Config().from_str("[a]\nx = 1", interpolate=False, overrides={"a.x": 99}) + assert result["a"]["x"] == 99 + + +def test_interpolation_missing_variable(): + """Referencing a nonexistent variable should raise.""" + with pytest.raises(Exception): + Config().from_str("[a]\nx = ${b.y}", interpolate=True) + + +def test_section_ref_in_string(): + """Referencing a whole section inside a string should raise.""" + with pytest.raises(ConfigValidationError, match="Can't reference whole sections"): + Config().from_str( + '[defaults]\nlr = 0.001\n\n[a]\nx = "hello ${defaults}"', + interpolate=True, + ) + + +def test_section_reference_resolves(): + """${section} references should resolve to the section dict.""" + result = Config().from_str( + "[defaults]\nlr = 0.001\n\n[a]\nsettings = ${defaults}", + interpolate=True, + ) + assert result["a"]["settings"] == {"lr": 0.001} + + +def test_override_replaces_promise_section(): + """Overriding a promise section with a different promise should replace it.""" + result = Config().from_str( + """ +[a] + +[a.scorer] +@scorers = "old_scorer.v1" +""", + interpolate=False, + overrides={"a.scorer": {"@scorers": "new_scorer.v1"}}, + ) + assert result["a"]["scorer"]["@scorers"] == "new_scorer.v1" + + +def test_uninterpolated_variable_preserved(): + """With interpolate=False, variable references should stay as strings.""" + result = Config().from_str("[a]\nx = 1\n\n[b]\ny = ${a.x}", interpolate=False) + assert result["b"]["y"] == "${a.x}" + + +def test_single_quoted_string_warns(): + """Single-quoted values should emit a warning about JSON formatting.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + Config().from_str("[a]\nx = 'hello'", interpolate=True) + assert any("single-quoted" in str(warning.message) for warning in w) + + +def test_dollar_dollar_escape(): + """$$ in config values should produce a literal $.""" + result = Config().from_str('[a]\nx = "$$100"', interpolate=True) + assert result["a"]["x"] == "$100" + + +def test_bad_interpolation_syntax(): + """Malformed ${...} reference should raise.""" + with pytest.raises(Exception): + Config().from_str("[a]\nx = ${", interpolate=True) + + +def test_bare_dollar_raises(): + """A bare $ not followed by $ or { should raise.""" + with pytest.raises(Exception): + Config().from_str('[a]\nx = "$x"', interpolate=True) + + +def test_same_section_variable(): + """${key} without a section prefix references the same section.""" + result = Config().from_str("[a]\nx = 1\ny = ${x}", interpolate=True) + assert result["a"]["y"] == 1 + + +def test_string_interpolation_coerces_values(): + """Non-string values interpolated into strings should be coerced.""" + result = Config().from_str( + '[a]\nx = 42\n\n[b]\ny = "value is ${a.x}"', + interpolate=True, + ) + assert result["b"]["y"] == "value is 42" + + +def test_string_interpolation_unwraps_json_strings(): + """JSON strings interpolated into compound expressions should be unwrapped.""" + result = Config().from_str( + '[a]\nx = "hello"\n\n[b]\ny = "${a.x} world"', + interpolate=True, + ) + assert result["b"]["y"] == "hello world" diff --git a/tests/test_config_serialization.py b/tests/test_config_serialization.py new file mode 100644 index 0000000..a3bda95 --- /dev/null +++ b/tests/test_config_serialization.py @@ -0,0 +1,234 @@ +"""Test Config serialization methods: to_str, from_str, to_bytes, from_bytes, +to_disk, from_disk, copy, merge, interpolate.""" + +import pytest + +from confection import Config +from confection._errors import ConfectionError + +# -- to_str / from_str basics (beyond the hypothesis tests) -- + + +def test_to_str_from_str_roundtrip(): + data = {"training": {"lr": 0.001, "epochs": 10}, "model": {"name": "cnn"}} + config = Config(data) + restored = Config().from_str(config.to_str(interpolate=False), interpolate=False) + assert dict(restored["training"]) == {"lr": 0.001, "epochs": 10} + assert dict(restored["model"]) == {"name": "cnn"} + + +def test_empty_config(): + config = Config() + assert config.to_str(interpolate=False) == "" + restored = Config().from_str("", interpolate=False) + assert dict(restored) == {} + + +# -- to_bytes / from_bytes -- + + +def test_to_bytes(): + config = Config({"a": {"x": 1}}) + b = config.to_bytes(interpolate=False) + assert isinstance(b, bytes) + assert b"[a]" in b + assert b"x = 1" in b + + +def test_from_bytes_roundtrip(): + config = Config({"a": {"x": 1, "y": "hello"}}) + b = config.to_bytes(interpolate=False) + restored = Config().from_bytes(b, interpolate=False) + assert restored["a"]["x"] == 1 + assert restored["a"]["y"] == "hello" + + +def test_from_bytes_with_overrides(): + config = Config({"a": {"x": 1}}) + b = config.to_bytes(interpolate=False) + restored = Config().from_bytes(b, interpolate=False, overrides={"a.x": 99}) + assert restored["a"]["x"] == 99 + + +# -- to_disk / from_disk -- + + +def test_to_disk_from_disk_roundtrip(tmp_path): + config = Config({"section": {"key": "value", "num": 42}}) + path = tmp_path / "config.cfg" + config.to_disk(path, interpolate=False) + assert path.exists() + restored = Config().from_disk(path, interpolate=False) + assert restored["section"]["key"] == "value" + assert restored["section"]["num"] == 42 + + +def test_to_disk_str_path(tmp_path): + config = Config({"a": {"x": 1}}) + path = str(tmp_path / "config.cfg") + config.to_disk(path, interpolate=False) + restored = Config().from_disk(path, interpolate=False) + assert restored["a"]["x"] == 1 + + +def test_from_disk_with_overrides(tmp_path): + config = Config({"a": {"x": 1}}) + path = tmp_path / "config.cfg" + config.to_disk(path, interpolate=False) + restored = Config().from_disk(path, interpolate=False, overrides={"a.x": 99}) + assert restored["a"]["x"] == 99 + + +# -- copy -- + + +def test_copy_is_deep(): + config = Config({"a": {"x": [1, 2, 3]}}) + copied = config.copy() + copied["a"]["x"].append(4) + assert config["a"]["x"] == [1, 2, 3] + assert copied["a"]["x"] == [1, 2, 3, 4] + + +def test_copy_preserves_metadata(): + config = Config({"a": {"x": 1}}, is_interpolated=False, section_order=["a"]) + copied = config.copy() + assert copied.is_interpolated is False + assert copied.section_order == ["a"] + + +# -- interpolate -- + + +def test_interpolate(): + config = Config().from_str( + "[a]\nx = 1\n\n[b]\ny = ${a.x}", + interpolate=False, + ) + assert config["b"]["y"] == "${a.x}" + interpolated = config.interpolate() + assert interpolated["b"]["y"] == 1 + # Original should be unchanged + assert config["b"]["y"] == "${a.x}" + + +def test_interpolate_returns_new_config(): + config = Config().from_str("[a]\nx = 1", interpolate=False) + interpolated = config.interpolate() + assert interpolated is not config + + +# -- merge -- + + +def test_merge_basic(): + base = Config({"a": {"x": 1, "y": 2}}) + updates = {"a": {"x": 99}} + merged = base.merge(updates) + assert merged["a"]["x"] == 99 + assert merged["a"]["y"] == 2 + + +def test_merge_adds_new_keys(): + base = Config({"a": {"x": 1}}) + updates = {"a": {"x": 1, "y": 2}} + merged = base.merge(updates) + assert merged["a"]["y"] == 2 + + +def test_merge_remove_extra(): + """remove_extra filters keys from updates that aren't in defaults.""" + base = Config({"a": {"x": 1}}) + updates = {"a": {"x": 99, "extra": "gone"}} + merged = base.merge(updates, remove_extra=True) + assert merged["a"]["x"] == 99 + assert "extra" not in merged["a"] + + +def test_merge_does_not_mutate_original(): + base = Config({"a": {"x": 1}}) + updates = {"a": {"x": 99}} + base.merge(updates) + assert base["a"]["x"] == 1 + + +def test_merge_deep(): + base = Config({"a": {"sub": {"x": 1, "y": 2}}}) + updates = {"a": {"sub": {"x": 99}}} + merged = base.merge(updates) + assert merged["a"]["sub"]["x"] == 99 + assert merged["a"]["sub"]["y"] == 2 + + +# -- __init__ -- + + +def test_init_from_dict(): + config = Config({"a": {"x": 1}}) + assert config["a"]["x"] == 1 + assert config.is_interpolated is True + + +def test_init_from_config(): + original = Config({"a": {"x": 1}}, is_interpolated=False, section_order=["a"]) + copy = Config(original) + assert copy["a"]["x"] == 1 + assert copy.is_interpolated is False + assert copy.section_order == ["a"] + + +def test_init_bad_data(): + with pytest.raises(ConfectionError, match="Expected dict"): + Config([1, 2, 3]) + + +def test_init_override_metadata(): + original = Config({"a": {"x": 1}}, is_interpolated=False) + copy = Config(original, is_interpolated=True) + assert copy.is_interpolated is True + + +# -- overrides with interpolation -- + +# -- deep_merge_configs edge cases -- + + +def test_merge_leaf_vs_dict(): + """When config has a leaf where defaults has a dict, skip merging.""" + base = Config({"a": {"sub": {"x": 1}}}) + updates = Config({"a": {"sub": "leaf_value"}}) + merged = base.merge(updates) + assert merged["a"]["sub"] == "leaf_value" + + +def test_merge_different_promise(): + """Blocks with different @registry functions should not merge.""" + base = Config({"a": {"@cats": "meow.v1", "x": 1}}) + updates = Config({"a": {"@cats": "woof.v1", "y": 2}}) + merged = base.merge(updates) + # Updates win — base defaults not merged in because different function + assert merged["a"]["@cats"] == "woof.v1" + assert merged["a"]["y"] == 2 + assert "x" not in merged["a"] + + +def test_merge_promise_not_in_defaults(): + """Promise in updates but not in defaults should not merge defaults.""" + base = Config({"a": {"x": 1}}) + updates = Config({"a": {"@cats": "meow.v1", "y": 2}}) + merged = base.merge(updates) + assert merged["a"]["@cats"] == "meow.v1" + assert "x" not in merged["a"] + + +# -- overrides with interpolation -- + + +def test_overrides_with_interpolation(): + result = Config().from_str( + "[a]\nx = 1\n\n[b]\ny = ${a.x}", + interpolate=True, + overrides={"a.x": 42}, + ) + assert result["a"]["x"] == 42 + assert result["b"]["y"] == 42 diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..7fad649 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,25 @@ +"""Test _errors.py edge cases.""" + +from confection._errors import ConfigValidationError + + +def test_from_error(): + original = ConfigValidationError( + errors=[{"loc": ["a"], "msg": "bad"}], + title="Original", + desc="original desc", + ) + new = ConfigValidationError.from_error(original, title="New title") + assert new.title == "New title" + assert new.desc == "original desc" + assert new.errors == original.errors + + +def test_format_with_parent(): + err = ConfigValidationError( + errors=[{"loc": ["x"], "msg": "bad"}], + title="Error", + parent="section", + ) + assert "section" in err.text + assert "-> x" in err.text diff --git a/confection/tests/test_issue52.py b/tests/test_issue52.py similarity index 100% rename from confection/tests/test_issue52.py rename to tests/test_issue52.py diff --git a/confection/tests/test_issue54.py b/tests/test_issue54.py similarity index 100% rename from confection/tests/test_issue54.py rename to tests/test_issue54.py diff --git a/confection/tests/test_issue58.py b/tests/test_issue58.py similarity index 100% rename from confection/tests/test_issue58.py rename to tests/test_issue58.py diff --git a/confection/tests/test_issue59.py b/tests/test_issue59.py similarity index 100% rename from confection/tests/test_issue59.py rename to tests/test_issue59.py diff --git a/confection/tests/test_issue61.py b/tests/test_issue61.py similarity index 100% rename from confection/tests/test_issue61.py rename to tests/test_issue61.py diff --git a/confection/tests/test_pydantic_shim.py b/tests/test_pydantic_shim.py similarity index 60% rename from confection/tests/test_pydantic_shim.py rename to tests/test_pydantic_shim.py index 2ab0313..2914d2a 100644 --- a/confection/tests/test_pydantic_shim.py +++ b/tests/test_pydantic_shim.py @@ -26,7 +26,7 @@ validator, ) except ImportError: - from pydantic import ( # type: ignore + from pydantic import ( # type: ignore # noqa: F401 BaseModel, Field, StrictFloat, @@ -35,9 +35,7 @@ validator, ) -from confection import ConfigValidationError -from confection.tests.util import my_registry -from confection.validation import Schema, ValidationError, ensure_schema +from confection.validation import Schema, ValidationError, ensure_schema # noqa: E402 # --- ensure_schema conversion --- @@ -151,88 +149,90 @@ def name_must_be_upper(cls, v): converted.model_validate({"name": "HELLO"}) -# --- Registry integration --- +# --- Pydantic v2 native models --- +import pydantic as _pydantic_v2 # noqa: E402 -def test_registry_resolve_with_pydantic_schema(): - class RegSchema(BaseModel): - hello: StrictInt - world: StrictInt - class Config: - extra = "forbid" +class V2SimpleSchema(_pydantic_v2.BaseModel): + model_config = _pydantic_v2.ConfigDict(extra="forbid") + name: str + value: int = 10 - result = my_registry.resolve( - {"hello": 1, "world": 2}, schema=RegSchema, validate=True - ) - assert result == {"hello": 1, "world": 2} +class V2InnerSchema(_pydantic_v2.BaseModel): + model_config = _pydantic_v2.ConfigDict(extra="forbid") + x: int -def test_registry_resolve_rejects_bad_type(): - class RegSchema(BaseModel): - hello: StrictInt - world: StrictInt - class Config: - extra = "forbid" +class V2OuterSchema(_pydantic_v2.BaseModel): + model_config = _pydantic_v2.ConfigDict(extra="forbid") + inner: V2InnerSchema + label: str = "default" - with pytest.raises(ConfigValidationError): - my_registry.resolve( - {"hello": "bad", "world": 2}, schema=RegSchema, validate=True - ) +def test_v2_converts_to_schema(): + converted = ensure_schema(V2SimpleSchema) + assert issubclass(converted, Schema) -def test_registry_fill_with_defaults(): - class FillSchema(BaseModel): - required: StrictInt - optional: StrictStr = "default_value" - class Config: - extra = "forbid" +def test_v2_extracts_fields(): + converted = ensure_schema(V2SimpleSchema) + assert "name" in converted.model_fields + assert "value" in converted.model_fields + assert converted.model_fields["name"].is_required() + assert converted.model_fields["value"].default == 10 - filled = my_registry.fill({"required": 42}, schema=FillSchema) - assert filled["required"] == 42 - assert filled["optional"] == "default_value" +def test_v2_extracts_config(): + converted = ensure_schema(V2SimpleSchema) + assert converted.model_config["extra"] == "forbid" -def test_registry_fill_rejects_extra(): - class StrictSchema(BaseModel): - x: StrictInt - class Config: - extra = "forbid" +def test_v2_nested_conversion(): + converted = ensure_schema(V2OuterSchema) + inner_type = converted.model_fields["inner"].annotation + assert issubclass(inner_type, Schema) + assert "x" in inner_type.model_fields - with pytest.raises(ConfigValidationError): - my_registry.fill({"x": 1, "extra": "bad"}, schema=StrictSchema, validate=True) +def test_v2_validate_correct(): + converted = ensure_schema(V2SimpleSchema) + result = converted.model_validate({"name": "test"}) + assert result.name == "test" + assert result.value == 10 -# --- Mimics spaCy-style schemas --- +def test_v2_validate_rejects_extra(): + converted = ensure_schema(V2SimpleSchema) + with pytest.raises(ValidationError): + converted.model_validate({"name": "x", "extra": 1}) -def test_spacy_style_config_schema(): - """Test a schema structure similar to spaCy's ConfigSchemaTraining.""" - class TrainingSchema(BaseModel): - train_corpus: StrictStr = Field(..., title="Training data path") - dev_corpus: StrictStr = Field(..., title="Dev data path") - dropout: StrictFloat = Field(..., title="Dropout rate") - max_epochs: StrictInt = Field(..., title="Max epochs") - seed: StrictInt = Field(0, title="Random seed") +# --- Config integration with pydantic schema --- + + +def test_config_from_str_with_pydantic_schema(): + """Config.from_str works with a pydantic schema for validation and defaults.""" + from confection import Config + + class MyPydanticSchema(BaseModel): + name: StrictStr + value: StrictInt = 10 class Config: extra = "forbid" - arbitrary_types_allowed = True - - config = { - "train_corpus": "corpus/train", - "dev_corpus": "corpus/dev", - "dropout": 0.2, - "max_epochs": 100, - } - filled = my_registry.fill(config, schema=TrainingSchema) - assert filled["seed"] == 0 - assert filled["dropout"] == 0.2 - - resolved = my_registry.resolve(config, schema=TrainingSchema, validate=True) - assert resolved["train_corpus"] == "corpus/train" - assert resolved["seed"] == 0 + + class TopSchema(BaseModel): + section: MyPydanticSchema + + config = Config().from_str( + """ +[section] +name = "test" +""", + interpolate=False, + schema=TopSchema, + ) + assert config["section"]["name"] == "test" + assert config["section"]["value"] == 10 diff --git a/tests/test_registry_fill.py b/tests/test_registry_fill.py new file mode 100644 index 0000000..741e020 --- /dev/null +++ b/tests/test_registry_fill.py @@ -0,0 +1,287 @@ +"""Test registry.fill() default-filling from function signatures.""" + +import catalogue +import pytest + +from confection import Config, registry +from confection._errors import ConfigValidationError + +# --- Test registry setup --- + + +class _test_registry(registry): + optimizers = catalogue.create("test_fill", "optimizers", entry_points=False) + schedules = catalogue.create("test_fill", "schedules", entry_points=False) + models = catalogue.create("test_fill", "models", entry_points=False) + + +@_test_registry.optimizers.register("Adam.v1") +def adam(learn_rate: float = 0.001, beta1: float = 0.9, beta2: float = 0.999): + return {"learn_rate": learn_rate, "beta1": beta1, "beta2": beta2} + + +@_test_registry.optimizers.register("SGD.v1") +def sgd(learn_rate: float = 0.01, momentum: float = 0.9): + return {"learn_rate": learn_rate, "momentum": momentum} + + +@_test_registry.schedules.register("linear.v1") +def linear(start: float = 0.0, end: float = 1.0, steps: int = 100): + return [start + (end - start) * i / steps for i in range(steps)] + + +@_test_registry.models.register("cnn.v1") +def cnn(width: int = 128, depth: int = 3, dropout: float = 0.1): + return {"width": width, "depth": depth, "dropout": dropout} + + +@_test_registry.models.register("no_defaults.v1") +def no_defaults(width: int, depth: int): + return {"width": width, "depth": depth} + + +# --- Tests --- + + +def test_fill_basic(): + """Fill adds missing defaults from the registered function.""" + config = Config().from_str( + """ +[optimizer] +@optimizers = "Adam.v1" +learn_rate = 0.01 +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert filled["optimizer"]["learn_rate"] == 0.01 + assert filled["optimizer"]["beta1"] == 0.9 + assert filled["optimizer"]["beta2"] == 0.999 + + +def test_fill_preserves_provided(): + """Values explicitly provided should not be overwritten.""" + config = Config().from_str( + """ +[optimizer] +@optimizers = "Adam.v1" +learn_rate = 0.05 +beta1 = 0.8 +beta2 = 0.99 +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert filled["optimizer"]["learn_rate"] == 0.05 + assert filled["optimizer"]["beta1"] == 0.8 + assert filled["optimizer"]["beta2"] == 0.99 + + +def test_fill_all_defaults(): + """When no args are provided, all defaults should be filled.""" + config = Config().from_str( + """ +[optimizer] +@optimizers = "Adam.v1" +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert filled["optimizer"]["learn_rate"] == 0.001 + assert filled["optimizer"]["beta1"] == 0.9 + assert filled["optimizer"]["beta2"] == 0.999 + + +def test_fill_no_defaults(): + """Function with no defaults should not add anything.""" + config = Config().from_str( + """ +[model] +@models = "no_defaults.v1" +width = 64 +depth = 2 +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert filled["model"]["width"] == 64 + assert filled["model"]["depth"] == 2 + + +def test_fill_nested_promise(): + """Defaults are filled recursively into nested promises.""" + config = Config().from_str( + """ +[optimizer] +@optimizers = "SGD.v1" + +[optimizer.learn_rate] +@schedules = "linear.v1" +end = 0.5 +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + # Outer promise filled + assert filled["optimizer"]["momentum"] == 0.9 + # Inner promise filled + lr = filled["optimizer"]["learn_rate"] + assert lr["start"] == 0.0 + assert lr["end"] == 0.5 # provided, not overwritten + assert lr["steps"] == 100 + + +def test_fill_non_promise_sections(): + """Non-promise sections are passed through unchanged.""" + config = Config().from_str( + """ +[training] +epochs = 10 + +[optimizer] +@optimizers = "Adam.v1" +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert filled["training"]["epochs"] == 10 + assert filled["optimizer"]["beta1"] == 0.9 + + +def test_fill_preserves_registry_key(): + """The @registry key should be preserved in the filled config.""" + config = Config().from_str( + """ +[optimizer] +@optimizers = "Adam.v1" +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert filled["optimizer"]["@optimizers"] == "Adam.v1" + + +def test_fill_returns_config(): + """fill() should return a Config object.""" + config = Config().from_str( + """ +[optimizer] +@optimizers = "Adam.v1" +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert isinstance(filled, Config) + + +def test_fill_top_level_promise_raises(): + """Top-level config can't be a promise.""" + with pytest.raises(ConfigValidationError): + _test_registry.fill({"@optimizers": "Adam.v1"}) + + +def test_fill_with_overrides(): + """Overrides should be applied before filling defaults.""" + config = Config().from_str( + """ +[optimizer] +@optimizers = "Adam.v1" +learn_rate = 0.01 +""", + interpolate=False, + ) + filled = _test_registry.fill( + config, overrides={"optimizer.learn_rate": 0.1}, interpolate=True + ) + assert filled["optimizer"]["learn_rate"] == 0.1 + assert filled["optimizer"]["beta1"] == 0.9 + + +def test_fill_promise_rejects_extra_args(): + """Extra args not in the function signature should be rejected.""" + config = Config( + { + "optimizer": { + "@optimizers": "Adam.v1", + "learn_rate": 0.01, + "metadata": {"nested_model": {"@models": "cnn.v1"}}, + } + } + ) + with pytest.raises(ConfigValidationError): + _test_registry.fill(config) + + +def test_fill_nested_non_promise_dict(): + """Non-promise dicts nested under a section with promises should recurse.""" + config = Config().from_str( + """ +[section] + +[section.sub] +x = 1 + +[section.model] +@models = "cnn.v1" +width = 64 +""", + interpolate=False, + ) + filled = _test_registry.fill(config) + assert filled["section"]["sub"] == {"x": 1} + assert filled["section"]["model"]["depth"] == 3 + + +def test_fill_skips_unknown_functions(): + """fill() should skip promises referencing unknown functions, not crash.""" + config = Config( + { + "section": { + "known": {"@optimizers": "Adam.v1"}, + "unknown": {"@optimizers": "nonexistent.v1"}, + } + } + ) + filled = _test_registry.fill(config) + # Known function gets defaults filled + assert filled["section"]["known"]["beta1"] == 0.9 + # Unknown function passes through unchanged + assert filled["section"]["unknown"] == {"@optimizers": "nonexistent.v1"} + + +def test_fill_strips_extra_fields_with_schema(): + """fill() with a schema that has extra='forbid' should strip unknown keys.""" + from confection.validation import Schema + + class TrainingSchema(Schema): + model_config = {"extra": "forbid"} + patience: int = 10 + dropout: float = 0.2 + + class MySchema(Schema): + model_config = {"extra": "forbid"} + training: TrainingSchema + + config = Config({"training": {"patience": 5, "extra_field": "hello"}}) + filled = _test_registry.fill(config, schema=MySchema) + assert filled["training"]["patience"] == 5 + assert filled["training"]["dropout"] == 0.2 + assert "extra_field" not in filled["training"] + + +def test_fill_with_interpolation(): + """fill() with interpolate=True should resolve variables.""" + config = Config().from_str( + """ +[hyper] +lr = 0.01 + +[optimizer] +@optimizers = "Adam.v1" +learn_rate = ${hyper.lr} +""", + interpolate=False, + ) + filled = _test_registry.fill(config, interpolate=True) + assert filled["optimizer"]["learn_rate"] == 0.01 + assert filled["optimizer"]["beta1"] == 0.9 diff --git a/tests/test_registry_resolve.py b/tests/test_registry_resolve.py new file mode 100644 index 0000000..4cd9a61 --- /dev/null +++ b/tests/test_registry_resolve.py @@ -0,0 +1,341 @@ +"""Test registry.resolve() and the Promise lifecycle.""" + +from typing import Callable, List + +import catalogue +import pytest + +from confection import Config, registry +from confection._constants import ARGS_FIELD_ALIAS, RESERVED_FIELDS_REVERSE +from confection._errors import ConfigValidationError +from confection._registry import ( + Promise, + _deep_copy_with_uncopyable, + _is_config_section, + alias_generator, + fix_positionals, + insert_promises, + resolve_promises, +) + +# --- Test registry setup --- + + +class _test_registry(registry): + namespace = "test_resolve" + cats = catalogue.create(namespace, "cats", entry_points=False) + optimizers = catalogue.create(namespace, "optimizers", entry_points=False) + schedules = catalogue.create(namespace, "schedules", entry_points=False) + layers = catalogue.create(namespace, "layers", entry_points=False) + + +@_test_registry.cats.register("catsie.v1") +def catsie(evil: bool, cute: bool = True) -> str: + if evil: + return "scratch!" + return "meow" + + +@_test_registry.optimizers.register("Adam.v1") +def adam(learn_rate: float = 0.001, beta1: float = 0.9, beta2: float = 0.999): + return {"learn_rate": learn_rate, "beta1": beta1, "beta2": beta2} + + +@_test_registry.schedules.register("decay.v1") +def decay(base_rate: float, repeat: int) -> List[float]: + return [base_rate] * repeat + + +@_test_registry.optimizers.register("cool.v1") +def cool_optimizer(learn_rate, beta1: float = 0.9): + return {"learn_rate": learn_rate, "beta1": beta1} + + +@_test_registry.layers.register("linear.v1") +def linear_layer(width: int, init: Callable = lambda: None): + return {"width": width, "init": init} + + +# --- resolve basic --- + + +def test_resolve_simple(): + config = Config({"cat": {"@cats": "catsie.v1", "evil": True}}) + result = _test_registry.resolve(config) + assert result["cat"] == "scratch!" + + +def test_resolve_with_defaults(): + config = Config({"cat": {"@cats": "catsie.v1", "evil": False}}) + result = _test_registry.resolve(config) + assert result["cat"] == "meow" + + +def test_resolve_nested_promise(): + """A promise arg can be another promise.""" + config = Config( + { + "optimizer": { + "@optimizers": "cool.v1", + "learn_rate": { + "@schedules": "decay.v1", + "base_rate": 0.001, + "repeat": 4, + }, + } + } + ) + result = _test_registry.resolve(config) + assert result["optimizer"]["learn_rate"] == [0.001] * 4 + assert result["optimizer"]["beta1"] == 0.9 + + +def test_resolve_non_promise_passthrough(): + config = Config( + {"training": {"epochs": 10}, "cat": {"@cats": "catsie.v1", "evil": True}} + ) + result = _test_registry.resolve(config) + assert result["training"] == {"epochs": 10} + assert result["cat"] == "scratch!" + + +def test_resolve_from_str(): + config = Config().from_str(""" +[optimizer] +@optimizers = "Adam.v1" +learn_rate = 0.01 +""") + result = _test_registry.resolve(config) + assert result["optimizer"]["learn_rate"] == 0.01 + assert result["optimizer"]["beta1"] == 0.9 + + +# --- Promise class --- + + +def test_promise_from_dict(): + values = {"@cats": "catsie.v1", "evil": True, "cute": False} + promise = Promise.from_dict(_test_registry, values) + assert promise.registry == "cats" + assert promise.name == "catsie.v1" + assert promise.kwargs == {"evil": True, "cute": False} + + +def test_promise_resolve(): + values = {"@cats": "catsie.v1", "evil": True} + promise = Promise.from_dict(_test_registry, values) + assert promise.resolve() == "scratch!" + + +def test_promise_return_type(): + values = {"@cats": "catsie.v1", "evil": True} + promise = Promise.from_dict(_test_registry, values) + assert promise.return_type is str + + +# --- insert_promises --- + + +def test_insert_promises(): + config = {"cat": {"@cats": "catsie.v1", "evil": True}, "x": 1} + result = insert_promises(_test_registry, config, resolve=True) + assert isinstance(result["cat"], Promise) + assert result["x"] == 1 + + +def test_insert_promises_nested(): + config = { + "section": { + "cat": {"@cats": "catsie.v1", "evil": False}, + "val": 42, + } + } + result = insert_promises(_test_registry, config, resolve=True) + assert isinstance(result["section"]["cat"], Promise) + assert result["section"]["val"] == 42 + + +# --- resolve_promises --- + + +def test_resolve_promises(): + config = {"cat": {"@cats": "catsie.v1", "evil": True}, "x": 1} + promised = insert_promises(_test_registry, config, resolve=True) + resolved = resolve_promises(promised) + assert resolved["cat"] == "scratch!" + assert resolved["x"] == 1 + + +# --- fix_positionals --- + + +def test_fix_positionals_dict_to_tuple(): + config = {"*": {"0": "a", "1": "b"}, "x": 1} + result = fix_positionals(config) + assert result["*"] == ("a", "b") + assert result["x"] == 1 + + +def test_fix_positionals_nested(): + config = {"section": {"*": {"0": "a"}, "val": 1}} + result = fix_positionals(config) + assert result["section"]["*"] == ("a",) + + +def test_fix_positionals_list(): + config = [{"*": {"0": 1}}, {"x": 2}] + result = fix_positionals(config) + assert result[0]["*"] == (1,) + assert result[1] == {"x": 2} + + +def test_fix_positionals_tuple(): + config = ({"*": {"0": 1}},) + result = fix_positionals(config) + assert result[0]["*"] == (1,) + assert isinstance(result, tuple) + + +def test_fix_positionals_scalar(): + assert fix_positionals(42) == 42 + assert fix_positionals("hello") == "hello" + + +# --- _is_config_section --- + + +def test_is_config_section(): + assert _is_config_section({"a": 1, "b": 2}) + assert not _is_config_section([1, 2]) + assert not _is_config_section("hello") + assert _is_config_section({}) + + +# --- alias_generator --- + + +def test_alias_generator_args_field(): + assert alias_generator(ARGS_FIELD_ALIAS) == "*" + + +def test_alias_generator_reserved(): + for alias, original in RESERVED_FIELDS_REVERSE.items(): + assert alias_generator(alias) == original + + +def test_alias_generator_normal(): + assert alias_generator("some_field") == "some_field" + + +# --- _deep_copy_with_uncopyable --- + + +def test_deep_copy_dict(): + original = {"a": [1, 2], "b": {"c": 3}} + copied = _deep_copy_with_uncopyable(original) + assert copied == original + copied["a"].append(3) + assert original["a"] == [1, 2] + + +def test_deep_copy_list(): + original = [{"a": 1}, {"b": 2}] + copied = _deep_copy_with_uncopyable(original) + assert copied == original + copied[0]["a"] = 99 + assert original[0]["a"] == 1 + + +def test_deep_copy_tuple(): + original = ({"a": 1},) + copied = _deep_copy_with_uncopyable(original) + assert copied == original + assert isinstance(copied, tuple) + + +def test_deep_copy_scalar(): + assert _deep_copy_with_uncopyable(42) == 42 + assert _deep_copy_with_uncopyable("hello") == "hello" + + +def test_deep_copy_generator(): + """Generators can't be deepcopied — should pass through.""" + + def gen(): + yield 1 + + g = gen() + copied = _deep_copy_with_uncopyable(g) + assert copied is g # same object, not copied + + +def test_deep_copy_memo(): + """Shared references should be preserved.""" + shared = [1, 2, 3] + original = {"a": shared, "b": shared} + copied = _deep_copy_with_uncopyable(original) + assert copied["a"] is copied["b"] # still shared + assert copied["a"] is not original["a"] # but different from original + + +# --- registry.has / registry.get --- + + +def test_registry_has(): + assert _test_registry.has("cats", "catsie.v1") + assert not _test_registry.has("cats", "nonexistent") + assert not _test_registry.has("nonexistent_registry", "anything") + + +def test_registry_get(): + func = _test_registry.get("cats", "catsie.v1") + assert func is catsie + + +def test_registry_get_unknown_registry(): + with pytest.raises(ValueError, match="Unknown registry"): + _test_registry.get("nonexistent", "anything") + + +def test_registry_get_unknown_func(): + with pytest.raises((ValueError, catalogue.RegistryError)): + _test_registry.get("cats", "nonexistent") + + +# --- get_constructor / parse_args --- + + +def test_get_constructor(): + assert _test_registry.get_constructor({"@cats": "catsie.v1", "evil": True}) == ( + "cats", + "catsie.v1", + ) + + +def test_get_constructor_multiple_refs(): + with pytest.raises(ConfigValidationError): + _test_registry.get_constructor({"@cats": "catsie.v1", "@dogs": "doggo.v1"}) + + +def test_parse_args(): + args, kwargs = _test_registry.parse_args( + {"@cats": "catsie.v1", "evil": True, "cute": False} + ) + assert args == [] + assert kwargs == {"evil": True, "cute": False} + + +def test_parse_args_with_positionals(): + args, kwargs = _test_registry.parse_args( + {"@cats": "catsie.v1", "*": [1, 2], "evil": True} + ) + assert args == [1, 2] + assert kwargs == {"evil": True} + + +# --- resolve top-level promise raises --- + + +def test_resolve_top_level_promise(): + with pytest.raises(ConfigValidationError): + _test_registry.resolve({"@cats": "catsie.v1", "evil": True}) diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..dcc1623 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,179 @@ +"""Test schema validation and default filling at the Config layer.""" + +from typing import List, Optional + +import pytest + +from confection import Config +from confection._errors import ConfigValidationError +from confection.validation import Schema + + +class TrainingSchema(Schema): + patience: int + dropout: float = 0.2 + use_vectors: bool = False + + +class NlpSchema(Schema): + lang: str + pipeline: List[str] = [] + + +class FullSchema(Schema): + training: TrainingSchema + nlp: NlpSchema + + +# -- fill_defaults -- + + +def test_fill_defaults_top_level(): + config = Config({"training": {"patience": 10}}) + config.fill_defaults(FullSchema) + assert config["training"]["dropout"] == 0.2 + assert config["training"]["use_vectors"] is False + + +def test_fill_defaults_preserves_existing(): + config = Config({"training": {"patience": 10, "dropout": 0.5}}) + config.fill_defaults(FullSchema) + assert config["training"]["dropout"] == 0.5 + + +def test_fill_defaults_adds_missing_sections(): + config = Config({"training": {"patience": 10}}) + config.fill_defaults(FullSchema) + # nlp has all defaults so it should be filled + assert "nlp" not in config # nlp section itself is required, not defaulted + + +def test_fill_defaults_returns_self(): + config = Config({"training": {"patience": 10}}) + result = config.fill_defaults(FullSchema) + assert result is config + + +# -- validate -- + + +def test_validate_passes(): + config = Config( + {"training": {"patience": 10, "dropout": 0.2}, "nlp": {"lang": "en"}} + ) + config.validate(FullSchema) # should not raise + + +def test_validate_missing_required(): + config = Config({"training": {"dropout": 0.5}, "nlp": {"lang": "en"}}) + with pytest.raises(ConfigValidationError): + config.validate(FullSchema) + + +def test_validate_wrong_type(): + config = Config( + {"training": {"patience": "nope", "dropout": 0.2}, "nlp": {"lang": "en"}} + ) + with pytest.raises(ConfigValidationError): + config.validate(FullSchema) + + +# -- from_str with schema -- + + +def test_from_str_with_schema(): + config = Config().from_str( + """ +[training] +patience = 10 + +[nlp] +lang = "en" +""", + interpolate=False, + schema=FullSchema, + ) + assert config["training"]["patience"] == 10 + assert config["training"]["dropout"] == 0.2 + assert config["training"]["use_vectors"] is False + + +def test_from_str_schema_validates(): + with pytest.raises(ConfigValidationError): + Config().from_str( + """ +[training] +dropout = 0.5 + +[nlp] +lang = "en" +""", + interpolate=False, + schema=FullSchema, + ) + + +def test_from_str_schema_with_interpolation(): + config = Config().from_str( + """ +[training] +patience = 10 + +[nlp] +lang = "en" +""", + interpolate=True, + schema=FullSchema, + ) + assert config["training"]["dropout"] == 0.2 + + +# -- Schema with extra="forbid" -- + + +class StrictSchema(Schema): + model_config = {"extra": "forbid"} + x: int + y: str = "default" + + +def test_validate_extra_forbidden(): + config = Config({"x": 1, "y": "hello", "z": "extra"}) + with pytest.raises(ConfigValidationError): + config.validate(StrictSchema) + + +# -- Schema with Optional fields -- + + +class OptionalSchema(Schema): + name: str + description: Optional[str] = None + + +def test_optional_field_defaults_to_none(): + config = Config({"name": "test"}) + config.fill_defaults(OptionalSchema) + assert config["description"] is None + + +def test_optional_field_accepts_value(): + config = Config({"name": "test", "description": "hello"}) + config.validate(OptionalSchema) # should not raise + + +# -- Flat schema (no nesting) -- + + +class FlatSchema(Schema): + x: int + y: float = 3.14 + z: str = "hello" + + +def test_flat_schema_from_str(): + """Schema works for flat configs too (top-level keys are leaves, not sections).""" + config = Config({"x": 1}) + config.fill_defaults(FlatSchema) + assert config["y"] == 3.14 + assert config["z"] == "hello" diff --git a/tests/test_typechecker_edges.py b/tests/test_typechecker_edges.py new file mode 100644 index 0000000..4900485 --- /dev/null +++ b/tests/test_typechecker_edges.py @@ -0,0 +1,278 @@ +"""Tests for typechecker edge cases.""" + +from dataclasses import dataclass +from enum import Enum +from typing import Annotated, Any + +from confection.typechecker import check_type +from confection.validation import Field, Schema + +# --- custom_handlers --- + + +def test_custom_handler(): + def handle_str(value, annotation, handlers, ctx): + return value == "magic" + + assert check_type("magic", int, custom_handlers={str: handle_str}) + assert not check_type("other", int, custom_handlers={str: handle_str}) + + +# --- Annotated with Strict() --- + +try: + from pydantic import Strict +except ImportError: + + class Strict: # type: ignore + strict = True + + +def test_strict_int(): + assert check_type(42, Annotated[int, Strict()]) + assert not check_type(True, Annotated[int, Strict()]) + assert not check_type("42", Annotated[int, Strict()]) + + +def test_strict_float(): + assert check_type(3.14, Annotated[float, Strict()]) + assert not check_type(42, Annotated[float, Strict()]) + + +def test_strict_str(): + assert check_type("hi", Annotated[str, Strict()]) + assert not check_type(42, Annotated[str, Strict()]) + + +def test_strict_bool(): + assert check_type(True, Annotated[bool, Strict()]) + assert not check_type(1, Annotated[bool, Strict()]) + + +def test_strict_other(): + assert check_type([1], Annotated[list, Strict()]) + + +# --- String forward reference --- + + +def test_string_annotation_accepts_anything(): + assert check_type(42, "SomeForwardRef") + assert check_type("hi", "SomeForwardRef") + + +# --- String enum --- + + +class Color(str, Enum): + RED = "red" + GREEN = "green" + + +def test_string_enum(): + assert check_type("red", Color) + assert not check_type("blue", Color) + + +# --- float with non-numeric --- + + +def test_float_rejects_list(): + assert not check_type([], float) + + +# --- Type[X] with non-class generic arg --- + + +def test_type_non_class_arg(): + """Type[X] where X causes TypeError in issubclass should accept.""" + from typing import Type + + # This would raise TypeError in issubclass + assert check_type(int, Type[Any]) + + +# --- Dataclass with dict --- + + +@dataclass +class Point: + x: int + y: int + + +def test_dataclass_accepts_dict(): + assert check_type({"x": 1, "y": 2}, Point) + + +def test_dataclass_dict_validates_fields(): + assert not check_type({"x": "bad", "y": 2}, Point) + + +def test_dataclass_with_dataclass_value(): + p = Point(x=1, y=2) + assert check_type(p, Point) + + +def test_dataclass_value_validates_fields(): + p = Point(x=1, y=2) + # Manually set bad type to test field validation + p.x = "bad" # type: ignore + assert not check_type(p, Point) + + +# --- Schema with aliased field in decompose --- + + +class AliasedSchema(Schema): + x: int = Field(default=..., alias="x_alias") + + +def test_schema_decompose_alias(): + assert check_type({"x_alias": 42}, AliasedSchema) + assert not check_type({"x_alias": "bad"}, AliasedSchema) + + +# --- Fall-through return False --- + + +def test_unknown_annotation(): + """An annotation that doesn't match any branch returns False.""" + # A module object isn't a type, not a TypeVar, not a string, etc. + import os + + assert not check_type(42, os) + + +# --- Pydantic v1 validator edges --- + + +def test_pydantic_v1_validator_many_params(): + """Validators with >2 params are skipped.""" + + class MyType: + @classmethod + def __get_validators__(cls): + def three_params(v, field, config): + return v + + yield three_params + + # Should still pass (validator is skipped) + assert check_type(42, MyType) + + +def test_pydantic_v1_validator_sig_error(): + """Validators with un-inspectable signatures default to 1 param.""" + + class MyType: + @classmethod + def __get_validators__(cls): + yield len # len's signature can be inspected, but let's test the path + + # len(42) will raise TypeError + assert not check_type(42, MyType) + + +# --- Pydantic v2 match edges --- + + +def test_pydantic_v2_no_function_key(): + """Schema without 'function' key in result.""" + + class MyType: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + return {"type": "any"} # no function key + + assert not check_type(42, MyType) + + +def test_pydantic_v2_isinstance_shortcut(): + """If value is already an instance, skip validator.""" + + class MyType: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + return {"type": "any"} + + assert check_type(MyType(), MyType) + + +# --- Parameter.empty --- + + +def test_parameter_empty(): + import inspect + + assert check_type(42, inspect.Parameter.empty) + assert check_type("anything", inspect.Parameter.empty) + + +# --- Type[Union[...]] --- + + +def test_type_union_arg(): + """Type[Union[int, str]] — issubclass raises TypeError, should accept.""" + from typing import Type, Union + + assert check_type(int, Type[Union[int, str]]) + + +# --- _pydantic_v1_match isinstance shortcut --- + + +def test_pydantic_v1_isinstance_shortcut(): + class MyType: + @classmethod + def __get_validators__(cls): + yield lambda v: v + + assert check_type(MyType(), MyType) + + +# --- _pydantic_v1_match uninspectable signature --- + + +def test_pydantic_v1_uninspectable_sig(): + """Validator with uninspectable signature defaults to 1 param.""" + + class MyType: + @classmethod + def __get_validators__(cls): + # A builtin with no inspectable sig in some Python versions + class Uninspectable: + def __call__(self, v): + if not isinstance(v, int): + raise ValueError + return v + + # Make signature() raise + __signature__ = property(lambda self: (_ for _ in ()).throw(ValueError)) + + yield Uninspectable() + + assert check_type(42, MyType) + assert not check_type("hi", MyType) + + +# --- TypeVar bound that fails isinstance --- + + +def test_typevar_bound_isinstance_typeerror(): + """TypeVar bound with ForwardRef that fails isinstance falls back to outer_match.""" + from typing import ForwardRef, TypeVar + + T = TypeVar("T", bound=ForwardRef("NonExistent")) + # ForwardRef can't be used with isinstance — falls back to outer_match + # which accepts ForwardRefs + assert check_type(42, T) + + +# --- ctx default --- + + +def test_check_type_creates_default_ctx(): + """check_type works without explicit ctx.""" + assert check_type(42, int) + assert not check_type("hi", int) diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..f1fff6d --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,143 @@ +"""Test util.py: frozen collections, partial, try_dump_json.""" + +import copy + +import pytest + +from confection._errors import ConfigValidationError +from confection.util import ( + SimpleFrozenDict, + SimpleFrozenList, + partial, + try_dump_json, +) + +# --- SimpleFrozenDict --- + + +def test_frozen_dict_init(): + d = SimpleFrozenDict({"a": 1, "b": 2}) + assert d["a"] == 1 + + +def test_frozen_dict_setitem(): + d = SimpleFrozenDict() + with pytest.raises(NotImplementedError): + d["x"] = 1 + + +def test_frozen_dict_pop(): + d = SimpleFrozenDict({"a": 1}) + with pytest.raises(NotImplementedError): + d.pop("a") + + +def test_frozen_dict_update(): + d = SimpleFrozenDict() + with pytest.raises(NotImplementedError): + d.update({"x": 1}) + + +def test_frozen_dict_deepcopy(): + d = SimpleFrozenDict({"a": [1, 2]}) + d2 = copy.deepcopy(d) + assert isinstance(d2, SimpleFrozenDict) + assert d2["a"] == [1, 2] + assert d2["a"] is not d["a"] + + +def test_frozen_dict_custom_error(): + d = SimpleFrozenDict(error="custom error") + with pytest.raises(NotImplementedError, match="custom error"): + d["x"] = 1 + + +# --- SimpleFrozenList --- + + +def test_frozen_list_init(): + lst = SimpleFrozenList([1, 2, 3]) + assert lst[0] == 1 + + +def test_frozen_list_append(): + lst = SimpleFrozenList() + with pytest.raises(NotImplementedError): + lst.append(1) + + +def test_frozen_list_clear(): + lst = SimpleFrozenList([1]) + with pytest.raises(NotImplementedError): + lst.clear() + + +def test_frozen_list_extend(): + lst = SimpleFrozenList() + with pytest.raises(NotImplementedError): + lst.extend([1]) + + +def test_frozen_list_insert(): + lst = SimpleFrozenList() + with pytest.raises(NotImplementedError): + lst.insert(0, 1) + + +def test_frozen_list_pop(): + lst = SimpleFrozenList([1]) + with pytest.raises(NotImplementedError): + lst.pop() + + +def test_frozen_list_remove(): + lst = SimpleFrozenList([1]) + with pytest.raises(NotImplementedError): + lst.remove(1) + + +def test_frozen_list_reverse(): + lst = SimpleFrozenList([1, 2]) + with pytest.raises(NotImplementedError): + lst.reverse() + + +def test_frozen_list_sort(): + lst = SimpleFrozenList([2, 1]) + with pytest.raises(NotImplementedError): + lst.sort() + + +def test_frozen_list_deepcopy(): + lst = SimpleFrozenList([[1, 2], [3]]) + lst2 = copy.deepcopy(lst) + assert isinstance(lst2, SimpleFrozenList) + assert lst2 == [[1, 2], [3]] + assert lst2[0] is not lst[0] + + +def test_frozen_list_custom_error(): + lst = SimpleFrozenList(error="nope") + with pytest.raises(NotImplementedError, match="nope"): + lst.append(1) + + +# --- partial --- + + +def test_partial_basic(): + def add(a, b): + """Add two numbers.""" + return a + b + + p = partial(add, 1) + assert p(2) == 3 + assert p.__doc__ == "Add two numbers." + + +# --- try_dump_json error --- + + +def test_try_dump_json_unserializable(): + with pytest.raises(ConfigValidationError, match="Couldn't serialize"): + try_dump_json(object()) diff --git a/confection/tests/test_validate_type.py b/tests/test_validate_type.py similarity index 88% rename from confection/tests/test_validate_type.py rename to tests/test_validate_type.py index 822d498..c66ecd7 100644 --- a/confection/tests/test_validate_type.py +++ b/tests/test_validate_type.py @@ -55,47 +55,6 @@ def test_annotated(): assert validate_type("hi", Annotated[int, "metadata"]) is not None -@pytest.mark.skipif(Annotated is None, reason="Annotated not available") -def test_annotated_strict(): - """Test that Annotated types with Strict metadata enforce strict checking. - This covers pydantic's StrictInt, StrictStr, StrictFloat, StrictBool.""" - - class Strict: - def __init__(self, strict=True): - self.strict = strict - - StrictInt = Annotated[int, Strict()] - StrictStr = Annotated[str, Strict()] - StrictFloat = Annotated[float, Strict()] - StrictBool = Annotated[bool, Strict()] - - # StrictInt: only actual ints - assert validate_type(42, StrictInt) is None - assert validate_type("10", StrictInt) is not None - assert validate_type(True, StrictInt) is not None - assert validate_type(3.0, StrictInt) is not None - - # StrictStr: only actual strings - assert validate_type("hello", StrictStr) is None - assert validate_type(42, StrictStr) is not None - - # StrictFloat: only actual floats - assert validate_type(3.14, StrictFloat) is None - assert validate_type(42, StrictFloat) is not None - assert validate_type("3.14", StrictFloat) is not None - - # StrictBool: only actual bools - assert validate_type(True, StrictBool) is None - assert validate_type(False, StrictBool) is None - assert validate_type(1, StrictBool) is not None - assert validate_type(0, StrictBool) is not None - - # Non-strict Annotated should still allow coercion - NonStrict = Annotated[int, Strict(strict=False)] - assert validate_type(42, NonStrict) is None - assert validate_type("10", NonStrict) is None # int("10") works - - # === Union / Optional === diff --git a/tests/test_validation_edges.py b/tests/test_validation_edges.py new file mode 100644 index 0000000..eeddee5 --- /dev/null +++ b/tests/test_validation_edges.py @@ -0,0 +1,348 @@ +"""Tests for edge cases in validation.py.""" + +import sys +from typing import Optional + +import pytest + +from confection.validation import ( + Field, + FieldInfo, + Schema, + _is_pydantic_model, + _pydantic_instance_to_dict, + _validate_schema, + create_schema, + ensure_schema, + validate_type, +) + +# --- Schema with FieldInfo as class default --- + + +class SchemaWithFieldInfo(Schema): + x: int = Field(42) + y: str = Field("hello") + + +def test_field_info_as_class_default(): + assert SchemaWithFieldInfo.model_fields["x"].default == 42 + assert SchemaWithFieldInfo.model_fields["y"].default == "hello" + + +def test_field_info_required(): + class RequiredSchema(Schema): + x: int = Field(...) + + assert RequiredSchema.model_fields["x"].is_required() + + +# --- Schema field aliases --- + + +class AliasedSchema(Schema): + model_config = {"extra": "forbid"} + x: int = Field(default=..., alias="x_alias") + + +def test_model_validate_with_alias(): + result = AliasedSchema.model_validate({"x_alias": 42}) + assert result.x_alias == 42 + + +def test_model_validate_alias_fills_default(): + class AliasDefault(Schema): + x: int = Field(default=10, alias="x_alias") + + result = AliasDefault.model_validate({}) + assert result.x_alias == 10 + + +# --- Schema with alias_generator --- + + +class AliasGenSchema(Schema): + model_config = {"extra": "forbid", "alias_generator": lambda name: name.upper()} + my_field: int + + +def test_model_validate_with_alias_generator(): + result = AliasGenSchema.model_validate({"MY_FIELD": 42}) + assert result.MY_FIELD == 42 + + +def test_model_validate_alias_generator_fills_default(): + class AliasGenDefault(Schema): + model_config = {"alias_generator": lambda name: name.upper()} + x: int = 99 + + result = AliasGenDefault.model_validate({}) + assert result.X == 99 + + +# --- model_dump with nested Schema --- + + +class Inner(Schema): + a: int + + +class Outer(Schema): + inner: Inner + b: str = "hi" + + +def test_model_dump_nested(): + instance = Outer(inner=Inner(a=1)) + dumped = instance.model_dump() + assert dumped == {"inner": {"a": 1}, "b": "hi"} + + +# --- create_schema edge cases --- + + +def test_create_schema_no_config(): + s = create_schema("NoConfig", x=(int, Field(...))) + assert s.model_config == {"extra": "allow"} + + +def test_create_schema_bad_field(): + with pytest.raises(ValueError, match="must be"): + create_schema("Bad", x="not a tuple") + + +def test_create_schema_with_alias_generator(): + s = create_schema( + "AliasGen", + __config__={"extra": "forbid", "alias_generator": lambda n: n.upper()}, + my_field=(int, Field(...)), + ) + assert s.model_fields["my_field"].alias == "MY_FIELD" + + +def test_create_schema_plain_default(): + """Non-FieldInfo second element in tuple gets wrapped.""" + s = create_schema("PlainDefault", x=(int, 42)) + assert s.model_fields["x"].default == 42 + + +# --- resolve_type_hints fallback --- + + +def test_from_function_unresolvable_forward_ref(): + """Forward refs that can't be resolved fall back to raw annotations.""" + + # Create a function with an annotation that can't be resolved + def func(x: "NonExistentType") -> None: # noqa: F821 + pass + + schema = Schema.from_function(func) + # Should still create the schema, just with Any or the raw annotation + assert "x" in schema.model_fields + + +# --- _error_type_for branches --- + + +def test_validate_type_str_error(): + err = validate_type(42, str) + assert err is not None + + +def test_validate_type_float_error(): + err = validate_type("abc", float) + assert err is not None + + +def test_validate_type_bool_error(): + err = validate_type(1, bool) + assert err is not None + + +def test_validate_type_complex_error(): + """Non-primitive type should give value_error.""" + from typing import List + + err = validate_type("not a list", List[int]) + assert err is not None + + +# --- _validate_schema error types --- + + +def test_validate_schema_str_error_type(): + f = FieldInfo(default=...) + f.annotation = str + errors = _validate_schema({"x": 42}, {"x": f}, {"extra": "allow"}, None) + assert errors[0]["type"] == "string_type" + + +def test_validate_schema_float_error_type(): + f = FieldInfo(default=...) + f.annotation = float + errors = _validate_schema({"x": "abc"}, {"x": f}, {"extra": "allow"}, None) + assert errors[0]["type"] == "float_parsing" + + +def test_validate_schema_bool_error_type(): + f = FieldInfo(default=...) + f.annotation = bool + errors = _validate_schema({"x": 1}, {"x": f}, {"extra": "allow"}, None) + assert errors[0]["type"] == "bool_type" + + +def test_validate_schema_positive_int_error_type(): + from confection.validation import PositiveInt + + f = FieldInfo(default=...) + f.annotation = PositiveInt + errors = _validate_schema({"x": -1}, {"x": f}, {"extra": "allow"}, None) + assert errors[0]["type"] == "int_parsing" + + +def test_validate_schema_strict_float_error_type(): + from confection.validation import StrictFloat + + f = FieldInfo(default=...) + f.annotation = StrictFloat + errors = _validate_schema({"x": "abc"}, {"x": f}, {"extra": "allow"}, None) + assert errors[0]["type"] == "float_parsing" + + +def test_validate_schema_generic_error_type(): + from typing import List + + f = FieldInfo(default=...) + f.annotation = List[int] + errors = _validate_schema({"x": "nope"}, {"x": f}, {"extra": "allow"}, None) + assert errors[0]["type"] == "value_error" + + +# --- _is_pydantic_model edge cases --- + + +def test_is_pydantic_model_non_type(): + assert _is_pydantic_model("not a type") is False + + +def test_is_pydantic_model_schema_subclass(): + class MySchema(Schema): + x: int + + assert _is_pydantic_model(MySchema) is False + + +# --- _pydantic_instance_to_dict --- + + +def test_pydantic_instance_to_dict_v2(): + import pydantic + + class M(pydantic.BaseModel): + x: int = 1 + + assert _pydantic_instance_to_dict(M()) == {"x": 1} + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="pydantic v1 not compatible with 3.14+" +) +def test_pydantic_instance_to_dict_v1(): + from pydantic.v1 import BaseModel + + class M(BaseModel): + x: int = 1 + + assert _pydantic_instance_to_dict(M()) == {"x": 1} + + +def test_pydantic_instance_to_dict_plain(): + assert _pydantic_instance_to_dict(42) == 42 + + +# --- ensure_schema passthrough for non-pydantic --- + + +def test_ensure_schema_plain_class(): + class NotAModel: + pass + + assert ensure_schema(NotAModel) is NotAModel + + +# --- v1 allow_none (Optional) --- + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="pydantic v1 not compatible with 3.14+" +) +def test_v1_optional_field(): + from pydantic.v1 import BaseModel as V1Model + + class M(V1Model): + x: Optional[int] = None + + converted = ensure_schema(M) + # Should accept None + result = converted.model_validate({"x": None}) + assert result.x is None + + +# --- v2 model with pydantic instance default --- + + +def test_v2_pydantic_instance_default(): + import pydantic + + class Inner(pydantic.BaseModel): + x: int = 1 + + class Outer(pydantic.BaseModel): + inner: Inner = Inner() + + converted = ensure_schema(Outer) + # Default should be converted from pydantic instance to dict + assert converted.model_fields["inner"].default == {"x": 1} + + +# --- v1 model with pydantic instance default --- + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="pydantic v1 not compatible with 3.14+" +) +def test_v1_pydantic_instance_default(): + from pydantic.v1 import BaseModel as V1Model + + class Inner(V1Model): + x: int = 1 + + class Outer(V1Model): + inner: Inner = Inner() + + converted = ensure_schema(Outer) + assert converted.model_fields["inner"].default == {"x": 1} + + +# --- Schema inheritance --- + + +class BaseSchema(Schema): + x: int = 1 + + +class DerivedSchema(BaseSchema): + y: str = "hello" + + +def test_schema_inheritance(): + assert "x" in DerivedSchema.model_fields + assert "y" in DerivedSchema.model_fields + assert DerivedSchema.model_fields["x"].default == 1 + + +def test_schema_inheritance_override(): + class Override(BaseSchema): + x: int = 99 + + assert Override.model_fields["x"].default == 99