Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gokart/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class WorkerProtocol(Protocol):
This protocol is determined by luigi.worker.Worker.
"""

def add(self, task: TaskOnKart) -> bool: ...
def add(self, task: TaskOnKart[Any]) -> bool: ...

def run(self) -> bool: ...

Expand Down Expand Up @@ -124,7 +124,7 @@ class TaskDumpConfig:
output_type: TaskDumpOutputType = TaskDumpOutputType.NONE


def process_task_info(task: TaskOnKart, task_dump_config: TaskDumpConfig = TaskDumpConfig()) -> None:
def process_task_info(task: TaskOnKart[Any], task_dump_config: TaskDumpConfig = TaskDumpConfig()) -> None:
match task_dump_config:
case TaskDumpConfig(mode=TaskDumpMode.NONE, output_type=TaskDumpOutputType.NONE):
pass
Expand Down
4 changes: 3 additions & 1 deletion gokart/config_params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import luigi

import gokart
Expand All @@ -18,7 +20,7 @@ def __init__(self, config_class: type[luigi.Config], parameter_alias: dict[str,
self._config_class: type[luigi.Config] = config_class
self._parameter_alias: dict[str, str] = parameter_alias if parameter_alias is not None else {}

def __call__(self, task_class: type[gokart.TaskOnKart]) -> type[gokart.TaskOnKart]:
def __call__(self, task_class: type[gokart.TaskOnKart[Any]]) -> type[gokart.TaskOnKart[Any]]:
# wrap task to prevent task name from being changed
@luigi.task._task_wraps(task_class)
class Wrapped(task_class): # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion gokart/conflict_prevention_lock/task_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TaskLockException(Exception):


class RedisClient:
_instances: dict = {}
_instances: dict[Any, Any] = {}

def __new__(cls, *args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
Expand Down
23 changes: 14 additions & 9 deletions gokart/conflict_prevention_lock/task_lock_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
import functools
from collections.abc import Callable
from logging import getLogger
from typing import Any
from typing import ParamSpec, TypeVar

from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock

logger = getLogger(__name__)


def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_check: Callable) -> Callable:
P = ParamSpec('P')
R = TypeVar('R')


def wrap_dump_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams, exist_check: Callable[..., bool]) -> Callable[P, R | None]:
"""Redis lock wrapper function for TargetOnKart.dump().
When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check.
https://github.com/m3dev/gokart/issues/265
Expand All @@ -19,14 +23,15 @@ def wrap_dump_with_lock(func: Callable, task_lock_params: TaskLockParams, exist_
if not task_lock_params.should_task_lock:
return func

def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None:
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)

try:
logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} locked.')
if not exist_check():
func(*args, **kwargs)
return func(*args, **kwargs)
return None
finally:
logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} released.')
task_lock.release()
Expand All @@ -35,7 +40,7 @@ def wrapper(*args, **kwargs):
return wrapper


def wrap_load_with_lock(func: Callable, task_lock_params: TaskLockParams) -> Callable:
def wrap_load_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]:
"""Redis lock wrapper function for TargetOnKart.load().
When TargetOnKart.load() is called, redis lock will be locked and released before load().
https://github.com/m3dev/gokart/issues/265
Expand All @@ -44,7 +49,7 @@ def wrap_load_with_lock(func: Callable, task_lock_params: TaskLockParams) -> Cal
if not task_lock_params.should_task_lock:
return func

def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)

Expand All @@ -58,15 +63,15 @@ def wrapper(*args, **kwargs):
return wrapper


def wrap_remove_with_lock(func: Callable, task_lock_params: TaskLockParams) -> Callable:
def wrap_remove_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]:
"""Redis lock wrapper function for TargetOnKart.remove().
When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock.
https://github.com/m3dev/gokart/issues/265
"""
if not task_lock_params.should_task_lock:
return func

def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)

Expand All @@ -86,7 +91,7 @@ def wrapper(*args, **kwargs):
return wrapper


def wrap_run_with_lock(run_func: Callable[[], Any], task_lock_params: TaskLockParams) -> Callable[[], Any]:
def wrap_run_with_lock(run_func: Callable[[], R], task_lock_params: TaskLockParams) -> Callable[[], R]:
@functools.wraps(run_func)
def wrapped():
task_lock = set_task_lock(task_lock_params=task_lock_params)
Expand Down
2 changes: 1 addition & 1 deletion gokart/gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_patched_obj_metadata(
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> dict | Any:
) -> dict[str, Any] | Any:
# If metadata from response when getting bucket and object information is not dictionary,
# something wrong might be happened, so return original metadata, no patched.
if not isinstance(metadata, dict):
Expand Down
5 changes: 3 additions & 2 deletions gokart/info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from logging import getLogger
from typing import Any

import luigi

Expand All @@ -11,7 +12,7 @@


def make_tree_info(
task: TaskOnKart,
task: TaskOnKart[Any],
indent: str = '',
last: bool = True,
details: bool = False,
Expand Down Expand Up @@ -43,7 +44,7 @@ def make_tree_info(
return make_task_info_as_tree_str(task=task, details=details, abbr=abbr, ignore_task_names=ignore_task_names)


class tree_info(TaskOnKart):
class tree_info(TaskOnKart[Any]):
mode: str = luigi.Parameter(default='', description='This must be in ["simple", "all"].')
output_path: str = luigi.Parameter(default='tree.txt', description='Output file path.')

Expand Down
3 changes: 2 additions & 1 deletion gokart/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
from logging import getLogger
from typing import Any

import luigi
import luigi.cmdline
Expand Down Expand Up @@ -49,7 +50,7 @@ def _try_tree_info(cmdline_args):

def _try_to_delete_unnecessary_output_file(cmdline_args: list[str]) -> None:
with CmdlineParser.global_instance(cmdline_args) as cp:
task = cp.get_task_obj() # type: gokart.TaskOnKart
task: gokart.TaskOnKart[Any] = cp.get_task_obj()
if task.delete_unnecessary_output_files:
if ObjectStorage.if_object_storage_path(task.workspace_directory):
logger.info('delete-unnecessary-output-files is not support s3/gcs.')
Expand Down
10 changes: 5 additions & 5 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def input(self) -> FlattenableItems[TargetOnKart]:
def output(self) -> FlattenableItems[TargetOnKart]:
return self.make_target()

def requires(self) -> FlattenableItems[TaskOnKart]:
def requires(self) -> FlattenableItems[TaskOnKart[Any]]:
tasks = self.make_task_instance_dictionary()
return tasks or [] # when tasks is empty dict, then this returns empty list.

def make_task_instance_dictionary(self) -> dict[str, TaskOnKart]:
def make_task_instance_dictionary(self) -> dict[str, TaskOnKart[Any]]:
return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)}

@staticmethod
Expand Down Expand Up @@ -395,7 +395,7 @@ def _to_str_params(task):
dependencies.append(self.get_own_code())
return hashlib.md5(str(dependencies).encode()).hexdigest()

def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart | list[TaskOnKart]) -> FlattenableItems[TargetOnKart]:
def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart[Any] | list[TaskOnKart[Any]]) -> FlattenableItems[TargetOnKart]:
if target is None:
return self.input()
if isinstance(target, str):
Expand Down Expand Up @@ -438,7 +438,7 @@ def get_info(self, only_significant=False):
def _get_task_log_target(self):
return self.make_target(f'log/task_log/{type(self).__name__}.pkl')

def get_task_log(self) -> dict:
def get_task_log(self) -> dict[str, Any]:
target = self._get_task_log_target()
if self.task_log:
return self.task_log
Expand All @@ -455,7 +455,7 @@ def _dump_task_log(self):
def _get_task_params_target(self):
return self.make_target(f'log/task_params/{type(self).__name__}.pkl')

def get_task_params(self) -> dict:
def get_task_params(self) -> dict[str, Any]:
target = self._get_task_log_target()
if target.exists():
return cast(dict[Any, Any], self.load(target))
Expand Down
3 changes: 2 additions & 1 deletion gokart/task_complete_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import functools
from collections.abc import Callable
from logging import getLogger
from typing import Any

logger = getLogger(__name__)


def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable) -> Callable:
def task_complete_check_wrapper(run_func: Callable[..., Any], complete_check_func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(run_func)
def wrapper(*args, **kwargs):
if complete_check_func():
Expand Down
9 changes: 5 additions & 4 deletions gokart/testing/check_if_run_with_empty_data_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import sys
from typing import Any

import luigi
from luigi.cmdline_parser import CmdlineParser
Expand All @@ -14,15 +15,15 @@
test_logger.setLevel(logging.INFO)


class test_run(gokart.TaskOnKart):
class test_run(gokart.TaskOnKart[Any]):
pandas: bool = luigi.BoolParameter()
namespace: str | None = luigi.OptionalParameter(
default=None, description='When task namespace is not defined explicitly, please use "__not_user_specified".'
)


class _TestStatus:
def __init__(self, task: gokart.TaskOnKart) -> None:
def __init__(self, task: gokart.TaskOnKart[Any]) -> None:
self.namespace = task.task_namespace
self.name = type(task).__name__
self.task_id = task.make_unique_id()
Expand All @@ -39,14 +40,14 @@ def fail(self) -> bool:
return self.status != 'OK'


def _get_all_tasks(task: gokart.TaskOnKart) -> list[gokart.TaskOnKart]:
def _get_all_tasks(task: gokart.TaskOnKart[Any]) -> list[gokart.TaskOnKart[Any]]:
result = [task]
for o in flatten(task.requires()):
result.extend(_get_all_tasks(o))
return result


def _run_with_test_status(task: gokart.TaskOnKart) -> _TestStatus:
def _run_with_test_status(task: gokart.TaskOnKart[Any]) -> _TestStatus:
test_message = _TestStatus(task)
try:
task.run()
Expand Down
9 changes: 5 additions & 4 deletions gokart/tree/task_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from typing import Any

import pandas as pd

Expand All @@ -9,7 +10,7 @@
from gokart.tree.task_info_formatter import make_task_info_tree, make_tree_info, make_tree_info_table_list


def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bool = True, ignore_task_names: list[str] | None = None) -> str:
def make_task_info_as_tree_str(task: TaskOnKart[Any], details: bool = False, abbr: bool = True, ignore_task_names: list[str] | None = None) -> str:
"""
Return a string representation of the tasks, their statuses/parameters in a dependency tree format

Expand All @@ -33,7 +34,7 @@ def make_task_info_as_tree_str(task: TaskOnKart, details: bool = False, abbr: bo
return result


def make_task_info_as_table(task: TaskOnKart, ignore_task_names: list[str] | None = None) -> pd.DataFrame:
def make_task_info_as_table(task: TaskOnKart[Any], ignore_task_names: list[str] | None = None) -> pd.DataFrame:
"""Return a table containing information about dependent tasks.

Parameters
Expand All @@ -54,7 +55,7 @@ def make_task_info_as_table(task: TaskOnKart, ignore_task_names: list[str] | Non
return task_info_table


def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None) -> None:
def dump_task_info_table(task: TaskOnKart[Any], task_info_dump_path: str, ignore_task_names: list[str] | None = None) -> None:
"""Dump a table containing information about dependent tasks.

Parameters
Expand All @@ -79,7 +80,7 @@ def dump_task_info_table(task: TaskOnKart, task_info_dump_path: str, ignore_task
task_info_target.dump(obj=task_info_table, lock_at_dump=False)


def dump_task_info_tree(task: TaskOnKart, task_info_dump_path: str, ignore_task_names: list[str] | None = None, use_unique_id: bool = True) -> None:
def dump_task_info_tree(task: TaskOnKart[Any], task_info_dump_path: str, ignore_task_names: list[str] | None = None, use_unique_id: bool = True) -> None:
"""Dump the task info tree object (TaskInfo) to a pickle file.

Parameters
Expand Down
8 changes: 4 additions & 4 deletions gokart/tree/task_info_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
import warnings
from dataclasses import dataclass
from typing import NamedTuple
from typing import Any, NamedTuple

from gokart.task import TaskOnKart
from gokart.utils import FlattenableItems, flatten
Expand All @@ -14,10 +14,10 @@ class TaskInfo:
name: str
unique_id: str
output_paths: list[str]
params: dict
params: dict[str, Any]
processing_time: str
is_complete: str
task_log: dict
task_log: dict[str, Any]
requires: FlattenableItems[RequiredTask]
children_task_infos: list[TaskInfo]

Expand Down Expand Up @@ -59,7 +59,7 @@ def _make_requires_info(requires):
raise TypeError(f'`requires` has unexpected type {type(requires)}. Must be `TaskOnKart`, `Iterarble[TaskOnKart]`, or `Dict[str, TaskOnKart]`')


def make_task_info_tree(task: TaskOnKart, ignore_task_names: list[str] | None = None, cache: dict[str, TaskInfo] | None = None) -> TaskInfo:
def make_task_info_tree(task: TaskOnKart[Any], ignore_task_names: list[str] | None = None, cache: dict[str, TaskInfo] | None = None) -> TaskInfo:
with warnings.catch_warnings():
warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete() method')
is_task_complete = task.complete()
Expand Down
8 changes: 4 additions & 4 deletions gokart/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
self,
task: luigi.Task,
worker_id: str,
result_queue: multiprocessing.Queue,
result_queue: multiprocessing.Queue[Any],
status_reporter: luigi.worker.TaskStatusReporter,
use_multiprocessing: bool = False,
worker_timeout: int = 0,
Expand All @@ -150,11 +150,11 @@ def __init__(
# completeness check using the cache
self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache)

def _run_task(self) -> collections.abc.Generator | None:
def _run_task(self) -> collections.abc.Generator[Any, Any, Any] | None:
if self.task_completion_check_at_run and self.check_complete(self.task):
logger.warning(f'{self.task} is skipped because the task is already completed.')
return None
return cast(collections.abc.Generator | None, self.task.run())
return cast(collections.abc.Generator[Any, Any, Any] | None, self.task.run())

def _run_get_new_deps(self) -> list[tuple[str, str, dict[str, str]]] | None:
task_gen = self._run_task()
Expand Down Expand Up @@ -447,7 +447,7 @@ def __init__(
pass

# Keep info about what tasks are running (could be in other processes)
self._task_result_queue: multiprocessing.Queue = _fork_context.Queue()
self._task_result_queue: multiprocessing.Queue[Any] = _fork_context.Queue()
self._running_tasks: dict[str, TaskProcess] = {}
self._idle_since: datetime.datetime | None = None

Expand Down
Loading