Skip to content

Commit 00d9aeb

Browse files
correct typing
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 7a47dd0 commit 00d9aeb

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def execute_command(
4242
parameters: List[ttypes.TSparkParameter],
4343
async_op: bool,
4444
enforce_embedded_schema_correctness: bool,
45-
) -> "ResultSet":
45+
) -> Union["ResultSet", None]:
4646
pass
4747

4848
@abstractmethod
@@ -62,7 +62,7 @@ def get_execution_result(
6262
self,
6363
command_id: CommandId,
6464
cursor: Any,
65-
) -> ExecuteResponse:
65+
) -> "ResultSet":
6666
pass
6767

6868
# == Metadata Operations ==

src/databricks/sql/backend/thrift_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def execute_command(
944944
parameters=[],
945945
async_op=False,
946946
enforce_embedded_schema_correctness=False,
947-
) -> "ResultSet":
947+
) -> Union["ResultSet", None]:
948948
thrift_handle = session_id.to_thrift_handle()
949949
if not thrift_handle:
950950
raise ValueError("Not a valid Thrift session ID")

src/databricks/sql/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def execute(
783783
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
784784
)
785785

786-
if self.active_result_set.is_staging_operation:
786+
if self.active_result_set and self.active_result_set.is_staging_operation:
787787
self._handle_staging_operation(
788788
staging_allowed_local_path=self.backend.staging_allowed_local_path
789789
)
@@ -879,7 +879,7 @@ def get_async_execution_result(self):
879879
self.active_op_handle, self
880880
)
881881

882-
if self.active_result_set.is_staging_operation:
882+
if self.active_result_set and self.active_result_set.is_staging_operation:
883883
self._handle_staging_operation(
884884
staging_allowed_local_path=self.backend.staging_allowed_local_path
885885
)

src/databricks/sql/result_set.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def __iter__(self):
4646
def rownumber(self):
4747
return self._next_row_index
4848

49+
@property
50+
@abstractmethod
51+
def is_staging_operation(self) -> bool:
52+
"""Whether this result set represents a staging operation."""
53+
pass
54+
4955
# Define abstract methods that concrete implementations must implement
5056
@abstractmethod
5157
def _fill_results_buffer(self):
@@ -117,7 +123,7 @@ def __init__(
117123
self.description = execute_response.description
118124
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
119125
self._use_cloud_fetch = use_cloud_fetch
120-
self.is_staging_operation = execute_response.is_staging_operation
126+
self._is_staging_operation = execute_response.is_staging_operation
121127

122128
# Initialize results queue
123129
if execute_response.arrow_queue:
@@ -350,3 +356,8 @@ def close(self) -> None:
350356
finally:
351357
self.has_been_closed_server_side = True
352358
self.op_state = ttypes.TOperationState.CLOSED_STATE
359+
360+
@property
361+
def is_staging_operation(self) -> bool:
362+
"""Whether this result set represents a staging operation."""
363+
return self._is_staging_operation

0 commit comments

Comments
 (0)