Skip to content

Commit dac08f2

Browse files
enforce ResultSet return in exec commands in backend client
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent b8e1bbd commit dac08f2

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def execute_command(
4242
parameters: List[ttypes.TSparkParameter],
4343
async_op: bool,
4444
enforce_embedded_schema_correctness: bool,
45-
) -> "ResultSet": # Changed return type to ResultSet
45+
) -> "ResultSet":
4646
pass
4747

4848
@abstractmethod
@@ -73,7 +73,7 @@ def get_catalogs(
7373
max_rows: int,
7474
max_bytes: int,
7575
cursor: Any,
76-
) -> Any:
76+
) -> "ResultSet":
7777
pass
7878

7979
@abstractmethod
@@ -85,7 +85,7 @@ def get_schemas(
8585
cursor: Any,
8686
catalog_name: Optional[str] = None,
8787
schema_name: Optional[str] = None,
88-
) -> Any:
88+
) -> "ResultSet":
8989
pass
9090

9191
@abstractmethod
@@ -99,7 +99,7 @@ def get_tables(
9999
schema_name: Optional[str] = None,
100100
table_name: Optional[str] = None,
101101
table_types: Optional[List[str]] = None,
102-
) -> Any:
102+
) -> "ResultSet":
103103
pass
104104

105105
@abstractmethod
@@ -113,7 +113,7 @@ def get_columns(
113113
schema_name: Optional[str] = None,
114114
table_name: Optional[str] = None,
115115
column_name: Optional[str] = None,
116-
) -> Any:
116+
) -> "ResultSet":
117117
pass
118118

119119
# == Utility Methods ==

src/databricks/sql/backend/thrift_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
)
4949
from databricks.sql.types import SSLOptions
5050
from databricks.sql.backend.databricks_client import DatabricksClient
51-
from databricks.sql.result_set import ThriftResultSet
51+
from databricks.sql.result_set import ResultSet, ThriftResultSet
5252

5353
logger = logging.getLogger(__name__)
5454

@@ -945,7 +945,7 @@ def execute_command(
945945
parameters=[],
946946
async_op=False,
947947
enforce_embedded_schema_correctness=False,
948-
):
948+
) -> "ResultSet":
949949
thrift_handle = session_id.to_thrift_handle()
950950
if not thrift_handle:
951951
raise ValueError("Not a valid Thrift session ID")
@@ -1009,7 +1009,7 @@ def get_catalogs(
10091009
max_rows: int,
10101010
max_bytes: int,
10111011
cursor: Any,
1012-
):
1012+
) -> "ResultSet":
10131013
thrift_handle = session_id.to_thrift_handle()
10141014
if not thrift_handle:
10151015
raise ValueError("Not a valid Thrift session ID")
@@ -1041,7 +1041,7 @@ def get_schemas(
10411041
cursor: Any,
10421042
catalog_name=None,
10431043
schema_name=None,
1044-
):
1044+
) -> "ResultSet":
10451045
thrift_handle = session_id.to_thrift_handle()
10461046
if not thrift_handle:
10471047
raise ValueError("Not a valid Thrift session ID")
@@ -1077,7 +1077,7 @@ def get_tables(
10771077
schema_name=None,
10781078
table_name=None,
10791079
table_types=None,
1080-
):
1080+
) -> "ResultSet":
10811081
thrift_handle = session_id.to_thrift_handle()
10821082
if not thrift_handle:
10831083
raise ValueError("Not a valid Thrift session ID")
@@ -1115,7 +1115,7 @@ def get_columns(
11151115
schema_name=None,
11161116
table_name=None,
11171117
column_name=None,
1118-
):
1118+
) -> "ResultSet":
11191119
thrift_handle = session_id.to_thrift_handle()
11201120
if not thrift_handle:
11211121
raise ValueError("Not a valid Thrift session ID")

0 commit comments

Comments
 (0)