Skip to content

Commit

Permalink
Refractore
Browse files Browse the repository at this point in the history
  • Loading branch information
jprakash-db committed Sep 18, 2024
1 parent 2470581 commit a58c97f
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 24 deletions.
22 changes: 14 additions & 8 deletions databricks_sql_connector_core/tests/e2e/common/decimal_tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from decimal import Decimal

import pyarrow
import pytest

try:
import pyarrow
except ImportError:
pyarrow = None

class DecimalTestsMixin:
decimal_and_expected_results = [
from tests.e2e.predicate import pysql_supports_arrow

def decimal_and_expected_results():
return [
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
Expand All @@ -17,7 +22,8 @@ class DecimalTestsMixin:
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
]

multi_decimals_and_expected_results = [
def multi_decimals_and_expected_results():
return [
(
["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
[Decimal("1.00"), Decimal("100.001"), None],
Expand All @@ -30,7 +36,9 @@ class DecimalTestsMixin:
),
]

@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results)
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class DecimalTestsMixin:
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results())
def test_decimals(self, decimal, expected_value, expected_type):
with self.cursor({}) as cursor:
query = "SELECT CAST ({})".format(decimal)
Expand All @@ -39,9 +47,7 @@ def test_decimals(self, decimal, expected_value, expected_type):
assert table.field(0).type == expected_type
assert table.to_pydict().popitem()[1][0] == expected_value

@pytest.mark.parametrize(
"decimals, expected_values, expected_type", multi_decimals_and_expected_results
)
@pytest.mark.parametrize("decimals, expected_values, expected_type", multi_decimals_and_expected_results())
def test_multi_decimals(self, decimals, expected_values, expected_type):
with self.cursor({}) as cursor:
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
import math
import time
from unittest import skipUnless

import pytest
from tests.e2e.common.predicates import pysql_supports_arrow

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,6 +44,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
+ "assuming 10K fetch size."
)

@skipUnless(pysql_supports_arrow(), "Without pyarrow lz4 compression is not supported")
def test_query_with_large_wide_result_set(self):
resultSize = 300 * 1000 * 1000 # 300 MB
width = 8192 # B
Expand Down
10 changes: 7 additions & 3 deletions databricks_sql_connector_core/tests/e2e/common/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@


def pysql_supports_arrow():
"""Import databricks.sql and test whether Cursor has fetchall_arrow."""
from databricks.sql.client import Cursor
return hasattr(Cursor, 'fetchall_arrow')
"""Checks if the pyarrow library is installed or not"""
try:
import pyarrow

return True
except ImportError:
return False


def pysql_has_version(compare, version):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from numpy import ndarray

from tests.e2e.test_driver import PySQLPytestTestCase
from tests.e2e.predicate import pysql_supports_arrow


@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class TestComplexTypes(PySQLPytestTestCase):
@pytest.fixture(scope="class")
def table_fixture(self, connection_details):
Expand Down
21 changes: 13 additions & 8 deletions databricks_sql_connector_core/tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from uuid import uuid4

import numpy as np
import pyarrow
import pytz
import thrift
import pytest
Expand All @@ -35,6 +34,7 @@
pysql_supports_arrow,
compare_dbr_versions,
is_thrift_v5_plus,
pysql_supports_arrow
)
from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin
from tests.e2e.common.large_queries_mixin import LargeQueriesMixin
Expand All @@ -48,6 +48,11 @@

from databricks.sql.exc import SessionAlreadyClosedError

try:
import pyarrow
except:
pyarrow = None

log = logging.getLogger(__name__)

unsafe_logger = logging.getLogger("databricks.sql.unsafe")
Expand Down Expand Up @@ -591,7 +596,7 @@ def test_ssp_passthrough(self):
cursor.execute("SET ansi_mode")
assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)]

@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
def test_timestamps_arrow(self):
with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor:
for timestamp, expected in self.timestamp_and_expected_results:
Expand All @@ -611,7 +616,7 @@ def test_timestamps_arrow(self):
aware_timestamp and aware_timestamp.timestamp() * 1000000
), "timestamp {} did not match {}".format(timestamp, expected)

@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
def test_multi_timestamps_arrow(self):
with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor:
query, expected = self.multi_query()
Expand All @@ -627,7 +632,7 @@ def test_multi_timestamps_arrow(self):
]
assert result == expected

@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
def test_timezone_with_timestamp(self):
if self.should_add_timezone():
with self.cursor() as cursor:
Expand All @@ -646,7 +651,7 @@ def test_timezone_with_timestamp(self):
assert arrow_result_table.field(0).type == ts_type
assert arrow_result_value == expected.timestamp() * 1000000

@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
def test_can_flip_compression(self):
with self.cursor() as cursor:
cursor.execute("SELECT array(1,2,3,4)")
Expand All @@ -663,7 +668,7 @@ def test_can_flip_compression(self):
def _should_have_native_complex_types(self):
return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments)

@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
def test_arrays_are_not_returned_as_strings_arrow(self):
if self._should_have_native_complex_types():
with self.cursor() as cursor:
Expand All @@ -674,7 +679,7 @@ def test_arrays_are_not_returned_as_strings_arrow(self):
assert pyarrow.types.is_list(list_type)
assert pyarrow.types.is_integer(list_type.value_type)

@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
def test_structs_are_not_returned_as_strings_arrow(self):
if self._should_have_native_complex_types():
with self.cursor() as cursor:
Expand All @@ -684,7 +689,7 @@ def test_structs_are_not_returned_as_strings_arrow(self):
struct_type = arrow_df.field(0).type
assert pyarrow.types.is_struct(struct_type)

@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
def test_decimal_not_returned_as_strings_arrow(self):
if self._should_have_native_complex_types():
with self.cursor() as cursor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
VoidParameter,
)
from tests.e2e.test_driver import PySQLPytestTestCase
from tests.e2e.predicate import pysql_supports_arrow


class ParamStyle(Enum):
Expand Down Expand Up @@ -284,6 +285,8 @@ def test_primitive_single(
(PrimitiveExtra.TINYINT, TinyIntParameter),
],
)

@pytest.mark.skipif(not pysql_supports_arrow(),reason="Without pyarrow TIMESTAMP_NTZ datatype cannot be inferred",)
def test_dbsqlparameter_single(
self,
primitive: Primitive,
Expand Down
9 changes: 8 additions & 1 deletion databricks_sql_connector_core/tests/unit/test_arrow_queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import unittest

import pyarrow as pa
import pytest

from databricks.sql.utils import ArrowQueue

try:
import pyarrow as pa
except ImportError:
pa = None

from tests.e2e.predicate import pysql_supports_arrow

@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class ArrowQueueSuite(unittest.TestCase):
@staticmethod
def make_arrow_table(batch):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import pyarrow
import pytest
import unittest
from unittest.mock import MagicMock, patch
from ssl import create_default_context

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
import databricks.sql.utils as utils
from tests.e2e.predicate import pysql_supports_arrow

try:
import pyarrow
except ImportError:
pyarrow = None

@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class CloudFetchQueueSuite(unittest.TestCase):

def create_result_link(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import unittest
from unittest.mock import patch, MagicMock
import pytest

from ssl import create_default_context

import databricks.sql.cloudfetch.download_manager as download_manager
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

from tests.e2e.predicate import pysql_supports_arrow

@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class DownloadManagerTests(unittest.TestCase):
"""
Unit tests for checking download manager logic.
Expand Down
9 changes: 7 additions & 2 deletions databricks_sql_connector_core/tests/unit/test_fetches.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import unittest
from unittest.mock import Mock

import pyarrow as pa
import pytest

import databricks.sql.client as client
from databricks.sql.utils import ExecuteResponse, ArrowQueue
from tests.e2e.predicate import pysql_supports_arrow

try:
import pyarrow as pa
except ImportError:
pa = None

@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
class FetchTests(unittest.TestCase):
"""
Unit tests for checking the fetch logic.
Expand Down

0 comments on commit a58c97f

Please sign in to comment.