Skip to content

Commit

Permalink
Create WorkflowRuntimeClientBase with shared logic for ZMQ & SSH cl…
Browse files Browse the repository at this point in the history
…ients (cylc#4742)

client: move to shared base class with common logic for all clients

* Move shared WorkflowRuntimeClient stuff into ABC
* Tidy and prefer str mode of Popen
* Make WorkflowRuntImeClient timeout handling more shared in ABC
* Remove duplication in WorkflowRuntimeClient timeout handler
* Reduce duplicated code
  • Loading branch information
MetRonnie authored and datamel committed Oct 19, 2022
1 parent c1b1875 commit 657de97
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 212 deletions.
2 changes: 1 addition & 1 deletion cylc/flow/host_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def _get_metrics(hosts, metrics, data=None):
if proc.poll() is None:
continue
del proc_map[host]
out, err = (f.decode().strip() for f in proc.communicate())
out, err = proc.communicate()
if proc.wait():
# Command failed
LOG.warning(
Expand Down
3 changes: 2 additions & 1 deletion cylc/flow/job_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def write(self, local_job_file_path, job_conf, check_syntax=True):
['/usr/bin/env', 'bash', '-n', tmp_name],
stderr=PIPE,
stdin=DEVNULL,
text=True
# * the purpose of this is to evaluate user defined code
# prior to it being executed
) as proc:
if proc.wait():
# This will leave behind the temporary file,
# which is useful for debugging syntax errors, etc.
raise RuntimeError(proc.communicate()[1].decode())
raise RuntimeError(proc.communicate()[1])
except OSError as exc:
# Popen has a bad habit of not telling you anything if it fails
# to run the executable.
Expand Down
8 changes: 4 additions & 4 deletions cylc/flow/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ def get_location(workflow: str) -> Tuple[str, int, int]:
Args:
workflow: workflow name
Returns:
tuple with the host name and port numbers.
Tuple (host name, port number, publish port number)
Raises:
WorkflowStopped: if the workflow is not running.
CylcVersionError: if target is a Cylc 7 (or earlier) workflow.
"""
try:
contact = load_contact_file(workflow)
except ServiceFileError:
except (IOError, ValueError, ServiceFileError):
# Contact file does not exist or corrupted, workflow should be dead
raise WorkflowStopped(workflow)

host = contact[ContactFileFields.HOST]
Expand All @@ -84,8 +85,7 @@ def get_location(workflow: str) -> Tuple[str, int, int]:
if ContactFileFields.PUBLISH_PORT in contact:
pub_port = int(contact[ContactFileFields.PUBLISH_PORT])
else:
version = (
contact['CYLC_VERSION'] if 'CYLC_VERSION' in contact else None)
version = contact.get('CYLC_VERSION', None)
raise CylcVersionError(version=version)
return host, port, pub_port

Expand Down
228 changes: 123 additions & 105 deletions cylc/flow/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Client for workflow runtime API."""

from functools import partial
from abc import ABCMeta, abstractmethod
import asyncio
import os
from shutil import which
import socket
import sys
from typing import TYPE_CHECKING, Any, Optional, Union, Dict
from typing import Any, Optional, Union, Dict

import zmq
import zmq.asyncio
Expand All @@ -33,6 +34,7 @@
ServiceFileError,
WorkflowStopped,
)
from cylc.flow.hostuserutil import get_fqdn_by_host
from cylc.flow.network import (
encode_,
decode_,
Expand All @@ -42,16 +44,122 @@
from cylc.flow.network.client_factory import CommsMeth
from cylc.flow.network.server import PB_METHOD_MAP
from cylc.flow.workflow_files import (
ContactFileFields,
detect_old_contact_file,
load_contact_file
)

if TYPE_CHECKING:
import asyncio

class WorkflowRuntimeClientBase(metaclass=ABCMeta):
"""Base class for WorkflowRuntimeClients.
class WorkflowRuntimeClient(ZMQSocketBase):
WorkflowRuntimeClients that inherit from this must implement an async
method ``async_request()``. This base class provides a ``serial_request()``
method based on the ``async_request()`` method, callable by ``__call__``.
It also provides a comms timeout handler method.
"""

DEFAULT_TIMEOUT = 5 # seconds

def __init__(
self,
workflow: str,
host: Optional[str] = None,
port: Union[int, str, None] = None,
timeout: Union[float, str, None] = None
):
self.workflow = workflow
if not host or not port:
host, port, _ = get_location(workflow)
else:
port = int(port)
self.host = self._orig_host = host
self.port = self._orig_port = port
self.timeout = (
float(timeout) if timeout is not None else self.DEFAULT_TIMEOUT
)

@abstractmethod
async def async_request(
self,
command: str,
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
) -> object:
"""Send an asynchronous request."""
...

def serial_request(
self,
command: str,
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
) -> object:
"""Send a request.
For convenience use ``__call__`` to call this method.
Args:
command: The name of the endpoint to call.
args: Arguments to pass to the endpoint function.
timeout: Override the default timeout (seconds).
Raises:
ClientTimeout: If a response takes longer than timeout to arrive.
ClientError: Coverall for all other issues including failed auth.
Returns:
object: The data exactly as returned from the endpoint function,
nothing more, nothing less.
"""
loop = getattr(self, 'loop', asyncio.new_event_loop())
task = loop.create_task(
self.async_request(command, args, timeout, req_meta)
)
loop.run_until_complete(task)
if not hasattr(self, 'loop'):
# (If inheriting class does have an event loop, don't mess with it)
loop.close()
return task.result()

__call__ = serial_request

def timeout_handler(self) -> None:
"""Handle the eventuality of a communication timeout with the workflow.
Raises:
WorkflowStopped: if the workflow has already stopped.
CyclError: if the workflow has moved to different host/port.
"""
contact_host, contact_port, _ = get_location(self.workflow)
if (
contact_host != get_fqdn_by_host(self._orig_host)
or contact_port != self._orig_port
):
raise CylcError(
'The workflow is no longer running at '
f'{self._orig_host}:{self._orig_port}\n'
f'It has moved to {contact_host}:{contact_port}'
)

# Cannot connect, perhaps workflow is no longer running and is leaving
# behind a contact file?
try:
detect_old_contact_file(self.workflow)
except (AssertionError, ServiceFileError):
# old contact file exists and the workflow process still alive
return
else:
# the workflow has stopped
raise WorkflowStopped(self.workflow)


class WorkflowRuntimeClient( # type: ignore[misc]
ZMQSocketBase, WorkflowRuntimeClientBase
):
# (Ignoring mypy 'definition of "host" in base class "ZMQSocketBase" is
# incompatible with definition in base class "WorkflowRuntimeClientBase"')
"""Initiate a client to the scheduler API.
Initiates the REQ part of a ZMQ REQ-REP pair.
Expand Down Expand Up @@ -118,33 +226,21 @@ class WorkflowRuntimeClient(ZMQSocketBase):
"""
# socket & event loop not None - get assigned on init by self.start():
socket: zmq.asyncio.Socket
loop: 'asyncio.AbstractEventLoop'

DEFAULT_TIMEOUT = 5. # 5 seconds
loop: asyncio.AbstractEventLoop

def __init__(
self,
workflow: str,
host: Optional[str] = None,
port: Optional[int] = None,
context: Optional[zmq.asyncio.Context] = None,
port: Union[int, str, None] = None,
timeout: Union[float, str, None] = None,
context: Optional[zmq.asyncio.Context] = None,
srv_public_key_loc: Optional[str] = None
):
super().__init__(zmq.REQ, workflow, context=context)
if not host or not port:
host, port, _ = get_location(workflow)
else:
port = int(port)
self.host = host
self.port = port
if timeout is None:
timeout = self.DEFAULT_TIMEOUT
else:
timeout = float(timeout)
self.timeout = timeout * 1000
self.timeout_handler = partial(
self._timeout_handler, workflow, host, port)
ZMQSocketBase.__init__(self, zmq.REQ, workflow, context=context)
WorkflowRuntimeClientBase.__init__(self, workflow, host, port, timeout)
# convert to milliseconds:
self.timeout *= 1000
self.poller: Any = None
# Connect the ZMQ socket on instantiation
self.start(self.host, self.port, srv_public_key_loc)
Expand Down Expand Up @@ -199,11 +295,7 @@ async def async_request(
if self.poller.poll(timeout):
res = await self.socket.recv()
else:
if callable(self.timeout_handler):
self.timeout_handler()
host, port, _ = get_location(self.workflow)
if host != self.host or port != self.port:
raise WorkflowStopped(self.workflow)
self.timeout_handler()
raise ClientTimeout(
'Timeout waiting for server response.'
' This could be due to network or server issues.'
Expand All @@ -228,36 +320,6 @@ async def async_request(
error.get('traceback'),
)

def serial_request(
self,
command: str,
args: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
req_meta: Optional[Dict[str, Any]] = None
) -> object:
"""Send a request.
For convenience use ``__call__`` to call this method.
Args:
command: The name of the endpoint to call.
args: Arguments to pass to the endpoint function.
timeout: Override the default timeout (seconds).
Raises:
ClientTimeout: If a response takes longer than timeout to arrive.
ClientError: Coverall for all other issues including failed auth.
Returns:
object: The data exactly as returned from the endpoint function,
nothing more, nothing less.
"""
task = self.loop.create_task(
self.async_request(command, args, timeout, req_meta))
self.loop.run_until_complete(task)
return task.result()

def get_header(self) -> dict:
"""Return "header" data to attach to each request for traceability.
Expand Down Expand Up @@ -292,47 +354,3 @@ def get_header(self) -> dict:
)
}
}

@staticmethod
def _timeout_handler(workflow: str, host: str, port: Union[int, str]):
"""Handle the eventuality of a communication timeout with the workflow.
Args:
workflow (str): workflow name
host (str): host name
port (Union[int, str]): port number
Raises:
ClientError: if the workflow has already stopped.
"""
if workflow is None:
return

try:
contact_data: Dict[str, str] = load_contact_file(workflow)
except (IOError, ValueError, ServiceFileError):
# Contact file does not exist or corrupted, workflow should be dead
return

contact_host: str = contact_data.get(ContactFileFields.HOST, '?')
contact_port: str = contact_data.get(ContactFileFields.PORT, '?')
if (
contact_host != host
or contact_port != str(port)
):
raise CylcError(
f'The workflow is no longer running at {host}:{port}\n'
f'It has moved to {contact_host}:{contact_port}'
)

# Cannot connect, perhaps workflow is no longer running and is leaving
# behind a contact file?
try:
detect_old_contact_file(workflow, contact_data)
except (AssertionError, ServiceFileError):
# old contact file exists and the workflow process still alive
return
else:
# the workflow has stopped
raise WorkflowStopped(workflow)

__call__ = serial_request
12 changes: 10 additions & 2 deletions cylc/flow/network/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

from enum import Enum
import os
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from cylc.flow.network.client import WorkflowRuntimeClientBase


class CommsMeth(Enum):
Expand All @@ -33,7 +37,11 @@ def get_comms_method() -> CommsMeth:
)


def get_runtime_client(comms_method: CommsMeth, workflow, timeout=None):
def get_runtime_client(
comms_method: CommsMeth,
workflow: str,
timeout: Union[float, str, None] = None
) -> 'WorkflowRuntimeClientBase':
"""Return client for the provided communication method.
Args:
Expand All @@ -43,7 +51,7 @@ def get_runtime_client(comms_method: CommsMeth, workflow, timeout=None):
if comms_method == CommsMeth.SSH:
from cylc.flow.network.ssh_client import WorkflowRuntimeClient
else:
from cylc.flow.network.client import ( # type: ignore
from cylc.flow.network.client import ( # type: ignore[no-redef]
WorkflowRuntimeClient
)
return WorkflowRuntimeClient(workflow, timeout=timeout)
Expand Down
Loading

0 comments on commit 657de97

Please sign in to comment.