Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap scalar generation into spark session in integration test #9405

Merged
merged 7 commits into from
Oct 18, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import *
import pyspark.sql.functions as f
import random
from spark_session import is_tz_utc, is_before_spark_340
from spark_session import is_tz_utc, is_before_spark_340, with_cpu_session
import sre_yield
import struct
from conftest import skip_unless_precommit_tests
Expand Down Expand Up @@ -596,11 +596,12 @@ def __init__(self, start=None, end=None, nullable=True, tzinfo=timezone.utc):
self._epoch = datetime(1970, 1, 1, tzinfo=tzinfo)
self._start_time = self._to_us_since_epoch(start)
self._end_time = self._to_us_since_epoch(end)
self._tzinfo = tzinfo
if (self._epoch >= start and self._epoch <= end):
self.with_special_case(self._epoch)

def _cache_repr(self):
return super()._cache_repr() + '(' + str(self._start_time) + ',' + str(self._end_time) + ')'
return super()._cache_repr() + '(' + str(self._start_time) + ',' + str(self._end_time) + ',' + str(self._tzinfo) + ')'

_us = timedelta(microseconds=1)

Expand Down Expand Up @@ -831,11 +832,15 @@ def _gen_scalars_common(data_gen, count, seed=0):

def gen_scalars(data_gen, count, seed=0, force_no_nulls=False):
"""Generate scalar values."""
if force_no_nulls:
assert(not isinstance(data_gen, NullGen))
src = _gen_scalars_common(data_gen, count, seed=seed)
data_type = src.data_type
return (_mark_as_lit(src.gen(force_no_nulls=force_no_nulls), data_type) for i in range(0, count))
def gen_scalars_help(data_gen, count, seed, force_no_nulls):
if force_no_nulls:
assert(not isinstance(data_gen, NullGen))
src = _gen_scalars_common(data_gen, count, seed=seed)
data_type = src.data_type
return (_mark_as_lit(src.gen(force_no_nulls=force_no_nulls), data_type) for i in range(0, count))
return with_cpu_session(lambda spark: gen_scalars_help(data_gen=data_gen,
Copy link
Collaborator

@revans2 revans2 Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this in a cpu_session fixes the current problem, but it adds a new one. If gen_scalars is called from inside a with_*_session it will have other problems. with_spark_session calls reset_spark_session_conf which does more than just reset the conf. It clears out the catalog too with no way to get the original config or catalog back after it exits. That means with_gpu_session -> gen_scalars will result in the query running on the CPU after the gen_scalars.

I see a few ways to properly fix this.

  1. We set spark.sql.legacy.allowNegativeScaleOfDecimal when launching spark and have the test framework throw an exception if it is not set. Then we remove references to it in all of the tests for consistency. Then we file a follow on issue to fix with_spark_session to not allow nesting and to throw an exception if it is nested.
  2. We fix with_spark_session to throw an exception if it is ever nested and do what you are doing today + update the docs for it to be clear that it can never be called from within a with_spark_session
  3. We fix the test to call gen_scalar from within a with_spark_session and add a doc fix for gen_scalar to indicate that negative scale decimals can have problems if called from outside of with_spark_session block. Then we file a follow on issue to fix with_spark_session to not allow nesting and to throw an exception if it is nested.

I personally prefer option 1 but I am fine with option 2 or 3. Talking to @jlowe he really prefers option 3. The main difference between option 3 and option 2 for me really about the amount of code that needs to change. If we just fix the one test and add some docs, that feels like a really small change. If we have to fix nesting/etc that feels a bit larger, but it is something we need to do either way and would mean all tests that use gen_scalar would be good to deal with all decimal values properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of 2. It's again surprising behavior (who would expect it to spawn a Spark session?). I'm fine with either 1 or 3, and even with 1, I still think we should fix the test(s). We should be putting all data generation inside a spark session context of some kind.

Copy link
Collaborator Author

@thirtiseven thirtiseven Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated code to option 3.

Now I wrap all scalar generation with a with_cpu_session, no matter if it calls f.lit or uses DecimalGen. Not sure if we only want to move the cases that are possible to fail into Spark sessions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-on issue: #9412

count=count, seed=seed,
force_no_nulls=force_no_nulls))

def gen_scalar(data_gen, seed=0, force_no_nulls=False):
"""Generate a single scalar value."""
Expand Down Expand Up @@ -1172,4 +1177,3 @@ def get_25_partitions_df(spark):
StructField("c3", IntegerType())])
data = [[i, j, k] for i in range(0, 5) for j in range(0, 5) for k in range(0, 100)]
return spark.createDataFrame(data, schema)