diff --git a/effectful/handlers/llm/encodable_type.py b/effectful/handlers/llm/encodable_type.py new file mode 100644 index 00000000..32c6f51e --- /dev/null +++ b/effectful/handlers/llm/encodable_type.py @@ -0,0 +1,193 @@ +"""Encodable type for LLM-synthesized classes.""" + +import ast +import collections +import ctypes +import inspect +import linecache +import sys +import textwrap +import types +import typing +from collections import ChainMap +from typing import Any + +import pydantic +from pydantic import Field + +from effectful.handlers.llm.encoding import EncodableAs, type_to_encodable_type +from effectful.handlers.llm.providers import OpenAIMessageContentListBlock +from effectful.handlers.llm.synthesis import SynthesisError + + +class _PyMappingProxyObject(ctypes.Structure): + """Internal ctypes structure to access the underlying dict of a mappingproxy.""" + + _fields_ = [ + ("ob_refcnt", ctypes.c_ssize_t), + ("ob_type", ctypes.py_object), + ("mapping", ctypes.py_object), + ] + + +class SynthesizedType(pydantic.BaseModel): + """Structured output for type/class synthesis. + + Pydantic model representing synthesized class code with type name and module code. + """ + + type_name: str = Field( + ..., + description="The name of the class that satisfies the specification", + ) + module_code: str = Field( + ..., + description="Complete Python module code with the class definition (no imports needed)", + ) + + +@type_to_encodable_type.register(type) +class EncodableSynthesizedType( + EncodableAs[type, SynthesizedType], +): + """Encodes type to SynthesizedType and vice versa.""" + + t = SynthesizedType + + @classmethod + def encode( + cls, vl: type, context: ChainMap[str, Any] | None = None + ) -> SynthesizedType: + """Encode a type to a SynthesizedType. + + Extracts the type name and source code. + """ + type_name = vl.__name__ + try: + source = inspect.getsource(vl) + except (OSError, TypeError): + # If we can't get source, create a minimal representation + source = f"class {type_name}: pass # Source unavailable" + + return SynthesizedType( + type_name=type_name, module_code=textwrap.dedent(source).strip() + ) + + # Counter for unique filenames + _decode_counter: typing.ClassVar[int] = 0 + + @classmethod + def decode(cls, vl: SynthesizedType) -> type: + """Decode a SynthesizedType to a type. + + Executes the module code and returns the named class. + Uses _decode_context attribute on vl if present (set by TypeSynthesis handler). + """ + context: ChainMap[str, Any] | None = getattr(vl, "_decode_context", None) + type_name = vl.type_name + module_code = textwrap.dedent(vl.module_code).strip() + "\n" + + # Create a unique filename and register source with linecache + # This allows inspect.getsource() to work on the generated class + cls._decode_counter += 1 + # NOTE: adding source to class is more tricky + # because for function func.__code__.co_filename (set by compile(..., filename, "exec")) is set automatically + # We have to do this manually for class (set module name) for inspect.getsource() to work + module_name = ( + f"_llm_effectful_synthesized_types.{type_name}.{cls._decode_counter}" + ) + filename = f"" + + # Register source for inspect/linecache + lines = module_code.splitlines(keepends=True) + # Ensure last line has newline for linecache + if lines and not lines[-1].endswith("\n"): + lines[-1] += "\n" + linecache.cache[filename] = ( + len(module_code), + None, + lines, + filename, + ) + + # Create a real module and put it to sys.modules + mod = types.ModuleType(module_name) + mod.__file__ = filename + sys.modules[module_name] = mod + + # globals = module.__dict__ + context + g = mod.__dict__ + g.update({"collections": collections}) + if context is not None: + g.update(context) + g.update({"__name__": module_name, "__file__": filename}) + g.setdefault("__package__", module_name.rpartition(".")[0]) + + try: + # NOTE: Parse and inject __firstlineno__ into class bodies for Python 3.13+ compatibility + # inspect.getsource() looks for __firstlineno__ in vars(cls), which requires it to be in the class's __dict__. + # We inject it via AST before execution. + tree = ast.parse(module_code) + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + # Create: __firstlineno__ = + assign = ast.Assign( + targets=[ast.Name(id="__firstlineno__", ctx=ast.Store())], + value=ast.Constant(value=node.lineno), + lineno=node.lineno, + col_offset=0, + ) + ast.fix_missing_locations(assign) + node.body.insert(0, assign) + ast.fix_missing_locations(tree) + code_obj = compile(tree, filename, "exec") + exec(code_obj, g, g) + except SyntaxError as exc: + raise SynthesisError( + f"Syntax error in generated code: {exc}", module_code + ) from exc + except Exception as exc: + raise SynthesisError(f"Evaluation failed: {exc!r}", module_code) from exc + + if type_name not in g: + raise SynthesisError( + f"Type '{type_name}' not found after execution. " + f"Available names: {[k for k in g.keys() if not k.startswith('_')]}", + module_code, + ) + + synthesized_type = g[type_name] + + if not isinstance(synthesized_type, type): + raise SynthesisError( + f"'{type_name}' is not a type, got {type(synthesized_type).__name__}", + module_code, + ) + + # Attach source code and module name + synthesized_type.__source__ = module_code # type: ignore[attr-defined] + synthesized_type.__synthesized__ = vl # type: ignore[attr-defined] + synthesized_type.__module__ = module_name + + # NOTE: Set __firstlineno__ AFTER __module__ assignment! + # In Python 3.13, setting __module__ clears __firstlineno__ from vars(). + # We use ctypes to directly inject it into __dict__ for inspect.getsource(). + if "__firstlineno__" not in vars(synthesized_type): + firstlineno = next( + ( + n.lineno + for n in ast.walk(ast.parse(module_code)) + if isinstance(n, ast.ClassDef) and n.name == type_name + ), + 1, + ) + inner_dict = _PyMappingProxyObject.from_address( + id(vars(synthesized_type)) + ).mapping + inner_dict["__firstlineno__"] = firstlineno + + return synthesized_type + + @classmethod + def serialize(cls, vl: SynthesizedType) -> list[OpenAIMessageContentListBlock]: + return [{"type": "text", "text": vl.model_dump_json()}] diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py index 41999a45..742afd78 100644 --- a/tests/test_handlers_llm_encoding.py +++ b/tests/test_handlers_llm_encoding.py @@ -1,3 +1,5 @@ +import inspect +from collections import ChainMap from dataclasses import asdict, dataclass from typing import NamedTuple, TypedDict @@ -5,7 +7,12 @@ import pytest from PIL import Image +from effectful.handlers.llm.encodable_type import ( + EncodableSynthesizedType, + SynthesizedType, +) from effectful.handlers.llm.encoding import type_to_encodable_type +from effectful.handlers.llm.synthesis import SynthesisError from effectful.ops.types import Operation, Term @@ -718,3 +725,249 @@ class Person(pydantic.BaseModel): assert decoded_from_model == person assert isinstance(decoded_from_model, Person) assert isinstance(decoded_from_model.address, Address) + + +class TestEncodableSynthesizedType: + """Tests for EncodableSynthesizedType encode/decode functionality.""" + + def test_decode_simple_class(self): + """Test decoding a simple class from SynthesizedType.""" + synth = SynthesizedType( + type_name="Greeter", + module_code="""\ +class Greeter: + def greet(self, name: str) -> str: + return f"Hello, {name}!" +""", + ) + + result = EncodableSynthesizedType.decode(synth) + + assert isinstance(result, type) + assert result.__name__ == "Greeter" + + # Test instantiation and method call + instance = result() + assert instance.greet("World") == "Hello, World!" + + def test_decode_with_inheritance(self): + """Test decoding a class that inherits from a base class in context.""" + + class Animal: + def speak(self) -> str: + raise NotImplementedError + + synth = SynthesizedType( + type_name="Dog", + module_code="""\ +class Dog(Animal): + def speak(self) -> str: + return "Woof!" +""", + ) + + # Attach context with base class + object.__setattr__(synth, "_decode_context", ChainMap({"Animal": Animal})) + + result = EncodableSynthesizedType.decode(synth) + + assert isinstance(result, type) + assert issubclass(result, Animal) + assert result.__name__ == "Dog" + + instance = result() + assert instance.speak() == "Woof!" + + def test_decode_attaches_source_attribute(self): + """Test that decoded types have __source__ attribute.""" + synth = SynthesizedType( + type_name="Simple", + module_code="class Simple:\n pass", + ) + + result = EncodableSynthesizedType.decode(synth) + + assert hasattr(result, "__source__") + assert "class Simple" in result.__source__ + + def test_decode_attaches_synthesized_attribute(self): + """Test that decoded types have __synthesized__ attribute.""" + synth = SynthesizedType( + type_name="Simple", + module_code="class Simple:\n pass", + ) + + result = EncodableSynthesizedType.decode(synth) + + assert hasattr(result, "__synthesized__") + assert result.__synthesized__ is synth + + def test_decode_inspect_getsource_works(self): + """Test that inspect.getsource() works on synthesized types.""" + synth = SynthesizedType( + type_name="Documented", + module_code='''\ +class Documented: + """A documented class.""" + + def method(self) -> int: + return 42 +''', + ) + + result = EncodableSynthesizedType.decode(synth) + source = inspect.getsource(result) + + assert "class Documented" in source + assert "A documented class" in source + assert "def method" in source + assert source == result.__source__ + + def test_decode_with_helper_in_class(self): + """Test decoding a class that uses a helper method.""" + synth = SynthesizedType( + type_name="Counter", + module_code="""\ +class Counter: + def __init__(self): + self.value = 0 + + def _increment(self, x): + return x + 1 + + def increment(self): + self.value = self._increment(self.value) + return self.value +""", + ) + + result = EncodableSynthesizedType.decode(synth) + instance = result() + + assert instance.increment() == 1 + assert instance.increment() == 2 + assert instance.increment() == 3 + + def test_decode_syntax_error_raises_synthesis_error(self): + """Test that syntax errors raise SynthesisError.""" + synth = SynthesizedType( + type_name="Broken", + module_code="class Broken\n pass # missing colon", + ) + + with pytest.raises(SynthesisError, match="Syntax error"): + EncodableSynthesizedType.decode(synth) + + def test_decode_missing_type_raises_synthesis_error(self): + """Test that missing type name raises SynthesisError.""" + synth = SynthesizedType( + type_name="Missing", + module_code="class WrongName:\n pass", + ) + + with pytest.raises(SynthesisError, match="not found after execution"): + EncodableSynthesizedType.decode(synth) + + def test_decode_non_type_raises_synthesis_error(self): + """Test that non-type result raises SynthesisError.""" + synth = SynthesizedType( + type_name="NotAType", + module_code="NotAType = 42", + ) + + with pytest.raises(SynthesisError, match="is not a type"): + EncodableSynthesizedType.decode(synth) + + def test_encode_simple_class(self): + """Test encoding a simple class to SynthesizedType.""" + + class MyClass: + def method(self) -> str: + return "hello" + + result = EncodableSynthesizedType.encode(MyClass) + + assert isinstance(result, SynthesizedType) + assert result.type_name == "MyClass" + assert "class MyClass" in result.module_code + assert "def method" in result.module_code + + def test_encode_builtin_class_fallback(self): + """Test encoding a builtin class (source unavailable) uses fallback.""" + # int is a builtin, so inspect.getsource() will fail + result = EncodableSynthesizedType.encode(int) + + assert isinstance(result, SynthesizedType) + assert result.type_name == "int" + assert "class int" in result.module_code + assert "Source unavailable" in result.module_code + + def test_serialize_produces_json(self): + """Test that serialize produces valid JSON content blocks.""" + synth = SynthesizedType( + type_name="TestType", + module_code="class TestType:\n pass", + ) + + result = EncodableSynthesizedType.serialize(synth) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["type"] == "text" + # Verify it's valid JSON + import json + + parsed = json.loads(result[0]["text"]) + assert parsed["type_name"] == "TestType" + assert "class TestType" in parsed["module_code"] + + def test_decode_unique_module_names(self): + """Test that each decoded type gets a unique module name.""" + synth1 = SynthesizedType( + type_name="Unique", + module_code="class Unique:\n value = 1", + ) + synth2 = SynthesizedType( + type_name="Unique", + module_code="class Unique:\n value = 2", + ) + + result1 = EncodableSynthesizedType.decode(synth1) + result2 = EncodableSynthesizedType.decode(synth2) + + # Both should be different types with different module names + assert result1 is not result2 + assert result1.__module__ != result2.__module__ + assert result1.value == 1 + assert result2.value == 2 + + def test_decode_context_with_multiple_items(self): + """Test decoding with context containing multiple items.""" + + class BaseA: + pass + + class BaseB: + pass + + def helper() -> int: + return 100 + + synth = SynthesizedType( + type_name="Combined", + module_code="""\ +class Combined(BaseA, BaseB): + def get_value(self) -> int: + return helper() +""", + ) + + context = ChainMap({"BaseA": BaseA, "BaseB": BaseB, "helper": helper}) + object.__setattr__(synth, "_decode_context", context) + + result = EncodableSynthesizedType.decode(synth) + + assert issubclass(result, BaseA) + assert issubclass(result, BaseB) + instance = result() + assert instance.get_value() == 100