Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update server-side State.data to accept arbitrary JSON #6896

Merged
merged 9 commits into from
Sep 21, 2022
4 changes: 2 additions & 2 deletions src/prefect/orion/database/orm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def flow_run_id(cls):
default=schemas.states.StateDetails,
nullable=False,
)
data = sa.Column(Pydantic(schemas.data.DataDocument), nullable=True)
data = sa.Column(sa.JSON, nullable=True)

@declared_attr
def flow_run(cls):
Expand Down Expand Up @@ -186,7 +186,7 @@ def task_run_id(cls):
default=schemas.states.StateDetails,
nullable=False,
)
data = sa.Column(Pydantic(schemas.data.DataDocument), nullable=True)
data = sa.Column(sa.JSON, nullable=True)

@declared_attr
def task_run(cls):
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/orion/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class StateCreate(ActionBaseModel):
type: schemas.states.StateType = FieldFrom(schemas.states.State)
name: Optional[str] = FieldFrom(schemas.states.State)
message: Optional[str] = FieldFrom(schemas.states.State)
data: Optional[schemas.data.DataDocument] = FieldFrom(schemas.states.State)
data: Optional[Any] = FieldFrom(schemas.states.State)
state_details: schemas.states.StateDetails = FieldFrom(schemas.states.State)

# DEPRECATED
Expand Down
9 changes: 6 additions & 3 deletions src/prefect/orion/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import datetime
import warnings
from typing import Generic, Optional, Type, TypeVar
from typing import Any, Generic, Optional, Type, TypeVar
from uuid import UUID

import pendulum
from pydantic import Field, root_validator, validator

from prefect.orion.schemas.data import DataDocument
from prefect.orion.utilities.schemas import DateTimeTZ, IDBaseModel, PrefectBaseModel
from prefect.utilities.collections import AutoEnum

Expand Down Expand Up @@ -58,8 +57,12 @@ class Config:
name: Optional[str] = Field(default=None)
timestamp: DateTimeTZ = Field(default_factory=lambda: pendulum.now("UTC"))
message: Optional[str] = Field(default=None, example="Run started")
data: Optional[DataDocument[R]] = Field(
data: Optional[Any] = Field(
default=None,
description=(
"Data associated with the state, e.g. a result. "
"Content must be storable as JSON."
),
)
state_details: StateDetails = Field(default_factory=StateDetails)

Expand Down
21 changes: 21 additions & 0 deletions tests/orion/api/test_flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,27 @@ async def test_set_flow_run_state_force_skips_orchestration(
assert response2.status_code == status.HTTP_201_CREATED
assert response2.json()["status"] == "ACCEPT"

@pytest.mark.parametrize("data", [1, "test", {"foo": "bar"}])
async def test_set_flow_run_state_accepts_any_jsonable_data(
self, flow_run, client, session, data
):
response = await client.post(
f"/flow_runs/{flow_run.id}/set_state",
json=dict(state=dict(type="COMPLETED", data=data)),
)
assert response.status_code == 201

api_response = OrchestrationResult.parse_obj(response.json())
assert api_response.status == responses.SetStateStatus.ACCEPT

flow_run_id = flow_run.id
session.expire(flow_run)

run = await models.flow_runs.read_flow_run(
session=session, flow_run_id=flow_run_id
)
assert run.state.data == data


class TestFlowRunHistory:
async def test_history_interval_must_be_one_second_or_larger(self, client):
Expand Down