Skip to content

Commit 7e730db

Browse files
move common state to ResultSet aprent
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent ac34732 commit 7e730db

File tree

3 files changed

+106
-37
lines changed

3 files changed

+106
-37
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
if TYPE_CHECKING:
1111
from databricks.sql.client import Cursor
12+
from databricks.sql.result_set import ResultSet, ThriftResultSet
1213

1314
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1415
from databricks.sql.backend.types import (
@@ -52,7 +53,6 @@
5253
)
5354
from databricks.sql.types import SSLOptions
5455
from databricks.sql.backend.databricks_client import DatabricksClient
55-
from databricks.sql.result_set import ResultSet, ThriftResultSet
5656

5757
logger = logging.getLogger(__name__)
5858

@@ -811,6 +811,8 @@ def _results_message_to_execute_response(self, resp, operation_state):
811811
def get_execution_result(
812812
self, command_id: CommandId, cursor: "Cursor"
813813
) -> "ResultSet":
814+
from databricks.sql.result_set import ThriftResultSet
815+
814816
thrift_handle = command_id.to_thrift_handle()
815817
if not thrift_handle:
816818
raise ValueError("Not a valid Thrift command ID")
@@ -906,7 +908,10 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
906908
poll_resp = self._poll_for_status(thrift_handle)
907909
operation_state = poll_resp.operationState
908910
self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp)
909-
return CommandState.from_thrift_state(operation_state)
911+
state = CommandState.from_thrift_state(operation_state)
912+
if state is None:
913+
raise ValueError(f"Unknown command state: {operation_state}")
914+
return state
910915

911916
@staticmethod
912917
def _check_direct_results_for_error(t_spark_direct_results):
@@ -941,6 +946,8 @@ def execute_command(
941946
async_op=False,
942947
enforce_embedded_schema_correctness=False,
943948
) -> Union["ResultSet", None]:
949+
from databricks.sql.result_set import ThriftResultSet
950+
944951
thrift_handle = session_id.to_thrift_handle()
945952
if not thrift_handle:
946953
raise ValueError("Not a valid Thrift session ID")
@@ -1005,6 +1012,8 @@ def get_catalogs(
10051012
max_bytes: int,
10061013
cursor: "Cursor",
10071014
) -> "ResultSet":
1015+
from databricks.sql.result_set import ThriftResultSet
1016+
10081017
thrift_handle = session_id.to_thrift_handle()
10091018
if not thrift_handle:
10101019
raise ValueError("Not a valid Thrift session ID")
@@ -1037,6 +1046,8 @@ def get_schemas(
10371046
catalog_name=None,
10381047
schema_name=None,
10391048
) -> "ResultSet":
1049+
from databricks.sql.result_set import ThriftResultSet
1050+
10401051
thrift_handle = session_id.to_thrift_handle()
10411052
if not thrift_handle:
10421053
raise ValueError("Not a valid Thrift session ID")
@@ -1073,6 +1084,8 @@ def get_tables(
10731084
table_name=None,
10741085
table_types=None,
10751086
) -> "ResultSet":
1087+
from databricks.sql.result_set import ThriftResultSet
1088+
10761089
thrift_handle = session_id.to_thrift_handle()
10771090
if not thrift_handle:
10781091
raise ValueError("Not a valid Thrift session ID")
@@ -1111,6 +1124,8 @@ def get_columns(
11111124
table_name=None,
11121125
column_name=None,
11131126
) -> "ResultSet":
1127+
from databricks.sql.result_set import ThriftResultSet
1128+
11141129
thrift_handle = session_id.to_thrift_handle()
11151130
if not thrift_handle:
11161131
raise ValueError("Not a valid Thrift session ID")

src/databricks/sql/backend/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ class CommandState(Enum):
3333
CANCELLED = "CANCELLED"
3434

3535
@classmethod
36-
def from_thrift_state(cls, state: ttypes.TOperationState) -> "CommandState":
36+
def from_thrift_state(
37+
cls, state: ttypes.TOperationState
38+
) -> Optional["CommandState"]:
3739
"""
3840
Convert a Thrift TOperationState to a normalized CommandState.
3941
@@ -75,7 +77,7 @@ def from_thrift_state(cls, state: ttypes.TOperationState) -> "CommandState":
7577
elif state == ttypes.TOperationState.CANCELED_STATE:
7678
return cls.CANCELLED
7779
else:
78-
raise ValueError(f"Unknown command state: {state}")
80+
return None
7981

8082

8183
class BackendType(Enum):

src/databricks/sql/result_set.py

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, Any, Union
2+
from typing import List, Optional, Any, Union, TYPE_CHECKING
33

44
import logging
55
import time
66
import pandas
77

8+
from databricks.sql.backend.types import CommandId, CommandState
9+
810
try:
911
import pyarrow
1012
except ImportError:
1113
pyarrow = None
1214

15+
if TYPE_CHECKING:
16+
from databricks.sql.backend.databricks_client import DatabricksClient
17+
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
18+
from databricks.sql.client import Connection
19+
1320
from databricks.sql.thrift_api.TCLIService import ttypes
1421
from databricks.sql.types import Row
1522
from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError
@@ -25,10 +32,30 @@ class ResultSet(ABC):
2532
This class defines the interface that all concrete result set implementations must follow.
2633
"""
2734

28-
def __init__(self, connection, backend, arraysize: int, buffer_size_bytes: int):
29-
"""Initialize the base ResultSet with common properties."""
35+
def __init__(
36+
self,
37+
connection: "Connection",
38+
backend: "DatabricksClient",
39+
command_id: CommandId,
40+
op_state: Optional[CommandState],
41+
has_been_closed_server_side: bool,
42+
arraysize: int,
43+
buffer_size_bytes: int,
44+
):
45+
"""
46+
A ResultSet manages the results of a single command.
47+
48+
:param connection: The parent connection that was used to execute this command
49+
:param backend: The specialised backend client to be invoked in the fetch phase
50+
:param execute_response: A `ExecuteResponse` class returned by a command execution
51+
:param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
52+
amount :param arraysize: The max number of rows to fetch at a time (PEP-249)
53+
"""
54+
self.command_id = command_id
55+
self.op_state = op_state
56+
self.has_been_closed_server_side = has_been_closed_server_side
3057
self.connection = connection
31-
self.backend = backend # Store the backend client directly
58+
self.backend = backend
3259
self.arraysize = arraysize
3360
self.buffer_size_bytes = buffer_size_bytes
3461
self._next_row_index = 0
@@ -83,20 +110,36 @@ def fetchall_arrow(self) -> Any:
83110
"""Fetch all remaining rows as an Arrow table."""
84111
pass
85112

86-
@abstractmethod
87113
def close(self) -> None:
88-
"""Close the result set and release any resources."""
89-
pass
114+
"""
115+
Close the result set.
116+
117+
If the connection has not been closed, and the result set has not already
118+
been closed on the server for some other reason, issue a request to the server to close it.
119+
"""
120+
try:
121+
if (
122+
self.op_state != CommandState.CLOSED
123+
and not self.has_been_closed_server_side
124+
and self.connection.open
125+
):
126+
self.backend.close_command(self.command_id)
127+
except RequestError as e:
128+
if isinstance(e.args[1], CursorAlreadyClosedError):
129+
logger.info("Operation was canceled by a prior request")
130+
finally:
131+
self.has_been_closed_server_side = True
132+
self.op_state = CommandState.CLOSED
90133

91134

92135
class ThriftResultSet(ResultSet):
93136
"""ResultSet implementation for the Thrift backend."""
94137

95138
def __init__(
96139
self,
97-
connection,
140+
connection: "Connection",
98141
execute_response: ExecuteResponse,
99-
thrift_client, # Pass the specific ThriftDatabricksClient instance
142+
thrift_client: "ThriftDatabricksClient",
100143
buffer_size_bytes: int = 104857600,
101144
arraysize: int = 10000,
102145
use_cloud_fetch: bool = True,
@@ -112,11 +155,20 @@ def __init__(
112155
arraysize: Default number of rows to fetch
113156
use_cloud_fetch: Whether to use cloud fetch for retrieving results
114157
"""
115-
super().__init__(connection, thrift_client, arraysize, buffer_size_bytes)
158+
command_id = execute_response.command_id
159+
op_state = CommandState.from_thrift_state(execute_response.status)
160+
has_been_closed_server_side = execute_response.has_been_closed_server_side
161+
super().__init__(
162+
connection,
163+
thrift_client,
164+
command_id,
165+
op_state,
166+
has_been_closed_server_side,
167+
arraysize,
168+
buffer_size_bytes,
169+
)
116170

117171
# Initialize ThriftResultSet-specific attributes
118-
self.command_id = execute_response.command_id
119-
self.op_state = execute_response.status
120172
self.has_been_closed_server_side = execute_response.has_been_closed_server_side
121173
self.has_more_rows = execute_response.has_more_rows
122174
self.lz4_compressed = execute_response.lz4_compressed
@@ -127,11 +179,15 @@ def __init__(
127179

128180
# Initialize results queue
129181
if execute_response.arrow_queue:
182+
# In this case the server has taken the fast path and returned an initial batch of
183+
# results
130184
self.results = execute_response.arrow_queue
131185
else:
186+
# In this case, there are results waiting on the server so we fetch now for simplicity
132187
self._fill_results_buffer()
133188

134189
def _fill_results_buffer(self):
190+
# At initialization or if the server does not have cloud fetch result links available
135191
results, has_more_rows = self.backend.fetch_results(
136192
command_id=self.command_id,
137193
max_rows=self.arraysize,
@@ -336,28 +392,24 @@ def fetchmany(self, size: int) -> List[Row]:
336392
else:
337393
return self._convert_arrow_table(self.fetchmany_arrow(size))
338394

339-
def close(self) -> None:
340-
"""
341-
Close the cursor.
342-
343-
If the connection has not been closed, and the cursor has not already
344-
been closed on the server for some other reason, issue a request to the server to close it.
345-
"""
346-
try:
347-
if (
348-
self.op_state != ttypes.TOperationState.CLOSED_STATE
349-
and not self.has_been_closed_server_side
350-
and self.connection.open
351-
):
352-
self.backend.close_command(self.command_id)
353-
except RequestError as e:
354-
if isinstance(e.args[1], CursorAlreadyClosedError):
355-
logger.info("Operation was canceled by a prior request")
356-
finally:
357-
self.has_been_closed_server_side = True
358-
self.op_state = ttypes.TOperationState.CLOSED_STATE
359-
360395
@property
361396
def is_staging_operation(self) -> bool:
362397
"""Whether this result set represents a staging operation."""
363398
return self._is_staging_operation
399+
400+
@staticmethod
401+
def _get_schema_description(table_schema_message):
402+
"""
403+
Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249
404+
"""
405+
406+
def map_col_type(type_):
407+
if type_.startswith("decimal"):
408+
return "decimal"
409+
else:
410+
return type_
411+
412+
return [
413+
(column.name, map_col_type(column.datatype), None, None, None, None, None)
414+
for column in table_schema_message.columns
415+
]

0 commit comments

Comments
 (0)