Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor executors again #112

Merged
merged 16 commits into from
Mar 23, 2022
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand All @@ -30,10 +30,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.6
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
161 changes: 78 additions & 83 deletions rechunker/executors/dask.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,98 @@
from functools import reduce
from typing import Iterable
from __future__ import annotations

from typing import Any, Dict, Set, Tuple, Union

import dask
import dask.array
from dask.blockwise import BlockwiseDepDict, blockwise
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph

from rechunker.types import (
MultiStagePipeline,
ParallelPipelines,
PipelineExecutor,
Stage,
)


class DaskPipelineExecutor(PipelineExecutor[Delayed]):
"""An execution engine based on dask.

Supports zarr and dask arrays as inputs. Outputs must be zarr arrays.

Execution plans for DaskExecutors are dask.delayed objects.
"""

def pipelines_to_plan(self, pipelines: ParallelPipelines) -> Delayed:
return _make_pipelines(pipelines)

def execute_plan(self, plan: Delayed, **kwargs):
return plan.compute(**kwargs)

from rechunker.types import ParallelPipelines, Pipeline, PipelineExecutor

def _make_pipelines(pipelines: ParallelPipelines) -> Delayed:
pipelines_delayed = [_make_pipeline(pipeline) for pipeline in pipelines]
return _merge(*pipelines_delayed)

def wrap_map_task(function):
# dependencies are dummy args used to create dependence between stages
def wrapped(map_arg, config, *dependencies):
return function(map_arg, config=config)

def _make_pipeline(pipeline: MultiStagePipeline) -> Delayed:
stages_delayed = [_make_stage(stage) for stage in pipeline]
d = reduce(_add_upstream, stages_delayed)
return d
return wrapped


def _make_stage(stage: Stage) -> Delayed:
if stage.map_args is None:
return dask.delayed(stage.func)()
else:
name = stage.func.__name__ + "-" + dask.base.tokenize(stage.func)
dsk = {(name, i): (stage.func, arg) for i, arg in enumerate(stage.map_args)}
# create a barrier
top_key = "stage-" + dask.base.tokenize(stage.func, stage.map_args)
def wrap_standalone_task(function):
def wrapped(config, *dependencies):
return function(config=config)

def merge_all(*args):
# this function is dependent on its arguments but doesn't actually do anything
return None
return wrapped

dsk.update({top_key: (merge_all, *list(dsk))})
return Delayed(top_key, dsk)

def checkpoint(*args):
return

def _merge_task(*args):
pass

def append_token(task_name: str, token: str) -> str:
return f"{task_name}-{token}"

def _merge(*args: Iterable[Delayed]) -> Delayed:
name = "merge-" + dask.base.tokenize(*args)
# mypy doesn't like arg.key
keys = [getattr(arg, "key") for arg in args]
new_task = (_merge_task, *keys)
# mypy doesn't like arg.dask
graph = dask.base.merge(
*[dask.utils.ensure_dict(getattr(arg, "dask")) for arg in args]
)
graph[name] = new_task
d = Delayed(name, graph)
return d

class DaskPipelineExecutor(PipelineExecutor[Delayed]):
"""An execution engine based on dask.

def _add_upstream(first: Delayed, second: Delayed):
upstream_key = first.key
dsk = second.dask
top_layer = _get_top_layer(dsk)
new_top_layer = {}

for key, value in top_layer.items():
new_top_layer[key] = ((lambda a, b: a), value, upstream_key)

dsk_new = dask.base.merge(
dask.utils.ensure_dict(first.dask), dask.utils.ensure_dict(dsk), new_top_layer
)
Supports zarr and dask arrays as inputs. Outputs must be zarr arrays.

return Delayed(second.key, dsk_new)
Execution plans for DaskExecutors are dask.delayed objects.
"""

def pipelines_to_plan(self, pipelines: ParallelPipelines) -> Delayed:
return [_make_pipeline(pipeline) for pipeline in pipelines]

def _get_top_layer(dsk):
if hasattr(dsk, "layers"):
# this is a HighLevelGraph
top_layer_key = list(dsk.layers)[0]
top_layer = dsk.layers[top_layer_key]
else:
# could this go wrong?
first_key = next(iter(dsk))
first_task = first_key[0].split("-")[0]
top_layer = {k: v for k, v in dsk.items() if k[0].startswith(first_task + "-")}
return top_layer
def execute_plan(self, plan: Delayed, **kwargs):
return dask.compute(*plan, **kwargs)


def _make_pipeline(pipeline: Pipeline) -> Delayed:
token = dask.base.tokenize(pipeline)

# we are constructing a HighLevelGraph from scratch
# https://docs.dask.org/en/latest/high-level-graphs.html
layers = dict() # type: Dict[str, Dict[Union[str, Tuple[str, int]], Any]]
dependencies = dict() # type: Dict[str, Set[str]]

# start with just the config as a standalone layer
# create a custom delayed object for the config
config_key = append_token("config", token)
layers[config_key] = {config_key: pipeline.config}
dependencies[config_key] = set()

prev_key: str = config_key
for stage in pipeline.stages:
if stage.mappable is None:
stage_key = append_token(stage.name, token)
func = wrap_standalone_task(stage.function)
layers[stage_key] = {stage_key: (func, config_key, prev_key)}
dependencies[stage_key] = {config_key, prev_key}
else:
func = wrap_map_task(stage.function)
map_key = append_token(stage.name, token)
layers[map_key] = map_layer = blockwise(
func,
map_key,
"x", # <-- dimension name doesn't matter
BlockwiseDepDict({(i,): x for i, x in enumerate(stage.mappable)}),
# ^ this is extra annoying. `BlockwiseDepList` at least would be nice.
"x",
config_key,
None,
prev_key,
None,
numblocks={},
# ^ also annoying; the default of None breaks Blockwise
)
dependencies[map_key] = {config_key, prev_key}

stage_key = f"{stage.name}-checkpoint-{token}"
layers[stage_key] = {stage_key: (checkpoint, *map_layer.get_output_keys())}
dependencies[stage_key] = {map_key}
prev_key = stage_key

hlg = HighLevelGraph(layers, dependencies)
delayed = Delayed(prev_key, hlg)
return delayed
70 changes: 24 additions & 46 deletions rechunker/executors/prefect.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,35 @@
import prefect
from typing import List

from prefect import Flow, task, unmapped

from rechunker.types import ParallelPipelines, PipelineExecutor


class PrefectPipelineExecutor(PipelineExecutor[prefect.Flow]):
class PrefectPipelineExecutor(PipelineExecutor[Flow]):
"""An execution engine based on Prefect.

Supports copying between any arrays that implement ``__getitem__`` and
``__setitem__`` for tuples of ``slice`` objects. Array must also be
serializable by Prefect (i.e., with pickle).

Execution plans for PrefectExecutor are prefect.Flow objects.
"""

def pipelines_to_plan(self, pipelines: ParallelPipelines) -> prefect.Flow:
return _make_flow(pipelines)

def execute_plan(self, plan: prefect.Flow, **kwargs):
def pipelines_to_plan(self, pipelines: ParallelPipelines) -> Flow:
with Flow("rechunker") as flow:
for pipeline in pipelines:
upstream_tasks = [] # type: List[task]
for stage in pipeline.stages:
stage_task = task(stage.function, name=stage.name)
if stage.mappable is not None:
stage_task_called = stage_task.map(
list(stage.mappable), # prefect doesn't accept a generator
config=unmapped(pipeline.config),
upstream_tasks=[unmapped(t) for t in upstream_tasks],
)
else:
stage_task_called = stage_task(
config=pipeline.config, upstream_tasks=upstream_tasks
)
upstream_tasks = [stage_task_called]
return flow

def execute_plan(self, plan: Flow, **kwargs):
state = plan.run(**kwargs)
return state


class MappedTaskWrapper(prefect.Task):
def __init__(self, stage, **kwargs):
self.stage = stage
super().__init__(**kwargs)

def run(self, key):
return self.stage.func(key)


class SingleTaskWrapper(prefect.Task):
def __init__(self, stage, **kwargs):
self.stage = stage
super().__init__(**kwargs)

def run(self):
return self.stage.func()


def _make_flow(pipelines: ParallelPipelines) -> prefect.Flow:
with prefect.Flow("Rechunker") as flow:
# iterate over different arrays in the group
for pipeline in pipelines:
stage_tasks = []
# iterate over the different stages of the array copying
for stage in pipeline:
if stage.map_args is None:
stage_task = SingleTaskWrapper(stage)
else:
stage_task = MappedTaskWrapper(stage).map(stage.map_args)
stage_tasks.append(stage_task)
# create dependence between stages
for n in range(len(stage_tasks) - 1):
stage_tasks[n + 1].set_upstream(stage_tasks[n])
return flow
30 changes: 11 additions & 19 deletions rechunker/executors/python.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from typing import Callable, Iterable
from typing import Callable

from rechunker.types import ParallelPipelines, PipelineExecutor

Expand All @@ -11,27 +10,20 @@
class PythonPipelineExecutor(PipelineExecutor[Task]):
"""An execution engine based on Python loops.

Supports copying between any arrays that implement ``__getitem__`` and
``__setitem__`` for tuples of ``slice`` objects.

Execution plans for PythonExecutor are functions that accept no arguments.
"""

def pipelines_to_plan(self, pipelines: ParallelPipelines) -> Task:
tasks = []
for pipeline in pipelines:
for stage in pipeline:
if stage.map_args is None:
tasks.append(stage.func)
else:
for arg in stage.map_args:
tasks.append(partial(stage.func, arg))
return partial(_execute_all, tasks)
def plan():
for pipeline in pipelines:
for stage in pipeline.stages:
if stage.mappable is not None:
for m in stage.mappable:
stage.function(m, config=pipeline.config)
else:
stage.function(config=pipeline.config)

return plan

def execute_plan(self, plan: Task, **kwargs):
plan()


def _execute_all(tasks: Iterable[Task]) -> None:
for task in tasks:
task()
Loading