Skip to content

Commit

Permalink
[SPARK-18621][PYTHON] Make sql type reprs eval-able
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

These changes update the `__repr__` methods of type classes in `pyspark.sql.types` to print string representations which are `eval`-able. In other words, any instance of a `DataType` will produce a repr which can be passed to `eval()` to create an identical instance.

Similar changes previously submitted: #25495

### Why are the changes needed?

This [bug](https://issues.apache.org/jira/browse/SPARK-18621) has been around for a while. The current implementation returns a string representation which is valid in scala rather than python. These changes fix the repr to be valid with python.

The [motivation](https://docs.python.org/3/library/functions.html#repr) is "to return a string that would yield an object with the same value when passed to eval()".

### Does this PR introduce _any_ user-facing change?

Example:

Current implementation:

```python
from pyspark.sql.types import *

struct = StructType([StructField('f1', StringType(), True)])
repr(struct)
# StructType(List(StructField(f1,StringType,true)))
new_struct = eval(repr(struct))
# Traceback (most recent call last):
#   File "<input>", line 1, in <module>
#   File "<string>", line 1, in <module>
# NameError: name 'List' is not defined

struct_field = StructField('f1', StringType(), True)
repr(struct_field)
# StructField(f1,StringType,true)
new_struct_field = eval(repr(struct_field))
# Traceback (most recent call last):
#   File "<input>", line 1, in <module>
#   File "<string>", line 1, in <module>
# NameError: name 'f1' is not defined
```

With changes:

```python
from pyspark.sql.types import *

struct = StructType([StructField('f1', StringType(), True)])
repr(struct)
# StructType([StructField('f1', StringType(), True)])
new_struct = eval(repr(struct))
struct == new_struct
# True

struct_field = StructField('f1', StringType(), True)
repr(struct_field)
# StructField('f1', StringType(), True)
new_struct_field = eval(repr(struct_field))
struct_field == new_struct_field
# True
```

### How was this patch tested?

The changes include a test which asserts that an instance of each type is equal to the `eval` of its `repr`, as in the above example.

Closes #34320 from crflynn/sql-types-repr.

Lead-authored-by: flynn <crf204@gmail.com>
Co-authored-by: Flynn <crflynn@users.noreply.github.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
(cherry picked from commit c5ebdc6)
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
2 people authored and srowen committed Mar 23, 2022
1 parent 737077a commit eb5d8fa
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 99 deletions.
1 change: 1 addition & 0 deletions python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ Upgrading from PySpark 3.2 to 3.3
* In Spark 3.3, the ``pyspark.pandas.sql`` method follows [the standard Python string formatter](https://docs.python.org/3/library/string.html#format-string-syntax). To restore the previous behavior, set ``PYSPARK_PANDAS_SQL_LEGACY`` environment variable to ``1``.
* In Spark 3.3, the ``drop`` method of pandas API on Spark DataFrame supports dropping rows by ``index``, and sets dropping by index instead of column by default.
* In Spark 3.3, PySpark upgrades Pandas version, the new minimum required version changes from 0.23.2 to 1.0.5.
* In Spark 3.3, the ``repr`` return values of SQL DataTypes have been changed to yield an object with the same value when passed to ``eval``.
8 changes: 4 additions & 4 deletions python/pyspark/ml/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def vector_to_array(col: Column, dtype: str = "float64") -> Column:
[Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]),
Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])]
>>> df1.schema.fields
[StructField(vec,ArrayType(DoubleType,false),false),
StructField(oldVec,ArrayType(DoubleType,false),false)]
[StructField('vec', ArrayType(DoubleType(), False), False),
StructField('oldVec', ArrayType(DoubleType(), False), False)]
>>> df2.schema.fields
[StructField(vec,ArrayType(FloatType,false),false),
StructField(oldVec,ArrayType(FloatType,false),false)]
[StructField('vec', ArrayType(FloatType(), False), False),
StructField('oldVec', ArrayType(FloatType(), False), False)]
"""
sc = SparkContext._active_spark_context
assert sc is not None and sc._jvm is not None
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/pandas/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, pandas_on_spark_obj):
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
ValueError: Cannot call DatetimeMethods on type StringType()
Note: This function is not meant to be used directly - instead, use register_dataframe_accessor,
register_series_accessor, or register_index_accessor.
Expand Down Expand Up @@ -169,7 +169,7 @@ def register_dataframe_accessor(name: str) -> Callable[[Type[T]], Type[T]]:
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
ValueError: Cannot call DatetimeMethods on type StringType()
Examples
--------
Expand Down Expand Up @@ -250,7 +250,7 @@ def __init__(self, pandas_on_spark_obj):
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
ValueError: Cannot call DatetimeMethods on type StringType()
Examples
--------
Expand Down Expand Up @@ -322,7 +322,7 @@ def __init__(self, pandas_on_spark_obj):
...
Traceback (most recent call last):
...
ValueError: Cannot call DatetimeMethods on type StringType
ValueError: Cannot call DatetimeMethods on type StringType()
Examples
--------
Expand Down
78 changes: 43 additions & 35 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __eq__(self, other: Any) -> bool:
)

def __repr__(self) -> str:
return "InternalField(dtype={dtype},struct_field={struct_field})".format(
return "InternalField(dtype={dtype}, struct_field={struct_field})".format(
dtype=self.dtype, struct_field=self.struct_field
)

Expand Down Expand Up @@ -293,13 +293,13 @@ class InternalFrame:
>>> internal.index_names
[None]
>>> internal.data_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=int64,struct_field=StructField(A,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(B,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(C,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(D,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(E,LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('A', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('B', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('C', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('D', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('E', LongType(), False))]
>>> internal.index_fields
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('__index_level_0__', LongType(), False))]
>>> internal.to_internal_spark_frame.show() # doctest: +NORMALIZE_WHITESPACE
+-----------------+---+---+---+---+---+
|__index_level_0__| A| B| C| D| E|
Expand Down Expand Up @@ -355,13 +355,13 @@ class InternalFrame:
['A', 'B', 'C', 'D', 'E']
>>> internal.index_names
[('A',)]
>>> internal.data_fields
[InternalField(dtype=int64,struct_field=StructField(B,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(C,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(D,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(E,LongType,false))]
>>> internal.data_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=int64, struct_field=StructField('B', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('C', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('D', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('E', LongType(), False))]
>>> internal.index_fields
[InternalField(dtype=int64,struct_field=StructField(A,LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('A', LongType(), False))]
>>> internal.to_internal_spark_frame.show() # doctest: +NORMALIZE_WHITESPACE
+---+---+---+---+---+
| A| B| C| D| E|
Expand Down Expand Up @@ -419,13 +419,13 @@ class InternalFrame:
>>> internal.index_names
[None, ('A',)]
>>> internal.data_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=int64,struct_field=StructField(B,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(C,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(D,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(E,LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('B', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('C', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('D', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('E', LongType(), False))]
>>> internal.index_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,false)),
InternalField(dtype=int64,struct_field=StructField(A,LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('__index_level_0__', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('A', LongType(), False))]
>>> internal.to_internal_spark_frame.show() # doctest: +NORMALIZE_WHITESPACE
+-----------------+---+---+---+---+---+
|__index_level_0__| A| B| C| D| E|
Expand Down Expand Up @@ -508,9 +508,9 @@ class InternalFrame:
>>> internal.index_names
[('A',)]
>>> internal.data_fields
[InternalField(dtype=int64,struct_field=StructField(B,LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('B', LongType(), False))]
>>> internal.index_fields
[InternalField(dtype=int64,struct_field=StructField(A,LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('A', LongType(), False))]
>>> internal.to_internal_spark_frame.show() # doctest: +NORMALIZE_WHITESPACE
+---+---+
| A| B|
Expand Down Expand Up @@ -596,9 +596,12 @@ def __init__(
[('row_index_a',), ('row_index_b',), ('a', 'x')]
>>> internal.index_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=object,struct_field=StructField(__index_level_0__,StringType,false)),
InternalField(dtype=object,struct_field=StructField(__index_level_1__,StringType,false)),
InternalField(dtype=int64,struct_field=StructField((a, x),LongType,false))]
[InternalField(dtype=object,
struct_field=StructField('__index_level_0__', StringType(), False)),
InternalField(dtype=object,
struct_field=StructField('__index_level_1__', StringType(), False)),
InternalField(dtype=int64,
struct_field=StructField('(a, x)', LongType(), False))]
>>> internal.column_labels
[('a', 'y'), ('b', 'z')]
Expand All @@ -607,8 +610,8 @@ def __init__(
[Column<'(a, y)'>, Column<'(b, z)'>]
>>> internal.data_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=int64,struct_field=StructField((a, y),LongType,false)),
InternalField(dtype=int64,struct_field=StructField((b, z),LongType,false))]
[InternalField(dtype=int64, struct_field=StructField('(a, y)', LongType(), False)),
InternalField(dtype=int64, struct_field=StructField('(b, z)', LongType(), False))]
>>> internal.column_label_names
[('column_labels_a',), ('column_labels_b',)]
Expand Down Expand Up @@ -1505,13 +1508,14 @@ def prepare_pandas_frame(
2 30 c 1
>>> index_columns
['__index_level_0__']
>>> index_fields
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,false))]
>>> index_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=int64, struct_field=StructField('__index_level_0__',
LongType(), False))]
>>> data_columns
['(x, a)', '(y, b)']
>>> data_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=object,struct_field=StructField((x, a),StringType,false)),
InternalField(dtype=category,struct_field=StructField((y, b),ByteType,false))]
[InternalField(dtype=object, struct_field=StructField('(x, a)', StringType(), False)),
InternalField(dtype=category, struct_field=StructField('(y, b)', ByteType(), False))]
>>> import datetime
>>> pdf = pd.DataFrame({
Expand All @@ -1521,9 +1525,11 @@ def prepare_pandas_frame(
>>> _, _, _, _, data_fields = (
... InternalFrame.prepare_pandas_frame(pdf, prefer_timestamp_ntz=True)
... )
>>> data_fields
[InternalField(dtype=datetime64[ns],struct_field=StructField(dt,TimestampNTZType,false)),
InternalField(dtype=object,struct_field=StructField(dt_obj,TimestampNTZType,false))]
>>> data_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=datetime64[ns],
struct_field=StructField('dt', TimestampNTZType(), False)),
InternalField(dtype=object,
struct_field=StructField('dt_obj', TimestampNTZType(), False))]
>>> pdf = pd.DataFrame({
... "td": [datetime.timedelta(0)], "td_obj": [datetime.timedelta(0)]
Expand All @@ -1533,8 +1539,10 @@ def prepare_pandas_frame(
... InternalFrame.prepare_pandas_frame(pdf)
... )
>>> data_fields # doctest: +NORMALIZE_WHITESPACE
[InternalField(dtype=timedelta64[ns],struct_field=StructField(td,DayTimeIntervalType(0,3),false)),
InternalField(dtype=object,struct_field=StructField(td_obj,DayTimeIntervalType(0,3),false))]
[InternalField(dtype=timedelta64[ns],
struct_field=StructField('td', DayTimeIntervalType(0, 3), False)),
InternalField(dtype=object,
struct_field=StructField('td_obj', DayTimeIntervalType(0, 3), False))]
"""
pdf = pdf.copy()

Expand Down
23 changes: 15 additions & 8 deletions python/pyspark/pandas/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def as_nullable_spark_type(dt: DataType) -> DataType:
>>> as_nullable_spark_type(StructType([
... StructField("A", IntegerType(), True),
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,IntegerType,true),StructField(B,FloatType,true)))
StructType([StructField('A', IntegerType(), True), StructField('B', FloatType(), True)])
>>> as_nullable_spark_type(StructType([
... StructField("A",
Expand All @@ -62,9 +62,12 @@ def as_nullable_spark_type(dt: DataType) -> DataType:
... ArrayType(IntegerType(), False), False), False),
... StructField('b', StringType(), True)])),
... StructField("B", FloatType(), False)])) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,StructType(List(StructField(a,MapType(IntegerType,ArrayType\
(IntegerType,true),true),true),StructField(b,StringType,true))),true),\
StructField(B,FloatType,true)))
StructType([StructField('A',
StructType([StructField('a',
MapType(IntegerType(),
ArrayType(IntegerType(), True), True), True),
StructField('b', StringType(), True)]), True),
StructField('B', FloatType(), True)])
"""
if isinstance(dt, StructType):
new_fields = []
Expand Down Expand Up @@ -132,7 +135,8 @@ def force_decimal_precision_scale(
>>> force_decimal_precision_scale(StructType([
... StructField("A", DecimalType(10, 0), True),
... StructField("B", DecimalType(14, 7), False)])) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,DecimalType(38,18),true),StructField(B,DecimalType(38,18),false)))
StructType([StructField('A', DecimalType(38,18), True),
StructField('B', DecimalType(38,18), False)])
>>> force_decimal_precision_scale(StructType([
... StructField("A",
Expand All @@ -143,9 +147,12 @@ def force_decimal_precision_scale(
... StructField('b', StringType(), True)])),
... StructField("B", DecimalType(30, 15), False)]),
... precision=30, scale=15) # doctest: +NORMALIZE_WHITESPACE
StructType(List(StructField(A,StructType(List(StructField(a,MapType(DecimalType(30,15),\
ArrayType(DecimalType(30,15),false),false),false),StructField(b,StringType,true))),true),\
StructField(B,DecimalType(30,15),false)))
StructType([StructField('A',
StructType([StructField('a',
MapType(DecimalType(30,15),
ArrayType(DecimalType(30,15), False), False), False),
StructField('b', StringType(), True)]), True),
StructField('B', DecimalType(30,15), False)])
"""
if isinstance(dt, StructType):
new_fields = []
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,7 +2227,7 @@ def udf(col) -> int:
with self.assertRaisesRegex(
TypeError,
"Expected the return type of this function to be of Series type, "
"but found type ScalarType\\[LongType\\]",
"but found type ScalarType\\[LongType\\(\\)\\]",
):
psdf.groupby("a").transform(udf)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,7 +2985,7 @@ def udf(col) -> ps.Series[int]:
with self.assertRaisesRegex(
ValueError,
r"Expected the return type of this function to be of scalar type, "
r"but found type SeriesType\[LongType\]",
r"but found type SeriesType\[LongType\(\)\]",
):
psser.apply(udf)

Expand Down
Loading

0 comments on commit eb5d8fa

Please sign in to comment.