diff --git a/.gitignore b/.gitignore index 82bab73..1dc13ca 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ docs/_build _version.py .vscode .DS_Store +.weave \ No newline at end of file diff --git a/pyflow/host.py b/pyflow/host.py index d8ec54c..9f407a5 100644 --- a/pyflow/host.py +++ b/pyflow/host.py @@ -83,6 +83,23 @@ SSH_COMMAND = "ssh -v -o StrictHostKeyChecking=no" +HOST_REGISTRY = {} + + +def register_host(registry_key): + """ + Registers a host class in the host registry. + + Parameters: + registry_key(str): The key to register the host class under. + """ + + def decorator(cls): + HOST_REGISTRY[registry_key] = cls + return cls + + return decorator + class Host: """ @@ -453,6 +470,7 @@ def job_preamble(self, exit_hook=None): ) + self.preamble_error_function(self.ecflow_path, exit_hook).split("\n") +@register_host("null") class NullHost(Host): """ A dummy host object invisible to **ecFlow**, but still throws exceptions if **pyflow** attempts to create tasks @@ -517,6 +535,7 @@ def build_label(self): return None +@register_host("localhost") class LocalHost(Host): """ A host object that executes scripts directly on the **ecFlow** server. @@ -628,6 +647,7 @@ def copy_file_to(self, source_file, target_file): ) +@register_host("ecflow-default") class EcflowDefaultHost(LocalHost): """ By default we just use LocalHost... Slightly modified from ecflow default of @@ -640,6 +660,7 @@ def __init__(self, **kwargs): super().__init__("default", **kwargs) +@register_host("ssh") class SSHHost(Host): """ A host object that executes scripts on the **ecFlow** server via SSH protocol. @@ -815,9 +836,10 @@ def host_postamble(self): return [] +@register_host("ssh-simple") class SimpleSSHHost(Host): - def __init__(self, host): - super().__init__(host) + def __init__(self, host, **kwargs): + super().__init__(host, **kwargs) self.host = host @property @@ -849,6 +871,7 @@ def host_postamble(self): return POSTAMBLE_SUBMITTED_JOBS.split("\n") +@register_host("slurm") class SLURMHost(SSHHost): """ A host object that executes scripts on the **ecFlow** server via Slurm job scheduling system. @@ -943,6 +966,7 @@ def host_postamble(self): return POSTAMBLE_SUBMITTED_JOBS.split("\n") +@register_host("pbs") class PBSHost(SSHHost): """ A host object that executes scripts on the **ecFlow** server via batch server. @@ -1037,6 +1061,7 @@ def host_postamble(self): return POSTAMBLE_SUBMITTED_JOBS.split("\n") +@register_host("troika") class TroikaHost(Host): """ A host object that executes scripts on the **ecFlow** server via the troika job submitter. @@ -1044,6 +1069,10 @@ class TroikaHost(Host): Parameters: name(str): The name of the host. user(str): The user to use for troika commands to the host. + troika_exec(str): The path to the troika executable, defaults to `%TROIKA:troika%`. + troika_config(str): The path to the troika configuration file, defaults to `%TROIKA_CONFIG%`. + Value False or None will deactivate the config in the command. + troika_version(str): The version of the troika executable, defaults to `0.2.3`. hostname(str): The hostname of the host, otherwise `name` will be used. scratch_directory(str): The path in which tasks will be run, unless otherwise specified. log_directory(str): The directory to use for script output. Normally `ECF_HOME`, but may need to be changed on @@ -1068,24 +1097,26 @@ class TroikaHost(Host): pass """ - def __init__(self, name, user, **kwargs): - self.troika_exec = kwargs.pop("troika_exec", "troika") - self.troika_config = kwargs.pop("troika_config", "") - self.troika_version = tuple( - map(int, kwargs.pop("troika_version", "0.2.1").split(".")) - ) + def __init__( + self, + name, + user, + troika_exec="%TROIKA:troika%", + troika_config=None, + troika_version="0.2.3", + **kwargs, + ): + self.troika_exec = troika_exec + self.troika_config = troika_config + self.troika_version = tuple(map(int, troika_version.split("."))) super().__init__(name, user=user, **kwargs) def troika_command(self, command): cmd = " ".join( [ - f"%TROIKA:{self.troika_exec}%", + f"{self.troika_exec}", "-vv", - ( - f"-c %TROIKA_CONFIG:{self.troika_config}%" - if self.troika_config - else "" - ), + (f"-c {self.troika_config}" if self.troika_config else ""), f"{command}", f"-u {self.user}", ] @@ -1204,3 +1235,24 @@ def _translate_sthost(val): args.append("#TROIKA {}={}".format(arg, val)) return args + + +def host_factory(key, *args, **kwargs): + """ + Factory function to create host objects based on a key. + + Parameters: + key(str): The key specifying the type of host to create. + *args: Positional arguments to pass to the host constructor. + **kwargs: Keyword arguments to pass to the host + constructor. + Returns: + Host: The created host object. + """ + + if (target := HOST_REGISTRY.get(key)) is not None: + return target(*args, **kwargs) + else: + raise ValueError( + f"Unknown host type: {key}. Available host types are: {list(HOST_REGISTRY.keys())}" + ) diff --git a/tests/test_host.py b/tests/test_host.py index fc18260..9978600 100644 --- a/tests/test_host.py +++ b/tests/test_host.py @@ -2,6 +2,18 @@ import pyflow import pyflow.host +from pyflow.host import ( + HOST_REGISTRY, + LocalHost, + NullHost, + PBSHost, + SimpleSSHHost, + SLURMHost, + SSHHost, + TroikaHost, + host_factory, + register_host, +) def test_host_task(): @@ -258,10 +270,10 @@ def test_troika_host(): host1 = pyflow.TroikaHost( name="test_host", user="test_user", + troika_version="0.2.1", + troika_config="%TROIKA_CONFIG%", ) - host2 = pyflow.TroikaHost( - name="test_host", user="test_user", troika_version="2.2.2" - ) + host2 = pyflow.TroikaHost(name="test_host", user="test_user") submit_args = { "total_tasks": 2, @@ -284,11 +296,11 @@ def test_troika_host(): assert ( s.ECF_JOB_CMD.value - == "%TROIKA:troika% -vv submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" + == "%TROIKA:troika% -vv -c %TROIKA_CONFIG% submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" ) assert ( s.ECF_KILL_CMD.value - == "%TROIKA:troika% -vv kill -u test_user test_host %ECF_JOB%" + == "%TROIKA:troika% -vv -c %TROIKA_CONFIG% kill -u test_user test_host %ECF_JOB%" ) t1_script = t1.generate_script() @@ -385,15 +397,34 @@ def test_troika_host_options(): assert ( s.ECF_JOB_CMD.value - == "%TROIKA:/path/to/troika% -vv -c %TROIKA_CONFIG:/path/to/troika.cfg% submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501 + == "/path/to/troika -vv -c /path/to/troika.cfg submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501 ) assert ( s.ECF_KILL_CMD.value - == "%TROIKA:/path/to/troika% -vv -c %TROIKA_CONFIG:/path/to/troika.cfg% kill -u test_user test_host %ECF_JOB%" # noqa: E501 + == "/path/to/troika -vv -c /path/to/troika.cfg kill -u test_user test_host %ECF_JOB%" # noqa: E501 ) assert s.host.troika_version == (2, 1, 3) +def test_troika_host_options_no_config(): + host = pyflow.TroikaHost( + name="test_host", + user="test_user", + troika_config=None, + ) + + s = pyflow.Suite("s", host=host) + + assert ( + s.ECF_JOB_CMD.value + == "%TROIKA:troika% -vv submit -u test_user -o %ECF_JOBOUT% test_host %ECF_JOB%" # noqa: E501 + ) + assert ( + s.ECF_KILL_CMD.value + == "%TROIKA:troika% -vv kill -u test_user test_host %ECF_JOB%" # noqa: E501 + ) + + def test_traps(): sigs = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 13] with pyflow.Suite("s") as s1: @@ -416,6 +447,77 @@ def test_traps(): assert signal_list2 in s2 +@pytest.mark.parametrize( + "key,expected_class,kwargs", + [ + ("null", NullHost, {}), + ("localhost", LocalHost, {}), + ("ssh", SSHHost, {"name": "test"}), + ("ssh-simple", SimpleSSHHost, {"host": "test"}), + ("slurm", SLURMHost, {"name": "test"}), + ("pbs", PBSHost, {"name": "test"}), + ("troika", TroikaHost, {"name": "test", "user": "testuser"}), + ], +) +def test_host_factory_returns_correct_types(key, expected_class, kwargs): + result = host_factory(key, **kwargs) + assert isinstance(result, expected_class) + + +def test_host_factory_forwards_kwargs(): + result = host_factory("localhost", name="myhost", scratch_directory="/tmp/test") + assert result.name == "myhost" + assert result.scratch_directory == "/tmp/test" + + +def test_host_factory_raises_and_lists_available_types(): + with pytest.raises(ValueError, match="Unknown host type: bogus") as exc_info: + host_factory("bogus") + exc_str = str(exc_info.value) + for key in ("null", "localhost", "ssh", "ssh-simple", "slurm", "pbs", "troika"): + assert key in exc_str + + +def test_register_host_adds_to_registry(): + try: + + @register_host("test-dummy") + class DummyHost: + pass + + assert HOST_REGISTRY["test-dummy"] is DummyHost + finally: + del HOST_REGISTRY["test-dummy"] + + +def test_register_host_returns_class_unchanged(): + try: + + class DummyHost2: + pass + + result = register_host("test-dummy2")(DummyHost2) + assert result is DummyHost2 + finally: + del HOST_REGISTRY["test-dummy2"] + + +def test_register_host_duplicate_key_overwrites(): + try: + + @register_host("test-dup") + class DummyHostA: + pass + + @register_host("test-dup") + class DummyHostB: + pass + + assert HOST_REGISTRY["test-dup"] is DummyHostB + finally: + del HOST_REGISTRY["test-dup"] + + if __name__ == "__main__": from os import path