Skip to content

Commit

Permalink
Merge pull request #7152 from PrefectHQ/flow-restarts
Browse files Browse the repository at this point in the history
Manual flow retries
  • Loading branch information
anticorrelator committed Oct 24, 2022
2 parents 0240284 + 83d98cc commit 32264d2
Show file tree
Hide file tree
Showing 16 changed files with 805 additions and 63 deletions.
37 changes: 37 additions & 0 deletions flows/flow_retries_with_subflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from prefect import flow

child_flow_run_count = 0
flow_run_count = 0


@flow
def child_flow():
global child_flow_run_count
child_flow_run_count += 1

# Fail on the first flow run but not the retry
if flow_run_count < 2:
raise ValueError()

return "hello"


@flow(retries=10)
def parent_flow():
global flow_run_count
flow_run_count += 1

result = child_flow()

# It is important that the flow run fails after the child flow run is created
if flow_run_count < 3:
raise ValueError()

return result


if __name__ == "__main__":
result = parent_flow()
assert result == "hello", f"Got {result}"
assert flow_run_count == 3, f"Got {flow_run_count}"
assert child_flow_run_count == 2, f"Got {child_flow_run_count}"
10 changes: 6 additions & 4 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ async def create_and_begin_subflow_run(
parent_logger.debug(f"Resolving inputs to {flow.name!r}")
task_inputs = {k: await collect_task_run_inputs(v) for k, v in parameters.items()}

rerunning = parent_flow_run_context.flow_run.run_count > 1

# Generate a task in the parent flow run to represent the result of the subflow run
dummy_task = Task(name=flow.name, fn=flow.fn, version=flow.version)
parent_task_run = await client.create_task_run(
Expand All @@ -413,8 +415,9 @@ async def create_and_begin_subflow_run(
# Resolve any task futures in the input
parameters = await resolve_inputs(parameters)

if parent_task_run.state.is_final():

if parent_task_run.state.is_final() and not (
rerunning and not parent_task_run.state.is_completed()
):
# Retrieve the most recent flow run from the database
flow_runs = await client.read_flow_runs(
flow_run_filter=FlowRunFilter(
Expand All @@ -433,7 +436,7 @@ async def create_and_begin_subflow_run(
flow,
parameters=flow.serialize_parameters(parameters),
parent_task_run_id=parent_task_run.id,
state=parent_task_run.state,
state=parent_task_run.state if not rerunning else Pending(),
tags=TagsContext.get().current_tags,
)

Expand Down Expand Up @@ -469,7 +472,6 @@ async def create_and_begin_subflow_run(
report_flow_run_crashes(flow_run=flow_run, client=client)
)
task_runner = await stack.enter_async_context(flow.task_runner.start())

terminal_state = await orchestrate_flow_run(
flow,
flow_run=flow_run,
Expand Down
2 changes: 2 additions & 0 deletions src/prefect/orion/api/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ async def set_flow_run_state(
flow_policy: BaseOrchestrationPolicy = Depends(
orchestration_dependencies.provide_flow_policy
),
api_version=Depends(dependencies.provide_request_api_version),
) -> OrchestrationResult:
"""Set a flow run state, invoking any orchestration rules."""

Expand All @@ -249,6 +250,7 @@ async def set_flow_run_state(
state=schemas.states.State.parse_obj(state),
force=force,
flow_policy=flow_policy,
api_version=api_version,
)

# set the 201 because a new state was created
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/orion/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
API_TITLE = "Prefect Orion API"
UI_TITLE = "Prefect Orion UI"
API_VERSION = prefect.__version__
ORION_API_VERSION = "0.8.2"
ORION_API_VERSION = "0.8.3"

logger = get_logger("orion")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Add retry and restart metadata
Revision ID: 8ea825da948d
Revises: ad4b1b4d1e9d
Create Date: 2022-10-19 16:51:10.239643
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "8ea825da948d"
down_revision = "3ced59d8806b"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"task_run",
sa.Column(
"flow_run_run_count", sa.Integer(), server_default="0", nullable=False
),
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("task_run", "flow_run_run_count")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Add retry and restart metadata
Revision ID: af52717cf201
Revises: ad4b1b4d1e9d
Create Date: 2022-10-19 15:58:10.016251
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "af52717cf201"
down_revision = "3ced59d8806b"
branch_labels = None
depends_on = None


def upgrade():
with op.batch_alter_table("task_run", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"flow_run_run_count", sa.Integer(), server_default="0", nullable=False
)
)

# ### end Alembic commands ###


def downgrade():
with op.batch_alter_table("task_run", schema=None) as batch_op:
batch_op.drop_column("flow_run_run_count")

# ### end Alembic commands ###
3 changes: 3 additions & 0 deletions src/prefect/orion/database/orm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ def flow_run_id(cls):
cache_key = sa.Column(sa.String)
cache_expiration = sa.Column(Timestamp())
task_version = sa.Column(sa.String)
flow_run_run_count = sa.Column(
sa.Integer, server_default="0", default=0, nullable=False
)
empirical_policy = sa.Column(
Pydantic(schemas.core.TaskRunPolicy),
server_default="{}",
Expand Down
5 changes: 5 additions & 0 deletions src/prefect/orion/models/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pendulum
import sqlalchemy as sa
from packaging.version import Version
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import load_only
Expand Down Expand Up @@ -377,6 +378,7 @@ async def set_flow_run_state(
state: schemas.states.State,
force: bool = False,
flow_policy: BaseOrchestrationPolicy = None,
api_version: Version = None,
) -> OrchestrationResult:
"""
Creates a new orchestrated flow run state.
Expand Down Expand Up @@ -426,6 +428,9 @@ async def set_flow_run_state(
proposed_state=state,
)

# pass the request version to the orchestration engine to support compatibility code
context.parameters["api-version"] = api_version

# apply orchestration rules and create the new flow run state
async with contextlib.AsyncExitStack() as stack:
for rule in orchestration_rules:
Expand Down
Loading

0 comments on commit 32264d2

Please sign in to comment.