Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform for cf_attrs #6

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 80 additions & 7 deletions data_management_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,89 @@
from google.api_core.exceptions import NotFound
import apache_beam as beam
import zarr
import xarray as xr
import datetime
from dataclasses import dataclass


# ----------------------------------------------------------------
# ----------------- Functions -----------------------------------
# ----------------------------------------------------------------
# wrapper functions (not sure if this works instead of the repeated copy and paste in the transform below)
def log_to_bq(iid: str, store: zarr.storage.FSStore, table_id: str):
bq_interface = BQInterface(table_id=table_id)
iid_entry = IIDEntry(iid=iid, store=store.path)
bq_interface.insert(iid_entry)


# ----------------------------------------------------------------
# ----------------- Transforms ----------------------------------
# ----------------------------------------------------------------


@dataclass
class ValidateCFConventions(beam.PTransform):
"""
Transform to validate CF conventions
"""

@staticmethod
def _get_dataset(store: zarr.storage.FSStore) -> xr.Dataset:
import xarray as xr

return xr.open_dataset(store, engine="zarr", chunks={}, use_cftime=True)

@staticmethod
def _retrieve_CF_axes(ds: xr.Dataset) -> dict[str, dict[str, str]]:
"""Retrieve the CF dimensions from the dataset"""
import cf_xarray # noqa
norlandrhagen marked this conversation as resolved.
Show resolved Hide resolved

results = {}
for variable in ds.variables:
axes = ds[variable].cf.axes
for key, value in axes.items():
axes[key] = value[0]
results[variable] = axes
return results

def _test_attributes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore:
norlandrhagen marked this conversation as resolved.
Show resolved Hide resolved
"""cf_axes validation"""
import cf_xarray # noqa

ds = self._get_dataset(store)
cf_axes_dict = ds.cf.axes
# Check X coord
assert cf_axes_dict.get(
"X"
), "According to cf_xarray, this dataset is missing an X Axis"
# Check Y coord
assert cf_axes_dict.get(
"Y"
), "According to cf_xarray, this dataset is missing a Y Axis"
# Check that data variables contains Axes
data_vars = [var for var in ds.data_vars]
for var in data_vars:
assert ds[var].cf.axes["X"], f"{var} does not have an X axis"
assert ds[var].cf.axes["Y"], f"{var} does not have a Y axis"

return store

def _update_cf_axes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore:
"""Updates the ds attrs with cf_axes information"""
ds = self._get_dataset(store)
cf_attrs = self._retrieve_CF_axes(ds)
ds.attrs["cf_axes"] = cf_attrs

return store

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
return (
pcoll
| "Testing - Check CF Attrs" >> beam.Map(self._test_attributes)
| "Update CF_Axes" >> beam.Map(self._update_cf_axes)
)


@dataclass
class IIDEntry:
"""Single row/entry for an iid
Expand Down Expand Up @@ -160,13 +240,6 @@ def iid_list_exists(self, iids: List[str]) -> List[str]:
return list(set([r["instance_id"] for r in results]))


# wrapper functions (not sure if this works instead of the repeated copy and paste in the transform below)
def log_to_bq(iid: str, store: zarr.storage.FSStore, table_id: str):
bq_interface = BQInterface(table_id=table_id)
iid_entry = IIDEntry(iid=iid, store=store.path)
bq_interface.insert(iid_entry)


@dataclass
class LogToBigQuery(beam.PTransform):
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ classifiers = [
]
dependencies = [
"apache-beam",
"cf_xarray",
"google-cloud-bigquery",
"google-api-core",
"zarr",
Expand Down