Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sys
import time
import typing
from functools import wraps
from logging import getLogger
from pathlib import Path

Expand All @@ -20,7 +21,7 @@
from openhexa.sdk.utils import Environment, Settings, get_environment

from .parameter import FunctionWithParameter, Parameter, ParameterValueError
from .task import PipelineWithTask, Task
from .task import P, R, Task
from .utils import get_local_workspace_config

logger = getLogger(__name__)
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(
self.functional_type = functional_type
self.tasks = []

def task(self, function) -> PipelineWithTask:
def task(self, function: typing.Callable[P, R]) -> typing.Callable[P, R]:
"""Task decorator.

Examples
Expand All @@ -78,7 +79,23 @@ def task(self, function) -> PipelineWithTask:
... def task_2(foo: int):
... pass
"""
return PipelineWithTask(function, self)

@wraps(function)
def wrapper(*task_args: P.args, **task_kwargs: P.kwargs) -> R:
"""Attach task to the decorated pipeline and return it.

NB: We claim to return type R but we actually return a Task object.
This is for better DX when writing pipeline DAGs in IDEs.
"""
task = Task(function)(*task_args, **task_kwargs)
self.tasks.append(task)
return task # type: ignore[return-value]

# store references to original function
wrapper._original_function = function # type: ignore
wrapper._pipeline = self # type: ignore

return wrapper

def run(self, config: dict[str, typing.Any]):
"""Run the pipeline using the provided config.
Expand Down
12 changes: 9 additions & 3 deletions openhexa/sdk/pipelines/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import datetime
import typing
from functools import wraps

import openhexa.sdk.pipelines.pipeline

Expand Down Expand Up @@ -127,18 +128,23 @@ def __repr__(self):
return self.name


class PipelineWithTask:
P = typing.ParamSpec("P")
R = typing.TypeVar("R")


class PipelineWithTask(typing.Generic[P, R]):
"""Pipeline with attached tasks, usually through the @task decorator."""

def __init__(
self,
function: typing.Callable,
function: typing.Callable[P, R],
pipeline: openhexa.sdk.pipelines.Pipeline,
):
self.function = function
self.pipeline = pipeline
wraps(function)(self)

def __call__(self, *task_args, **task_kwargs) -> Task:
def __call__(self, *task_args: typing.Any, **task_kwargs: typing.Any) -> Task:
"""Attach the new task to the decorated pipeline and return it."""
task = Task(self.function)(*task_args, **task_kwargs)
self.pipeline.tasks.append(task)
Expand Down
Loading