Skip to content

Commit a151df2

Browse files
authored
[PECO-1926] Create a non pyarrow flow to handle small results for the column set (#440)
* Implemented the columnar flow for non arrow users * Minor fixes * Introduced the Column Table structure * Added test for the new column table * Minor fix * Removed unnecessory fikes
1 parent d31063c commit a151df2

File tree

4 files changed

+263
-32
lines changed

4 files changed

+263
-32
lines changed

src/databricks/sql/client.py

+88-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
22

33
import pandas
4-
import pyarrow
4+
try:
5+
import pyarrow
6+
except ImportError:
7+
pyarrow = None
58
import requests
69
import json
710
import os
@@ -22,6 +25,8 @@
2225
ParamEscaper,
2326
inject_parameters,
2427
transform_paramstyle,
28+
ColumnTable,
29+
ColumnQueue
2530
)
2631
from databricks.sql.parameters.native import (
2732
DbsqlParameterBase,
@@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]:
991996
else:
992997
raise Error("There is no active result set")
993998

994-
def fetchall_arrow(self) -> pyarrow.Table:
999+
def fetchall_arrow(self) -> "pyarrow.Table":
9951000
self._check_not_closed()
9961001
if self.active_result_set:
9971002
return self.active_result_set.fetchall_arrow()
9981003
else:
9991004
raise Error("There is no active result set")
10001005

1001-
def fetchmany_arrow(self, size) -> pyarrow.Table:
1006+
def fetchmany_arrow(self, size) -> "pyarrow.Table":
10021007
self._check_not_closed()
10031008
if self.active_result_set:
10041009
return self.active_result_set.fetchmany_arrow(size)
@@ -1143,6 +1148,18 @@ def _fill_results_buffer(self):
11431148
self.results = results
11441149
self.has_more_rows = has_more_rows
11451150

1151+
def _convert_columnar_table(self, table):
1152+
column_names = [c[0] for c in self.description]
1153+
ResultRow = Row(*column_names)
1154+
result = []
1155+
for row_index in range(table.num_rows):
1156+
curr_row = []
1157+
for col_index in range(table.num_columns):
1158+
curr_row.append(table.get_item(col_index, row_index))
1159+
result.append(ResultRow(*curr_row))
1160+
1161+
return result
1162+
11461163
def _convert_arrow_table(self, table):
11471164
column_names = [c[0] for c in self.description]
11481165
ResultRow = Row(*column_names)
@@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table):
11851202
def rownumber(self):
11861203
return self._next_row_index
11871204

1188-
def fetchmany_arrow(self, size: int) -> pyarrow.Table:
1205+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
11891206
"""
11901207
Fetch the next set of rows of a query result, returning a PyArrow table.
11911208
@@ -1210,7 +1227,46 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:
12101227

12111228
return results
12121229

1213-
def fetchall_arrow(self) -> pyarrow.Table:
1230+
def merge_columnar(self, result1, result2):
1231+
"""
1232+
Function to merge / combining the columnar results into a single result
1233+
:param result1:
1234+
:param result2:
1235+
:return:
1236+
"""
1237+
1238+
if result1.column_names != result2.column_names:
1239+
raise ValueError("The columns in the results don't match")
1240+
1241+
merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)]
1242+
return ColumnTable(merged_result, result1.column_names)
1243+
1244+
def fetchmany_columnar(self, size: int):
1245+
"""
1246+
Fetch the next set of rows of a query result, returning a Columnar Table.
1247+
An empty sequence is returned when no more rows are available.
1248+
"""
1249+
if size < 0:
1250+
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
1251+
1252+
results = self.results.next_n_rows(size)
1253+
n_remaining_rows = size - results.num_rows
1254+
self._next_row_index += results.num_rows
1255+
1256+
while (
1257+
n_remaining_rows > 0
1258+
and not self.has_been_closed_server_side
1259+
and self.has_more_rows
1260+
):
1261+
self._fill_results_buffer()
1262+
partial_results = self.results.next_n_rows(n_remaining_rows)
1263+
results = self.merge_columnar(results, partial_results)
1264+
n_remaining_rows -= partial_results.num_rows
1265+
self._next_row_index += partial_results.num_rows
1266+
1267+
return results
1268+
1269+
def fetchall_arrow(self) -> "pyarrow.Table":
12141270
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
12151271
results = self.results.remaining_rows()
12161272
self._next_row_index += results.num_rows
@@ -1223,12 +1279,30 @@ def fetchall_arrow(self) -> pyarrow.Table:
12231279

12241280
return results
12251281

1282+
def fetchall_columnar(self):
1283+
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
1284+
results = self.results.remaining_rows()
1285+
self._next_row_index += results.num_rows
1286+
1287+
while not self.has_been_closed_server_side and self.has_more_rows:
1288+
self._fill_results_buffer()
1289+
partial_results = self.results.remaining_rows()
1290+
results = self.merge_columnar(results, partial_results)
1291+
self._next_row_index += partial_results.num_rows
1292+
1293+
return results
1294+
12261295
def fetchone(self) -> Optional[Row]:
12271296
"""
12281297
Fetch the next row of a query result set, returning a single sequence,
12291298
or None when no more data is available.
12301299
"""
1231-
res = self._convert_arrow_table(self.fetchmany_arrow(1))
1300+
1301+
if isinstance(self.results, ColumnQueue):
1302+
res = self._convert_columnar_table(self.fetchmany_columnar(1))
1303+
else:
1304+
res = self._convert_arrow_table(self.fetchmany_arrow(1))
1305+
12321306
if len(res) > 0:
12331307
return res[0]
12341308
else:
@@ -1238,15 +1312,21 @@ def fetchall(self) -> List[Row]:
12381312
"""
12391313
Fetch all (remaining) rows of a query result, returning them as a list of rows.
12401314
"""
1241-
return self._convert_arrow_table(self.fetchall_arrow())
1315+
if isinstance(self.results, ColumnQueue):
1316+
return self._convert_columnar_table(self.fetchall_columnar())
1317+
else:
1318+
return self._convert_arrow_table(self.fetchall_arrow())
12421319

12431320
def fetchmany(self, size: int) -> List[Row]:
12441321
"""
12451322
Fetch the next set of rows of a query result, returning a list of rows.
12461323
12471324
An empty sequence is returned when no more rows are available.
12481325
"""
1249-
return self._convert_arrow_table(self.fetchmany_arrow(size))
1326+
if isinstance(self.results, ColumnQueue):
1327+
return self._convert_columnar_table(self.fetchmany_columnar(size))
1328+
else:
1329+
return self._convert_arrow_table(self.fetchmany_arrow(size))
12501330

12511331
def close(self) -> None:
12521332
"""

src/databricks/sql/thrift_backend.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import threading
88
from typing import List, Union
99

10-
import pyarrow
10+
try:
11+
import pyarrow
12+
except ImportError:
13+
pyarrow = None
1114
import thrift.transport.THttpClient
1215
import thrift.protocol.TBinaryProtocol
1316
import thrift.transport.TSocket
@@ -621,6 +624,7 @@ def _get_metadata_resp(self, op_handle):
621624

622625
@staticmethod
623626
def _hive_schema_to_arrow_schema(t_table_schema):
627+
624628
def map_type(t_type_entry):
625629
if t_type_entry.primitiveEntry:
626630
return {
@@ -726,12 +730,17 @@ def _results_message_to_execute_response(self, resp, operation_state):
726730
description = self._hive_schema_to_description(
727731
t_result_set_metadata_resp.schema
728732
)
729-
schema_bytes = (
730-
t_result_set_metadata_resp.arrowSchema
731-
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
732-
.serialize()
733-
.to_pybytes()
734-
)
733+
734+
if pyarrow:
735+
schema_bytes = (
736+
t_result_set_metadata_resp.arrowSchema
737+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
738+
.serialize()
739+
.to_pybytes()
740+
)
741+
else:
742+
schema_bytes = None
743+
735744
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
736745
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
737746
if direct_results and direct_results.resultSet:
@@ -827,7 +836,7 @@ def execute_command(
827836
getDirectResults=ttypes.TSparkGetDirectResults(
828837
maxRows=max_rows, maxBytes=max_bytes
829838
),
830-
canReadArrowResult=True,
839+
canReadArrowResult=True if pyarrow else False,
831840
canDecompressLZ4Result=lz4_compression,
832841
canDownloadResult=use_cloud_fetch,
833842
confOverlay={

0 commit comments

Comments
 (0)