From 5ce3e1d6c21f7e6d539985d4cb229b7ce3ffa55d Mon Sep 17 00:00:00 2001 From: Varun Mittal Date: Tue, 10 Oct 2023 17:56:33 -0400 Subject: [PATCH] Implemented lookup operation for converting column values using a reference dataset. Signed-off-by: Varun Mittal --- .../focus_converter/configs/base_config.py | 18 +- .../conversion_functions/__init__.py | 3 + .../conversion_functions/lookup_function.py | 24 +++ .../focus_converter/converter.py | 16 ++ .../tests/converter_functions/__init__.py | 0 .../test_lookup_function.py | 156 ++++++++++++++++++ 6 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 focus_converter_base/focus_converter/conversion_functions/lookup_function.py create mode 100644 focus_converter_base/tests/converter_functions/__init__.py create mode 100644 focus_converter_base/tests/converter_functions/test_lookup_function.py diff --git a/focus_converter_base/focus_converter/configs/base_config.py b/focus_converter_base/focus_converter/configs/base_config.py index cb8fda70..b4170428 100644 --- a/focus_converter_base/focus_converter/configs/base_config.py +++ b/focus_converter_base/focus_converter/configs/base_config.py @@ -8,10 +8,10 @@ BaseModel, ConfigDict, Field, - FilePath, ValidationError, field_validator, ) +from pydantic import FilePath from pydantic_core.core_schema import FieldValidationInfo from pytz.exceptions import UnknownTimeZoneError from typing_extensions import Annotated @@ -25,6 +25,12 @@ class SQLConditionConversionArgs(BaseModel): default_value: Any +class LookupConversionArgs(BaseModel): + reference_dataset_path: FilePath + source_value: str + destination_value: str + + CONFIG_FILE_PATTERN = re.compile("D\d{3}_S\d{3}.yaml") @@ -58,7 +64,7 @@ class ConversionPlan(BaseModel): @field_validator("conversion_args") @classmethod - def double(cls, v: Any, field_info: FieldValidationInfo) -> str: + def conversion_args_validation(cls, v: Any, field_info: FieldValidationInfo) -> str: conversion_type: STATIC_CONVERSION_TYPES = field_info.data.get( "conversion_type" ) @@ -80,6 +86,14 @@ def double(cls, v: Any, field_info: FieldValidationInfo) -> str: raise ValueError( f"Invalid SQL condition specified in conversion_args for plan: {field_info.data}" ) + elif conversion_type == STATIC_CONVERSION_TYPES.LOOKUP: + try: + LookupConversionArgs.model_validate(v) + except ValidationError as e: + raise ValueError( + e, + f"Invalid lookup arg specified in conversion_args for plan: {field_info.data}", + ) return v @field_validator("column_prefix") diff --git a/focus_converter_base/focus_converter/conversion_functions/__init__.py b/focus_converter_base/focus_converter/conversion_functions/__init__.py index adcaec15..2749d41d 100644 --- a/focus_converter_base/focus_converter/conversion_functions/__init__.py +++ b/focus_converter_base/focus_converter/conversion_functions/__init__.py @@ -18,6 +18,9 @@ class STATIC_CONVERSION_TYPES(Enum): # unnest operation UNNEST_COLUMN = "unnest" + # lookup operation + LOOKUP = "lookup" + __all__ = [ "STATIC_CONVERSION_TYPES", diff --git a/focus_converter_base/focus_converter/conversion_functions/lookup_function.py b/focus_converter_base/focus_converter/conversion_functions/lookup_function.py new file mode 100644 index 00000000..d5905797 --- /dev/null +++ b/focus_converter_base/focus_converter/conversion_functions/lookup_function.py @@ -0,0 +1,24 @@ +from focus_converter.configs.base_config import ConversionPlan, LookupConversionArgs +import polars as pl + + +class LookupFunction: + @classmethod + def map_values_using_lookup(cls, plan: ConversionPlan, column_alias): + conversion_args = LookupConversionArgs.model_validate(plan.conversion_args) + reference_data_lf = ( + pl.scan_csv(conversion_args.reference_dataset_path) + .select([conversion_args.source_value, conversion_args.destination_value]) + .with_columns( + [ + pl.col(conversion_args.source_value).cast(pl.Utf8), + ] + ) + .rename({conversion_args.destination_value: column_alias}) + ) + return { + "other": reference_data_lf, + "left_on": plan.column, + "how": "left", + "right_on": conversion_args.source_value, + } diff --git a/focus_converter_base/focus_converter/converter.py b/focus_converter_base/focus_converter/converter.py index 229ab729..e20ebbb6 100644 --- a/focus_converter_base/focus_converter/converter.py +++ b/focus_converter_base/focus_converter/converter.py @@ -13,6 +13,7 @@ from focus_converter.conversion_functions.datetime_functions import ( DateTimeConversionFunctions, ) +from focus_converter.conversion_functions.lookup_function import LookupFunction from focus_converter.conversion_functions.sql_functions import SQLFunctions from focus_converter.data_loaders.data_exporter import DataExporter from focus_converter.data_loaders.data_loader import DataLoader @@ -84,6 +85,9 @@ def prepare_horizontal_conversion_plan(self, provider): # sql queries collected to be applied on the lazy frame self.h_sql_queries = sql_queries = [] + # lookup lazyframes arguments to be assembled later on the final source lazyframe + self.lookup_reference_args = [] + # add provider by default to our column expressions column_exprs.append(ColumnFunctions.add_provider(provider=provider)) @@ -144,6 +148,12 @@ def prepare_horizontal_conversion_plan(self, provider): column_exprs.append( ColumnFunctions.unnest(plan=plan, column_alias=column_alias) ) + elif plan.conversion_type == STATIC_CONVERSION_TYPES.LOOKUP: + self.lookup_reference_args.append( + LookupFunction.map_values_using_lookup( + plan=plan, column_alias=column_alias + ) + ) else: raise NotImplementedError( f"Plan: {plan.conversion_type} not implemented" @@ -166,6 +176,12 @@ def __apply_column_expressions__( lf = lf.with_columns_seq(expr) return lf + @staticmethod + def __apply_lookup_reference_plans__(lf: pl.LazyFrame, lookup_args): + for lookup_arg in lookup_args: + lf = lf.join(**lookup_arg) + return lf + def explain(self): # get batched data lazy frame, build the plan and then break return self.__network__.show_graph() diff --git a/focus_converter_base/tests/converter_functions/__init__.py b/focus_converter_base/tests/converter_functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/focus_converter_base/tests/converter_functions/test_lookup_function.py b/focus_converter_base/tests/converter_functions/test_lookup_function.py new file mode 100644 index 00000000..b1e095b2 --- /dev/null +++ b/focus_converter_base/tests/converter_functions/test_lookup_function.py @@ -0,0 +1,156 @@ +import os +import tempfile +from unittest import TestCase +from uuid import uuid4 + +import pandas as pd +import polars as pl +from jinja2 import Template +from pydantic import ValidationError + +from focus_converter.configs.base_config import ConversionPlan +from focus_converter.conversion_functions.lookup_function import LookupFunction +from focus_converter.converter import FocusConverter + +VALUE_LOOKUP_SAMPLE_TEMPLATE_YAML_JINJA = """ +plan_name: sample +priority: 1 +column: {{ random_column_alias }} +conversion_type: lookup +focus_column: Region +conversion_args: + reference_dataset_path: {{ test_reference_dataset_path }} + source_value: {{ source_value }} + destination_value: {{ destination_value }} +""" + + +VALUE_LOOKUP_SAMPLE_TEMPLATE_MISSING_VALUE_YAML = """ +plan_name: sample +priority: 1 +column: test_column +conversion_type: lookup +focus_column: Region +""" + +VALUE_MAPPING_SAMPLE_TEMPLATE_YAML = Template(VALUE_LOOKUP_SAMPLE_TEMPLATE_YAML_JINJA) + + +# noinspection DuplicatedCode +class TestMappingFunction(TestCase): + def test_map_not_defined(self): + with tempfile.TemporaryDirectory() as temp_dir: + sample_file_path = os.path.join(temp_dir, "D001_S001.yaml") + + with open(sample_file_path, "w") as fd: + fd.write(VALUE_LOOKUP_SAMPLE_TEMPLATE_MISSING_VALUE_YAML) + + with self.assertRaises(ValidationError) as cm: + ConversionPlan.load_yaml(sample_file_path) + self.assertEqual(len(cm.exception.errors()), 1) + self.assertEqual(cm.exception.errors()[0]["loc"], ("conversion_args",)) + + def test_lookup_value_with_bad_reference_data_path(self): + random_column_alias = str(uuid4()) + generated_yaml = VALUE_MAPPING_SAMPLE_TEMPLATE_YAML.render( + random_column_alias=random_column_alias, + test_reference_dataset_path="nonexistent_path", + ) + + with tempfile.TemporaryDirectory() as temp_dir: + sample_file_path = os.path.join(temp_dir, "D001_S001.yaml") + + with open(sample_file_path, "w") as fd: + fd.write(generated_yaml) + + with self.assertRaises(ValidationError) as cm: + ConversionPlan.load_yaml(sample_file_path) + self.assertEqual(len(cm.exception.errors()), 1) + self.assertEqual(cm.exception.errors()[0]["loc"], ("conversion_args",)) + + def test_lookup_value(self): + random_column_alias = str(uuid4()) + random_focus_colum = str(uuid4()) + + source_column_alias = str(uuid4()) + destination_column_alias = str(uuid4()) + + random_mapping_df = pd.DataFrame( + [ + { + source_column_alias: "1", + destination_column_alias: "1_mapped", + "ignore_column": "-", + }, + { + source_column_alias: "2", + destination_column_alias: "2_mapped", + "ignore_column": "-", + }, + { + source_column_alias: "3", + destination_column_alias: "3_mapped", + "ignore_column": "-", + }, + { + source_column_alias: "4", + destination_column_alias: "4_mapped", + "ignore_column": "-", + }, + ] + ) + with tempfile.NamedTemporaryFile(suffix=".csv") as mapping_csv: + random_mapping_df.to_csv(mapping_csv.name) + + generated_yaml = VALUE_MAPPING_SAMPLE_TEMPLATE_YAML.render( + random_column_alias=random_column_alias, + test_reference_dataset_path=mapping_csv.name, + source_value=source_column_alias, + destination_value=destination_column_alias, + ) + + df = pd.DataFrame( + [ + {"index_value": "1", random_column_alias: "1"}, + {"index_value": "2", random_column_alias: "2"}, + {"index_value": "3", random_column_alias: "3"}, + {"index_value": "4", random_column_alias: "4"}, + {"index_value": "5", random_column_alias: "5"}, + ] + ) + pl_df = pl.from_dataframe(df).lazy() + + with tempfile.TemporaryDirectory() as temp_dir: + sample_file_path = os.path.join(temp_dir, "D001_S001.yaml") + + with open(sample_file_path, "w") as fd: + fd.write(generated_yaml) + + conversion_plan = ConversionPlan.load_yaml(sample_file_path) + conversion_lookup_args = LookupFunction.map_values_using_lookup( + plan=conversion_plan, column_alias=random_focus_colum + ) + + modified_pl_df = FocusConverter.__apply_lookup_reference_plans__( + lf=pl_df, lookup_args=[conversion_lookup_args] + ).collect() + self.assertIn(random_column_alias, modified_pl_df.columns) + self.assertIn(random_focus_colum, modified_pl_df.columns) + self.assertIn("index_value", modified_pl_df.columns) + self.assertEqual(len(modified_pl_df.columns), 3) + + for index_value, _, mapped_value in modified_pl_df.iter_rows(): + if index_value == "1": + self.assertEqual(mapped_value, "1_mapped") + elif index_value == "2": + self.assertEqual(mapped_value, "2_mapped") + elif index_value == "3": + self.assertEqual(mapped_value, "3_mapped") + elif index_value == "4": + self.assertEqual(mapped_value, "4_mapped") + elif index_value == "5": + self.assertIsNone(mapped_value) + else: + raise self.failureException( + f"Invalid value, map function not mapped, key: {index_value}, value: {mapped_value}" + )