Skip to content

Commit 0ddca9d

Browse files
committed
Refractored the test code and moved to respective folders
1 parent c576110 commit 0ddca9d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+588
-633
lines changed

databricks_sql_connector_core/src/databricks_sql_connector_core/sql/client.py

-3
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,6 @@ def execute(
784784
parameters=prepared_params,
785785
)
786786

787-
# print("Line 781")
788-
# print(execute_response)
789787
self.active_result_set = ResultSet(
790788
self.connection,
791789
execute_response,
@@ -1141,7 +1139,6 @@ def _fill_results_buffer(self):
11411139
def _convert_columnar_table(self, table):
11421140
column_names = [c[0] for c in self.description]
11431141
ResultRow = Row(*column_names)
1144-
# print("Table\n",table)
11451142
result = []
11461143
for row_index in range(len(table[0])):
11471144
curr_row = []

databricks_sql_connector_core/src/databricks_sql_connector_core/sql/thrift_backend.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
743743
else:
744744
t_result_set_metadata_resp = self._get_metadata_resp(resp.operationHandle)
745745

746-
# print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}")
747746
if t_result_set_metadata_resp.resultFormat not in [
748747
ttypes.TSparkRowSetType.ARROW_BASED_SET,
749748
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
@@ -880,18 +879,8 @@ def execute_command(
880879
# We want to receive proper Timestamp arrow types.
881880
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
882881
},
883-
# useArrowNativeTypes=spark_arrow_types,
884-
# canReadArrowResult=True,
885-
# # canDecompressLZ4Result=lz4_compression,
886-
# canDecompressLZ4Result=False,
887-
# canDownloadResult=False,
888-
# # confOverlay={
889-
# # # We want to receive proper Timestamp arrow types.
890-
# # "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
891-
# # },
892-
# resultDataFormat=TDBSqlResultFormat(None,None,True),
893-
# # useArrowNativeTypes=spark_arrow_types,
894-
parameters=parameters,
882+
useArrowNativeTypes=spark_arrow_types,
883+
parameters=parameters,
895884
)
896885
resp = self.make_request(self._client.ExecuteStatement, req)
897886
return self._handle_execute_response(resp, cursor)

databricks_sql_connector_core/src/databricks_sql_connector_core/sql/utils.py

-31
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,6 @@ def build_queue(
7474
Returns:
7575
ResultSetQueue
7676
"""
77-
78-
# def trow_to_json(trow):
79-
# # Step 1: Serialize TRow using Thrift's TJSONProtocol
80-
# transport = TTransport.TMemoryBuffer()
81-
# protocol = TJSONProtocol.TJSONProtocol(transport)
82-
# trow.write(protocol)
83-
#
84-
# # Step 2: Extract JSON string from the transport
85-
# json_str = transport.getvalue().decode('utf-8')
86-
#
87-
# return json_str
88-
8977
if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
9078
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
9179
t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes
@@ -95,30 +83,11 @@ def build_queue(
9583
)
9684
return ArrowQueue(converted_arrow_table, n_valid_rows)
9785
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
98-
# print("Lin 79 ")
99-
# print(type(t_row_set))
100-
# print(t_row_set)
101-
# json_str = json.loads(trow_to_json(t_row_set))
102-
# pretty_json = json.dumps(json_str, indent=2)
103-
# print(pretty_json)
104-
10586
converted_column_table, column_names = convert_column_based_set_to_column_table(
10687
t_row_set.columns,
10788
description)
108-
# print(converted_column_table, column_names)
10989

11090
return ColumnQueue(converted_column_table, column_names)
111-
112-
# print(columnQueue.next_n_rows(2))
113-
# print(columnQueue.next_n_rows(2))
114-
# print(columnQueue.remaining_rows())
115-
# arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
116-
# t_row_set.columns, description
117-
# )
118-
# converted_arrow_table = convert_decimals_in_arrow_table(
119-
# arrow_table, description
120-
# )
121-
# return ArrowQueue(converted_arrow_table, n_valid_rows)
12291
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
12392
return CloudFetchQueue(
12493
schema_bytes=arrow_schema_bytes,
File renamed without changes.

tests/e2e/common/predicates.py renamed to databricks_sql_connector_core/tests/e2e/common/predicates.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99

1010
def pysql_supports_arrow():
11-
"""Import databricks.sql and test whether Cursor has fetchall_arrow."""
12-
from databricks.sql import Cursor
11+
"""Import databricks_sql_connector_core.sql and test whether Cursor has fetchall_arrow."""
12+
from databricks_sql_connector_core.sql.client import Cursor
1313
return hasattr(Cursor, 'fetchall_arrow')
1414

1515

1616
def pysql_has_version(compare, version):
17-
"""Import databricks.sql, and return compare_module_version(...).
17+
"""Import databricks_sql_connector_core.sql, and return compare_module_version(...).
1818
1919
Expected use:
2020
from common.predicates import pysql_has_version
@@ -98,4 +98,4 @@ def validate_version(version):
9898

9999
mod_version = validate_version(module.__version__)
100100
req_version = validate_version(version)
101-
return compare_versions(compare, mod_version, req_version)
101+
return compare_versions(compare, mod_version, req_version)

tests/e2e/common/retry_test_mixins.py renamed to databricks_sql_connector_core/tests/e2e/common/retry_test_mixins.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import pytest
77
from urllib3.exceptions import MaxRetryError
88

9-
from databricks.sql import DatabricksRetryPolicy
10-
from databricks.sql import (
9+
from databricks_sql_connector_core.sql.auth.retry import DatabricksRetryPolicy
10+
from databricks_sql_connector_core.sql.exc import (
1111
MaxRetryDurationError,
1212
NonRecoverableNetworkError,
1313
RequestError,
@@ -146,7 +146,7 @@ def test_retry_urllib3_settings_are_honored(self):
146146
def test_oserror_retries(self):
147147
"""If a network error occurs during make_request, the request is retried according to policy"""
148148
with patch(
149-
"urllib3.connectionpool.HTTPSConnectionPool._validate_conn",
149+
"urllib3.connectionpool.HTTPSConnectionPool._validate_conn",
150150
) as mock_validate_conn:
151151
mock_validate_conn.side_effect = OSError("Some arbitrary network error")
152152
with pytest.raises(MaxRetryError) as cm:
@@ -275,7 +275,7 @@ def test_retry_safe_execute_statement_retry_condition(self):
275275
]
276276

277277
with self.connection(
278-
extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1}
278+
extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1}
279279
) as conn:
280280
with conn.cursor() as cursor:
281281
# Code 502 is a Bad Gateway, which we commonly see in production under heavy load
@@ -318,9 +318,9 @@ def test_retry_abort_close_operation_on_404(self, caplog):
318318
with self.connection(extra_params={**self._retry_policy}) as conn:
319319
with conn.cursor() as curs:
320320
with patch(
321-
"databricks.sql.utils.ExecuteResponse.has_been_closed_server_side",
322-
new_callable=PropertyMock,
323-
return_value=False,
321+
"databricks_sql_connector_core.sql.utils.ExecuteResponse.has_been_closed_server_side",
322+
new_callable=PropertyMock,
323+
return_value=False,
324324
):
325325
# This call guarantees we have an open cursor at the server
326326
curs.execute("SELECT 1")
@@ -340,10 +340,10 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self):
340340
with mocked_server_response(status=302, redirect_location="/foo.bar") as mock_obj:
341341
with pytest.raises(MaxRetryError) as cm:
342342
with self.connection(
343-
extra_params={
344-
**self._retry_policy,
345-
"_retry_max_redirects": max_redirects,
346-
}
343+
extra_params={
344+
**self._retry_policy,
345+
"_retry_max_redirects": max_redirects,
346+
}
347347
):
348348
pass
349349
assert "too many redirects" == str(cm.value.reason)
@@ -362,9 +362,9 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self):
362362
with mocked_server_response(status=302, redirect_location="/foo.bar/") as mock_obj:
363363
with pytest.raises(MaxRetryError) as cm:
364364
with self.connection(
365-
extra_params={
366-
**self._retry_policy,
367-
}
365+
extra_params={
366+
**self._retry_policy,
367+
}
368368
):
369369
pass
370370

@@ -394,13 +394,13 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self):
394394

395395
def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog):
396396
with self.connection(
397-
extra_params={
398-
**self._retry_policy,
399-
**{
400-
"_retry_max_redirects": 100,
401-
"_retry_stop_after_attempts_count": 1,
402-
},
403-
}
397+
extra_params={
398+
**self._retry_policy,
399+
**{
400+
"_retry_max_redirects": 100,
401+
"_retry_stop_after_attempts_count": 1,
402+
},
403+
}
404404
):
405405
assert "it will have no affect!" in caplog.text
406406

@@ -433,4 +433,4 @@ def test_401_not_retried(self):
433433
with pytest.raises(RequestError) as cm:
434434
with self.connection(extra_params=self._retry_policy):
435435
pass
436-
assert isinstance(cm.value.args[1], NonRecoverableNetworkError)
436+
assert isinstance(cm.value.args[1], NonRecoverableNetworkError)

tests/e2e/common/staging_ingestion_tests.py renamed to databricks_sql_connector_core/tests/e2e/common/staging_ingestion_tests.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import tempfile
33

44
import pytest
5-
import databricks.sql as sql
6-
from databricks.sql import Error
5+
import databricks_sql_connector_core.sql as sql
6+
from databricks_sql_connector_core.sql import Error
77

88

99
@pytest.fixture(scope="module", autouse=True)
@@ -100,7 +100,7 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, in
100100
cursor.execute(query)
101101

102102
def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path(
103-
self, ingestion_user
103+
self, ingestion_user
104104
):
105105

106106
fh, temp_path = tempfile.mkstemp()
@@ -116,8 +116,8 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p
116116
base_path = os.path.join(base_path, "temp")
117117

118118
with pytest.raises(
119-
Error,
120-
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
119+
Error,
120+
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
121121
):
122122
with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn:
123123
cursor = conn.cursor()
@@ -158,7 +158,7 @@ def perform_remove():
158158

159159
# Try to put it again
160160
with pytest.raises(
161-
sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS"
161+
sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS"
162162
):
163163
perform_put()
164164

@@ -209,7 +209,7 @@ def perform_get():
209209
perform_get()
210210

211211
def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path(
212-
self, ingestion_user
212+
self, ingestion_user
213213
):
214214
"""
215215
This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths.
@@ -222,11 +222,11 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe
222222
target_file = "/var/www/html/../html1/not_allowed.html"
223223

224224
with pytest.raises(
225-
Error,
226-
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
225+
Error,
226+
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
227227
):
228228
with self.connection(
229-
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
229+
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
230230
) as conn:
231231
cursor = conn.cursor()
232232
query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE"
@@ -238,7 +238,7 @@ def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, inges
238238

239239
with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"):
240240
with self.connection(
241-
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
241+
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
242242
) as conn:
243243
cursor = conn.cursor()
244244
query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE"
@@ -250,14 +250,14 @@ def test_staging_ingestion_invalid_staging_path_fails_at_server(self, ingestion_
250250

251251
with pytest.raises(Error, match="INVALID_STAGING_PATH_IN_STAGING_ACCESS_QUERY"):
252252
with self.connection(
253-
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
253+
extra_params={"staging_allowed_local_path": staging_allowed_local_path}
254254
) as conn:
255255
cursor = conn.cursor()
256256
query = f"PUT '{target_file}' INTO 'stageRANDOMSTRINGOFCHARACTERS://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE"
257257
cursor.execute(query)
258258

259259
def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values(
260-
self, ingestion_user
260+
self, ingestion_user
261261
):
262262
"""staging_allowed_local_path may be either a path-like object or a list of path-like objects.
263263
@@ -286,19 +286,19 @@ def generate_file_and_path_and_queries():
286286
fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries()
287287

288288
with self.connection(
289-
extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]}
289+
extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]}
290290
) as conn:
291291
cursor = conn.cursor()
292292

293293
cursor.execute(put_query1)
294294
cursor.execute(put_query2)
295295

296296
with pytest.raises(
297-
Error,
298-
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
297+
Error,
298+
match="Local file operations are restricted to paths within the configured staging_allowed_local_path",
299299
):
300300
cursor.execute(put_query3)
301301

302302
# Then clean up the files we made
303303
cursor.execute(remove_query1)
304-
cursor.execute(remove_query2)
304+
cursor.execute(remove_query2)

0 commit comments

Comments
 (0)