Skip to content

Commit 048af73

Browse files
committed
Enhance Arrow to Pandas conversion with type overrides and additional kwargs
* Introduced _arrow_pandas_type_override and _arrow_to_pandas_kwargs in Connection class for customizable dtype mapping and DataFrame construction parameters. * Updated ResultSet to utilize these new options during conversion from Arrow tables to Pandas DataFrames. * Added unit tests to validate the new functionality, including scenarios for type overrides and additional kwargs handling.
1 parent 0947b9a commit 048af73

File tree

2 files changed

+143
-6
lines changed

2 files changed

+143
-6
lines changed

src/databricks/sql/client.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]:
213213
# (True by default)
214214
# use_cloud_fetch
215215
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
216+
# _arrow_pandas_type_override
217+
# Override the default pandas dtype mapping for Arrow types.
218+
# This is a dictionary of Arrow types to pandas dtypes.
219+
# _arrow_to_pandas_kwargs
220+
# Additional or modified arguments to pass to pandas.DataFrame constructor.
216221

217222
logger.debug(
218223
"Connection.__init__(server_hostname=%s, http_path=%s)",
@@ -1346,7 +1351,7 @@ def _convert_arrow_table(self, table):
13461351
# Need to use nullable types, as otherwise type can change when there are missing values.
13471352
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
13481353
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html
1349-
dtype_mapping = {
1354+
DEFAULT_DTYPE_MAPPING: Dict[pyarrow.DataType, pandas.api.extensions.ExtensionDtype] = {
13501355
pyarrow.int8(): pandas.Int8Dtype(),
13511356
pyarrow.int16(): pandas.Int16Dtype(),
13521357
pyarrow.int32(): pandas.Int32Dtype(),
@@ -1360,14 +1365,18 @@ def _convert_arrow_table(self, table):
13601365
pyarrow.float64(): pandas.Float64Dtype(),
13611366
pyarrow.string(): pandas.StringDtype(),
13621367
}
1368+
dtype_mapping = {**DEFAULT_DTYPE_MAPPING, **self.connection._arrow_pandas_type_override}
1369+
1370+
to_pandas_kwargs: dict[str, Any] = {
1371+
"types_mapper": dtype_mapping.get,
1372+
"date_as_object": True,
1373+
"timestamp_as_object": True,
1374+
}
1375+
to_pandas_kwargs.update(self.connection._arrow_to_pandas_kwargs)
13631376

13641377
# Need to rename columns, as the to_pandas function cannot handle duplicate column names
13651378
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)])
1366-
df = table_renamed.to_pandas(
1367-
types_mapper=dtype_mapping.get,
1368-
date_as_object=True,
1369-
timestamp_as_object=True,
1370-
)
1379+
df = table_renamed.to_pandas(**to_pandas_kwargs)
13711380

13721381
res = df.to_numpy(na_value=None, dtype="object")
13731382
return [ResultRow(*v) for v in res]

tests/unit/test_arrow_conversion.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import pytest
2+
import pyarrow
3+
import pandas
4+
import datetime
5+
from unittest.mock import MagicMock, patch
6+
7+
from databricks.sql.client import ResultSet, Connection, ExecuteResponse
8+
from databricks.sql.types import Row
9+
from databricks.sql.utils import ArrowQueue
10+
11+
12+
@pytest.fixture
13+
def mock_connection():
14+
conn = MagicMock(spec=Connection)
15+
conn.disable_pandas = False
16+
conn._arrow_pandas_type_override = {}
17+
conn._arrow_to_pandas_kwargs = {}
18+
if not hasattr(conn, '_arrow_to_pandas_kwargs'):
19+
conn._arrow_to_pandas_kwargs = {}
20+
return conn
21+
22+
@pytest.fixture
23+
def mock_thrift_backend(sample_arrow_table):
24+
tb = MagicMock()
25+
empty_arrays = [pyarrow.array([], type=field.type) for field in sample_arrow_table.schema]
26+
empty_table = pyarrow.Table.from_arrays(empty_arrays, schema=sample_arrow_table.schema)
27+
tb.fetch_results.return_value = (ArrowQueue(empty_table, 0) , False)
28+
return tb
29+
30+
@pytest.fixture
31+
def mock_raw_execute_response():
32+
er = MagicMock(spec=ExecuteResponse)
33+
er.description = [("col_int", "int", None, None, None, None, None),
34+
("col_str", "string", None, None, None, None, None)]
35+
er.arrow_schema_bytes = None
36+
er.arrow_queue = None
37+
er.has_more_rows = False
38+
er.lz4_compressed = False
39+
er.command_handle = MagicMock()
40+
er.status = MagicMock()
41+
er.has_been_closed_server_side = False
42+
er.is_staging_operation = False
43+
return er
44+
45+
@pytest.fixture
46+
def sample_arrow_table():
47+
data = [
48+
pyarrow.array([1, 2, 3], type=pyarrow.int32()),
49+
pyarrow.array(["a", "b", "c"], type=pyarrow.string())
50+
]
51+
schema = pyarrow.schema([
52+
('col_int', pyarrow.int32()),
53+
('col_str', pyarrow.string())
54+
])
55+
return pyarrow.Table.from_arrays(data, schema=schema)
56+
57+
58+
def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table):
59+
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
60+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
61+
result_one = rs.fetchone()
62+
assert isinstance(result_one, Row)
63+
assert result_one.col_int == 1
64+
assert result_one.col_str == "a"
65+
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
66+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
67+
result_all = rs.fetchall()
68+
assert len(result_all) == 3
69+
assert isinstance(result_all[0], Row)
70+
assert result_all[0].col_int == 1
71+
assert result_all[1].col_str == "b"
72+
73+
74+
def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table):
75+
mock_connection.disable_pandas = True
76+
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
77+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
78+
result = rs.fetchall()
79+
assert len(result) == 3
80+
assert isinstance(result[0], Row)
81+
assert result[0].col_int == 1
82+
assert result[0].col_str == "a"
83+
assert isinstance(sample_arrow_table.column(0)[0].as_py(), int)
84+
assert isinstance(sample_arrow_table.column(1)[0].as_py(), str)
85+
86+
87+
def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table):
88+
mock_connection._arrow_pandas_type_override = {pyarrow.int32(): pandas.Float64Dtype()}
89+
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
90+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
91+
result = rs.fetchall()
92+
assert len(result) == 3
93+
assert isinstance(result[0].col_int, float)
94+
assert result[0].col_int == 1.0
95+
assert result[0].col_str == "a"
96+
97+
98+
def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backend, mock_raw_execute_response):
99+
dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc)
100+
ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp('us', tz='UTC'))
101+
ts_schema = pyarrow.schema([('col_ts', pyarrow.timestamp('us', tz='UTC'))])
102+
ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema)
103+
104+
mock_raw_execute_response.description = [("col_ts", "timestamp", None, None, None, None, None)]
105+
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
106+
107+
# Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row.
108+
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True}
109+
rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
110+
result_true = rs_ts_true.fetchall()
111+
assert len(result_true) == 1
112+
assert isinstance(result_true[0].col_ts, datetime.datetime)
113+
114+
# Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input.
115+
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
116+
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False}
117+
rs_ts_false = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
118+
result_false = rs_ts_false.fetchall()
119+
assert len(result_false) == 1
120+
assert isinstance(result_false[0].col_ts, pandas.Timestamp)
121+
122+
# Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default.
123+
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
124+
mock_connection._arrow_to_pandas_kwargs = {}
125+
rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
126+
result_true = rs_ts_true.fetchall()
127+
assert len(result_true) == 1
128+
assert isinstance(result_true[0].col_ts, datetime.datetime)

0 commit comments

Comments
 (0)