Skip to content

Commit

Permalink
return list from split_fragment
Browse files Browse the repository at this point in the history
  • Loading branch information
cisaacstern committed Jul 27, 2023
1 parent f0c7dac commit 0d5f97a
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions pangeo_forge_recipes/rechunking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import itertools
import operator
from typing import Dict, Iterator, List, Tuple
from typing import Dict, List, Tuple

import numpy as np
import xarray as xr
Expand All @@ -20,7 +20,7 @@ def split_fragment(
fragment: Tuple[Index, xr.Dataset],
target_chunks: Optional[Dict[str, int]] = None,
schema: Optional[XarraySchema] = None,
) -> Iterator[Tuple[GroupKey, Tuple[Index, xr.Dataset]]]:
) -> List[Tuple[GroupKey, Tuple[Index, xr.Dataset]]]:
"""Split a single indexed dataset fragment into sub-fragments, according to the
specified target chunks
Expand Down Expand Up @@ -94,6 +94,7 @@ def split_fragment(
]
)

splits = []
# this iteration yields new fragments, indexed by their target chunk group
for target_chunk_group in all_chunks:
# now we need to figure out which piece of the fragment belongs in which chunk
Expand All @@ -115,12 +116,15 @@ def split_fragment(
)
sub_fragment_ds = ds.isel(**sub_fragment_indexer)

yield (
# append the `merge_dim_positions` to the target_chunk_group before returning,
# to ensure correct grouping of merge dims. e.g., `(("time", 0), ("variable", 0))`.
tuple(sorted(target_chunk_group) + merge_dim_positions),
(sub_fragment_index, sub_fragment_ds),
splits.append(
(
# append the `merge_dim_positions` to the target_chunk_group before returning,
# to ensure correct grouping of merge dims. e.g., `(("time", 0), ("variable", 0))`.
tuple(sorted(target_chunk_group) + merge_dim_positions),
(sub_fragment_index, sub_fragment_ds),
)
)
return splits


def _sort_index_key(item):
Expand Down

0 comments on commit 0d5f97a

Please sign in to comment.