diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py new file mode 100644 index 00000000..edff1015 --- /dev/null +++ b/src/databricks/sql/backend/databricks_client.py @@ -0,0 +1,344 @@ +""" +Abstract client interface for interacting with Databricks SQL services. + +Implementations of this class are responsible for: +- Managing connections to Databricks SQL services +- Executing SQL queries and commands +- Retrieving query results +- Fetching metadata about catalogs, schemas, tables, and columns +""" + +from abc import ABC, abstractmethod +from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.backend.types import SessionId, CommandId +from databricks.sql.utils import ExecuteResponse +from databricks.sql.types import SSLOptions + + +class DatabricksClient(ABC): + # == Connection and Session Management == + @abstractmethod + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + """ + Opens a new session with the Databricks SQL service. + + This method establishes a new session with the server and returns a session + identifier that can be used for subsequent operations. + + Args: + session_configuration: Optional dictionary of configuration parameters for the session + catalog: Optional catalog name to use as the initial catalog for the session + schema: Optional schema name to use as the initial schema for the session + + Returns: + SessionId: A session identifier object that can be used for subsequent operations + + Raises: + Error: If the session configuration is invalid + OperationalError: If there's an error establishing the session + InvalidServerResponseError: If the server response is invalid or unexpected + """ + pass + + @abstractmethod + def close_session(self, session_id: SessionId) -> None: + """ + Closes an existing session with the Databricks SQL service. + + This method terminates the session identified by the given session ID and + releases any resources associated with it. + + Args: + session_id: The session identifier returned by open_session() + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error closing the session + """ + pass + + # == Query Execution, Command Management == + @abstractmethod + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + ) -> Optional[ExecuteResponse]: + """ + Executes a SQL command or query within the specified session. + + This method sends a SQL command to the server for execution and handles + the response. It can operate in both synchronous and asynchronous modes. + + Args: + operation: The SQL command or query to execute + session_id: The session identifier in which to execute the command + max_rows: Maximum number of rows to fetch in a single fetch batch + max_bytes: Maximum number of bytes to fetch in a single fetch batch + lz4_compression: Whether to use LZ4 compression for result data + cursor: The cursor object that will handle the results + use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets + parameters: List of parameters to bind to the query + async_op: Whether to execute the command asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + If async_op is False, returns an ExecuteResponse object containing the + query results and metadata. If async_op is True, returns None and the + results must be fetched later using get_execution_result(). + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error executing the command + ServerOperationError: If the server encounters an error during execution + """ + pass + + @abstractmethod + def cancel_command(self, command_id: CommandId) -> None: + """ + Cancels a running command or query. + + This method attempts to cancel a command that is currently being executed. + It can be called from a different thread than the one executing the command. + + Args: + command_id: The command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error canceling the command + """ + pass + + @abstractmethod + def close_command(self, command_id: CommandId) -> ttypes.TStatus: + """ + Closes a command and releases associated resources. + + This method informs the server that the client is done with the command + and any resources associated with it can be released. + + Args: + command_id: The command identifier to close + + Returns: + ttypes.TStatus: The status of the close operation + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error closing the command + """ + pass + + @abstractmethod + def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState: + """ + Gets the current state of a query or command. + + This method retrieves the current execution state of a command from the server. + + Args: + command_id: The command identifier to check + + Returns: + ttypes.TOperationState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error retrieving the state + ServerOperationError: If the command is in an error state + DatabaseError: If the command has been closed unexpectedly + """ + pass + + @abstractmethod + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ) -> ExecuteResponse: + """ + Retrieves the results of a previously executed command. + + This method fetches the results of a command that was executed asynchronously + or retrieves additional results from a command that has more rows available. + + Args: + command_id: The command identifier for which to retrieve results + cursor: The cursor object that will handle the results + + Returns: + ExecuteResponse: An object containing the query results and metadata + + Raises: + ValueError: If the command ID is invalid + OperationalError: If there's an error retrieving the results + """ + pass + + # == Metadata Operations == + @abstractmethod + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> ExecuteResponse: + """ + Retrieves a list of available catalogs. + + This method fetches metadata about all catalogs available in the current + session's context. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + + Returns: + ExecuteResponse: An object containing the catalog metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the catalogs + """ + pass + + @abstractmethod + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> ExecuteResponse: + """ + Retrieves a list of schemas, optionally filtered by catalog and schema name patterns. + + This method fetches metadata about schemas available in the specified catalog + or all catalogs if no catalog is specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + + Returns: + ExecuteResponse: An object containing the schema metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the schemas + """ + pass + + @abstractmethod + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> ExecuteResponse: + """ + Retrieves a list of tables, optionally filtered by catalog, schema, table name, and table types. + + This method fetches metadata about tables available in the specified catalog + and schema, or all catalogs and schemas if not specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + table_name: Optional table name pattern to filter by + table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW']) + + Returns: + ExecuteResponse: An object containing the table metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the tables + """ + pass + + @abstractmethod + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> ExecuteResponse: + """ + Retrieves a list of columns, optionally filtered by catalog, schema, table, and column name patterns. + + This method fetches metadata about columns available in the specified table, + or all tables if not specified. + + Args: + session_id: The session identifier + max_rows: Maximum number of rows to fetch in a single batch + max_bytes: Maximum number of bytes to fetch in a single batch + cursor: The cursor object that will handle the results + catalog_name: Optional catalog name pattern to filter by + schema_name: Optional schema name pattern to filter by + table_name: Optional table name pattern to filter by + column_name: Optional column name pattern to filter by + + Returns: + ExecuteResponse: An object containing the column metadata + + Raises: + ValueError: If the session ID is invalid + OperationalError: If there's an error retrieving the columns + """ + pass + + @property + @abstractmethod + def max_download_threads(self) -> int: + """ + Gets the maximum number of download threads for cloud fetch operations. + + Returns: + int: The maximum number of download threads + """ + pass diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py similarity index 87% rename from src/databricks/sql/thrift_backend.py rename to src/databricks/sql/backend/thrift_backend.py index e3dc38ad..c09397c2 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,9 +5,18 @@ import time import uuid import threading -from typing import List, Union +from typing import List, Optional, Union, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from databricks.sql.client import Cursor from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState +from databricks.sql.backend.types import ( + SessionId, + CommandId, + BackendType, +) +from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -41,6 +50,7 @@ convert_column_based_set_to_arrow_table, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.databricks_client import DatabricksClient logger = logging.getLogger(__name__) @@ -73,7 +83,7 @@ } -class ThriftBackend: +class ThriftDatabricksClient(DatabricksClient): CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE @@ -91,7 +101,6 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, - staging_allowed_local_path: Union[None, str, List[str]] = None, **kwargs, ): # Internal arguments in **kwargs: @@ -150,7 +159,6 @@ def __init__( else: raise ValueError("No valid connection settings.") - self.staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -161,7 +169,7 @@ def __init__( ) # Cloud fetch - self.max_download_threads = kwargs.get("max_download_threads", 10) + self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options @@ -224,6 +232,10 @@ def __init__( self._request_lock = threading.RLock() + @property + def max_download_threads(self) -> int: + return self._max_download_threads + # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918) def _initialize_retry_args(self, kwargs): # Configure retries & timing: use user-settings or defaults, and bound @@ -446,8 +458,10 @@ def attempt_request(attempt): logger.error("ThriftBackend.attempt_request: Exception: %s", err) error = err retry_delay = extract_retry_delay(attempt) - error_message = ThriftBackend._extract_error_message_from_headers( - getattr(self._transport, "headers", {}) + error_message = ( + ThriftDatabricksClient._extract_error_message_from_headers( + getattr(self._transport, "headers", {}) + ) ) finally: # Calling `close()` here releases the active HTTP connection back to the pool @@ -483,7 +497,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftBackend._check_response_for_error(response) + ThriftDatabricksClient._check_response_for_error(response) return response error_info = response_or_error_info @@ -534,7 +548,7 @@ def _check_session_configuration(self, session_configuration): ) ) - def open_session(self, session_configuration, catalog, schema): + def open_session(self, session_configuration, catalog, schema) -> SessionId: try: self._transport.open() session_configuration = { @@ -562,13 +576,22 @@ def open_session(self, session_configuration, catalog, schema): response = self.make_request(self._client.OpenSession, open_session_req) self._check_initial_namespace(catalog, schema, response) self._check_protocol_version(response) - return response + properties = ( + {"serverProtocolVersion": response.serverProtocolVersion} + if response.serverProtocolVersion + else {} + ) + return SessionId.from_thrift_handle(response.sessionHandle, properties) except: self._transport.close() raise - def close_session(self, session_handle) -> None: - req = ttypes.TCloseSessionReq(sessionHandle=session_handle) + def close_session(self, session_id: SessionId) -> None: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") + + req = ttypes.TCloseSessionReq(sessionHandle=thrift_handle) try: self.make_request(self._client.CloseSession, req) finally: @@ -583,7 +606,7 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.displayMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, ) @@ -592,18 +615,18 @@ def _check_command_not_in_error_or_closed_state( get_operations_resp.errorMessage, { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid), + and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( "Command {} unexpectedly closed server side".format( - op_handle and self.guid_to_hex_id(op_handle.operationId.guid) + op_handle and guid_to_hex_id(op_handle.operationId.guid) ), { "operation-id": op_handle - and self.guid_to_hex_id(op_handle.operationId.guid) + and guid_to_hex_id(op_handle.operationId.guid) }, ) @@ -707,7 +730,8 @@ def _col_to_description(col): @staticmethod def _hive_schema_to_description(t_table_schema): return [ - ThriftBackend._col_to_description(col) for col in t_table_schema.columns + ThriftDatabricksClient._col_to_description(col) + for col in t_table_schema.columns ] def _results_message_to_execute_response(self, resp, operation_state): @@ -767,6 +791,9 @@ def _results_message_to_execute_response(self, resp, operation_state): ) else: arrow_queue_opt = None + + command_id = CommandId.from_thrift_handle(resp.operationHandle) + return ExecuteResponse( arrow_queue=arrow_queue_opt, status=operation_state, @@ -774,21 +801,24 @@ def _results_message_to_execute_response(self, resp, operation_state): has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=resp.operationHandle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) - def get_execution_result(self, op_handle, cursor): - - assert op_handle is not None + def get_execution_result( + self, command_id: CommandId, cursor: "Cursor" + ) -> ExecuteResponse: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=cursor.arraysize, maxBytes=cursor.buffer_size_bytes, @@ -834,7 +864,7 @@ def get_execution_result(self, op_handle, cursor): has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=op_handle, + command_id=command_id, description=description, arrow_schema_bytes=schema_bytes, ) @@ -857,51 +887,57 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_state(self, op_handle) -> "TOperationState": - poll_resp = self._poll_for_status(op_handle) + def get_query_state(self, command_id: CommandId) -> "TOperationState": + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") + + poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState - self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) + self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) return operation_state @staticmethod def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: if t_spark_direct_results.operationStatus: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.operationStatus ) if t_spark_direct_results.resultSetMetadata: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSetMetadata ) if t_spark_direct_results.resultSet: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSet ) if t_spark_direct_results.closeOperation: - ThriftBackend._check_response_for_error( + ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.closeOperation ) def execute_command( self, - operation, - session_handle, - max_rows, - max_bytes, - lz4_compression, - cursor, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", use_cloud_fetch=True, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, - ): - assert session_handle is not None + ) -> Optional[ExecuteResponse]: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") logger.debug( "ThriftBackend.execute_command(operation=%s, session_handle=%s)", operation, - session_handle, + thrift_handle, ) spark_arrow_types = ttypes.TSparkArrowTypes( @@ -913,7 +949,7 @@ def execute_command( intervalTypesAsArrow=False, ) req = ttypes.TExecuteStatementReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, statement=operation, runAsync=True, # For async operation we don't want the direct results @@ -938,14 +974,23 @@ def execute_command( if async_op: self._handle_execute_response_async(resp, cursor) + return None else: return self._handle_execute_response(resp, cursor) - def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): - assert session_handle is not None + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetCatalogsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -955,17 +1000,19 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): def get_schemas( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, - ): - assert session_handle is not None + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetSchemasReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -977,19 +1024,21 @@ def get_schemas( def get_tables( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, table_name=None, table_types=None, - ): - assert session_handle is not None + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetTablesReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -1003,19 +1052,21 @@ def get_tables( def get_columns( self, - session_handle, - max_rows, - max_bytes, - cursor, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", catalog_name=None, schema_name=None, table_name=None, column_name=None, - ): - assert session_handle is not None + ) -> ExecuteResponse: + thrift_handle = session_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift session ID") req = ttypes.TGetColumnsReq( - sessionHandle=session_handle, + sessionHandle=thrift_handle, getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), @@ -1028,7 +1079,9 @@ def get_columns( return self._handle_execute_response(resp, cursor) def _handle_execute_response(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) final_operation_state = self._wait_until_command_done( @@ -1039,28 +1092,31 @@ def _handle_execute_response(self, resp, cursor): return self._results_message_to_execute_response(resp, final_operation_state) def _handle_execute_response_async(self, resp, cursor): - cursor.active_op_handle = resp.operationHandle + command_id = CommandId.from_thrift_handle(resp.operationHandle) + cursor.active_command_id = command_id self._check_direct_results_for_error(resp.directResults) def fetch_results( self, - op_handle, - max_rows, - max_bytes, - expected_row_start_offset, - lz4_compressed, + command_id: CommandId, + max_rows: int, + max_bytes: int, + expected_row_start_offset: int, + lz4_compressed: bool, arrow_schema_bytes, description, use_cloud_fetch=True, ): - assert op_handle is not None + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") req = ttypes.TFetchResultsReq( operationHandle=ttypes.TOperationHandle( - op_handle.operationId, - op_handle.operationType, + thrift_handle.operationId, + thrift_handle.operationType, False, - op_handle.modifiedRowCount, + thrift_handle.modifiedRowCount, ), maxRows=max_rows, maxBytes=max_bytes, @@ -1089,46 +1145,21 @@ def fetch_results( return queue, resp.hasMoreRows - def close_command(self, op_handle): - logger.debug("ThriftBackend.close_command(op_handle=%s)", op_handle) - req = ttypes.TCloseOperationReq(operationHandle=op_handle) - resp = self.make_request(self._client.CloseOperation, req) - return resp.status + def cancel_command(self, command_id: CommandId) -> None: + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - def cancel_command(self, active_op_handle): - logger.debug( - "Cancelling command {}".format( - self.guid_to_hex_id(active_op_handle.operationId.guid) - ) - ) - req = ttypes.TCancelOperationReq(active_op_handle) + logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) - @staticmethod - def handle_to_id(session_handle): - return session_handle.sessionId.guid - - @staticmethod - def handle_to_hex_id(session_handle: TCLIService.TSessionHandle): - this_uuid = uuid.UUID(bytes=session_handle.sessionId.guid) - return str(this_uuid) + def close_command(self, command_id: CommandId): + thrift_handle = command_id.to_thrift_handle() + if not thrift_handle: + raise ValueError("Not a valid Thrift command ID") - @staticmethod - def guid_to_hex_id(guid: bytes) -> str: - """Return a hexadecimal string instead of bytes - - Example: - IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' - OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' - - If conversion to hexadecimal fails, the original bytes are returned - """ - - this_uuid: Union[bytes, uuid.UUID] - - try: - this_uuid = uuid.UUID(bytes=guid) - except Exception as e: - logger.debug(f"Unable to convert bytes to UUID: {bytes} -- {str(e)}") - this_uuid = guid - return str(this_uuid) + logger.debug("ThriftBackend.close_command(command_id=%s)", command_id) + req = ttypes.TCloseOperationReq(operationHandle=thrift_handle) + resp = self.make_request(self._client.CloseOperation, req) + return resp.status diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py new file mode 100644 index 00000000..740be019 --- /dev/null +++ b/src/databricks/sql/backend/types.py @@ -0,0 +1,306 @@ +from enum import Enum +from typing import Dict, Optional, Any, Union +import logging + +from databricks.sql.backend.utils import guid_to_hex_id + +logger = logging.getLogger(__name__) + + +class BackendType(Enum): + """ + Enum representing the type of backend + """ + + THRIFT = "thrift" + SEA = "sea" + + +class SessionId: + """ + A normalized session identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TSessionHandle and + SEA's session ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + properties: Optional[Dict[str, Any]] = None, + ): + """ + Initialize a SessionId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the session + secret: The secret part of the identifier (only used for Thrift) + properties: Additional information about the session + """ + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.properties = properties or {} + + def __str__(self) -> str: + """ + Return a string representation of the SessionId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the session ID + """ + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.get_hex_guid()}|{secret_hex}" + return str(self.guid) + + @classmethod + def from_thrift_handle( + cls, session_handle, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a Thrift session handle. + + Args: + session_handle: A TSessionHandle object from the Thrift API + + Returns: + A SessionId instance + """ + if session_handle is None: + return None + + guid_bytes = session_handle.sessionId.guid + secret_bytes = session_handle.sessionId.secret + + if session_handle.serverProtocolVersion is not None: + if properties is None: + properties = {} + properties["serverProtocolVersion"] = session_handle.serverProtocolVersion + + return cls(BackendType.THRIFT, guid_bytes, secret_bytes, properties) + + @classmethod + def from_sea_session_id( + cls, session_id: str, properties: Optional[Dict[str, Any]] = None + ): + """ + Create a SessionId from a SEA session ID. + + Args: + session_id: The SEA session ID string + + Returns: + A SessionId instance + """ + return cls(BackendType.SEA, session_id, properties=properties) + + def to_thrift_handle(self): + """ + Convert this SessionId to a Thrift TSessionHandle. + + Returns: + A TSessionHandle object or None if this is not a Thrift session ID + """ + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + server_protocol_version = self.properties.get("serverProtocolVersion") + return ttypes.TSessionHandle( + sessionId=handle_identifier, serverProtocolVersion=server_protocol_version + ) + + def to_sea_session_id(self): + """ + Get the SEA session ID string. + + Returns: + The session ID string or None if this is not a SEA session ID + """ + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + def get_guid(self) -> Any: + """ + Get the ID of the session. + """ + return self.guid + + def get_hex_guid(self) -> str: + """ + Get a hexadecimal string representation of the session ID. + + Returns: + A hexadecimal string representation + """ + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) + + def get_protocol_version(self): + """ + Get the server protocol version for this session. + + Returns: + The server protocol version or None if it does not exist + It is not expected to exist for SEA sessions. + """ + return self.properties.get("serverProtocolVersion") + + +class CommandId: + """ + A normalized command identifier that works with both Thrift and SEA backends. + + This class abstracts away the differences between Thrift's TOperationHandle and + SEA's statement ID string, providing a consistent interface for the connector. + """ + + def __init__( + self, + backend_type: BackendType, + guid: Any, + secret: Optional[Any] = None, + operation_type: Optional[int] = None, + has_result_set: bool = False, + modified_row_count: Optional[int] = None, + ): + """ + Initialize a CommandId. + + Args: + backend_type: The type of backend (THRIFT or SEA) + guid: The primary identifier for the command + secret: The secret part of the identifier (only used for Thrift) + operation_type: The operation type (only used for Thrift) + has_result_set: Whether the command has a result set + modified_row_count: The number of rows modified by the command + """ + self.backend_type = backend_type + self.guid = guid + self.secret = secret + self.operation_type = operation_type + self.has_result_set = has_result_set + self.modified_row_count = modified_row_count + + def __str__(self) -> str: + """ + Return a string representation of the CommandId. + + For SEA backend, returns the guid. + For Thrift backend, returns a format like "guid|secret". + + Returns: + A string representation of the command ID + """ + if self.backend_type == BackendType.SEA: + return str(self.guid) + elif self.backend_type == BackendType.THRIFT: + secret_hex = ( + guid_to_hex_id(self.secret) + if isinstance(self.secret, bytes) + else str(self.secret) + ) + return f"{self.to_hex_guid()}|{secret_hex}" + return str(self.guid) + + @classmethod + def from_thrift_handle(cls, operation_handle): + """ + Create a CommandId from a Thrift operation handle. + + Args: + operation_handle: A TOperationHandle object from the Thrift API + + Returns: + A CommandId instance + """ + if operation_handle is None: + return None + + guid_bytes = operation_handle.operationId.guid + secret_bytes = operation_handle.operationId.secret + + return cls( + BackendType.THRIFT, + guid_bytes, + secret_bytes, + operation_handle.operationType, + operation_handle.hasResultSet, + operation_handle.modifiedRowCount, + ) + + @classmethod + def from_sea_statement_id(cls, statement_id: str): + """ + Create a CommandId from a SEA statement ID. + + Args: + statement_id: The SEA statement ID string + + Returns: + A CommandId instance + """ + return cls(BackendType.SEA, statement_id) + + def to_thrift_handle(self): + """ + Convert this CommandId to a Thrift TOperationHandle. + + Returns: + A TOperationHandle object or None if this is not a Thrift command ID + """ + if self.backend_type != BackendType.THRIFT: + return None + + from databricks.sql.thrift_api.TCLIService import ttypes + + handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret) + return ttypes.TOperationHandle( + operationId=handle_identifier, + operationType=self.operation_type, + hasResultSet=self.has_result_set, + modifiedRowCount=self.modified_row_count, + ) + + def to_sea_statement_id(self): + """ + Get the SEA statement ID string. + + Returns: + The statement ID string or None if this is not a SEA statement ID + """ + if self.backend_type != BackendType.SEA: + return None + + return self.guid + + def to_hex_guid(self) -> str: + """ + Get a hexadecimal string representation of the command ID. + + Returns: + A hexadecimal string representation + """ + if isinstance(self.guid, bytes): + return guid_to_hex_id(self.guid) + else: + return str(self.guid) diff --git a/src/databricks/sql/backend/utils/__init__.py b/src/databricks/sql/backend/utils/__init__.py new file mode 100644 index 00000000..3d601e5e --- /dev/null +++ b/src/databricks/sql/backend/utils/__init__.py @@ -0,0 +1,3 @@ +from .guid_utils import guid_to_hex_id + +__all__ = ["guid_to_hex_id"] diff --git a/src/databricks/sql/backend/utils/guid_utils.py b/src/databricks/sql/backend/utils/guid_utils.py new file mode 100644 index 00000000..28975171 --- /dev/null +++ b/src/databricks/sql/backend/utils/guid_utils.py @@ -0,0 +1,22 @@ +import uuid +import logging + +logger = logging.getLogger(__name__) + + +def guid_to_hex_id(guid: bytes) -> str: + """Return a hexadecimal string instead of bytes + + Example: + IN b'\x01\xee\x1d)\xa4\x19\x1d\xb6\xa9\xc0\x8d\xf1\xfe\xbaB\xdd' + OUT '01ee1d29-a419-1db6-a9c0-8df1feba42dd' + + If conversion to hexadecimal fails, a string representation of the original + bytes is returned + """ + try: + this_uuid = uuid.UUID(bytes=guid) + except Exception as e: + logger.debug(f"Unable to convert bytes to UUID: {guid!r} -- {str(e)}") + return str(guid) + return str(this_uuid) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d6a9e6b0..1c384c73 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -21,7 +21,8 @@ CursorAlreadyClosedError, ) from databricks.sql.thrift_api.TCLIService import ttypes -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( ExecuteResponse, ParamEscaper, @@ -46,6 +47,7 @@ from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence from databricks.sql.session import Session +from databricks.sql.backend.types import CommandId, BackendType from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -230,7 +232,6 @@ def read(self) -> Optional[OAuthToken]: self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) self._cursors = [] # type: List[Cursor] - # Create the session self.session = Session( server_hostname, http_path, @@ -243,14 +244,10 @@ def read(self) -> Optional[OAuthToken]: ) self.session.open() - logger.info( - "Successfully opened connection with session " - + str(self.get_session_id_hex()) - ) - self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) ) + self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): """Valid values are True, False, and "silent" @@ -305,11 +302,11 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - """Get the session ID from the Session object""" + """Get the raw session ID (backend-specific)""" return self.session.get_id() def get_session_id_hex(self): - """Get the session ID in hex format from the Session object""" + """Get the session ID in hex format""" return self.session.get_id_hex() @staticmethod @@ -347,7 +344,7 @@ def cursor( cursor = Cursor( self, - self.session.thrift_backend, + self.session.backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -380,7 +377,7 @@ class Cursor: def __init__( self, connection: Connection, - thrift_backend: ThriftBackend, + backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = DEFAULT_ARRAY_SIZE, ) -> None: @@ -399,8 +396,8 @@ def __init__( # Note that Cursor closed => active result set closed, but not vice versa self.open = True self.executing_command_id = None - self.thrift_backend = thrift_backend - self.active_op_handle = None + self.backend = backend + self.active_command_id = None self.escaper = ParamEscaper() self.lastrowid = None @@ -774,9 +771,9 @@ def execute( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.execute_command( + execute_response = self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session.get_handle(), + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -786,10 +783,12 @@ def execute( async_op=False, enforce_embedded_schema_correctness=enforce_embedded_schema_correctness, ) + assert execute_response is not None # async_op = False above + self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, self.connection.use_cloud_fetch, @@ -797,7 +796,7 @@ def execute( if execute_response.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path ) return self @@ -837,9 +836,9 @@ def execute_async( self._check_not_closed() self._close_and_clear_active_result_set() - self.thrift_backend.execute_command( + self.backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session.get_handle(), + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -859,7 +858,9 @@ def get_query_state(self) -> "TOperationState": :return: """ self._check_not_closed() - return self.thrift_backend.get_query_state(self.active_op_handle) + if self.active_command_id is None: + raise Error("No active command to get state for") + return self.backend.get_query_state(self.active_command_id) def is_query_pending(self): """ @@ -889,20 +890,20 @@ def get_async_execution_result(self): operation_state = self.get_query_state() if operation_state == ttypes.TOperationState.FINISHED_STATE: - execute_response = self.thrift_backend.get_execution_result( - self.active_op_handle, self + execute_response = self.backend.get_execution_result( + self.active_command_id, self ) self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, ) if execute_response.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path ) return self @@ -934,8 +935,8 @@ def catalogs(self) -> "Cursor": """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_catalogs( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -943,9 +944,10 @@ def catalogs(self) -> "Cursor": self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -960,8 +962,8 @@ def schemas( """ self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_schemas( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -971,9 +973,10 @@ def schemas( self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -993,8 +996,8 @@ def tables( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_tables( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_tables( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1006,9 +1009,10 @@ def tables( self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -1028,8 +1032,8 @@ def columns( self._check_not_closed() self._close_and_clear_active_result_set() - execute_response = self.thrift_backend.get_columns( - session_handle=self.connection.session.get_handle(), + execute_response = self.backend.get_columns( + session_id=self.connection.session.get_session_id(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1041,9 +1045,10 @@ def columns( self.active_result_set = ResultSet( self.connection, execute_response, - self.thrift_backend, + self.backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) return self @@ -1117,8 +1122,8 @@ def cancel(self) -> None: The command should be closed to free resources from the server. This method can be called from another thread. """ - if self.active_op_handle is not None: - self.thrift_backend.cancel_command(self.active_op_handle) + if self.active_command_id is not None: + self.backend.cancel_command(self.active_command_id) else: logger.warning( "Attempting to cancel a command, but there is no " @@ -1130,9 +1135,9 @@ def close(self) -> None: self.open = False # Close active operation handle if it exists - if self.active_op_handle: + if self.active_command_id: try: - self.thrift_backend.close_command(self.active_op_handle) + self.backend.close_command(self.active_command_id) except RequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): logger.info("Operation was canceled by a prior request") @@ -1141,7 +1146,7 @@ def close(self) -> None: except Exception as e: logging.warning(f"Error closing operation handle: {e}") finally: - self.active_op_handle = None + self.active_command_id = None if self.active_result_set: self._close_and_clear_active_result_set() @@ -1154,8 +1159,8 @@ def query_id(self) -> Optional[str]: This attribute will be ``None`` if the cursor has not had an operation invoked via the execute method yet, or if cursor was closed. """ - if self.active_op_handle is not None: - return str(UUID(bytes=self.active_op_handle.operationId.guid)) + if self.active_command_id is not None: + return self.active_command_id.to_hex_guid() return None @property @@ -1207,7 +1212,7 @@ def __init__( self, connection: Connection, execute_response: ExecuteResponse, - thrift_backend: ThriftBackend, + backend: DatabricksClient, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -1217,18 +1222,20 @@ def __init__( :param connection: The parent connection that was used to execute this command :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param backend: The DatabricksClient instance to use for fetching results + :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch amount + :param arraysize: The max number of rows to fetch at a time (PEP-249) + :param use_cloud_fetch: Whether to use cloud fetch for retrieving results """ self.connection = connection - self.command_id = execute_response.command_handle + self.command_id = execute_response.command_id self.op_state = execute_response.status self.has_been_closed_server_side = execute_response.has_been_closed_server_side self.has_more_rows = execute_response.has_more_rows self.buffer_size_bytes = result_buffer_size_bytes self.lz4_compressed = execute_response.lz4_compressed self.arraysize = arraysize - self.thrift_backend = thrift_backend + self.backend = backend self.description = execute_response.description self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 @@ -1251,9 +1258,16 @@ def __iter__(self): break def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available - results, has_more_rows = self.thrift_backend.fetch_results( - op_handle=self.command_id, + if not isinstance(self.backend, ThriftDatabricksClient): + # currently, we are assuming only the Thrift backend exists + raise NotImplementedError( + "Fetching further result batches is currently only implemented for the Thrift backend." + ) + + # Now we know self.backend is ThriftDatabricksClient, so it has fetch_results + thrift_backend_instance = self.backend # type: ThriftDatabricksClient + results, has_more_rows = thrift_backend_instance.fetch_results( + command_id=self.command_id, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, expected_row_start_offset=self._next_row_index, @@ -1468,19 +1482,21 @@ def close(self) -> None: If the connection has not been closed, and the cursor has not already been closed on the server for some other reason, issue a request to the server to close it. """ + # TODO: the state is still thrift specific, define some ENUM for status that each service has to map to + # when we generalise the ResultSet try: if ( - self.op_state != self.thrift_backend.CLOSED_OP_STATE + self.op_state != ttypes.TOperationState.CLOSED_STATE and not self.has_been_closed_server_side and self.connection.open ): - self.thrift_backend.close_command(self.command_id) + self.backend.close_command(self.command_id) except RequestError as e: if isinstance(e.args[1], CursorAlreadyClosedError): logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = self.thrift_backend.CLOSED_OP_STATE + self.op_state = ttypes.TOperationState.CLOSED_STATE @staticmethod def _get_schema_description(table_schema_message): diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f2f38d57..2ee5e53f 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -7,7 +7,9 @@ from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) @@ -71,7 +73,7 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - self.thrift_backend = ThriftBackend( + self.backend: DatabricksClient = ThriftDatabricksClient( self.host, self.port, http_path, @@ -82,31 +84,21 @@ def __init__( **kwargs, ) - self._handle = None self.protocol_version = None - def open(self) -> None: - self._open_session_resp = self.thrift_backend.open_session( - self.session_configuration, self.catalog, self.schema + def open(self): + self._session_id = self.backend.open_session( + session_configuration=self.session_configuration, + catalog=self.catalog, + schema=self.schema, ) - self._handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.protocol_version = self.get_protocol_version(self._session_id) self.is_open = True logger.info("Successfully opened session " + str(self.get_id_hex())) @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_protocol_version(session_id: SessionId): + return session_id.get_protocol_version() @staticmethod def server_parameterized_queries_enabled(protocolVersion): @@ -118,20 +110,17 @@ def server_parameterized_queries_enabled(protocolVersion): else: return False - def get_handle(self): - return self._handle + def get_session_id(self) -> SessionId: + """Get the normalized session ID""" + return self._session_id def get_id(self): - handle = self.get_handle() - if handle is None: - return None - return self.thrift_backend.handle_to_id(handle) + """Get the raw session ID (backend-specific)""" + return self._session_id.get_guid() - def get_id_hex(self): - handle = self.get_handle() - if handle is None: - return None - return self.thrift_backend.handle_to_hex_id(handle) + def get_id_hex(self) -> str: + """Get the session ID in hex format""" + return self._session_id.get_hex_guid() def close(self) -> None: """Close the underlying session.""" @@ -141,7 +130,7 @@ def close(self) -> None: return try: - self.thrift_backend.close_session(self.get_handle()) + self.backend.close_session(self._session_id) except RequestError as e: if isinstance(e.args[1], SessionAlreadyClosedError): logger.info("Session was closed by a prior request") diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 186f13dd..c541ad3f 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -26,6 +26,7 @@ TSparkRowSetType, ) from databricks.sql.types import SSLOptions +from databricks.sql.backend.types import CommandId from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter @@ -345,7 +346,7 @@ def _create_empty_table(self) -> "pyarrow.Table": ExecuteResponse = namedtuple( "ExecuteResponse", "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_handle arrow_queue arrow_schema_bytes", + "command_id arrow_queue arrow_schema_bytes", ) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index abe0e22d..c446b671 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -822,11 +822,10 @@ def test_close_connection_closes_cursors(self): # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True # Cursor op state should be open before connection is closed status_request = ttypes.TGetOperationStatusReq( - operationHandle=ars.command_id, getProgressUpdate=False - ) - op_status_at_server = ars.thrift_backend._client.GetOperationStatus( - status_request + operationHandle=ars.command_id.to_thrift_handle(), + getProgressUpdate=False, ) + op_status_at_server = ars.backend._client.GetOperationStatus(status_request) assert ( op_status_at_server.operationState != ttypes.TOperationState.CLOSED_STATE @@ -836,7 +835,7 @@ def test_close_connection_closes_cursors(self): # When connection closes, any cursor operations should no longer exist at the server with pytest.raises(SessionAlreadyClosedError) as cm: - op_status_at_server = ars.thrift_backend._client.GetOperationStatus( + op_status_at_server = ars.backend._client.GetOperationStatus( status_request ) @@ -866,9 +865,9 @@ def test_cursor_close_properly_closes_operation(self): cursor = conn.cursor() try: cursor.execute("SELECT 1 AS test") - assert cursor.active_op_handle is not None + assert cursor.active_command_id is not None cursor.close() - assert cursor.active_op_handle is None + assert cursor.active_command_id is None assert not cursor.open finally: if cursor.open: @@ -894,19 +893,19 @@ def test_nested_cursor_context_managers(self): with self.connection() as conn: with conn.cursor() as cursor1: cursor1.execute("SELECT 1 AS test1") - assert cursor1.active_op_handle is not None + assert cursor1.active_command_id is not None with conn.cursor() as cursor2: cursor2.execute("SELECT 2 AS test2") - assert cursor2.active_op_handle is not None + assert cursor2.active_command_id is not None # After inner context manager exit, cursor2 should be not open assert not cursor2.open - assert cursor2.active_op_handle is None + assert cursor2.active_command_id is None # After outer context manager exit, cursor1 should be not open assert not cursor1.open - assert cursor1.active_op_handle is None + assert cursor1.active_command_id is None def test_cursor_error_handling(self): """Test that cursor close handles errors properly to prevent orphaned operations.""" @@ -915,12 +914,12 @@ def test_cursor_error_handling(self): cursor.execute("SELECT 1 AS test") - op_handle = cursor.active_op_handle + op_handle = cursor.active_command_id assert op_handle is not None # Manually close the operation to simulate server-side closure - conn.session.thrift_backend.close_command(op_handle) + conn.session.backend.close_command(op_handle) cursor.close() @@ -940,7 +939,7 @@ def test_result_set_close(self): result_set.close() - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state == result_set.backend.CLOSED_OP_STATE assert result_set.op_state != initial_op_state # Closing the result set again should be a no-op and not raise exceptions diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index a9c7a43a..fa6fae1d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,23 +15,24 @@ THandleIdentifier, TOperationType, ) -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row +from databricks.sql.client import CommandId from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite -class ThriftBackendMockFactory: +class ThriftDatabricksClientMockFactory: @classmethod def new(cls): - ThriftBackendMock = Mock(spec=ThriftBackend) + ThriftBackendMock = Mock(spec=ThriftDatabricksClient) ThriftBackendMock.return_value = ThriftBackendMock cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None) @@ -42,7 +43,7 @@ def new(cls): description=None, arrow_queue=None, is_staging_operation=False, - command_handle=b"\x22", + command_id=None, has_been_closed_server_side=True, has_more_rows=True, lz4_compressed=True, @@ -81,7 +82,10 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch( + "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, + ThriftDatabricksClientMockFactory.new(), + ) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): # Test once with has_been_closed_server side, once without @@ -98,7 +102,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): ) mock_result_set_class.return_value.close.assert_called_once_with() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -108,7 +112,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -125,7 +129,7 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_backend = Mock() result_set = client.ResultSet( connection=mock_connection, - thrift_backend=mock_backend, + backend=mock_backend, execute_response=Mock(), ) # Setup session mock on the mock_connection @@ -155,7 +159,7 @@ def test_closing_result_set_hard_closes_commands(self): result_set.close() mock_thrift_backend.close_command.assert_called_once_with( - mock_results_response.command_handle + mock_results_response.command_id ) @patch("%s.client.ResultSet" % PACKAGE_NAME) @@ -167,7 +171,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command( mock_result_set_class.side_effect = mock_result_sets cursor = client.Cursor( - connection=Mock(), thrift_backend=ThriftBackendMockFactory.new() + connection=Mock(), backend=ThriftDatabricksClientMockFactory.new() ) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -205,11 +209,11 @@ def test_context_manager_closes_cursor(self): mock_close.assert_called_once_with() cursor = client.Cursor(Mock(), Mock()) - cursor.close = Mock() + cursor.close = Mock() try: with self.assertRaises(KeyboardInterrupt): - with cursor: + with cursor: raise KeyboardInterrupt("Simulated interrupt") finally: cursor.close.assert_called() @@ -226,7 +230,7 @@ def dict_product(self, dicts): """ return (dict(zip(dicts.keys(), x)) for x in itertools.product(*dicts.values())) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -247,7 +251,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -270,7 +274,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen for k, v in req_args.items(): self.assertEqual(v, call_args[k]) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backend): req_args_combinations = self.dict_product( dict( @@ -296,10 +300,10 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe def test_cancel_command_calls_the_backend(self): mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) - mock_op_handle = Mock() - cursor.active_op_handle = mock_op_handle + mock_command_id = Mock() + cursor.active_command_id = mock_command_id cursor.cancel() - mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle) + mock_thrift_backend.cancel_command.assert_called_with(mock_command_id) @patch("databricks.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @@ -321,7 +325,7 @@ def test_version_is_canonical(self): self.assertIsNotNone(re.match(canonical_version_re, version)) def test_execute_parameter_passthrough(self): - mock_thrift_backend = ThriftBackendMockFactory.new() + mock_thrift_backend = ThriftDatabricksClientMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) tests = [ @@ -345,16 +349,16 @@ def test_execute_parameter_passthrough(self): expected_query, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class, mock_thrift_backend + self, mock_result_set_class ): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] mock_result_set_class.side_effect = mock_result_set_instances - mock_thrift_backend = ThriftBackendMockFactory.new() - cursor = client.Cursor(Mock(), mock_thrift_backend()) + mock_backend = ThriftDatabricksClientMockFactory.new() + + cursor = client.Cursor(Mock(), mock_backend) params = [{"x": None}, {"x": "foo1"}, {"x": "bar2"}] expected_queries = ["SELECT NULL", "SELECT 'foo1'", "SELECT 'bar2'"] @@ -362,13 +366,13 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( cursor.executemany("SELECT %(x)s", seq_of_parameters=params) self.assertEqual( - len(mock_thrift_backend.execute_command.call_args_list), + len(mock_backend.execute_command.call_args_list), len(expected_queries), "Expected execute_command to be called the same number of times as params were passed", ) for expected_query, call_args in zip( - expected_queries, mock_thrift_backend.execute_command.call_args_list + expected_queries, mock_backend.execute_command.call_args_list ): self.assertEqual(call_args[1]["operation"], expected_query) @@ -379,7 +383,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -392,14 +396,14 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): c.rollback() @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): def make_fake_row_slice(n_rows): mock_slice = Mock() @@ -424,7 +428,7 @@ def make_fake_row_slice(n_rows): self.assertEqual(cursor.rownumber, 29) @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_disable_pandas_respected(self, mock_thrift_backend_class): mock_thrift_backend = mock_thrift_backend_class.return_value mock_table = Mock() @@ -477,7 +481,7 @@ def test_column_name_api(self): }, ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -496,13 +500,13 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called - ThriftBackendMockFactory.apply_property_to_mock( + ThriftDatabricksClientMockFactory.apply_property_to_mock( mock_execute_response, is_staging_operation=True ) mock_client_class.execute_command.return_value = mock_execute_response @@ -515,7 +519,10 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch( + "%s.session.ThriftDatabricksClient" % PACKAGE_NAME, + ThriftDatabricksClientMockFactory.new(), + ) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" @@ -524,9 +531,13 @@ def test_access_current_query_id(self): self.assertIsNone(cursor.query_id) - cursor.active_op_handle = TOperationHandle( - operationId=THandleIdentifier(guid=UUID(operation_id).bytes, secret=0x00), - operationType=TOperationType.EXECUTE_STATEMENT, + cursor.active_command_id = CommandId.from_thrift_handle( + TOperationHandle( + operationId=THandleIdentifier( + guid=UUID(operation_id).bytes, secret=0x00 + ), + operationType=TOperationType.EXECUTE_STATEMENT, + ) ) self.assertEqual(cursor.query_id.upper(), operation_id.upper()) @@ -537,18 +548,18 @@ def test_cursor_close_handles_exception(self): """Test that Cursor.close() handles exceptions from close_command properly.""" mock_backend = Mock() mock_connection = Mock() - mock_op_handle = Mock() + mock_command_id = Mock() mock_backend.close_command.side_effect = Exception("Test error") cursor = client.Cursor(mock_connection, mock_backend) - cursor.active_op_handle = mock_op_handle + cursor.active_command_id = mock_command_id cursor.close() - mock_backend.close_command.assert_called_once_with(mock_op_handle) + mock_backend.close_command.assert_called_once_with(mock_command_id) - self.assertIsNone(cursor.active_op_handle) + self.assertIsNone(cursor.active_command_id) self.assertFalse(cursor.open) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 71766f2c..1c6a1b18 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -9,6 +9,7 @@ import databricks.sql.client as client from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") @@ -39,14 +40,14 @@ def make_dummy_result_set_from_initial_results(initial_results): arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) rs = client.ResultSet( connection=Mock(), - thrift_backend=None, + backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema_bytes=schema.serialize().to_pybytes(), is_staging_operation=False, @@ -64,7 +65,7 @@ def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 def fetch_results( - op_handle, + command_id, max_rows, max_bytes, expected_row_start_offset, @@ -79,13 +80,13 @@ def fetch_results( return results, batch_index < len(batch_list) - mock_thrift_backend = Mock() + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 rs = client.ResultSet( connection=Mock(), - thrift_backend=mock_thrift_backend, + backend=mock_thrift_backend, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=False, @@ -95,7 +96,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_handle=None, + command_id=None, arrow_queue=None, arrow_schema_bytes=None, is_staging_operation=False, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index 55287222..b302c00d 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -31,13 +31,13 @@ def make_dummy_result_set_from_initial_results(arrow_table): arrow_queue = ArrowQueue(arrow_table, arrow_table.num_rows, 0) rs = client.ResultSet( connection=None, - thrift_backend=None, + backend=None, execute_response=ExecuteResponse( status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), - command_handle=None, + command_id=None, arrow_queue=arrow_queue, arrow_schema=arrow_table.schema, ), diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index eec921e4..949230d1 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -22,6 +22,7 @@ TinyIntParameter, VoidParameter, ) +from databricks.sql.backend.types import SessionId from databricks.sql.parameters.native import ( TDbsqlParameter, TSparkParameterValue, @@ -42,7 +43,10 @@ class TestSessionHandleChecks(object): ( TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, - sessionHandle=TSessionHandle(1, None), + sessionHandle=TSessionHandle( + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=None, + ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, ), @@ -51,7 +55,8 @@ class TestSessionHandleChecks(object): TOpenSessionResp( serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, sessionHandle=TSessionHandle( - 1, ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + sessionId=ttypes.THandleIdentifier(guid=0x36, secret=0x37), + serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, ), ), ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8, @@ -59,7 +64,13 @@ class TestSessionHandleChecks(object): ], ) def test_get_protocol_version_fallback_behavior(self, test_input, expected): - assert Connection.get_protocol_version(test_input) == expected + properties = ( + {"serverProtocolVersion": test_input.serverProtocolVersion} + if test_input.serverProtocolVersion + else {} + ) + session_id = SessionId.from_thrift_handle(test_input.sessionHandle, properties) + assert Connection.get_protocol_version(session_id) == expected @pytest.mark.parametrize( "test_input,expected", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index eb392a22..858119f9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4,7 +4,10 @@ from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, + TSessionHandle, + THandleIdentifier, ) +from databricks.sql.backend.types import SessionId, BackendType import databricks.sql @@ -21,22 +24,23 @@ class SessionTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.close() - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): # Test that the following auth args work: # token = foo, @@ -63,7 +67,7 @@ def test_auth_args(self, mock_client_class): self.assertEqual(args["http_path"], http_path) connection.close() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) @@ -71,7 +75,7 @@ def test_http_header_passthrough(self, mock_client_class): call_args = mock_client_class.call_args[0][3] self.assertIn(("foo", "bar"), call_args) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, @@ -87,7 +91,7 @@ def test_tls_arg_passthrough(self, mock_client_class): self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -108,22 +112,23 @@ def test_useragent_header(self, mock_client_class): http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: pass - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): databricks.sql.connect( _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS @@ -133,54 +138,62 @@ def test_max_number_of_retries_passthrough(self, mock_client_class): mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() + + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + mock_client_class.return_value.open_session.return_value = mock_session_id + databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) + # Check that open_session was called with the correct session_configuration as keyword argument + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + self.assertEqual(call_kwargs["session_configuration"], mock_session_config) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + mock_client_class.return_value.open_session.return_value = mock_session_id + databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + # Check that open_session was called with the correct catalog and schema as keyword arguments + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + self.assertEqual(call_kwargs["catalog"], mock_cat) + self.assertEqual(call_kwargs["schema"], mock_schem) + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp + # Create a mock SessionId that will be returned by open_session + mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33") + instance.open_session.return_value = mock_session_id databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) # not strictly necessary as the refcount is 0, but just to be sure gc.collect() - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") + # Check that close_session was called with the correct SessionId + close_session_call_args = instance.close_session.call_args[0][0] + self.assertEqual(close_session_call_args.guid, b"\x22") + self.assertEqual(close_session_call_args.secret, b"\x33") if __name__ == "__main__": diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 458ea9a8..41a2a580 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -17,7 +17,8 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.thrift_backend import ThriftBackend +from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.types import CommandId, SessionId, BackendType def retry_policy_factory(): @@ -51,6 +52,7 @@ class ThriftBackendTestSuite(unittest.TestCase): open_session_resp = ttypes.TOpenSessionResp( status=okay_status, serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, + sessionHandle=session_handle, ) metadata_resp = ttypes.TGetResultSetMetadataResp( @@ -73,7 +75,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -92,7 +94,7 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -126,14 +128,16 @@ def test_hive_schema_to_arrow_schema_preserves_column_names(self): ] t_table_schema = ttypes.TTableSchema(columns) - arrow_schema = ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + arrow_schema = ThriftDatabricksClient._hive_schema_to_arrow_schema( + t_table_schema + ) self.assertEqual(arrow_schema.field(0).name, "column 1") self.assertEqual(arrow_schema.field(1).name, "column 2") self.assertEqual(arrow_schema.field(2).name, "column 2") self.assertEqual(arrow_schema.field(3).name, "") - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value bad_protocol_versions = [ @@ -163,7 +167,7 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): "expected server to use a protocol version", str(cm.exception) ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): t_http_client_instance = tcli_service_client_cass.return_value good_protocol_versions = [ @@ -174,7 +178,9 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): for protocol_version in good_protocol_versions: t_http_client_instance.OpenSession.return_value = ttypes.TOpenSessionResp( - status=self.okay_status, serverProtocolVersion=protocol_version + status=self.okay_status, + serverProtocolVersion=protocol_version, + sessionHandle=self.session_handle, ) thrift_backend = self._make_fake_thrift_backend() @@ -182,7 +188,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -229,7 +235,7 @@ def test_tls_cert_args_are_propagated( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -315,7 +321,7 @@ def test_tls_no_verify_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -339,7 +345,7 @@ def test_tls_verify_hostname_is_respected( mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend( + ThriftDatabricksClient( "foo", 123, "bar", @@ -356,7 +362,7 @@ def test_tls_verify_hostname_is_respected( @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -371,7 +377,7 @@ def test_port_and_host_are_respected(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname", 123, "path_value", @@ -386,7 +392,7 @@ def test_host_with_https_does_not_duplicate(self, t_http_client_class): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "https://hostname/", 123, "path_value", @@ -401,7 +407,7 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -413,7 +419,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -423,7 +429,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -434,7 +440,7 @@ def test_socket_timeout_is_propagated(self, t_http_client_class): self.assertEqual( t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 ) - ThriftBackend( + ThriftDatabricksClient( "hostname", 123, "path_value", @@ -467,9 +473,9 @@ def test_non_primitive_types_raise_error(self): t_table_schema = ttypes.TTableSchema(columns) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_arrow_schema(t_table_schema) + ThriftDatabricksClient._hive_schema_to_arrow_schema(t_table_schema) with self.assertRaises(OperationalError): - ThriftBackend._hive_schema_to_description(t_table_schema) + ThriftDatabricksClient._hive_schema_to_description(t_table_schema) def test_hive_schema_to_description_preserves_column_names_and_types(self): # Full coverage of all types is done in integration tests, this is just a @@ -493,7 +499,7 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, @@ -532,7 +538,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): ] t_table_schema = ttypes.TTableSchema(columns) - description = ThriftBackend._hive_schema_to_description(t_table_schema) + description = ThriftDatabricksClient._hive_schema_to_description(t_table_schema) self.assertEqual( description, [ @@ -545,7 +551,7 @@ def test_make_request_checks_status_code(self): ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -589,7 +595,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): closeOperation=None, ), ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -628,7 +634,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( closeOperation=None, ), ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -642,7 +648,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_checks_operation_state_in_polls( self, tcli_service_class ): @@ -672,7 +678,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( ) tcli_service_instance.GetOperationStatus.return_value = op_state_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -686,7 +692,7 @@ def test_handle_execute_response_checks_operation_state_in_polls( if op_state_resp.errorMessage: self.assertIn(op_state_resp.errorMessage, str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -710,7 +716,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ) tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -724,7 +730,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): self.assertEqual(display_message, str(cm.exception)) self.assertIn(diagnostic_info, str(cm.exception.message_with_context())) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_direct_results_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value @@ -750,7 +756,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -812,7 +818,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -825,7 +831,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): thrift_backend._handle_execute_response(error_resp, Mock()) self.assertIn("this is a bad error", str(cm.exception)) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_can_handle_without_direct_results( self, tcli_service_class ): @@ -863,7 +869,7 @@ def test_handle_execute_response_can_handle_without_direct_results( op_state_2, op_state_3, ] - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -900,7 +906,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): operationHandle=self.operation_handle, ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -917,7 +923,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ttypes.TOperationState.FINISHED_STATE, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_use_arrow_schema_if_available(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value arrow_schema_mock = MagicMock(name="Arrow schema mock") @@ -946,7 +952,7 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value hive_schema_mock = MagicMock(name="Hive schema mock") @@ -976,7 +982,7 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue ): @@ -1020,7 +1026,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( @patch( "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue ): @@ -1064,7 +1070,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( thrift_backend._handle_execute_response(execute_resp, Mock()) _, has_more_rows_resp = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1075,7 +1081,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( self.assertEqual(has_more_rows, has_more_rows_resp) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_arrow_batches_row_count_are_respected(self, tcli_service_class): # make some semi-real arrow batches and check the number of rows is correct in the queue tcli_service_instance = tcli_service_class.return_value @@ -1108,7 +1114,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1117,7 +1123,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): ssl_options=SSLOptions(), ) arrow_queue, has_more_results = thrift_backend.fetch_results( - op_handle=Mock(), + command_id=Mock(), max_rows=1, max_bytes=1, expected_row_start_offset=0, @@ -1128,14 +1134,14 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_execute_statement_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1157,14 +1163,14 @@ def test_execute_statement_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_catalogs_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1185,14 +1191,14 @@ def test_get_catalogs_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_schemas_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1222,14 +1228,14 @@ def test_get_schemas_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_tables_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1263,14 +1269,14 @@ def test_get_tables_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_get_columns_calls_client_and_handle_execute_response( self, tcli_service_class ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1304,12 +1310,12 @@ def test_get_columns_calls_client_and_handle_execute_response( response, cursor_mock ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1320,10 +1326,10 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1331,16 +1337,17 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_command(self.operation_handle) + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.close_command(command_id) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, self.operation_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1348,13 +1355,14 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend.close_session(self.session_handle) + session_id = SessionId.from_thrift_handle(self.session_handle) + thrift_backend.close_session(session_id) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception( self, tcli_service_class ): @@ -1392,7 +1400,7 @@ def test_non_arrow_non_column_based_set_triggers_exception( def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1403,12 +1411,16 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) - @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") - @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") + @patch( + "databricks.sql.backend.thrift_backend.convert_arrow_based_set_to_arrow_table" + ) + @patch( + "databricks.sql.backend.thrift_backend.convert_column_based_set_to_arrow_table" + ) def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1443,7 +1455,7 @@ def test_create_arrow_table_calls_correct_conversion_method( def test_convert_arrow_based_set_to_arrow_table( self, open_stream_mock, lz4_decompress_mock ): - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1597,17 +1609,18 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): self.assertEqual(arrow_table.column(2).to_pylist(), [1.15, 2.2, 3.3]) self.assertEqual(arrow_table.column(3).to_pylist(), [b"\x11", b"\x22", b"\x33"]) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cancel_command_uses_active_op_handle(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value thrift_backend = self._make_fake_thrift_backend() - active_op_handle_mock = Mock() - thrift_backend.cancel_command(active_op_handle_mock) + # Create a proper CommandId from the existing operation_handle + command_id = CommandId.from_thrift_handle(self.operation_handle) + thrift_backend.cancel_command(command_id) self.assertEqual( tcli_service_instance.CancelOperation.call_args[0][0].operationHandle, - active_op_handle_mock, + self.operation_handle, ) def test_handle_execute_response_sets_active_op_handle(self): @@ -1615,19 +1628,27 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() thrift_backend._results_message_to_execute_response = Mock() + + # Create a mock response with a real operation handle mock_resp = Mock() + mock_resp.operationHandle = ( + self.operation_handle + ) # Use the real operation handle from the test class mock_cursor = Mock() thrift_backend._handle_execute_response(mock_resp, mock_cursor) - self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) + self.assertEqual( + mock_resp.operationHandle, mock_cursor.active_command_id.to_thrift_handle() + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class @@ -1654,7 +1675,7 @@ def test_make_request_will_retry_GetOperationStatus( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1681,7 +1702,7 @@ def test_make_request_will_retry_GetOperationStatus( ) with self.assertLogs( - "databricks.sql.thrift_backend", level=logging.WARNING + "databricks.sql.backend.thrift_backend", level=logging.WARNING ) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1702,7 +1723,8 @@ def test_make_request_will_retry_GetOperationStatus( "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" ) @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos @@ -1731,7 +1753,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( EXPECTED_RETRIES = 2 - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1763,7 +1785,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1779,7 +1801,8 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503( @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch( - "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + "databricks.sql.backend.thrift_backend._retry_policy", + new_callable=retry_policy_factory, ) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class @@ -1791,7 +1814,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1820,7 +1843,7 @@ def test_make_request_will_read_error_message_headers_if_set( mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1944,7 +1967,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_count": 1, "_retry_stop_after_attempts_duration": 100, } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1959,7 +1982,12 @@ def test_retry_args_passthrough(self, mock_http_client): @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} - for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): + for k, ( + _, + _, + min, + max, + ) in databricks.sql.backend.thrift_backend._retry_policy.items(): retry_delay_test_args_and_expected_values[k] = ( (min - 1, min), (max + 1, max), @@ -1970,7 +1998,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -1986,7 +2014,7 @@ def test_retry_args_bounding(self, mock_http_client): for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_configuration_passthrough(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp @@ -1998,7 +2026,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2011,12 +2039,12 @@ def test_configuration_passthrough(self, tcli_client_class): open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertEqual(open_session_req.configuration, expected_config) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2036,13 +2064,14 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4, canUseMultipleCatalogs=can_use_multiple_cats, initialNamespace=ttypes.TNamespace(catalogName=cat, schemaName=schem), + sessionHandle=self.session_handle, ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2066,14 +2095,14 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_set_in_open_session_req( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2086,13 +2115,13 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req( open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertTrue(open_session_req.canUseMultipleCatalogs) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( self, tcli_client_class ): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2126,7 +2155,7 @@ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( ) backend.open_session({}, cat, schem) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value @@ -2135,9 +2164,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): serverProtocolVersion=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3, canUseMultipleCatalogs=True, initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), + sessionHandle=self.session_handle, ) - backend = ThriftBackend( + backend = ThriftDatabricksClient( "foobar", 443, "path", @@ -2154,8 +2184,10 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): str(cm.exception), ) - @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - @patch("databricks.sql.thrift_backend.ThriftBackend._handle_execute_response") + @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) + @patch( + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class ): @@ -2172,7 +2204,7 @@ def test_execute_command_sets_complex_type_fields_correctly( if decimals is not None: complex_arg_types["_use_arrow_native_decimals"] = decimals - thrift_backend = ThriftBackend( + thrift_backend = ThriftDatabricksClient( "foobar", 443, "path",