diff --git a/examples/experimental/sea_connector_test.py b/examples/experimental/sea_connector_test.py index abe6bd1a..87b62efe 100644 --- a/examples/experimental/sea_connector_test.py +++ b/examples/experimental/sea_connector_test.py @@ -6,34 +6,122 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) + +def test_sea_query_exec(): + """ + Test executing a query using the SEA backend with result compression. + + This function connects to a Databricks SQL endpoint using the SEA backend, + executes a simple query with result compression enabled and disabled, + and verifies that execution completes successfully. + """ + server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + access_token = os.environ.get("DATABRICKS_TOKEN") + catalog = os.environ.get("DATABRICKS_CATALOG") + + if not all([server_hostname, http_path, access_token]): + logger.error("Missing required environment variables.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) + sys.exit(1) + + try: + # Test with compression enabled + logger.info("Creating connection with LZ4 compression enabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=True, # Enable cloud fetch to use compression + enable_query_result_lz4_compression=True, # Enable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + logger.info(f"backend type: {type(connection.session.backend)}") + + # Execute a simple query with compression enabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query with LZ4 compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query with compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression enabled") + + # Test with compression disabled + logger.info("Creating connection with LZ4 compression disabled") + connection = Connection( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + catalog=catalog, + schema="default", + use_sea=True, + user_agent_entry="SEA-Test-Client", + use_cloud_fetch=False, # Enable cloud fetch + enable_query_result_lz4_compression=False, # Disable LZ4 compression + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" + ) + + # Execute a simple query with compression disabled + cursor = connection.cursor(arraysize=0, buffer_size_bytes=0) + logger.info("Executing query without compression: SELECT 1 as test_value") + cursor.execute("SELECT 1 as test_value") + logger.info("Query without compression executed successfully") + cursor.close() + connection.close() + logger.info("Successfully closed SEA session with compression disabled") + + except Exception as e: + logger.error(f"Error during SEA query execution test: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + sys.exit(1) + + logger.info("SEA query execution test with compression completed successfully") + + def test_sea_session(): """ Test opening and closing a SEA session using the connector. - + This function connects to a Databricks SQL endpoint using the SEA backend, opens a session, and then closes it. - + Required environment variables: - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint - DATABRICKS_TOKEN: Personal access token for authentication """ - server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME") http_path = os.environ.get("DATABRICKS_HTTP_PATH") access_token = os.environ.get("DATABRICKS_TOKEN") catalog = os.environ.get("DATABRICKS_CATALOG") - + if not all([server_hostname, http_path, access_token]): logger.error("Missing required environment variables.") - logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.") + logger.error( + "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." + ) sys.exit(1) - + logger.info(f"Connecting to {server_hostname}") logger.info(f"HTTP Path: {http_path}") if catalog: logger.info(f"Using catalog: {catalog}") - + try: logger.info("Creating connection with SEA backend...") connection = Connection( @@ -42,25 +130,33 @@ def test_sea_session(): access_token=access_token, catalog=catalog, schema="default", - use_sea=True, - user_agent_entry="SEA-Test-Client" # add custom user agent + use_sea=True, + user_agent_entry="SEA-Test-Client", # add custom user agent + ) + + logger.info( + f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}" ) - - logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}") logger.info(f"backend type: {type(connection.session.backend)}") - + # Close the connection logger.info("Closing the SEA session...") connection.close() logger.info("Successfully closed SEA session") - + except Exception as e: logger.error(f"Error testing SEA session: {str(e)}") import traceback + logger.error(traceback.format_exc()) sys.exit(1) - + logger.info("SEA session test completed successfully") + if __name__ == "__main__": + # Test session management test_sea_session() + + # Test query execution with compression + test_sea_query_exec() diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 20b059fa..bbca4c50 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -16,8 +16,6 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.backend.types import SessionId, CommandId, CommandState -from databricks.sql.utils import ExecuteResponse -from databricks.sql.types import SSLOptions # Forward reference for type hints from typing import TYPE_CHECKING @@ -88,34 +86,6 @@ def execute_command( async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: - """ - Executes a SQL command or query within the specified session. - - This method sends a SQL command to the server for execution and handles - the response. It can operate in both synchronous and asynchronous modes. - - Args: - operation: The SQL command or query to execute - session_id: The session identifier in which to execute the command - max_rows: Maximum number of rows to fetch in a single fetch batch - max_bytes: Maximum number of bytes to fetch in a single fetch batch - lz4_compression: Whether to use LZ4 compression for result data - cursor: The cursor object that will handle the results - use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets - parameters: List of parameters to bind to the query - async_op: Whether to execute the command asynchronously - enforce_embedded_schema_correctness: Whether to enforce schema correctness - - Returns: - If async_op is False, returns a ResultSet object containing the - query results and metadata. If async_op is True, returns None and the - results must be fetched later using get_execution_result(). - - Raises: - ValueError: If the session ID is invalid - OperationalError: If there's an error executing the command - ServerOperationError: If the server encounters an error during execution - """ pass @abstractmethod diff --git a/src/databricks/sql/backend/filters.py b/src/databricks/sql/backend/filters.py new file mode 100644 index 00000000..7f48b617 --- /dev/null +++ b/src/databricks/sql/backend/filters.py @@ -0,0 +1,143 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import ResultSet + +from databricks.sql.result_set import SeaResultSet + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets that can be applied to any backend. + + This class provides methods to filter result sets based on various criteria, + similar to the client-side filtering in the JDBC connector. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] + ) -> "SeaResultSet": + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + # Create a filtered version of the result set + filtered_response = result_set._response.copy() + + # If there's a result with rows, filter them + if ( + "result" in filtered_response + and "data_array" in filtered_response["result"] + ): + rows = filtered_response["result"]["data_array"] + filtered_rows = [row for row in rows if filter_func(row)] + filtered_response["result"]["data_array"] = filtered_rows + + # Update row count if present + if "row_count" in filtered_response["result"]: + filtered_response["result"]["row_count"] = len(filtered_rows) + + # Create a new result set with the filtered data + return SeaResultSet( + connection=result_set.connection, + sea_response=filtered_response, + sea_client=result_set.backend, + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + ) + + @staticmethod + def filter_by_column_values( + result_set: "ResultSet", + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> "ResultSet": + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + # Determine the type of result set and apply appropriate filtering + if isinstance(result_set, SeaResultSet): + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and isinstance(row[column_index], str) + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + # For other result set types, return the original (should be handled by specific implementations) + logger.warning( + f"Filtering not implemented for result set type: {type(result_set).__name__}" + ) + return result_set + + @staticmethod + def filter_tables_by_type( + result_set: "ResultSet", table_types: Optional[List[str]] = None + ) -> "ResultSet": + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is typically in the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=False + ) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a05..c7a4ed1b 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,34 @@ import logging import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +import uuid +import time +from typing import Dict, Set, Tuple, List, Optional, Any, Union, TYPE_CHECKING + +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError +from databricks.sql.exc import Error, NotSupportedError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, -) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -55,6 +66,9 @@ def _filter_session_configuration( class SeaDatabricksClient(DatabricksClient): """ Statement Execution API (SEA) implementation of the DatabricksClient interface. + + This implementation provides session management functionality for SEA, + while other operations raise NotImplementedError. """ # SEA API paths @@ -274,41 +288,222 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param.name, + value=param.value, + type=param.type if hasattr(param, "type") else None, + ) + ) + + format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY" + disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE" + result_compression = "LZ4_FRAME" if lz4_compression else "NONE" + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout="0s" if async_op else "10s", + on_wait_timeout="CONTINUE", + row_limit=max_rows if max_rows > 0 else None, + byte_limit=max_bytes if max_bytes > 0 else None, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, ) + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() + ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return None and let the client poll for results + if async_op: + return None + + # For synchronous operation, wait for the statement to complete + # Poll until the statement is done + status = response.status + state = status.state + + # Keep polling until we reach a terminal state + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(0.5) # add a small delay to avoid excessive API calls + state = self.get_query_state(command_id) + + if state != CommandState.SUCCEEDED: + raise ServerOperationError( + f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}", + { + "operation-id": command_id.to_sea_statement_id(), + "diagnostic-info": None, + }, + ) + + return self.get_execution_result(command_id, cursor) + def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + return SeaResultSet( + connection=cursor.connection, + sea_response=response_data, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, ) # == Metadata Operations == @@ -319,9 +514,22 @@ def get_catalogs( max_rows: int, max_bytes: int, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation="SHOW CATALOGS", + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, @@ -331,9 +539,30 @@ def get_schemas( cursor: "Cursor", catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_schemas") + + operation = f"SHOW SCHEMAS IN `{catalog_name}`" + + if schema_name: + operation += f" LIKE '{schema_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, @@ -345,9 +574,43 @@ def get_tables( schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_tables") + + operation = "SHOW TABLES IN " + ( + "ALL CATALOGS" + if catalog_name in [None, "*", "%"] + else f"CATALOG `{catalog_name}`" + ) + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" LIKE '{table_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types if specified + from databricks.sql.backend.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, @@ -359,6 +622,33 @@ def get_columns( schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> "ResultSet": + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise ValueError("Catalog name is required for get_columns") + + operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`" + + if schema_name: + operation += f" SCHEMA LIKE '{schema_name}'" + + if table_name: + operation += f" TABLE LIKE '{table_name}'" + + if column_name: + operation += f" LIKE '{column_name}'" + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/models/__init__.py b/src/databricks/sql/backend/sea/models/__init__.py index c9310d36..b7c8bd39 100644 --- a/src/databricks/sql/backend/sea/models/__init__.py +++ b/src/databricks/sql/backend/sea/models/__init__.py @@ -4,19 +4,49 @@ This package contains data models for SEA API requests and responses. """ +from databricks.sql.backend.sea.models.base import ( + ServiceError, + StatementStatus, + ExternalLink, + ResultData, + ColumnInfo, + ResultManifest, +) + from databricks.sql.backend.sea.models.requests import ( + StatementParameter, + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, ) from databricks.sql.backend.sea.models.responses import ( + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) __all__ = [ + # Base models + "ServiceError", + "StatementStatus", + "ExternalLink", + "ResultData", + "ColumnInfo", + "ResultManifest", # Request models + "StatementParameter", + "ExecuteStatementRequest", + "GetStatementRequest", + "CancelStatementRequest", + "CloseStatementRequest", "CreateSessionRequest", "DeleteSessionRequest", # Response models + "ExecuteStatementResponse", + "GetStatementResponse", "CreateSessionResponse", ] diff --git a/src/databricks/sql/backend/sea/models/base.py b/src/databricks/sql/backend/sea/models/base.py new file mode 100644 index 00000000..671f7be1 --- /dev/null +++ b/src/databricks/sql/backend/sea/models/base.py @@ -0,0 +1,68 @@ +""" +Base models for the SEA (Statement Execution API) backend. + +These models define the common structures used in SEA API requests and responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState + + +@dataclass +class ServiceError: + """Error information returned by the SEA API.""" + + message: str + error_code: Optional[str] = None + + +@dataclass +class StatementStatus: + """Status information for a statement execution.""" + + state: CommandState + error: Optional[ServiceError] = None + sql_state: Optional[str] = None + + +@dataclass +class ExternalLink: + """External link information for result data.""" + + external_link: str + expiration: str + chunk_index: int + + +@dataclass +class ResultData: + """Result data from a statement execution.""" + + data: Optional[List[List[Any]]] = None + external_links: Optional[List[ExternalLink]] = None + + +@dataclass +class ColumnInfo: + """Information about a column in the result set.""" + + name: str + type_name: str + type_text: str + nullable: bool = True + precision: Optional[int] = None + scale: Optional[int] = None + ordinal_position: Optional[int] = None + + +@dataclass +class ResultManifest: + """Manifest information for a result set.""" + + schema: List[ColumnInfo] + total_row_count: int + total_byte_count: int + truncated: bool = False + chunk_count: Optional[int] = None diff --git a/src/databricks/sql/backend/sea/models/requests.py b/src/databricks/sql/backend/sea/models/requests.py index 7966cb50..1c519d93 100644 --- a/src/databricks/sql/backend/sea/models/requests.py +++ b/src/databricks/sql/backend/sea/models/requests.py @@ -1,5 +1,111 @@ -from typing import Dict, Any, Optional -from dataclasses import dataclass +""" +Request models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API requests. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + + +@dataclass +class StatementParameter: + """Parameter for a SQL statement.""" + + name: str + value: Optional[str] = None + type: Optional[str] = None + + +@dataclass +class ExecuteStatementRequest: + """Request to execute a SQL statement.""" + + warehouse_id: str + statement: str + session_id: str + disposition: str = "EXTERNAL_LINKS" + format: str = "JSON_ARRAY" + wait_timeout: str = "10s" + on_wait_timeout: str = "CONTINUE" + row_limit: Optional[int] = None + byte_limit: Optional[int] = None + parameters: Optional[List[StatementParameter]] = None + catalog: Optional[str] = None + schema: Optional[str] = None + result_compression: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "warehouse_id": self.warehouse_id, + "session_id": self.session_id, + "statement": self.statement, + "disposition": self.disposition, + "format": self.format, + "wait_timeout": self.wait_timeout, + "on_wait_timeout": self.on_wait_timeout, + } + + if self.row_limit is not None and self.row_limit > 0: + result["row_limit"] = self.row_limit + + if self.byte_limit is not None and self.byte_limit > 0: + result["byte_limit"] = self.byte_limit + + if self.catalog: + result["catalog"] = self.catalog + + if self.schema: + result["schema"] = self.schema + + if self.result_compression: + result["result_compression"] = self.result_compression + + if self.parameters: + result["parameters"] = [ + { + "name": param.name, + **({"value": param.value} if param.value is not None else {}), + **({"type": param.type} if param.type is not None else {}), + } + for param in self.parameters + ] + + return result + + +@dataclass +class GetStatementRequest: + """Request to get information about a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CancelStatementRequest: + """Request to cancel a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} + + +@dataclass +class CloseStatementRequest: + """Request to close a statement.""" + + statement_id: str + + def to_dict(self) -> Dict[str, Any]: + """Convert the request to a dictionary for JSON serialization.""" + return {"statement_id": self.statement_id} @dataclass diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 1bb54590..d70459b9 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -1,5 +1,96 @@ -from typing import Dict, Any -from dataclasses import dataclass +""" +Response models for the SEA (Statement Execution API) backend. + +These models define the structures used in SEA API responses. +""" + +from typing import Dict, List, Any, Optional, Union +from dataclasses import dataclass, field + +from databricks.sql.backend.types import CommandState +from databricks.sql.backend.sea.models.base import ( + StatementStatus, + ResultManifest, + ResultData, + ServiceError, +) + + +@dataclass +class ExecuteStatementResponse: + """Response from executing a SQL statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse": + """Create an ExecuteStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) + + +@dataclass +class GetStatementResponse: + """Response from getting information about a statement.""" + + statement_id: str + status: StatementStatus + manifest: Optional[ResultManifest] = None + result: Optional[ResultData] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse": + """Create a GetStatementResponse from a dictionary.""" + status_data = data.get("status", {}) + error = None + if "error" in status_data: + error_data = status_data["error"] + error = ServiceError( + message=error_data.get("message", ""), + error_code=error_data.get("error_code"), + ) + + state = CommandState.from_sea_state(status_data.get("state", "")) + if state is None: + raise ValueError(f"Invalid state: {status_data.get('state', '')}") + + status = StatementStatus( + state=state, + error=error, + sql_state=status_data.get("sql_state"), + ) + + return cls( + statement_id=data.get("statement_id", ""), + status=status, + manifest=data.get("manifest"), # We'll parse this more fully if needed + result=data.get("result"), # We'll parse this more fully if needed + ) @dataclass diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index de388f1d..e03d6f23 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -5,11 +5,10 @@ import time import uuid import threading -from typing import List, Optional, Union, Any, TYPE_CHECKING +from typing import List, Union, Any, TYPE_CHECKING if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState from databricks.sql.backend.types import ( @@ -17,8 +16,9 @@ SessionId, CommandId, BackendType, + guid_to_hex_id, + ExecuteResponse, ) -from databricks.sql.backend.utils import guid_to_hex_id try: import pyarrow @@ -42,7 +42,7 @@ ) from databricks.sql.utils import ( - ExecuteResponse, + ResultSetQueueFactory, _bound, RequestErrorInfo, NoRetryReason, @@ -53,6 +53,7 @@ ) from databricks.sql.types import SSLOptions from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.result_set import ResultSet, ThriftResultSet logger = logging.getLogger(__name__) @@ -351,7 +352,6 @@ def make_request(self, method, request, retryable=True): Will stop retry attempts if total elapsed time + next retry delay would exceed _retry_stop_after_attempts_duration. """ - # basic strategy: build range iterator rep'ing number of available # retries. bounds can be computed from there. iterate over it with # retries until success or final failure achieved. @@ -797,23 +797,27 @@ def _results_message_to_execute_response(self, resp, operation_state): command_id = CommandId.from_thrift_handle(resp.operationHandle) - return ExecuteResponse( - arrow_queue=arrow_queue_opt, - status=CommandState.from_thrift_state(operation_state), - has_been_closed_server_side=has_been_closed_server_side, - has_more_rows=has_more_rows, - lz4_compressed=lz4_compressed, - is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, + status = CommandState.from_thrift_state(operation_state) + if status is None: + raise ValueError(f"Invalid operation state: {operation_state}") + + return ( + ExecuteResponse( + command_id=command_id, + status=status, + description=description, + has_more_rows=has_more_rows, + results_queue=arrow_queue_opt, + has_been_closed_server_side=has_been_closed_server_side, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + ), + schema_bytes, ) def get_execution_result( self, command_id: CommandId, cursor: "Cursor" ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = command_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift command ID") @@ -863,15 +867,14 @@ def get_execution_result( ) execute_response = ExecuteResponse( - arrow_queue=queue, - status=CommandState.from_thrift_state(resp.status), - has_been_closed_server_side=False, + command_id=command_id, + status=resp.status, + description=description, has_more_rows=has_more_rows, + results_queue=queue, + has_been_closed_server_side=False, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_id=command_id, - description=description, - arrow_schema_bytes=schema_bytes, ) return ThriftResultSet( @@ -881,6 +884,7 @@ def get_execution_result( buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=schema_bytes, ) def _wait_until_command_done(self, op_handle, initial_operation_status_resp): @@ -909,10 +913,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: poll_resp = self._poll_for_status(thrift_handle) operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp) - state = CommandState.from_thrift_state(operation_state) - if state is None: - raise ValueError(f"Unknown command state: {operation_state}") - return state + return CommandState.from_thrift_state(operation_state) @staticmethod def _check_direct_results_for_error(t_spark_direct_results): @@ -947,8 +948,6 @@ def execute_command( async_op=False, enforce_embedded_schema_correctness=False, ) -> Union["ResultSet", None]: - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -995,7 +994,9 @@ def execute_command( self._handle_execute_response_async(resp, cursor) return None else: - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1004,6 +1005,7 @@ def execute_command( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_catalogs( @@ -1013,8 +1015,6 @@ def get_catalogs( max_bytes: int, cursor: "Cursor", ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1027,7 +1027,9 @@ def get_catalogs( ) resp = self.make_request(self._client.GetCatalogs, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1036,6 +1038,7 @@ def get_catalogs( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_schemas( @@ -1047,8 +1050,6 @@ def get_schemas( catalog_name=None, schema_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1063,7 +1064,9 @@ def get_schemas( ) resp = self.make_request(self._client.GetSchemas, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1072,6 +1075,7 @@ def get_schemas( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_tables( @@ -1085,8 +1089,6 @@ def get_tables( table_name=None, table_types=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1103,7 +1105,9 @@ def get_tables( ) resp = self.make_request(self._client.GetTables, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1112,6 +1116,7 @@ def get_tables( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def get_columns( @@ -1125,8 +1130,6 @@ def get_columns( table_name=None, column_name=None, ) -> "ResultSet": - from databricks.sql.result_set import ThriftResultSet - thrift_handle = session_id.to_thrift_handle() if not thrift_handle: raise ValueError("Not a valid Thrift session ID") @@ -1143,7 +1146,9 @@ def get_columns( ) resp = self.make_request(self._client.GetColumns, req) - execute_response = self._handle_execute_response(resp, cursor) + execute_response, arrow_schema_bytes = self._handle_execute_response( + resp, cursor + ) return ThriftResultSet( connection=cursor.connection, @@ -1152,6 +1157,7 @@ def get_columns( buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, + arrow_schema_bytes=arrow_schema_bytes, ) def _handle_execute_response(self, resp, cursor): @@ -1165,7 +1171,12 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + ( + execute_response, + arrow_schema_bytes, + ) = self._results_message_to_execute_response(resp, final_operation_state) + execute_response.command_id = command_id + return execute_response, arrow_schema_bytes def _handle_execute_response_async(self, resp, cursor): command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -1226,7 +1237,7 @@ def cancel_command(self, command_id: CommandId) -> None: if not thrift_handle: raise ValueError("Not a valid Thrift command ID") - logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid))) + logger.debug("Cancelling command {}".format(command_id.guid)) req = ttypes.TCancelOperationReq(thrift_handle) self.make_request(self._client.CancelOperation, req) diff --git a/src/databricks/sql/backend/types.py b/src/databricks/sql/backend/types.py index 9cd21b5e..5bf02e0e 100644 --- a/src/databricks/sql/backend/types.py +++ b/src/databricks/sql/backend/types.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from enum import Enum -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import logging from databricks.sql.backend.utils import guid_to_hex_id @@ -80,6 +81,28 @@ def from_thrift_state( else: return None + @classmethod + def from_sea_state(cls, state: str) -> Optional["CommandState"]: + """ + Map SEA state string to CommandState enum. + + Args: + state: SEA state string + + Returns: + CommandState: The corresponding CommandState enum value + """ + state_mapping = { + "PENDING": cls.PENDING, + "RUNNING": cls.RUNNING, + "SUCCEEDED": cls.SUCCEEDED, + "FAILED": cls.FAILED, + "CLOSED": cls.CLOSED, + "CANCELED": cls.CANCELLED, + } + + return state_mapping.get(state, None) + class BackendType(Enum): """ @@ -285,28 +308,6 @@ def __init__( self.has_result_set = has_result_set self.modified_row_count = modified_row_count - def __str__(self) -> str: - """ - Return a string representation of the CommandId. - - For SEA backend, returns the guid. - For Thrift backend, returns a format like "guid|secret". - - Returns: - A string representation of the command ID - """ - - if self.backend_type == BackendType.SEA: - return str(self.guid) - elif self.backend_type == BackendType.THRIFT: - secret_hex = ( - guid_to_hex_id(self.secret) - if isinstance(self.secret, bytes) - else str(self.secret) - ) - return f"{self.to_hex_guid()}|{secret_hex}" - return str(self.guid) - @classmethod def from_thrift_handle(cls, operation_handle): """ @@ -318,7 +319,6 @@ def from_thrift_handle(cls, operation_handle): Returns: A CommandId instance """ - if operation_handle is None: return None @@ -394,3 +394,19 @@ def to_hex_guid(self) -> str: return guid_to_hex_id(self.guid) else: return str(self.guid) + + +@dataclass +class ExecuteResponse: + """Response from executing a SQL command.""" + + command_id: CommandId + status: CommandState + description: Optional[ + List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]] + ] = None + has_more_rows: bool = False + results_queue: Optional[Any] = None + has_been_closed_server_side: bool = False + lz4_compressed: bool = True + is_staging_operation: bool = False diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f7c060a..e145e4e5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -24,7 +24,6 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.utils import ( - ExecuteResponse, ParamEscaper, inject_parameters, transform_paramstyle, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index a0d8d357..fc859583 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -1,26 +1,23 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any, Union, TYPE_CHECKING +from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING import logging import time import pandas -from databricks.sql.backend.types import CommandId, CommandState - try: import pyarrow except ImportError: pyarrow = None if TYPE_CHECKING: - from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.client import Connection - from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ColumnTable, ColumnQueue +from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse logger = logging.getLogger(__name__) @@ -34,32 +31,31 @@ class ResultSet(ABC): def __init__( self, - connection: "Connection", - backend: "DatabricksClient", - command_id: CommandId, - op_state: Optional[CommandState], - has_been_closed_server_side: bool, + connection, + backend, arraysize: int, buffer_size_bytes: int, + command_id=None, + status=None, + has_been_closed_server_side: bool = False, + has_more_rows: bool = False, + results_queue=None, + description=None, + is_staging_operation: bool = False, ): - """ - A ResultSet manages the results of a single command. - - :param connection: The parent connection that was used to execute this command - :param backend: The specialised backend client to be invoked in the fetch phase - :param execute_response: A `ExecuteResponse` class returned by a command execution - :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch - amount :param arraysize: The max number of rows to fetch at a time (PEP-249) - """ - self.command_id = command_id - self.op_state = op_state - self.has_been_closed_server_side = has_been_closed_server_side + """Initialize the base ResultSet with common properties.""" self.connection = connection - self.backend = backend + self.backend = backend # Store the backend client directly self.arraysize = arraysize self.buffer_size_bytes = buffer_size_bytes self._next_row_index = 0 - self.description = None + self.description = description + self.command_id = command_id + self.status = status + self.has_been_closed_server_side = has_been_closed_server_side + self._has_more_rows = has_more_rows + self.results = results_queue + self._is_staging_operation = is_staging_operation def __iter__(self): while True: @@ -74,10 +70,9 @@ def rownumber(self): return self._next_row_index @property - @abstractmethod def is_staging_operation(self) -> bool: """Whether this result set represents a staging operation.""" - pass + return self._is_staging_operation # Define abstract methods that concrete implementations must implement @abstractmethod @@ -101,12 +96,12 @@ def fetchall(self) -> List[Row]: pass @abstractmethod - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": + def fetchmany_arrow(self, size: int) -> Any: """Fetch the next set of rows as an Arrow table.""" pass @abstractmethod - def fetchall_arrow(self) -> "pyarrow.Table": + def fetchall_arrow(self) -> Any: """Fetch all remaining rows as an Arrow table.""" pass @@ -119,7 +114,7 @@ def close(self) -> None: """ try: if ( - self.op_state != CommandState.CLOSED + self.status != CommandState.CLOSED and not self.has_been_closed_server_side and self.connection.open ): @@ -129,7 +124,7 @@ def close(self) -> None: logger.info("Operation was canceled by a prior request") finally: self.has_been_closed_server_side = True - self.op_state = CommandState.CLOSED + self.status = CommandState.CLOSED class ThriftResultSet(ResultSet): @@ -138,11 +133,12 @@ class ThriftResultSet(ResultSet): def __init__( self, connection: "Connection", - execute_response: ExecuteResponse, + execute_response: "ExecuteResponse", thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, + arrow_schema_bytes: Optional[bytes] = None, ): """ Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient. @@ -154,37 +150,33 @@ def __init__( buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch use_cloud_fetch: Whether to use cloud fetch for retrieving results + arrow_schema_bytes: Arrow schema bytes for the result set """ - super().__init__( - connection, - thrift_client, - execute_response.command_id, - execute_response.status, - execute_response.has_been_closed_server_side, - arraysize, - buffer_size_bytes, - ) - # Initialize ThriftResultSet-specific attributes - self.has_been_closed_server_side = execute_response.has_been_closed_server_side - self.has_more_rows = execute_response.has_more_rows - self.lz4_compressed = execute_response.lz4_compressed - self.description = execute_response.description - self._arrow_schema_bytes = execute_response.arrow_schema_bytes + self._arrow_schema_bytes = arrow_schema_bytes self._use_cloud_fetch = use_cloud_fetch - self._is_staging_operation = execute_response.is_staging_operation + self.lz4_compressed = execute_response.lz4_compressed - # Initialize results queue - if execute_response.arrow_queue: - # In this case the server has taken the fast path and returned an initial batch of - # results - self.results = execute_response.arrow_queue - else: - # In this case, there are results waiting on the server so we fetch now for simplicity + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=thrift_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=execute_response.command_id, + status=execute_response.status, + has_been_closed_server_side=execute_response.has_been_closed_server_side, + has_more_rows=execute_response.has_more_rows, + results_queue=execute_response.results_queue, + description=execute_response.description, + is_staging_operation=execute_response.is_staging_operation, + ) + + # Initialize results queue if not provided + if not self.results: self._fill_results_buffer() def _fill_results_buffer(self): - # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.backend.fetch_results( command_id=self.command_id, max_rows=self.arraysize, @@ -196,7 +188,7 @@ def _fill_results_buffer(self): use_cloud_fetch=self._use_cloud_fetch, ) self.results = results - self.has_more_rows = has_more_rows + self._has_more_rows = has_more_rows def _convert_columnar_table(self, table): column_names = [c[0] for c in self.description] @@ -248,7 +240,7 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": + def merge_columnar(self, result1, result2): """ Function to merge / combining the columnar results into a single result :param result1: @@ -280,7 +272,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -305,7 +297,7 @@ def fetchmany_columnar(self, size: int): while ( n_remaining_rows > 0 and not self.has_been_closed_server_side - and self.has_more_rows + and self._has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) @@ -320,7 +312,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() if isinstance(results, ColumnTable) and isinstance( @@ -346,7 +338,7 @@ def fetchall_columnar(self): results = self.results.remaining_rows() self._next_row_index += results.num_rows - while not self.has_been_closed_server_side and self.has_more_rows: + while not self.has_been_closed_server_side and self._has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() results = self.merge_columnar(results, partial_results) @@ -389,24 +381,110 @@ def fetchmany(self, size: int) -> List[Row]: else: return self._convert_arrow_table(self.fetchmany_arrow(size)) - @property - def is_staging_operation(self) -> bool: - """Whether this result set represents a staging operation.""" - return self._is_staging_operation - @staticmethod - def _get_schema_description(table_schema_message): +class SeaResultSet(ResultSet): + """ResultSet implementation for SEA backend.""" + + def __init__( + self, + connection, + sea_client, + buffer_size_bytes: int = 104857600, + arraysize: int = 10000, + execute_response=None, + sea_response=None, + ): """ - Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249 + Initialize a SeaResultSet with the response from a SEA query execution. + + Args: + connection: The parent connection + sea_client: The SeaDatabricksClient instance for direct access + buffer_size_bytes: Buffer size for fetching results + arraysize: Default number of rows to fetch + execute_response: Response from the execute command (new style) + sea_response: Direct SEA response (legacy style) """ + # Handle both initialization styles + if execute_response is not None: + # New style with ExecuteResponse + command_id = execute_response.command_id + status = execute_response.status + has_been_closed_server_side = execute_response.has_been_closed_server_side + has_more_rows = execute_response.has_more_rows + results_queue = execute_response.results_queue + description = execute_response.description + is_staging_operation = execute_response.is_staging_operation + self._response = getattr(execute_response, "sea_response", {}) + self.statement_id = command_id.to_sea_statement_id() if command_id else None + elif sea_response is not None: + # Legacy style with direct sea_response + self._response = sea_response + # Extract values from sea_response + command_id = CommandId.from_sea_statement_id( + sea_response.get("statement_id", "") + ) + self.statement_id = sea_response.get("statement_id", "") + + # Extract status + status_data = sea_response.get("status", {}) + status = CommandState.from_sea_state(status_data.get("state", "PENDING")) + + # Set defaults for other fields + has_been_closed_server_side = False + has_more_rows = False + results_queue = None + description = None + is_staging_operation = False + else: + raise ValueError("Either execute_response or sea_response must be provided") - def map_col_type(type_): - if type_.startswith("decimal"): - return "decimal" - else: - return type_ + # Call parent constructor with common attributes + super().__init__( + connection=connection, + backend=sea_client, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=status, + has_been_closed_server_side=has_been_closed_server_side, + has_more_rows=has_more_rows, + results_queue=results_queue, + description=description, + is_staging_operation=is_staging_operation, + ) - return [ - (column.name, map_col_type(column.datatype), None, None, None, None, None) - for column in table_schema_message.columns - ] + def _fill_results_buffer(self): + """Fill the results buffer from the backend.""" + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchone(self) -> Optional[Row]: + """ + Fetch the next row of a query result set, returning a single sequence, + or None when no more data is available. + """ + + raise NotImplementedError("fetchone is not implemented for SEA backend") + + def fetchmany(self, size: Optional[int] = None) -> List[Row]: + """ + Fetch the next set of rows of a query result, returning a list of rows. + + An empty sequence is returned when no more rows are available. + """ + + raise NotImplementedError("fetchmany is not implemented for SEA backend") + + def fetchall(self) -> List[Row]: + """ + Fetch all (remaining) rows of a query result, returning them as a list of rows. + """ + raise NotImplementedError("fetchall is not implemented for SEA backend") + + def fetchmany_arrow(self, size: int) -> Any: + """Fetch the next set of rows as an Arrow table.""" + raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend") + + def fetchall_arrow(self) -> Any: + """Fetch all remaining rows as an Arrow table.""" + raise NotImplementedError("fetchall_arrow is not implemented for SEA backend") diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 7c33d9b2..76aec467 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -10,7 +10,7 @@ from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId +from databricks.sql.backend.types import SessionId, BackendType logger = logging.getLogger(__name__) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2622b117..edb13ef6 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -349,13 +349,6 @@ def _create_empty_table(self) -> "pyarrow.Table": return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) -ExecuteResponse = namedtuple( - "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " - "command_id arrow_queue arrow_schema_bytes", -) - - def _bound(min_x, max_x, x): """Bound x by [min_x, max_x] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1a795087..090ec255 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,7 +26,7 @@ from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState -from databricks.sql.utils import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse from tests.unit.test_fetches import FetchTests from tests.unit.test_thrift_backend import ThriftBackendTestSuite @@ -121,10 +121,10 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # Verify initial state self.assertEqual(real_result_set.has_been_closed_server_side, closed) - expected_op_state = ( + expected_status = ( CommandState.CLOSED if closed else CommandState.SUCCEEDED ) - self.assertEqual(real_result_set.op_state, expected_op_state) + self.assertEqual(real_result_set.status, expected_status) # Mock execute_command to return our real result set cursor.backend.execute_command = Mock(return_value=real_result_set) @@ -146,8 +146,8 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): # 1. has_been_closed_server_side should always be True after close() self.assertTrue(real_result_set.has_been_closed_server_side) - # 2. op_state should always be CLOSED after close() - self.assertEqual(real_result_set.op_state, CommandState.CLOSED) + # 2. status should always be CLOSED after close() + self.assertEqual(real_result_set.status, CommandState.CLOSED) # 3. Backend close_command should be called appropriately if not closed: @@ -556,7 +556,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() - @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) + @patch("%s.backend.types.ExecuteResponse" % PACKAGE_NAME) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_staging_operation_response_is_handled( @@ -678,10 +678,10 @@ def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ThriftResultSet.__new__(client.ThriftResultSet) result_set.backend = Mock() - result_set.backend.CLOSED_OP_STATE = "CLOSED" + result_set.backend.CLOSED_OP_STATE = CommandState.CLOSED result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = "RUNNING" + result_set.status = CommandState.RUNNING result_set.has_been_closed_server_side = False result_set.command_id = Mock() @@ -695,7 +695,7 @@ def __init__(self): try: try: if ( - result_set.op_state != result_set.backend.CLOSED_OP_STATE + result_set.status != result_set.backend.CLOSED_OP_STATE and not result_set.has_been_closed_server_side and result_set.connection.open ): @@ -705,7 +705,7 @@ def __init__(self): pass finally: result_set.has_been_closed_server_side = True - result_set.op_state = result_set.backend.CLOSED_OP_STATE + result_set.status = result_set.backend.CLOSED_OP_STATE result_set.backend.close_command.assert_called_once_with( result_set.command_id @@ -713,7 +713,7 @@ def __init__(self): assert result_set.has_been_closed_server_side is True - assert result_set.op_state == result_set.backend.CLOSED_OP_STATE + assert result_set.status == result_set.backend.CLOSED_OP_STATE finally: pass diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 030510a6..7249a59e 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -8,7 +8,8 @@ pa = None import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue from databricks.sql.backend.thrift_backend import ThriftDatabricksClient from databricks.sql.result_set import ThriftResultSet @@ -42,14 +43,13 @@ def make_dummy_result_set_from_initial_results(initial_results): rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=True, has_more_rows=False, description=Mock(), lz4_compressed=Mock(), - command_id=None, - arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes(), + results_queue=arrow_queue, is_staging_operation=False, ), thrift_client=None, @@ -88,6 +88,7 @@ def fetch_results( rs = ThriftResultSet( connection=Mock(), execute_response=ExecuteResponse( + command_id=None, status=None, has_been_closed_server_side=False, has_more_rows=True, @@ -96,9 +97,7 @@ def fetch_results( for col_id in range(num_cols) ], lz4_compressed=Mock(), - command_id=None, - arrow_queue=None, - arrow_schema_bytes=None, + results_queue=None, is_staging_operation=False, ), thrift_client=mock_thrift_backend, diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index b302c00d..7e025cf8 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -10,7 +10,8 @@ import pytest import databricks.sql.client as client -from databricks.sql.utils import ExecuteResponse, ArrowQueue +from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.utils import ArrowQueue @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") diff --git a/tests/unit/test_result_set_filter.py b/tests/unit/test_result_set_filter.py new file mode 100644 index 00000000..e8eb2a75 --- /dev/null +++ b/tests/unit/test_result_set_filter.py @@ -0,0 +1,246 @@ +""" +Tests for the ResultSetFilter class. + +This module contains tests for the ResultSetFilter class, which provides +filtering capabilities for result sets returned by different backends. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.backend.filters import ResultSetFilter +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestResultSetFilter: + """Test suite for the ResultSetFilter class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response_with_tables(self): + """Create a sample SEA response with table data based on the server schema.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 7, + "columns": [ + {"name": "namespace", "type_text": "STRING", "position": 0}, + {"name": "tableName", "type_text": "STRING", "position": 1}, + {"name": "isTemporary", "type_text": "BOOLEAN", "position": 2}, + {"name": "information", "type_text": "STRING", "position": 3}, + {"name": "catalogName", "type_text": "STRING", "position": 4}, + {"name": "tableType", "type_text": "STRING", "position": 5}, + {"name": "remarks", "type_text": "STRING", "position": 6}, + ], + }, + "total_row_count": 4, + }, + "result": { + "data_array": [ + ["schema1", "table1", False, None, "catalog1", "TABLE", None], + ["schema1", "table2", False, None, "catalog1", "VIEW", None], + [ + "schema1", + "table3", + False, + None, + "catalog1", + "SYSTEM TABLE", + None, + ], + ["schema1", "table4", False, None, "catalog1", "EXTERNAL", None], + ] + }, + } + + @pytest.fixture + def sea_result_set_with_tables( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Create a SeaResultSet with table data.""" + # Create a deep copy of the response to avoid test interference + import copy + + sea_response_copy = copy.deepcopy(sea_response_with_tables) + + return SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy, + buffer_size_bytes=1000, + arraysize=100, + ) + + def test_filter_tables_by_type_default(self, sea_result_set_with_tables): + """Test filtering tables by type with default types.""" + # Default types are TABLE, VIEW, SYSTEM TABLE + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables + ) + + # Verify that only the default types are included + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_custom(self, sea_result_set_with_tables): + """Test filtering tables by type with custom types.""" + # Filter for only TABLE and EXTERNAL + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["TABLE", "EXTERNAL"] + ) + + # Verify that only the specified types are included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "EXTERNAL" in table_types + assert "VIEW" not in table_types + assert "SYSTEM TABLE" not in table_types + + def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables): + """Test that table type filtering is case-insensitive.""" + # Filter for lowercase "table" and "view" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=["table", "view"] + ) + + # Verify that the matching types are included despite case differences + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" not in table_types + assert "EXTERNAL" not in table_types + + def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables): + """Test filtering tables with an empty type list (should use defaults).""" + filtered_result = ResultSetFilter.filter_tables_by_type( + sea_result_set_with_tables, table_types=[] + ) + + # Verify that default types are used + assert len(filtered_result._response["result"]["data_array"]) == 3 + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + assert "SYSTEM TABLE" in table_types + assert "EXTERNAL" not in table_types + + def test_filter_by_column_values(self, sea_result_set_with_tables): + """Test filtering by values in a specific column.""" + # Filter by namespace in column index 0 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=0, allowed_values=["schema1"] + ) + + # All rows have schema1 in namespace, so all should be included + assert len(filtered_result._response["result"]["data_array"]) == 4 + + # Filter by table name in column index 1 + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, + column_index=1, + allowed_values=["table1", "table3"], + ) + + # Only rows with table1 or table3 should be included + assert len(filtered_result._response["result"]["data_array"]) == 2 + table_names = [ + row[1] for row in filtered_result._response["result"]["data_array"] + ] + assert "table1" in table_names + assert "table3" in table_names + assert "table2" not in table_names + assert "table4" not in table_names + + def test_filter_by_column_values_case_sensitive( + self, mock_connection, mock_sea_client, sea_response_with_tables + ): + """Test case-sensitive filtering by column values.""" + import copy + + # Create a fresh result set for the first test + sea_response_copy1 = copy.deepcopy(sea_response_with_tables) + result_set1 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy1, + buffer_size_bytes=1000, + arraysize=100, + ) + + # First test: Case-sensitive filtering with lowercase values (should find no matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set1, + column_index=5, # tableType column + allowed_values=["table", "view"], # lowercase + case_sensitive=True, + ) + + # Verify no matches with lowercase values + assert len(filtered_result._response["result"]["data_array"]) == 0 + + # Create a fresh result set for the second test + sea_response_copy2 = copy.deepcopy(sea_response_with_tables) + result_set2 = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response_copy2, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Second test: Case-sensitive filtering with correct case (should find matches) + filtered_result = ResultSetFilter.filter_by_column_values( + result_set2, + column_index=5, # tableType column + allowed_values=["TABLE", "VIEW"], # correct case + case_sensitive=True, + ) + + # Verify matches with correct case + assert len(filtered_result._response["result"]["data_array"]) == 2 + + # Extract the table types from the filtered results + table_types = [ + row[5] for row in filtered_result._response["result"]["data_array"] + ] + assert "TABLE" in table_types + assert "VIEW" in table_types + + def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables): + """Test filtering with a column index that's out of bounds.""" + # Filter by column index 10 (out of bounds) + filtered_result = ResultSetFilter.filter_by_column_values( + sea_result_set_with_tables, column_index=10, allowed_values=["value"] + ) + + # No rows should match + assert len(filtered_result._response["result"]["data_array"]) == 0 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a6..c2078f73 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,20 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + +import json import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import Error, NotSupportedError class TestSeaBackend: @@ -41,6 +50,23 @@ def sea_client(self, mock_http_client): return client + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + return cursor + def test_init_extracts_warehouse_id(self, mock_http_client): """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" # Test with warehouses format @@ -175,109 +201,650 @@ def test_close_session_invalid_id_type(self, sea_client): assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" - ) - assert default_value == "true" + # Tests for command execution and management + + def test_execute_command_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command synchronously.""" + # Set up mock responses + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "warehouse_id" in kwargs["data"] + assert "session_id" in kwargs["data"] + assert "statement" in kwargs["data"] + assert kwargs["data"]["statement"] == "SELECT 1" + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" + + def test_execute_command_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command asynchronously.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + # Call the method + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, # Async mode + enforce_embedded_schema_correctness=False, ) - assert default_value is None - - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + + # Verify the result is None for async operation + assert result is None + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.STATEMENT_PATH + assert "wait_timeout" in kwargs["data"] + assert kwargs["data"]["wait_timeout"] == "0s" # Async mode uses 0s timeout + + # Verify the command ID was stored in the cursor + assert hasattr(mock_cursor, "active_command_id") + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + def test_execute_command_with_polling( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that requires polling.""" + # Set up mock responses for initial request and polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, + } + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + + # Configure mock to return different responses on subsequent calls + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + # Mock the get_execution_result method + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify the HTTP requests (initial and poll) + assert mock_http_client._make_request.call_count == 2 + + # Verify get_execution_result was called with the right command ID + mock_get_result.assert_called_once() + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-789" + + def test_execute_command_with_parameters( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command with parameters.""" + # Set up mock response + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, } - assert set(allowed_configs) == expected_keys + mock_http_client._make_request.return_value = execute_response - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + # Create parameter mock + param = MagicMock() + param.name = "param1" + param.value = "value1" + param.type = "STRING" - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + # Mock the get_execution_result method + with patch.object(sea_client, "get_execution_result") as mock_get_result: + # Call the method with parameters sea_client.execute_command( - operation="SELECT 1", - session_id=session_id, + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, - parameters=[], + parameters=[param], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) - - # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) - - # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) - - # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) - - # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) - - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) - - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) - - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 - - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + + # Verify the HTTP request contains parameters + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + def test_execute_command_failure( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test executing a command that fails.""" + # Set up mock response for a failed execution + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + + # Configure the mock to return the error response for the initial request + # and then raise an exception when trying to poll (to simulate immediate failure) + mock_http_client._make_request.side_effect = [ + error_response, # Initial response + Error( + "Statement execution did not succeed: Syntax error in SQL" + ), # Will be raised during polling + ] + + # Mock time.sleep to avoid actual delays + with patch("time.sleep"): + # Call the method and expect an error + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert "Statement execution did not succeed" in str(excinfo.value) + + def test_cancel_command(self, sea_client, mock_http_client, sea_command_id): + """Test canceling a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.cancel_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "POST" + assert kwargs["path"] == sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_close_command(self, sea_client, mock_http_client, sea_command_id): + """Test closing a command.""" + # Set up mock response + mock_http_client._make_request.return_value = {} + + # Call the method + sea_client.close_command(sea_command_id) + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "DELETE" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + def test_get_query_state(self, sea_client, mock_http_client, sea_command_id): + """Test getting the state of a query.""" + # Set up mock response + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + + # Call the method + state = sea_client.get_query_state(sea_command_id) + + # Verify the result + assert state == CommandState.RUNNING + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_get_execution_result( + self, sea_client, mock_http_client, mock_cursor, sea_command_id + ): + """Test getting the result of a command execution.""" + # Set up mock response + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + + # Create a real result set to verify the implementation + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + print(result) + + # Verify basic properties of the result + assert result.statement_id == "test-statement-123" + assert result.status == CommandState.SUCCEEDED + + # Verify the HTTP request + mock_http_client._make_request.assert_called_once() + args, kwargs = mock_http_client._make_request.call_args + assert kwargs["method"] == "GET" + assert kwargs["path"] == sea_client.STATEMENT_PATH_WITH_ID.format( + "test-statement-123" + ) + + # Tests for metadata operations + + def test_get_catalogs(self, sea_client, mock_cursor, sea_session_id): + """Test getting catalogs.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["lz4_compression"] is False + assert kwargs["use_cloud_fetch"] is False + assert kwargs["parameters"] == [] + assert kwargs["async_op"] is False + assert kwargs["enforce_embedded_schema_correctness"] is False + + def test_get_schemas_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW SCHEMAS IN `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_with_catalog_and_schema_pattern( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting schemas with catalog and schema pattern.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog and schema pattern + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW SCHEMAS IN `test_catalog` LIKE 'test_schema'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_schemas_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting schemas without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + ) + + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_all_catalogs( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables from all catalogs using wildcard.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW TABLES IN ALL CATALOGS" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_with_schema_and_table_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting tables with schema and table patterns.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with catalog, schema, and table patterns + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW TABLES IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' LIKE 'test_table'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_tables_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting tables without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_tables" in str(excinfo.value) + + def test_get_columns_with_catalog_only( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with only catalog name.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with only catalog + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert kwargs["operation"] == "SHOW COLUMNS IN CATALOG `test_catalog`" + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_with_all_patterns( + self, sea_client, mock_cursor, sea_session_id + ): + """Test getting columns with all patterns specified.""" + # Mock execute_command to verify it's called with the right parameters + with patch.object( + sea_client, "execute_command", return_value="mock_result_set" + ) as mock_execute: + # Call the method with all patterns + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", + ) + + # Verify the result + assert result == "mock_result_set" + + # Verify execute_command was called with the right parameters + mock_execute.assert_called_once() + args, kwargs = mock_execute.call_args + assert ( + kwargs["operation"] + == "SHOW COLUMNS IN CATALOG `test_catalog` SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'" + ) + assert kwargs["session_id"] == sea_session_id + assert kwargs["max_rows"] == 100 + assert kwargs["max_bytes"] == 1000 + assert kwargs["cursor"] == mock_cursor + assert kwargs["async_op"] is False + + def test_get_columns_no_catalog_name(self, sea_client, mock_cursor, sea_session_id): + """Test getting columns without a catalog name raises an error.""" + with pytest.raises(ValueError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name=None, # No catalog name + schema_name="test_schema", + table_name="test_table", + ) + + assert "Catalog name is required for get_columns" in str(excinfo.value) diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py new file mode 100644 index 00000000..f666fd61 --- /dev/null +++ b/tests/unit/test_sea_result_set.py @@ -0,0 +1,275 @@ +""" +Tests for the SeaResultSet class. + +This module contains tests for the SeaResultSet class, which implements +the result set functionality for the SEA (Statement Execution API) backend. +""" + +import pytest +from unittest.mock import patch, MagicMock, Mock + +from databricks.sql.result_set import SeaResultSet +from databricks.sql.backend.types import CommandId, CommandState, BackendType + + +class TestSeaResultSet: + """Test suite for the SeaResultSet class.""" + + @pytest.fixture + def mock_connection(self): + """Create a mock connection.""" + connection = Mock() + connection.open = True + return connection + + @pytest.fixture + def mock_sea_client(self): + """Create a mock SEA client.""" + return Mock() + + @pytest.fixture + def sea_response(self): + """Create a sample SEA response.""" + return { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + + @pytest.fixture + def execute_response(self): + """Create a sample execute response.""" + mock_response = Mock() + mock_response.command_id = CommandId.from_sea_statement_id("test-statement-123") + mock_response.status = CommandState.SUCCEEDED + mock_response.has_been_closed_server_side = False + mock_response.has_more_rows = False + mock_response.results_queue = None + mock_response.description = [ + ("test_value", "INT", None, None, None, None, None) + ] + mock_response.is_staging_operation = False + mock_response.sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "result": {"data_array": [["1"]]}, + } + return mock_response + + def test_init_with_sea_response( + self, mock_connection, mock_sea_client, sea_response + ): + """Test initializing SeaResultSet with a SEA response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == sea_response + + def test_init_with_execute_response( + self, mock_connection, mock_sea_client, execute_response + ): + """Test initializing SeaResultSet with an execute response.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + execute_response=execute_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Verify basic properties + assert result_set.statement_id == "test-statement-123" + assert result_set.status == CommandState.SUCCEEDED + assert result_set.command_id.guid == "test-statement-123" + assert result_set.command_id.backend_type == BackendType.SEA + assert result_set.connection == mock_connection + assert result_set.backend == mock_sea_client + assert result_set.buffer_size_bytes == 1000 + assert result_set.arraysize == 100 + assert result_set._response == execute_response.sea_response + + def test_init_with_no_response(self, mock_connection, mock_sea_client): + """Test that initialization fails when neither response type is provided.""" + with pytest.raises(ValueError) as excinfo: + SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + buffer_size_bytes=1000, + arraysize=100, + ) + assert "Either execute_response or sea_response must be provided" in str( + excinfo.value + ) + + def test_close(self, mock_connection, mock_sea_client, sea_response): + """Test closing a result set.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was called + mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_already_closed_server_side( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set that has already been closed server-side.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + result_set.has_been_closed_server_side = True + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_close_when_connection_closed( + self, mock_connection, mock_sea_client, sea_response + ): + """Test closing a result set when the connection is closed.""" + mock_connection.open = False + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Close the result set + result_set.close() + + # Verify the backend's close_command was NOT called + mock_sea_client.close_command.assert_not_called() + assert result_set.has_been_closed_server_side is True + assert result_set.status == CommandState.CLOSED + + def test_unimplemented_methods( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that unimplemented methods raise NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + # Test each unimplemented method individually with specific error messages + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set.fetchone() + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + result_set.fetchmany(10) + + with pytest.raises( + NotImplementedError, match="fetchmany is not implemented for SEA backend" + ): + # Test with default parameter value + result_set.fetchmany() + + with pytest.raises( + NotImplementedError, match="fetchall is not implemented for SEA backend" + ): + result_set.fetchall() + + with pytest.raises( + NotImplementedError, + match="fetchmany_arrow is not implemented for SEA backend", + ): + result_set.fetchmany_arrow(10) + + with pytest.raises( + NotImplementedError, + match="fetchall_arrow is not implemented for SEA backend", + ): + result_set.fetchall_arrow() + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test iteration protocol (calls fetchone internally) + next(iter(result_set)) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + # Test using the result set in a for loop + for row in result_set: + pass + + def test_fill_results_buffer_not_implemented( + self, mock_connection, mock_sea_client, sea_response + ): + """Test that _fill_results_buffer raises NotImplementedError.""" + result_set = SeaResultSet( + connection=mock_connection, + sea_client=mock_sea_client, + sea_response=sea_response, + buffer_size_bytes=1000, + arraysize=100, + ) + + with pytest.raises( + NotImplementedError, match="fetchone is not implemented for SEA backend" + ): + result_set._fill_results_buffer() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 92de8d8f..fef07036 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,11 @@ from unittest.mock import patch, Mock import gc +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, + TSessionHandle, + THandleIdentifier, +) from databricks.sql.backend.types import SessionId, BackendType import databricks.sql diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 57a2a61e..b8de970d 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -619,11 +619,18 @@ def test_handle_execute_response_sets_compression_in_direct_results( lz4Compressed = Mock() resultSet = MagicMock() resultSet.results.startRowOffset = 0 + + # Create a valid operation status + op_status = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + t_execute_resp = resp_type( status=Mock(), operationHandle=Mock(), directResults=ttypes.TSparkDirectResults( - operationStatus=Mock(), + operationStatus=op_status, resultSetMetadata=ttypes.TGetResultSetMetadataResp( status=self.okay_status, resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET, @@ -644,7 +651,7 @@ def test_handle_execute_response_sets_compression_in_direct_results( ssl_options=SSLOptions(), ) - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -878,11 +885,12 @@ def test_handle_execute_response_can_handle_without_direct_results( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - results_message_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) + self.assertEqual( - results_message_response.status, + execute_response.status, CommandState.SUCCEEDED, ) @@ -915,7 +923,9 @@ def test_handle_execute_response_can_handle_with_direct_results(self): auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -943,15 +953,21 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = ( t_get_result_set_metadata_resp ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response( t_execute_resp, Mock() ) - self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) + self.assertEqual(arrow_schema_bytes, arrow_schema_mock) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): @@ -971,6 +987,12 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): operationHandle=self.operation_handle, ) + # Mock the operation status response + op_state = ttypes.TGetOperationStatusResp( + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, + ) + tcli_service_instance.GetOperationStatus.return_value = op_state tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -1018,7 +1040,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response( + execute_response, _ = thrift_backend._handle_execute_response( execute_resp, Mock() ) @@ -1150,7 +1172,7 @@ def test_execute_statement_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.execute_command( @@ -1184,7 +1206,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock) @@ -1215,7 +1237,7 @@ def test_get_schemas_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_schemas( @@ -1255,7 +1277,7 @@ def test_get_tables_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_tables( @@ -1299,7 +1321,7 @@ def test_get_columns_calls_client_and_handle_execute_response( auth_provider=AuthProvider(), ssl_options=SSLOptions(), ) - thrift_backend._handle_execute_response = Mock() + thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock())) cursor_mock = Mock() result = thrift_backend.get_columns( @@ -1645,7 +1667,9 @@ def test_handle_execute_response_sets_active_op_handle(self): thrift_backend = self._make_fake_thrift_backend() thrift_backend._check_direct_results_for_error = Mock() thrift_backend._wait_until_command_done = Mock() - thrift_backend._results_message_to_execute_response = Mock() + thrift_backend._results_message_to_execute_response = Mock( + return_value=(Mock(), Mock()) + ) # Create a mock response with a real operation handle mock_resp = Mock() @@ -2204,7 +2228,8 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) @patch( - "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response" + "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response", + return_value=(Mock(), Mock()), ) def test_execute_command_sets_complex_type_fields_correctly( self, mock_handle_execute_response, tcli_service_class