Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guard calling after_transition when exception encountered #7156

Merged
merged 5 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Guard calling after_transition when exception encountered
  • Loading branch information
zangell44 committed Oct 13, 2022
commit 52a08f6dc763ceba734830f03e6865505908d70a
14 changes: 12 additions & 2 deletions src/prefect/orion/orchestration/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,12 +884,12 @@ async def __aexit__(
"""
Exit the async runtime context governed by this transform.

If the transition has been nullified upon exiting this transforms's context,
If the transition has been nullified or errorred upon exiting this transforms's context,
nothing happens. Otherwise, `self.after_transition` will fire on every non-null
proposed state.
"""

if not self.nullified_transition():
if not self.nullified_transition() and not self.errorred_transition():
await self.after_transition(self.context)
self.context.finalization_signature.append(str(self.__class__))

Expand Down Expand Up @@ -927,3 +927,13 @@ def nullified_transition(self) -> bool:
"""

return self.context.proposed_state is None

def errorred_transition(self) -> bool:
"""
Determines if the transition has encountered an exception.

Returns:
True if the transition is encountered an exception, False otherwise.
"""

return self.context.orchestration_error is not None
51 changes: 51 additions & 0 deletions tests/orion/orchestration/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,57 @@ async def after_transition(self, context):
assert before_hook.call_count == 0
assert after_hook.call_count == 0

@pytest.mark.parametrize(
"intended_transition",
list(product([*states.StateType, None], [*states.StateType])),
ids=transition_names,
)
async def test_universal_transforms_never_fire_after_transition_on_errored_transitions(
self, session, task_run, intended_transition
):
# nullified transitions occur when the proposed state becomes None
# and nothing is written to the database

side_effect = 0
before_hook = MagicMock()
after_hook = MagicMock()

class IllustrativeUniversalTransform(BaseUniversalTransform):
async def before_transition(self, context):
nonlocal side_effect
side_effect += 1
before_hook()

async def after_transition(self, context):
nonlocal side_effect
side_effect += 1
after_hook()

initial_state_type, proposed_state_type = intended_transition
initial_state = await commit_task_run_state(
session, task_run, initial_state_type
)
proposed_state = (
states.State(type=proposed_state_type) if proposed_state_type else None
)

ctx = OrchestrationContext(
session=session,
initial_state=initial_state,
proposed_state=proposed_state,
)

universal_transform = IllustrativeUniversalTransform(ctx)

async with universal_transform as ctx:
ctx.orchestration_error = Exception

assert side_effect == 1
assert before_hook.call_count == 1
assert (
after_hook.call_count == 0
), "after_transition should not be called if orchestration encountered errors."


@pytest.mark.parametrize("run_type", ["task", "flow"])
class TestOrchestrationContext:
Expand Down