Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ docs/_build
_version.py
.vscode
.DS_Store
.weave
80 changes: 66 additions & 14 deletions pyflow/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@

SSH_COMMAND = "ssh -v -o StrictHostKeyChecking=no"

HOST_REGISTRY = {}


def register_host(registry_key):
"""
Registers a host class in the host registry.

Parameters:
registry_key(str): The key to register the host class under.
"""

def decorator(cls):
HOST_REGISTRY[registry_key] = cls
return cls

return decorator


class Host:
"""
Expand Down Expand Up @@ -453,6 +470,7 @@ def job_preamble(self, exit_hook=None):
) + self.preamble_error_function(self.ecflow_path, exit_hook).split("\n")


@register_host("null")
class NullHost(Host):
"""
A dummy host object invisible to **ecFlow**, but still throws exceptions if **pyflow** attempts to create tasks
Expand Down Expand Up @@ -517,6 +535,7 @@ def build_label(self):
return None


@register_host("localhost")
class LocalHost(Host):
"""
A host object that executes scripts directly on the **ecFlow** server.
Expand Down Expand Up @@ -628,6 +647,7 @@ def copy_file_to(self, source_file, target_file):
)


@register_host("ecflow-default")
class EcflowDefaultHost(LocalHost):
"""
By default we just use LocalHost... Slightly modified from ecflow default of
Expand All @@ -640,6 +660,7 @@ def __init__(self, **kwargs):
super().__init__("default", **kwargs)


@register_host("ssh")
class SSHHost(Host):
"""
A host object that executes scripts on the **ecFlow** server via SSH protocol.
Expand Down Expand Up @@ -815,9 +836,10 @@ def host_postamble(self):
return []


@register_host("ssh-simple")
class SimpleSSHHost(Host):
def __init__(self, host):
super().__init__(host)
def __init__(self, host, **kwargs):
super().__init__(host, **kwargs)
self.host = host

@property
Expand Down Expand Up @@ -849,6 +871,7 @@ def host_postamble(self):
return POSTAMBLE_SUBMITTED_JOBS.split("\n")


@register_host("slurm")
class SLURMHost(SSHHost):
"""
A host object that executes scripts on the **ecFlow** server via Slurm job scheduling system.
Expand Down Expand Up @@ -943,6 +966,7 @@ def host_postamble(self):
return POSTAMBLE_SUBMITTED_JOBS.split("\n")


@register_host("pbs")
class PBSHost(SSHHost):
"""
A host object that executes scripts on the **ecFlow** server via batch server.
Expand Down Expand Up @@ -1037,13 +1061,18 @@ def host_postamble(self):
return POSTAMBLE_SUBMITTED_JOBS.split("\n")


@register_host("troika")
class TroikaHost(Host):
"""
A host object that executes scripts on the **ecFlow** server via the troika job submitter.

Parameters:
name(str): The name of the host.
user(str): The user to use for troika commands to the host.
troika_exec(str): The path to the troika executable, defaults to `%TROIKA:troika%`.
troika_config(str): The path to the troika configuration file, defaults to `%TROIKA_CONFIG%`.
Value False or None will deactivate the config in the command.
troika_version(str): The version of the troika executable, defaults to `0.2.3`.
hostname(str): The hostname of the host, otherwise `name` will be used.
scratch_directory(str): The path in which tasks will be run, unless otherwise specified.
log_directory(str): The directory to use for script output. Normally `ECF_HOME`, but may need to be changed on
Expand All @@ -1068,24 +1097,26 @@ class TroikaHost(Host):
pass
"""

def __init__(self, name, user, **kwargs):
self.troika_exec = kwargs.pop("troika_exec", "troika")
self.troika_config = kwargs.pop("troika_config", "")
self.troika_version = tuple(
map(int, kwargs.pop("troika_version", "0.2.1").split("."))
)
def __init__(
self,
name,
user,
troika_exec="%TROIKA:troika%",
troika_config=None,
troika_version="0.2.3",
**kwargs,
):
self.troika_exec = troika_exec
self.troika_config = troika_config
self.troika_version = tuple(map(int, troika_version.split(".")))
super().__init__(name, user=user, **kwargs)

def troika_command(self, command):
cmd = " ".join(
[
f"%TROIKA:{self.troika_exec}%",
f"{self.troika_exec}",
"-vv",
(
f"-c %TROIKA_CONFIG:{self.troika_config}%"
if self.troika_config
else ""
),
(f"-c {self.troika_config}" if self.troika_config else ""),
f"{command}",
f"-u {self.user}",
]
Expand Down Expand Up @@ -1204,3 +1235,24 @@ def _translate_sthost(val):
args.append("#TROIKA {}={}".format(arg, val))

return args


def host_factory(key, *args, **kwargs):
"""
Factory function to create host objects based on a key.

Parameters:
key(str): The key specifying the type of host to create.
*args: Positional arguments to pass to the host constructor.
**kwargs: Keyword arguments to pass to the host
constructor.
Returns:
Host: The created host object.
"""

if (target := HOST_REGISTRY.get(key)) is not None:
return target(*args, **kwargs)
else:
raise ValueError(
f"Unknown host type: {key}. Available host types are: {list(HOST_REGISTRY.keys())}"
)
116 changes: 109 additions & 7 deletions tests/test_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

import pyflow
import pyflow.host
from pyflow.host import (
HOST_REGISTRY,
LocalHost,
NullHost,
PBSHost,
SimpleSSHHost,
SLURMHost,
SSHHost,
TroikaHost,
host_factory,
register_host,
)


def test_host_task():
Expand Down Expand Up @@ -258,10 +270,10 @@ def test_troika_host():
host1 = pyflow.TroikaHost(
name="test_host",
user="test_user",
troika_version="0.2.1",
troika_config="%TROIKA_CONFIG%",
)
host2 = pyflow.TroikaHost(
name="test_host", user="test_user", troika_version="2.2.2"
)
host2 = pyflow.TroikaHost(name="test_host", user="test_user")

submit_args = {
"total_tasks": 2,
Expand All @@ -284,11 +296,11 @@ def test_troika_host():

assert (
s.ECF_JOB_CMD.value
== "%TROIKA:troika% -vv submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%"
== "%TROIKA:troika% -vv -c %TROIKA_CONFIG% submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%"
)
assert (
s.ECF_KILL_CMD.value
== "%TROIKA:troika% -vv kill -u test_user test_host %ECF_JOB%"
== "%TROIKA:troika% -vv -c %TROIKA_CONFIG% kill -u test_user test_host %ECF_JOB%"
)

t1_script = t1.generate_script()
Expand Down Expand Up @@ -385,15 +397,34 @@ def test_troika_host_options():

assert (
s.ECF_JOB_CMD.value
== "%TROIKA:/path/to/troika% -vv -c %TROIKA_CONFIG:/path/to/troika.cfg% submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501
== "/path/to/troika -vv -c /path/to/troika.cfg submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501
)
assert (
s.ECF_KILL_CMD.value
== "%TROIKA:/path/to/troika% -vv -c %TROIKA_CONFIG:/path/to/troika.cfg% kill -u test_user test_host %ECF_JOB%" # noqa: E501
== "/path/to/troika -vv -c /path/to/troika.cfg kill -u test_user test_host %ECF_JOB%" # noqa: E501
)
assert s.host.troika_version == (2, 1, 3)


def test_troika_host_options_no_config():
host = pyflow.TroikaHost(
name="test_host",
user="test_user",
troika_config=None,
)

s = pyflow.Suite("s", host=host)

assert (
s.ECF_JOB_CMD.value
== "%TROIKA:troika% -vv submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501
)
assert (
s.ECF_KILL_CMD.value
== "%TROIKA:troika% -vv kill -u test_user test_host %ECF_JOB%" # noqa: E501
)


def test_traps():
sigs = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 13]
with pyflow.Suite("s") as s1:
Expand All @@ -416,6 +447,77 @@ def test_traps():
assert signal_list2 in s2


@pytest.mark.parametrize(
"key,expected_class,kwargs",
[
("null", NullHost, {}),
("localhost", LocalHost, {}),
("ssh", SSHHost, {"name": "test"}),
("ssh-simple", SimpleSSHHost, {"host": "test"}),
("slurm", SLURMHost, {"name": "test"}),
("pbs", PBSHost, {"name": "test"}),
("troika", TroikaHost, {"name": "test", "user": "testuser"}),
],
)
def test_host_factory_returns_correct_types(key, expected_class, kwargs):
result = host_factory(key, **kwargs)
assert isinstance(result, expected_class)


def test_host_factory_forwards_kwargs():
result = host_factory("localhost", name="myhost", scratch_directory="/tmp/test")
assert result.name == "myhost"
assert result.scratch_directory == "/tmp/test"


def test_host_factory_raises_and_lists_available_types():
with pytest.raises(ValueError, match="Unknown host type: bogus") as exc_info:
host_factory("bogus")
exc_str = str(exc_info.value)
for key in ("null", "localhost", "ssh", "ssh-simple", "slurm", "pbs", "troika"):
assert key in exc_str


def test_register_host_adds_to_registry():
try:

@register_host("test-dummy")
class DummyHost:
pass

assert HOST_REGISTRY["test-dummy"] is DummyHost
finally:
del HOST_REGISTRY["test-dummy"]


def test_register_host_returns_class_unchanged():
try:

class DummyHost2:
pass

result = register_host("test-dummy2")(DummyHost2)
assert result is DummyHost2
finally:
del HOST_REGISTRY["test-dummy2"]


def test_register_host_duplicate_key_overwrites():
try:

@register_host("test-dup")
class DummyHostA:
pass

@register_host("test-dup")
class DummyHostB:
pass

assert HOST_REGISTRY["test-dup"] is DummyHostB
finally:
del HOST_REGISTRY["test-dup"]


if __name__ == "__main__":
from os import path

Expand Down
Loading