diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a860c82..8c437a3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`766` moves runtime profiling persistence from SQLite to a JSON snapshot plus append-only journal in `.pytask/`, keeping runtime data resilient to crashes and compacted on normal build exits. +- {pull}`776` clears decoration-time `annotation_locals` snapshots after collection so + task functions remain picklable in process-based parallel backends. ## 0.5.8 - 2025-12-30 diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index d19e555b..d18fd909 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -90,9 +90,23 @@ def pytask_collect(session: Session) -> bool: session=session, reports=session.collection_reports, tasks=session.tasks ) + _clear_annotation_locals(session.tasks) + return True +def _clear_annotation_locals(tasks: list[PTask]) -> None: + """Drop decoration-time locals snapshots once collection finishes. + + The snapshot is only needed to evaluate deferred annotations while collecting + dependencies/products. Keeping it afterwards can retain non-picklable objects (for + example locks) and break parallel backends that cloudpickle task functions. + """ + for task in tasks: + if isinstance(task.function, TaskFunction): + task.function.pytask_meta.annotation_locals = None + + def _collect_from_paths(session: Session) -> None: """Collect tasks from paths. diff --git a/tests/test_collect.py b/tests/test_collect.py index fd22b251..7e4d0695 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -5,6 +5,7 @@ import warnings from pathlib import Path +import cloudpickle import pytest from _pytask.collect import _find_shortest_uniquely_identifiable_name_for_tasks @@ -404,6 +405,32 @@ def task_example() -> 'Annotated[str, OUTPUT]': assert tmp_path.joinpath("out.txt").exists() +def test_annotation_locals_are_cleared_after_collection_to_allow_pickling(tmp_path): + source = """ + import threading + + from pytask import task + + lock = threading.RLock() + + for i in range(2): + @task + def task_example(): + return None + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = build(paths=tmp_path, dry_run=True) + assert session.exit_code == ExitCode.OK + assert len(session.tasks) == 2 + + for collected_task in session.tasks: + meta = getattr(collected_task.function, "pytask_meta", None) + assert meta is not None + assert meta.annotation_locals is None + cloudpickle.dumps(collected_task.function) + + def test_collect_string_product_raises_error_with_annotation(runner, tmp_path): """The string is not converted to a path.""" source = """