From 6ad80fd036834e7291c572335e318096781a7ae4 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 23 Nov 2023 22:05:52 -0800 Subject: [PATCH] [SPARK-46084][PS] Refactor data type casting operation for Categorical type ### What changes were proposed in this pull request? The PR proposes to refactor data type casting operation - especially `DataTypeOps.astype` - for Categorical type. ### Why are the changes needed? To optimize performance/debuggability/readability by using official API. We can leverage the PySpark API `coalesce` and `create_map `, instead of implementing Python code from scratch. ### Does this PR introduce _any_ user-facing change? No, it's internal optimization. ### How was this patch tested? The existing CI should pass. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43993 from itholic/refactor_cat. Authored-by: Haejoon Lee Signed-off-by: Dongjoon Hyun --- python/pyspark/pandas/data_type_ops/base.py | 26 +++++---------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index 4f57aa65be7c1..5a4cd7a1eb070 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -18,6 +18,7 @@ import numbers from abc import ABCMeta from typing import Any, Optional, Union +from itertools import chain import numpy as np import pandas as pd @@ -129,26 +130,11 @@ def _as_categorical_type( if len(categories) == 0: scol = F.lit(-1) else: - scol = F.lit(-1) - if isinstance( - index_ops._internal.spark_type_for(index_ops._internal.column_labels[0]), BinaryType - ): - from pyspark.sql.functions import base64 - - stringified_column = base64(index_ops.spark.column) - for code, category in enumerate(categories): - # Convert each category to base64 before comparison - base64_category = F.base64(F.lit(category)) - scol = F.when(stringified_column == base64_category, F.lit(code)).otherwise( - scol - ) - else: - stringified_column = F.format_string("%s", index_ops.spark.column) - - for code, category in enumerate(categories): - scol = F.when(stringified_column == F.lit(category), F.lit(code)).otherwise( - scol - ) + kvs = chain( + *[(F.lit(category), F.lit(code)) for code, category in enumerate(categories)] + ) + map_scol = F.create_map(*kvs) + scol = F.coalesce(map_scol[index_ops.spark.column], F.lit(-1)) return index_ops._with_new_scol( scol.cast(spark_type),