Skip to content

Commit

Permalink
Reformatted all the files using black (#448)
Browse files Browse the repository at this point in the history
Reformatted the files using black
  • Loading branch information
jprakash-db authored Oct 3, 2024
1 parent 08f14a0 commit 97c815e
Show file tree
Hide file tree
Showing 36 changed files with 1,521 additions and 580 deletions.
22 changes: 13 additions & 9 deletions examples/custom_cred_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,27 @@
from databricks.sdk.oauth import OAuthClient
import os

oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
redirect_url=os.getenv("APP_REDIRECT_URL"),
scopes=['all-apis', 'offline_access'])
oauth_client = OAuthClient(
host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
redirect_url=os.getenv("APP_REDIRECT_URL"),
scopes=["all-apis", "offline_access"],
)

consent = oauth_client.initiate_consent()

creds = consent.launch_external_browser()

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
credentials_provider=creds) as connection:
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
credentials_provider=creds,
) as connection:

for x in range(1, 5):
cursor = connection.cursor()
cursor.execute('SELECT 1+1')
cursor.execute("SELECT 1+1")
result = cursor.fetchall()
for row in result:
print(row)
Expand Down
26 changes: 14 additions & 12 deletions examples/insert_data.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from databricks import sql
import os

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
) as connection:

with connection.cursor() as cursor:
cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)")
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)")

squares = [(i, i * i) for i in range(100)]
values = ",".join([f"({x}, {y})" for (x, y) in squares])
squares = [(i, i * i) for i in range(100)]
values = ",".join([f"({x}, {y})" for (x, y) in squares])

cursor.execute(f"INSERT INTO squares VALUES {values}")
cursor.execute(f"INSERT INTO squares VALUES {values}")

cursor.execute("SELECT * FROM squares LIMIT 10")
cursor.execute("SELECT * FROM squares LIMIT 10")

result = cursor.fetchall()
result = cursor.fetchall()

for row in result:
print(row)
for row in result:
print(row)
8 changes: 5 additions & 3 deletions examples/interactive_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
token across script executions.
"""

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH")) as connection:
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
) as connection:

for x in range(1, 100):
cursor = connection.cursor()
cursor.execute('SELECT 1+1')
cursor.execute("SELECT 1+1")
result = cursor.fetchall()
for row in result:
print(row)
Expand Down
12 changes: 7 additions & 5 deletions examples/m2m_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@ def credential_provider():
# Service Principal UUID
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
# Service Principal Secret
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"))
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
)
return oauth_service_principal(config)


with sql.connect(
server_hostname=server_hostname,
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
credentials_provider=credential_provider) as connection:
server_hostname=server_hostname,
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
credentials_provider=credential_provider,
) as connection:
for x in range(1, 100):
cursor = connection.cursor()
cursor.execute('SELECT 1+1')
cursor.execute("SELECT 1+1")
result = cursor.fetchall()
for row in result:
print(row)
Expand Down
47 changes: 27 additions & 20 deletions examples/persistent_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,44 @@
from typing import Optional

from databricks import sql
from databricks.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken, DevOnlyFilePersistence
from databricks.sql.experimental.oauth_persistence import (
OAuthPersistence,
OAuthToken,
DevOnlyFilePersistence,
)


class SampleOAuthPersistence(OAuthPersistence):
def persist(self, hostname: str, oauth_token: OAuthToken):
"""To be implemented by the end user to persist in the preferred storage medium.
def persist(self, hostname: str, oauth_token: OAuthToken):
"""To be implemented by the end user to persist in the preferred storage medium.
OAuthToken has two properties:
1. OAuthToken.access_token
2. OAuthToken.refresh_token
OAuthToken has two properties:
1. OAuthToken.access_token
2. OAuthToken.refresh_token
Both should be persisted.
"""
pass
Both should be persisted.
"""
pass

def read(self, hostname: str) -> Optional[OAuthToken]:
"""To be implemented by the end user to fetch token from the preferred storage
def read(self, hostname: str) -> Optional[OAuthToken]:
"""To be implemented by the end user to fetch token from the preferred storage
Fetch the access_token and refresh_token for the given hostname.
Return OAuthToken(access_token, refresh_token)
"""
pass
Fetch the access_token and refresh_token for the given hostname.
Return OAuthToken(access_token, refresh_token)
"""
pass

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
auth_type="databricks-oauth",
experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json")) as connection:

with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
auth_type="databricks-oauth",
experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json"),
) as connection:

for x in range(1, 100):
cursor = connection.cursor()
cursor.execute('SELECT 1+1')
cursor.execute("SELECT 1+1")
result = cursor.fetchall()
for row in result:
print(row)
Expand Down
69 changes: 37 additions & 32 deletions examples/query_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,52 @@
The current operation of a cursor may be cancelled by calling its `.cancel()` method as shown in the example below.
"""

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
) as connection:

with connection.cursor() as cursor:
def execute_really_long_query():
try:
cursor.execute("SELECT SUM(A.id - B.id) " +
"FROM range(1000000000) A CROSS JOIN range(100000000) B " +
"GROUP BY (A.id - B.id)")
except sql.exc.RequestError:
print("It looks like this query was cancelled.")
with connection.cursor() as cursor:

exec_thread = threading.Thread(target=execute_really_long_query)
def execute_really_long_query():
try:
cursor.execute(
"SELECT SUM(A.id - B.id) "
+ "FROM range(1000000000) A CROSS JOIN range(100000000) B "
+ "GROUP BY (A.id - B.id)"
)
except sql.exc.RequestError:
print("It looks like this query was cancelled.")

print("\n Beginning to execute long query")
exec_thread.start()
exec_thread = threading.Thread(target=execute_really_long_query)

# Make sure the query has started before cancelling
print("\n Waiting 15 seconds before canceling", end="", flush=True)
print("\n Beginning to execute long query")
exec_thread.start()

seconds_waited = 0
while seconds_waited < 15:
seconds_waited += 1
print(".", end="", flush=True)
time.sleep(1)
# Make sure the query has started before cancelling
print("\n Waiting 15 seconds before canceling", end="", flush=True)

print("\n Cancelling the cursor's operation. This can take a few seconds.")
cursor.cancel()
seconds_waited = 0
while seconds_waited < 15:
seconds_waited += 1
print(".", end="", flush=True)
time.sleep(1)

print("\n Now checking the cursor status:")
exec_thread.join(5)
print("\n Cancelling the cursor's operation. This can take a few seconds.")
cursor.cancel()

assert not exec_thread.is_alive()
print("\n The previous command was successfully canceled")
print("\n Now checking the cursor status:")
exec_thread.join(5)

print("\n Now reusing the cursor to run a separate query.")
assert not exec_thread.is_alive()
print("\n The previous command was successfully canceled")

# We can still execute a new command on the cursor
cursor.execute("SELECT * FROM range(3)")
print("\n Now reusing the cursor to run a separate query.")

print("\n Execution was successful. Results appear below:")
# We can still execute a new command on the cursor
cursor.execute("SELECT * FROM range(3)")

print(cursor.fetchall())
print("\n Execution was successful. Results appear below:")

print(cursor.fetchall())
18 changes: 10 additions & 8 deletions examples/query_execute.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from databricks import sql
import os

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
) as connection:

with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()
with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()

for row in result:
print(row)
for row in result:
print(row)
20 changes: 11 additions & 9 deletions examples/set_user_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from databricks import sql
import os

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
access_token = os.getenv("DATABRICKS_TOKEN"),
_user_agent_entry="ExamplePartnerTag") as connection:
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
_user_agent_entry="ExamplePartnerTag",
) as connection:

with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()
with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()

for row in result:
print(row)
for row in result:
print(row)
24 changes: 13 additions & 11 deletions examples/v3_retries_query_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@
#
# For complete information about configuring retries, see the docstring for databricks.sql.thrift_backend.ThriftBackend

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
access_token = os.getenv("DATABRICKS_TOKEN"),
_enable_v3_retries = True,
_retry_dangerous_codes=[502,400],
_retry_max_redirects=2) as connection:
with sql.connect(
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
_enable_v3_retries=True,
_retry_dangerous_codes=[502, 400],
_retry_max_redirects=2,
) as connection:

with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()
with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()

for row in result:
print(row)
for row in result:
print(row)
16 changes: 10 additions & 6 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence

import pandas

try:
import pyarrow
except ImportError:
Expand All @@ -26,7 +27,7 @@
inject_parameters,
transform_paramstyle,
ColumnTable,
ColumnQueue
ColumnQueue,
)
from databricks.sql.parameters.native import (
DbsqlParameterBase,
Expand Down Expand Up @@ -1155,7 +1156,7 @@ def _convert_columnar_table(self, table):
for row_index in range(table.num_rows):
curr_row = []
for col_index in range(table.num_columns):
curr_row.append(table.get_item(col_index, row_index))
curr_row.append(table.get_item(col_index, row_index))
result.append(ResultRow(*curr_row))

return result
Expand Down Expand Up @@ -1238,7 +1239,10 @@ def merge_columnar(self, result1, result2):
if result1.column_names != result2.column_names:
raise ValueError("The columns in the results don't match")

merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)]
merged_result = [
result1.column_table[i] + result2.column_table[i]
for i in range(result1.num_columns)
]
return ColumnTable(merged_result, result1.column_names)

def fetchmany_columnar(self, size: int):
Expand All @@ -1254,9 +1258,9 @@ def fetchmany_columnar(self, size: int):
self._next_row_index += results.num_rows

while (
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
Expand Down
Loading

0 comments on commit 97c815e

Please sign in to comment.