Skip to content

Commit

Permalink
[SPARK-46084][PS] Refactor data type casting operation for Categorica…
Browse files Browse the repository at this point in the history
…l type

### What changes were proposed in this pull request?

The PR proposes to refactor data type casting operation - especially `DataTypeOps.astype` -  for Categorical type.

### Why are the changes needed?

To optimize performance/debuggability/readability by using official API. We can leverage the PySpark API `coalesce` and `create_map `, instead of implementing Python code from scratch.

### Does this PR introduce _any_ user-facing change?

No, it's internal optimization.

### How was this patch tested?

The existing CI should pass.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#43993 from itholic/refactor_cat.

Authored-by: Haejoon Lee <haejoon.lee@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
itholic authored and dongjoon-hyun committed Nov 24, 2023
1 parent 1131a56 commit 6ad80fd
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions python/pyspark/pandas/data_type_ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numbers
from abc import ABCMeta
from typing import Any, Optional, Union
from itertools import chain

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -129,26 +130,11 @@ def _as_categorical_type(
if len(categories) == 0:
scol = F.lit(-1)
else:
scol = F.lit(-1)
if isinstance(
index_ops._internal.spark_type_for(index_ops._internal.column_labels[0]), BinaryType
):
from pyspark.sql.functions import base64

stringified_column = base64(index_ops.spark.column)
for code, category in enumerate(categories):
# Convert each category to base64 before comparison
base64_category = F.base64(F.lit(category))
scol = F.when(stringified_column == base64_category, F.lit(code)).otherwise(
scol
)
else:
stringified_column = F.format_string("%s", index_ops.spark.column)

for code, category in enumerate(categories):
scol = F.when(stringified_column == F.lit(category), F.lit(code)).otherwise(
scol
)
kvs = chain(
*[(F.lit(category), F.lit(code)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = F.coalesce(map_scol[index_ops.spark.column], F.lit(-1))

return index_ops._with_new_scol(
scol.cast(spark_type),
Expand Down

0 comments on commit 6ad80fd

Please sign in to comment.