Skip to content

Commit f4f27e3

Browse files
Merge remote-tracking branch 'origin/backend-interface' into fetch-interface
2 parents 9267ef9 + 8da84e8 commit f4f27e3

File tree

7 files changed

+32
-15
lines changed

7 files changed

+32
-15
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,11 @@ def _handle_execute_response_async(self, resp, cursor):
11661166
cursor.active_command_id = command_id
11671167
self._check_direct_results_for_error(resp.directResults)
11681168

1169+
def _handle_execute_response_async(self, resp, cursor):
1170+
command_id = CommandId.from_thrift_handle(resp.operationHandle)
1171+
cursor.active_command_id = command_id
1172+
self._check_direct_results_for_error(resp.directResults)
1173+
11691174
def fetch_results(
11701175
self,
11711176
command_id: CommandId,

src/databricks/sql/backend/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def guid_to_hex_id(guid: bytes) -> str:
4747
try:
4848
this_uuid = uuid.UUID(bytes=guid)
4949
except Exception as e:
50-
logger.debug(f"Unable to convert bytes to UUID: {guid} -- {str(e)}")
50+
logger.debug(f"Unable to convert bytes to UUID: {guid!r} -- {str(e)}")
5151
return str(guid)
5252
return str(this_uuid)
5353

src/databricks/sql/client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ def read(self) -> Optional[OAuthToken]:
232232
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
233233
self._cursors = [] # type: List[Cursor]
234234

235-
# Create the session
236235
self.session = Session(
237236
server_hostname,
238237
http_path,
@@ -245,11 +244,6 @@ def read(self) -> Optional[OAuthToken]:
245244
)
246245
self.session.open()
247246

248-
logger.info(
249-
"Successfully opened connection with session "
250-
+ str(self.get_session_id_hex())
251-
)
252-
253247
self.use_inline_params = self._set_use_inline_params_with_warning(
254248
kwargs.get("use_inline_params", False)
255249
)
@@ -788,6 +782,14 @@ def execute(
788782
async_op=False,
789783
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
790784
)
785+
self.active_result_set = ResultSet(
786+
self.connection,
787+
execute_response,
788+
self.backend,
789+
self.buffer_size_bytes,
790+
self.arraysize,
791+
self.connection.use_cloud_fetch,
792+
)
791793

792794
if self.active_result_set.is_staging_operation:
793795
self._handle_staging_operation(
@@ -853,6 +855,8 @@ def get_query_state(self) -> CommandState:
853855
:return:
854856
"""
855857
self._check_not_closed()
858+
if self.active_command_id is None:
859+
raise Error("No active command to get state for")
856860
return self.backend.get_query_state(self.active_command_id)
857861

858862
def is_query_pending(self):

src/databricks/sql/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def __init__(
8484
**kwargs,
8585
)
8686

87+
self.protocol_version = None
88+
8789
def open(self):
8890
self._session_id = self.backend.open_session(
8991
session_configuration=self.session_configuration,

tests/e2e/test_complex_types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def table_fixture(self, connection_details):
1414
# Create the table
1515
cursor.execute(
1616
"""
17-
CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table (
17+
CREATE TABLE IF NOT EXISTS pysql_e2e_test_complex_types_table (
1818
array_col ARRAY<STRING>,
1919
map_col MAP<STRING, INTEGER>,
2020
struct_col STRUCT<field1: STRING, field2: INTEGER>
@@ -24,7 +24,7 @@ def table_fixture(self, connection_details):
2424
# Insert a record
2525
cursor.execute(
2626
"""
27-
INSERT INTO pysql_test_complex_types_table
27+
INSERT INTO pysql_e2e_test_complex_types_table
2828
VALUES (
2929
ARRAY('a', 'b', 'c'),
3030
MAP('a', 1, 'b', 2, 'c', 3),
@@ -34,7 +34,7 @@ def table_fixture(self, connection_details):
3434
)
3535
yield
3636
# Clean up the table after the test
37-
cursor.execute("DROP TABLE IF EXISTS pysql_test_complex_types_table")
37+
cursor.execute("DROP TABLE IF EXISTS pysql_e2e_test_complex_types_table")
3838

3939
@pytest.mark.parametrize(
4040
"field,expected_type",
@@ -45,7 +45,7 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
4545

4646
with self.cursor() as cursor:
4747
result = cursor.execute(
48-
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
48+
"SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1"
4949
).fetchone()
5050

5151
assert isinstance(result[field], expected_type)
@@ -57,7 +57,7 @@ def test_read_complex_types_as_string(self, field, table_fixture):
5757
extra_params={"_use_arrow_native_complex_types": False}
5858
) as cursor:
5959
result = cursor.execute(
60-
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
60+
"SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1"
6161
).fetchone()
6262

6363
assert isinstance(result[field], str)

tests/unit/test_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ def test_closing_connection_closes_commands(self, mock_result_set_class):
102102
# Close the connection
103103
connection.close()
104104

105-
# After closing the connection, the close method should have been called on the result set
106-
mock_result_set.close.assert_called_once_with()
105+
self.assertTrue(
106+
mock_result_set_class.return_value.has_been_closed_server_side
107+
)
108+
mock_result_set_class.return_value.close.assert_called_once_with()
107109

108110
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
109111
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):

tests/unit/test_thrift_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass):
178178

179179
for protocol_version in good_protocol_versions:
180180
t_http_client_instance.OpenSession.return_value = ttypes.TOpenSessionResp(
181-
status=self.okay_status, serverProtocolVersion=protocol_version
181+
status=self.okay_status,
182+
serverProtocolVersion=protocol_version,
183+
sessionHandle=self.session_handle,
182184
)
183185

184186
thrift_backend = self._make_fake_thrift_backend()
@@ -2079,6 +2081,7 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch
20792081
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
20802082
canUseMultipleCatalogs=can_use_multiple_cats,
20812083
initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem),
2084+
sessionHandle=self.session_handle,
20822085
)
20832086

20842087
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
@@ -2178,6 +2181,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
21782181
serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3,
21792182
canUseMultipleCatalogs=True,
21802183
initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"),
2184+
sessionHandle=self.session_handle,
21812185
)
21822186

21832187
backend = ThriftDatabricksClient(

0 commit comments

Comments
 (0)