Skip to content

Commit a87e2cb

Browse files
committed
Introduced the Column Table structure
1 parent 6a8646d commit a87e2cb

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed

src/databricks/sql/client.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ParamEscaper,
2626
inject_parameters,
2727
transform_paramstyle,
28-
ArrowQueue,
28+
ColumnTable,
2929
ColumnQueue
3030
)
3131
from databricks.sql.parameters.native import (
@@ -1152,10 +1152,10 @@ def _convert_columnar_table(self, table):
11521152
column_names = [c[0] for c in self.description]
11531153
ResultRow = Row(*column_names)
11541154
result = []
1155-
for row_index in range(len(table[0])):
1155+
for row_index in range(table.num_rows):
11561156
curr_row = []
1157-
for col_index in range(len(table)):
1158-
curr_row.append(table[col_index][row_index])
1157+
for col_index in range(table.num_columns):
1158+
curr_row.append(table.get_item(col_index, row_index))
11591159
result.append(ResultRow(*curr_row))
11601160

11611161
return result
@@ -1235,11 +1235,11 @@ def merge_columnar(self, result1, result2):
12351235
:return:
12361236
"""
12371237

1238-
if len(result1) != len(result2):
1239-
raise ValueError("The number of columns in both results must be the same")
1238+
if result1.column_names != result2.column_names:
1239+
raise ValueError("The columns in the results don't match")
12401240

1241-
merged_result = [result1[i] + result2[i] for i in range(len(result1))]
1242-
return merged_result
1241+
merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)]
1242+
return ColumnTable(merged_result, result1.column_names)
12431243

12441244
def fetchmany_columnar(self, size: int):
12451245
"""
@@ -1250,8 +1250,8 @@ def fetchmany_columnar(self, size: int):
12501250
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
12511251

12521252
results = self.results.next_n_rows(size)
1253-
n_remaining_rows = size - len(results[0])
1254-
self._next_row_index += len(results[0])
1253+
n_remaining_rows = size - results.num_rows
1254+
self._next_row_index += results.num_rows
12551255

12561256
while (
12571257
n_remaining_rows > 0
@@ -1261,8 +1261,8 @@ def fetchmany_columnar(self, size: int):
12611261
self._fill_results_buffer()
12621262
partial_results = self.results.next_n_rows(n_remaining_rows)
12631263
results = self.merge_columnar(results, partial_results)
1264-
n_remaining_rows -= len(partial_results[0])
1265-
self._next_row_index += len(partial_results[0])
1264+
n_remaining_rows -= partial_results.num_rows
1265+
self._next_row_index += partial_results.num_rows
12661266

12671267
return results
12681268

@@ -1282,13 +1282,13 @@ def fetchall_arrow(self) -> "pyarrow.Table":
12821282
def fetchall_columnar(self):
12831283
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
12841284
results = self.results.remaining_rows()
1285-
self._next_row_index += len(results[0])
1285+
self._next_row_index += results.num_rows
12861286

12871287
while not self.has_been_closed_server_side and self.has_more_rows:
12881288
self._fill_results_buffer()
12891289
partial_results = self.results.remaining_rows()
12901290
results = self.merge_columnar(results, partial_results)
1291-
self._next_row_index += len(partial_results[0])
1291+
self._next_row_index += partial_results.num_rows
12921292

12931293
return results
12941294

src/databricks/sql/utils.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def build_queue(
8888
column_table, description
8989
)
9090

91-
return ColumnQueue(converted_column_table, column_names)
91+
return ColumnQueue(ColumnTable(converted_column_table, column_names))
9292
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
9393
return CloudFetchQueue(
9494
schema_bytes=arrow_schema_bytes,
@@ -102,27 +102,47 @@ def build_queue(
102102
else:
103103
raise AssertionError("Row set type is not valid")
104104

105+
class ColumnTable:
106+
def __init__(self, column_table, column_names):
107+
self.column_table = column_table
108+
self.column_names = column_names
109+
110+
@property
111+
def num_rows(self):
112+
if len(self.column_table) == 0:
113+
return 0
114+
else:
115+
return len(self.column_table[0])
116+
117+
@property
118+
def num_columns(self):
119+
return len(self.column_names)
120+
121+
def get_item(self, col_index, row_index):
122+
return self.column_table[col_index][row_index]
123+
124+
def slice(self, curr_index, length):
125+
sliced_column_table = [column[curr_index : curr_index + length] for column in self.column_table]
126+
return ColumnTable(sliced_column_table, self.column_names)
127+
128+
105129
class ColumnQueue(ResultSetQueue):
106-
def __init__(self, columnar_table, column_names):
107-
self.columnar_table = columnar_table
130+
def __init__(self, column_table: ColumnTable):
131+
self.column_table = column_table
108132
self.cur_row_index = 0
109-
self.n_valid_rows = len(columnar_table[0])
110-
self.column_names = column_names
133+
self.n_valid_rows = column_table.num_rows
111134

112135
def next_n_rows(self, num_rows):
113136
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
114-
# Slicing using the default python slice
115-
next_data = [
116-
column[self.cur_row_index : self.cur_row_index + length]
117-
for column in self.columnar_table
118-
]
119-
self.cur_row_index += length
120-
return next_data
137+
138+
slice = self.column_table.slice(self.cur_row_index, length)
139+
self.cur_row_index += slice.num_rows
140+
return slice
121141

122142
def remaining_rows(self):
123-
next_data = [column[self.cur_row_index :] for column in self.columnar_table]
124-
self.cur_row_index += len(next_data[0])
125-
return next_data
143+
slice = self.column_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index)
144+
self.cur_row_index += slice.num_rows
145+
return slice
126146

127147

128148
class ArrowQueue(ResultSetQueue):

0 commit comments

Comments
 (0)