diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0c9a08a8..da1177f4 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]: # (True by default) # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage + # _arrow_pandas_type_override + # Override the default pandas dtype mapping for Arrow types. + # This is a dictionary of Arrow types to pandas dtypes. + # _arrow_to_pandas_kwargs + # Additional or modified arguments to pass to pandas.DataFrame constructor. logger.debug( "Connection.__init__(server_hostname=%s, http_path=%s)", @@ -1346,7 +1351,9 @@ def _convert_arrow_table(self, table): # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { + DEFAULT_DTYPE_MAPPING: Dict[ + pyarrow.DataType, pandas.api.extensions.ExtensionDtype + ] = { pyarrow.int8(): pandas.Int8Dtype(), pyarrow.int16(): pandas.Int16Dtype(), pyarrow.int32(): pandas.Int32Dtype(), @@ -1361,13 +1368,35 @@ def _convert_arrow_table(self, table): pyarrow.string(): pandas.StringDtype(), } + arrow_pandas_type_override = self.connection._arrow_pandas_type_override + if not isinstance(arrow_pandas_type_override, dict): + logger.debug( + "_arrow_pandas_type_override on connection was not a dict, using default type mapping" + ) + arrow_pandas_type_override = {} + + dtype_mapping = { + **DEFAULT_DTYPE_MAPPING, + **arrow_pandas_type_override, + } + + to_pandas_kwargs: dict[str, Any] = { + "types_mapper": dtype_mapping.get, + "date_as_object": True, + "timestamp_as_object": True, + } + + arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs + if isinstance(arrow_to_pandas_kwargs, dict): + to_pandas_kwargs.update(arrow_to_pandas_kwargs) + else: + logger.debug( + "_arrow_to_pandas_kwargs on connection was not a dict, using default arguments" + ) + # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) + df = table_renamed.to_pandas(**to_pandas_kwargs) res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py new file mode 100644 index 00000000..78d43635 --- /dev/null +++ b/tests/unit/test_arrow_conversion.py @@ -0,0 +1,184 @@ +import pytest + +try: + import pyarrow as pa +except ImportError: + pa = None +import pandas +import datetime +import unittest +from unittest.mock import MagicMock + +from databricks.sql.client import ResultSet, Connection, ExecuteResponse +from databricks.sql.types import Row +from databricks.sql.utils import ArrowQueue + +@pytest.mark.skipif(pa is None, reason="PyArrow is not installed") +class ArrowConversionTests(unittest.TestCase): + @staticmethod + def mock_connection_static(): + conn = MagicMock(spec=Connection) + conn.disable_pandas = False + conn._arrow_pandas_type_override = {} + conn._arrow_to_pandas_kwargs = {} + return conn + + @staticmethod + def sample_arrow_table_static(): + data = [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["a", "b", "c"], type=pa.string()), + ] + schema = pa.schema([("col_int", pa.int32()), ("col_str", pa.string())]) + return pa.Table.from_arrays(data, schema=schema) + + @staticmethod + def mock_thrift_backend_static(): + sample_table = ArrowConversionTests.sample_arrow_table_static() + tb = MagicMock() + empty_arrays = [pa.array([], type=field.type) for field in sample_table.schema] + empty_table = pa.Table.from_arrays(empty_arrays, schema=sample_table.schema) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) + return tb + + @staticmethod + def mock_raw_execute_response_static(): + er = MagicMock(spec=ExecuteResponse) + er.description = [ + ("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None), + ] + er.arrow_schema_bytes = None + er.arrow_queue = None + er.has_more_rows = False + er.lz4_compressed = False + er.command_handle = MagicMock() + er.status = MagicMock() + er.has_been_closed_server_side = False + er.is_staging_operation = False + return er + + def test_convert_arrow_table_default(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_one = rs.fetchone() + self.assertIsInstance(result_one, Row) + self.assertEqual(result_one.col_int, 1) + self.assertEqual(result_one.col_str, "a") + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_all = rs.fetchall() + self.assertEqual(len(result_all), 3) + self.assertIsInstance(result_all[0], Row) + self.assertEqual(result_all[0].col_int, 1) + self.assertEqual(result_all[1].col_str, "b") + + def test_convert_arrow_table_disable_pandas(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection.disable_pandas = True + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0], Row) + self.assertEqual(result[0].col_int, 1) + self.assertEqual(result[0].col_str, "a") + self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(), int) + self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(), str) + + def test_convert_arrow_table_type_override(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection._arrow_pandas_type_override = { + pa.int32(): pandas.Float64Dtype() + } + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0].col_int, float) + self.assertEqual(result[0].col_int, 1.0) + self.assertEqual(result[0].col_str, "a") + + def test_convert_arrow_table_to_pandas_kwargs(self): + mock_connection = ArrowConversionTests.mock_connection_static() + mock_thrift_backend = ( + ArrowConversionTests.mock_thrift_backend_static() + ) # Does not use sample_arrow_table + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + ts_array = pa.array([dt_obj], type=pa.timestamp("us", tz="UTC")) + ts_schema = pa.schema([("col_ts", pa.timestamp("us", tz="UTC"))]) + ts_table = pa.Table.from_arrays([ts_array], schema=ts_schema) + + mock_raw_execute_response.description = [ + ("col_ts", "timestamp", None, None, None, None, None) + ] + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + + # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_true = rs_ts_true.fetchall() + self.assertEqual(len(result_true), 1) + self.assertIsInstance(result_true[0].col_ts, datetime.datetime) + + # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} + rs_ts_false = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_false = rs_ts_false.fetchall() + self.assertEqual(len(result_false), 1) + self.assertIsInstance(result_false[0].col_ts, pandas.Timestamp) + + # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {} + rs_ts_default = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_default = rs_ts_default.fetchall() + self.assertEqual(len(result_default), 1) + self.assertIsInstance(result_default[0].col_ts, datetime.datetime) + + +if __name__ == "__main__": + unittest.main()