diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 31b2f10..daa4806 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,7 +69,7 @@ jobs: if: always() && inputs.notify_teams && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/develop') steps: - name: Notify Teams - uses: ecmwf-actions/notify-teams@v1 + uses: ecmwf/notify-teams@v1 with: incoming_webhook: ${{ secrets.incoming_webhook }} needs_context: ${{ toJSON(needs) }} diff --git a/.gitignore b/.gitignore index 82bab73..1dc13ca 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ docs/_build _version.py .vscode .DS_Store +.weave \ No newline at end of file diff --git a/docs/content/api-reference.rst b/docs/content/api-reference.rst index ef743fc..f7df6d5 100644 --- a/docs/content/api-reference.rst +++ b/docs/content/api-reference.rst @@ -226,19 +226,21 @@ External .. _Extern: -.. autoclass:: pyflow.Extern +.. autoclass:: pyflow.ExternSuite -.. autoclass:: pyflow.ExternNode +.. autoclass:: pyflow.ExternFamily .. autoclass:: pyflow.ExternTask -.. autoclass:: pyflow.ExternFamily +.. autoclass:: pyflow.ExternVariable + +.. autoclass:: pyflow.ExternLimit .. autoclass:: pyflow.ExternEvent .. autoclass:: pyflow.ExternMeter -.. autoclass:: pyflow.ExternYMD +.. autoclass:: pyflow.ExternRepeat Deployment diff --git a/docs/content/introductory-course/flow-control.ipynb b/docs/content/introductory-course/flow-control.ipynb index 16542fb..075ed91 100644 --- a/docs/content/introductory-course/flow-control.ipynb +++ b/docs/content/introductory-course/flow-control.ipynb @@ -1001,7 +1001,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "a2937746-a676-49ea-9332-315e763167ca", "metadata": {}, "outputs": [ @@ -1034,7 +1034,7 @@ " etask = pf.ExternTask('/a/b/c/d')\n", " efamily = pf.ExternFamily('/f/g/h/i')\n", " \n", - " eymd = pf.ExternYMD('/a/b/c/d:YMD')\n", + " eymd = pf.ExternRepeat('/a/b/c/d:YMD')\n", " eevent = pf.ExternEvent('/e/f/g/h:ev')\n", " emeter = pf.ExternMeter('/g/h/i/j:mt')\n", " \n", diff --git a/docs/content/introductory-course/helper-functionality.ipynb b/docs/content/introductory-course/helper-functionality.ipynb index 25adaa1..9c4460b 100644 --- a/docs/content/introductory-course/helper-functionality.ipynb +++ b/docs/content/introductory-course/helper-functionality.ipynb @@ -158,7 +158,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "b2e2a324-3df0-4874-8c69-d3d93ea39bc4", "metadata": {}, "outputs": [ @@ -196,7 +196,7 @@ " pf.RepeatDate(\"YMD\", datetime.date(2019, 1, 1), datetime.date(2019, 12, 31))\n", " with pf.Family('follower') as follower:\n", " pf.RepeatDate(\"YMD\", datetime.date(2019, 1, 1), datetime.date(2019, 12, 31))\n", - " follower.follow = leader.YMD\n", + " follower.follow = leader\n", "\n", "s" ] diff --git a/pyflow/__init__.py b/pyflow/__init__.py index 8e6f63b..305eb3d 100644 --- a/pyflow/__init__.py +++ b/pyflow/__init__.py @@ -48,11 +48,15 @@ from .expressions import Deferred, all_complete, sequence from .extern import ( Extern, + ExternAttribute, ExternEvent, ExternFamily, + ExternLimit, ExternMeter, - ExternNode, + ExternRepeat, + ExternSuite, ExternTask, + ExternVariable, ExternYMD, ) from .header import FileHeader, FileTail, Header, InlineCodeHeader diff --git a/pyflow/adder.py b/pyflow/adder.py index 80ac9c3..1e77f37 100644 --- a/pyflow/adder.py +++ b/pyflow/adder.py @@ -55,7 +55,9 @@ def replace(self, other): self.add(other) def _create(self, other): - assert other is not None + result = [] + if other is None: + return result if isinstance(other, dict): other = [it for it in other.items()] @@ -63,7 +65,6 @@ def _create(self, other): if not isinstance(other, list): other = [other] - result = [] for o in other: if isinstance(o, tuple): if len(o) == 2 and isinstance(o[1], dict): diff --git a/pyflow/attributes.py b/pyflow/attributes.py index e8358b5..559ceb4 100644 --- a/pyflow/attributes.py +++ b/pyflow/attributes.py @@ -3,6 +3,7 @@ import datetime import re +from . import warn from .anchor import AnchorMixin from .base import Base, GenerateError from .cron import Crontab @@ -18,6 +19,7 @@ Le, Lt, Mod, + Mul, Ne, Sub, expression_from_json, @@ -92,26 +94,6 @@ def generate_stub(self): shape = "box" -class RepeatDay(Attribute): - """ - An attribute that allows a node to be repeated infinitely. - - Parameters: - value(int): The repeat step. - - Example:: - - pyflow.attributes.RepeatDay(1) - """ - - def __init__(self, value): - super().__init__("_repeat") - self._value = value - - def _build(self, ecflow_parent): - ecflow_parent.add_repeat(ecflow.RepeatDay(int(self.value))) - - class Time(Attribute): """ An attribute for setting a time dependency of the node. @@ -341,7 +323,43 @@ def __init__(self, **kwargs): Variable(key, val) -class RepeatString(Exportable): +class Repeat(Exportable): + """ + A virtual class that defines a repeat attribute + """ + + def __init__(self, name, values=None): + super().__init__(name, values) + if self.parent._repeat is not None: + warn( + "Overwriting an existing repeat value!", + category=UserWarning, + stacklevel=2, + ) + self.parent._repeat = self # set the repeat at the node level + + +class RepeatDay(Repeat): + """ + An attribute that allows a node to be repeated infinitely. + + Parameters: + value(int): The repeat step. + + Example:: + + pyflow.attributes.RepeatDay(1) + """ + + def __init__(self, value): + super().__init__("_repeat") + self._value = value + + def _build(self, ecflow_parent): + ecflow_parent.add_repeat(ecflow.RepeatDay(int(self.value))) + + +class RepeatString(Repeat): """ An attribute that allows a node to be repeated by a string value. @@ -411,8 +429,11 @@ def __add__(self, other): def __sub__(self, other): return Sub(self, other) + def __mul__(self, other): + return Mul(self, other) + -class RepeatEnumerated(Exportable): +class RepeatEnumerated(Repeat): """ An attribute that allows a node to be repeated by an enumerated list. @@ -422,7 +443,7 @@ class RepeatEnumerated(Exportable): Example:: - pyflow.RepeatEnumerated("REPEAT_STRING", ["a", "b", "c", "d", "e"]) + pyflow.RepeatEnumerated("REPEAT_STRING", [1, 3, 4, 5]) """ def __init__(self, name, value): @@ -437,17 +458,17 @@ def values(self): """*list*: The list of enumerated values.""" return [str(x) for x in self.value] - def settings(self): - return self.value - def __add__(self, other): return Add(self, other) def __sub__(self, other): return Sub(self, other) + def __mul__(self, other): + return Mul(self, other) + -class RepeatDateList(Exportable): +class RepeatDateList(Repeat): """ An attribute that allows a node to be repeated over a list of dates. @@ -479,9 +500,6 @@ def values(self): v = [int(x) for x in v] return v - def settings(self): - return self.value - def __add__(self, other): return Add(self, other) @@ -489,7 +507,7 @@ def __sub__(self, other): return Sub(self, other) -class RepeatInteger(Exportable): +class RepeatInteger(Repeat): """ An attribute that allows a node to be repeated by an integer range. @@ -526,8 +544,11 @@ def __add__(self, other): def __sub__(self, other): return Sub(self, other) + def __mul__(self, other): + return Mul(self, other) + -class RepeatDate(Exportable): +class RepeatDate(Repeat): """ An attribute that allows a node to be repeated by a date value. @@ -580,9 +601,6 @@ def __sub__(self, other): result = Sub(self.julian, other.julian) return result - def settings(self): - return self._start, self._end, self._increment - @property def julian(self): """*int*: The Julian date of the repeat date.""" @@ -681,9 +699,6 @@ def __add__(self, other): def __sub__(self, other): return Sub(self, other) - def settings(self): - return self._start, self._end, self._increment - def _delta_to_string(self, delta): # there is no strftime for timedelta so we make our own total_seconds = int(delta.total_seconds()) @@ -712,12 +727,6 @@ def day_of_week(self): return Mod(Add(Div(self, 86400), 4), 7) -def string_or_enumerated(name, value): - if all(isinstance(v, int) for v in value): - return RepeatEnumerated(name, value) - return RepeatString(name, value) - - def is_date(value): return ( isinstance(value, (datetime.date, datetime.datetime)) @@ -796,27 +805,9 @@ def make_variable(node, name, value): with node: if isinstance(value, (tuple, list)): - if len(value) in [2, 3]: - if is_date(value[0]) and is_date(value[1]): - if len(value) == 3: - if isinstance(value[2], int): - return RepeatDate( - name, - as_date(value[0]), - as_date(value[1]), - value[2], - ) - else: - return RepeatDate(name, as_date(value[0]), as_date(value[1]), 1) - - if isinstance(value[0], int) and isinstance(value[1], int): - if len(value) == 3: - if isinstance(value[2], int): - return RepeatInteger(name, value[0], value[1], value[2]) - else: - return RepeatInteger(name, value[0], value[1], 1) - - return string_or_enumerated(name, value) + raise Exception( + "Repeat construction through a list is not supported anymore" + ) if isinstance(value, (str, int, float)): return Variable(name, value) @@ -1242,23 +1233,32 @@ class Follow(_Trigger): An attribute for setting a condition for running the node behind another repeated node which has completed. Parameters: - value(RepeatDate_): The repeat date attribute of the followed node. + value(Repeat_ or Task_ or Family_ or Suite_): The followed node or the repeat attribute of the followed node. Example:: - - pyflow.attributes.Follow(pyflow.RepeatDate('REPEAT_DATE', - datetime.date(year=2019, month=1, day=1), - datetime.date(year=2019, month=12, day=31))) + t1 = Task("t1", repeat=(RepeatEnumerated, "NUM", [1, 2, 3])) + t2 = Task("t2", repeat=(RepeatEnumerated, "NUM", [1, 2, 3])) + t2.follow = t1 """ def __init__(self, value): - super().__init__("_follow_%s" % (value,), value) - if not hasattr(value, "settings"): - raise Exception( - "Cannot follow a node of type %s (%r)" % (type(value), value) + super().__init__(f"_follow_{value.name}") + from .nodes import Node # yeah it's bad but there's a circular import + + if isinstance(value, Node): + parent = value + repeat = value.repeat + elif isinstance(value, Repeat): + parent = value.parent + repeat = value + else: + raise TypeError( + f"Follow attribute {self.name} requires a Repeat or a Node instance" ) - self.parent[value.name] = value.settings() - self._value = value.parent.complete | (self.parent[value.name] < value) + + if repeat is None: + raise TypeError(f"Follow attribute {self.name} requires a repeat") + self._value = parent.complete | (self.parent.repeat < repeat) ################################################################### @@ -1404,6 +1404,7 @@ class Mirror(Attribute): polling(str, int): The time interval used to poll the remote ecFlow server. ssl(bool): The flag indicating if SSL communication is enabled. auth(str): The path to the ecFlow authentication credentials file. + force(bool): The flag indicating whether to override lower limit on polling time Example:: @@ -1412,14 +1413,36 @@ class Mirror(Attribute): "/suite/family/task", "remote-ecflow-server", "3141", - 60, + 300, False "/path/to/auth.json") """ - def __init__(self, name, remote_path, remote_host, remote_port, polling, ssl, auth): + def __init__( + self, + name: str, + remote_path: str, + remote_host: str, + remote_port: str, + auth: str = "", + polling: int = 300, + ssl: bool = False, + force: bool = False, + ): super().__init__(name) + if not force: + try: + polling = int(polling) + except ValueError: + pass # polling is not an integer, so we cannot validate it here + + if isinstance(polling, int) and polling < 60: + raise ValueError( + "Mirror polling interval must be at least 60s to avoid overloading \ + the remote server. Use force=True to override." + ) + self.remote_path = str(remote_path) self.remote_host = str(remote_host) self.remote_port = str(remote_port) diff --git a/pyflow/base.py b/pyflow/base.py index 92ff78d..9fabcfe 100644 --- a/pyflow/base.py +++ b/pyflow/base.py @@ -21,6 +21,10 @@ def remove_node(self, node): def add_node(self, node): pass + @property + def repeat(self): + return None + @property def host(self): return None diff --git a/pyflow/expressions.py b/pyflow/expressions.py index 4719ea8..a8ee24f 100644 --- a/pyflow/expressions.py +++ b/pyflow/expressions.py @@ -265,6 +265,17 @@ def simplify(self): return self._simplify() +class BinMatOp(BinOp): + def __mul__(self, other): + return Mul(self, other) + + def __add__(self, other): + return Add(self, other) + + def __sub__(self, other): + return Sub(self, other) + + class Ne(BinOp): def __init__(self, left, right): super().__init__("ne", left, right, 1) @@ -300,7 +311,7 @@ def __init__(self, left, right): super().__init__("or", left, right, 0) def _simplify(self): - (l, r) = (self._left.evaluate(), self._right.evaluate()) + l, r = (self._left.evaluate(), self._right.evaluate()) if l is not UNDEFINED and r is not UNDEFINED: return self._left.value or self._right.value @@ -327,7 +338,7 @@ def __init__(self, left, right): super().__init__("and", left, right, 0) def _simplify(self): - (l, r) = (self._left.evaluate(), self._right.evaluate()) + l, r = (self._left.evaluate(), self._right.evaluate()) if l is not UNDEFINED and r is not UNDEFINED: return self._left.value and self._right.value @@ -349,26 +360,31 @@ def _simplify(self): return self -class Sub(BinOp): +class Sub(BinMatOp): def __init__(self, left, right): super().__init__("-", left, right, 2) -class Add(BinOp): +class Add(BinMatOp): def __init__(self, left, right): super().__init__("+", left, right, 2) -class Mod(BinOp): +class Mod(BinMatOp): def __init__(self, left, right): super().__init__("%", left, right, 3) -class Div(BinOp): +class Div(BinMatOp): def __init__(self, left, right): super().__init__("/", left, right, 3) +class Mul(BinMatOp): + def __init__(self, left, right): + super().__init__("*", left, right, 3) + + class Atom(Expression): _priority = 99 diff --git a/pyflow/extern.py b/pyflow/extern.py index 489db20..e3e280c 100644 --- a/pyflow/extern.py +++ b/pyflow/extern.py @@ -1,6 +1,7 @@ import datetime -from .attributes import Event, Meter, RepeatDate +from . import warn +from .attributes import Event, Limit, Meter, Repeat, RepeatDate, Variable from .base import Root from .nodes import Family, Suite, Task @@ -45,10 +46,65 @@ def ExternNode(path, tail_cls=Family): def ExternAttribute(path, cls, *args): KNOWN_EXTERNS.add(path) path, attr = path.split(":") - with ExternNode(path): + kind = Family if "/" in path[1:] else Suite + with ExternNode(path, kind): return cls(attr, *args) +def ExternVariable(path): + """ + Maps an external variable. + + Parameters: + path(*str*): Path of the item. + + Returns: + Variable_: An object that corresponds to an external variable. + + Example:: + + pyflow.ExternVariable('/a/b:var') + """ + return ExternAttribute(path, Variable, 1) + + +def ExternLimit(path): + """ + Maps an external limit. + + Parameters: + path(*str*): Path of the item. + + Returns: + Limit_: An object that corresponds to an external item. + + Example:: + + pyflow.ExternLimit('/a/limits:hpc') + """ + return ExternAttribute(path, Limit, 1) + + +def ExternRepeat(path): + """ + Maps an external repeat, i.e. a repeat that is not built from the same repository. + Cannot be a generic attribute as the repeat can be used with the follow() approach, + which requires an object of type Repeat. + + Parameters: + path(*str*): Path of the external repeat. + + Returns: + RepeatDate_: An object that corresponds to an external repeat. + + Example:: + + pyflow.ExternRepeat('/a/b/c/d:YMD') + """ + + return ExternAttribute(path, Repeat) + + def ExternYMD(path): """ Maps an external repeat date, i.e. a repeat date that is not built from the same repository. @@ -63,7 +119,11 @@ def ExternYMD(path): pyflow.ExternYMD('/a/b/c/d:YMD') """ - + warn( + "'ExternYMD' is deprecated, use ExternAttribute instead", + DeprecationWarning, + stacklevel=1, + ) return ExternAttribute( path, RepeatDate, datetime.datetime.now(), datetime.datetime.now() ) @@ -83,7 +143,6 @@ def ExternEvent(path): pyflow.ExternEvent('/e/f/g/h:ev') """ - return ExternAttribute(path, Event) @@ -101,7 +160,6 @@ def ExternMeter(path): pyflow.ExternMeter('/g/h/i/j:mt') """ - return ExternAttribute(path, Meter, 0) @@ -119,10 +177,31 @@ def Extern(path): pyflow.Extern('/f/g/h/i') """ - + warn( + "'Extern' is deprecated, use ExternSuite, ExternFamily or ExternTask instead", + DeprecationWarning, + stacklevel=1, + ) return ExternNode(path) +def ExternSuite(path): + """ + Maps an external suite. + + Parameters: + path(str): Path of the external suite. + + Returns: + Suite_: An object that corresponds to an external suite. + + Example:: + + pyflow.ExternSuite('/a') + """ + return ExternNode(path, Suite) + + def ExternFamily(path): """ Maps an external family, i.e. a family that is not built from the same repository. @@ -137,8 +216,7 @@ def ExternFamily(path): pyflow.ExternFamily('/f/g/h/i') """ - - return ExternNode(path) + return ExternNode(path, Family) def ExternTask(path): @@ -155,5 +233,4 @@ def ExternTask(path): pyflow.ExternTask('/a/b/c/d') """ - - return ExternNode(path, tail_cls=Task) + return ExternNode(path, Task) diff --git a/pyflow/host.py b/pyflow/host.py index 35ed405..9f407a5 100644 --- a/pyflow/host.py +++ b/pyflow/host.py @@ -83,6 +83,23 @@ SSH_COMMAND = "ssh -v -o StrictHostKeyChecking=no" +HOST_REGISTRY = {} + + +def register_host(registry_key): + """ + Registers a host class in the host registry. + + Parameters: + registry_key(str): The key to register the host class under. + """ + + def decorator(cls): + HOST_REGISTRY[registry_key] = cls + return cls + + return decorator + class Host: """ @@ -186,21 +203,52 @@ def __str__(self): def __repr__(self): return str(self) - @property - def ecflow_variables(self): - """*dict*: The variables that must be set on relevant nodes to run on this host.""" + def update_node_attributes(self, options): + """ + Updates the attributes of a node with the host-specific values. + + Parameters: + - options (dict): The options dictionary to update with host-specific attributes. + """ if self.server_ecfvars: - vars = {} + # Use generated variables to be able to export the variables in tasks + host_attrs = { + "generated_variables": [ + "ECF_JOB_CMD", + "ECF_KILL_CMD", + "ECF_STATUS_CMD", + "ECF_CHECK_CMD", + "ECF_OUT", + ], + "variables": {**self.extra_variables}, + } else: - vars = { - "ECF_JOB_CMD": self.job_cmd, - "ECF_KILL_CMD": self.kill_cmd, - "ECF_STATUS_CMD": self.status_cmd, - "ECF_CHECK_CMD": self.check_cmd, - "ECF_OUT": self.log_directory, + host_attrs = { + "generated_variables": [], + "variables": { + "ECF_JOB_CMD": self.job_cmd, + "ECF_KILL_CMD": self.kill_cmd, + "ECF_STATUS_CMD": self.status_cmd, + "ECF_CHECK_CMD": self.check_cmd, + "ECF_OUT": self.log_directory, + **self.extra_variables, + }, } - vars.update(self.extra_variables) - return vars + + # Update the options with host-specific attributes + for attribute, values in host_attrs.items(): + if attribute in ["variables"]: + variables = options.pop(attribute, {}) + values.update(variables) + options[attribute] = values + elif attribute in ["generated_variables"]: + gen_variables = options.pop(attribute, []) + values += gen_variables + options[attribute] = values + else: + raise Exception("Unknown attribute: {}".format(attribute)) + + return options @property def job_cmd(self): @@ -347,9 +395,7 @@ def preamble_init(self, ecflowpath): *str*: The preamble initialisation script. """ - script = ( - textwrap.dedent( - """ + script = textwrap.dedent(""" # ----------------------------- ECFLOW INIT ---------------------------- export PATH=%(ecf_path)s:$PATH @@ -358,10 +404,7 @@ def preamble_init(self, ecflowpath): # Tell ecFlow we have started ecflow_client --init=$$ - """ - ) - % {"ecf_path": ecflowpath} - ) + """) % {"ecf_path": ecflowpath} return script @@ -378,22 +421,18 @@ def preamble_error_function(self, ecflowpath, exit_hook=None): """ script = "" - script += textwrap.dedent( - """ + script += textwrap.dedent(""" # custom exit/cleanup code exit_hook () { echo "cleaning up ...." - """ - ) + """) if exit_hook: for line in exit_hook: script += f" {line}\n" script += "}\n\n" signal_list = " ".join(str(s) for s in self.trap_signals) - script += textwrap.dedent( - ( - """ + script += textwrap.dedent((""" # ----------------------------- TRAPS FOR SUBMITTED JOBS ---------------------------- set +x # Define a error handler @@ -421,10 +460,7 @@ def preamble_error_function(self, ecflowpath, exit_hook=None): # Trap any calls to exit and errors caught by the -e flag trap ERROR 0 set -x - """ # noqa: E501 - ) - % {"ecf_path": ecflowpath, "signal_list": signal_list} - ) + """) % {"ecf_path": ecflowpath, "signal_list": signal_list}) # noqa: E501 return script def job_preamble(self, exit_hook=None): @@ -434,6 +470,7 @@ def job_preamble(self, exit_hook=None): ) + self.preamble_error_function(self.ecflow_path, exit_hook).split("\n") +@register_host("null") class NullHost(Host): """ A dummy host object invisible to **ecFlow**, but still throws exceptions if **pyflow** attempts to create tasks @@ -469,10 +506,8 @@ def __init__(self, **kwargs): kwargs.setdefault("limit", None) super().__init__("null", **kwargs) - @property - def ecflow_variables(self): - """*dict*: The variables that must be set on relevant nodes to run on this host, always empty.""" - return {} + def update_node_attributes(self, options): + return options def host_preamble(self, exit_hook=None): """ @@ -500,6 +535,7 @@ def build_label(self): return None +@register_host("localhost") class LocalHost(Host): """ A host object that executes scripts directly on the **ecFlow** server. @@ -611,6 +647,7 @@ def copy_file_to(self, source_file, target_file): ) +@register_host("ecflow-default") class EcflowDefaultHost(LocalHost): """ By default we just use LocalHost... Slightly modified from ecflow default of @@ -623,6 +660,7 @@ def __init__(self, **kwargs): super().__init__("default", **kwargs) +@register_host("ssh") class SSHHost(Host): """ A host object that executes scripts on the **ecFlow** server via SSH protocol. @@ -798,9 +836,10 @@ def host_postamble(self): return [] +@register_host("ssh-simple") class SimpleSSHHost(Host): - def __init__(self, host): - super().__init__(host) + def __init__(self, host, **kwargs): + super().__init__(host, **kwargs) self.host = host @property @@ -832,6 +871,7 @@ def host_postamble(self): return POSTAMBLE_SUBMITTED_JOBS.split("\n") +@register_host("slurm") class SLURMHost(SSHHost): """ A host object that executes scripts on the **ecFlow** server via Slurm job scheduling system. @@ -926,6 +966,7 @@ def host_postamble(self): return POSTAMBLE_SUBMITTED_JOBS.split("\n") +@register_host("pbs") class PBSHost(SSHHost): """ A host object that executes scripts on the **ecFlow** server via batch server. @@ -1020,6 +1061,7 @@ def host_postamble(self): return POSTAMBLE_SUBMITTED_JOBS.split("\n") +@register_host("troika") class TroikaHost(Host): """ A host object that executes scripts on the **ecFlow** server via the troika job submitter. @@ -1027,6 +1069,10 @@ class TroikaHost(Host): Parameters: name(str): The name of the host. user(str): The user to use for troika commands to the host. + troika_exec(str): The path to the troika executable, defaults to `%TROIKA:troika%`. + troika_config(str): The path to the troika configuration file, defaults to `%TROIKA_CONFIG%`. + Value False or None will deactivate the config in the command. + troika_version(str): The version of the troika executable, defaults to `0.2.3`. hostname(str): The hostname of the host, otherwise `name` will be used. scratch_directory(str): The path in which tasks will be run, unless otherwise specified. log_directory(str): The directory to use for script output. Normally `ECF_HOME`, but may need to be changed on @@ -1051,24 +1097,26 @@ class TroikaHost(Host): pass """ - def __init__(self, name, user, **kwargs): - self.troika_exec = kwargs.pop("troika_exec", "troika") - self.troika_config = kwargs.pop("troika_config", "") - self.troika_version = tuple( - map(int, kwargs.pop("troika_version", "0.2.1").split(".")) - ) + def __init__( + self, + name, + user, + troika_exec="%TROIKA:troika%", + troika_config=None, + troika_version="0.2.3", + **kwargs, + ): + self.troika_exec = troika_exec + self.troika_config = troika_config + self.troika_version = tuple(map(int, troika_version.split("."))) super().__init__(name, user=user, **kwargs) def troika_command(self, command): cmd = " ".join( [ - f"%TROIKA:{self.troika_exec}%", + f"{self.troika_exec}", "-vv", - ( - f"-c %TROIKA_CONFIG:{self.troika_config}%" - if self.troika_config - else "" - ), + (f"-c {self.troika_config}" if self.troika_config else ""), f"{command}", f"-u {self.user}", ] @@ -1187,3 +1235,24 @@ def _translate_sthost(val): args.append("#TROIKA {}={}".format(arg, val)) return args + + +def host_factory(key, *args, **kwargs): + """ + Factory function to create host objects based on a key. + + Parameters: + key(str): The key specifying the type of host to create. + *args: Positional arguments to pass to the host constructor. + **kwargs: Keyword arguments to pass to the host + constructor. + Returns: + Host: The created host object. + """ + + if (target := HOST_REGISTRY.get(key)) is not None: + return target(*args, **kwargs) + else: + raise ValueError( + f"Unknown host type: {key}. Available host types are: {list(HOST_REGISTRY.keys())}" + ) diff --git a/pyflow/nodes.py b/pyflow/nodes.py index 650f4a5..67b5a46 100644 --- a/pyflow/nodes.py +++ b/pyflow/nodes.py @@ -25,7 +25,7 @@ Limit, Manual, Meter, - RepeatDay, + Mirror, Time, Today, Trigger, @@ -112,6 +112,7 @@ def __init__( purge_modules=False, extern=False, workdir=None, + repeat=None, **kwargs, ): """ @@ -142,7 +143,7 @@ def __init__( limits(Limit_): An attribute for a simple load management by limiting the number of tasks submitted by a specific **ecFlow** server. meters(Meter_): An attribute for a range of integer values that can be set from a script. - repeat(RepeatDay_): An attribute that allows a node to be repeated infinitely. + mirror(Mirror_): An attribute for a mirroring a node on another ecflow server. tasks(Task_): An attribute for adding a child task on the node. time(Time_): An attribute for setting a time dependency of the node. today(Today_): An attribute for setting a cron dependency of the node for the current day. @@ -152,6 +153,7 @@ def __init__( generated_variables(GeneratedVariable_): An attribute for setting an **ecFlow** generated variable. zombies(Zombies_): An attribute that defines how a zombie should be handled in an automated fashion. events(Event_): An attribute for declaring an action that a task can trigger while it is running. + repeat(Repeat_): An attribute for setting a repeat schedule for the node. **kwargs(str): Accept extra keyword arguments as variables to be set on the node. """ @@ -162,12 +164,17 @@ def __init__( self._modules = modules or [] self._purge_modules = purge_modules self._extern = extern + self._repeat = None # can't be set in constructor, needs to be done in context + if repeat is not None: + if not isinstance(repeat, (list, tuple)): + raise TypeError("Repeat attribute must be passed as a list or tuple") + with self: + self._repeat = repeat[0](*repeat[1:]) # If we have changed the host, then set the relevant directories self._host = host if host is not None: - for variable_name, variable_val in host.ecflow_variables.items(): - kwargs.setdefault(variable_name, variable_val) + host.update_node_attributes(kwargs) # If we have set/changed the host, then add a label as decided by the Host object with self: @@ -302,6 +309,26 @@ def append_node(self, node): self.add_node(node) return self + @property + def repeat(self): + """ + Returns the currently active repeat object. + If not found in current node, search in parents. + + Returns: + Repeat_: Currently active repeat object. + """ + if self._repeat is not None: + return self._repeat + return self.parent.repeat + + @repeat.setter + def repeat(self, value): + if not isinstance(value, (list, tuple)): + raise TypeError("Repeat attribute must be passed as a list or tuple") + with self: + self._repeat = value[0](*value[1:]) + @property def host(self): """ @@ -602,7 +629,9 @@ def ecflow_definition(self): d.auto_add_externs(True) for ext in d.externs: - assert is_extern_known(ext), "Attempting to add unknown extern reference" + assert is_extern_known( + ext + ), f"Attempting to add unknown extern reference {ext}" return d @@ -830,7 +859,6 @@ def __init__( limits(Limit_): An attribute for a simple load management by limiting the number of tasks submitted by a specific **ecFlow** server. meters(Meter_): An attribute for a range of integer values that can be set from a script. - repeat(RepeatDay_): An attribute that allows a node to be repeated infinitely. tasks(Task_): An attribute for adding a child task on the node. time(Time_): An attribute for setting a time dependency of the node. today(Today_): An attribute for setting a cron dependency of the node for the current day. @@ -840,6 +868,7 @@ def __init__( generated_variables(GeneratedVariable_): An attribute for setting an **ecFlow** generated variable. zombies(Zombies_): An attribute that defines how a zombie should be handled in an automated fashion. events(Event_): An attribute for declaring an action that a task can trigger while it is running. + repeat(Repeat_): An attribute for setting a repeat schedule for the node. **kwargs(str): Accept extra keyword arguments as variables to be set on the family. Example:: @@ -958,7 +987,6 @@ def __init__( limits(Limit_): An attribute for a simple load management by limiting the number of tasks submitted by a specific **ecFlow** server. meters(Meter_): An attribute for a range of integer values that can be set from a script. - repeat(RepeatDay_): An attribute that allows a node to be repeated infinitely. tasks(Task_): An attribute for adding a child task on the node. time(Time_): An attribute for setting a time dependency of the node. today(Today_): An attribute for setting a cron dependency of the node for the current day. @@ -967,6 +995,7 @@ def __init__( variables(Variable_): An attribute for setting an **ecFlow** variable. zombies(Zombies_): An attribute that defines how a zombie should be handled in an automated fashion. events(Event_): An attribute for declaring an action that a task can trigger while it is running. + repeat(Repeat_): An attribute for setting a repeat schedule for the node. **kwargs(str): Accept extra keyword arguments as variables to be set on the anchor family. Example:: @@ -1033,7 +1062,6 @@ def __init__(self, name, host=None, exit_hook=None, *args, **kwargs): limits(Limit_): An attribute for a simple load management by limiting the number of tasks submitted by a specific **ecFlow** server. meters(Meter_): An attribute for a range of integer values that can be set from a script. - repeat(RepeatDay_): An attribute that allows a node to be repeated infinitely. tasks(Task_): An attribute for adding a child task on the node. time(Time_): An attribute for setting a time dependency of the node. today(Today_): An attribute for setting a cron dependency of the node for the current day. @@ -1043,6 +1071,7 @@ def __init__(self, name, host=None, exit_hook=None, *args, **kwargs): generated_variables(GeneratedVariable_): An attribute for setting an **ecFlow** generated variable. zombies(Zombies_): An attribute that defines how a zombie should be handled in an automated fashion. events(Event_): An attribute for declaring an action that a task can trigger while it is running. + repeat(Repeat_): An attribute for setting a repeat schedule for the node. **kwargs(str): Accept extra keyword arguments as variables to be set on the suite. Example:: @@ -1159,11 +1188,7 @@ def deploy_suite(self, target=FileSystem, node=None, **options): for t in node.all_tasks: script, includes = t.generate_script() - try: - target.deploy_task(t.deploy_path, script, includes) - except Exception as e: - print(f"\nERROR when deploying task: {t.fullname}\n") - raise (e) + target.deploy_task(t.deploy_path, script, includes) for f in node.all_families: manual = self.generate_stub(f.manual) if manual: @@ -1243,7 +1268,7 @@ def __init__( limits(Limit_): An attribute for a simple load management by limiting the number of tasks submitted by a specific **ecFlow** server. meters(Meter_): An attribute for a range of integer values that can be set from a script. - repeat(RepeatDay_): An attribute that allows a node to be repeated infinitely. + mirror(Mirror_): An attribute for a mirroring a node on another ecflow server. tasks(Task_): An attribute for adding a child task on the node. time(Time_): An attribute for setting a time dependency of the node. today(Today_): An attribute for setting a cron dependency of the node for the current day. @@ -1253,6 +1278,7 @@ def __init__( generated_variables(GeneratedVariable_): An attribute for setting an **ecFlow** generated variable. zombies(Zombies_): An attribute that defines how a zombie should be handled in an automated fashion. events(Event_): An attribute for declaring an action that a task can trigger while it is running. + repeat(Repeat_): An attribute for setting a repeat schedule for the node. **kwargs(str): Accept extra keyword arguments as variables to be set on the task. Example:: @@ -1518,7 +1544,6 @@ def generate_script(self): ("labels", Label), ("limits", Limit), ("meters", Meter), - ("repeat", RepeatDay), ("tasks", Task), ("time", Time), ("today", Today), @@ -1527,6 +1552,7 @@ def generate_script(self): ("generated_variables", GeneratedVariable), ("zombies", Zombies), ("events", Event), + ("mirror", Mirror), ] diff --git a/pyflow/resource.py b/pyflow/resource.py index 7844aac..4bc59df 100644 --- a/pyflow/resource.py +++ b/pyflow/resource.py @@ -81,30 +81,7 @@ def get_resource(self, filename): return [] - def install_file_stub(self, target): - """ - Installs any data associated with the resource object that is going to be deployed from the **ecFlow** server. - - Parameters: - target(Deployment): The target deployment where the resource data should be installed. - """ - - """ - n.b. If a resource does not need to save data at deployment time, it should not do so (e.g. WebResource) - """ - # Install path is for the suite, so we don't need to include the suite name - assert self.fullname.count("/") > 1 - subpath = self.fullname[self.fullname.find("/", 1) + 1 :] - - self._server_filename = os.path.join( - target.files_install_path(), subpath, self.name - ) - - super().install_file_stub(target) - - self.save_data(target, self._server_filename) - - def build_script(self): + def generate_script(self): """ Returns the installer script for the data resource. @@ -129,7 +106,7 @@ def build_script(self): for h in self._hosts: lines += h.copy_file_to(self._server_filename, self.location()).split("\n") - return lines + return lines, [] def location(self): """ @@ -184,6 +161,8 @@ def save_data(self, target, filename): filename(str): The filename for the resource data. """ + self._server_filename = filename + """ Resources don't all need to save data at generation time """ @@ -207,9 +186,8 @@ class FileResource(Resource): """ def __init__(self, name, hosts, source_file): - self._source = source_file - super().__init__(name, hosts) + self._server_filename = source_file def md5(self): """ @@ -231,7 +209,7 @@ def data(self): The resource data. """ - with open(self._source, "rb") as f: + with open(self._server_filename, "rb") as f: return f.read() def save_data(self, target, filename): diff --git a/tests/file_resource.txt b/tests/file_resource.txt new file mode 100644 index 0000000..c53e6c6 --- /dev/null +++ b/tests/file_resource.txt @@ -0,0 +1 @@ +# Example for a file resource diff --git a/tests/test8.json b/tests/test8.json index 85e8970..ce617c7 100644 --- a/tests/test8.json +++ b/tests/test8.json @@ -1,10 +1,6 @@ { "FOO": 42, "f1": { - "YMD": [ - "2010-01-01", - "2011-01-01" - ], "labels": { "foo_label": "bar" }, @@ -69,6 +65,5 @@ }, "f2": { "FOO": 42 - }, - "repeat": true + } } diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 22fcaaf..594eac9 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -46,9 +46,9 @@ def test_reassign_variable(): with pyflow.Suite("s") as s: with pyflow.Task("t1") as t: t.FOO = 61 - t.FOO = (1, 10) + t.FOO = 100 with pyflow.Task("t2") as t: - t.FOO = (1, 10) + t.FOO = 100 assert s.t1.FOO.value == s.t2.FOO.value @@ -276,9 +276,11 @@ def test_date(): assert "date *.*.3" in str(s.ecflow_definition()) with pyflow.Suite("s") as s: - t1 = pyflow.Task("t1", DATE=("20180105", "20180206")) + t1 = pyflow.Task( + "t1", repeat=(pyflow.RepeatString, "DATE", ["20180105", "20180206"]) + ) - assert t1.DATE.value == ("20180105", "20180206") + assert t1.DATE.value == ["20180105", "20180206"] assert 'repeat string DATE "20180105" "20180206"' in str(s.ecflow_definition()) @@ -288,26 +290,34 @@ class TestRepeats: def test_string_repeat(self): with pyflow.Suite("s") as s: with pyflow.Family("f1") as f1: - f1.STRING_REPEAT = [str(v) for v in reversed(range(10))] + pyflow.RepeatString( + "STRING_REPEAT", [str(v) for v in reversed(range(10))] + ) t1 = pyflow.Task("t1") t1.triggers = (f1.STRING_REPEAT == "7") & (f1.STRING_REPEAT == 3) - t1.triggers |= (f1.STRING_REPEAT + 2 == 7) & (f1.STRING_REPEAT - 1 == 6) + t1.triggers |= (f1.STRING_REPEAT + 2 + 1 == 7) & ( + f1.STRING_REPEAT * 2 - 1 == 6 + ) assert ( str(t1.triggers.value) == "(((/s/f1:STRING_REPEAT eq 2)" " and (/s/f1:STRING_REPEAT eq 3))" - " or (((/s/f1:STRING_REPEAT + 2) eq 7)" - " and ((/s/f1:STRING_REPEAT - 1) eq 6)))" + " or ((((/s/f1:STRING_REPEAT + 2) + 1) eq 7)" + " and (((/s/f1:STRING_REPEAT * 2) - 1) eq 6)))" ) s.check_definition() def test_combined_string_repeats(self): with pyflow.Suite("s") as s: - t1 = pyflow.Task("t1", YMD=["20170101", "20180101"]) - t2 = pyflow.Task("t2", YMD=["20170101", "20180101"]) - t2.triggers = t1.YMD >= t2.YMD + t1 = pyflow.Task( + "t1", repeat=(pyflow.RepeatDate, "YMD", "20170101", "20180101") + ) + t2 = pyflow.Task( + "t2", repeat=(pyflow.RepeatDate, "YMD", "20170101", "20180101") + ) + t2.triggers = t1.YMD >= t2.repeat assert str(t2.triggers.value) == "(/s/t1:YMD ge /s/t2:YMD)" s.check_definition() @@ -315,29 +325,31 @@ def test_combined_string_repeats(self): def test_enumerated_repeat(self): with pyflow.Suite("s") as s: with pyflow.Family("f2") as f2: - f2.ENUMERATED_REPEAT = list(reversed(range(10))) + pyflow.RepeatEnumerated("ENUMERATED_REPEAT", list(range(10))) t2 = pyflow.Task("t2") t2.triggers = (f2.ENUMERATED_REPEAT == "7") & ( f2.ENUMERATED_REPEAT == 3 ) - t2.triggers |= (f2.ENUMERATED_REPEAT + 2 == 7) & ( - f2.ENUMERATED_REPEAT - 1 == 6 + t2.triggers |= (f2.ENUMERATED_REPEAT + 2 + 1 == 7) & ( + f2.ENUMERATED_REPEAT * 2 - 1 == 6 ) assert ( str(t2.triggers.value) == "(((/s/f2:ENUMERATED_REPEAT eq 7) and " "(/s/f2:ENUMERATED_REPEAT eq 3)) or " - "(((/s/f2:ENUMERATED_REPEAT + 2) eq 7) and " - "((/s/f2:ENUMERATED_REPEAT - 1) eq 6)))" + "((((/s/f2:ENUMERATED_REPEAT + 2) + 1) eq 7) and " + "(((/s/f2:ENUMERATED_REPEAT * 2) - 1) eq 6)))" ) s.check_definition() def test_combined_enumerated_repeats(self): with pyflow.Suite("s") as s: - t1 = pyflow.Task("t1", ENUMERATED_REPEAT=list(range(10))) - t2 = pyflow.Task("t2", ENUMERATED_REPEAT=list(range(10))) + with pyflow.Task("t1") as t1: + pyflow.RepeatEnumerated("ENUMERATED_REPEAT", list(range(10))) + with pyflow.Task("t2") as t2: + pyflow.RepeatEnumerated("ENUMERATED_REPEAT", list(range(10))) t2.triggers = t1.ENUMERATED_REPEAT >= t2.ENUMERATED_REPEAT assert ( str(t2.triggers.value) @@ -353,15 +365,15 @@ def test_integer_repeat(self): t3 = pyflow.Task("t3") t3.triggers = (f3.INTEGER_REPEAT == "7") & (f3.INTEGER_REPEAT == 3) - t3.triggers |= (f3.INTEGER_REPEAT + 2 == 7) & ( - f3.INTEGER_REPEAT - 1 == 6 + t3.triggers |= (f3.INTEGER_REPEAT + 2 + 1 == 7) & ( + f3.INTEGER_REPEAT * 2 - 1 == 6 ) assert ( str(t3.triggers.value) == "(((/s/f3:INTEGER_REPEAT eq 7) and " "(/s/f3:INTEGER_REPEAT eq 3)) or " - "(((/s/f3:INTEGER_REPEAT + 2) eq 7) and " - "((/s/f3:INTEGER_REPEAT - 1) eq 6)))" + "((((/s/f3:INTEGER_REPEAT + 2) + 1) eq 7) and " + "(((/s/f3:INTEGER_REPEAT * 2) - 1) eq 6)))" ) s.check_definition() @@ -382,7 +394,9 @@ def test_combined_integer_repeats(self): def test_date_datetime_repeat(self): with pyflow.Suite("s") as s: with pyflow.Family("f4") as f4: - f4.DATE_REPEAT = (datetime(2018, 1, 1), datetime(2019, 12, 31)) + pyflow.RepeatDateTime( + "DATE_REPEAT", datetime(2018, 1, 1), datetime(2019, 12, 31) + ) t4 = pyflow.Task("t4") t4.triggers = (f4.DATE_REPEAT >= "20180301") & ( @@ -403,7 +417,7 @@ def test_date_datetime_repeat(self): def test_date_date_repeat(self): with pyflow.Suite("s") as s: with pyflow.Family("f4") as f4: - f4.DATE_REPEAT = (date(2018, 1, 1), date(2019, 12, 31)) + pyflow.RepeatDate("DATE_REPEAT", date(2018, 1, 1), date(2019, 12, 31)) t4 = pyflow.Task("t4") t4.triggers = (f4.DATE_REPEAT >= "20180301") & ( @@ -669,46 +683,36 @@ class TestMirror: """A set of tests for Mirror attributes.""" def test_create_mirror_from_strings(self): - name = "MIRROR_ATTRIBUTE" - remote_path = "/s/f/t" - remote_host = "remote-ecflow-server" - remote_port = "3141" - polling = "%ECFLOW_MIRROR_POLLING%" - ssl = True - auth = "/path/to/auth.json" + setup = { + "name": "MIRROR_ATTRIBUTE", + "remote_path": "/s/f/t", + "remote_host": "remote-ecflow-server", + "remote_port": "3141", + "polling": "%ECFLOW_MIRROR_REMOTE_POLLING%", + "ssl": True, + "auth": "/path/to/auth.json", + } - attr = pyflow.Mirror( - name, remote_path, remote_host, remote_port, polling, ssl, auth - ) + attr = pyflow.Mirror(**setup) - assert attr.name == name - assert attr.remote_path == remote_path - assert attr.remote_host == remote_host - assert attr.remote_port == remote_port - assert attr.polling == polling - assert attr.ssl == ssl - assert attr.auth == auth + for name, value in setup.items(): + assert getattr(attr, name) == value def test_create_mirror_from_objects(self): - name = "MIRROR_ATTRIBUTE" - remote_path = "/s/f/t" - remote_host = "remote-ecflow-server" - remote_port = 3141 - polling = 60 - ssl = True - auth = "/path/to/auth.json" + setup = { + "name": "MIRROR_ATTRIBUTE", + "remote_path": "/s/f/t", + "remote_host": "remote-ecflow-server", + "remote_port": "3141", + "polling": 600, + "ssl": True, + "auth": "/path/to/auth.json", + } - attr = pyflow.Mirror( - name, remote_path, remote_host, remote_port, polling, ssl, auth - ) + attr = pyflow.Mirror(**setup) - assert attr.name == name - assert attr.remote_path == remote_path - assert attr.remote_host == remote_host - assert attr.remote_port == str(remote_port) - assert attr.polling == str(polling) - assert attr.ssl == ssl - assert attr.auth == auth + for name, value in setup.items(): + assert type(value)(getattr(attr, name)) == value def test_create_mirror_on_task(self): with pyflow.Suite("s") as s: @@ -718,16 +722,14 @@ def test_create_mirror_on_task(self): with pyflow.Task("t") as t: assert "t" == t.name - name = "MIRROR_ATTRIBUTE" - remote_path = "/s/f/t" - remote_host = "remote-ecflow-server" - remote_port = 3141 - polling = 60 - ssl = True - auth = "/path/to/auth.json" - pyflow.Mirror( - name, remote_path, remote_host, remote_port, polling, ssl, auth + name="MIRROR_ATTRIBUTE", + remote_path="/s/f/t", + remote_host="remote-ecflow-server", + remote_port="3141", + polling=600, + ssl=True, + auth="/path/to/auth.json", ) s.check_definition() @@ -740,16 +742,14 @@ def test_definitions_content_with_mirror_attribute(self): with pyflow.Task("t") as t: assert "t" == t.name - name = "MIRROR_ATTRIBUTE" - remote_path = "/s/f/t" - remote_host = "remote-ecflow-server" - remote_port = 3141 - polling = 60 - ssl = True - auth = "/path/to/auth.json" - pyflow.Mirror( - name, remote_path, remote_host, remote_port, polling, ssl, auth + name="MIRROR_ATTRIBUTE", + remote_path="/s/f/t", + remote_host="remote-ecflow-server", + remote_port="3141", + polling=600, + ssl=True, + auth="/path/to/auth.json", ) defs = s.ecflow_definition() diff --git a/tests/test_contextmanager.py b/tests/test_contextmanager.py index f7fcebc..5ca6eb6 100644 --- a/tests/test_contextmanager.py +++ b/tests/test_contextmanager.py @@ -9,6 +9,7 @@ Label, Limit, Meter, + RepeatEnumerated, Suite, Task, Tasks, @@ -22,15 +23,15 @@ def test_suite(): Limit("foo", 1) Limit("bar", 2) - with Family("f", BAR=[1, 2, 3]): + with Family("f", repeat=(RepeatEnumerated, "BAR", [1, 2, 3])): Task("t1") Task("t2") Task("t3").triggers = (s.f.t1 == "complete") | "2 < 8" - with Family("g") as g: + with Family("g"): InLimit("foo") - g.QUUX = [1, 2] + RepeatEnumerated("QUUX", [1, 2]) Task("t4").triggers = s.f.t1 == "aborted" with Task("t5"): diff --git a/tests/test_extern.py b/tests/test_extern.py index 1baa6a8..2aa0059 100644 --- a/tests/test_extern.py +++ b/tests/test_extern.py @@ -7,29 +7,35 @@ Event, ExternEvent, ExternFamily, + ExternLimit, ExternMeter, + ExternRepeat, + ExternSuite, ExternTask, - ExternYMD, + ExternVariable, Family, + Limit, Meter, Notebook, RepeatDate, Suite, Task, + Variable, ) -from pyflow.extern import KNOWN_EXTERNS +from pyflow.extern import KNOWN_EXTERNS, Repeat now = datetime.datetime.now() def test_extern(): with Suite("s") as s: - t1 = Task("t1", YMD=(now, now)) + t1 = Task("t1", repeat=(RepeatDate, "YMD", now, now)) et = ExternTask("/a/b/c/d") ef = ExternFamily("/f/g/h/i") + es = ExternSuite("/j") - t1.triggers = et & ef + t1.triggers = et & ef & es # Check that the externs have real types --> will have correct functionality available @@ -39,6 +45,9 @@ def test_extern(): assert isinstance(ef, Family) assert ef.name == "i" assert ef.fullname == "/f/g/h/i" + assert isinstance(es, Suite) + assert es.name == "j" + assert es.fullname == "/j" # Check that they work! @@ -62,22 +71,46 @@ def test_extern(): with pytest.raises(AssertionError) as excinfo: s.ecflow_definition() - assert excinfo.value.args == ("Attempting to add unknown extern reference",) + assert excinfo.value.args == ( + "Attempting to add unknown extern reference /a/b/c/d", + ) def test_extern_attributes(): + sext = ExternSuite("/limits") # extern shall not be under a node suite/family/task + evar = ExternVariable("/a/main:SUITE_START") + svar = ExternVariable("/a:SUITE_START") + limit = ExternLimit( + "/limits:hpc" + ) # extern shall not be under a node suite/family/task + + # svar = ExternEdit("/b:SUITE_START") # OK with Suite("s") as s: - eymd = ExternYMD("/a/b/c/d:YMD") + eymd = ExternRepeat("/a/b/c/d:YMD") + elimit = ExternLimit("/limits/lim:hpc") + slimit = ExternLimit("/limits:hpc") eevent = ExternEvent("/e/f/g/h:ev") emeter = ExternMeter("/g/h/i/j:mt") - Task("t1", YMD=(now, now)).follow = eymd + t1 = Task("t1", repeat=(RepeatDate, "YMD", now, now)) + t1.follow = eymd Task("t2").triggers = eevent Task("t3").triggers = emeter == 10 + Task("t4").completes = evar != eymd + Task("t5").inlimits = [elimit, slimit] + Task("t6").completes = svar != eymd + Task("ts").completes = sext.complete + # Check that the externs have real types --> will have correct functionality available - # Check that the externs have real types --> will have correct functionality available + assert isinstance(elimit, Limit) + assert limit.name == "hpc" + assert limit.fullname == "/limits:hpc" + + assert isinstance(svar, Variable) + assert svar.name == "SUITE_START" + assert svar.fullname == "/a:SUITE_START" - assert isinstance(eymd, RepeatDate) + assert isinstance(eymd, Repeat) assert eymd.name == "YMD" assert eymd.fullname == "/a/b/c/d:YMD" @@ -109,16 +142,21 @@ def test_extern_attributes(): def test_extern_safety(): externs = [] + externs.append(ExternSuite("/limits")) + externs.append(ExternLimit("/limits:hpc")) + externs.append(ExternLimit("/limits/lim:hpc")) + externs.append(ExternVariable("/a/main:SUITE_START")) + externs.append(ExternVariable("/a:SUITE_START")) with Suite("s"): - externs.append(ExternTask("/a/b/c/d")) - externs.append(ExternFamily("/e/f/g/h")) + externs.append(ExternFamily("/a/b/c/d")) + externs.append(ExternTask("/e/f/g/h")) with externs[-1]: # n.b. should never do this in reality, but trying to break things... externs.append(Task("e3")) - externs.append(ExternYMD("i/j/k/l:YMD")) + externs.append(ExternRepeat("i/j/k/l:YMD")) externs.append(ExternEvent("m/n/o/p:ev")) externs.append(ExternMeter("q/s/t/u:mt")) diff --git a/tests/test_follow.py b/tests/test_follow.py index 81caa29..64c7666 100644 --- a/tests/test_follow.py +++ b/tests/test_follow.py @@ -1,23 +1,35 @@ import datetime -from pyflow import Notebook, Suite, Task +from pyflow import Family, Notebook, RepeatDate, Suite, Task -now = datetime.datetime.now() +# Use a date object, not the datetime.date descriptor method +now = datetime.date(2025, 8, 14) def test_follow(): with Suite("s") as s: - t1 = Task("t1", YMD=(now, now)) - t2 = Task("t2") - Task("t3") + with Task("t1"): + r1 = RepeatDate("YMD1", now, now) + with Family("f1") as f1: + f1.repeat = (RepeatDate, "YMD2", now, now) + t2 = Task("t2") + t3 = Task("t3", repeat=(RepeatDate, "YMD3", now, now)) - t2.triggers = t1.complete - t2.follow = t1.YMD + t2.follow = r1 + t3.follow = t2 s.check_definition() s.generate_node() s.deploy_suite(target=Notebook) + print(s) + print(t3.repeat) + print(str(t2.triggers)) + print(str(t3.triggers)) + print(t2.triggers) + assert "trigger ../t1 eq complete or ../f1:YMD2 lt ../t1:YMD1" in str(t2.triggers) + assert "trigger f1/t2 eq complete or t3:YMD3 lt f1:YMD2" in str(t3.triggers) + assert "repeat date YMD3 20250814 20250814 1" in str(t3.triggers) if __name__ == "__main__": diff --git a/tests/test_host.py b/tests/test_host.py index 09fb0b4..9978600 100644 --- a/tests/test_host.py +++ b/tests/test_host.py @@ -2,6 +2,18 @@ import pyflow import pyflow.host +from pyflow.host import ( + HOST_REGISTRY, + LocalHost, + NullHost, + PBSHost, + SimpleSSHHost, + SLURMHost, + SSHHost, + TroikaHost, + host_factory, + register_host, +) def test_host_task(): @@ -258,13 +270,13 @@ def test_troika_host(): host1 = pyflow.TroikaHost( name="test_host", user="test_user", + troika_version="0.2.1", + troika_config="%TROIKA_CONFIG%", ) - host2 = pyflow.TroikaHost( - name="test_host", user="test_user", troika_version="2.2.2" - ) + host2 = pyflow.TroikaHost(name="test_host", user="test_user") submit_args = { - "tasks": 2, # deprecated option, will be translated to total_tasks + "total_tasks": 2, "gpus": 1, "sthost": "/foo/bar", "distribution": "test", # generates TROIKA pragma for recent version of troika, SBATCH for older versions @@ -284,11 +296,11 @@ def test_troika_host(): assert ( s.ECF_JOB_CMD.value - == "%TROIKA:troika% -vv submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" + == "%TROIKA:troika% -vv -c %TROIKA_CONFIG% submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" ) assert ( s.ECF_KILL_CMD.value - == "%TROIKA:troika% -vv kill -u test_user test_host %ECF_JOB%" + == "%TROIKA:troika% -vv -c %TROIKA_CONFIG% kill -u test_user test_host %ECF_JOB%" ) t1_script = t1.generate_script() @@ -317,7 +329,7 @@ def test_troika_host(): def test_host_submit_args(): submit_args = { "troika": { - "tasks": 2, # deprecated option, will be translated to total_tasks + "total_tasks": 2, "gpus": 1, "sthost": "/foo/bar", "distribution": "test", # generates TROIKA pragma for recent version of troika, SBATCH for older versions @@ -385,15 +397,34 @@ def test_troika_host_options(): assert ( s.ECF_JOB_CMD.value - == "%TROIKA:/path/to/troika% -vv -c %TROIKA_CONFIG:/path/to/troika.cfg% submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501 + == "/path/to/troika -vv -c /path/to/troika.cfg submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501 ) assert ( s.ECF_KILL_CMD.value - == "%TROIKA:/path/to/troika% -vv -c %TROIKA_CONFIG:/path/to/troika.cfg% kill -u test_user test_host %ECF_JOB%" # noqa: E501 + == "/path/to/troika -vv -c /path/to/troika.cfg kill -u test_user test_host %ECF_JOB%" # noqa: E501 ) assert s.host.troika_version == (2, 1, 3) +def test_troika_host_options_no_config(): + host = pyflow.TroikaHost( + name="test_host", + user="test_user", + troika_config=None, + ) + + s = pyflow.Suite("s", host=host) + + assert ( + s.ECF_JOB_CMD.value + == "%TROIKA:troika% -vv submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501 + ) + assert ( + s.ECF_KILL_CMD.value + == "%TROIKA:troika% -vv kill -u test_user test_host %ECF_JOB%" # noqa: E501 + ) + + def test_traps(): sigs = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 13] with pyflow.Suite("s") as s1: @@ -416,6 +447,77 @@ def test_traps(): assert signal_list2 in s2 +@pytest.mark.parametrize( + "key,expected_class,kwargs", + [ + ("null", NullHost, {}), + ("localhost", LocalHost, {}), + ("ssh", SSHHost, {"name": "test"}), + ("ssh-simple", SimpleSSHHost, {"host": "test"}), + ("slurm", SLURMHost, {"name": "test"}), + ("pbs", PBSHost, {"name": "test"}), + ("troika", TroikaHost, {"name": "test", "user": "testuser"}), + ], +) +def test_host_factory_returns_correct_types(key, expected_class, kwargs): + result = host_factory(key, **kwargs) + assert isinstance(result, expected_class) + + +def test_host_factory_forwards_kwargs(): + result = host_factory("localhost", name="myhost", scratch_directory="/tmp/test") + assert result.name == "myhost" + assert result.scratch_directory == "/tmp/test" + + +def test_host_factory_raises_and_lists_available_types(): + with pytest.raises(ValueError, match="Unknown host type: bogus") as exc_info: + host_factory("bogus") + exc_str = str(exc_info.value) + for key in ("null", "localhost", "ssh", "ssh-simple", "slurm", "pbs", "troika"): + assert key in exc_str + + +def test_register_host_adds_to_registry(): + try: + + @register_host("test-dummy") + class DummyHost: + pass + + assert HOST_REGISTRY["test-dummy"] is DummyHost + finally: + del HOST_REGISTRY["test-dummy"] + + +def test_register_host_returns_class_unchanged(): + try: + + class DummyHost2: + pass + + result = register_host("test-dummy2")(DummyHost2) + assert result is DummyHost2 + finally: + del HOST_REGISTRY["test-dummy2"] + + +def test_register_host_duplicate_key_overwrites(): + try: + + @register_host("test-dup") + class DummyHostA: + pass + + @register_host("test-dup") + class DummyHostB: + pass + + assert HOST_REGISTRY["test-dup"] is DummyHostB + finally: + del HOST_REGISTRY["test-dup"] + + if __name__ == "__main__": from os import path diff --git a/tests/test_json.py b/tests/test_json.py index 29f53da..9318013 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -9,6 +9,7 @@ def test_json(): x = json.loads(f.read()) s = Suite("s", json=x) + print(s) s.check_definition() s.generate_node() diff --git a/tests/test_resource.py b/tests/test_resource.py new file mode 100644 index 0000000..bda8046 --- /dev/null +++ b/tests/test_resource.py @@ -0,0 +1,44 @@ +from os import path +from os.path import join + +from pyflow import FileResource, Notebook, Suite +from pyflow.host import SSHHost + + +def test_file_resource(): + + resouces_directory = "/resources_directory" + + sshhost_1 = SSHHost("example_ssh_host_1", resources_directory=resouces_directory) + sshhost_2 = SSHHost("example_ssh_host_2", resources_directory=resouces_directory) + host_set = [sshhost_1, sshhost_2] + + source_file = path.join(path.dirname(path.abspath(__file__)), "file_resource.txt") + name = "file_resource" + + with Suite("s", host=sshhost_1) as s: + s.resource_file = FileResource(name, hosts=host_set, source_file=source_file) + + # Check that variables are set correctly + assert s.resource_file.host == sshhost_1 + assert s.resource_file.location() == join( + str(sshhost_1.resources_directory), s.name, name + ) + assert s.resource_file._hosts == host_set + + # Check that the deployment scripts have been generated + s.check_definition() + s.generate_node() + + s.deploy_suite(target=Notebook) + + generate_file_resource_script_lines, _ = s.resource_file.generate_script() + assert any(sshhost_1.name in s for s in generate_file_resource_script_lines) + assert any(sshhost_2.name in s for s in generate_file_resource_script_lines) + + +if __name__ == "__main__": + + import pytest + + pytest.main(path.abspath(__file__)) diff --git a/tests/test_script.py b/tests/test_script.py index bdb4ba7..582dcb3 100644 --- a/tests/test_script.py +++ b/tests/test_script.py @@ -44,20 +44,14 @@ def test_script_lists(): t2.script += "echo 'bit4'" t2.script += ["echo 'bit5'", "echo 'bit6'"] - checkscript = os.linesep.join( - line - for line in textwrap.dedent( - """ + checkscript = os.linesep.join(line for line in textwrap.dedent(""" echo 'bit1' echo 'bit2' echo 'bit3' echo 'bit4' echo 'bit5' echo 'bit6' - """ - ).splitlines() - if line - ) + """).splitlines() if line) assert t1.script.value == checkscript assert t2.script.value == checkscript @@ -156,10 +150,7 @@ def test_python_script(): t = pyflow.Task("t", script=[s1, s2, s3, s4, s5]) - checkscript = os.linesep.join( - line - for line in textwrap.dedent( - """ + checkscript = os.linesep.join(line for line in textwrap.dedent(""" echo 'bit1' python3 -u - <