Skip to content

Commit

Permalink
🐛 ✅ Fix KeyError on MemoryDataset outputs & add tests (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-mmm authored and Galileo-Galilei committed Jun 3, 2024
1 parent e40fd99 commit 9c0bb10
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
3 changes: 2 additions & 1 deletion kedro_pandera/framework/hooks/pandera_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def _validate_datasets(
self, node: Node, catalog: DataCatalog, datasets: Dict[str, Any]
):
for name, data in datasets.items():
metadata = getattr(catalog._datasets[name], "metadata", None)
dataset = catalog._datasets.get(name)
metadata = getattr(dataset, "metadata", None)
if (
metadata is not None
and "pandera" in metadata
Expand Down
28 changes: 26 additions & 2 deletions tests/framework/hooks/test_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest
from kedro.io import DataCatalog
from kedro.pipeline import node
from kedro.framework.hooks import _create_hook_manager
from kedro.framework.hooks.manager import _register_hooks
from kedro.io import DataCatalog, LambdaDataset
from kedro.pipeline import node, pipeline
from kedro.runner import SequentialRunner
from kedro_datasets.pandas import CSVDataset
from pandera.errors import SchemaError
from pandera.io import from_yaml
Expand Down Expand Up @@ -118,3 +121,24 @@ def test_validate_only_once(caplog):
)
# should only be validated once
assert caplog.text.count("successfully validated") == 1


def test_no_exception_on_memory_dataset_output():
test_hook_manager = _create_hook_manager()
test_hook = _get_test_hook()
HOOKS = (test_hook,)
_register_hooks(test_hook_manager, HOOKS)
test_catalog = DataCatalog(
{
"Input": LambdaDataset(load=lambda: "data", save=lambda data: None),
"Output": LambdaDataset(load=lambda: "data", save=lambda data: None),
}
)
test_pipeline = pipeline(
[
node(func=lambda x: x, inputs="Input", outputs="MemOutput", name="node1"),
node(func=lambda x: x, inputs="MemOutput", outputs="Output", name="node2"),
]
)
assert test_hook_manager.is_registered(test_hook)
SequentialRunner().run(test_pipeline, test_catalog, hook_manager=test_hook_manager)

0 comments on commit 9c0bb10

Please sign in to comment.