Skip to content

Commit 2470581

Browse files
committed
Reformatted
1 parent 3318b04 commit 2470581

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

databricks_sql_connector_core/src/databricks/sql/client.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
TSparkParameter,
4444
)
4545

46+
try:
47+
import pyarrow
48+
except ImportError:
49+
pyarrow = None
4650

4751
logger = logging.getLogger(__name__)
4852

@@ -977,14 +981,14 @@ def fetchmany(self, size: int) -> List[Row]:
977981
else:
978982
raise Error("There is no active result set")
979983

980-
def fetchall_arrow(self) -> pyarrow.Table:
984+
def fetchall_arrow(self) -> "pyarrow.Table":
981985
self._check_not_closed()
982986
if self.active_result_set:
983987
return self.active_result_set.fetchall_arrow()
984988
else:
985989
raise Error("There is no active result set")
986990

987-
def fetchmany_arrow(self, size) -> pyarrow.Table:
991+
def fetchmany_arrow(self, size) -> "pyarrow.Table":
988992
self._check_not_closed()
989993
if self.active_result_set:
990994
return self.active_result_set.fetchmany_arrow(size)
@@ -1171,7 +1175,7 @@ def _convert_arrow_table(self, table):
11711175
def rownumber(self):
11721176
return self._next_row_index
11731177

1174-
def fetchmany_arrow(self, size: int) -> pyarrow.Table:
1178+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
11751179
"""
11761180
Fetch the next set of rows of a query result, returning a PyArrow table.
11771181
@@ -1196,7 +1200,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:
11961200

11971201
return results
11981202

1199-
def fetchall_arrow(self) -> pyarrow.Table:
1203+
def fetchall_arrow(self) -> "pyarrow.Table":
12001204
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
12011205
results = self.results.remaining_rows()
12021206
self._next_row_index += results.num_rows

databricks_sql_connector_core/src/databricks/sql/thrift_backend.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
convert_column_based_set_to_arrow_table,
3838
)
3939

40+
try:
41+
import pyarrow
42+
except ImportError:
43+
pyarrow = None
44+
4045
logger = logging.getLogger(__name__)
4146

4247
unsafe_logger = logging.getLogger("databricks.sql.unsafe")
@@ -652,6 +657,12 @@ def _get_metadata_resp(self, op_handle):
652657

653658
@staticmethod
654659
def _hive_schema_to_arrow_schema(t_table_schema):
660+
661+
if pyarrow is None:
662+
raise ImportError(
663+
"pyarrow is required to convert Hive schema to Arrow schema"
664+
)
665+
655666
def map_type(t_type_entry):
656667
if t_type_entry.primitiveEntry:
657668
return {
@@ -858,7 +869,7 @@ def execute_command(
858869
getDirectResults=ttypes.TSparkGetDirectResults(
859870
maxRows=max_rows, maxBytes=max_bytes
860871
),
861-
canReadArrowResult=True,
872+
canReadArrowResult=True if pyarrow else False,
862873
canDecompressLZ4Result=lz4_compression,
863874
canDownloadResult=use_cloud_fetch,
864875
confOverlay={

databricks_sql_connector_core/src/databricks/sql/utils.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,21 @@
2828

2929
import logging
3030

31+
try:
32+
import pyarrow
33+
except ImportError:
34+
pyarrow = None
35+
3136
logger = logging.getLogger(__name__)
3237

3338

3439
class ResultSetQueue(ABC):
3540
@abstractmethod
36-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
41+
def next_n_rows(self, num_rows: int):
3742
pass
3843

3944
@abstractmethod
40-
def remaining_rows(self) -> pyarrow.Table:
45+
def remaining_rows(self):
4146
pass
4247

4348

@@ -100,7 +105,7 @@ def build_queue(
100105
class ArrowQueue(ResultSetQueue):
101106
def __init__(
102107
self,
103-
arrow_table: pyarrow.Table,
108+
arrow_table: "pyarrow.Table",
104109
n_valid_rows: int,
105110
start_row_index: int = 0,
106111
):
@@ -115,7 +120,7 @@ def __init__(
115120
self.arrow_table = arrow_table
116121
self.n_valid_rows = n_valid_rows
117122

118-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
123+
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
119124
"""Get upto the next n rows of the Arrow dataframe"""
120125
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
121126
# Note that the table.slice API is not the same as Python's slice
@@ -124,7 +129,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
124129
self.cur_row_index += slice.num_rows
125130
return slice
126131

127-
def remaining_rows(self) -> pyarrow.Table:
132+
def remaining_rows(self) -> "pyarrow.Table":
128133
slice = self.arrow_table.slice(
129134
self.cur_row_index, self.n_valid_rows - self.cur_row_index
130135
)
@@ -184,7 +189,7 @@ def __init__(
184189
self.table = self._create_next_table()
185190
self.table_row_index = 0
186191

187-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
192+
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
188193
"""
189194
Get up to the next n rows of the cloud fetch Arrow dataframes.
190195
@@ -216,7 +221,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
216221
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
217222
return results
218223

219-
def remaining_rows(self) -> pyarrow.Table:
224+
def remaining_rows(self) -> "pyarrow.Table":
220225
"""
221226
Get all remaining rows of the cloud fetch Arrow dataframes.
222227
@@ -237,7 +242,7 @@ def remaining_rows(self) -> pyarrow.Table:
237242
self.table_row_index = 0
238243
return results
239244

240-
def _create_next_table(self) -> Union[pyarrow.Table, None]:
245+
def _create_next_table(self) -> Union["pyarrow.Table", None]:
241246
logger.debug(
242247
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
243248
self.start_row_index
@@ -276,7 +281,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]:
276281

277282
return arrow_table
278283

279-
def _create_empty_table(self) -> pyarrow.Table:
284+
def _create_empty_table(self) -> "pyarrow.Table":
280285
# Create a 0-row table with just the schema bytes
281286
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
282287

@@ -515,7 +520,7 @@ def transform_paramstyle(
515520
return output
516521

517522

518-
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table:
523+
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table":
519524
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
520525
return convert_decimals_in_arrow_table(arrow_table, description)
521526

@@ -542,7 +547,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
542547
return arrow_table, n_rows
543548

544549

545-
def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
550+
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
546551
for i, col in enumerate(table.itercolumns()):
547552
if description[i][1] == "decimal":
548553
decimal_col = col.to_pandas().apply(

0 commit comments

Comments
 (0)