Skip to content

Commit 0b1b05b

Browse files
committed
fmt
1 parent 048af73 commit 0b1b05b

File tree

2 files changed

+67
-28
lines changed

2 files changed

+67
-28
lines changed

src/databricks/sql/client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,9 @@ def _convert_arrow_table(self, table):
13511351
# Need to use nullable types, as otherwise type can change when there are missing values.
13521352
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
13531353
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html
1354-
DEFAULT_DTYPE_MAPPING: Dict[pyarrow.DataType, pandas.api.extensions.ExtensionDtype] = {
1354+
DEFAULT_DTYPE_MAPPING: Dict[
1355+
pyarrow.DataType, pandas.api.extensions.ExtensionDtype
1356+
] = {
13551357
pyarrow.int8(): pandas.Int8Dtype(),
13561358
pyarrow.int16(): pandas.Int16Dtype(),
13571359
pyarrow.int32(): pandas.Int32Dtype(),
@@ -1365,7 +1367,10 @@ def _convert_arrow_table(self, table):
13651367
pyarrow.float64(): pandas.Float64Dtype(),
13661368
pyarrow.string(): pandas.StringDtype(),
13671369
}
1368-
dtype_mapping = {**DEFAULT_DTYPE_MAPPING, **self.connection._arrow_pandas_type_override}
1370+
dtype_mapping = {
1371+
**DEFAULT_DTYPE_MAPPING,
1372+
**self.connection._arrow_pandas_type_override,
1373+
}
13691374

13701375
to_pandas_kwargs: dict[str, Any] = {
13711376
"types_mapper": dtype_mapping.get,

tests/unit/test_arrow_conversion.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,31 @@ def mock_connection():
1515
conn.disable_pandas = False
1616
conn._arrow_pandas_type_override = {}
1717
conn._arrow_to_pandas_kwargs = {}
18-
if not hasattr(conn, '_arrow_to_pandas_kwargs'):
18+
if not hasattr(conn, "_arrow_to_pandas_kwargs"):
1919
conn._arrow_to_pandas_kwargs = {}
2020
return conn
2121

22+
2223
@pytest.fixture
2324
def mock_thrift_backend(sample_arrow_table):
2425
tb = MagicMock()
25-
empty_arrays = [pyarrow.array([], type=field.type) for field in sample_arrow_table.schema]
26-
empty_table = pyarrow.Table.from_arrays(empty_arrays, schema=sample_arrow_table.schema)
27-
tb.fetch_results.return_value = (ArrowQueue(empty_table, 0) , False)
26+
empty_arrays = [
27+
pyarrow.array([], type=field.type) for field in sample_arrow_table.schema
28+
]
29+
empty_table = pyarrow.Table.from_arrays(
30+
empty_arrays, schema=sample_arrow_table.schema
31+
)
32+
tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False)
2833
return tb
2934

35+
3036
@pytest.fixture
3137
def mock_raw_execute_response():
3238
er = MagicMock(spec=ExecuteResponse)
33-
er.description = [("col_int", "int", None, None, None, None, None),
34-
("col_str", "string", None, None, None, None, None)]
39+
er.description = [
40+
("col_int", "int", None, None, None, None, None),
41+
("col_str", "string", None, None, None, None, None),
42+
]
3543
er.arrow_schema_bytes = None
3644
er.arrow_queue = None
3745
er.has_more_rows = False
@@ -42,27 +50,33 @@ def mock_raw_execute_response():
4250
er.is_staging_operation = False
4351
return er
4452

53+
4554
@pytest.fixture
4655
def sample_arrow_table():
4756
data = [
4857
pyarrow.array([1, 2, 3], type=pyarrow.int32()),
49-
pyarrow.array(["a", "b", "c"], type=pyarrow.string())
58+
pyarrow.array(["a", "b", "c"], type=pyarrow.string()),
5059
]
51-
schema = pyarrow.schema([
52-
('col_int', pyarrow.int32()),
53-
('col_str', pyarrow.string())
54-
])
60+
schema = pyarrow.schema(
61+
[("col_int", pyarrow.int32()), ("col_str", pyarrow.string())]
62+
)
5563
return pyarrow.Table.from_arrays(data, schema=schema)
5664

5765

58-
def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table):
59-
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
66+
def test_convert_arrow_table_default(
67+
mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table
68+
):
69+
mock_raw_execute_response.arrow_queue = ArrowQueue(
70+
sample_arrow_table, sample_arrow_table.num_rows
71+
)
6072
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
6173
result_one = rs.fetchone()
6274
assert isinstance(result_one, Row)
6375
assert result_one.col_int == 1
6476
assert result_one.col_str == "a"
65-
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
77+
mock_raw_execute_response.arrow_queue = ArrowQueue(
78+
sample_arrow_table, sample_arrow_table.num_rows
79+
)
6680
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
6781
result_all = rs.fetchall()
6882
assert len(result_all) == 3
@@ -71,9 +85,13 @@ def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_
7185
assert result_all[1].col_str == "b"
7286

7387

74-
def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table):
88+
def test_convert_arrow_table_disable_pandas(
89+
mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table
90+
):
7591
mock_connection.disable_pandas = True
76-
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
92+
mock_raw_execute_response.arrow_queue = ArrowQueue(
93+
sample_arrow_table, sample_arrow_table.num_rows
94+
)
7795
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
7896
result = rs.fetchall()
7997
assert len(result) == 3
@@ -84,9 +102,15 @@ def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend
84102
assert isinstance(sample_arrow_table.column(1)[0].as_py(), str)
85103

86104

87-
def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table):
88-
mock_connection._arrow_pandas_type_override = {pyarrow.int32(): pandas.Float64Dtype()}
89-
mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows)
105+
def test_convert_arrow_table_type_override(
106+
mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table
107+
):
108+
mock_connection._arrow_pandas_type_override = {
109+
pyarrow.int32(): pandas.Float64Dtype()
110+
}
111+
mock_raw_execute_response.arrow_queue = ArrowQueue(
112+
sample_arrow_table, sample_arrow_table.num_rows
113+
)
90114
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
91115
result = rs.fetchall()
92116
assert len(result) == 3
@@ -95,34 +119,44 @@ def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend,
95119
assert result[0].col_str == "a"
96120

97121

98-
def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backend, mock_raw_execute_response):
122+
def test_convert_arrow_table_to_pandas_kwargs(
123+
mock_connection, mock_thrift_backend, mock_raw_execute_response
124+
):
99125
dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc)
100-
ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp('us', tz='UTC'))
101-
ts_schema = pyarrow.schema([('col_ts', pyarrow.timestamp('us', tz='UTC'))])
126+
ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp("us", tz="UTC"))
127+
ts_schema = pyarrow.schema([("col_ts", pyarrow.timestamp("us", tz="UTC"))])
102128
ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema)
103129

104-
mock_raw_execute_response.description = [("col_ts", "timestamp", None, None, None, None, None)]
130+
mock_raw_execute_response.description = [
131+
("col_ts", "timestamp", None, None, None, None, None)
132+
]
105133
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
106134

107135
# Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row.
108136
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True}
109-
rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
137+
rs_ts_true = ResultSet(
138+
mock_connection, mock_raw_execute_response, mock_thrift_backend
139+
)
110140
result_true = rs_ts_true.fetchall()
111141
assert len(result_true) == 1
112142
assert isinstance(result_true[0].col_ts, datetime.datetime)
113143

114144
# Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input.
115145
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
116146
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False}
117-
rs_ts_false = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
147+
rs_ts_false = ResultSet(
148+
mock_connection, mock_raw_execute_response, mock_thrift_backend
149+
)
118150
result_false = rs_ts_false.fetchall()
119151
assert len(result_false) == 1
120152
assert isinstance(result_false[0].col_ts, pandas.Timestamp)
121153

122154
# Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default.
123155
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
124156
mock_connection._arrow_to_pandas_kwargs = {}
125-
rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
157+
rs_ts_true = ResultSet(
158+
mock_connection, mock_raw_execute_response, mock_thrift_backend
159+
)
126160
result_true = rs_ts_true.fetchall()
127161
assert len(result_true) == 1
128162
assert isinstance(result_true[0].col_ts, datetime.datetime)

0 commit comments

Comments
 (0)