Skip to content

Commit

Permalink
Merge pull request #506 from rabernat/fix-target_chunks-bug
Browse files Browse the repository at this point in the history
fix target chunks bug
  • Loading branch information
cisaacstern authored Jun 14, 2023
2 parents e2f3bff + 8113c08 commit 5edf5ec
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pangeo_forge_recipes/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class XarraySchema(TypedDict):

def dataset_to_schema(ds: xr.Dataset) -> XarraySchema:
"""Convert the output of `dataset.to_dict(data=False, encoding=True)` to a schema
(Basically justs adds chunks, which is not part of the Xarray ouput).
(Basically just adds chunks, which is not part of the Xarray output).
"""

# Remove redundant encoding options
Expand Down Expand Up @@ -233,12 +233,16 @@ def determine_target_chunks(
) -> Dict[str, int]:
# if the schema is chunked, use that
target_chunks = {dim: dimchunks[0] for dim, dimchunks in schema["chunks"].items()}
if include_all_dims:
for dim, dimsize in schema["dims"].items():
if dim not in target_chunks:
target_chunks[dim] = dimsize
# finally override with any specified chunks
for dim, dimsize in schema["dims"].items():
if dim not in target_chunks:
target_chunks[dim] = dimsize
# override with any specified chunks
target_chunks.update(specified_chunks or {})
if not include_all_dims:
# remove chunks with the same size as their dimension
dims_to_remove = [dim for dim, cs in target_chunks.items() if cs == schema["dims"][dim]]
for dim in dims_to_remove:
del target_chunks[dim]
return target_chunks


Expand Down
26 changes: 26 additions & 0 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DatasetCombineError,
XarrayCombineAccumulator,
dataset_to_schema,
determine_target_chunks,
schema_to_template_ds,
)

Expand Down Expand Up @@ -39,6 +40,31 @@ def test_schema_to_template_ds(specified_chunks):
assert schema == schema2


@pytest.mark.parametrize(
"specified_chunks",
[{}, {"time": 1}, {"time": 2}, {"time": 2, "lon": 9}, {"time": 3}, {"time": 3, "lon": 7}],
)
@pytest.mark.parametrize("include_all_dims", [True, False])
def test_determine_target_chunks(specified_chunks, include_all_dims):
nt = 3
ds = make_ds(nt=nt)
schema = dataset_to_schema(ds)

chunks = determine_target_chunks(schema, specified_chunks, include_all_dims)

if include_all_dims:
for name, default_chunk in schema["dims"].items():
assert name in chunks
if name in specified_chunks:
assert chunks[name] == specified_chunks[name]
else:
assert chunks[name] == default_chunk
else:
for name, cs in specified_chunks.items():
if name in chunks and name in schema["dims"]:
assert chunks[name] != schema["dims"][name]


def test_schema_to_template_ds_cftime():
ds = xr.decode_cf(
xr.DataArray(
Expand Down

0 comments on commit 5edf5ec

Please sign in to comment.