From feae56bb9bf3d5a5674ecfb4c9bcd566face4b2d Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Thu, 11 Jan 2024 07:29:15 -0700 Subject: [PATCH 1/4] Transform for cf_attrs --- data_management_utils/utils.py | 62 +++++++++++++++++++++++++++++++--- pyproject.toml | 1 + 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/data_management_utils/utils.py b/data_management_utils/utils.py index f3bb247..f475c8c 100644 --- a/data_management_utils/utils.py +++ b/data_management_utils/utils.py @@ -6,9 +6,66 @@ from google.api_core.exceptions import NotFound import apache_beam as beam import zarr +import xarray as xr import datetime from dataclasses import dataclass +import dask +#---------------------------------------------------------------- +# ----------------- Functions ----------------------------------- +#---------------------------------------------------------------- +# wrapper functions (not sure if this works instead of the repeated copy and paste in the transform below) +def log_to_bq(iid: str, store: zarr.storage.FSStore, table_id: str): + bq_interface = BQInterface(table_id=table_id) + iid_entry = IIDEntry(iid=iid, store=store.path) + bq_interface.insert(iid_entry) + + + + + +#---------------------------------------------------------------- +# ----------------- Transforms ---------------------------------- +#---------------------------------------------------------------- + +@dataclass +class ValidateCFConventions(beam.PTransform): + """ + Transform to validate CF conventions + """ + @staticmethod + def _get_dataset(store: zarr.storage.FSStore) -> xr.Dataset: + import xarray as xr + return xr.open_dataset(store, engine='zarr', chunks={}, use_cftime=True) + + @staticmethod + def _retrieve_CF_axes(ds: xr.Dataset) -> dict[str, dict[str, str]]: + """Retrieve the CF dimensions from the dataset""" + import cf_xarray # noqa + + results = {} + for variable in ds.variables: + axes = ds[variable].cf.axes + for key, value in axes.items(): + axes[key] = value[0] + results[variable] = axes + return results + + def _test_attributes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: + ds = self._get_dataset(store) + cf_attrs = self._retrieve_CF_axes(ds) + + # How should we validate cf_attrs from here? + + import pdb; pdb.set_trace() + + return store + + def expand(self, pcoll: beam.PCollection) -> beam.PCollection: + return (pcoll + | "Testing - Check CF Attrs" >> beam.Map(self._test_attributes) + ) + @dataclass class IIDEntry: """Single row/entry for an iid @@ -160,11 +217,6 @@ def iid_list_exists(self, iids: List[str]) -> List[str]: return list(set([r["instance_id"] for r in results])) -# wrapper functions (not sure if this works instead of the repeated copy and paste in the transform below) -def log_to_bq(iid: str, store: zarr.storage.FSStore, table_id: str): - bq_interface = BQInterface(table_id=table_id) - iid_entry = IIDEntry(iid=iid, store=store.path) - bq_interface.insert(iid_entry) @dataclass diff --git a/pyproject.toml b/pyproject.toml index f7e5533..244c313 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ ] dependencies = [ "apache-beam", + "cf_xarray", "google-cloud-bigquery", "google-api-core", "zarr", From 981e1b35420170bfb7b186832c92501bb07db786 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Fri, 12 Jan 2024 16:13:01 -0700 Subject: [PATCH 2/4] cf_axes validation --- data_management_utils/utils.py | 65 ++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/data_management_utils/utils.py b/data_management_utils/utils.py index f475c8c..b4e6478 100644 --- a/data_management_utils/utils.py +++ b/data_management_utils/utils.py @@ -6,13 +6,14 @@ from google.api_core.exceptions import NotFound import apache_beam as beam import zarr -import xarray as xr +import xarray as xr import datetime from dataclasses import dataclass -import dask -#---------------------------------------------------------------- + + +# ---------------------------------------------------------------- # ----------------- Functions ----------------------------------- -#---------------------------------------------------------------- +# ---------------------------------------------------------------- # wrapper functions (not sure if this works instead of the repeated copy and paste in the transform below) def log_to_bq(iid: str, store: zarr.storage.FSStore, table_id: str): bq_interface = BQInterface(table_id=table_id) @@ -20,24 +21,23 @@ def log_to_bq(iid: str, store: zarr.storage.FSStore, table_id: str): bq_interface.insert(iid_entry) - - - -#---------------------------------------------------------------- +# ---------------------------------------------------------------- # ----------------- Transforms ---------------------------------- -#---------------------------------------------------------------- +# ---------------------------------------------------------------- @dataclass class ValidateCFConventions(beam.PTransform): """ - Transform to validate CF conventions + Transform to validate CF conventions """ + @staticmethod def _get_dataset(store: zarr.storage.FSStore) -> xr.Dataset: import xarray as xr - return xr.open_dataset(store, engine='zarr', chunks={}, use_cftime=True) - + + return xr.open_dataset(store, engine="zarr", chunks={}, use_cftime=True) + @staticmethod def _retrieve_CF_axes(ds: xr.Dataset) -> dict[str, dict[str, str]]: """Retrieve the CF dimensions from the dataset""" @@ -50,22 +50,45 @@ def _retrieve_CF_axes(ds: xr.Dataset) -> dict[str, dict[str, str]]: axes[key] = value[0] results[variable] = axes return results - + def _test_attributes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: + """cf_axes validation""" + import cf_xarray # noqa + + ds = self._get_dataset(store) + cf_axes_dict = ds.cf.axes + # Check X coord + assert cf_axes_dict.get( + "X" + ), "According to cf_xarray, this dataset is missing an X Axis" + # Check Y coord + assert cf_axes_dict.get( + "Y" + ), "According to cf_xarray, this dataset is missing a Y Axis" + # Check that data variables contains Axes + data_vars = [var for var in ds.data_vars] + for var in data_vars: + assert ds[var].cf.axes["X"], f"{var} does not have an X axis" + assert ds[var].cf.axes["Y"], f"{var} does not have a Y axis" + + return store + + def _update_cf_axes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: + """Updates the ds attrs with cf_axes information""" ds = self._get_dataset(store) cf_attrs = self._retrieve_CF_axes(ds) + ds.attrs["cf_axes"] = cf_attrs - # How should we validate cf_attrs from here? - - import pdb; pdb.set_trace() - return store - + def expand(self, pcoll: beam.PCollection) -> beam.PCollection: - return (pcoll + return ( + pcoll | "Testing - Check CF Attrs" >> beam.Map(self._test_attributes) + | "Update CF_Axes" >> beam.Map(self._update_cf_axes) ) - + + @dataclass class IIDEntry: """Single row/entry for an iid @@ -217,8 +240,6 @@ def iid_list_exists(self, iids: List[str]) -> List[str]: return list(set([r["instance_id"] for r in results])) - - @dataclass class LogToBigQuery(beam.PTransform): """ From 35f69125f64de60eab19e787dd7c05041a306dfd Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Sun, 14 Jan 2024 18:46:54 -0700 Subject: [PATCH 3/4] updated naming of _validate_cf_attributes --- data_management_utils/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/data_management_utils/utils.py b/data_management_utils/utils.py index b4e6478..0dcebdc 100644 --- a/data_management_utils/utils.py +++ b/data_management_utils/utils.py @@ -51,7 +51,9 @@ def _retrieve_CF_axes(ds: xr.Dataset) -> dict[str, dict[str, str]]: results[variable] = axes return results - def _test_attributes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: + def _validate_cf_attributes( + self, store: zarr.storage.FSStore + ) -> zarr.storage.FSStore: """cf_axes validation""" import cf_xarray # noqa @@ -84,7 +86,7 @@ def _update_cf_axes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: def expand(self, pcoll: beam.PCollection) -> beam.PCollection: return ( pcoll - | "Testing - Check CF Attrs" >> beam.Map(self._test_attributes) + | "Testing - Check CF Attrs" >> beam.Map(self._validate_cf_attributes) | "Update CF_Axes" >> beam.Map(self._update_cf_axes) ) From f98df5b99f5defcb209f738088e33467a97ab0f9 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Mon, 15 Jan 2024 10:04:01 -0700 Subject: [PATCH 4/4] moved _retrieve_CF_axes to top level --- data_management_utils/utils.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/data_management_utils/utils.py b/data_management_utils/utils.py index 0dcebdc..0a17965 100644 --- a/data_management_utils/utils.py +++ b/data_management_utils/utils.py @@ -21,6 +21,19 @@ def log_to_bq(iid: str, store: zarr.storage.FSStore, table_id: str): bq_interface.insert(iid_entry) +def _retrieve_CF_axes(ds: xr.Dataset) -> dict[str, dict[str, str]]: + """Retrieve the CF dimensions from the dataset""" + import cf_xarray # noqa + + results = {} + for variable in ds.variables: + axes = ds[variable].cf.axes + for key, value in axes.items(): + axes[key] = value[0] + results[variable] = axes + return results + + # ---------------------------------------------------------------- # ----------------- Transforms ---------------------------------- # ---------------------------------------------------------------- @@ -38,19 +51,6 @@ def _get_dataset(store: zarr.storage.FSStore) -> xr.Dataset: return xr.open_dataset(store, engine="zarr", chunks={}, use_cftime=True) - @staticmethod - def _retrieve_CF_axes(ds: xr.Dataset) -> dict[str, dict[str, str]]: - """Retrieve the CF dimensions from the dataset""" - import cf_xarray # noqa - - results = {} - for variable in ds.variables: - axes = ds[variable].cf.axes - for key, value in axes.items(): - axes[key] = value[0] - results[variable] = axes - return results - def _validate_cf_attributes( self, store: zarr.storage.FSStore ) -> zarr.storage.FSStore: @@ -78,7 +78,7 @@ def _validate_cf_attributes( def _update_cf_axes(self, store: zarr.storage.FSStore) -> zarr.storage.FSStore: """Updates the ds attrs with cf_axes information""" ds = self._get_dataset(store) - cf_attrs = self._retrieve_CF_axes(ds) + cf_attrs = _retrieve_CF_axes(ds) ds.attrs["cf_axes"] = cf_attrs return store