diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 76903ccd2..bfc0c6c9e 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import time import re @@ -10,11 +12,12 @@ ResultDisposition, ResultCompression, WaitTimeout, + MetadataCommands, ) if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.result_set import ResultSet + from databricks.sql.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -24,7 +27,7 @@ BackendType, ExecuteResponse, ) -from databricks.sql.exc import DatabaseError, ServerOperationError +from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions @@ -169,7 +172,7 @@ def _extract_warehouse_id(self, http_path: str) -> str: f"Note: SEA only works for warehouses." ) logger.error(error_message) - raise ValueError(error_message) + raise ProgrammingError(error_message) @property def max_download_threads(self) -> int: @@ -241,14 +244,14 @@ def close_session(self, session_id: SessionId) -> None: session_id: The session identifier returned by open_session() Raises: - ValueError: If the session ID is invalid + ProgrammingError: If the session ID is invalid OperationalError: If there's an error closing the session """ logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") + raise ProgrammingError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() request_data = DeleteSessionRequest( @@ -400,12 +403,12 @@ def execute_command( max_rows: int, max_bytes: int, lz4_compression: bool, - cursor: "Cursor", + cursor: Cursor, use_cloud_fetch: bool, parameters: List[Dict[str, Any]], async_op: bool, enforce_embedded_schema_correctness: bool, - ) -> Union["ResultSet", None]: + ) -> Union[SeaResultSet, None]: """ Execute a SQL command using the SEA backend. @@ -426,7 +429,7 @@ def execute_command( """ if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") + raise ProgrammingError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() @@ -501,11 +504,11 @@ def cancel_command(self, command_id: CommandId) -> None: command_id: Command identifier to cancel Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -524,11 +527,11 @@ def close_command(self, command_id: CommandId) -> None: command_id: Command identifier to close Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -550,7 +553,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: CommandState: The current state of the command Raises: - ValueError: If the command ID is invalid + ProgrammingError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: @@ -572,8 +575,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState: def get_execution_result( self, command_id: CommandId, - cursor: "Cursor", - ) -> "ResultSet": + cursor: Cursor, + ) -> SeaResultSet: """ Get the result of a command execution. @@ -582,14 +585,14 @@ def get_execution_result( cursor: Cursor executing the command Returns: - ResultSet: A SeaResultSet instance with the execution results + SeaResultSet: A SeaResultSet instance with the execution results Raises: ValueError: If the command ID is invalid """ if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") + raise ProgrammingError("Not a valid SEA command ID") sea_statement_id = command_id.to_sea_statement_id() @@ -626,47 +629,141 @@ def get_catalogs( session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError("get_catalogs is not yet implemented for SEA backend") + cursor: Cursor, + ) -> SeaResultSet: + """Get available catalogs by executing 'SHOW CATALOGS'.""" + result = self.execute_command( + operation=MetadataCommands.SHOW_CATALOGS.value, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_schemas( self, session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_schemas is not yet implemented for SEA backend") + ) -> SeaResultSet: + """Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'.""" + if not catalog_name: + raise DatabaseError("Catalog name is required for get_schemas") + + operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name) + + if schema_name: + operation += MetadataCommands.LIKE_PATTERN.value.format(schema_name) + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result def get_tables( self, session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, table_types: Optional[List[str]] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_tables is not yet implemented for SEA backend") + ) -> SeaResultSet: + """Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'.""" + operation = ( + MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value + if catalog_name in [None, "*", "%"] + else MetadataCommands.SHOW_TABLES.value.format( + MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name) + ) + ) + + if schema_name: + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) + + if table_name: + operation += MetadataCommands.LIKE_PATTERN.value.format(table_name) + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + + # Apply client-side filtering by table_types + from databricks.sql.backend.sea.utils.filters import ResultSetFilter + + result = ResultSetFilter.filter_tables_by_type(result, table_types) + + return result def get_columns( self, session_id: SessionId, max_rows: int, max_bytes: int, - cursor: "Cursor", + cursor: Cursor, catalog_name: Optional[str] = None, schema_name: Optional[str] = None, table_name: Optional[str] = None, column_name: Optional[str] = None, - ): - """Not implemented yet.""" - raise NotImplementedError("get_columns is not yet implemented for SEA backend") + ) -> SeaResultSet: + """Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'.""" + if not catalog_name: + raise DatabaseError("Catalog name is required for get_columns") + + operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name) + + if schema_name: + operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name) + + if table_name: + operation += MetadataCommands.TABLE_LIKE_PATTERN.value.format(table_name) + + if column_name: + operation += MetadataCommands.LIKE_PATTERN.value.format(column_name) + + result = self.execute_command( + operation=operation, + session_id=session_id, + max_rows=max_rows, + max_bytes=max_bytes, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result is not None, "execute_command returned None in synchronous mode" + return result diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 7481a90db..402da0de5 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -45,3 +45,23 @@ class WaitTimeout(Enum): ASYNC = "0s" SYNC = "10s" + + +class MetadataCommands(Enum): + """SQL commands used in the SEA backend. + + These constants are used for metadata operations and other SQL queries + to ensure consistency and avoid string literal duplication. + """ + + SHOW_CATALOGS = "SHOW CATALOGS" + SHOW_SCHEMAS = "SHOW SCHEMAS IN {}" + SHOW_TABLES = "SHOW TABLES IN {}" + SHOW_TABLES_ALL_CATALOGS = "SHOW TABLES IN ALL CATALOGS" + SHOW_COLUMNS = "SHOW COLUMNS IN CATALOG {}" + + LIKE_PATTERN = " LIKE '{}'" + SCHEMA_LIKE_PATTERN = " SCHEMA" + LIKE_PATTERN + TABLE_LIKE_PATTERN = " TABLE" + LIKE_PATTERN + + CATALOG_SPECIFIC = "CATALOG {}" diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py new file mode 100644 index 000000000..1b7660829 --- /dev/null +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -0,0 +1,152 @@ +""" +Client-side filtering utilities for Databricks SQL connector. + +This module provides filtering capabilities for result sets returned by different backends. +""" + +from __future__ import annotations + +import logging +from typing import ( + List, + Optional, + Any, + Callable, + cast, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from databricks.sql.result_set import SeaResultSet + +from databricks.sql.backend.types import ExecuteResponse + +logger = logging.getLogger(__name__) + + +class ResultSetFilter: + """ + A general-purpose filter for result sets. + """ + + @staticmethod + def _filter_sea_result_set( + result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool] + ) -> SeaResultSet: + """ + Filter a SEA result set using the provided filter function. + + Args: + result_set: The SEA result set to filter + filter_func: Function that takes a row and returns True if the row should be included + + Returns: + A filtered SEA result set + """ + + # Get all remaining rows + all_rows = result_set.results.remaining_rows() + + # Filter rows + filtered_rows = [row for row in all_rows if filter_func(row)] + + # Reuse the command_id from the original result set + command_id = result_set.command_id + + # Create an ExecuteResponse with the filtered data + execute_response = ExecuteResponse( + command_id=command_id, + status=result_set.status, + description=result_set.description, + has_been_closed_server_side=result_set.has_been_closed_server_side, + lz4_compressed=result_set.lz4_compressed, + arrow_schema_bytes=result_set._arrow_schema_bytes, + is_staging_operation=False, + ) + + # Create a new ResultData object with filtered data + from databricks.sql.backend.sea.models.base import ResultData + + result_data = ResultData(data=filtered_rows, external_links=None) + + from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.result_set import SeaResultSet + + # Create a new SeaResultSet with the filtered data + filtered_result_set = SeaResultSet( + connection=result_set.connection, + execute_response=execute_response, + sea_client=cast(SeaDatabricksClient, result_set.backend), + buffer_size_bytes=result_set.buffer_size_bytes, + arraysize=result_set.arraysize, + result_data=result_data, + ) + + return filtered_result_set + + @staticmethod + def filter_by_column_values( + result_set: SeaResultSet, + column_index: int, + allowed_values: List[str], + case_sensitive: bool = False, + ) -> SeaResultSet: + """ + Filter a result set by values in a specific column. + + Args: + result_set: The result set to filter + column_index: The index of the column to filter on + allowed_values: List of allowed values for the column + case_sensitive: Whether to perform case-sensitive comparison + + Returns: + A filtered result set + """ + + # Convert to uppercase for case-insensitive comparison if needed + if not case_sensitive: + allowed_values = [v.upper() for v in allowed_values] + + return ResultSetFilter._filter_sea_result_set( + result_set, + lambda row: ( + len(row) > column_index + and ( + row[column_index].upper() + if not case_sensitive + else row[column_index] + ) + in allowed_values + ), + ) + + @staticmethod + def filter_tables_by_type( + result_set: SeaResultSet, table_types: Optional[List[str]] = None + ) -> SeaResultSet: + """ + Filter a result set of tables by the specified table types. + + This is a client-side filter that processes the result set after it has been + retrieved from the server. It filters out tables whose type does not match + any of the types in the table_types list. + + Args: + result_set: The original result set containing tables + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) + + Returns: + A filtered result set containing only tables of the specified types + """ + + # Default table types if none specified + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] + valid_types = ( + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES + ) + + # Table type is the 6th column (index 5) + return ResultSetFilter.filter_by_column_values( + result_set, 5, valid_types, case_sensitive=True + ) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py new file mode 100644 index 000000000..975376e13 --- /dev/null +++ b/tests/unit/test_filters.py @@ -0,0 +1,160 @@ +""" +Tests for the ResultSetFilter class. +""" + +import unittest +from unittest.mock import MagicMock, patch + +from databricks.sql.backend.sea.utils.filters import ResultSetFilter + + +class TestResultSetFilter(unittest.TestCase): + """Tests for the ResultSetFilter class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock SeaResultSet + self.mock_sea_result_set = MagicMock() + + # Set up the remaining_rows method on the results attribute + self.mock_sea_result_set.results = MagicMock() + self.mock_sea_result_set.results.remaining_rows.return_value = [ + ["catalog1", "schema1", "table1", "owner1", "2023-01-01", "TABLE", ""], + ["catalog1", "schema1", "table2", "owner1", "2023-01-01", "VIEW", ""], + [ + "catalog1", + "schema1", + "table3", + "owner1", + "2023-01-01", + "SYSTEM TABLE", + "", + ], + [ + "catalog1", + "schema1", + "table4", + "owner1", + "2023-01-01", + "EXTERNAL TABLE", + "", + ], + ] + + # Set up the connection and other required attributes + self.mock_sea_result_set.connection = MagicMock() + self.mock_sea_result_set.backend = MagicMock() + self.mock_sea_result_set.buffer_size_bytes = 1000 + self.mock_sea_result_set.arraysize = 100 + self.mock_sea_result_set.statement_id = "test-statement-id" + self.mock_sea_result_set.lz4_compressed = False + + # Create a mock CommandId + from databricks.sql.backend.types import CommandId, BackendType + + mock_command_id = CommandId(BackendType.SEA, "test-statement-id") + self.mock_sea_result_set.command_id = mock_command_id + + self.mock_sea_result_set.status = MagicMock() + self.mock_sea_result_set.description = [ + ("catalog_name", "string", None, None, None, None, True), + ("schema_name", "string", None, None, None, None, True), + ("table_name", "string", None, None, None, None, True), + ("owner", "string", None, None, None, None, True), + ("creation_time", "string", None, None, None, None, True), + ("table_type", "string", None, None, None, None, True), + ("remarks", "string", None, None, None, None, True), + ] + self.mock_sea_result_set.has_been_closed_server_side = False + self.mock_sea_result_set._arrow_schema_bytes = None + + def test_filter_by_column_values(self): + """Test filtering by column values with various options.""" + # Case 1: Case-sensitive filtering + allowed_values = ["table1", "table3"] + + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + # Call filter_by_column_values on the table_name column (index 2) + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, 2, allowed_values, case_sensitive=True + ) + + # Verify the filter was applied correctly + mock_sea_result_set_class.assert_called_once() + + # Check the filtered data passed to the constructor + args, kwargs = mock_sea_result_set_class.call_args + result_data = kwargs.get("result_data") + self.assertIsNotNone(result_data) + self.assertEqual(len(result_data.data), 2) + self.assertIn(result_data.data[0][2], allowed_values) + self.assertIn(result_data.data[1][2], allowed_values) + + # Case 2: Case-insensitive filtering + mock_sea_result_set_class.reset_mock() + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch( + "databricks.sql.result_set.SeaResultSet" + ) as mock_sea_result_set_class: + mock_instance = MagicMock() + mock_sea_result_set_class.return_value = mock_instance + + # Call filter_by_column_values with case-insensitive matching + result = ResultSetFilter.filter_by_column_values( + self.mock_sea_result_set, + 2, + ["TABLE1", "TABLE3"], + case_sensitive=False, + ) + mock_sea_result_set_class.assert_called_once() + + def test_filter_tables_by_type(self): + """Test filtering tables by type with various options.""" + # Case 1: Specific table types + table_types = ["TABLE", "VIEW"] + + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + ResultSetFilter.filter_tables_by_type( + self.mock_sea_result_set, table_types + ) + args, kwargs = mock_filter.call_args + self.assertEqual(args[0], self.mock_sea_result_set) + self.assertEqual(args[1], 5) # Table type column index + self.assertEqual(args[2], table_types) + self.assertEqual(kwargs.get("case_sensitive"), True) + + # Case 2: Default table types (None or empty list) + with patch( + "databricks.sql.backend.sea.utils.filters.isinstance", return_value=True + ): + with patch.object( + ResultSetFilter, "filter_by_column_values" + ) as mock_filter: + # Test with None + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, None) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + # Test with empty list + ResultSetFilter.filter_tables_by_type(self.mock_sea_result_set, []) + args, kwargs = mock_filter.call_args + self.assertEqual(args[2], ["TABLE", "VIEW", "SYSTEM TABLE"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f30c92ed0..6847cded0 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -18,6 +18,7 @@ from databricks.sql.exc import ( Error, NotSupportedError, + ProgrammingError, ServerOperationError, DatabaseError, ) @@ -129,7 +130,7 @@ def test_initialization(self, mock_http_client): assert client3.max_download_threads == 5 # Test with invalid HTTP path - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: SeaDatabricksClient( server_hostname="test-server.databricks.com", port=443, @@ -195,7 +196,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i ) # Test close_session with invalid ID type - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.close_session(thrift_session_id) assert "Not a valid SEA session ID" in str(excinfo.value) @@ -244,7 +245,7 @@ def test_command_execution_sync( assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: mock_thrift_handle = MagicMock() mock_thrift_handle.sessionId.guid = b"guid" mock_thrift_handle.sessionId.secret = b"secret" @@ -448,7 +449,7 @@ def test_command_management( ) # Test cancel_command with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.cancel_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -462,7 +463,7 @@ def test_command_management( ) # Test close_command with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.close_command(thrift_command_id) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -521,7 +522,7 @@ def test_command_management( assert result.status == CommandState.SUCCEEDED # Test get_execution_result with invalid ID - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ProgrammingError) as excinfo: sea_client.get_execution_result(thrift_command_id, mock_cursor) assert "Not a valid SEA command ID" in str(excinfo.value) @@ -632,54 +633,247 @@ def test_utility_methods(self, sea_client): sea_client._extract_description_from_manifest(no_columns_manifest) is None ) - def test_unimplemented_metadata_methods( - self, sea_client, sea_session_id, mock_cursor - ): - """Test that metadata methods raise NotImplementedError.""" - # Test get_catalogs - with pytest.raises(NotImplementedError): - sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas - with pytest.raises(NotImplementedError): - sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) - - # Test get_schemas with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_schemas( - sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" + def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor): + """Test the get_catalogs method.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Call get_catalogs + result = sea_client.get_catalogs( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + + # Verify execute_command was called with the correct parameters + mock_execute.assert_called_once_with( + operation="SHOW CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Verify the result is correct + assert result == mock_result_set + + def test_get_schemas(self, sea_client, sea_session_id, mock_cursor): + """Test the get_schemas method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 2: With catalog and schema names + result = sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + ) + + mock_execute.assert_called_with( + operation="SHOW SCHEMAS IN test_catalog LIKE 'test_schema'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(DatabaseError) as excinfo: + sea_client.get_schemas( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_schemas" in str(excinfo.value) + + def test_get_tables(self, sea_client, sea_session_id, mock_cursor): + """Test the get_tables method with various parameter combinations.""" + # Mock the execute_command method + from databricks.sql.result_set import SeaResultSet + + mock_result_set = Mock(spec=SeaResultSet) + + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Mock the filter_tables_by_type method + with patch( + "databricks.sql.backend.sea.utils.filters.ResultSetFilter.filter_tables_by_type", + return_value=mock_result_set, + ) as mock_filter: + # Case 1: With catalog name only + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, None) + + # Case 2: With all parameters + table_types = ["TABLE", "VIEW"] + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + table_types=table_types, + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN CATALOG test_catalog SCHEMA LIKE 'test_schema' LIKE 'test_table'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + mock_filter.assert_called_with(mock_result_set, table_types) + + # Case 3: With wildcard catalog + result = sea_client.get_tables( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="*", + ) + + mock_execute.assert_called_with( + operation="SHOW TABLES IN ALL CATALOGS", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + def test_get_columns(self, sea_client, sea_session_id, mock_cursor): + """Test the get_columns method with various parameter combinations.""" + # Mock the execute_command method + mock_result_set = Mock() + with patch.object( + sea_client, "execute_command", return_value=mock_result_set + ) as mock_execute: + # Case 1: With catalog name only + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", ) - # Test get_tables - with pytest.raises(NotImplementedError): - sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) - - # Test get_tables with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_tables( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - table_types=["TABLE", "VIEW"], + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, ) - # Test get_columns - with pytest.raises(NotImplementedError): - sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) - - # Test get_columns with optional parameters - with pytest.raises(NotImplementedError): - sea_client.get_columns( - sea_session_id, - 100, - 1000, - mock_cursor, - catalog_name="catalog", - schema_name="schema", - table_name="table", - column_name="column", + # Case 2: With all parameters + result = sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + catalog_name="test_catalog", + schema_name="test_schema", + table_name="test_table", + column_name="test_column", ) + + mock_execute.assert_called_with( + operation="SHOW COLUMNS IN CATALOG test_catalog SCHEMA LIKE 'test_schema' TABLE LIKE 'test_table' LIKE 'test_column'", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # Case 3: Without catalog name (should raise ValueError) + with pytest.raises(DatabaseError) as excinfo: + sea_client.get_columns( + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + cursor=mock_cursor, + ) + assert "Catalog name is required for get_columns" in str(excinfo.value)