Skip to content

Commit

Permalink
Add async support to State.result() (#7071)
Browse files Browse the repository at this point in the history
Co-authored-by: Bill Palombi <bill@prefect.io>
Co-authored-by: Terrence Dorsey <terrence@prefect.io>
  • Loading branch information
3 people committed Oct 6, 2022
1 parent 7752f87 commit fa694c4
Show file tree
Hide file tree
Showing 18 changed files with 474 additions and 169 deletions.
77 changes: 77 additions & 0 deletions docs/concepts/results.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,83 @@ def my_flow():
my_flow()
```

### Working with async results

When **calling** flows or tasks, the result is returned directly:

```python
import asyncio
from prefect import flow, task

@task
async def my_task():
return 1

@flow
async def my_flow():
task_result = await my_task()
return task_result + 1

result = asyncio.run(my_flow())
assert result == 2
```

When working with flow and task states, the result can be retreived with the `State.result()` method:

```python
import asyncio
from prefect import flow, task

@task
async def my_task():
return 1

@flow
async def my_flow():
state = await my_task(return_state=True)
result = await state.result(fetch=True)
return result + 1

async def main():
state = await my_flow(return_state=True)
assert await state.result(fetch=True) == 2

asyncio.run(main())
```

!!! important "Resolving results"
Prefect 2.6.0 added automatic retrieval of persisted results.
Prior to this version, `State.result()` did not require an `await`.
For backwards compatibility, when used from an asynchronous context, `State.result()` returns a raw result type.

You may opt-in to the new behavior by passing `fetch=True` as shown in the example above.
If you would like this behavior to be used automatically, you may enable the `PREFECT_ASYNC_FETCH_STATE_RESULT` setting.
If you do not opt-in to this behavior, you will see a warning.

You may also opt-out by setting `fetch=False`.
This will silence the warning, but you will need to retrieve your result manually from the result type.

When submitting tasks to a runner, the result can be retreived with the `Future.result()` method:

```python
import asyncio
from prefect import flow, task

@task
async def my_task():
return 1

@flow
async def my_flow():
future = await my_task.submit()
result = await future.result()
return result + 1

result = asyncio.run(my_flow())
assert result == 2
```


## Persisting results

The Prefect API does not store your results [except in special cases](#storage-of-results-in-prefect). Instead, the result is _persisted_ to a storage location in your infrastructure and Prefect stores a _reference_ to the result.
Expand Down
15 changes: 15 additions & 0 deletions flows/hello_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from prefect import flow, get_run_logger, task


@task
def say_hello(name: str):
get_run_logger().info(f"Hello {name}!")


@flow
def hello(name: str = "world", count: int = 1):
say_hello.map(f"{name}-{i}" for i in range(count))


if __name__ == "__main__":
hello(count=3)
67 changes: 58 additions & 9 deletions src/prefect/client/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import datetime
import warnings
from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, Union, overload

from pydantic import Field

from prefect.orion import schemas
from prefect.settings import PREFECT_ASYNC_FETCH_STATE_RESULT
from prefect.utilities.asyncutils import in_async_main_thread, sync_compatible

if TYPE_CHECKING:
from prefect.deprecated.data_documents import DataDocument
Expand Down Expand Up @@ -31,19 +34,23 @@ def result(self: "State[R]", raise_on_failure: bool = True) -> R:
def result(self: "State[R]", raise_on_failure: bool = False) -> Union[R, Exception]:
...

def result(self, raise_on_failure: bool = True):
def result(self, raise_on_failure: bool = True, fetch: Optional[bool] = None):
"""
Convenience method for access the data on the state's data document.
Retrieve the result
Args:
raise_on_failure: a boolean specifying whether to raise an exception
if the state is of type `FAILED` and the underlying data is an exception
fetch: a boolean specifying whether to resolve references to persisted
results into data. For synchronous users, this defaults to `True`.
For asynchronous users, this defaults to `False` for backwards
compatibility.
Raises:
TypeError: if the state is failed but without an exception
TypeError: If the state is failed but the result is not an exception.
Returns:
The underlying decoded data
The result of the run
Examples:
>>> from prefect import flow, task
Expand All @@ -67,39 +74,81 @@ def result(self, raise_on_failure: bool = True):
>>> @flow
>>> def my_flow():
>>> return "hello"
>>> my_flow().result()
>>> my_flow(return_state=True).result()
hello
Get the result from a failed state
>>> @flow
>>> def my_flow():
>>> raise ValueError("oh no!")
>>> state = my_flow() # Error is wrapped in FAILED state
>>> state = my_flow(return_state=True) # Error is wrapped in FAILED state
>>> state.result() # Raises `ValueError`
Get the result from a failed state without erroring
>>> @flow
>>> def my_flow():
>>> raise ValueError("oh no!")
>>> state = my_flow()
>>> state = my_flow(return_state=True)
>>> result = state.result(raise_on_failure=False)
>>> print(result)
ValueError("oh no!")
Get the result from a flow state in an async context
>>> @flow
>>> async def my_flow():
>>> return "hello"
>>> state = await my_flow(return_state=True)
>>> await state.result()
hello
"""
from prefect.deprecated.data_documents import (
DataDocument,
result_from_state_with_data_document,
)
from prefect.results import BaseResult

if fetch is None and (
PREFECT_ASYNC_FETCH_STATE_RESULT or not in_async_main_thread()
):
# Fetch defaults to `True` for sync users or async users who have opted in
fetch = True

if not fetch:
if fetch is None and in_async_main_thread():
warnings.warn(
"State.result() was called from an async context but not awaited. "
"This method will be updated to return a coroutine by default in "
"the future. Pass `fetch=True` and `await` the call to get rid of "
"this warning.",
DeprecationWarning,
stacklevel=2,
)
# Backwards compatibility
if isinstance(self.data, DataDocument):
return result_from_state_with_data_document(
self, raise_on_failure=raise_on_failure
)
else:
return self.data
else:
return self._result(raise_on_failure=raise_on_failure)

@sync_compatible
async def _result(self, raise_on_failure: bool):
from prefect.deprecated.data_documents import (
DataDocument,
result_from_state_with_data_document,
)

if isinstance(self.data, DataDocument):
return result_from_state_with_data_document(
self, raise_on_failure=raise_on_failure
)
elif isinstance(self.data, BaseResult):
return self.data.load()
return await self.data.get()
else:
raise ValueError(
f"State data is of unknown result type {type(self.data).__name__!r}."
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/deprecated/data_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ def result_from_state_with_data_document(state: "State", raise_on_failure: bool)
)
return data
elif isinstance(data, State):
data.result()
data.result(fetch=False)
elif isinstance(data, Iterable) and all([isinstance(o, State) for o in data]):
# raise the first failure we find
for state in data:
state.result()
state.result(fetch=False)

# we don't make this an else in case any of the above conditionals doesn't raise
raise TypeError(
Expand Down
8 changes: 4 additions & 4 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def create_then_begin_flow_run(
if return_type == "state":
return state
elif return_type == "result":
return state.result()
return await state.result(fetch=True)
else:
raise ValueError(f"Invalid return type for flow engine {return_type!r}.")

Expand Down Expand Up @@ -520,7 +520,7 @@ async def create_and_begin_subflow_run(
if return_type == "state":
return terminal_state
elif return_type == "result":
return terminal_state.result()
return await terminal_state.result(fetch=True)
else:
raise ValueError(f"Invalid return type for flow engine {return_type!r}.")

Expand Down Expand Up @@ -1319,7 +1319,7 @@ async def wait_for_task_runs_and_report_crashes(
if not state.type == StateType.CRASHED:
continue

exception = state.result(raise_on_failure=False)
exception = await state.result(raise_on_failure=False, fetch=True)

logger.info(f"Crash detected! {state.message}")
logger.debug("Crash details:", exc_info=exception)
Expand Down Expand Up @@ -1412,7 +1412,7 @@ def resolve_input(expr):
)

# Only retrieve the result if requested as it may be expensive
return state.result() if return_data else None
return state._result(raise_on_failure=True) if return_data else None

return await run_sync_in_worker_thread(
visit_collection,
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async def _result(self, timeout: float = None, raise_on_failure: bool = True):
final_state = await self._wait(timeout=timeout)
if not final_state:
raise TimeoutError("Call timed out before task finished.")
return final_state.result(raise_on_failure=raise_on_failure)
return await final_state._result(raise_on_failure=raise_on_failure)

@overload
def get_state(
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/orion/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def copy(self, *, update: dict = None, reset_fields: bool = False, **kwargs):
update.setdefault("timestamp", self.__fields__["timestamp"].get_default())
return super().copy(reset_fields=reset_fields, update=update, **kwargs)

def result(self, raise_on_failure: bool = True):
def result(self, raise_on_failure: bool = True, fetch: Optional[bool] = None):
from prefect.client.schemas import State

warnings.warn(
Expand All @@ -136,7 +136,7 @@ def result(self, raise_on_failure: bool = True):
)

state = State.parse_obj(self)
return state.result(raise_on_failure=raise_on_failure)
return state.result(raise_on_failure=raise_on_failure, fetch=fetch)

def __repr__(self) -> str:
"""
Expand Down
18 changes: 18 additions & 0 deletions src/prefect/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,24 @@ def warn_on_database_password_value_without_usage(values):
prefetched. Defaults to `10`.""",
)

PREFECT_ASYNC_FETCH_STATE_RESULT = Setting(
bool,
default=False,
description=textwrap.dedent(
"""
Determines whether `State.result()` fetches results automatically or not.
In Prefect 2.6.0, the `State.result()` method was updated to be async
to faciliate automatic retrieval of results from storage which means when
writing async code you must `await` the call. For backwards compatibility,
the result is not retrieved by default for async users. You may opt into this
per call by passing `fetch=True` or toggle this setting to change the behavior
globally.
This setting does not affect users writing synchronous tasks and flows.
This setting does not affect retrieval of results when using `Future.result()`.
"""
),
)

PREFECT_ORION_BLOCKS_REGISTER_ON_START = Setting(
bool,
default=True,
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/testing/standard_test_suites/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ async def fake_orchestrate_task_run(example_kwarg):
state = await task_runner.wait(task_run.id, 5)
assert state is not None, "wait timed out"
assert isinstance(state, State), "wait should return a state"
assert state.result() == 1
assert await state.result() == 1

@pytest.mark.parametrize("exception", [KeyboardInterrupt(), ValueError("test")])
async def test_wait_captures_exceptions_as_crashed_state(
Expand All @@ -367,7 +367,7 @@ async def fake_orchestrate_task_run():
assert state is not None, "wait timed out"
assert isinstance(state, State), "wait should return a state"
assert state.type == StateType.CRASHED
result = state.result(raise_on_failure=False)
result = await state.result(raise_on_failure=False)

assert exceptions_equal(result, exception)

Expand Down
23 changes: 2 additions & 21 deletions src/prefect/utilities/asyncutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import ctypes
import inspect
import sys
import threading
import warnings
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -182,26 +181,8 @@ def sync_compatible(async_fn: T) -> T:
@wraps(async_fn)
def wrapper(*args, **kwargs):
if in_async_main_thread():
caller_frame = sys._getframe(1)
caller_module = caller_frame.f_globals.get("__name__", "unknown")
caller_async = caller_frame.f_code.co_flags & inspect.CO_COROUTINE
if caller_async or any(
# Add exceptions for the internals anyio/asyncio which can run
# coroutines from synchronous functions
caller_module.startswith(f"{module}.")
for module in ["asyncio", "anyio"]
):
# In the main async context; return the coro for them to await
return async_fn(*args, **kwargs)
else:
# In the main thread but call was made from a sync method
raise RuntimeError(
"A 'sync_compatible' method was called from a context that was "
"previously async but is now sync. The sync call must be changed "
"to run in a worker thread to support sending the coroutine for "
f"{async_fn.__name__!r} to the main thread."
)

# In the main async context; return the coro for them to await
return async_fn(*args, **kwargs)
elif in_async_worker_thread():
# In a sync context but we can access the event loop thread; send the async
# call to the parent
Expand Down
Loading

0 comments on commit fa694c4

Please sign in to comment.