diff --git a/rechunker/executors/beam.py b/rechunker/executors/beam.py index 588bcf4..876f342 100644 --- a/rechunker/executors/beam.py +++ b/rechunker/executors/beam.py @@ -1,7 +1,8 @@ import uuid -from typing import Iterable, Optional, Mapping, Tuple +from typing import Iterable, NamedTuple, Optional, Mapping, Sequence, Tuple import apache_beam as beam +import numpy as np from rechunker.executors.util import chunk_keys from rechunker.types import ( @@ -28,14 +29,14 @@ class BeamExecutor(Executor[beam.PTransform]): # This would offer a cleaner API and would perhaps be faster, too. def prepare_plan(self, specs: Iterable[StagedCopySpec]) -> beam.PTransform: - return "Rechunker" >> _Rechunker(specs) + return "Rechunker" >> _OnDiskRechunker(specs) def execute_plan(self, plan: beam.PTransform, **kwargs): with beam.Pipeline(**kwargs) as pipeline: pipeline | plan -class _Rechunker(beam.PTransform): +class _OnDiskRechunker(beam.PTransform): def __init__(self, specs: Iterable[StagedCopySpec]): super().__init__() self.specs = tuple(specs) @@ -55,11 +56,11 @@ def expand(self, pcoll): k: v.stages[stage] if stage < len(v.stages) else None for k, v in specs_map.items() } - pcoll = pcoll | f"Stage{stage}" >> _CopyStage(specs_by_target) + pcoll = pcoll | f"Stage{stage}" >> _OnDiskStage(specs_by_target) return pcoll -class _CopyStage(beam.PTransform): +class _OnDiskStage(beam.PTransform): def __init__(self, specs_by_target: Mapping[str, CopySpec]): super().__init__() self.specs_by_target = specs_by_target @@ -101,3 +102,94 @@ def _copy_chunk( ) -> str: target[key] = source[key] return target_id + + +class _DirectCopySpec(NamedTuple): + uuid: str + source: ReadableArray + target: WriteableArray + read_chunks: Tuple[int, ...] + intermediate_chunks: Tuple[int, ...] + write_chunks: Tuple[int, ...] + + +class _DirectRechunker(beam.PTransform): + def __init__(self, specs: Iterable[_DirectCopySpec]): + super().__init__() + self.specs = tuple(specs) + + def expand(self, pcoll): + return ( + pcoll + | "Create" >> beam.Create(self.specs) + | "CreateTasks" >> beam.FlatMapTuple(_create_tasks) + | "Reshuffle" >> beam.Reshuffle() + | "ReadChunks" >> beam.Map(_read_chunk) + | "SplitChunks" >> beam.FlatMap(_split_chunks) + | "AddTargetIndex" >> beam.Map(_prepend_target_index) + | "ConsolidateChunks" >> beam.CombinePerKey(_combine_chunks) + | "WriteChunks" >> beam.Map(_write_chunk) + ) + + +def _create_tasks(spec): + for key in chunk_keys(spec.source.shape, spec.read_chunks): + yield spec, key + + +def _read_chunk(spec, key): + return spec, key, spec.source[key] + + +def _split_chunks(spec, key, value): + for k, v in _split_into_chunks(key, value, spec.intermediate_chunks): + yield spec, k, v + + +def _prepend_target_index(spec, key, value): + index = _chunk_index(key, spec.target_chunks) + return (spec.uuid, index), (spec, key, value) + + +def _combine_chunks(triplets): + identical_specs, keys, values = zip(*triplets) + key, value = _conslidate_into_chunk(keys, values) + return identical_specs[0], key, value + + +def _write_chunk(spec, key, value): + spec.target[key] = value + + +def _chunk_index(key: Tuple[slice, ...], chunks: Tuple[int, ...]) -> Tuple[int, ...]: + return tuple(k.start // c for k, c in zip(key, chunks)) + + +def _split_into_chunks( + key: Tuple[slice, ...], value: ReadableArray, chunks: Tuple[int, ...], +) -> Tuple[Tuple[slice, ...], ReadableArray]: + for key2 in chunk_keys(value.shape, chunks): + fixed_key = tuple( + slice(k1.start + k2.start, min(k1.start + k2.stop, k1.stop)) + for k1, k2 in zip(key, key2) + ) + yield fixed_key, value[key2] + + +def _conslidate_into_chunk( + keys: Sequence[Tuple[slice, ...]], values: Sequence[ReadableArray], +) -> Tuple[Tuple[slice, ...], ReadableArray]: + lower = tuple(map(min, zip(*[[k.start for k in key] for key in keys]))) + upper = tuple(map(max, zip(*[[k.stop for k in key] for key in keys]))) + overall_key = tuple(map(slice, lower, upper)) + + shape = tuple(u - l for l, u in zip(lower, upper)) + dtype = values[0].dtype + assert all(dtype == v.dtype for v in values[1:]) + result = np.empty(shape, dtype) + + for key, value in zip(keys, values): + fixed_key = tuple(slice(k.start - l, k.stop - l) for k, l in zip(key, lower)) + result[fixed_key] = value + + return overall_key, result