diff --git a/examples/custom_cred_provider.py b/examples/custom_cred_provider.py index 4c43280f..67945f23 100644 --- a/examples/custom_cred_provider.py +++ b/examples/custom_cred_provider.py @@ -4,23 +4,27 @@ from databricks.sdk.oauth import OAuthClient import os -oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"), - client_id=os.getenv("DATABRICKS_CLIENT_ID"), - client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"), - redirect_url=os.getenv("APP_REDIRECT_URL"), - scopes=['all-apis', 'offline_access']) +oauth_client = OAuthClient( + host=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + client_id=os.getenv("DATABRICKS_CLIENT_ID"), + client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"), + redirect_url=os.getenv("APP_REDIRECT_URL"), + scopes=["all-apis", "offline_access"], +) consent = oauth_client.initiate_consent() creds = consent.launch_external_browser() -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - credentials_provider=creds) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + credentials_provider=creds, +) as connection: for x in range(1, 5): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/insert_data.py b/examples/insert_data.py index b304a0e9..053ed158 100644 --- a/examples/insert_data.py +++ b/examples/insert_data.py @@ -1,21 +1,23 @@ from databricks import sql import os -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: - with connection.cursor() as cursor: - cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)") + with connection.cursor() as cursor: + cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)") - squares = [(i, i * i) for i in range(100)] - values = ",".join([f"({x}, {y})" for (x, y) in squares]) + squares = [(i, i * i) for i in range(100)] + values = ",".join([f"({x}, {y})" for (x, y) in squares]) - cursor.execute(f"INSERT INTO squares VALUES {values}") + cursor.execute(f"INSERT INTO squares VALUES {values}") - cursor.execute("SELECT * FROM squares LIMIT 10") + cursor.execute("SELECT * FROM squares LIMIT 10") - result = cursor.fetchall() + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/examples/interactive_oauth.py b/examples/interactive_oauth.py index dad5cac6..8dbc8c47 100644 --- a/examples/interactive_oauth.py +++ b/examples/interactive_oauth.py @@ -13,12 +13,14 @@ token across script executions. """ -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), +) as connection: for x in range(1, 100): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/m2m_oauth.py b/examples/m2m_oauth.py index eba2095c..1c8c7278 100644 --- a/examples/m2m_oauth.py +++ b/examples/m2m_oauth.py @@ -22,17 +22,19 @@ def credential_provider(): # Service Principal UUID client_id=os.getenv("DATABRICKS_CLIENT_ID"), # Service Principal Secret - client_secret=os.getenv("DATABRICKS_CLIENT_SECRET")) + client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"), + ) return oauth_service_principal(config) with sql.connect( - server_hostname=server_hostname, - http_path=os.getenv("DATABRICKS_HTTP_PATH"), - credentials_provider=credential_provider) as connection: + server_hostname=server_hostname, + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + credentials_provider=credential_provider, +) as connection: for x in range(1, 100): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/persistent_oauth.py b/examples/persistent_oauth.py index 0f2ba077..1a2eded2 100644 --- a/examples/persistent_oauth.py +++ b/examples/persistent_oauth.py @@ -17,37 +17,44 @@ from typing import Optional from databricks import sql -from databricks.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken, DevOnlyFilePersistence +from databricks.sql.experimental.oauth_persistence import ( + OAuthPersistence, + OAuthToken, + DevOnlyFilePersistence, +) class SampleOAuthPersistence(OAuthPersistence): - def persist(self, hostname: str, oauth_token: OAuthToken): - """To be implemented by the end user to persist in the preferred storage medium. + def persist(self, hostname: str, oauth_token: OAuthToken): + """To be implemented by the end user to persist in the preferred storage medium. - OAuthToken has two properties: - 1. OAuthToken.access_token - 2. OAuthToken.refresh_token + OAuthToken has two properties: + 1. OAuthToken.access_token + 2. OAuthToken.refresh_token - Both should be persisted. - """ - pass + Both should be persisted. + """ + pass - def read(self, hostname: str) -> Optional[OAuthToken]: - """To be implemented by the end user to fetch token from the preferred storage + def read(self, hostname: str) -> Optional[OAuthToken]: + """To be implemented by the end user to fetch token from the preferred storage - Fetch the access_token and refresh_token for the given hostname. - Return OAuthToken(access_token, refresh_token) - """ - pass + Fetch the access_token and refresh_token for the given hostname. + Return OAuthToken(access_token, refresh_token) + """ + pass -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - auth_type="databricks-oauth", - experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json")) as connection: + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + auth_type="databricks-oauth", + experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json"), +) as connection: for x in range(1, 100): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/query_cancel.py b/examples/query_cancel.py index 4e0b74a5..b67fc085 100644 --- a/examples/query_cancel.py +++ b/examples/query_cancel.py @@ -5,47 +5,52 @@ The current operation of a cursor may be cancelled by calling its `.cancel()` method as shown in the example below. """ -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: - with connection.cursor() as cursor: - def execute_really_long_query(): - try: - cursor.execute("SELECT SUM(A.id - B.id) " + - "FROM range(1000000000) A CROSS JOIN range(100000000) B " + - "GROUP BY (A.id - B.id)") - except sql.exc.RequestError: - print("It looks like this query was cancelled.") + with connection.cursor() as cursor: - exec_thread = threading.Thread(target=execute_really_long_query) + def execute_really_long_query(): + try: + cursor.execute( + "SELECT SUM(A.id - B.id) " + + "FROM range(1000000000) A CROSS JOIN range(100000000) B " + + "GROUP BY (A.id - B.id)" + ) + except sql.exc.RequestError: + print("It looks like this query was cancelled.") - print("\n Beginning to execute long query") - exec_thread.start() + exec_thread = threading.Thread(target=execute_really_long_query) - # Make sure the query has started before cancelling - print("\n Waiting 15 seconds before canceling", end="", flush=True) + print("\n Beginning to execute long query") + exec_thread.start() - seconds_waited = 0 - while seconds_waited < 15: - seconds_waited += 1 - print(".", end="", flush=True) - time.sleep(1) + # Make sure the query has started before cancelling + print("\n Waiting 15 seconds before canceling", end="", flush=True) - print("\n Cancelling the cursor's operation. This can take a few seconds.") - cursor.cancel() + seconds_waited = 0 + while seconds_waited < 15: + seconds_waited += 1 + print(".", end="", flush=True) + time.sleep(1) - print("\n Now checking the cursor status:") - exec_thread.join(5) + print("\n Cancelling the cursor's operation. This can take a few seconds.") + cursor.cancel() - assert not exec_thread.is_alive() - print("\n The previous command was successfully canceled") + print("\n Now checking the cursor status:") + exec_thread.join(5) - print("\n Now reusing the cursor to run a separate query.") + assert not exec_thread.is_alive() + print("\n The previous command was successfully canceled") - # We can still execute a new command on the cursor - cursor.execute("SELECT * FROM range(3)") + print("\n Now reusing the cursor to run a separate query.") - print("\n Execution was successful. Results appear below:") + # We can still execute a new command on the cursor + cursor.execute("SELECT * FROM range(3)") - print(cursor.fetchall()) + print("\n Execution was successful. Results appear below:") + + print(cursor.fetchall()) diff --git a/examples/query_execute.py b/examples/query_execute.py index a851ab50..38d2f17a 100644 --- a/examples/query_execute.py +++ b/examples/query_execute.py @@ -1,13 +1,15 @@ from databricks import sql import os -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: - with connection.cursor() as cursor: - cursor.execute("SELECT * FROM default.diamonds LIMIT 2") - result = cursor.fetchall() + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM default.diamonds LIMIT 2") + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/examples/set_user_agent.py b/examples/set_user_agent.py index 449692cf..93eb2e0b 100644 --- a/examples/set_user_agent.py +++ b/examples/set_user_agent.py @@ -1,14 +1,16 @@ from databricks import sql import os -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN"), - _user_agent_entry="ExamplePartnerTag") as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + _user_agent_entry="ExamplePartnerTag", +) as connection: - with connection.cursor() as cursor: - cursor.execute("SELECT * FROM default.diamonds LIMIT 2") - result = cursor.fetchall() + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM default.diamonds LIMIT 2") + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/examples/v3_retries_query_execute.py b/examples/v3_retries_query_execute.py index 4b6772fe..aaab47d1 100644 --- a/examples/v3_retries_query_execute.py +++ b/examples/v3_retries_query_execute.py @@ -28,16 +28,18 @@ # # For complete information about configuring retries, see the docstring for databricks.sql.thrift_backend.ThriftBackend -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN"), - _enable_v3_retries = True, - _retry_dangerous_codes=[502,400], - _retry_max_redirects=2) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + _enable_v3_retries=True, + _retry_dangerous_codes=[502, 400], + _retry_max_redirects=2, +) as connection: - with connection.cursor() as cursor: - cursor.execute("SELECT * FROM default.diamonds LIMIT 2") - result = cursor.fetchall() + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM default.diamonds LIMIT 2") + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4df67a08..4e0ab941 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,6 +1,7 @@ from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas + try: import pyarrow except ImportError: @@ -26,7 +27,7 @@ inject_parameters, transform_paramstyle, ColumnTable, - ColumnQueue + ColumnQueue, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -1155,7 +1156,7 @@ def _convert_columnar_table(self, table): for row_index in range(table.num_rows): curr_row = [] for col_index in range(table.num_columns): - curr_row.append(table.get_item(col_index, row_index)) + curr_row.append(table.get_item(col_index, row_index)) result.append(ResultRow(*curr_row)) return result @@ -1238,7 +1239,10 @@ def merge_columnar(self, result1, result2): if result1.column_names != result2.column_names: raise ValueError("The columns in the results don't match") - merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)] + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] return ColumnTable(merged_result, result1.column_names) def fetchmany_columnar(self, size: int): @@ -1254,9 +1258,9 @@ def fetchmany_columnar(self, size: int): self._next_row_index += results.num_rows while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 7f6ada9d..cf5cd906 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -624,7 +624,6 @@ def _get_metadata_resp(self, op_handle): @staticmethod def _hive_schema_to_arrow_schema(t_table_schema): - def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -733,10 +732,10 @@ def _results_message_to_execute_response(self, resp, operation_state): if pyarrow: schema_bytes = ( - t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) - .serialize() - .to_pybytes() + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() ) else: schema_bytes = None diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index ffeaeaf0..cd655c4e 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -12,6 +12,7 @@ import re import lz4.frame + try: import pyarrow except ImportError: @@ -103,6 +104,7 @@ def build_queue( else: raise AssertionError("Row set type is not valid") + class ColumnTable: def __init__(self, column_table, column_names): self.column_table = column_table @@ -123,11 +125,17 @@ def get_item(self, col_index, row_index): return self.column_table[col_index][row_index] def slice(self, curr_index, length): - sliced_column_table = [column[curr_index : curr_index + length] for column in self.column_table] + sliced_column_table = [ + column[curr_index : curr_index + length] for column in self.column_table + ] return ColumnTable(sliced_column_table, self.column_names) def __eq__(self, other): - return self.column_table == other.column_table and self.column_names == other.column_names + return ( + self.column_table == other.column_table + and self.column_names == other.column_names + ) + class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): @@ -143,7 +151,9 @@ def next_n_rows(self, num_rows): return slice def remaining_rows(self): - slice = self.column_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) + slice = self.column_table.slice( + self.cur_row_index, self.n_valid_rows - self.cur_row_index + ) self.cur_row_index += slice.num_rows return slice @@ -571,7 +581,9 @@ def transform_paramstyle( return output -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table": +def create_arrow_table_from_arrow_file( + file_bytes: bytes, description +) -> "pyarrow.Table": arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -621,22 +633,26 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description): converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": - converted_column_table.append(tuple(v if v is None else Decimal(v) for v in col)) + converted_column_table.append( + tuple(v if v is None else Decimal(v) for v in col) + ) elif description[i][1] == "date": - converted_column_table.append(tuple( - v if v is None else datetime.date.fromisoformat(v) for v in col - )) + converted_column_table.append( + tuple(v if v is None else datetime.date.fromisoformat(v) for v in col) + ) elif description[i][1] == "timestamp": - converted_column_table.append(tuple( - ( - v - if v is None - else datetime.datetime.strptime(v, "%Y-%m-%d %H:%M:%S.%f").replace( - tzinfo=pytz.UTC + converted_column_table.append( + tuple( + ( + v + if v is None + else datetime.datetime.strptime( + v, "%Y-%m-%d %H:%M:%S.%f" + ).replace(tzinfo=pytz.UTC) ) + for v in col ) - for v in col - )) + ) else: converted_column_table.append(col) @@ -734,4 +750,4 @@ def _create_python_tuple(t_col_value_wrapper): if nulls[i >> 3] & BIT_MASKS[i & 0x7]: result[i] = None - return tuple(result) \ No newline at end of file + return tuple(result) diff --git a/tests/e2e/common/core_tests.py b/tests/e2e/common/core_tests.py index e89289ef..3f0fdc05 100644 --- a/tests/e2e/common/core_tests.py +++ b/tests/e2e/common/core_tests.py @@ -4,15 +4,18 @@ TypeFailure = namedtuple( "TypeFailure", - "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf", ) ResultFailure = namedtuple( "ResultFailure", - "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf", ) ExecFailure = namedtuple( "ExecFailure", - "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf,error", + "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf,error", ) @@ -58,7 +61,9 @@ def run_tests_on_queries(self, default_conf): for query, columnType, rowValueType, answer in self.range_queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) + self.run_query( + cursor, query, columnType, rowValueType, answer, default_conf + ) ) failures.extend( self.run_range_query( @@ -69,7 +74,9 @@ def run_tests_on_queries(self, default_conf): for query, columnType, rowValueType, answer in self.queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) + self.run_query( + cursor, query, columnType, rowValueType, answer, default_conf + ) ) if failures: @@ -84,7 +91,9 @@ def run_query(self, cursor, query, columnType, rowValueType, answer, conf): try: cursor.execute(full_query) (result,) = cursor.fetchone() - if not all(cursor.description[0][1] == type for type in expected_column_types): + if not all( + cursor.description[0][1] == type for type in expected_column_types + ): return [ TypeFailure( full_query, @@ -150,7 +159,10 @@ def run_range_query(self, cursor, query, columnType, rowValueType, expected, con if len(rows) <= 0: break for index, (result, id) in enumerate(rows): - if not all(cursor.description[0][1] == type for type in expected_column_types): + if not all( + cursor.description[0][1] == type + for type in expected_column_types + ): return [ TypeFailure( full_query, @@ -163,7 +175,10 @@ def run_range_query(self, cursor, query, columnType, rowValueType, expected, con conf, ) ] - if self.validate_row_value_type and type(result) is not rowValueType: + if ( + self.validate_row_value_type + and type(result) is not rowValueType + ): return [ TypeFailure( full_query, diff --git a/tests/e2e/common/decimal_tests.py b/tests/e2e/common/decimal_tests.py index 5005cdf1..0029f30c 100644 --- a/tests/e2e/common/decimal_tests.py +++ b/tests/e2e/common/decimal_tests.py @@ -7,8 +7,16 @@ class DecimalTestsMixin: decimal_and_expected_results = [ ("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)), - ("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)), - ("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)), + ( + "1000000.0000 AS DECIMAL(11, 4)", + Decimal("1000000.0000"), + pyarrow.decimal128(11, 4), + ), + ( + "-10.2343 AS DECIMAL(10, 6)", + Decimal("-10.234300"), + pyarrow.decimal128(10, 6), + ), # TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False # ("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)), ("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)), @@ -30,7 +38,9 @@ class DecimalTestsMixin: ), ] - @pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results) + @pytest.mark.parametrize( + "decimal, expected_value, expected_type", decimal_and_expected_results + ) def test_decimals(self, decimal, expected_value, expected_type): with self.cursor({}) as cursor: query = "SELECT CAST ({})".format(decimal) @@ -44,7 +54,9 @@ def test_decimals(self, decimal, expected_value, expected_type): ) def test_multi_decimals(self, decimals, expected_values, expected_type): with self.cursor({}) as cursor: - union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) + union_str = " UNION ".join( + ["(SELECT CAST ({}))".format(dec) for dec in decimals] + ) query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str) cursor.execute(query) diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index 9ebc3f01..41ef029b 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -36,7 +36,9 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): num_fetches = max(math.ceil(n / 10000), 1) latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 print( - "Fetched {} rows with an avg latency of {} per fetch, ".format(n, latency_ms) + "Fetched {} rows with an avg latency of {} per fetch, ".format( + n, latency_ms + ) + "assuming 10K fetch size." ) @@ -55,10 +57,14 @@ def test_query_with_large_wide_result_set(self): cursor.connection.lz4_compression = lz4_compression uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) cursor.execute( - "SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows) + "SELECT id, {uuids} FROM RANGE({rows})".format( + uuids=uuids, rows=rows + ) ) assert lz4_compression == cursor.active_result_set.lz4_compressed - for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): + for row_id, row in enumerate( + self.fetch_rows(cursor, rows, fetchmany_size) + ): assert row[0] == row_id # Verify no rows are dropped in the middle. assert len(row[1]) == 36 diff --git a/tests/e2e/common/predicates.py b/tests/e2e/common/predicates.py index 88b14961..61de69fd 100644 --- a/tests/e2e/common/predicates.py +++ b/tests/e2e/common/predicates.py @@ -10,7 +10,8 @@ def pysql_supports_arrow(): """Import databricks.sql and test whether Cursor has fetchall_arrow.""" from databricks.sql.client import Cursor - return hasattr(Cursor, 'fetchall_arrow') + + return hasattr(Cursor, "fetchall_arrow") def pysql_has_version(compare, version): @@ -25,6 +26,7 @@ def test_some_pyhive_v1_stuff(): ... """ from databricks import sql + return compare_module_version(sql, compare, version) @@ -38,7 +40,7 @@ def is_endpoint_test(cli_args=None): def compare_dbr_versions(cli_args, compare, major_version, minor_version): if MAJOR_DBR_V_KEY in cli_args and MINOR_DBR_V_KEY in cli_args: if cli_args[MINOR_DBR_V_KEY] == "x": - actual_minor_v = float('inf') + actual_minor_v = float("inf") else: actual_minor_v = int(cli_args[MINOR_DBR_V_KEY]) dbr_version = (int(cli_args[MAJOR_DBR_V_KEY]), actual_minor_v) @@ -47,8 +49,10 @@ def compare_dbr_versions(cli_args, compare, major_version, minor_version): if not is_endpoint_test(): raise ValueError( - "DBR version not provided for non-endpoint test. Please pass the {} and {} params". - format(MAJOR_DBR_V_KEY, MINOR_DBR_V_KEY)) + "DBR version not provided for non-endpoint test. Please pass the {} and {} params".format( + MAJOR_DBR_V_KEY, MINOR_DBR_V_KEY + ) + ) def is_thrift_v5_plus(cli_args): @@ -56,18 +60,18 @@ def is_thrift_v5_plus(cli_args): _compare_fns = { - '<': '__lt__', - '<=': '__le__', - '>': '__gt__', - '>=': '__ge__', - '==': '__eq__', - '!=': '__ne__', + "<": "__lt__", + "<=": "__le__", + ">": "__gt__", + ">=": "__ge__", + "==": "__eq__", + "!=": "__ne__", } def compare_versions(compare, v1_tuple, v2_tuple): compare_fn_name = _compare_fns.get(compare) - assert compare_fn_name, 'Received invalid compare string: ' + compare + assert compare_fn_name, "Received invalid compare string: " + compare return getattr(v1_tuple, compare_fn_name)(v2_tuple) @@ -87,13 +91,15 @@ def test_some_pyhive_v1_stuff(): NOTE: This comparison leverages packaging.version.parse, and compares _release_ versions, thus ignoring pre/post release tags (eg -rc1, -dev, etc). """ - assert module, 'Received invalid module: ' + module - assert getattr(module, '__version__'), 'Received module with no version: ' + module + assert module, "Received invalid module: " + module + assert getattr(module, "__version__"), "Received module with no version: " + module def validate_version(version): v = parse_version(str(version)) # assert that we get a PEP-440 Version back -- LegacyVersion doesn't have major/minor. - assert hasattr(v, 'major'), 'Module has incompatible "Legacy" version: ' + version + assert hasattr(v, "major"), ( + 'Module has incompatible "Legacy" version: ' + version + ) return (v.major, v.minor, v.micro) mod_version = validate_version(module.__version__) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index 106a8fb5..7dd5f745 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -59,7 +59,9 @@ def _test_retry_disabled_with_message(self, error_msg_substring, exception_type) @contextmanager -def mocked_server_response(status: int = 200, headers: dict = {}, redirect_location: Optional[str] = None): +def mocked_server_response( + status: int = 200, headers: dict = {}, redirect_location: Optional[str] = None +): """Context manager for patching urllib3 responses""" # When mocking mocking a BaseHTTPResponse for urllib3 the mock must include @@ -97,7 +99,9 @@ def mock_sequential_server_responses(responses: List[dict]): # Each resp should have these members: for resp in responses: - _mock = MagicMock(headers=resp["headers"], msg=resp["headers"], status=resp["status"]) + _mock = MagicMock( + headers=resp["headers"], msg=resp["headers"], status=resp["status"] + ) _mock.get_redirect_location.return_value = ( False if resp["redirect_location"] is None else resp["redirect_location"] ) @@ -176,7 +180,9 @@ def test_retry_exponential_backoff(self): retry_policy["_retry_delay_min"] = 1 time_start = time.time() - with mocked_server_response(status=429, headers={"Retry-After": "3"}) as mock_obj: + with mocked_server_response( + status=429, headers={"Retry-After": "3"} + ) as mock_obj: with pytest.raises(RequestError) as cm: with self.connection(extra_params=retry_policy) as conn: pass @@ -256,7 +262,9 @@ def test_retry_dangerous_codes(self): assert isinstance(cm.value.args[1], UnsafeToRetryError) # Prove that these codes are retried if forced by the user - with self.connection(extra_params={**self._retry_policy, **additional_settings}) as conn: + with self.connection( + extra_params={**self._retry_policy, **additional_settings} + ) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): @@ -326,7 +334,9 @@ def test_retry_abort_close_operation_on_404(self, caplog): curs.execute("SELECT 1") with mock_sequential_server_responses(responses): curs.close() - assert "Operation was canceled by a prior request" in caplog.text + assert ( + "Operation was canceled by a prior request" in caplog.text + ) def test_retry_max_redirects_raises_too_many_redirects_exception(self): """GIVEN the connector is configured with a custom max_redirects @@ -337,7 +347,9 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): max_redirects, expected_call_count = 1, 2 # Code 302 is a redirect - with mocked_server_response(status=302, redirect_location="/foo.bar") as mock_obj: + with mocked_server_response( + status=302, redirect_location="/foo.bar" + ) as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ @@ -359,7 +371,9 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): _stop_after_attempts_count is enforced. """ # Code 302 is a redirect - with mocked_server_response(status=302, redirect_location="/foo.bar/") as mock_obj: + with mocked_server_response( + status=302, redirect_location="/foo.bar/" + ) as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ @@ -385,7 +399,9 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): with pytest.raises(RequestError) as cm: with mock_sequential_server_responses(responses): - with self.connection(extra_params={**self._retry_policy, **additional_settings}): + with self.connection( + extra_params={**self._retry_policy, **additional_settings} + ): pass # The error should be the result of the 500, not because of too many requests. @@ -405,9 +421,12 @@ def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog) assert "it will have no affect!" in caplog.text def test_retry_legacy_behavior_warns_user(self, caplog): - with self.connection(extra_params={**self._retry_policy, "_enable_v3_retries": False}): - assert "Legacy retry behavior is enabled for this connection." in caplog.text - + with self.connection( + extra_params={**self._retry_policy, "_enable_v3_retries": False} + ): + assert ( + "Legacy retry behavior is enabled for this connection." in caplog.text + ) def test_403_not_retried(self): """GIVEN the server returns a code 403 diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index d8d0429f..008055e3 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -41,7 +41,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): with open(fh, "wb") as fp: fp.write(original_text) - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" @@ -51,7 +53,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): new_fh, new_temp_path = tempfile.mkstemp() - with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": new_temp_path} + ) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -71,17 +75,19 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): + with pytest.raises( + Error, match="Staging operation over HTTP was unsuccessful: 404" + ): cursor = conn.cursor() - query = ( - f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" - ) + query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" cursor.execute(query) os.remove(temp_path) os.remove(new_temp_path) - def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, ingestion_user): + def test_staging_ingestion_put_fails_without_staging_allowed_local_path( + self, ingestion_user + ): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -93,7 +99,9 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, in with open(fh, "wb") as fp: fp.write(original_text) - with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): + with pytest.raises( + Error, match="You must provide at least one staging_allowed_local_path" + ): with self.connection() as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" @@ -119,12 +127,16 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): - with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": base_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self, ingestion_user): + def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set( + self, ingestion_user + ): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" fh, temp_path = tempfile.mkstemp() @@ -135,16 +147,22 @@ def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self, fp.write(original_text) def perform_put(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" cursor.execute(query) def perform_remove(): try: - remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" + remove_query = ( + f"REMOVE 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" + ) - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/"} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) except Exception: @@ -178,7 +196,9 @@ def test_staging_ingestion_fails_to_modify_another_staging_user(self): fp.write(original_text) def perform_put(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv' OVERWRITE" cursor.execute(query) @@ -186,12 +206,16 @@ def perform_put(): def perform_remove(): remove_query = f"REMOVE 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv'" - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/"} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) def perform_get(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{some_other_user}/tmp/11/15/file1.csv' TO '{temp_path}'" cursor.execute(query) @@ -232,7 +256,9 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, ingestion_user): + def test_staging_ingestion_empty_local_path_fails_to_parse_at_server( + self, ingestion_user + ): staging_allowed_local_path = "/var/www/html" target_file = "" @@ -244,7 +270,9 @@ def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, inges query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_invalid_staging_path_fails_at_server(self, ingestion_user): + def test_staging_ingestion_invalid_staging_path_fails_at_server( + self, ingestion_user + ): staging_allowed_local_path = "/var/www/html" target_file = "index.html" @@ -278,12 +306,29 @@ def generate_file_and_path_and_queries(): original_text = "hello world!".encode("utf-8") fp.write(original_text) put_query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv' OVERWRITE" - remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" + remove_query = ( + f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" + ) return fh, temp_path, put_query, remove_query - fh1, temp_path1, put_query1, remove_query1 = generate_file_and_path_and_queries() - fh2, temp_path2, put_query2, remove_query2 = generate_file_and_path_and_queries() - fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() + ( + fh1, + temp_path1, + put_query1, + remove_query1, + ) = generate_file_and_path_and_queries() + ( + fh2, + temp_path2, + put_query2, + remove_query2, + ) = generate_file_and_path_and_queries() + ( + fh3, + temp_path3, + put_query3, + remove_query3, + ) = generate_file_and_path_and_queries() with self.connection( extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} diff --git a/tests/e2e/common/timestamp_tests.py b/tests/e2e/common/timestamp_tests.py index f25aed7e..70ded7d0 100644 --- a/tests/e2e/common/timestamp_tests.py +++ b/tests/e2e/common/timestamp_tests.py @@ -15,7 +15,10 @@ class TimestampTestsMixin: ] timestamp_and_expected_results = [ - ("2021-09-30 11:27:35.123+04:00", datetime.datetime(2021, 9, 30, 7, 27, 35, 123000)), + ( + "2021-09-30 11:27:35.123+04:00", + datetime.datetime(2021, 9, 30, 7, 27, 35, 123000), + ), ("2021-09-30 11:27:35+04:00", datetime.datetime(2021, 9, 30, 7, 27, 35)), ("2021-09-30 11:27:35.123", datetime.datetime(2021, 9, 30, 11, 27, 35, 123000)), ("2021-09-30 11:27:35", datetime.datetime(2021, 9, 30, 11, 27, 35)), @@ -45,18 +48,24 @@ def assertTimestampsEqual(self, result, expected): def multi_query(self, n_rows=10): row_sql = "SELECT " + ", ".join( - ["TIMESTAMP('{}')".format(ts) for (ts, _) in self.timestamp_and_expected_results] + [ + "TIMESTAMP('{}')".format(ts) + for (ts, _) in self.timestamp_and_expected_results + ] ) query = " UNION ALL ".join([row_sql for _ in range(n_rows)]) expected_matrix = [ - [dt for (_, dt) in self.timestamp_and_expected_results] for _ in range(n_rows) + [dt for (_, dt) in self.timestamp_and_expected_results] + for _ in range(n_rows) ] return query, expected_matrix def test_timestamps(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: - cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + cursor.execute( + "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) + ) result = cursor.fetchone()[0] self.assertTimestampsEqual(result, expected) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 21e43036..72e2f502 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -40,19 +40,21 @@ def test_uc_volume_life_cycle(self, catalog, schema): with open(fh, "wb") as fp: fp.write(original_text) - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() - query = ( - f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" - ) + query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) # GET should succeed new_fh, new_temp_path = tempfile.mkstemp() - with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": new_temp_path} + ) as conn: cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -72,7 +74,9 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): + with pytest.raises( + Error, match="Staging operation over HTTP was unsuccessful: 404" + ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -80,7 +84,9 @@ def test_uc_volume_life_cycle(self, catalog, schema): os.remove(temp_path) os.remove(new_temp_path) - def test_uc_volume_put_fails_without_staging_allowed_local_path(self, catalog, schema): + def test_uc_volume_put_fails_without_staging_allowed_local_path( + self, catalog, schema + ): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -92,7 +98,9 @@ def test_uc_volume_put_fails_without_staging_allowed_local_path(self, catalog, s with open(fh, "wb") as fp: fp.write(original_text) - with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): + with pytest.raises( + Error, match="You must provide at least one staging_allowed_local_path" + ): with self.connection() as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" @@ -118,12 +126,16 @@ def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): - with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": base_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self, catalog, schema): + def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set( + self, catalog, schema + ): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" fh, temp_path = tempfile.mkstemp() @@ -134,16 +146,22 @@ def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self, catalog, fp.write(original_text) def perform_put(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" cursor.execute(query) def perform_remove(): try: - remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" + remove_query = ( + f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" + ) - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/"} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) except Exception: @@ -212,7 +230,9 @@ def test_uc_volume_invalid_volume_path_fails_at_server(self, catalog, schema): query = f"PUT '{target_file}' INTO '/Volumes/RANDOMSTRINGOFCHARACTERS/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_supports_multiple_staging_allowed_local_path_values(self, catalog, schema): + def test_uc_volume_supports_multiple_staging_allowed_local_path_values( + self, catalog, schema + ): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. This test confirms that two configured base paths: @@ -232,12 +252,29 @@ def generate_file_and_path_and_queries(): original_text = "hello world!".encode("utf-8") fp.write(original_text) put_query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv' OVERWRITE" - remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv'" + remove_query = ( + f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv'" + ) return fh, temp_path, put_query, remove_query - fh1, temp_path1, put_query1, remove_query1 = generate_file_and_path_and_queries() - fh2, temp_path2, put_query2, remove_query2 = generate_file_and_path_and_queries() - fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() + ( + fh1, + temp_path1, + put_query1, + remove_query1, + ) = generate_file_and_path_and_queries() + ( + fh2, + temp_path2, + put_query2, + remove_query2, + ) = generate_file_and_path_and_queries() + ( + fh3, + temp_path3, + put_query3, + remove_query3, + ) = generate_file_and_path_and_queries() with self.connection( extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index 0a7f514a..446a6b50 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -53,7 +53,9 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): @pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")]) def test_read_complex_types_as_string(self, field, table_fixture): """Confirms the return type of a complex type that is returned as a string""" - with self.cursor(extra_params={"_use_arrow_native_complex_types": False}) as cursor: + with self.cursor( + extra_params={"_use_arrow_native_complex_types": False} + ) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index c23e4f79..cfd1e969 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -40,7 +40,10 @@ from tests.e2e.common.large_queries_mixin import LargeQueriesMixin from tests.e2e.common.timestamp_tests import TimestampTestsMixin from tests.e2e.common.decimal_tests import DecimalTestsMixin -from tests.e2e.common.retry_test_mixins import Client429ResponseMixin, Client503ResponseMixin +from tests.e2e.common.retry_test_mixins import ( + Client429ResponseMixin, + Client503ResponseMixin, +) from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin @@ -57,7 +60,9 @@ # manually decorate DecimalTestsMixin to need arrow support for name in loader.getTestCaseNames(DecimalTestsMixin, "test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(fn) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( + fn + ) setattr(DecimalTestsMixin, name, decorated) @@ -68,7 +73,9 @@ class PySQLPytestTestCase: error_type = Error conf_to_disable_rate_limit_retries = {"_retry_stop_after_attempts_count": 1} - conf_to_disable_temporarily_unavailable_retries = {"_retry_stop_after_attempts_count": 1} + conf_to_disable_temporarily_unavailable_retries = { + "_retry_stop_after_attempts_count": 1 + } arraysize = 1000 buffer_size_bytes = 104857600 @@ -105,7 +112,9 @@ def connection(self, extra_params=()): @contextmanager def cursor(self, extra_params=()): with self.connection(extra_params) as conn: - cursor = conn.cursor(arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes) + cursor = conn.cursor( + arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes + ) try: yield cursor finally: @@ -144,7 +153,9 @@ def test_cloud_fetch(self): limits, threads, [True, False] ): with self.subTest( - num_limit=num_limit, num_threads=num_threads, lz4_compression=lz4_compression + num_limit=num_limit, + num_threads=num_threads, + lz4_compression=lz4_compression, ): cf_result, noop_result = None, None query = base_query + "LIMIT " + str(num_limit) @@ -289,7 +300,15 @@ def test_get_tables(self): ("TYPE_CAT", "string", None, None, None, None, None), ("TYPE_SCHEM", "string", None, None, None, None, None), ("TYPE_NAME", "string", None, None, None, None, None), - ("SELF_REFERENCING_COL_NAME", "string", None, None, None, None, None), + ( + "SELF_REFERENCING_COL_NAME", + "string", + None, + None, + None, + None, + None, + ), ("REF_GENERATION", "string", None, None, None, None, None), ] assert tables_desc == expected @@ -390,15 +409,21 @@ def test_escape_single_quotes(self): table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) # Test escape syntax directly cursor.execute( - "CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name) + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format( + table_name + ) + ) + cursor.execute( + "SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name) ) - cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name)) rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" # Test escape syntax in parameter cursor.execute( - "SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), + "SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format( + table_name, table_name + ), parameters={"var": "you're"}, ) rows = cursor.fetchall() @@ -427,7 +452,9 @@ def test_get_catalogs(self): cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description - assert catalogs_desc == [("TABLE_CAT", "string", None, None, None, None, None)] + assert catalogs_desc == [ + ("TABLE_CAT", "string", None, None, None, None, None) + ] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") def test_get_arrow(self): @@ -564,7 +591,8 @@ def test_temp_view_fetch(self): @skipIf(pysql_has_version("<", "2"), "requires pysql v2") @skipIf( - True, "Unclear the purpose of this test since urllib3 does not complain when timeout == 0" + True, + "Unclear the purpose of this test since urllib3 does not complain when timeout == 0", ) def test_socket_timeout(self): # We expect to see a BlockingIO error when the socket is opened @@ -587,7 +615,9 @@ def test_socket_timeout_user_defined(self): def test_ssp_passthrough(self): for enable_ansi in (True, False): - with self.cursor({"session_configuration": {"ansi_mode": enable_ansi}}) as cursor: + with self.cursor( + {"session_configuration": {"ansi_mode": enable_ansi}} + ) as cursor: cursor.execute("SET ansi_mode") assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] @@ -595,7 +625,9 @@ def test_ssp_passthrough(self): def test_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: - cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + cursor.execute( + "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) + ) arrow_table = cursor.fetchmany_arrow(1) if self.should_add_timezone(): ts_type = pyarrow.timestamp("us", tz="Etc/UTC") @@ -606,7 +638,9 @@ def test_timestamps_arrow(self): # To work consistently across different local timezones, we specify the timezone # of the expected result to # be UTC (what it should be by default on the server) - aware_timestamp = expected and expected.replace(tzinfo=datetime.timezone.utc) + aware_timestamp = expected and expected.replace( + tzinfo=datetime.timezone.utc + ) assert result_value == ( aware_timestamp and aware_timestamp.timestamp() * 1000000 ), "timestamp {} did not match {}".format(timestamp, expected) @@ -616,14 +650,16 @@ def test_multi_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: query, expected = self.multi_query() expected = [ - [self.maybe_add_timezone_to_timestamp(ts) for ts in row] for row in expected + [self.maybe_add_timezone_to_timestamp(ts) for ts in row] + for row in expected ] cursor.execute(query) table = cursor.fetchall_arrow() # Transpose columnar result to list of rows list_of_cols = [c.to_pylist() for c in table] result = [ - [col[row_index] for col in list_of_cols] for row_index in range(table.num_rows) + [col[row_index] for col in list_of_cols] + for row_index in range(table.num_rows) ] assert result == expected @@ -640,7 +676,9 @@ def test_timezone_with_timestamp(self): cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") arrow_result_table = cursor.fetchmany_arrow(1) - arrow_result_value = arrow_result_table.column(0).combine_chunks()[0].value + arrow_result_value = ( + arrow_result_table.column(0).combine_chunks()[0].value + ) ts_type = pyarrow.timestamp("us", tz="Europe/Amsterdam") assert arrow_result_table.field(0).type == ts_type @@ -700,7 +738,9 @@ def test_close_connection_closes_cursors(self): with self.connection() as conn: cursor = conn.cursor() - cursor.execute("SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()") + cursor.execute( + "SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()" + ) ars = cursor.active_result_set # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True @@ -709,14 +749,21 @@ def test_close_connection_closes_cursors(self): status_request = ttypes.TGetOperationStatusReq( operationHandle=ars.command_id, getProgressUpdate=False ) - op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) - assert op_status_at_server.operationState != ttypes.TOperationState.CLOSED_STATE + op_status_at_server = ars.thrift_backend._client.GetOperationStatus( + status_request + ) + assert ( + op_status_at_server.operationState + != ttypes.TOperationState.CLOSED_STATE + ) conn.close() # When connection closes, any cursor operations should no longer exist at the server with pytest.raises(SessionAlreadyClosedError) as cm: - op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) + op_status_at_server = ars.thrift_backend._client.GetOperationStatus( + status_request + ) def test_closing_a_closed_connection_doesnt_fail(self, caplog): caplog.set_level(logging.DEBUG) @@ -737,7 +784,9 @@ class HTTP429Suite(Client429ResponseMixin, PySQLPytestTestCase): class HTTP503Suite(Client503ResponseMixin, PySQLPytestTestCase): # 503Response suite gets custom error here vs PyODBC def test_retry_disabled(self): - self._test_retry_disabled_with_message("TEMPORARILY_UNAVAILABLE", OperationalError) + self._test_retry_disabled_with_message( + "TEMPORARILY_UNAVAILABLE", OperationalError + ) class TestPySQLUnityCatalogSuite(PySQLPytestTestCase): diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 47dfc38c..d346ad5c 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -167,12 +167,8 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle): This is a no-op but is included to make the test-code easier to read. """ target_column = self._get_inline_table_column(params.get("p")) - INSERT_QUERY = ( - f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" - ) - SELECT_QUERY = ( - f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" - ) + INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" + SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" with self.connection(extra_params={"use_inline_params": True}) as conn: @@ -308,11 +304,15 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) If a user explicitly sets use_inline_params, don't warn them about it. """ - extra_args = {"use_inline_params": use_inline_params} if use_inline_params else {} + extra_args = ( + {"use_inline_params": use_inline_params} if use_inline_params else {} + ) with self.connection(extra_params=extra_args) as conn: with conn.cursor() as cursor: - with self.patch_server_supports_native_params(supports_native_params=True): + with self.patch_server_supports_native_params( + supports_native_params=True + ): cursor.execute("SELECT %(p)s", parameters={"p": 1}) if use_inline_params is True: assert ( @@ -402,7 +402,9 @@ def test_inline_ordinals_can_break_sql(self): query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] with self.cursor(extra_params={"use_inline_params": True}) as cursor: - with pytest.raises(TypeError, match="not enough arguments for format string"): + with pytest.raises( + TypeError, match="not enough arguments for format string" + ): cursor.execute(query, parameters=params) def test_inline_named_dont_break_sql(self): diff --git a/tests/unit/test_arrow_queue.py b/tests/unit/test_arrow_queue.py index 6834cc9c..b3dff45f 100644 --- a/tests/unit/test_arrow_queue.py +++ b/tests/unit/test_arrow_queue.py @@ -14,13 +14,21 @@ def make_arrow_table(batch): return pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) def test_fetchmany_respects_n_rows(self): - arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) + arrow_table = self.make_arrow_table( + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + ) aq = ArrowQueue(arrow_table, 3) - self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]])) + self.assertEqual( + aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]]) + ) self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[6, 7, 8]])) def test_fetch_remaining_rows_respects_n_rows(self): - arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) + arrow_table = self.make_arrow_table( + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + ) aq = ArrowQueue(arrow_table, 3) self.assertEqual(aq.next_n_rows(1), self.make_arrow_table([[0, 1, 2]])) - self.assertEqual(aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]])) + self.assertEqual( + aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]]) + ) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index d6541525..d5b06bbf 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -9,7 +9,10 @@ ExternalAuthProvider, AuthType, ) -from databricks.sql.auth.auth import get_python_sql_connector_auth_provider, PYSQL_OAUTH_CLIENT_ID +from databricks.sql.auth.auth import ( + get_python_sql_connector_auth_provider, + PYSQL_OAUTH_CLIENT_ID, +) from databricks.sql.auth.oauth import OAuthManager from databricks.sql.auth.authenticators import DatabricksOAuthProvider from databricks.sql.auth.endpoint import ( @@ -177,12 +180,13 @@ def test_get_python_sql_connector_basic_auth(self): } with self.assertRaises(ValueError) as e: get_python_sql_connector_auth_provider("foo.cloud.databricks.com", **kwargs) - self.assertIn("Username/password authentication is no longer supported", str(e.exception)) + self.assertIn( + "Username/password authentication is no longer supported", str(e.exception) + ) @patch.object(DatabricksOAuthProvider, "_initial_get_token") def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" auth_provider = get_python_sql_connector_auth_provider(hostname) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") - self.assertTrue(auth_provider._client_id,PYSQL_OAUTH_CLIENT_ID) - + self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c86a9f7f..0ff660d5 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -13,7 +13,7 @@ TExecuteStatementResp, TOperationHandle, THandleIdentifier, - TOperationType + TOperationType, ) from databricks.sql.thrift_backend import ThriftBackend @@ -26,8 +26,8 @@ from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite -class ThriftBackendMockFactory: +class ThriftBackendMockFactory: @classmethod def new(cls): ThriftBackendMock = Mock(spec=ThriftBackend) @@ -68,10 +68,6 @@ def apply_property_to_mock(self, mock_obj, **kwargs): setattr(type(mock_obj), key, prop) - - - - class ClientTestSuite(unittest.TestCase): """ Unit tests for isolated client behaviour. @@ -89,7 +85,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -97,7 +93,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): # Check the close session request has an id of x22 close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b'\x22') + self.assertEqual(close_session_id, b"\x22") @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): @@ -155,13 +151,19 @@ def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) http_headers = mock_client_class.call_args[0][3] - user_agent_header = ("User-Agent", "{}/{}".format(databricks.sql.USER_AGENT_NAME, - databricks.sql.__version__)) + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) self.assertIn(user_agent_header, http_headers) databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, _user_agent_entry="foobar") - user_agent_header_with_entry = ("User-Agent", "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar")) + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) @@ -177,7 +179,9 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): cursor.execute("SELECT 1;") connection.close() - self.assertTrue(mock_result_set_class.return_value.has_been_closed_server_side) + self.assertTrue( + mock_result_set_class.return_value.has_been_closed_server_side + ) mock_result_set_class.return_value.close.assert_called_once_with() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @@ -192,7 +196,9 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) - def test_arraysize_buffer_size_passthrough(self, mock_cursor_class, mock_client_class): + def test_arraysize_buffer_size_passthrough( + self, mock_cursor_class, mock_client_class + ): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.cursor(arraysize=999, buffer_size_bytes=1234) kwargs = mock_cursor_class.call_args[1] @@ -204,7 +210,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() result_set = client.ResultSet( - connection=mock_connection, thrift_backend=mock_backend, execute_response=Mock()) + connection=mock_connection, + thrift_backend=mock_backend, + execute_response=Mock(), + ) mock_connection.open = False result_set.close() @@ -218,20 +227,27 @@ def test_closing_result_set_hard_closes_commands(self): mock_connection = Mock() mock_thrift_backend = Mock() mock_connection.open = True - result_set = client.ResultSet(mock_connection, mock_results_response, mock_thrift_backend) + result_set = client.ResultSet( + mock_connection, mock_results_response, mock_thrift_backend + ) result_set.close() mock_thrift_backend.close_command.assert_called_once_with( - mock_results_response.command_handle) + mock_results_response.command_handle + ) @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executing_multiple_commands_uses_the_most_recent_command(self, mock_result_set_class): + def test_executing_multiple_commands_uses_the_most_recent_command( + self, mock_result_set_class + ): mock_result_sets = [Mock(), Mock()] mock_result_set_class.side_effect = mock_result_sets - cursor = client.Cursor(connection=Mock(), thrift_backend=ThriftBackendMockFactory.new()) + cursor = client.Cursor( + connection=Mock(), thrift_backend=ThriftBackendMockFactory.new() + ) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -272,7 +288,7 @@ def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: @@ -280,7 +296,7 @@ def test_context_manager_closes_connection(self, mock_client_class): # Check the close session request has an id of x22 close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b'\x22') + self.assertEqual(close_session_id, b"\x22") def dict_product(self, dicts): """ @@ -299,7 +315,9 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe req_args_combinations = self.dict_product( dict( catalog_name=["NOT_SET", None, "catalog_pattern"], - schema_name=["NOT_SET", None, "schema_pattern"])) + schema_name=["NOT_SET", None, "schema_pattern"], + ) + ) for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} @@ -320,7 +338,9 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen catalog_name=["NOT_SET", None, "catalog_pattern"], schema_name=["NOT_SET", None, "schema_pattern"], table_name=["NOT_SET", None, "table_pattern"], - table_types=["NOT_SET", [], ["type1", "type2"]])) + table_types=["NOT_SET", [], ["type1", "type2"]], + ) + ) for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} @@ -341,7 +361,9 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe catalog_name=["NOT_SET", None, "catalog_pattern"], schema_name=["NOT_SET", None, "schema_pattern"], table_name=["NOT_SET", None, "table_pattern"], - column_name=["NOT_SET", None, "column_pattern"])) + column_name=["NOT_SET", None, "column_pattern"], + ) + ) for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} @@ -365,7 +387,8 @@ def test_cancel_command_calls_the_backend(self): @patch("databricks.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( - self, logger_instance): + self, logger_instance + ): mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.cancel() @@ -375,9 +398,13 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect(_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS) + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) - self.assertEqual(mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54) + self.assertEqual( + mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): @@ -386,27 +413,38 @@ def test_socket_timeout_passthrough(self, mock_client_class): def test_version_is_canonical(self): version = databricks.sql.__version__ - canonical_version_re = r'^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)' \ - r'(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$' + canonical_version_re = ( + r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)" + r"(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$" + ) self.assertIsNotNone(re.match(canonical_version_re, version)) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS) + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) - self.assertEqual(mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem) - self.assertEqual(mock_client_class.return_value.open_session.call_args[0][1], mock_cat) - self.assertEqual(mock_client_class.return_value.open_session.call_args[0][2], mock_schem) + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][1], mock_cat + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][2], mock_schem + ) def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() @@ -436,7 +474,8 @@ def test_execute_parameter_passthrough(self): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class, mock_thrift_backend): + self, mock_result_set_class, mock_thrift_backend + ): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] mock_result_set_class.side_effect = mock_result_set_instances @@ -449,17 +488,22 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( cursor.executemany("SELECT %(x)s", seq_of_parameters=params) self.assertEqual( - len(mock_thrift_backend.execute_command.call_args_list), len(expected_queries), - "Expected execute_command to be called the same number of times as params were passed") + len(mock_thrift_backend.execute_command.call_args_list), + len(expected_queries), + "Expected execute_command to be called the same number of times as params were passed", + ) - for expected_query, call_args in zip(expected_queries, - mock_thrift_backend.execute_command.call_args_list): + for expected_query, call_args in zip( + expected_queries, mock_thrift_backend.execute_command.call_args_list + ): self.assertEqual(call_args[1]["operation"], expected_query) self.assertEqual( - cursor.active_result_set, mock_result_set_instances[2], + cursor.active_result_set, + mock_result_set_instances[2], "Expected the active result set to be the result set corresponding to the" - "last operation") + "last operation", + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): @@ -495,7 +539,7 @@ def make_fake_row_slice(n_rows): mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) - cursor.execute('foo') + cursor.execute("foo") self.assertEqual(cursor.rownumber, 0) cursor.fetchmany_arrow(10) @@ -516,12 +560,14 @@ def test_disable_pandas_respected(self, mock_thrift_backend_class): mock_aq = Mock() mock_aq.remaining_rows.return_value = mock_table mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.execute_command.return_value.has_been_closed_server_side = True + mock_thrift_backend.execute_command.return_value.has_been_closed_server_side = ( + True + ) mock_con = Mock() mock_con.disable_pandas = True cursor = client.Cursor(mock_con, mock_thrift_backend) - cursor.execute('foo') + cursor.execute("foo") cursor.fetchall() mock_table.itercolumns.assert_called_once_with() @@ -548,18 +594,21 @@ def test_column_name_api(self): self.assertEqual(row[1], expected[1]) self.assertEqual(row[2], expected[2]) - self.assertEqual(row.asDict(), { - "first_col": expected[0], - "second_col": expected[1], - "third_col": expected[2] - }) + self.assertEqual( + row.asDict(), + { + "first_col": expected[0], + "second_col": expected[1], + "third_col": expected[2], + }, + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -569,14 +618,14 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): # Check the close session request has an id of x22 close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b'\x22') + self.assertEqual(close_session_id, b"\x22") @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -591,11 +640,14 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_staging_operation_response_is_handled(self, mock_client_class, mock_handle_staging_operation, mock_execute_response): + def test_staging_operation_response_is_handled( + self, mock_client_class, mock_handle_staging_operation, mock_execute_response + ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called - - ThriftBackendMockFactory.apply_property_to_mock(mock_execute_response, is_staging_operation=True) + ThriftBackendMockFactory.apply_property_to_mock( + mock_execute_response, is_staging_operation=True + ) mock_client_class.execute_command.return_value = mock_execute_response mock_client_class.return_value = mock_client_class @@ -608,7 +660,7 @@ def test_staging_operation_response_is_handled(self, mock_client_class, mock_han @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) def test_access_current_query_id(self): - operation_id = 'EE6A8778-21FC-438B-92D8-96AC51EE3821' + operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() @@ -617,17 +669,23 @@ def test_access_current_query_id(self): cursor.active_op_handle = TOperationHandle( operationId=THandleIdentifier(guid=UUID(operation_id).bytes, secret=0x00), - operationType=TOperationType.EXECUTE_STATEMENT) + operationType=TOperationType.EXECUTE_STATEMENT, + ) self.assertEqual(cursor.query_id.upper(), operation_id.upper()) cursor.close() self.assertIsNone(cursor.query_id) -if __name__ == '__main__': +if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) loader = unittest.TestLoader() - test_classes = [ClientTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite] + test_classes = [ + ClientTestSuite, + FetchTests, + ThriftBackendTestSuite, + ArrowQueueSuite, + ] suites_list = [] for test_class in test_classes: suite = loader.loadTestsFromTestCase(test_class) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index acd0c392..01d8a79b 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -6,22 +6,26 @@ import databricks.sql.utils as utils from databricks.sql.types import SSLOptions -class CloudFetchQueueSuite(unittest.TestCase): +class CloudFetchQueueSuite(unittest.TestCase): def create_result_link( - self, - file_link: str = "fileLink", - start_row_offset: int = 0, - row_count: int = 8000, - bytes_num: int = 20971520 + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520, ): - return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + return TSparkArrowResultLink( + file_link, None, start_row_offset, row_count, bytes_num + ) def create_result_links(self, num_files: int, start_row_offset: int = 0): result_links = [] for i in range(num_files): file_link = "fileLink_" + str(i) - result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_link = self.create_result_link( + file_link=file_link, start_row_offset=start_row_offset + ) result_links.append(result_link) start_row_offset += result_link.rowCount return result_links @@ -42,8 +46,10 @@ def get_schema_bytes(): writer.close() return sink.getvalue().to_pybytes() - - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=[None, None]) + @patch( + "databricks.sql.utils.CloudFetchQueue._create_next_table", + return_value=[None, None], + ) def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) @@ -72,7 +78,10 @@ def test_initializer_no_links_to_add(self): assert len(queue.download_manager._download_tasks) == 0 assert queue.table is None - @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) + @patch( + "databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", + return_value=None, + ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): queue = utils.CloudFetchQueue( MagicMock(), @@ -85,9 +94,13 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): mock_get_next_downloaded_file.assert_called_with(0) @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") - @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", - return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) - def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): + @patch( + "databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", + return_value=MagicMock(file_bytes=b"1234567890", row_count=4), + ) + def test_initializer_create_next_table_success( + self, mock_get_next_downloaded_file, mock_create_arrow_table + ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue( @@ -169,7 +182,12 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): result = queue.next_n_rows(7) assert result.num_rows == 7 assert queue.table_row_index == 3 - assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] + assert ( + result + == pyarrow.concat_tables( + [self.make_arrow_table(), self.make_arrow_table()] + )[:7] + ) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): @@ -265,8 +283,14 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result == self.make_arrow_table() @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") - def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), self.make_arrow_table(), None] + def test_remaining_rows_multiple_tables_fully_returned( + self, mock_create_next_table + ): + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(), + None, + ] schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue( schema_bytes, @@ -282,7 +306,12 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta result = queue.remaining_rows() assert mock_create_next_table.call_count == 3 assert result.num_rows == 5 - assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[3:] + assert ( + result + == pyarrow.concat_tables( + [self.make_arrow_table(), self.make_arrow_table()] + )[3:] + ) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_remaining_rows_empty_table(self, mock_create_next_table): diff --git a/tests/unit/test_column_queue.py b/tests/unit/test_column_queue.py index 130b589b..234af88e 100644 --- a/tests/unit/test_column_queue.py +++ b/tests/unit/test_column_queue.py @@ -8,14 +8,18 @@ def make_column_table(table): return ColumnTable(table, [f"col_{i}" for i in range(n_cols)]) def test_fetchmany_respects_n_rows(self): - column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_table = self.make_column_table( + [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] + ) column_queue = ColumnQueue(column_table) assert column_queue.next_n_rows(2) == column_table.slice(0, 2) assert column_queue.next_n_rows(2) == column_table.slice(2, 2) def test_fetch_remaining_rows_respects_n_rows(self): - column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_table = self.make_column_table( + [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] + ) column_queue = ColumnQueue(column_table) assert column_queue.next_n_rows(2) == column_table.slice(0, 2) diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index a11bc8d4..64edbdeb 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -11,7 +11,9 @@ class DownloadManagerTests(unittest.TestCase): Unit tests for checking download manager logic. """ - def create_download_manager(self, links, max_download_threads=10, lz4_compressed=True): + def create_download_manager( + self, links, max_download_threads=10, lz4_compressed=True + ): return download_manager.ResultFileDownloadManager( links, max_download_threads, @@ -20,19 +22,23 @@ def create_download_manager(self, links, max_download_threads=10, lz4_compressed ) def create_result_link( - self, - file_link: str = "fileLink", - start_row_offset: int = 0, - row_count: int = 8000, - bytes_num: int = 20971520 + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520, ): - return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + return TSparkArrowResultLink( + file_link, None, start_row_offset, row_count, bytes_num + ) def create_result_links(self, num_files: int, start_row_offset: int = 0): result_links = [] for i in range(num_files): file_link = "fileLink_" + str(i) - result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_link = self.create_result_link( + file_link=file_link, start_row_offset=start_row_offset + ) result_links.append(result_link) start_row_offset += result_link.rowCount return result_links @@ -41,7 +47,9 @@ def test_add_file_links_zero_row_count(self): links = [self.create_result_link(row_count=0, bytes_num=0)] manager = self.create_download_manager(links) - assert len(manager._pending_links) == 0 # the only link supplied contains no data, so should be skipped + assert ( + len(manager._pending_links) == 0 + ) # the only link supplied contains no data, so should be skipped assert len(manager._download_tasks) == 0 def test_add_file_links_success(self): @@ -55,7 +63,9 @@ def test_add_file_links_success(self): def test_schedule_downloads(self, mock_submit): max_download_threads = 4 links = self.create_result_links(num_files=10) - manager = self.create_download_manager(links, max_download_threads=max_download_threads) + manager = self.create_download_manager( + links, max_download_threads=max_download_threads + ) manager._schedule_downloads() assert mock_submit.call_count == max_download_threads diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 7075ef6c..2a3b715b 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -20,36 +20,40 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ - @patch('time.time', return_value=1000) + @patch("time.time", return_value=1000) def test_run_link_expired(self, mock_time): settings = Mock() result_link = Mock() # Already expired result_link.expiryTime = 999 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(Error) as context: d.run() - self.assertTrue('link has expired' in context.exception.message) + self.assertTrue("link has expired" in context.exception.message) mock_time.assert_called_once() - @patch('time.time', return_value=1000) + @patch("time.time", return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() # Within the expiry buffer time result_link.expiryTime = 1004 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(Error) as context: d.run() - self.assertTrue('link has expired' in context.exception.message) + self.assertTrue("link has expired" in context.exception.message) mock_time.assert_called_once() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=None))) - @patch('time.time', return_value=1000) + @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) + @patch("time.time", return_value=1000) def test_run_get_response_not_ok(self, mock_time, mock_session): mock_session.return_value.get.return_value = create_response(status_code=404) @@ -58,62 +62,81 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() - self.assertTrue('404' in str(context.exception)) + self.assertTrue("404" in str(context.exception)) - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=None))) - @patch('time.time', return_value=1000) + @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) + @patch("time.time", return_value=1000) def test_run_uncompressed_successful(self, mock_time, mock_session): file_bytes = b"1234567890" * 10 - mock_session.return_value.get.return_value = create_response(status_code=200, _content=file_bytes) + mock_session.return_value.get.return_value = create_response( + status_code=200, _content=file_bytes + ) settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) file = d.run() assert file.file_bytes == b"1234567890" * 10 - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) - @patch('time.time', return_value=1000) + @patch( + "requests.Session", + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))), + ) + @patch("time.time", return_value=1000) def test_run_compressed_successful(self, mock_time, mock_session): file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - mock_session.return_value.get.return_value = create_response(status_code=200, _content=compressed_bytes) + mock_session.return_value.get.return_value = create_response( + status_code=200, _content=compressed_bytes + ) settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) file = d.run() assert file.file_bytes == b"1234567890" * 10 - @patch('requests.Session.get', side_effect=ConnectionError('foo')) - @patch('time.time', return_value=1000) + @patch("requests.Session.get", side_effect=ConnectionError("foo")) + @patch("time.time", return_value=1000) def test_download_connection_error(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) + settings = Mock( + link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True + ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = \ - b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(ConnectionError): d.run() - @patch('requests.Session.get', side_effect=TimeoutError('foo')) - @patch('time.time', return_value=1000) + @patch("requests.Session.get", side_effect=TimeoutError("foo")) + @patch("time.time", return_value=1000) def test_download_timeout(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) + settings = Mock( + link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True + ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = \ - b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7d5686f8..e9a58acd 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -17,7 +17,9 @@ def make_arrow_table(batch): n_cols = len(batch[0]) if batch else 0 schema = pa.schema({"col%s" % i: pa.uint32() for i in range(n_cols)}) cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)] - return schema, pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) + return schema, pa.Table.from_pydict( + dict(zip(schema.names, cols)), schema=schema + ) @staticmethod def make_arrow_queue(batch): @@ -42,18 +44,29 @@ def make_dummy_result_set_from_initial_results(initial_results): command_handle=None, arrow_queue=arrow_queue, arrow_schema_bytes=schema.serialize().to_pybytes(), - is_staging_operation=False)) + is_staging_operation=False, + ), + ) num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [(f'col{col_id}', 'integer', None, None, None, None, None) - for col_id in range(num_cols)] + rs.description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] return rs @staticmethod def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 - def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, lz4_compressed, - arrow_schema_bytes, description): + def fetch_results( + op_handle, + max_rows, + max_bytes, + expected_row_start_offset, + lz4_compressed, + arrow_schema_bytes, + description, + ): nonlocal batch_index results = FetchTests.make_arrow_queue(batch_list[batch_index]) batch_index += 1 @@ -71,13 +84,17 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, lz4 status=None, has_been_closed_server_side=False, has_more_rows=True, - description=[(f'col{col_id}', 'integer', None, None, None, None, None) - for col_id in range(num_cols)], + description=[ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ], lz4_compressed=Mock(), command_handle=None, arrow_queue=None, arrow_schema_bytes=None, - is_staging_operation=False)) + is_staging_operation=False, + ), + ) return rs def assertEqualRowValues(self, actual, expected): @@ -87,30 +104,44 @@ def assertEqualRowValues(self, actual, expected): def test_fetchmany_with_initial_results(self): # Fetch all in one go - initial_results_1 = [[1], [2], [3]] # This is a list of rows, each row with 1 col - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + initial_results_1 = [ + [1], + [2], + [3], + ] # This is a list of rows, each row with 1 col + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(3), [[1], [2], [3]]) # Fetch in small amounts initial_results_2 = [[1], [2], [3], [4]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_2) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_2 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(1), [[1]]) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[2], [3]]) self.assertEqualRowValues(dummy_result_set.fetchmany(1), [[4]]) # Fetch too many initial_results_3 = [[2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_3) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_3 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(5), [[2], [3]]) # Empty results initial_results_4 = [[]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_4) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_4 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(0), []) def test_fetch_many_without_initial_results(self): # Fetch all in one go; single batch - batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [ + [[1], [2], [3]] + ] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchmany(3), [[1], [2], [3]]) @@ -140,7 +171,9 @@ def test_fetch_many_without_initial_results(self): # Fetch too many; multiple batches batch_list_6 = [[[1]], [[2], [3], [4]], [[5], [6]]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_6) - self.assertEqualRowValues(dummy_result_set.fetchmany(100), [[1], [2], [3], [4], [5], [6]]) + self.assertEqualRowValues( + dummy_result_set.fetchmany(100), [[1], [2], [3], [4], [5], [6]] + ) # Fetch 0; 1 empty batch batch_list_7 = [[]] @@ -154,19 +187,25 @@ def test_fetch_many_without_initial_results(self): def test_fetchall_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3]]) def test_fetchall_without_initial_results(self): # Fetch all, single batch - batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [ + [[1], [2], [3]] + ] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3]]) # Fetch all, multiple batches batch_list_2 = [[[1], [2]], [[3]], [[4], [5], [6]]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2) - self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3], [4], [5], [6]]) + self.assertEqualRowValues( + dummy_result_set.fetchall(), [[1], [2], [3], [4], [5], [6]] + ) batch_list_3 = [[]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_3) @@ -174,12 +213,16 @@ def test_fetchall_without_initial_results(self): def test_fetchmany_fetchall_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[1], [2]]) self.assertEqualRowValues(dummy_result_set.fetchall(), [[3]]) def test_fetchmany_fetchall_without_initial_results(self): - batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [ + [[1], [2], [3]] + ] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[1], [2]]) self.assertEqualRowValues(dummy_result_set.fetchall(), [[3]]) @@ -191,7 +234,9 @@ def test_fetchmany_fetchall_without_initial_results(self): def test_fetchone_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertSequenceEqual(dummy_result_set.fetchone(), [1]) self.assertSequenceEqual(dummy_result_set.fetchone(), [2]) self.assertSequenceEqual(dummy_result_set.fetchone(), [3]) @@ -210,5 +255,5 @@ def test_fetchone_without_initial_results(self): self.assertEqual(dummy_result_set.fetchone(), None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index e322b44a..9382c3b3 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -35,12 +35,18 @@ def make_dummy_result_set_from_initial_results(arrow_table): description=Mock(), command_handle=None, arrow_queue=arrow_queue, - arrow_schema=arrow_table.schema)) - rs.description = [(f'col{col_id}', 'string', None, None, None, None, None) - for col_id in range(arrow_table.num_columns)] + arrow_schema=arrow_table.schema, + ), + ) + rs.description = [ + (f"col{col_id}", "string", None, None, None, None, None) + for col_id in range(arrow_table.num_columns) + ] return rs - @pytest.mark.skip(reason="Test has not been updated for latest connector API (June 2022)") + @pytest.mark.skip( + reason="Test has not been updated for latest connector API (June 2022)" + ) def test_benchmark_fetchall(self): print("preparing dummy arrow table") arrow_table = FetchBenchmarkTests.make_arrow_table(10, 25000) @@ -50,7 +56,9 @@ def test_benchmark_fetchall(self): start_time = time.time() count = 0 while time.time() < start_time + benchmark_seconds: - dummy_result_set = self.make_dummy_result_set_from_initial_results(arrow_table) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + arrow_table + ) res = dummy_result_set.fetchall() for _ in res: pass @@ -59,5 +67,5 @@ def test_benchmark_fetchall(self): print(f"Executed query {count} times, in {time.time() - start_time} seconds") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py index 28b3cab3..a8ceb14e 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/tests/unit/test_oauth_persistence.py @@ -1,16 +1,17 @@ - import unittest -from databricks.sql.experimental.oauth_persistence import DevOnlyFilePersistence, OAuthToken +from databricks.sql.experimental.oauth_persistence import ( + DevOnlyFilePersistence, + OAuthToken, +) import tempfile import os class OAuthPersistenceTests(unittest.TestCase): - def test_DevOnlyFilePersistence_read_my_write(self): with tempfile.TemporaryDirectory() as tempdir: - test_json_file_path = os.path.join(tempdir, 'test.json') + test_json_file_path = os.path.join(tempdir, "test.json") persistence_manager = DevOnlyFilePersistence(test_json_file_path) access_token = "abc#$%%^&^*&*()()_=-/" refresh_token = "#$%%^^&**()+)_gter243]xyz" @@ -23,7 +24,7 @@ def test_DevOnlyFilePersistence_read_my_write(self): def test_DevOnlyFilePersistence_file_does_not_exist(self): with tempfile.TemporaryDirectory() as tempdir: - test_json_file_path = os.path.join(tempdir, 'test.json') + test_json_file_path = os.path.join(tempdir, "test.json") persistence_manager = DevOnlyFilePersistence(test_json_file_path) new_token = persistence_manager.read("https://randomserver") diff --git a/tests/unit/test_param_escaper.py b/tests/unit/test_param_escaper.py index 472a0843..925fcea5 100644 --- a/tests/unit/test_param_escaper.py +++ b/tests/unit/test_param_escaper.py @@ -3,7 +3,12 @@ from typing import Any, Dict from databricks.sql.parameters.native import dbsql_parameter_from_primitive -from databricks.sql.utils import ParamEscaper, inject_parameters, transform_paramstyle, ParameterStructure +from databricks.sql.utils import ( + ParamEscaper, + inject_parameters, + transform_paramstyle, + ParameterStructure, +) pe = ParamEscaper() @@ -200,26 +205,31 @@ class TestInlineToNativeTransformer(object): "query with like wildcard", 'select * from table where field like "%"', {}, - 'select * from table where field like "%"' + 'select * from table where field like "%"', ), ( "query with named param and like wildcard", 'select :param from table where field like "%"', {"param": None}, - 'select :param from table where field like "%"' + 'select :param from table where field like "%"', ), ( "query with doubled wildcards", - 'select 1 where '' like "%%"', + "select 1 where " ' like "%%"', {"param": None}, - 'select 1 where '' like "%%"', - ) + "select 1 where " ' like "%%"', + ), ), ) def test_transformer( self, label: str, query: str, params: Dict[str, Any], expected: str ): - _params = [dbsql_parameter_from_primitive(value=value, name=name) for name, value in params.items()] - output = transform_paramstyle(query, _params, param_structure=ParameterStructure.NAMED) + _params = [ + dbsql_parameter_from_primitive(value=value, name=name) + for name, value in params.items() + ] + output = transform_paramstyle( + query, _params, param_structure=ParameterStructure.NAMED + ) assert output == expected diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 798bac2e..2108af4f 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -8,7 +8,6 @@ class TestRetry: - @pytest.fixture() def retry_policy(self) -> DatabricksRetryPolicy: return DatabricksRetryPolicy( @@ -41,7 +40,9 @@ def test_sleep__retry_after_is_binding(self, t_mock, retry_policy, error_history t_mock.assert_called_with(3) @patch("time.sleep") - def test_sleep__retry_after_present_but_not_binding(self, t_mock, retry_policy, error_history): + def test_sleep__retry_after_present_but_not_binding( + self, t_mock, retry_policy, error_history + ): retry_policy._retry_start_time = time.time() retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503, headers={"Retry-After": "1"})) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0333766c..293467af 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -69,7 +69,14 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -79,7 +86,14 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -89,13 +103,16 @@ def _make_fake_thrift_backend(self): def test_hive_schema_to_arrow_schema_preserves_column_names(self): columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 1", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) @@ -136,7 +153,9 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) - self.assertIn("expected server to use a protocol version", str(cm.exception)) + self.assertIn( + "expected server to use a protocol version", str(cm.exception) + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @@ -157,8 +176,17 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider(), ssl_options=SSLOptions()) - t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) + ThriftBackend( + "foo", + 123, + "bar", + [("header", "value")], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + t_http_client_class.return_value.setCustomHeaders.assert_called_with( + {"header": "value"} + ) def test_proxy_headers_are_set(self): @@ -174,11 +202,13 @@ def test_proxy_headers_are_set(self): assert False assert isinstance(result, type(dict())) - assert isinstance(result.get('proxy-authorization'), type(str())) + assert isinstance(result.get("proxy-authorization"), type(str())) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") - def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): + def test_tls_cert_args_are_propagated( + self, mock_create_default_context, t_http_client_class + ): mock_cert_key_file = Mock() mock_cert_key_password = Mock() mock_trusted_ca_file = Mock() @@ -203,11 +233,15 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ ) mock_ssl_context.load_cert_chain.assert_called_once_with( - certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password + certfile=mock_cert_file, + keyfile=mock_cert_key_file, + password=mock_cert_key_password, ) self.assertTrue(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + self.assertEqual( + t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options + ) @patch("databricks.sql.types.create_default_context") def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context): @@ -232,7 +266,7 @@ def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context ssl_options=mock_ssl_options, ) - self.assertEqual(http_client.scheme, 'https') + self.assertEqual(http_client.scheme, "https") self.assertEqual(http_client.certfile, mock_ssl_options.tls_client_cert_file) self.assertEqual(http_client.keyfile, mock_ssl_options.tls_client_cert_key_file) self.assertIsNotNone(http_client.certfile) @@ -246,7 +280,9 @@ def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context self.assertEqual(conn_pool.ca_certs, mock_ssl_options.tls_trusted_ca_file) self.assertEqual(conn_pool.cert_file, mock_ssl_options.tls_client_cert_file) self.assertEqual(conn_pool.key_file, mock_ssl_options.tls_client_cert_key_file) - self.assertEqual(conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password) + self.assertEqual( + conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password + ) def test_tls_no_verify_is_respected_by_http_client(self): from databricks.sql.auth.thrift_http_client import THttpClient @@ -256,7 +292,7 @@ def test_tls_no_verify_is_respected_by_http_client(self): uri_or_host="https://example.com", ssl_options=SSLOptions(tls_verify=False), ) - self.assertEqual(http_client.scheme, 'https') + self.assertEqual(http_client.scheme, "https") http_client.open() @@ -266,16 +302,27 @@ def test_tls_no_verify_is_respected_by_http_client(self): @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") - def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): + def test_tls_no_verify_is_respected( + self, mock_create_default_context, t_http_client_class + ): mock_ssl_options = SSLOptions(tls_verify=False) mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options) + ThriftBackend( + "foo", + 123, + "bar", + [], + auth_provider=AuthProvider(), + ssl_options=mock_ssl_options, + ) self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) - self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + self.assertEqual( + t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") @@ -287,61 +334,126 @@ def test_tls_verify_hostname_is_respected( mock_create_default_context.assert_called() ThriftBackend( - "foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options + "foo", + 123, + "bar", + [], + auth_provider=AuthProvider(), + ssl_options=mock_ssl_options, ) self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + self.assertEqual( + t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + ThriftBackend( + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) self.assertEqual( - t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value", ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + ThriftBackend( + "https://hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) self.assertEqual( - t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value", ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + ThriftBackend( + "https://hostname/", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) self.assertEqual( - t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value", ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=129 + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + _socket_timeout=129, + ) + self.assertEqual( + t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000 ) - self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=0 + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) - self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=None + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + self.assertEqual( + t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 + ) + ThriftBackend( + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + _socket_timeout=None, + ) + self.assertEqual( + t_http_client_class.return_value.setTimeout.call_args[0][0], None ) - self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) def test_non_primitive_types_raise_error(self): columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 1", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( columnName="column 2", typeDesc=ttypes.TTypeDesc( types=[ - ttypes.TTypeEntry(userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo")) + ttypes.TTypeEntry( + userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo") + ) ] ), ), @@ -358,13 +470,16 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): # canary test columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 1", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.BOOLEAN_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.BOOLEAN_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.MAP_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.MAP_TYPE), ), ttypes.TColumnDesc( columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.STRUCT_TYPE) @@ -395,8 +510,12 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): type=ttypes.TTypeId.DECIMAL_TYPE, typeQualifiers=ttypes.TTypeQualifiers( qualifiers={ - "precision": ttypes.TTypeQualifierValue(i32Value=10), - "scale": ttypes.TTypeQualifierValue(i32Value=100), + "precision": ttypes.TTypeQualifierValue( + i32Value=10 + ), + "scale": ttypes.TTypeQualifierValue( + i32Value=100 + ), } ), ) @@ -416,8 +535,18 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): ) def test_make_request_checks_status_code(self): - error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS] - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + error_codes = [ + ttypes.TStatusCode.ERROR_STATUS, + ttypes.TStatusCode.INVALID_HANDLE_STATUS, + ] + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) for code in error_codes: mock_error_response = Mock() @@ -455,15 +584,24 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): ), ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertIn("some information about the error", str(cm.exception)) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) - def test_handle_execute_response_sets_compression_in_direct_results(self, build_queue): + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) + def test_handle_execute_response_sets_compression_in_direct_results( + self, build_queue + ): for resp_type in self.execute_response_types: lz4Compressed = Mock() resultSet = MagicMock() @@ -484,13 +622,24 @@ def test_handle_execute_response_sets_compression_in_direct_results(self, build_ closeOperation=None, ), ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) - execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + execute_response = thrift_backend._handle_execute_response( + t_execute_resp, Mock() + ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_service_class): + def test_handle_execute_response_checks_operation_state_in_polls( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value error_resp = ttypes.TGetOperationStatusResp( @@ -506,7 +655,9 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv for op_state_resp, exec_resp_type in itertools.product( [error_resp, closed_resp], self.execute_response_types ): - with self.subTest(op_state_resp=op_state_resp, exec_resp_type=exec_resp_type): + with self.subTest( + op_state_resp=op_state_resp, exec_resp_type=exec_resp_type + ): tcli_service_instance = tcli_service_class.return_value t_execute_resp = exec_resp_type( status=self.okay_status, @@ -516,7 +667,12 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv tcli_service_instance.GetOperationStatus.return_value = op_state_resp thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: @@ -539,12 +695,23 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ) t_execute_resp = ttypes.TExecuteStatementResp( - status=self.okay_status, directResults=None, operationHandle=self.operation_handle + status=self.okay_status, + directResults=None, + operationHandle=self.operation_handle, + ) + tcli_service_instance.GetOperationStatus.return_value = ( + t_get_operation_status_resp ) - tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -577,7 +744,14 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -589,7 +763,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resp_1 = resp_type( status=self.okay_status, directResults=ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp(status=self.bad_status), + operationStatus=ttypes.TGetOperationStatusResp( + status=self.bad_status + ), resultSetMetadata=None, resultSet=None, closeOperation=None, @@ -600,7 +776,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): status=self.okay_status, directResults=ttypes.TSparkDirectResults( operationStatus=None, - resultSetMetadata=ttypes.TGetResultSetMetadataResp(status=self.bad_status), + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + status=self.bad_status + ), resultSet=None, closeOperation=None, ), @@ -629,7 +807,12 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: @@ -637,7 +820,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): self.assertIn("this is a bad error", str(cm.exception)) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_can_handle_without_direct_results(self, tcli_service_class): + def test_handle_execute_response_can_handle_without_direct_results( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value for resp_type in self.execute_response_types: @@ -660,23 +845,32 @@ def test_handle_execute_response_can_handle_without_direct_results(self, tcli_se ) op_state_3 = ttypes.TGetOperationStatusResp( - status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, ) - tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) tcli_service_instance.GetOperationStatus.side_effect = [ op_state_1, op_state_2, op_state_3, ] thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) results_message_response = thrift_backend._handle_execute_response( execute_resp, Mock() ) self.assertEqual( - results_message_response.status, ttypes.TOperationState.FINISHED_STATE + results_message_response.status, + ttypes.TOperationState.FINISHED_STATE, ) def test_handle_execute_response_can_handle_with_direct_results(self): @@ -701,7 +895,12 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -731,9 +930,13 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = t_get_result_set_metadata_resp + tcli_service_instance.GetResultSetMetadata.return_value = ( + t_get_result_set_metadata_resp + ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + execute_response = thrift_backend._handle_execute_response( + t_execute_resp, Mock() + ) self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @@ -760,10 +963,13 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( - hive_schema_mock, thrift_backend._hive_schema_to_arrow_schema.call_args[0][0] + hive_schema_mock, + thrift_backend._hive_schema_to_arrow_schema.call_args[0][0], ) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue @@ -794,14 +1000,20 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) self.assertEqual(has_more_rows, execute_response.has_more_rows) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue @@ -836,8 +1048,12 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( ) tcli_service_instance.FetchResults.return_value = fetch_results_resp - tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp + tcli_service_instance.GetOperationStatus.return_value = ( + operation_status_resp + ) + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -864,7 +1080,8 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): startRowOffset=0, rows=[], arrowBatches=[ - ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) for _ in range(10) + ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) + for _ in range(10) ], ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( @@ -885,7 +1102,14 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) arrow_queue, has_more_results = thrift_backend.fetch_results( op_handle=Mock(), max_rows=1, @@ -899,11 +1123,20 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_execute_statement_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -914,14 +1147,25 @@ def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_s self.assertEqual(req.getDirectResults, get_direct_results) self.assertEqual(req.statement, "foo") # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_catalogs_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -931,14 +1175,25 @@ def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_servic get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) self.assertEqual(req.getDirectResults, get_direct_results) # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_schemas_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -957,14 +1212,25 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service self.assertEqual(req.catalogName, "catalog_pattern") self.assertEqual(req.schemaName, "schema_pattern") # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_tables_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -987,14 +1253,25 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ self.assertEqual(req.tableName, "table_pattern") self.assertEqual(req.tableTypes, ["type1", "type2"]) # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_columns_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -1017,21 +1294,37 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service self.assertEqual(req.tableName, "table_pattern") self.assertEqual(req.columnName, "column_pattern") # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend.close_command(self.operation_handle) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, @@ -1041,20 +1334,32 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend.close_session(self.session_handle) self.assertEqual( - tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle + tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, + self.session_handle, ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_class): + def test_non_arrow_non_column_based_set_triggers_exception( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 execute_statement_resp = ttypes.TExecuteStatementResp( - status=self.okay_status, directResults=None, operationHandle=self.operation_handle + status=self.okay_status, + directResults=None, + operationHandle=self.operation_handle, ) metadata_resp = ttypes.TGetResultSetMetadataResp( @@ -1075,11 +1380,20 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) - self.assertIn("Expected results to be in Arrow or column based format", str(cm.exception)) + self.assertIn( + "Expected results to be in Arrow or column based format", str(cm.exception) + ) def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1088,7 +1402,14 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1099,18 +1420,31 @@ def test_create_arrow_table_calls_correct_conversion_method( description = Mock() t_col_set = ttypes.TRowSet(columns=cols) - thrift_backend._create_arrow_table(t_col_set, lz4_compressed, schema, description) + thrift_backend._create_arrow_table( + t_col_set, lz4_compressed, schema, description + ) convert_arrow_mock.assert_not_called() convert_col_mock.assert_called_once_with(cols, description) t_arrow_set = ttypes.TRowSet(arrowBatches=arrow_batches) thrift_backend._create_arrow_table(t_arrow_set, lz4_compressed, schema, Mock()) - convert_arrow_mock.assert_called_once_with(arrow_batches, lz4_compressed, schema) + convert_arrow_mock.assert_called_once_with( + arrow_batches, lz4_compressed, schema + ) @patch("lz4.frame.decompress") @patch("pyarrow.ipc.open_stream") - def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_decompress_mock): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + def test_convert_arrow_based_set_to_arrow_table( + self, open_stream_mock, lz4_decompress_mock + ): + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) lz4_decompress_mock.return_value = bytearray("Testing", "utf-8") @@ -1142,15 +1476,23 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self): t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) + stringVal=ttypes.TStringColumn( + values=["s1", "s2", "s3"], nulls=bytes(1) + ) ), - ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) + doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1)) + ), + ttypes.TColumn( + binaryVal=ttypes.TBinaryColumn( + values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1) + ) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( + t_cols, description + ) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -1175,19 +1517,29 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self): description = [(name,) for name in field_names] t_cols = [ - ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1]))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes([2])) + i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1])) ), ttypes.TColumn( - doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes([4])) + stringVal=ttypes.TStringColumn( + values=["s1", "s2", "s3"], nulls=bytes([2]) + ) ), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes([3])) + doubleVal=ttypes.TDoubleColumn( + values=[1.15, 2.2, 3.3], nulls=bytes([4]) + ) + ), + ttypes.TColumn( + binaryVal=ttypes.TBinaryColumn( + values=[b"\x11", b"\x22", b"\x33"], nulls=bytes([3]) + ) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( + t_cols, description + ) self.assertEqual(n_rows, 3) # Check data @@ -1203,15 +1555,23 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) + stringVal=ttypes.TStringColumn( + values=["s1", "s2", "s3"], nulls=bytes(1) + ) ), - ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) + doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1)) + ), + ttypes.TColumn( + binaryVal=ttypes.TBinaryColumn( + values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1) + ) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( + t_cols, description + ) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -1257,8 +1617,12 @@ def test_handle_execute_response_sets_active_op_handle(self): self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch( + "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" + ) + @patch( + "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + ) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class ): @@ -1270,7 +1634,9 @@ def test_make_request_will_retry_GetOperationStatus( this_gos_name = "GetOperationStatus" mock_GetOperationStatus.__name__ = this_gos_name - mock_GetOperationStatus.side_effect = OSError(errno.ETIMEDOUT, "Connection timed out") + mock_GetOperationStatus.side_effect = OSError( + errno.ETIMEDOUT, "Connection timed out" + ) protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(t_transport_class) client = Client(protocol) @@ -1299,12 +1665,18 @@ def test_make_request_will_retry_GetOperationStatus( self.assertEqual( NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] ) - self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) + self.assertEqual( + f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"] + ) # Unusual OSError code - mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") + mock_GetOperationStatus.side_effect = OSError( + errno.EEXIST, "File does not exist" + ) - with self.assertLogs("databricks.sql.thrift_backend", level=logging.WARNING) as cm: + with self.assertLogs( + "databricks.sql.thrift_backend", level=logging.WARNING + ) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1320,8 +1692,12 @@ def test_make_request_will_retry_GetOperationStatus( cm.output[0], ) - @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch( + "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" + ) + @patch( + "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + ) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos ): @@ -1366,10 +1742,14 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( self.assertEqual( NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] ) - self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) + self.assertEqual( + f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"] + ) @patch("thrift.transport.THttpClient.THttpClient") - def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_class): + def test_make_request_wont_retry_if_error_code_not_429_or_503( + self, t_transport_class + ): t_transport_instance = t_transport_class.return_value t_transport_instance.code = 430 t_transport_instance.headers = {"Retry-After": "1"} @@ -1377,7 +1757,14 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1385,7 +1772,9 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ self.assertIn("This method fails", str(cm.exception.message_with_context())) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch( + "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + ) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class ): @@ -1417,13 +1806,22 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self.assertEqual(mock_method.call_count, 14) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - def test_make_request_will_read_error_message_headers_if_set(self, t_transport_class): + def test_make_request_will_read_error_message_headers_if_set( + self, t_transport_class + ): t_transport_instance = t_transport_class.return_value mock_method = Mock() mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) error_headers = [ [("x-thriftserver-error-message", "thrift server error message")], @@ -1457,12 +1855,18 @@ def make_table_and_desc( ): int_col = [int_constant for _ in range(height)] decimal_col = [decimal_constant for _ in range(height)] - data = OrderedDict({"col{}".format(i): int_col for i in range(width - n_decimal_cols)}) - decimals = OrderedDict({"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)}) + data = OrderedDict( + {"col{}".format(i): int_col for i in range(width - n_decimal_cols)} + ) + decimals = OrderedDict( + {"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)} + ) data.update(decimals) int_desc = [("", "int")] * (width - n_decimal_cols) - decimal_desc = [("", "decimal", None, None, precision, scale, None)] * n_decimal_cols + decimal_desc = [ + ("", "decimal", None, None, precision, scale, None) + ] * n_decimal_cols description = int_desc + decimal_desc table = pyarrow.Table.from_pydict(data) @@ -1495,25 +1899,36 @@ def test_arrow_decimal_conversion(self): if height > 0: if i < width - n_decimal_cols: self.assertEqual( - decimal_converted_table.field(i).type, pyarrow.int64() + decimal_converted_table.field(i).type, + pyarrow.int64(), ) else: self.assertEqual( decimal_converted_table.field(i).type, - pyarrow.decimal128(precision=precision, scale=scale), + pyarrow.decimal128( + precision=precision, scale=scale + ), ) int_col = [int_constant for _ in range(height)] decimal_col = [Decimal(decimal_constant) for _ in range(height)] expected_result = OrderedDict( - {"col{}".format(i): int_col for i in range(width - n_decimal_cols)} + { + "col{}".format(i): int_col + for i in range(width - n_decimal_cols) + } ) decimals = OrderedDict( - {"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)} + { + "col_dec{}".format(i): decimal_col + for i in range(n_decimal_cols) + } ) expected_result.update(decimals) - self.assertEqual(decimal_converted_table.to_pydict(), expected_result) + self.assertEqual( + decimal_converted_table.to_pydict(), expected_result + ) @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_passthrough(self, mock_http_client): @@ -1524,7 +1939,13 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_duration": 100, } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + **retry_delay_args, ) for arg, val in retry_delay_args.items(): self.assertEqual(getattr(backend, arg), val) @@ -1533,17 +1954,28 @@ def test_retry_args_passthrough(self, mock_http_client): def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): - retry_delay_test_args_and_expected_values[k] = ((min - 1, min), (max + 1, max)) + retry_delay_test_args_and_expected_values[k] = ( + (min - 1, min), + (max + 1, max), + ) for i in range(2): retry_delay_args = { - k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][0] + for (k, v) in retry_delay_test_args_and_expected_values.items() } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + **retry_delay_args, ) retry_delay_expected_vals = { - k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][1] + for (k, v) in retry_delay_test_args_and_expected_values.items() } for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) @@ -1560,7 +1992,14 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) backend.open_session(mock_config, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1571,7 +2010,14 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(databricks.sql.Error) as cm: backend.open_session(mock_config, None, None) @@ -1590,7 +2036,14 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] for cat, schem in initial_cat_schem_args: @@ -1601,26 +2054,46 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): backend.open_session({}, cat, schem) - open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] + open_session_req = tcli_client_class.return_value.OpenSession.call_args[ + 0 + ][0] self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_class): + def test_can_use_multiple_catalogs_is_set_in_open_session_req( + self, tcli_client_class + ): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) backend.open_session({}, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertTrue(open_session_req.canUseMultipleCatalogs) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): + def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( + self, tcli_client_class + ): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False # is fine @@ -1658,13 +2131,21 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), ) - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, "cat", "schem") self.assertIn( - "Setting initial namespace not supported by the DBR version", str(cm.exception) + "Setting initial namespace not supported by the DBR version", + str(cm.exception), ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) @@ -1686,10 +2167,18 @@ def test_execute_command_sets_complex_type_fields_correctly( complex_arg_types["_use_arrow_native_decimals"] = decimals thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **complex_arg_types + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + **complex_arg_types, ) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) - t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0] + t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ + 0 + ][0] # If the value is unset, the native type should default to True self.assertEqual( t_execute_statement_req.useArrowNativeTypes.timestampAsArrow, @@ -1703,7 +2192,9 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.complexTypesAsArrow, complex_arg_types.get("_use_arrow_native_complex_types", True), ) - self.assertFalse(t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow) + self.assertFalse( + t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow + ) if __name__ == "__main__":