Skip to content

Commit 950de1e

Browse files
committed
Improvements to fetch results databricks#1
1 parent f9b7f43 commit 950de1e

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/databricks/sql/utils.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -611,21 +611,32 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
611611

612612

613613
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
614+
new_columns = []
615+
new_fields = []
616+
614617
for i, col in enumerate(table.itercolumns()):
618+
field = table.field(i)
619+
615620
if description[i][1] == "decimal":
616-
decimal_col = col.to_pandas().apply(
617-
lambda v: v if v is None else Decimal(v)
618-
)
619621
precision, scale = description[i][4], description[i][5]
620622
assert scale is not None
621623
assert precision is not None
622-
# Spark limits decimal to a maximum scale of 38,
623-
# so 128 is guaranteed to be big enough
624+
# create the target decimal type
624625
dtype = pyarrow.decimal128(precision, scale)
625-
col_data = pyarrow.array(decimal_col, type=dtype)
626-
field = table.field(i).with_type(dtype)
627-
table = table.set_column(i, field, col_data)
628-
return table
626+
627+
# convert the column directly using PyArrow's cast operation
628+
new_col = col.cast(dtype)
629+
new_field = field.with_type(dtype)
630+
631+
new_columns.append(new_col)
632+
new_fields.append(new_field)
633+
else:
634+
new_columns.append(col)
635+
new_fields.append(field)
636+
637+
new_schema = pyarrow.schema(new_fields)
638+
639+
return pyarrow.Table.from_arrays(new_columns, schema=new_schema)
629640

630641

631642
def convert_to_assigned_datatypes_in_column_table(column_table, description):

0 commit comments

Comments
 (0)