Skip to content

Commit

Permalink
Tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
MetRonnie committed May 18, 2022
1 parent c392180 commit 97828a5
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 71 deletions.
17 changes: 5 additions & 12 deletions cylc/flow/id_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
TYPE_CHECKING,
Expand Down Expand Up @@ -112,7 +113,7 @@ def filter_ids(
_not_matched: 'List[str]' = []

# enable / disable pattern matching
match: 'Callable[[Any, Any], bool]'
match: Callable[[Any, Any], bool]
if pattern_match:
match = fnmatchcase
else:
Expand All @@ -131,7 +132,7 @@ def filter_ids(
]
_not_matched.extend(pattern_ids)

id_tokens_map = {}
id_tokens_map: Dict[str, Tokens] = {}
for id_ in ids:
try:
id_tokens_map[id_] = Tokens(id_, relative=True)
Expand Down Expand Up @@ -190,15 +191,7 @@ def filter_ids(
or match(itask.state.status, cycle_sel)
)
# check namespace name
and (
# task name
match(itask.tdef.name, task)
# family name
or any(
match(ns, task)
for ns in itask.tdef.namespace_hierarchy
)
)
and itask.name_match(task, match_func=match)
# check task selector
and (
(
Expand All @@ -223,7 +216,7 @@ def filter_ids(
_cycles.extend(cycles)
_tasks.extend(tasks)

ret: 'List[Any]' = []
ret: List[Any] = []
if out == IDTokens.Cycle:
_cycles.extend({
itask.point
Expand Down
33 changes: 12 additions & 21 deletions cylc/flow/task_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@
"""Provide a class to represent a task proxy in a running workflow."""

from collections import Counter
from contextlib import suppress
from copy import copy
from fnmatch import fnmatchcase
from time import time
from typing import Any, Dict, List, Set, Tuple, Optional, TYPE_CHECKING
from typing import (
Any, Callable, Dict, List, Set, Tuple, Optional, TYPE_CHECKING
)

from metomi.isodatetime.timezone import get_local_time_zone

from cylc.flow import LOG
from cylc.flow.cycling.loader import standardise_point_string
from cylc.flow.exceptions import PointParsingError
from cylc.flow.id import Tokens
from cylc.flow.platforms import get_platform
from cylc.flow.task_action_timer import TimerFlags
Expand Down Expand Up @@ -420,30 +419,22 @@ def reset_try_timers(self):
for timer in self.try_timers.values():
timer.timeout = None

def point_match(self, point: Optional[str]) -> bool:
"""Return whether a string/glob matches the task's point.
None is treated as '*'.
"""
if point is None:
return True
with suppress(PointParsingError): # point_str may be a glob
point = standardise_point_string(point)
return fnmatchcase(str(self.point), point)

def status_match(self, status: Optional[str]) -> bool:
"""Return whether a string matches the task's status.
None/an empty string is treated as a match.
"""
return (not status) or self.state.status == status

def name_match(self, name: str) -> bool:
"""Return whether a string/glob matches the task's name."""
if fnmatchcase(self.tdef.name, name):
return True
return any(
fnmatchcase(ns, name) for ns in self.tdef.namespace_hierarchy
def name_match(
self,
value: str,
match_func: Callable[[Any, Any], bool] = fnmatchcase
) -> bool:
"""Return whether a string/pattern matches the task's name or any of
its parent family names."""
return match_func(self.tdef.name, value) or any(
match_func(ns, value) for ns in self.tdef.namespace_hierarchy
)

def merge_flows(self, flow_nums: Set) -> None:
Expand Down
32 changes: 18 additions & 14 deletions tests/unit/test_id_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from types import SimpleNamespace
from typing import TYPE_CHECKING, Callable
from unittest.mock import create_autospec

import pytest

Expand All @@ -24,25 +24,29 @@
from cylc.flow.task_pool import Pool
from cylc.flow.cycling.integer import IntegerPoint, CYCLER_TYPE_INTEGER
from cylc.flow.cycling.iso8601 import ISO8601Point
from cylc.flow.task_proxy import TaskProxy
from cylc.flow.taskdef import TaskDef

if TYPE_CHECKING:
from cylc.flow.cycling import PointBase


def get_task_id(itask: TaskProxy) -> str:
return f"{itask.tokens.relative_id}:{itask.state.status}"


@pytest.fixture
def task_pool(set_cycling_type: Callable):
def _task_proxy(id_, hier):
tokens = Tokens(id_, relative=True)
hier = hier.get(tokens['task'], [])
hier.append('root')
return SimpleNamespace(
id_=id_,
point=IntegerPoint(tokens['cycle']),
state=SimpleNamespace(status=tokens['task_sel']),
tdef=SimpleNamespace(
name=tokens['task'],
namespace_hierarchy=hier
),
tdef = create_autospec(TaskDef, namespace_hierarchy=hier)
tdef.name = tokens['task']
return TaskProxy(
tdef,
start_point=IntegerPoint(tokens['cycle']),
status=tokens['task_sel'],
)

def _task_pool(pool, hier) -> 'Pool':
Expand Down Expand Up @@ -123,7 +127,7 @@ def test_filter_ids_task_mode(task_pool, ids, matched, not_matched):
)

_matched, _not_matched = filter_ids([pool], ids)
assert [itask.id_ for itask in _matched] == matched
assert [get_task_id(itask) for itask in _matched] == matched
assert _not_matched == not_matched


Expand Down Expand Up @@ -216,7 +220,7 @@ def test_filter_ids_pattern_match_off(task_pool):
out=IDTokens.Task,
pattern_match=False,
)
assert [itask.id_ for itask in _matched] == ['1/a:x']
assert [get_task_id(itask) for itask in _matched] == ['1/a:x']
assert _not_matched == []


Expand All @@ -238,7 +242,7 @@ def test_filter_ids_toggle_pattern_matching(task_pool, caplog):
out=IDTokens.Task,
pattern_match=True,
)
assert [itask.id_ for itask in _matched] == ['1/a:x']
assert [get_task_id(itask) for itask in _matched] == ['1/a:x']
assert _not_matched == []

# ensure pattern matching can be disabled
Expand All @@ -249,7 +253,7 @@ def test_filter_ids_toggle_pattern_matching(task_pool, caplog):
out=IDTokens.Task,
pattern_match=False,
)
assert [itask.id_ for itask in _matched] == []
assert [get_task_id(itask) for itask in _matched] == []
assert _not_matched == ['*/*']

# ensure the ID is logged
Expand Down Expand Up @@ -285,7 +289,7 @@ def test_filter_ids_namespace_hierarchy(task_pool, ids, matched, not_matched):
pattern_match=False,
)

assert [itask.id_ for itask in _matched] == matched
assert [get_task_id(itask) for itask in _matched] == matched
assert _not_matched == not_matched


Expand Down
24 changes: 0 additions & 24 deletions tests/unit/test_task_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,10 @@
from unittest.mock import Mock

from cylc.flow.cycling import PointBase
from cylc.flow.cycling.integer import IntegerPoint
from cylc.flow.cycling.iso8601 import ISO8601Point
from cylc.flow.task_proxy import TaskProxy


@pytest.mark.parametrize(
'itask_point, point_str, expected',
[param(IntegerPoint(5), '5', True, id="Integer, basic"),
param(IntegerPoint(5), '*', True, id="Integer, glob"),
param(IntegerPoint(5), None, True, id="None same as glob(*)"),
param(ISO8601Point('2012'), '2012-01-01', True, id="ISO, basic"),
param(ISO8601Point('2012'), '2012*', True, id="ISO, glob"),
param(ISO8601Point('2012'), '2012-*', False,
id="ISO, bad glob (must use short datetime format)")]
)
def test_point_match(
itask_point: PointBase, point_str: Optional[str], expected: bool,
set_cycling_type: Callable
) -> None:
"""Test TaskProxy.point_match()."""
set_cycling_type(itask_point.TYPE)
mock_itask = Mock(point=itask_point.standardise())

assert TaskProxy.point_match(mock_itask, point_str) is expected, (
f"Does mock_task.point={mock_itask.point!r} match {point_str!r}?"
)


@pytest.mark.parametrize(
'itask_point, offset_str, expected',
[
Expand Down

0 comments on commit 97828a5

Please sign in to comment.