From 059a5e46d24ac1f0855cf2c219e646522cc0f241 Mon Sep 17 00:00:00 2001 From: Vittorio Distefano <101409677+blackopsrepl@users.noreply.github.com> Date: Sun, 12 Apr 2026 11:48:44 +0200 Subject: [PATCH 1/2] Document Path-2 postmortem against removed solverforge-py experiment --- README.md | 2 + docs/python-model-ir.md | 89 +++++++++ docs/python-path2-postmortem.md | 79 ++++++++ python/solverforge_ir/__init__.py | 55 ++++++ python/solverforge_ir/codegen.py | 317 ++++++++++++++++++++++++++++++ python/solverforge_ir/model.py | 305 ++++++++++++++++++++++++++++ python/tests/test_codegen.py | 112 +++++++++++ python/tests/test_model_ir.py | 103 ++++++++++ 8 files changed, 1062 insertions(+) create mode 100644 docs/python-model-ir.md create mode 100644 docs/python-path2-postmortem.md create mode 100644 python/solverforge_ir/__init__.py create mode 100644 python/solverforge_ir/codegen.py create mode 100644 python/solverforge_ir/model.py create mode 100644 python/tests/test_codegen.py create mode 100644 python/tests/test_model_ir.py diff --git a/README.md b/README.md index 4ab325b4..d22141ac 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,8 @@ The current templates cover standard-variable and list-heavy planning models, an - `README.md` is the user-facing entry point for the workspace and generated-project integration model. - `docs/extend-solver.md` and `docs/extend-domain.md` cover scaffold extension workflows. - `docs/lifecycle-pause-resume-contract.md` defines the retained lifecycle contract, including exact pause/resume semantics, snapshot identity, and terminal-state cleanup rules. +- `docs/python-model-ir.md` outlines the Python-first declarative model IR prototype that lowers into typed Rust solver code. +- `docs/python-path2-postmortem.md` compares Path 2 against the removed `solverforge-py` experiment and records hard guardrails. - `docs/typed-contract-audit.md` records the current neutral selector and extractor naming model, including the `EntityCollectionExtractor`, `ValueSelector`, and `MoveSelector` surface adopted in `0.7.0`. - `crates/*/WIREFRAME.md` files are the canonical public API maps for each crate. - `AGENTS.md` defines repository-level engineering and documentation expectations for coding agents. diff --git a/docs/python-model-ir.md b/docs/python-model-ir.md new file mode 100644 index 00000000..ee366099 --- /dev/null +++ b/docs/python-model-ir.md @@ -0,0 +1,89 @@ +# Python Model IR (Path 2: Codegen + Compile) + +This document defines the Python-first model surface that lowers into typed SolverForge Rust code, then compiles as a Rust crate. + +## Goals + +- Preserve SolverForge zero-erasure and monomorphized hot paths. +- Let Python users model the same planning constructs and lifecycle workflows. +- Keep the Python/Rust bridge thin: compile once, run in Rust, stream lifecycle events. + +## Modules + +## Historical Context + +Path 2 guardrails were derived from the removed `solverforge-py` experiment; see `docs/python-path2-postmortem.md`. + +- IR schema and validation: `python/solverforge_ir/model.py` +- Rust code generation and project writer: `python/solverforge_ir/codegen.py` + +## Design + +The IR is declarative and typed: + +- Domain declarations (`FactDef`, `EntityDef`, `VariableDef`, `SolutionDef`) +- Constraint declarations (`ConstraintDef`, `JoinSpec`, `FilterSpec`, `ImpactSpec`) +- Runtime configuration (`TerminationDef`, `SolverDef`) +- Top-level container (`ModelDef`) + +Expressions are represented as an AST (not executable Python callbacks): + +- `RefExpr` +- `ConstExpr` +- `CompareExpr` +- `BoolExpr` +- `CallExpr` + +## Lambda Lowering + +`lambda_to_expr(fn, aliases)` lowers a restricted subset of Python lambda/function syntax into the expression AST: + +- Attribute references from known stream aliases +- `==`, `!=`, `<`, `<=`, `>`, `>=` +- Boolean `and`, `or`, `not` +- Whitelisted calls (`contains`, `overlaps`, `len`) + +Unsupported constructs fail fast with `LambdaLoweringError`. + +## Validation + +`validate_model(model)` performs structural validation: + +- Unique entity/fact names +- Solution collection references target known entities/facts +- Constraint source and join collection references exist +- Join-specific required fields are present (`left_key/right_key` for keyed joins, predicate for predicate joins) + +## Code Generation (Path 2) + +`generate_rust_module(model)` emits Rust source with: + +- Domain structs annotated by `#[problem_fact]`, `#[planning_entity]`, `#[planning_solution]` +- Typed `define_constraints()` function using `ConstraintFactory` and fluent stream builders +- Join lowering for `self_equal`, `cross_keyed`, and `cross_predicate` +- Filter/impact/name lowering per constraint + +`write_rust_project(model, out_dir, crate_name)` writes a compilable crate: + +- `Cargo.toml` +- `src/lib.rs` + +Returned metadata (`GeneratedRustProject`) points to generated paths. + +## Intended Lowering Contract + +The IR lowers into the Rust stream API: + +- `source(collection)` -> `ConstraintFactory::::new().()` +- `join(self_equal|cross_keyed|cross_predicate)` -> `.join(...)` +- `filter(expr)` -> `.filter(...)` +- `impact` -> `.penalize_*()` / `.reward_*()` +- `name` -> `.named(...)` + +This keeps solving and scoring in Rust while preserving Python modeling ergonomics. + +## Current Limitations + +- Codegen currently targets common standard-variable patterns. +- Advanced list-variable selectors/phases are not yet emitted. +- Project writing creates a Rust crate artifact; packaging via PyO3/maturin is a follow-up step. diff --git a/docs/python-path2-postmortem.md b/docs/python-path2-postmortem.md new file mode 100644 index 00000000..4e7d09ed --- /dev/null +++ b/docs/python-path2-postmortem.md @@ -0,0 +1,79 @@ +# Python Path 2 vs Historical `solverforge-py` (Postmortem) + +This note compares the current Python Path 2 direction (IR -> Rust codegen -> compile) with the removed `crates/solverforge-py` experiment. + +## Historical Reference Point + +The latest commit where `crates/solverforge-py` still existed was: + +- `be76aaf` (2026-02-06) `refactor(py): remove Solver pyclass and unify API under SolverManager` + +The deletion happened at: + +- `559c57d` (2026-03-08) `chore: delete dynamic, py + all cranelift stuff; delete stub dotfiles that were used with zoyd` + +## Why the Old Experiment Failed (Structural Issues) + +### 1) Dynamic runtime model instead of typed compile-time model + +Old `solverforge-py` built solutions with dynamic descriptors and dynamic values at runtime: + +- `DynamicDescriptor`, `DynamicEntity`, `DynamicSolution`, `DynamicConstraintSet` +- runtime-defined classes and value ranges + +This created a separate dynamic execution path that diverged from the typed core. + +### 2) String-expression constraints + +Old constraints were built from string expressions like: + +- `"A.row == B.row"` +- `"field + 1"` + +and parsed at runtime with ad-hoc parsing logic (`parse_expr`, `parse_simple_expr`). + +This was fragile, hard to validate statically, and not aligned with typed stream APIs. + +### 3) Python API drift from Rust public API + +The old interface (`entity_class`, `add_entities`, string joins/filters) did not match the typed Rust modeling surface and lifecycle contracts. + +### 4) Lifecycle/telemetry mismatch + +The old manager API exposed coarse status strings and custom async controls, which did not map to retained `job/snapshot/checkpoint` semantics used in modern SolverForge. + +## Path 2 Correctives + +Path 2 intentionally avoids the above failure modes: + +1. **Typed IR, not dynamic runtime objects** + - Python describes model structure and expressions as a typed AST. + +2. **Compile to Rust, do not interpret in Python** + - Emit Rust structs/macros/constraint streams and compile. + +3. **No string DSL at runtime** + - Expressions are AST nodes lowered into Rust source. + +4. **Keep Rust as the only execution path** + - Scoring/moves/phases run in generated Rust. + +5. **Align with retained lifecycle contracts** + - Future bindings should expose `job`/`snapshot`/`events` directly rather than inventing a parallel lifecycle model. + +## Non-Negotiable Guardrails + +- No runtime expression parser for user strings. +- No dynamic scoring/move engine fork for Python. +- No Python callback execution in hot scoring/move loops. +- Generated code must target the same public SolverForge contracts used by Rust users. + +## Current Status + +The repository now contains a Python IR + codegen prototype under `python/solverforge_ir` and docs in `docs/python-model-ir.md`. + +Remaining work includes: + +- pyproject/maturin packaging for produced crates, +- list-variable parity, +- lifecycle bridge that forwards retained runtime events as web/SSE-friendly payloads. diff --git a/python/solverforge_ir/__init__.py b/python/solverforge_ir/__init__.py new file mode 100644 index 00000000..b8424b99 --- /dev/null +++ b/python/solverforge_ir/__init__.py @@ -0,0 +1,55 @@ +from .codegen import GeneratedRustProject, generate_rust_module, write_rust_project +from .model import ( + BoolExpr, + CallExpr, + CompareExpr, + ConstraintDef, + ConstraintSetDef, + ConstExpr, + EntityDef, + FactDef, + FieldDef, + FilterSpec, + ImpactSpec, + JoinSpec, + LambdaLoweringError, + ModelDef, + ModelValidationError, + RefExpr, + SolutionDef, + SolverDef, + Source, + TerminationDef, + VariableDef, + lambda_to_expr, + validate_model, +) + +__all__ = [ + "BoolExpr", + "CallExpr", + "CompareExpr", + "ConstraintDef", + "ConstraintSetDef", + "ConstExpr", + "EntityDef", + "FactDef", + "FieldDef", + "FilterSpec", + "ImpactSpec", + "JoinSpec", + "LambdaLoweringError", + "ModelDef", + "ModelValidationError", + "RefExpr", + "SolutionDef", + "SolverDef", + "Source", + "TerminationDef", + "VariableDef", + "lambda_to_expr", + "validate_model", + "GeneratedRustProject", + "generate_rust_module", + "write_rust_project", +] diff --git a/python/solverforge_ir/codegen.py b/python/solverforge_ir/codegen.py new file mode 100644 index 00000000..83cddb04 --- /dev/null +++ b/python/solverforge_ir/codegen.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +from .model import ( + BoolExpr, + CallExpr, + CompareExpr, + ConstraintDef, + ConstExpr, + EntityDef, + Expr, + ImpactSpec, + JoinSpec, + ModelDef, + RefExpr, +) + + +@dataclass(frozen=True) +class GeneratedRustProject: + crate_dir: Path + cargo_toml: Path + lib_rs: Path + + +_SCORE_MAP = { + "soft": "SoftScore", + "hard_soft": "HardSoftScore", + "hard_medium_soft": "HardMediumSoftScore", + "hard_soft_decimal": "HardSoftDecimalScore", + "bendable": "BendableScore<1, 1>", +} + + +def generate_rust_module(model: ModelDef) -> str: + score_ty = _score_type(model.solution.score_kind) + lines: list[str] = [] + lines.extend( + [ + "use solverforge::prelude::*;", + "use solverforge::stream::{joiner::*, ConstraintFactory};", + "", + ] + ) + + lines.extend(_emit_facts(model.facts)) + lines.extend(_emit_entities(model.entities)) + lines.extend(_emit_solution(model)) + + lines.append(f"use {model.solution.name}ConstraintStreams;") + lines.append("") + lines.append( + f"pub fn define_constraints() -> impl ConstraintSet<{model.solution.name}, {score_ty}> {{" + ) + + for idx, constraint in enumerate(model.constraints.constraints): + c_name = f"c_{idx}" + lines.extend(_emit_constraint_builder(model, constraint, c_name, score_ty)) + + tuple_items = ", ".join([f"c_{i}" for i in range(len(model.constraints.constraints))]) + lines.append(f" ({tuple_items})") + lines.append("}") + lines.append("") + + return "\n".join(lines) + + +def write_rust_project(model: ModelDef, out_dir: Path, crate_name: str = "solverforge_py_model") -> GeneratedRustProject: + crate_dir = out_dir / crate_name + src_dir = crate_dir / "src" + src_dir.mkdir(parents=True, exist_ok=True) + + cargo_toml = crate_dir / "Cargo.toml" + lib_rs = src_dir / "lib.rs" + + cargo_toml.write_text( + "\n".join( + [ + "[package]", + f'name = "{crate_name}"', + 'version = "0.1.0"', + 'edition = "2021"', + "", + "[dependencies]", + 'solverforge = { version = "0.8", features = ["console"] }', + "", + ] + ) + ) + lib_rs.write_text(generate_rust_module(model)) + + return GeneratedRustProject(crate_dir=crate_dir, cargo_toml=cargo_toml, lib_rs=lib_rs) + + +def _emit_facts(facts: Iterable) -> list[str]: + lines: list[str] = [] + for fact in facts: + lines.append("#[problem_fact]") + lines.append(f"pub struct {fact.name} {{") + for field in fact.fields: + rust_ty = _map_type(field.type_ref) + lines.append(f" pub {field.name}: {rust_ty},") + lines.append("}") + lines.append("") + return lines + + +def _emit_entities(entities: Iterable[EntityDef]) -> list[str]: + lines: list[str] = [] + for entity in entities: + variable_names = {v.name for v in entity.variables} + + lines.append("#[planning_entity]") + lines.append(f"pub struct {entity.name} {{") + for field in entity.fields: + if field.name == entity.planning_id_field: + lines.append(" #[planning_id]") + if field.name in variable_names: + lines.append(" #[planning_variable]") + rust_ty = _map_type(field.type_ref) + lines.append(f" pub {field.name}: {rust_ty},") + lines.append("}") + lines.append("") + return lines + + +def _emit_solution(model: ModelDef) -> list[str]: + score_ty = _score_type(model.solution.score_kind) + lines: list[str] = ["#[planning_solution]", f"pub struct {model.solution.name} {{"] + + for field_name, fact_name in model.solution.fact_collections.items(): + lines.append(" #[problem_fact_collection]") + lines.append(f" pub {field_name}: Vec<{fact_name}>,") + + for field_name, entity_name in model.solution.entity_collections.items(): + lines.append(" #[planning_entity_collection]") + lines.append(f" pub {field_name}: Vec<{entity_name}>,") + + lines.append(" #[planning_score]") + lines.append(f" pub score: Option<{score_ty}>,") + lines.append("}") + lines.append("") + return lines + + +def _emit_constraint_builder(model: ModelDef, constraint: ConstraintDef, c_name: str, score_ty: str) -> list[str]: + solution_name = model.solution.name + source_entity_name = model.solution.entity_collections[constraint.source.collection_field] + + lines = [ + f" let {c_name} = ConstraintFactory::<{solution_name}, {score_ty}>::new()", + f" .{constraint.source.collection_field}()", + ] + + current_aliases: list[str] = ["left"] + current_entity_names: list[str] = [source_entity_name] + + for join in constraint.joins: + right_entity = None + if join.right_collection_field: + right_entity = model.solution.entity_collections.get(join.right_collection_field) + if right_entity is None: + right_entity = model.solution.fact_collections[join.right_collection_field] + + if join.kind == "self_equal": + left_ty = current_entity_names[0] + left_expr = _rust_expr(join.left_key or RefExpr("left", ""), {"left": "left"}) + lines.append(f" .join(equal(|left: &{left_ty}| {left_expr}))") + current_aliases.append("right") + current_entity_names.append(left_ty) + + elif join.kind == "cross_keyed": + if not join.right_collection_field or not join.left_key or not join.right_key: + raise ValueError("cross_keyed join requires right_collection_field, left_key, right_key") + + left_ty = current_entity_names[0] + right_ty = right_entity + left_key_expr = _rust_expr(join.left_key, {"left": "left"}) + right_key_expr = _rust_expr(join.right_key, {"right": "right"}) + lines.extend( + [ + " .join((", + f" |s: &{solution_name}| &s.{join.right_collection_field},", + " equal_bi(", + f" |left: &{left_ty}| {left_key_expr},", + f" |right: &{right_ty}| {right_key_expr},", + " ),", + " ))", + ] + ) + current_aliases.append("right") + current_entity_names.append(right_ty) + + elif join.kind == "cross_predicate": + if not join.right_collection_field or not join.predicate: + raise ValueError("cross_predicate join requires right_collection_field and predicate") + left_ty = current_entity_names[0] + right_ty = right_entity + pred = _rust_expr(join.predicate, {"left": "left", "right": "right"}) + lines.extend( + [ + " .join((", + f" |s: &{solution_name}| &s.{join.right_collection_field},", + f" |left: &{left_ty}, right: &{right_ty}| {pred},", + " ))", + ] + ) + current_aliases.append("right") + current_entity_names.append(right_ty) + + for f in constraint.filters: + args = [] + alias_map: dict[str, str] = {} + for idx, alias in enumerate(current_aliases): + var_name = alias if idx == 0 else f"{alias}{idx}" + ty = current_entity_names[idx] + args.append(f"{var_name}: &{ty}") + alias_map[alias] = var_name + + pred = _rust_expr(f.predicate, alias_map) + args_sig = ", ".join(args) + lines.append(f" .filter(|{args_sig}| {pred})") + + lines.append(f" .{_impact_method(constraint.impact)}({constraint.impact.weight})") + lines.append(f" .named(\"{constraint.name}\");") + lines.append("") + return lines + + +def _impact_method(impact: ImpactSpec) -> str: + prefix = "penalize" if impact.impact == "penalize" else "reward" + return f"{prefix}_{impact.level}" + + +def _rust_expr(expr: Expr, alias_map: dict[str, str]) -> str: + if isinstance(expr, ConstExpr): + return _const_to_rust(expr.value) + + if isinstance(expr, RefExpr): + base = alias_map.get(expr.stream_alias, expr.stream_alias) + if expr.field_path: + return f"{base}.{expr.field_path}" + return base + + if isinstance(expr, CompareExpr): + left = _rust_expr(expr.left, alias_map) + right = _rust_expr(expr.right, alias_map) + return f"({left} {expr.op} {right})" + + if isinstance(expr, BoolExpr): + if expr.op == "not": + return f"(!{_rust_expr(expr.args[0], alias_map)})" + joiner = " && " if expr.op == "and" else " || " + return f"({joiner.join(_rust_expr(a, alias_map) for a in expr.args)})" + + if isinstance(expr, CallExpr): + args = [_rust_expr(a, alias_map) for a in expr.args] + if expr.fn == "contains": + if len(args) != 2: + raise ValueError("contains expects 2 args") + return f"{args[0]}.contains(&{args[1]})" + if expr.fn == "len": + if len(args) != 1: + raise ValueError("len expects 1 arg") + return f"{args[0]}.len()" + if expr.fn == "overlaps": + if len(args) != 4: + raise ValueError("overlaps expects 4 args: a_start,a_end,b_start,b_end") + return f"(({args[0]} < {args[3]}) && ({args[2]} < {args[1]}))" + raise ValueError(f"Unsupported call: {expr.fn}") + + raise TypeError(f"Unsupported expression type: {type(expr).__name__}") + + +def _const_to_rust(value: object) -> str: + if value is None: + return "None" + if value is True: + return "true" + if value is False: + return "false" + if isinstance(value, str): + return f'"{value}".to_string()' + return repr(value) + + +def _score_type(score_kind: str) -> str: + try: + return _SCORE_MAP[score_kind] + except KeyError as e: + raise ValueError(f"Unsupported score kind '{score_kind}'.") from e + + +def _map_type(type_ref: str) -> str: + cleaned = type_ref.strip() + + if cleaned.startswith("Option[") and cleaned.endswith("]"): + inner = cleaned[len("Option[") : -1] + return f"Option<{_map_type(inner)}>" + + if cleaned.startswith("Vec[") and cleaned.endswith("]"): + inner = cleaned[len("Vec[") : -1] + return f"Vec<{_map_type(inner)}>" + + primitives = { + "str": "String", + "string": "String", + "i64": "i64", + "int": "i64", + "f64": "f64", + "float": "f64", + "bool": "bool", + } + return primitives.get(cleaned, cleaned) diff --git a/python/solverforge_ir/model.py b/python/solverforge_ir/model.py new file mode 100644 index 00000000..b0b29b61 --- /dev/null +++ b/python/solverforge_ir/model.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import ast +import inspect +import textwrap +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Literal, Optional, Union + +ScoreKind = Literal["soft", "hard_soft", "hard_medium_soft", "bendable", "hard_soft_decimal"] +VarKind = Literal["basic", "list"] +JoinKind = Literal["self_equal", "cross_keyed", "cross_predicate"] +Impact = Literal["penalize", "reward"] +Level = Literal["hard", "medium", "soft"] + + +@dataclass(frozen=True) +class FieldDef: + name: str + type_ref: str + nullable: bool = False + + +@dataclass(frozen=True) +class FactDef: + name: str + fields: list[FieldDef] + + +@dataclass(frozen=True) +class VariableDef: + name: str + value_type: str + kind: VarKind + pinned_field: Optional[str] = None + + +@dataclass(frozen=True) +class EntityDef: + name: str + planning_id_field: str + fields: list[FieldDef] + variables: list[VariableDef] = field(default_factory=list) + + +@dataclass(frozen=True) +class SolutionDef: + name: str + score_kind: ScoreKind + entity_collections: dict[str, str] + fact_collections: dict[str, str] + + +@dataclass(frozen=True) +class RefExpr: + stream_alias: str + field_path: str + + +@dataclass(frozen=True) +class ConstExpr: + value: Any + + +@dataclass(frozen=True) +class CallExpr: + fn: str + args: list["Expr"] + + +@dataclass(frozen=True) +class CompareExpr: + op: Literal["==", "!=", "<", "<=", ">", ">="] + left: "Expr" + right: "Expr" + + +@dataclass(frozen=True) +class BoolExpr: + op: Literal["and", "or", "not"] + args: list["Expr"] + + +Expr = Union[RefExpr, ConstExpr, CallExpr, CompareExpr, BoolExpr] + + +@dataclass(frozen=True) +class Source: + collection_field: str + + +@dataclass(frozen=True) +class JoinSpec: + kind: JoinKind + right_collection_field: Optional[str] = None + left_key: Optional[Expr] = None + right_key: Optional[Expr] = None + predicate: Optional[Expr] = None + + +@dataclass(frozen=True) +class FilterSpec: + predicate: Expr + + +@dataclass(frozen=True) +class ImpactSpec: + impact: Impact + level: Level + weight: int = 1 + + +@dataclass(frozen=True) +class ConstraintDef: + name: str + source: Source + joins: list[JoinSpec] = field(default_factory=list) + filters: list[FilterSpec] = field(default_factory=list) + impact: ImpactSpec = field(default_factory=lambda: ImpactSpec("penalize", "hard", 1)) + + +@dataclass(frozen=True) +class ConstraintSetDef: + constraints: list[ConstraintDef] + + +@dataclass(frozen=True) +class TerminationDef: + time_limit_ms: Optional[int] = None + step_count_limit: Optional[int] = None + unimproved_time_limit_ms: Optional[int] = None + unimproved_step_limit: Optional[int] = None + + +@dataclass(frozen=True) +class SolverDef: + termination: TerminationDef + + +@dataclass(frozen=True) +class ModelDef: + facts: list[FactDef] + entities: list[EntityDef] + solution: SolutionDef + constraints: ConstraintSetDef + solver: SolverDef + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +_ALLOWED_CALLS = { + "contains", + "overlaps", + "len", +} + + +class LambdaLoweringError(ValueError): + pass + + +class ModelValidationError(ValueError): + pass + + +def lambda_to_expr(fn: Callable[..., Any], aliases: list[str]) -> Expr: + source = textwrap.dedent(inspect.getsource(fn)) + tree = ast.parse(source) + expr_node: ast.AST | None = None + + if isinstance(tree.body[0], ast.Expr) and isinstance(tree.body[0].value, ast.Lambda): + lam = tree.body[0].value + expr_node = lam.body + elif isinstance(tree.body[0], ast.Assign) and isinstance(tree.body[0].value, ast.Lambda): + lam = tree.body[0].value + expr_node = lam.body + elif isinstance(tree.body[0], ast.FunctionDef): + fn_def = tree.body[0] + returns = [n for n in fn_def.body if isinstance(n, ast.Return)] + if len(returns) != 1: + raise LambdaLoweringError("Function must have exactly one return statement.") + expr_node = returns[0].value + else: + raise LambdaLoweringError("Only lambda or single-return function expressions are supported.") + + return _lower_ast_expr(expr_node, aliases) + + +def _lower_ast_expr(node: ast.AST, aliases: list[str]) -> Expr: + if isinstance(node, ast.Constant): + return ConstExpr(node.value) + + if isinstance(node, ast.Name): + if node.id in aliases: + return RefExpr(stream_alias=node.id, field_path="") + raise LambdaLoweringError(f"Unknown name '{node.id}'.") + + if isinstance(node, ast.Attribute): + parts: list[str] = [] + cur = node + while isinstance(cur, ast.Attribute): + parts.append(cur.attr) + cur = cur.value + if isinstance(cur, ast.Name) and cur.id in aliases: + return RefExpr(stream_alias=cur.id, field_path=".".join(reversed(parts))) + raise LambdaLoweringError("Attribute access must start from a known stream alias.") + + if isinstance(node, ast.Compare): + if len(node.ops) != 1 or len(node.comparators) != 1: + raise LambdaLoweringError("Only single comparisons are supported.") + op_map = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + } + for k, v in op_map.items(): + if isinstance(node.ops[0], k): + return CompareExpr( + op=v, + left=_lower_ast_expr(node.left, aliases), + right=_lower_ast_expr(node.comparators[0], aliases), + ) + raise LambdaLoweringError("Unsupported comparison operator.") + + if isinstance(node, ast.BoolOp): + if isinstance(node.op, ast.And): + return BoolExpr(op="and", args=[_lower_ast_expr(v, aliases) for v in node.values]) + if isinstance(node.op, ast.Or): + return BoolExpr(op="or", args=[_lower_ast_expr(v, aliases) for v in node.values]) + raise LambdaLoweringError("Unsupported boolean operator.") + + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return BoolExpr(op="not", args=[_lower_ast_expr(node.operand, aliases)]) + + if isinstance(node, ast.Call): + fn_name = _call_name(node.func) + if fn_name not in _ALLOWED_CALLS: + raise LambdaLoweringError(f"Call '{fn_name}' is not in the allowed call whitelist.") + args = [_lower_ast_expr(a, aliases) for a in node.args] + return CallExpr(fn=fn_name, args=args) + + raise LambdaLoweringError(f"Unsupported syntax node: {type(node).__name__}") + + +def _call_name(func_node: ast.AST) -> str: + if isinstance(func_node, ast.Name): + return func_node.id + if isinstance(func_node, ast.Attribute): + return func_node.attr + raise LambdaLoweringError("Unsupported callable expression.") + + +def validate_model(model: ModelDef) -> None: + entity_names = {e.name for e in model.entities} + fact_names = {f.name for f in model.facts} + + if len(entity_names) != len(model.entities): + raise ModelValidationError("Duplicate entity names are not allowed.") + if len(fact_names) != len(model.facts): + raise ModelValidationError("Duplicate fact names are not allowed.") + + for field_name, entity_name in model.solution.entity_collections.items(): + if entity_name not in entity_names: + raise ModelValidationError( + f"Solution entity collection '{field_name}' references unknown entity '{entity_name}'." + ) + + for field_name, fact_name in model.solution.fact_collections.items(): + if fact_name not in fact_names: + raise ModelValidationError( + f"Solution fact collection '{field_name}' references unknown fact '{fact_name}'." + ) + + all_collections = set(model.solution.entity_collections) | set(model.solution.fact_collections) + + for constraint in model.constraints.constraints: + if constraint.source.collection_field not in all_collections: + raise ModelValidationError( + f"Constraint '{constraint.name}' source references unknown collection " + f"'{constraint.source.collection_field}'." + ) + if not constraint.name: + raise ModelValidationError("Constraint name must not be empty.") + + for join in constraint.joins: + if join.kind in {"cross_keyed", "cross_predicate"} and not join.right_collection_field: + raise ModelValidationError( + f"Constraint '{constraint.name}' has join '{join.kind}' without right collection." + ) + if join.right_collection_field and join.right_collection_field not in all_collections: + raise ModelValidationError( + f"Constraint '{constraint.name}' join references unknown collection " + f"'{join.right_collection_field}'." + ) + if join.kind == "cross_keyed" and (join.left_key is None or join.right_key is None): + raise ModelValidationError( + f"Constraint '{constraint.name}' keyed join requires left_key and right_key." + ) + if join.kind == "cross_predicate" and join.predicate is None: + raise ModelValidationError( + f"Constraint '{constraint.name}' predicate join requires predicate expression." + ) diff --git a/python/tests/test_codegen.py b/python/tests/test_codegen.py new file mode 100644 index 00000000..c92a83f8 --- /dev/null +++ b/python/tests/test_codegen.py @@ -0,0 +1,112 @@ +import tempfile +import unittest +from pathlib import Path + +from solverforge_ir.codegen import generate_rust_module, write_rust_project +from solverforge_ir.model import ( + CallExpr, + CompareExpr, + ConstraintDef, + ConstraintSetDef, + ConstExpr, + EntityDef, + FactDef, + FieldDef, + FilterSpec, + ImpactSpec, + JoinSpec, + ModelDef, + RefExpr, + SolutionDef, + SolverDef, + Source, + TerminationDef, + VariableDef, +) + + +def _schedule_model() -> ModelDef: + return ModelDef( + facts=[ + FactDef(name="Employee", fields=[FieldDef("id", "i64"), FieldDef("skills", "Vec[str]")]) + ], + entities=[ + EntityDef( + name="Shift", + planning_id_field="id", + fields=[ + FieldDef("id", "i64"), + FieldDef("employee_id", "Option[i64]", nullable=True), + FieldDef("required_skill", "str"), + ], + variables=[VariableDef(name="employee_id", value_type="Option[i64]", kind="basic")], + ) + ], + solution=SolutionDef( + name="Schedule", + score_kind="hard_soft", + entity_collections={"shifts": "Shift"}, + fact_collections={"employees": "Employee"}, + ), + constraints=ConstraintSetDef( + constraints=[ + ConstraintDef( + name="required_skill", + source=Source("shifts"), + joins=[ + JoinSpec( + kind="cross_keyed", + right_collection_field="employees", + left_key=RefExpr("left", "employee_id"), + right_key=RefExpr("right", "id"), + ) + ], + filters=[ + FilterSpec( + predicate=CallExpr( + "contains", + [RefExpr("right", "skills"), RefExpr("left", "required_skill")], + ) + ) + ], + impact=ImpactSpec("penalize", "hard", 1), + ), + ConstraintDef( + name="assigned", + source=Source("shifts"), + filters=[ + FilterSpec( + predicate=CompareExpr( + "!=", + left=RefExpr("left", "employee_id"), + right=ConstExpr(None), + ) + ) + ], + impact=ImpactSpec("reward", "soft", 1), + ), + ] + ), + solver=SolverDef(termination=TerminationDef(step_count_limit=1000)), + ) + + +class TestCodegen(unittest.TestCase): + def test_generate_rust_module_contains_core_builders(self) -> None: + text = generate_rust_module(_schedule_model()) + + self.assertIn("#[planning_solution]", text) + self.assertIn("ConstraintFactory::::new()", text) + self.assertIn('.named("required_skill")', text) + self.assertIn('.reward_soft(1)', text) + + def test_write_project_emits_cargo_and_lib(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + out = write_rust_project(_schedule_model(), Path(tmp), crate_name="demo_py_model") + self.assertTrue(out.cargo_toml.exists()) + self.assertTrue(out.lib_rs.exists()) + self.assertIn('name = "demo_py_model"', out.cargo_toml.read_text()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_model_ir.py b/python/tests/test_model_ir.py new file mode 100644 index 00000000..44e10770 --- /dev/null +++ b/python/tests/test_model_ir.py @@ -0,0 +1,103 @@ +import unittest + +from solverforge_ir.model import ( + CompareExpr, + ConstExpr, + ConstraintDef, + ConstraintSetDef, + EntityDef, + FactDef, + FieldDef, + FilterSpec, + ImpactSpec, + JoinSpec, + ModelDef, + ModelValidationError, + SolutionDef, + SolverDef, + Source, + TerminationDef, + lambda_to_expr, + validate_model, +) + + +def _valid_model() -> ModelDef: + employee = FactDef( + name="Employee", + fields=[FieldDef("id", "i64"), FieldDef("skills", "Vec[str]")], + ) + shift = EntityDef( + name="Shift", + planning_id_field="id", + fields=[ + FieldDef("id", "i64"), + FieldDef("employee_id", "Option[i64]", nullable=True), + FieldDef("required_skill", "str"), + ], + ) + + required_skill = ConstraintDef( + name="required_skill", + source=Source("shifts"), + joins=[ + JoinSpec( + kind="cross_keyed", + right_collection_field="employees", + left_key=ConstExpr("shift.employee_id"), + right_key=ConstExpr("employee.id"), + ) + ], + filters=[FilterSpec(predicate=ConstExpr(True))], + impact=ImpactSpec(impact="penalize", level="hard", weight=1), + ) + + return ModelDef( + facts=[employee], + entities=[shift], + solution=SolutionDef( + name="Schedule", + score_kind="hard_soft", + entity_collections={"shifts": "Shift"}, + fact_collections={"employees": "Employee"}, + ), + constraints=ConstraintSetDef([required_skill]), + solver=SolverDef(TerminationDef(step_count_limit=10000)), + ) + + +def shift_unassigned(shift): + return shift.employee_id == None + + +class TestModelIr(unittest.TestCase): + def test_lambda_lowering_comparison(self) -> None: + expr = lambda_to_expr(shift_unassigned, aliases=["shift"]) + self.assertIsInstance(expr, CompareExpr) + self.assertEqual(expr.op, "==") + + def test_model_validation_ok(self) -> None: + model = _valid_model() + validate_model(model) + + def test_model_validation_unknown_collection(self) -> None: + model = _valid_model() + bad = ConstraintDef( + name="bad", + source=Source("missing_collection"), + impact=ImpactSpec("penalize", "hard", 1), + ) + broken = ModelDef( + facts=model.facts, + entities=model.entities, + solution=model.solution, + constraints=ConstraintSetDef(model.constraints.constraints + [bad]), + solver=model.solver, + ) + + with self.assertRaises(ModelValidationError): + validate_model(broken) + + +if __name__ == "__main__": + unittest.main() From 4386f63d8383efed58b111ec334de198c2df2498 Mon Sep 17 00:00:00 2001 From: Vittorio Distefano Date: Sun, 12 Apr 2026 16:30:26 +0200 Subject: [PATCH 2/2] docs(python): keep Path 2 design docs-only --- README.md | 4 +- docs/python-model-ir.md | 32 +-- docs/python-path2-postmortem.md | 13 +- python/solverforge_ir/__init__.py | 55 ------ python/solverforge_ir/codegen.py | 317 ------------------------------ python/solverforge_ir/model.py | 305 ---------------------------- python/tests/test_codegen.py | 112 ----------- python/tests/test_model_ir.py | 103 ---------- 8 files changed, 32 insertions(+), 909 deletions(-) delete mode 100644 python/solverforge_ir/__init__.py delete mode 100644 python/solverforge_ir/codegen.py delete mode 100644 python/solverforge_ir/model.py delete mode 100644 python/tests/test_codegen.py delete mode 100644 python/tests/test_model_ir.py diff --git a/README.md b/README.md index d22141ac..ac08a570 100644 --- a/README.md +++ b/README.md @@ -45,8 +45,8 @@ The current templates cover standard-variable and list-heavy planning models, an - `README.md` is the user-facing entry point for the workspace and generated-project integration model. - `docs/extend-solver.md` and `docs/extend-domain.md` cover scaffold extension workflows. - `docs/lifecycle-pause-resume-contract.md` defines the retained lifecycle contract, including exact pause/resume semantics, snapshot identity, and terminal-state cleanup rules. -- `docs/python-model-ir.md` outlines the Python-first declarative model IR prototype that lowers into typed Rust solver code. -- `docs/python-path2-postmortem.md` compares Path 2 against the removed `solverforge-py` experiment and records hard guardrails. +- `docs/python-model-ir.md` outlines the proposed Python-first declarative model IR and the intended lowering contract into typed Rust solver code. +- `docs/python-path2-postmortem.md` compares Path 2 against the removed `solverforge-py` experiment, records hard guardrails, and explains why any implementation should live outside this workspace. - `docs/typed-contract-audit.md` records the current neutral selector and extractor naming model, including the `EntityCollectionExtractor`, `ValueSelector`, and `MoveSelector` surface adopted in `0.7.0`. - `crates/*/WIREFRAME.md` files are the canonical public API maps for each crate. - `AGENTS.md` defines repository-level engineering and documentation expectations for coding agents. diff --git a/docs/python-model-ir.md b/docs/python-model-ir.md index ee366099..6d46b833 100644 --- a/docs/python-model-ir.md +++ b/docs/python-model-ir.md @@ -1,6 +1,6 @@ # Python Model IR (Path 2: Codegen + Compile) -This document defines the Python-first model surface that lowers into typed SolverForge Rust code, then compiles as a Rust crate. +This document defines the proposed Python-first model surface that should lower into typed SolverForge Rust code, then compile as a Rust crate in a standalone integration repository. ## Goals @@ -8,14 +8,20 @@ This document defines the Python-first model surface that lowers into typed Solv - Let Python users model the same planning constructs and lifecycle workflows. - Keep the Python/Rust bridge thin: compile once, run in Rust, stream lifecycle events. -## Modules - ## Historical Context Path 2 guardrails were derived from the removed `solverforge-py` experiment; see `docs/python-path2-postmortem.md`. -- IR schema and validation: `python/solverforge_ir/model.py` -- Rust code generation and project writer: `python/solverforge_ir/codegen.py` +This workspace intentionally keeps Path 2 at the documentation level. Any Python implementation should live outside the SolverForge Rust workspace and consume the public `solverforge` API as a client. + +## Planned Modules + +A standalone Python integration should roughly split into: + +- IR schema and validation +- Expression lowering +- Rust code generation and project writing +- Build/runtime bridge around compiled generated crates ## Design @@ -36,7 +42,7 @@ Expressions are represented as an AST (not executable Python callbacks): ## Lambda Lowering -`lambda_to_expr(fn, aliases)` lowers a restricted subset of Python lambda/function syntax into the expression AST: +A convenience helper such as `lambda_to_expr(fn, aliases)` can lower a restricted subset of Python lambda/function syntax into the expression AST: - Attribute references from known stream aliases - `==`, `!=`, `<`, `<=`, `>`, `>=` @@ -47,7 +53,7 @@ Unsupported constructs fail fast with `LambdaLoweringError`. ## Validation -`validate_model(model)` performs structural validation: +`validate_model(model)` should perform structural validation: - Unique entity/fact names - Solution collection references target known entities/facts @@ -56,19 +62,19 @@ Unsupported constructs fail fast with `LambdaLoweringError`. ## Code Generation (Path 2) -`generate_rust_module(model)` emits Rust source with: +A generator such as `generate_rust_module(model)` should emit Rust source with: - Domain structs annotated by `#[problem_fact]`, `#[planning_entity]`, `#[planning_solution]` - Typed `define_constraints()` function using `ConstraintFactory` and fluent stream builders - Join lowering for `self_equal`, `cross_keyed`, and `cross_predicate` - Filter/impact/name lowering per constraint -`write_rust_project(model, out_dir, crate_name)` writes a compilable crate: +A project writer such as `write_rust_project(model, out_dir, crate_name)` should write a compilable crate: - `Cargo.toml` - `src/lib.rs` -Returned metadata (`GeneratedRustProject`) points to generated paths. +Returned metadata should point to generated paths and build artifacts. ## Intended Lowering Contract @@ -84,6 +90,6 @@ This keeps solving and scoring in Rust while preserving Python modeling ergonomi ## Current Limitations -- Codegen currently targets common standard-variable patterns. -- Advanced list-variable selectors/phases are not yet emitted. -- Project writing creates a Rust crate artifact; packaging via PyO3/maturin is a follow-up step. +- The first production scope should target common standard-variable patterns only. +- Advanced list-variable selectors/phases should be designed as a separate lowering track, not implied by the initial IR. +- Packaging, native build/import flow, and lifecycle bridging should live in the standalone Python integration repo rather than this Rust workspace. diff --git a/docs/python-path2-postmortem.md b/docs/python-path2-postmortem.md index 4e7d09ed..4ee4cc3c 100644 --- a/docs/python-path2-postmortem.md +++ b/docs/python-path2-postmortem.md @@ -1,6 +1,6 @@ # Python Path 2 vs Historical `solverforge-py` (Postmortem) -This note compares the current Python Path 2 direction (IR -> Rust codegen -> compile) with the removed `crates/solverforge-py` experiment. +This note compares the proposed Python Path 2 direction (IR -> Rust codegen -> compile) with the removed `crates/solverforge-py` experiment. ## Historical Reference Point @@ -68,12 +68,21 @@ Path 2 intentionally avoids the above failure modes: - No Python callback execution in hot scoring/move loops. - Generated code must target the same public SolverForge contracts used by Rust users. +## Repository Boundary + +This workspace should remain Rust-first and docs-first for Path 2. + +- Keep the design and guardrails in this repository. +- Build any Python implementation in a standalone repository that depends on the published/public `solverforge` surface. +- Do not reintroduce a `python/` implementation subtree or a second solver runtime inside this workspace. + ## Current Status -The repository now contains a Python IR + codegen prototype under `python/solverforge_ir` and docs in `docs/python-model-ir.md`. +This repository now documents the intended direction in `docs/python-model-ir.md` and records the architectural guardrails here. Remaining work includes: +- creating the standalone Python integration repository, - pyproject/maturin packaging for produced crates, - list-variable parity, - lifecycle bridge that forwards retained runtime events as web/SSE-friendly payloads. diff --git a/python/solverforge_ir/__init__.py b/python/solverforge_ir/__init__.py deleted file mode 100644 index b8424b99..00000000 --- a/python/solverforge_ir/__init__.py +++ /dev/null @@ -1,55 +0,0 @@ -from .codegen import GeneratedRustProject, generate_rust_module, write_rust_project -from .model import ( - BoolExpr, - CallExpr, - CompareExpr, - ConstraintDef, - ConstraintSetDef, - ConstExpr, - EntityDef, - FactDef, - FieldDef, - FilterSpec, - ImpactSpec, - JoinSpec, - LambdaLoweringError, - ModelDef, - ModelValidationError, - RefExpr, - SolutionDef, - SolverDef, - Source, - TerminationDef, - VariableDef, - lambda_to_expr, - validate_model, -) - -__all__ = [ - "BoolExpr", - "CallExpr", - "CompareExpr", - "ConstraintDef", - "ConstraintSetDef", - "ConstExpr", - "EntityDef", - "FactDef", - "FieldDef", - "FilterSpec", - "ImpactSpec", - "JoinSpec", - "LambdaLoweringError", - "ModelDef", - "ModelValidationError", - "RefExpr", - "SolutionDef", - "SolverDef", - "Source", - "TerminationDef", - "VariableDef", - "lambda_to_expr", - "validate_model", - "GeneratedRustProject", - "generate_rust_module", - "write_rust_project", -] diff --git a/python/solverforge_ir/codegen.py b/python/solverforge_ir/codegen.py deleted file mode 100644 index 83cddb04..00000000 --- a/python/solverforge_ir/codegen.py +++ /dev/null @@ -1,317 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from pathlib import Path -from typing import Iterable - -from .model import ( - BoolExpr, - CallExpr, - CompareExpr, - ConstraintDef, - ConstExpr, - EntityDef, - Expr, - ImpactSpec, - JoinSpec, - ModelDef, - RefExpr, -) - - -@dataclass(frozen=True) -class GeneratedRustProject: - crate_dir: Path - cargo_toml: Path - lib_rs: Path - - -_SCORE_MAP = { - "soft": "SoftScore", - "hard_soft": "HardSoftScore", - "hard_medium_soft": "HardMediumSoftScore", - "hard_soft_decimal": "HardSoftDecimalScore", - "bendable": "BendableScore<1, 1>", -} - - -def generate_rust_module(model: ModelDef) -> str: - score_ty = _score_type(model.solution.score_kind) - lines: list[str] = [] - lines.extend( - [ - "use solverforge::prelude::*;", - "use solverforge::stream::{joiner::*, ConstraintFactory};", - "", - ] - ) - - lines.extend(_emit_facts(model.facts)) - lines.extend(_emit_entities(model.entities)) - lines.extend(_emit_solution(model)) - - lines.append(f"use {model.solution.name}ConstraintStreams;") - lines.append("") - lines.append( - f"pub fn define_constraints() -> impl ConstraintSet<{model.solution.name}, {score_ty}> {{" - ) - - for idx, constraint in enumerate(model.constraints.constraints): - c_name = f"c_{idx}" - lines.extend(_emit_constraint_builder(model, constraint, c_name, score_ty)) - - tuple_items = ", ".join([f"c_{i}" for i in range(len(model.constraints.constraints))]) - lines.append(f" ({tuple_items})") - lines.append("}") - lines.append("") - - return "\n".join(lines) - - -def write_rust_project(model: ModelDef, out_dir: Path, crate_name: str = "solverforge_py_model") -> GeneratedRustProject: - crate_dir = out_dir / crate_name - src_dir = crate_dir / "src" - src_dir.mkdir(parents=True, exist_ok=True) - - cargo_toml = crate_dir / "Cargo.toml" - lib_rs = src_dir / "lib.rs" - - cargo_toml.write_text( - "\n".join( - [ - "[package]", - f'name = "{crate_name}"', - 'version = "0.1.0"', - 'edition = "2021"', - "", - "[dependencies]", - 'solverforge = { version = "0.8", features = ["console"] }', - "", - ] - ) - ) - lib_rs.write_text(generate_rust_module(model)) - - return GeneratedRustProject(crate_dir=crate_dir, cargo_toml=cargo_toml, lib_rs=lib_rs) - - -def _emit_facts(facts: Iterable) -> list[str]: - lines: list[str] = [] - for fact in facts: - lines.append("#[problem_fact]") - lines.append(f"pub struct {fact.name} {{") - for field in fact.fields: - rust_ty = _map_type(field.type_ref) - lines.append(f" pub {field.name}: {rust_ty},") - lines.append("}") - lines.append("") - return lines - - -def _emit_entities(entities: Iterable[EntityDef]) -> list[str]: - lines: list[str] = [] - for entity in entities: - variable_names = {v.name for v in entity.variables} - - lines.append("#[planning_entity]") - lines.append(f"pub struct {entity.name} {{") - for field in entity.fields: - if field.name == entity.planning_id_field: - lines.append(" #[planning_id]") - if field.name in variable_names: - lines.append(" #[planning_variable]") - rust_ty = _map_type(field.type_ref) - lines.append(f" pub {field.name}: {rust_ty},") - lines.append("}") - lines.append("") - return lines - - -def _emit_solution(model: ModelDef) -> list[str]: - score_ty = _score_type(model.solution.score_kind) - lines: list[str] = ["#[planning_solution]", f"pub struct {model.solution.name} {{"] - - for field_name, fact_name in model.solution.fact_collections.items(): - lines.append(" #[problem_fact_collection]") - lines.append(f" pub {field_name}: Vec<{fact_name}>,") - - for field_name, entity_name in model.solution.entity_collections.items(): - lines.append(" #[planning_entity_collection]") - lines.append(f" pub {field_name}: Vec<{entity_name}>,") - - lines.append(" #[planning_score]") - lines.append(f" pub score: Option<{score_ty}>,") - lines.append("}") - lines.append("") - return lines - - -def _emit_constraint_builder(model: ModelDef, constraint: ConstraintDef, c_name: str, score_ty: str) -> list[str]: - solution_name = model.solution.name - source_entity_name = model.solution.entity_collections[constraint.source.collection_field] - - lines = [ - f" let {c_name} = ConstraintFactory::<{solution_name}, {score_ty}>::new()", - f" .{constraint.source.collection_field}()", - ] - - current_aliases: list[str] = ["left"] - current_entity_names: list[str] = [source_entity_name] - - for join in constraint.joins: - right_entity = None - if join.right_collection_field: - right_entity = model.solution.entity_collections.get(join.right_collection_field) - if right_entity is None: - right_entity = model.solution.fact_collections[join.right_collection_field] - - if join.kind == "self_equal": - left_ty = current_entity_names[0] - left_expr = _rust_expr(join.left_key or RefExpr("left", ""), {"left": "left"}) - lines.append(f" .join(equal(|left: &{left_ty}| {left_expr}))") - current_aliases.append("right") - current_entity_names.append(left_ty) - - elif join.kind == "cross_keyed": - if not join.right_collection_field or not join.left_key or not join.right_key: - raise ValueError("cross_keyed join requires right_collection_field, left_key, right_key") - - left_ty = current_entity_names[0] - right_ty = right_entity - left_key_expr = _rust_expr(join.left_key, {"left": "left"}) - right_key_expr = _rust_expr(join.right_key, {"right": "right"}) - lines.extend( - [ - " .join((", - f" |s: &{solution_name}| &s.{join.right_collection_field},", - " equal_bi(", - f" |left: &{left_ty}| {left_key_expr},", - f" |right: &{right_ty}| {right_key_expr},", - " ),", - " ))", - ] - ) - current_aliases.append("right") - current_entity_names.append(right_ty) - - elif join.kind == "cross_predicate": - if not join.right_collection_field or not join.predicate: - raise ValueError("cross_predicate join requires right_collection_field and predicate") - left_ty = current_entity_names[0] - right_ty = right_entity - pred = _rust_expr(join.predicate, {"left": "left", "right": "right"}) - lines.extend( - [ - " .join((", - f" |s: &{solution_name}| &s.{join.right_collection_field},", - f" |left: &{left_ty}, right: &{right_ty}| {pred},", - " ))", - ] - ) - current_aliases.append("right") - current_entity_names.append(right_ty) - - for f in constraint.filters: - args = [] - alias_map: dict[str, str] = {} - for idx, alias in enumerate(current_aliases): - var_name = alias if idx == 0 else f"{alias}{idx}" - ty = current_entity_names[idx] - args.append(f"{var_name}: &{ty}") - alias_map[alias] = var_name - - pred = _rust_expr(f.predicate, alias_map) - args_sig = ", ".join(args) - lines.append(f" .filter(|{args_sig}| {pred})") - - lines.append(f" .{_impact_method(constraint.impact)}({constraint.impact.weight})") - lines.append(f" .named(\"{constraint.name}\");") - lines.append("") - return lines - - -def _impact_method(impact: ImpactSpec) -> str: - prefix = "penalize" if impact.impact == "penalize" else "reward" - return f"{prefix}_{impact.level}" - - -def _rust_expr(expr: Expr, alias_map: dict[str, str]) -> str: - if isinstance(expr, ConstExpr): - return _const_to_rust(expr.value) - - if isinstance(expr, RefExpr): - base = alias_map.get(expr.stream_alias, expr.stream_alias) - if expr.field_path: - return f"{base}.{expr.field_path}" - return base - - if isinstance(expr, CompareExpr): - left = _rust_expr(expr.left, alias_map) - right = _rust_expr(expr.right, alias_map) - return f"({left} {expr.op} {right})" - - if isinstance(expr, BoolExpr): - if expr.op == "not": - return f"(!{_rust_expr(expr.args[0], alias_map)})" - joiner = " && " if expr.op == "and" else " || " - return f"({joiner.join(_rust_expr(a, alias_map) for a in expr.args)})" - - if isinstance(expr, CallExpr): - args = [_rust_expr(a, alias_map) for a in expr.args] - if expr.fn == "contains": - if len(args) != 2: - raise ValueError("contains expects 2 args") - return f"{args[0]}.contains(&{args[1]})" - if expr.fn == "len": - if len(args) != 1: - raise ValueError("len expects 1 arg") - return f"{args[0]}.len()" - if expr.fn == "overlaps": - if len(args) != 4: - raise ValueError("overlaps expects 4 args: a_start,a_end,b_start,b_end") - return f"(({args[0]} < {args[3]}) && ({args[2]} < {args[1]}))" - raise ValueError(f"Unsupported call: {expr.fn}") - - raise TypeError(f"Unsupported expression type: {type(expr).__name__}") - - -def _const_to_rust(value: object) -> str: - if value is None: - return "None" - if value is True: - return "true" - if value is False: - return "false" - if isinstance(value, str): - return f'"{value}".to_string()' - return repr(value) - - -def _score_type(score_kind: str) -> str: - try: - return _SCORE_MAP[score_kind] - except KeyError as e: - raise ValueError(f"Unsupported score kind '{score_kind}'.") from e - - -def _map_type(type_ref: str) -> str: - cleaned = type_ref.strip() - - if cleaned.startswith("Option[") and cleaned.endswith("]"): - inner = cleaned[len("Option[") : -1] - return f"Option<{_map_type(inner)}>" - - if cleaned.startswith("Vec[") and cleaned.endswith("]"): - inner = cleaned[len("Vec[") : -1] - return f"Vec<{_map_type(inner)}>" - - primitives = { - "str": "String", - "string": "String", - "i64": "i64", - "int": "i64", - "f64": "f64", - "float": "f64", - "bool": "bool", - } - return primitives.get(cleaned, cleaned) diff --git a/python/solverforge_ir/model.py b/python/solverforge_ir/model.py deleted file mode 100644 index b0b29b61..00000000 --- a/python/solverforge_ir/model.py +++ /dev/null @@ -1,305 +0,0 @@ -from __future__ import annotations - -import ast -import inspect -import textwrap -from dataclasses import asdict, dataclass, field -from typing import Any, Callable, Literal, Optional, Union - -ScoreKind = Literal["soft", "hard_soft", "hard_medium_soft", "bendable", "hard_soft_decimal"] -VarKind = Literal["basic", "list"] -JoinKind = Literal["self_equal", "cross_keyed", "cross_predicate"] -Impact = Literal["penalize", "reward"] -Level = Literal["hard", "medium", "soft"] - - -@dataclass(frozen=True) -class FieldDef: - name: str - type_ref: str - nullable: bool = False - - -@dataclass(frozen=True) -class FactDef: - name: str - fields: list[FieldDef] - - -@dataclass(frozen=True) -class VariableDef: - name: str - value_type: str - kind: VarKind - pinned_field: Optional[str] = None - - -@dataclass(frozen=True) -class EntityDef: - name: str - planning_id_field: str - fields: list[FieldDef] - variables: list[VariableDef] = field(default_factory=list) - - -@dataclass(frozen=True) -class SolutionDef: - name: str - score_kind: ScoreKind - entity_collections: dict[str, str] - fact_collections: dict[str, str] - - -@dataclass(frozen=True) -class RefExpr: - stream_alias: str - field_path: str - - -@dataclass(frozen=True) -class ConstExpr: - value: Any - - -@dataclass(frozen=True) -class CallExpr: - fn: str - args: list["Expr"] - - -@dataclass(frozen=True) -class CompareExpr: - op: Literal["==", "!=", "<", "<=", ">", ">="] - left: "Expr" - right: "Expr" - - -@dataclass(frozen=True) -class BoolExpr: - op: Literal["and", "or", "not"] - args: list["Expr"] - - -Expr = Union[RefExpr, ConstExpr, CallExpr, CompareExpr, BoolExpr] - - -@dataclass(frozen=True) -class Source: - collection_field: str - - -@dataclass(frozen=True) -class JoinSpec: - kind: JoinKind - right_collection_field: Optional[str] = None - left_key: Optional[Expr] = None - right_key: Optional[Expr] = None - predicate: Optional[Expr] = None - - -@dataclass(frozen=True) -class FilterSpec: - predicate: Expr - - -@dataclass(frozen=True) -class ImpactSpec: - impact: Impact - level: Level - weight: int = 1 - - -@dataclass(frozen=True) -class ConstraintDef: - name: str - source: Source - joins: list[JoinSpec] = field(default_factory=list) - filters: list[FilterSpec] = field(default_factory=list) - impact: ImpactSpec = field(default_factory=lambda: ImpactSpec("penalize", "hard", 1)) - - -@dataclass(frozen=True) -class ConstraintSetDef: - constraints: list[ConstraintDef] - - -@dataclass(frozen=True) -class TerminationDef: - time_limit_ms: Optional[int] = None - step_count_limit: Optional[int] = None - unimproved_time_limit_ms: Optional[int] = None - unimproved_step_limit: Optional[int] = None - - -@dataclass(frozen=True) -class SolverDef: - termination: TerminationDef - - -@dataclass(frozen=True) -class ModelDef: - facts: list[FactDef] - entities: list[EntityDef] - solution: SolutionDef - constraints: ConstraintSetDef - solver: SolverDef - - def to_dict(self) -> dict[str, Any]: - return asdict(self) - - -_ALLOWED_CALLS = { - "contains", - "overlaps", - "len", -} - - -class LambdaLoweringError(ValueError): - pass - - -class ModelValidationError(ValueError): - pass - - -def lambda_to_expr(fn: Callable[..., Any], aliases: list[str]) -> Expr: - source = textwrap.dedent(inspect.getsource(fn)) - tree = ast.parse(source) - expr_node: ast.AST | None = None - - if isinstance(tree.body[0], ast.Expr) and isinstance(tree.body[0].value, ast.Lambda): - lam = tree.body[0].value - expr_node = lam.body - elif isinstance(tree.body[0], ast.Assign) and isinstance(tree.body[0].value, ast.Lambda): - lam = tree.body[0].value - expr_node = lam.body - elif isinstance(tree.body[0], ast.FunctionDef): - fn_def = tree.body[0] - returns = [n for n in fn_def.body if isinstance(n, ast.Return)] - if len(returns) != 1: - raise LambdaLoweringError("Function must have exactly one return statement.") - expr_node = returns[0].value - else: - raise LambdaLoweringError("Only lambda or single-return function expressions are supported.") - - return _lower_ast_expr(expr_node, aliases) - - -def _lower_ast_expr(node: ast.AST, aliases: list[str]) -> Expr: - if isinstance(node, ast.Constant): - return ConstExpr(node.value) - - if isinstance(node, ast.Name): - if node.id in aliases: - return RefExpr(stream_alias=node.id, field_path="") - raise LambdaLoweringError(f"Unknown name '{node.id}'.") - - if isinstance(node, ast.Attribute): - parts: list[str] = [] - cur = node - while isinstance(cur, ast.Attribute): - parts.append(cur.attr) - cur = cur.value - if isinstance(cur, ast.Name) and cur.id in aliases: - return RefExpr(stream_alias=cur.id, field_path=".".join(reversed(parts))) - raise LambdaLoweringError("Attribute access must start from a known stream alias.") - - if isinstance(node, ast.Compare): - if len(node.ops) != 1 or len(node.comparators) != 1: - raise LambdaLoweringError("Only single comparisons are supported.") - op_map = { - ast.Eq: "==", - ast.NotEq: "!=", - ast.Lt: "<", - ast.LtE: "<=", - ast.Gt: ">", - ast.GtE: ">=", - } - for k, v in op_map.items(): - if isinstance(node.ops[0], k): - return CompareExpr( - op=v, - left=_lower_ast_expr(node.left, aliases), - right=_lower_ast_expr(node.comparators[0], aliases), - ) - raise LambdaLoweringError("Unsupported comparison operator.") - - if isinstance(node, ast.BoolOp): - if isinstance(node.op, ast.And): - return BoolExpr(op="and", args=[_lower_ast_expr(v, aliases) for v in node.values]) - if isinstance(node.op, ast.Or): - return BoolExpr(op="or", args=[_lower_ast_expr(v, aliases) for v in node.values]) - raise LambdaLoweringError("Unsupported boolean operator.") - - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): - return BoolExpr(op="not", args=[_lower_ast_expr(node.operand, aliases)]) - - if isinstance(node, ast.Call): - fn_name = _call_name(node.func) - if fn_name not in _ALLOWED_CALLS: - raise LambdaLoweringError(f"Call '{fn_name}' is not in the allowed call whitelist.") - args = [_lower_ast_expr(a, aliases) for a in node.args] - return CallExpr(fn=fn_name, args=args) - - raise LambdaLoweringError(f"Unsupported syntax node: {type(node).__name__}") - - -def _call_name(func_node: ast.AST) -> str: - if isinstance(func_node, ast.Name): - return func_node.id - if isinstance(func_node, ast.Attribute): - return func_node.attr - raise LambdaLoweringError("Unsupported callable expression.") - - -def validate_model(model: ModelDef) -> None: - entity_names = {e.name for e in model.entities} - fact_names = {f.name for f in model.facts} - - if len(entity_names) != len(model.entities): - raise ModelValidationError("Duplicate entity names are not allowed.") - if len(fact_names) != len(model.facts): - raise ModelValidationError("Duplicate fact names are not allowed.") - - for field_name, entity_name in model.solution.entity_collections.items(): - if entity_name not in entity_names: - raise ModelValidationError( - f"Solution entity collection '{field_name}' references unknown entity '{entity_name}'." - ) - - for field_name, fact_name in model.solution.fact_collections.items(): - if fact_name not in fact_names: - raise ModelValidationError( - f"Solution fact collection '{field_name}' references unknown fact '{fact_name}'." - ) - - all_collections = set(model.solution.entity_collections) | set(model.solution.fact_collections) - - for constraint in model.constraints.constraints: - if constraint.source.collection_field not in all_collections: - raise ModelValidationError( - f"Constraint '{constraint.name}' source references unknown collection " - f"'{constraint.source.collection_field}'." - ) - if not constraint.name: - raise ModelValidationError("Constraint name must not be empty.") - - for join in constraint.joins: - if join.kind in {"cross_keyed", "cross_predicate"} and not join.right_collection_field: - raise ModelValidationError( - f"Constraint '{constraint.name}' has join '{join.kind}' without right collection." - ) - if join.right_collection_field and join.right_collection_field not in all_collections: - raise ModelValidationError( - f"Constraint '{constraint.name}' join references unknown collection " - f"'{join.right_collection_field}'." - ) - if join.kind == "cross_keyed" and (join.left_key is None or join.right_key is None): - raise ModelValidationError( - f"Constraint '{constraint.name}' keyed join requires left_key and right_key." - ) - if join.kind == "cross_predicate" and join.predicate is None: - raise ModelValidationError( - f"Constraint '{constraint.name}' predicate join requires predicate expression." - ) diff --git a/python/tests/test_codegen.py b/python/tests/test_codegen.py deleted file mode 100644 index c92a83f8..00000000 --- a/python/tests/test_codegen.py +++ /dev/null @@ -1,112 +0,0 @@ -import tempfile -import unittest -from pathlib import Path - -from solverforge_ir.codegen import generate_rust_module, write_rust_project -from solverforge_ir.model import ( - CallExpr, - CompareExpr, - ConstraintDef, - ConstraintSetDef, - ConstExpr, - EntityDef, - FactDef, - FieldDef, - FilterSpec, - ImpactSpec, - JoinSpec, - ModelDef, - RefExpr, - SolutionDef, - SolverDef, - Source, - TerminationDef, - VariableDef, -) - - -def _schedule_model() -> ModelDef: - return ModelDef( - facts=[ - FactDef(name="Employee", fields=[FieldDef("id", "i64"), FieldDef("skills", "Vec[str]")]) - ], - entities=[ - EntityDef( - name="Shift", - planning_id_field="id", - fields=[ - FieldDef("id", "i64"), - FieldDef("employee_id", "Option[i64]", nullable=True), - FieldDef("required_skill", "str"), - ], - variables=[VariableDef(name="employee_id", value_type="Option[i64]", kind="basic")], - ) - ], - solution=SolutionDef( - name="Schedule", - score_kind="hard_soft", - entity_collections={"shifts": "Shift"}, - fact_collections={"employees": "Employee"}, - ), - constraints=ConstraintSetDef( - constraints=[ - ConstraintDef( - name="required_skill", - source=Source("shifts"), - joins=[ - JoinSpec( - kind="cross_keyed", - right_collection_field="employees", - left_key=RefExpr("left", "employee_id"), - right_key=RefExpr("right", "id"), - ) - ], - filters=[ - FilterSpec( - predicate=CallExpr( - "contains", - [RefExpr("right", "skills"), RefExpr("left", "required_skill")], - ) - ) - ], - impact=ImpactSpec("penalize", "hard", 1), - ), - ConstraintDef( - name="assigned", - source=Source("shifts"), - filters=[ - FilterSpec( - predicate=CompareExpr( - "!=", - left=RefExpr("left", "employee_id"), - right=ConstExpr(None), - ) - ) - ], - impact=ImpactSpec("reward", "soft", 1), - ), - ] - ), - solver=SolverDef(termination=TerminationDef(step_count_limit=1000)), - ) - - -class TestCodegen(unittest.TestCase): - def test_generate_rust_module_contains_core_builders(self) -> None: - text = generate_rust_module(_schedule_model()) - - self.assertIn("#[planning_solution]", text) - self.assertIn("ConstraintFactory::::new()", text) - self.assertIn('.named("required_skill")', text) - self.assertIn('.reward_soft(1)', text) - - def test_write_project_emits_cargo_and_lib(self) -> None: - with tempfile.TemporaryDirectory() as tmp: - out = write_rust_project(_schedule_model(), Path(tmp), crate_name="demo_py_model") - self.assertTrue(out.cargo_toml.exists()) - self.assertTrue(out.lib_rs.exists()) - self.assertIn('name = "demo_py_model"', out.cargo_toml.read_text()) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/tests/test_model_ir.py b/python/tests/test_model_ir.py deleted file mode 100644 index 44e10770..00000000 --- a/python/tests/test_model_ir.py +++ /dev/null @@ -1,103 +0,0 @@ -import unittest - -from solverforge_ir.model import ( - CompareExpr, - ConstExpr, - ConstraintDef, - ConstraintSetDef, - EntityDef, - FactDef, - FieldDef, - FilterSpec, - ImpactSpec, - JoinSpec, - ModelDef, - ModelValidationError, - SolutionDef, - SolverDef, - Source, - TerminationDef, - lambda_to_expr, - validate_model, -) - - -def _valid_model() -> ModelDef: - employee = FactDef( - name="Employee", - fields=[FieldDef("id", "i64"), FieldDef("skills", "Vec[str]")], - ) - shift = EntityDef( - name="Shift", - planning_id_field="id", - fields=[ - FieldDef("id", "i64"), - FieldDef("employee_id", "Option[i64]", nullable=True), - FieldDef("required_skill", "str"), - ], - ) - - required_skill = ConstraintDef( - name="required_skill", - source=Source("shifts"), - joins=[ - JoinSpec( - kind="cross_keyed", - right_collection_field="employees", - left_key=ConstExpr("shift.employee_id"), - right_key=ConstExpr("employee.id"), - ) - ], - filters=[FilterSpec(predicate=ConstExpr(True))], - impact=ImpactSpec(impact="penalize", level="hard", weight=1), - ) - - return ModelDef( - facts=[employee], - entities=[shift], - solution=SolutionDef( - name="Schedule", - score_kind="hard_soft", - entity_collections={"shifts": "Shift"}, - fact_collections={"employees": "Employee"}, - ), - constraints=ConstraintSetDef([required_skill]), - solver=SolverDef(TerminationDef(step_count_limit=10000)), - ) - - -def shift_unassigned(shift): - return shift.employee_id == None - - -class TestModelIr(unittest.TestCase): - def test_lambda_lowering_comparison(self) -> None: - expr = lambda_to_expr(shift_unassigned, aliases=["shift"]) - self.assertIsInstance(expr, CompareExpr) - self.assertEqual(expr.op, "==") - - def test_model_validation_ok(self) -> None: - model = _valid_model() - validate_model(model) - - def test_model_validation_unknown_collection(self) -> None: - model = _valid_model() - bad = ConstraintDef( - name="bad", - source=Source("missing_collection"), - impact=ImpactSpec("penalize", "hard", 1), - ) - broken = ModelDef( - facts=model.facts, - entities=model.entities, - solution=model.solution, - constraints=ConstraintSetDef(model.constraints.constraints + [bad]), - solver=model.solver, - ) - - with self.assertRaises(ModelValidationError): - validate_model(broken) - - -if __name__ == "__main__": - unittest.main()