Skip to content

Commit

Permalink
Only calculate lengths of iterable arguments to map (PrefectHQ#6513)
Browse files Browse the repository at this point in the history
* Implicitly mark default kewyord arguments to the task function as being unmapped.

* Only calculate lengths of iterable arguments to `map`

* Update `map` docstring to detail new behavior

* Update `map` section of task concepts
  • Loading branch information
bunchesofdonald authored and darrida committed Aug 25, 2022
1 parent accf166 commit 0a1e3f7
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 22 deletions.
24 changes: 20 additions & 4 deletions docs/concepts/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,10 @@ def map_flow(nums):

map_flow([1,2,3,5,8,13])
```

Prefect also supports `unmapped` arguments, allowing to pass static values that don't get mapped over.
Prefect also supports static arguments, allowing you to pass static values that don't get mapped over.

```python
from prefect import flow, task, unmapped
from prefect import flow, task

@task
def add_together(x, y):
Expand All @@ -276,7 +275,24 @@ def sum_it(numbers, static_value):
futures = add_together.map(numbers, static_value)
return futures

sum_it([1, 2, 3], unmapped(5))
sum_it([1, 2, 3], 5)
```

If your static argument is an iterable, you'll need to wrap it with `unmapped` to tell Prefect that it should be treated as a static value.

```python
from prefect import flow, task, unmapped

@task
def sum_plus(x, static_iterable):
return x + sum(static_iterable)

@flow
def sum_it(numbers, static_iterable):
futures = sum_plus.map(numbers, static_iterable)
return futures

sum_it([4, 5, 6], unmapped([1, 2, 3]))
```

## Async tasks
Expand Down
44 changes: 32 additions & 12 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
TaskRunContext,
)
from prefect.deployments import load_flow_from_flow_run
from prefect.exceptions import Abort, MappingLengthMismatch, UpstreamTaskError
from prefect.exceptions import (
Abort,
MappingLengthMismatch,
MappingMissingIterable,
UpstreamTaskError,
)
from prefect.filesystems import LocalFileSystem, WritableFileSystem
from prefect.flows import Flow
from prefect.futures import PrefectFuture, call_repr
Expand Down Expand Up @@ -89,7 +94,7 @@
run_sync_in_worker_thread,
)
from prefect.utilities.callables import parameters_to_args_kwargs
from prefect.utilities.collections import Quote, visit_collection
from prefect.utilities.collections import Quote, isiterable, visit_collection
from prefect.utilities.pydantic import PartialModel

R = TypeVar("R")
Expand Down Expand Up @@ -737,24 +742,39 @@ async def begin_task_map(
# Resolve any futures / states that are in the parameters as we need to
# validate the lengths of those values before proceeding.
parameters.update(await resolve_inputs(parameters))
parameter_lengths = {
key: len(val)
for key, val in parameters.items()
if not isinstance(val, unmapped)
}

lengths = set(parameter_lengths.values())
iterable_parameters = {}
static_parameters = {}
for key, val in parameters.items():
if isinstance(val, unmapped):
static_parameters[key] = val.value
elif isiterable(val):
iterable_parameters[key] = val
else:
static_parameters[key] = val

if not len(iterable_parameters):
raise MappingMissingIterable(
"No iterable parameters were received. Parameters for map must "
f"include at least one iterable. Parameters: {parameters}"
)

iterable_parameter_lengths = {
key: len(val) for key, val in iterable_parameters.items()
}
lengths = set(iterable_parameter_lengths.values())
if len(lengths) > 1:
raise MappingLengthMismatch(
"Received parameters with different lengths. Parameters for map "
f"must all be the same length. Got lengths: {parameter_lengths}"
"Received iterable parameters with different lengths. Parameters "
f"for map must all be the same length. Got lengths: {iterable_parameter_lengths}"
)

map_length = list(lengths)[0] if lengths else 1
map_length = list(lengths)[0]

task_runs = []
for i in range(map_length):
call_parameters = {key: value[i] for key, value in parameters.items()}
call_parameters = {key: value[i] for key, value in iterable_parameters.items()}
call_parameters.update({key: value for key, value in static_parameters.items()})
task_runs.append(
partial(
create_task_run_then_submit,
Expand Down
6 changes: 6 additions & 0 deletions src/prefect/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,9 @@ class MappingLengthMismatch(PrefectException):
"""
Raised when attempting to call Task.map with arguments of different lengths.
"""


class MappingMissingIterable(PrefectException):
"""
Raised when attempting to call Task.map with all static arguments
"""
31 changes: 26 additions & 5 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,9 @@ def map(
Must be called within a flow function. If writing an async task, this
call must be awaited.
Must be called with an iterable per task function argument. All
iterables must be the same length.
Must be called with at least one iterable and all iterables must be
the same length. Any arguments that are not iterable will be treated as
a static value and each task run will recieve the same value.
Will create as many task runs as the length of the iterable(s) in the
backing API and submit the task runs to the flow's task runner. This
Expand All @@ -557,10 +558,11 @@ def map(
for sync tasks and they are fully resolved on submission.
Args:
*args: Iterable arguments to run the tasks with
*args: Iterable and static arguments to run the tasks with
return_state: Return a list of Prefect States that wrap the results
of each task run.
wait_for: Upstream task futures to wait for before starting the task
of each task run.
wait_for: Upstream task futures to wait for before starting the
task
**kwargs: Keyword iterable arguments to run the task with
Returns:
Expand Down Expand Up @@ -616,6 +618,25 @@ def map(
>>>
>>> # task 2 will wait for task_1 to complete
>>> y = task_2.map([1, 2, 3], wait_for=[x])
Use static arguments
>>> @task
>>> def add_y(x, y):
>>> return x + y
>>>
>>> @flow
>>> def my_flow():
>>> futures = add_something.map([1, 2, 3], 5)
>>>
>>> # collect results
>>> result = []
>>> for future in futures:
>>> result.append(future.result())
>>> return result
>>>
>>> my_flow()
[6, 7, 8]
"""

from prefect.engine import enter_task_run_engine
Expand Down
9 changes: 9 additions & 0 deletions src/prefect/utilities/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def flatdict_to_dict(
T = TypeVar("T")


def isiterable(obj: Any) -> bool:
try:
iter(obj)
except TypeError:
return False
else:
return True


def ensure_iterable(obj: Union[T, Iterable[T]]) -> Iterable[T]:
if isinstance(obj, Sequence) or isinstance(obj, Set):
return obj
Expand Down
27 changes: 26 additions & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from prefect import flow, get_run_logger, tags
from prefect.context import PrefectObjectRegistry
from prefect.engine import get_state_for_result
from prefect.exceptions import MappingLengthMismatch, ReservedArgumentError
from prefect.exceptions import (
MappingLengthMismatch,
MappingMissingIterable,
ReservedArgumentError,
)
from prefect.futures import PrefectFuture
from prefect.orion.schemas.core import TaskRunResult
from prefect.orion.schemas.data import DataDocument
Expand Down Expand Up @@ -1965,6 +1969,14 @@ def my_flow():
futures = my_flow()
assert [future.result() for future in futures] == [5, 7, 9]

def test_missing_iterable_argument(self):
@flow
def my_flow():
return TestTaskMap.add_together.map(5, 6)

with pytest.raises(MappingMissingIterable):
assert my_flow()

def test_mismatching_input_lengths(self):
@flow
def my_flow():
Expand Down Expand Up @@ -2008,3 +2020,16 @@ def my_flow():

futures = my_flow()
assert [future.result() for future in futures] == [6, 7, 8]

async def test_with_default_kwargs(self):
@task
def add_some(x, y=5):
return x + y

@flow
def my_flow():
numbers = [1, 2, 3]
return add_some.map(numbers)

futures = my_flow()
assert [future.result() for future in futures] == [6, 7, 8]
11 changes: 11 additions & 0 deletions tests/utilities/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AutoEnum,
dict_to_flatdict,
flatdict_to_dict,
isiterable,
remove_nested_keys,
visit_collection,
)
Expand Down Expand Up @@ -409,3 +410,13 @@ def test_passes_through_non_dict(self):
assert remove_nested_keys(["foo"], 1) == 1
assert remove_nested_keys(["foo"], "foo") == "foo"
assert remove_nested_keys(["foo"], b"foo") == b"foo"


class TestIsIterable:
@pytest.mark.parametrize("obj", [[1, 2, 3], (1, 2, 3), "hello"])
def test_is_iterable(self, obj):
assert isiterable(obj)

@pytest.mark.parametrize("obj", [5, Exception(), True])
def test_not_iterable(self, obj):
assert not isiterable(obj)

0 comments on commit 0a1e3f7

Please sign in to comment.