From ab706a466f48b273176b875ece0d90436214f310 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Wed, 15 Oct 2025 18:10:23 +0200 Subject: [PATCH 1/5] Make PipelineWithTask generic to preserve original function metadata --- openhexa/sdk/pipelines/task.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/openhexa/sdk/pipelines/task.py b/openhexa/sdk/pipelines/task.py index 3a6da086..853f0ae7 100644 --- a/openhexa/sdk/pipelines/task.py +++ b/openhexa/sdk/pipelines/task.py @@ -7,6 +7,7 @@ import datetime import typing +from functools import wraps import openhexa.sdk.pipelines.pipeline @@ -127,18 +128,23 @@ def __repr__(self): return self.name -class PipelineWithTask: +P = typing.ParamSpec("P") +R = typing.TypeVar("R") + + +class PipelineWithTask(typing.Generic[P, R]): """Pipeline with attached tasks, usually through the @task decorator.""" def __init__( self, - function: typing.Callable, + function: typing.Callable[P, R], pipeline: openhexa.sdk.pipelines.Pipeline, ): self.function = function self.pipeline = pipeline + wraps(function)(self) - def __call__(self, *task_args, **task_kwargs) -> Task: + def __call__(self, *task_args: P.args, **task_kwargs: P.kwargs) -> Task: """Attach the new task to the decorated pipeline and return it.""" task = Task(self.function)(*task_args, **task_kwargs) self.pipeline.tasks.append(task) From f6996ef8c0f0e64e3b98858999d97a332811f635 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Wed, 15 Oct 2025 18:12:10 +0200 Subject: [PATCH 2/5] Update Pipeline.task decorator signature --- openhexa/sdk/pipelines/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index ed19f777..fffa6831 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -20,7 +20,7 @@ from openhexa.sdk.utils import Environment, Settings, get_environment from .parameter import FunctionWithParameter, Parameter, ParameterValueError -from .task import PipelineWithTask, Task +from .task import P, PipelineWithTask, R, Task from .utils import get_local_workspace_config logger = getLogger(__name__) @@ -60,7 +60,7 @@ def __init__( self.functional_type = functional_type self.tasks = [] - def task(self, function) -> PipelineWithTask: + def task(self, function: typing.Callable[P, R]) -> PipelineWithTask[P, R]: """Task decorator. Examples From eb73f19830878a1ebfaa2f954acfcf4822cca938 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Wed, 15 Oct 2025 18:26:49 +0200 Subject: [PATCH 3/5] Loosen type checking of task args --- openhexa/sdk/pipelines/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openhexa/sdk/pipelines/task.py b/openhexa/sdk/pipelines/task.py index 853f0ae7..dcc7152d 100644 --- a/openhexa/sdk/pipelines/task.py +++ b/openhexa/sdk/pipelines/task.py @@ -144,7 +144,7 @@ def __init__( self.pipeline = pipeline wraps(function)(self) - def __call__(self, *task_args: P.args, **task_kwargs: P.kwargs) -> Task: + def __call__(self, *task_args: typing.Any, **task_kwargs: typing.Any) -> Task: """Attach the new task to the decorated pipeline and return it.""" task = Task(self.function)(*task_args, **task_kwargs) self.pipeline.tasks.append(task) From f88dbb7dc0fd6fec011bf752c73a80f207337cc9 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Wed, 15 Oct 2025 18:41:15 +0200 Subject: [PATCH 4/5] Claim to return function return type instead of Task --- openhexa/sdk/pipelines/pipeline.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index fffa6831..f14d565a 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -11,6 +11,7 @@ import sys import time import typing +from functools import wraps from logging import getLogger from pathlib import Path @@ -60,7 +61,7 @@ def __init__( self.functional_type = functional_type self.tasks = [] - def task(self, function: typing.Callable[P, R]) -> PipelineWithTask[P, R]: + def task(self, function: typing.Callable[P, R]) -> typing.Callable[P, R]: """Task decorator. Examples @@ -78,7 +79,23 @@ def task(self, function: typing.Callable[P, R]) -> PipelineWithTask[P, R]: ... def task_2(foo: int): ... pass """ - return PipelineWithTask(function, self) + + @wraps(function) + def wrapper(*task_args: P.args, **task_kwargs: P.kwargs) -> R: + """Attach task to the decorated pipeline and return it. + + NB: We claim to return type R but we actually return a Task object. + This is for better DX when writing pipeline DAGs in IDEs. + """ + task = Task(function)(*task_args, **task_kwargs) + self.tasks.append(task) + return task # type: ignore[return-value] + + # store references to original function + wrapper._original_function = function # type: ignore + wrapper._pipeline = self # type: ignore + + return wrapper def run(self, config: dict[str, typing.Any]): """Run the pipeline using the provided config. From 37a21026998852ec26e95c4a3d6fd276ed6cd90e Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Wed, 15 Oct 2025 20:13:55 +0200 Subject: [PATCH 5/5] Ruff formatting --- openhexa/sdk/pipelines/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index f14d565a..004f19f8 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -21,7 +21,7 @@ from openhexa.sdk.utils import Environment, Settings, get_environment from .parameter import FunctionWithParameter, Parameter, ParameterValueError -from .task import P, PipelineWithTask, R, Task +from .task import P, R, Task from .utils import get_local_workspace_config logger = getLogger(__name__)