diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index ed19f777..004f19f8 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -11,6 +11,7 @@ import sys import time import typing +from functools import wraps from logging import getLogger from pathlib import Path @@ -20,7 +21,7 @@ from openhexa.sdk.utils import Environment, Settings, get_environment from .parameter import FunctionWithParameter, Parameter, ParameterValueError -from .task import PipelineWithTask, Task +from .task import P, R, Task from .utils import get_local_workspace_config logger = getLogger(__name__) @@ -60,7 +61,7 @@ def __init__( self.functional_type = functional_type self.tasks = [] - def task(self, function) -> PipelineWithTask: + def task(self, function: typing.Callable[P, R]) -> typing.Callable[P, R]: """Task decorator. Examples @@ -78,7 +79,23 @@ def task(self, function) -> PipelineWithTask: ... def task_2(foo: int): ... pass """ - return PipelineWithTask(function, self) + + @wraps(function) + def wrapper(*task_args: P.args, **task_kwargs: P.kwargs) -> R: + """Attach task to the decorated pipeline and return it. + + NB: We claim to return type R but we actually return a Task object. + This is for better DX when writing pipeline DAGs in IDEs. + """ + task = Task(function)(*task_args, **task_kwargs) + self.tasks.append(task) + return task # type: ignore[return-value] + + # store references to original function + wrapper._original_function = function # type: ignore + wrapper._pipeline = self # type: ignore + + return wrapper def run(self, config: dict[str, typing.Any]): """Run the pipeline using the provided config. diff --git a/openhexa/sdk/pipelines/task.py b/openhexa/sdk/pipelines/task.py index 3a6da086..dcc7152d 100644 --- a/openhexa/sdk/pipelines/task.py +++ b/openhexa/sdk/pipelines/task.py @@ -7,6 +7,7 @@ import datetime import typing +from functools import wraps import openhexa.sdk.pipelines.pipeline @@ -127,18 +128,23 @@ def __repr__(self): return self.name -class PipelineWithTask: +P = typing.ParamSpec("P") +R = typing.TypeVar("R") + + +class PipelineWithTask(typing.Generic[P, R]): """Pipeline with attached tasks, usually through the @task decorator.""" def __init__( self, - function: typing.Callable, + function: typing.Callable[P, R], pipeline: openhexa.sdk.pipelines.Pipeline, ): self.function = function self.pipeline = pipeline + wraps(function)(self) - def __call__(self, *task_args, **task_kwargs) -> Task: + def __call__(self, *task_args: typing.Any, **task_kwargs: typing.Any) -> Task: """Attach the new task to the decorated pipeline and return it.""" task = Task(self.function)(*task_args, **task_kwargs) self.pipeline.tasks.append(task)