Skip to content

Commit 97c815e

Browse files
authored
Reformatted all the files using black (databricks#448)
Reformatted the files using black
1 parent 08f14a0 commit 97c815e

36 files changed

+1521
-580
lines changed

examples/custom_cred_provider.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,27 @@
44
from databricks.sdk.oauth import OAuthClient
55
import os
66

7-
oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
8-
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
9-
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
10-
redirect_url=os.getenv("APP_REDIRECT_URL"),
11-
scopes=['all-apis', 'offline_access'])
7+
oauth_client = OAuthClient(
8+
host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
9+
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
10+
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
11+
redirect_url=os.getenv("APP_REDIRECT_URL"),
12+
scopes=["all-apis", "offline_access"],
13+
)
1214

1315
consent = oauth_client.initiate_consent()
1416

1517
creds = consent.launch_external_browser()
1618

17-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
18-
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
19-
credentials_provider=creds) as connection:
19+
with sql.connect(
20+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
21+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
22+
credentials_provider=creds,
23+
) as connection:
2024

2125
for x in range(1, 5):
2226
cursor = connection.cursor()
23-
cursor.execute('SELECT 1+1')
27+
cursor.execute("SELECT 1+1")
2428
result = cursor.fetchall()
2529
for row in result:
2630
print(row)

examples/insert_data.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
from databricks import sql
22
import os
33

4-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
5-
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
6-
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:
4+
with sql.connect(
5+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
6+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
7+
access_token=os.getenv("DATABRICKS_TOKEN"),
8+
) as connection:
79

8-
with connection.cursor() as cursor:
9-
cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)")
10+
with connection.cursor() as cursor:
11+
cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)")
1012

11-
squares = [(i, i * i) for i in range(100)]
12-
values = ",".join([f"({x}, {y})" for (x, y) in squares])
13+
squares = [(i, i * i) for i in range(100)]
14+
values = ",".join([f"({x}, {y})" for (x, y) in squares])
1315

14-
cursor.execute(f"INSERT INTO squares VALUES {values}")
16+
cursor.execute(f"INSERT INTO squares VALUES {values}")
1517

16-
cursor.execute("SELECT * FROM squares LIMIT 10")
18+
cursor.execute("SELECT * FROM squares LIMIT 10")
1719

18-
result = cursor.fetchall()
20+
result = cursor.fetchall()
1921

20-
for row in result:
21-
print(row)
22+
for row in result:
23+
print(row)

examples/interactive_oauth.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
token across script executions.
1414
"""
1515

16-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
17-
http_path = os.getenv("DATABRICKS_HTTP_PATH")) as connection:
16+
with sql.connect(
17+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
18+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
19+
) as connection:
1820

1921
for x in range(1, 100):
2022
cursor = connection.cursor()
21-
cursor.execute('SELECT 1+1')
23+
cursor.execute("SELECT 1+1")
2224
result = cursor.fetchall()
2325
for row in result:
2426
print(row)

examples/m2m_oauth.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@ def credential_provider():
2222
# Service Principal UUID
2323
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
2424
# Service Principal Secret
25-
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"))
25+
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
26+
)
2627
return oauth_service_principal(config)
2728

2829

2930
with sql.connect(
30-
server_hostname=server_hostname,
31-
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
32-
credentials_provider=credential_provider) as connection:
31+
server_hostname=server_hostname,
32+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
33+
credentials_provider=credential_provider,
34+
) as connection:
3335
for x in range(1, 100):
3436
cursor = connection.cursor()
35-
cursor.execute('SELECT 1+1')
37+
cursor.execute("SELECT 1+1")
3638
result = cursor.fetchall()
3739
for row in result:
3840
print(row)

examples/persistent_oauth.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,44 @@
1717
from typing import Optional
1818

1919
from databricks import sql
20-
from databricks.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken, DevOnlyFilePersistence
20+
from databricks.sql.experimental.oauth_persistence import (
21+
OAuthPersistence,
22+
OAuthToken,
23+
DevOnlyFilePersistence,
24+
)
2125

2226

2327
class SampleOAuthPersistence(OAuthPersistence):
24-
def persist(self, hostname: str, oauth_token: OAuthToken):
25-
"""To be implemented by the end user to persist in the preferred storage medium.
28+
def persist(self, hostname: str, oauth_token: OAuthToken):
29+
"""To be implemented by the end user to persist in the preferred storage medium.
2630
27-
OAuthToken has two properties:
28-
1. OAuthToken.access_token
29-
2. OAuthToken.refresh_token
31+
OAuthToken has two properties:
32+
1. OAuthToken.access_token
33+
2. OAuthToken.refresh_token
3034
31-
Both should be persisted.
32-
"""
33-
pass
35+
Both should be persisted.
36+
"""
37+
pass
3438

35-
def read(self, hostname: str) -> Optional[OAuthToken]:
36-
"""To be implemented by the end user to fetch token from the preferred storage
39+
def read(self, hostname: str) -> Optional[OAuthToken]:
40+
"""To be implemented by the end user to fetch token from the preferred storage
3741
38-
Fetch the access_token and refresh_token for the given hostname.
39-
Return OAuthToken(access_token, refresh_token)
40-
"""
41-
pass
42+
Fetch the access_token and refresh_token for the given hostname.
43+
Return OAuthToken(access_token, refresh_token)
44+
"""
45+
pass
4246

43-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
44-
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
45-
auth_type="databricks-oauth",
46-
experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json")) as connection:
47+
48+
with sql.connect(
49+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
50+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
51+
auth_type="databricks-oauth",
52+
experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json"),
53+
) as connection:
4754

4855
for x in range(1, 100):
4956
cursor = connection.cursor()
50-
cursor.execute('SELECT 1+1')
57+
cursor.execute("SELECT 1+1")
5158
result = cursor.fetchall()
5259
for row in result:
5360
print(row)

examples/query_cancel.py

+37-32
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,52 @@
55
The current operation of a cursor may be cancelled by calling its `.cancel()` method as shown in the example below.
66
"""
77

8-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
9-
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
10-
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:
8+
with sql.connect(
9+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
10+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
11+
access_token=os.getenv("DATABRICKS_TOKEN"),
12+
) as connection:
1113

12-
with connection.cursor() as cursor:
13-
def execute_really_long_query():
14-
try:
15-
cursor.execute("SELECT SUM(A.id - B.id) " +
16-
"FROM range(1000000000) A CROSS JOIN range(100000000) B " +
17-
"GROUP BY (A.id - B.id)")
18-
except sql.exc.RequestError:
19-
print("It looks like this query was cancelled.")
14+
with connection.cursor() as cursor:
2015

21-
exec_thread = threading.Thread(target=execute_really_long_query)
16+
def execute_really_long_query():
17+
try:
18+
cursor.execute(
19+
"SELECT SUM(A.id - B.id) "
20+
+ "FROM range(1000000000) A CROSS JOIN range(100000000) B "
21+
+ "GROUP BY (A.id - B.id)"
22+
)
23+
except sql.exc.RequestError:
24+
print("It looks like this query was cancelled.")
2225

23-
print("\n Beginning to execute long query")
24-
exec_thread.start()
26+
exec_thread = threading.Thread(target=execute_really_long_query)
2527

26-
# Make sure the query has started before cancelling
27-
print("\n Waiting 15 seconds before canceling", end="", flush=True)
28+
print("\n Beginning to execute long query")
29+
exec_thread.start()
2830

29-
seconds_waited = 0
30-
while seconds_waited < 15:
31-
seconds_waited += 1
32-
print(".", end="", flush=True)
33-
time.sleep(1)
31+
# Make sure the query has started before cancelling
32+
print("\n Waiting 15 seconds before canceling", end="", flush=True)
3433

35-
print("\n Cancelling the cursor's operation. This can take a few seconds.")
36-
cursor.cancel()
34+
seconds_waited = 0
35+
while seconds_waited < 15:
36+
seconds_waited += 1
37+
print(".", end="", flush=True)
38+
time.sleep(1)
3739

38-
print("\n Now checking the cursor status:")
39-
exec_thread.join(5)
40+
print("\n Cancelling the cursor's operation. This can take a few seconds.")
41+
cursor.cancel()
4042

41-
assert not exec_thread.is_alive()
42-
print("\n The previous command was successfully canceled")
43+
print("\n Now checking the cursor status:")
44+
exec_thread.join(5)
4345

44-
print("\n Now reusing the cursor to run a separate query.")
46+
assert not exec_thread.is_alive()
47+
print("\n The previous command was successfully canceled")
4548

46-
# We can still execute a new command on the cursor
47-
cursor.execute("SELECT * FROM range(3)")
49+
print("\n Now reusing the cursor to run a separate query.")
4850

49-
print("\n Execution was successful. Results appear below:")
51+
# We can still execute a new command on the cursor
52+
cursor.execute("SELECT * FROM range(3)")
5053

51-
print(cursor.fetchall())
54+
print("\n Execution was successful. Results appear below:")
55+
56+
print(cursor.fetchall())

examples/query_execute.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from databricks import sql
22
import os
33

4-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
5-
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
6-
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:
4+
with sql.connect(
5+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
6+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
7+
access_token=os.getenv("DATABRICKS_TOKEN"),
8+
) as connection:
79

8-
with connection.cursor() as cursor:
9-
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
10-
result = cursor.fetchall()
10+
with connection.cursor() as cursor:
11+
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
12+
result = cursor.fetchall()
1113

12-
for row in result:
13-
print(row)
14+
for row in result:
15+
print(row)

examples/set_user_agent.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from databricks import sql
22
import os
33

4-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
5-
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
6-
access_token = os.getenv("DATABRICKS_TOKEN"),
7-
_user_agent_entry="ExamplePartnerTag") as connection:
4+
with sql.connect(
5+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
6+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
7+
access_token=os.getenv("DATABRICKS_TOKEN"),
8+
_user_agent_entry="ExamplePartnerTag",
9+
) as connection:
810

9-
with connection.cursor() as cursor:
10-
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
11-
result = cursor.fetchall()
11+
with connection.cursor() as cursor:
12+
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
13+
result = cursor.fetchall()
1214

13-
for row in result:
14-
print(row)
15+
for row in result:
16+
print(row)

examples/v3_retries_query_execute.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,18 @@
2828
#
2929
# For complete information about configuring retries, see the docstring for databricks.sql.thrift_backend.ThriftBackend
3030

31-
with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
32-
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
33-
access_token = os.getenv("DATABRICKS_TOKEN"),
34-
_enable_v3_retries = True,
35-
_retry_dangerous_codes=[502,400],
36-
_retry_max_redirects=2) as connection:
31+
with sql.connect(
32+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
33+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
34+
access_token=os.getenv("DATABRICKS_TOKEN"),
35+
_enable_v3_retries=True,
36+
_retry_dangerous_codes=[502, 400],
37+
_retry_max_redirects=2,
38+
) as connection:
3739

38-
with connection.cursor() as cursor:
39-
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
40-
result = cursor.fetchall()
40+
with connection.cursor() as cursor:
41+
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
42+
result = cursor.fetchall()
4143

42-
for row in result:
43-
print(row)
44+
for row in result:
45+
print(row)

src/databricks/sql/client.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
22

33
import pandas
4+
45
try:
56
import pyarrow
67
except ImportError:
@@ -26,7 +27,7 @@
2627
inject_parameters,
2728
transform_paramstyle,
2829
ColumnTable,
29-
ColumnQueue
30+
ColumnQueue,
3031
)
3132
from databricks.sql.parameters.native import (
3233
DbsqlParameterBase,
@@ -1155,7 +1156,7 @@ def _convert_columnar_table(self, table):
11551156
for row_index in range(table.num_rows):
11561157
curr_row = []
11571158
for col_index in range(table.num_columns):
1158-
curr_row.append(table.get_item(col_index, row_index))
1159+
curr_row.append(table.get_item(col_index, row_index))
11591160
result.append(ResultRow(*curr_row))
11601161

11611162
return result
@@ -1238,7 +1239,10 @@ def merge_columnar(self, result1, result2):
12381239
if result1.column_names != result2.column_names:
12391240
raise ValueError("The columns in the results don't match")
12401241

1241-
merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)]
1242+
merged_result = [
1243+
result1.column_table[i] + result2.column_table[i]
1244+
for i in range(result1.num_columns)
1245+
]
12421246
return ColumnTable(merged_result, result1.column_names)
12431247

12441248
def fetchmany_columnar(self, size: int):
@@ -1254,9 +1258,9 @@ def fetchmany_columnar(self, size: int):
12541258
self._next_row_index += results.num_rows
12551259

12561260
while (
1257-
n_remaining_rows > 0
1258-
and not self.has_been_closed_server_side
1259-
and self.has_more_rows
1261+
n_remaining_rows > 0
1262+
and not self.has_been_closed_server_side
1263+
and self.has_more_rows
12601264
):
12611265
self._fill_results_buffer()
12621266
partial_results = self.results.next_n_rows(n_remaining_rows)

0 commit comments

Comments
 (0)