Skip to content

Commit 335fc0c

Browse files
committed
Implemented ColumnQueue to test the fetchall without pyarrow
Removed token removed token
1 parent b438c38 commit 335fc0c

File tree

12 files changed

+227
-7
lines changed

12 files changed

+227
-7
lines changed

.idea/.gitignore

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/codeStyles/Project.xml

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/codeStyles/codeStyleConfig.xml

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/databricks-sql-python.iml

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

check.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
import sys
3+
# import logging
4+
#
5+
# logging.basicConfig(level=logging.DEBUG)
6+
7+
#
8+
# # Get the parent directory of the current file
9+
# target_folder_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "databricks-sql-python", "src"))
10+
#
11+
# # Add the parent directory to sys.path
12+
# sys.path.append(target_folder_path)
13+
14+
from src.databricks import sql
15+
16+
# from dotenv import load_dotenv
17+
18+
# export DATABRICKS_TOKEN=whatever
19+
20+
21+
# Load environment variables from .env file
22+
# load_dotenv()
23+
24+
host = "e2-dogfood.staging.cloud.databricks.com"
25+
http_path = "/sql/1.0/warehouses/dd43ee29fedd958d"
26+
27+
access_token = ""
28+
connection = sql.connect(
29+
server_hostname=host,
30+
http_path=http_path,
31+
access_token=access_token)
32+
33+
34+
cursor = connection.cursor()
35+
cursor.execute('SELECT :param `p`, * FROM RANGE(10)', {"param": "foo"})
36+
# cursor.execute('SELECT 1')
37+
result = cursor.fetchall()
38+
for row in result:
39+
print(row)
40+
41+
cursor.close()
42+
connection.close()

src/databricks/sql/client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,9 @@ def execute(
777777
use_cloud_fetch=self.connection.use_cloud_fetch,
778778
parameters=prepared_params,
779779
)
780+
781+
print("Line 781")
782+
print(execute_response)
780783
self.active_result_set = ResultSet(
781784
self.connection,
782785
execute_response,
@@ -1129,6 +1132,20 @@ def _fill_results_buffer(self):
11291132
self.results = results
11301133
self.has_more_rows = has_more_rows
11311134

1135+
def _convert_columnar_table(self, table):
1136+
column_names = [c[0] for c in self.description]
1137+
ResultRow = Row(*column_names)
1138+
1139+
result = []
1140+
for row_index in range(len(table[0])):
1141+
curr_row = []
1142+
for col_index in range(len(table)-1, -1, -1):
1143+
curr_row.append(table[col_index][row_index])
1144+
result.append(ResultRow(*curr_row))
1145+
1146+
return result
1147+
1148+
11321149
def _convert_arrow_table(self, table):
11331150
column_names = [c[0] for c in self.description]
11341151
ResultRow = Row(*column_names)
@@ -1209,6 +1226,11 @@ def fetchall_arrow(self) -> pyarrow.Table:
12091226

12101227
return results
12111228

1229+
def fetchall_columnar(self):
1230+
results = self.results.remaining_rows()
1231+
self._next_row_index += len(results[0])
1232+
return results
1233+
12121234
def fetchone(self) -> Optional[Row]:
12131235
"""
12141236
Fetch the next row of a query result set, returning a single sequence,
@@ -1224,6 +1246,9 @@ def fetchall(self) -> List[Row]:
12241246
"""
12251247
Fetch all (remaining) rows of a query result, returning them as a list of rows.
12261248
"""
1249+
1250+
return self._convert_columnar_table(self.fetchall_columnar())
1251+
12271252
return self._convert_arrow_table(self.fetchall_arrow())
12281253

12291254
def fetchmany(self, size: int) -> List[Row]:

src/databricks/sql/thrift_api/TCLIService/ttypes.py

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/databricks/sql/thrift_backend.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
convert_column_based_set_to_arrow_table,
3838
)
3939

40+
from src.databricks.sql.thrift_api.TCLIService.ttypes import TDBSqlResultFormat
41+
4042
logger = logging.getLogger(__name__)
4143

4244
unsafe_logger = logging.getLogger("databricks.sql.unsafe")
@@ -734,6 +736,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
734736
else:
735737
t_result_set_metadata_resp = self._get_metadata_resp(resp.operationHandle)
736738

739+
print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}")
737740
if t_result_set_metadata_resp.resultFormat not in [
738741
ttypes.TSparkRowSetType.ARROW_BASED_SET,
739742
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
@@ -858,15 +861,25 @@ def execute_command(
858861
getDirectResults=ttypes.TSparkGetDirectResults(
859862
maxRows=max_rows, maxBytes=max_bytes
860863
),
861-
canReadArrowResult=True,
864+
canReadArrowResult=False,
862865
canDecompressLZ4Result=lz4_compression,
863866
canDownloadResult=use_cloud_fetch,
864867
confOverlay={
865868
# We want to receive proper Timestamp arrow types.
866869
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
867870
},
868-
useArrowNativeTypes=spark_arrow_types,
869-
parameters=parameters,
871+
# useArrowNativeTypes=spark_arrow_types,
872+
# canReadArrowResult=True,
873+
# # canDecompressLZ4Result=lz4_compression,
874+
# canDecompressLZ4Result=False,
875+
# canDownloadResult=False,
876+
# # confOverlay={
877+
# # # We want to receive proper Timestamp arrow types.
878+
# # "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
879+
# # },
880+
# resultDataFormat=TDBSqlResultFormat(None,None,True),
881+
# # useArrowNativeTypes=spark_arrow_types,
882+
parameters=parameters,
870883
)
871884
resp = self.make_request(self._client.ExecuteStatement, req)
872885
return self._handle_execute_response(resp, cursor)

src/databricks/sql/utils.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
2-
2+
import json
3+
from thrift.protocol import TJSONProtocol
4+
from thrift.transport import TTransport
35
import datetime
46
import decimal
57
from abc import ABC, abstractmethod
@@ -33,15 +35,17 @@
3335

3436
class ResultSetQueue(ABC):
3537
@abstractmethod
36-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
38+
def next_n_rows(self, num_rows: int):
3739
pass
3840

3941
@abstractmethod
40-
def remaining_rows(self) -> pyarrow.Table:
42+
def remaining_rows(self):
4143
pass
4244

4345

4446
class ResultSetQueueFactory(ABC):
47+
48+
4549
@staticmethod
4650
def build_queue(
4751
row_set_type: TSparkRowSetType,
@@ -67,6 +71,18 @@ def build_queue(
6771
Returns:
6872
ResultSetQueue
6973
"""
74+
75+
def trow_to_json(trow):
76+
# Step 1: Serialize TRow using Thrift's TJSONProtocol
77+
transport = TTransport.TMemoryBuffer()
78+
protocol = TJSONProtocol.TJSONProtocol(transport)
79+
trow.write(protocol)
80+
81+
# Step 2: Extract JSON string from the transport
82+
json_str = transport.getvalue().decode('utf-8')
83+
84+
return json_str
85+
7086
if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
7187
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
7288
t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes
@@ -76,6 +92,23 @@ def build_queue(
7692
)
7793
return ArrowQueue(converted_arrow_table, n_valid_rows)
7894
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
95+
print("Lin 79 ")
96+
print(type(t_row_set))
97+
print(t_row_set)
98+
json_str = json.loads(trow_to_json(t_row_set))
99+
pretty_json = json.dumps(json_str, indent=2)
100+
print(pretty_json)
101+
102+
converted_column_table, column_names = convert_column_based_set_to_column_table(
103+
t_row_set.columns,
104+
description)
105+
print(converted_column_table, column_names)
106+
107+
return ColumnQueue(converted_column_table, column_names)
108+
109+
print(columnQueue.next_n_rows(2))
110+
print(columnQueue.next_n_rows(2))
111+
print(columnQueue.remaining_rows())
79112
arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
80113
t_row_set.columns, description
81114
)
@@ -97,6 +130,28 @@ def build_queue(
97130
raise AssertionError("Row set type is not valid")
98131

99132

133+
class ColumnQueue(ResultSetQueue):
134+
def __init__(
135+
self,
136+
columnar_table, column_names):
137+
self.columnar_table = columnar_table
138+
self.cur_row_index = 0
139+
self.n_valid_rows = len(columnar_table[0])
140+
self.column_names = column_names
141+
142+
def next_n_rows(self, num_rows):
143+
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
144+
# Slicing using the default python slice
145+
next_data = [column[self.cur_row_index:self.cur_row_index+length] for column in self.columnar_table]
146+
self.cur_row_index += length
147+
return next_data
148+
149+
def remaining_rows(self):
150+
next_data = [column[self.cur_row_index:] for column in self.columnar_table]
151+
self.cur_row_index += len(next_data[0])
152+
return next_data
153+
154+
100155
class ArrowQueue(ResultSetQueue):
101156
def __init__(
102157
self,
@@ -570,6 +625,13 @@ def convert_column_based_set_to_arrow_table(columns, description):
570625
)
571626
return arrow_table, arrow_table.num_rows
572627

628+
def convert_column_based_set_to_column_table(columns, description):
629+
column_names = [c[0] for c in description]
630+
column_table = [_covert_column_to_list(c) for c in columns]
631+
632+
return column_table, column_names
633+
634+
573635

574636
def _convert_column_to_arrow_array(t_col):
575637
"""
@@ -594,6 +656,15 @@ def _convert_column_to_arrow_array(t_col):
594656

595657
raise OperationalError("Empty TColumn instance {}".format(t_col))
596658

659+
def _covert_column_to_list(t_col):
660+
supported_field_types = ("boolVal", "byteVal", "i16Val", "i32Val", "i64Val", "doubleVal", "stringVal", "binaryVal")
661+
662+
for field in supported_field_types:
663+
wrapper = getattr(t_col, field)
664+
if wrapper:
665+
return _create_python_tuple(wrapper)
666+
667+
raise OperationalError("Empty TColumn instance {}".format(t_col))
597668

598669
def _create_arrow_array(t_col_value_wrapper, arrow_type):
599670
result = t_col_value_wrapper.values
@@ -609,3 +680,18 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
609680
result[i] = None
610681

611682
return pyarrow.array(result, type=arrow_type)
683+
684+
def _create_python_tuple(t_col_value_wrapper):
685+
result = t_col_value_wrapper.values
686+
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
687+
assert isinstance(nulls, bytes)
688+
689+
# The number of bits in nulls can be both larger or smaller than the number of
690+
# elements in result, so take the minimum of both to iterate over.
691+
length = min(len(result), len(nulls) * 8)
692+
693+
for i in range(length):
694+
if nulls[i >> 3] & BIT_MASKS[i & 0x7]:
695+
result[i] = None
696+
697+
return tuple(result)

0 commit comments

Comments
 (0)