Skip to content

Commit 4099939

Browse files
authored
[PECO-1803] Splitting the PySql connector into the core and the non core part (#417)
* Implemented ColumnQueue to test the fetchall without pyarrow Removed token removed token * order of fields in row corrected * Changed the folder structure and tested the basic setup to work * Refractored the code to make connector to work * Basic Setup of connector, core and sqlalchemy is working * Basic integration of core, connect and sqlalchemy is working * Setup working dynamic change from ColumnQueue to ArrowQueue * Refractored the test code and moved to respective folders * Added the unit test for column_queue Fixed __version__ Fix * venv_main added to git ignore * Added code for merging columnar table * Merging code for columnar * Fixed the retry_close sesssion test issue with logging * Fixed the databricks_sqlalchemy tests and introduced pytest.ini for the sqla_testing * Added pyarrow_test mark on pytest * Fixed databricks.sqlalchemy to databricks_sqlalchemy imports * Added poetry.lock * Added dist folder * Changed the pyproject.toml * Minor Fix * Added the pyarrow skip tag on unit tests and tested their working * Fixed the Decimal and timestamp conversion issue in non arrow pipeline * Removed not required files and reformatted * Fixed test_retry error * Changed the folder structure to src / databricks * Removed the columnar non arrow flow to another PR * Moved the README to the root * removed columnQueue instance * Revmoved databricks_sqlalchemy dependency in core * Changed the pysql_supports_arrow predicate, introduced changes in the pyproject.toml * Ran the black formatter with the original version * Extra .py removed from all the __init__.py files names * Undo formatting check * Check * Check * Check * Check * Check * Check * Check * Check * Check * Check * Check * Check * Check * Check * BIG UPDATE * Refeactor code * Refractor * Fixed versioning * Minor refractoring * Minor refractoring
1 parent 9cb1ea3 commit 4099939

File tree

89 files changed

+162
-4232
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+162
-4232
lines changed
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[tool.poetry]
2+
name = "databricks-sql-connector"
3+
version = "3.5.0"
4+
description = "Databricks SQL Connector for Python"
5+
authors = ["Databricks <[email protected]>"]
6+
license = "Apache-2.0"
7+
8+
9+
[tool.poetry.dependencies]
10+
databricks_sql_connector_core = { version = ">=1.0.0", extras=["all"]}
11+
databricks_sqlalchemy = { version = ">=1.0.0", optional = true }
12+
13+
[tool.poetry.extras]
14+
databricks_sqlalchemy = ["databricks_sqlalchemy"]
15+
16+
[tool.poetry.urls]
17+
"Homepage" = "https://github.com/databricks/databricks-sql-python"
18+
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"
19+
20+
[build-system]
21+
requires = ["poetry-core>=1.0.0"]
22+
build-backend = "poetry.core.masonry.api"
23+
File renamed without changes.
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,26 @@
11
[tool.poetry]
2-
name = "databricks-sql-connector"
3-
version = "3.3.0"
4-
description = "Databricks SQL Connector for Python"
2+
name = "databricks-sql-connector-core"
3+
version = "1.0.0"
4+
description = "Databricks SQL Connector core for Python"
55
authors = ["Databricks <[email protected]>"]
6-
license = "Apache-2.0"
7-
readme = "README.md"
86
packages = [{ include = "databricks", from = "src" }]
9-
include = ["CHANGELOG.md"]
107

118
[tool.poetry.dependencies]
129
python = "^3.8.0"
1310
thrift = ">=0.16.0,<0.21.0"
1411
pandas = [
1512
{ version = ">=1.2.5,<2.3.0", python = ">=3.8" }
1613
]
17-
pyarrow = ">=14.0.1,<17"
18-
1914
lz4 = "^4.0.2"
2015
requests = "^2.18.1"
2116
oauthlib = "^3.1.0"
22-
numpy = [
23-
{ version = "^1.16.6", python = ">=3.8,<3.11" },
24-
{ version = "^1.23.4", python = ">=3.11" },
25-
]
26-
sqlalchemy = { version = ">=2.0.21", optional = true }
2717
openpyxl = "^3.0.10"
2818
alembic = { version = "^1.0.11", optional = true }
2919
urllib3 = ">=1.26"
20+
pyarrow = {version = ">=14.0.1,<17", optional = true}
3021

3122
[tool.poetry.extras]
32-
sqlalchemy = ["sqlalchemy"]
33-
alembic = ["sqlalchemy", "alembic"]
23+
pyarrow = ["pyarrow"]
3424

3525
[tool.poetry.dev-dependencies]
3626
pytest = "^7.1.2"
@@ -43,8 +33,6 @@ pytest-dotenv = "^0.5.2"
4333
"Homepage" = "https://github.com/databricks/databricks-sql-python"
4434
"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues"
4535

46-
[tool.poetry.plugins."sqlalchemy.dialects"]
47-
"databricks" = "databricks.sqlalchemy:DatabricksDialect"
4836

4937
[build-system]
5038
requires = ["poetry-core>=1.0.0"]
@@ -62,5 +50,5 @@ markers = {"reviewed" = "Test case has been reviewed by Databricks"}
6250
minversion = "6.0"
6351
log_cli = "false"
6452
log_cli_level = "INFO"
65-
testpaths = ["tests", "src/databricks/sqlalchemy/test_local"]
53+
testpaths = ["tests", "databricks_sql_connector_core/tests"]
6654
env_files = ["test.env"]

src/databricks/sql/client.py renamed to databricks_sql_connector_core/src/databricks/sql/client.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
22

33
import pandas
4-
import pyarrow
54
import requests
65
import json
76
import os
@@ -43,6 +42,10 @@
4342
TSparkParameter,
4443
)
4544

45+
try:
46+
import pyarrow
47+
except ImportError:
48+
pyarrow = None
4649

4750
logger = logging.getLogger(__name__)
4851

@@ -977,14 +980,14 @@ def fetchmany(self, size: int) -> List[Row]:
977980
else:
978981
raise Error("There is no active result set")
979982

980-
def fetchall_arrow(self) -> pyarrow.Table:
983+
def fetchall_arrow(self) -> "pyarrow.Table":
981984
self._check_not_closed()
982985
if self.active_result_set:
983986
return self.active_result_set.fetchall_arrow()
984987
else:
985988
raise Error("There is no active result set")
986989

987-
def fetchmany_arrow(self, size) -> pyarrow.Table:
990+
def fetchmany_arrow(self, size) -> "pyarrow.Table":
988991
self._check_not_closed()
989992
if self.active_result_set:
990993
return self.active_result_set.fetchmany_arrow(size)
@@ -1171,7 +1174,7 @@ def _convert_arrow_table(self, table):
11711174
def rownumber(self):
11721175
return self._next_row_index
11731176

1174-
def fetchmany_arrow(self, size: int) -> pyarrow.Table:
1177+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
11751178
"""
11761179
Fetch the next set of rows of a query result, returning a PyArrow table.
11771180
@@ -1196,7 +1199,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:
11961199

11971200
return results
11981201

1199-
def fetchall_arrow(self) -> pyarrow.Table:
1202+
def fetchall_arrow(self) -> "pyarrow.Table":
12001203
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
12011204
results = self.results.remaining_rows()
12021205
self._next_row_index += results.num_rows

src/databricks/sql/thrift_backend.py renamed to databricks_sql_connector_core/src/databricks/sql/thrift_backend.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
99
from typing import List, Union
1010

11-
import pyarrow
1211
import thrift.transport.THttpClient
1312
import thrift.protocol.TBinaryProtocol
1413
import thrift.transport.TSocket
@@ -37,6 +36,11 @@
3736
convert_column_based_set_to_arrow_table,
3837
)
3938

39+
try:
40+
import pyarrow
41+
except ImportError:
42+
pyarrow = None
43+
4044
logger = logging.getLogger(__name__)
4145

4246
unsafe_logger = logging.getLogger("databricks.sql.unsafe")
@@ -652,6 +656,12 @@ def _get_metadata_resp(self, op_handle):
652656

653657
@staticmethod
654658
def _hive_schema_to_arrow_schema(t_table_schema):
659+
660+
if pyarrow is None:
661+
raise ImportError(
662+
"pyarrow is required to convert Hive schema to Arrow schema"
663+
)
664+
655665
def map_type(t_type_entry):
656666
if t_type_entry.primitiveEntry:
657667
return {
@@ -858,7 +868,7 @@ def execute_command(
858868
getDirectResults=ttypes.TSparkGetDirectResults(
859869
maxRows=max_rows, maxBytes=max_bytes
860870
),
861-
canReadArrowResult=True,
871+
canReadArrowResult=True if pyarrow else False,
862872
canDecompressLZ4Result=lz4_compression,
863873
canDownloadResult=use_cloud_fetch,
864874
confOverlay={

src/databricks/sql/utils.py renamed to databricks_sql_connector_core/src/databricks/sql/utils.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ssl import SSLContext
1313

1414
import lz4.frame
15-
import pyarrow
1615

1716
from databricks.sql import OperationalError, exc
1817
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
@@ -28,16 +27,21 @@
2827

2928
import logging
3029

30+
try:
31+
import pyarrow
32+
except ImportError:
33+
pyarrow = None
34+
3135
logger = logging.getLogger(__name__)
3236

3337

3438
class ResultSetQueue(ABC):
3539
@abstractmethod
36-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
40+
def next_n_rows(self, num_rows: int):
3741
pass
3842

3943
@abstractmethod
40-
def remaining_rows(self) -> pyarrow.Table:
44+
def remaining_rows(self):
4145
pass
4246

4347

@@ -100,7 +104,7 @@ def build_queue(
100104
class ArrowQueue(ResultSetQueue):
101105
def __init__(
102106
self,
103-
arrow_table: pyarrow.Table,
107+
arrow_table: "pyarrow.Table",
104108
n_valid_rows: int,
105109
start_row_index: int = 0,
106110
):
@@ -115,7 +119,7 @@ def __init__(
115119
self.arrow_table = arrow_table
116120
self.n_valid_rows = n_valid_rows
117121

118-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
122+
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
119123
"""Get upto the next n rows of the Arrow dataframe"""
120124
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
121125
# Note that the table.slice API is not the same as Python's slice
@@ -124,7 +128,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
124128
self.cur_row_index += slice.num_rows
125129
return slice
126130

127-
def remaining_rows(self) -> pyarrow.Table:
131+
def remaining_rows(self) -> "pyarrow.Table":
128132
slice = self.arrow_table.slice(
129133
self.cur_row_index, self.n_valid_rows - self.cur_row_index
130134
)
@@ -184,7 +188,7 @@ def __init__(
184188
self.table = self._create_next_table()
185189
self.table_row_index = 0
186190

187-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
191+
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
188192
"""
189193
Get up to the next n rows of the cloud fetch Arrow dataframes.
190194
@@ -216,7 +220,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
216220
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
217221
return results
218222

219-
def remaining_rows(self) -> pyarrow.Table:
223+
def remaining_rows(self) -> "pyarrow.Table":
220224
"""
221225
Get all remaining rows of the cloud fetch Arrow dataframes.
222226
@@ -237,7 +241,7 @@ def remaining_rows(self) -> pyarrow.Table:
237241
self.table_row_index = 0
238242
return results
239243

240-
def _create_next_table(self) -> Union[pyarrow.Table, None]:
244+
def _create_next_table(self) -> Union["pyarrow.Table", None]:
241245
logger.debug(
242246
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
243247
self.start_row_index
@@ -276,7 +280,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]:
276280

277281
return arrow_table
278282

279-
def _create_empty_table(self) -> pyarrow.Table:
283+
def _create_empty_table(self) -> "pyarrow.Table":
280284
# Create a 0-row table with just the schema bytes
281285
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
282286

@@ -515,7 +519,7 @@ def transform_paramstyle(
515519
return output
516520

517521

518-
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table:
522+
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table":
519523
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
520524
return convert_decimals_in_arrow_table(arrow_table, description)
521525

@@ -542,7 +546,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
542546
return arrow_table, n_rows
543547

544548

545-
def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
549+
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
546550
for i, col in enumerate(table.itercolumns()):
547551
if description[i][1] == "decimal":
548552
decimal_col = col.to_pandas().apply(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
try:
2+
from databricks_sqlalchemy import *
3+
except:
4+
import warnings
5+
6+
warnings.warn("Install databricks-sqlalchemy plugin before using this")
File renamed without changes.

tests/e2e/common/decimal_tests.py renamed to databricks_sql_connector_core/tests/e2e/common/decimal_tests.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from decimal import Decimal
22

3-
import pyarrow
43
import pytest
54

5+
try:
6+
import pyarrow
7+
except ImportError:
8+
pyarrow = None
69

7-
class DecimalTestsMixin:
8-
decimal_and_expected_results = [
10+
from tests.e2e.common.predicates import pysql_supports_arrow
11+
12+
def decimal_and_expected_results():
13+
14+
if pyarrow is None:
15+
return []
16+
17+
return [
918
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
1019
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
1120
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
@@ -17,7 +26,12 @@ class DecimalTestsMixin:
1726
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
1827
]
1928

20-
multi_decimals_and_expected_results = [
29+
def multi_decimals_and_expected_results():
30+
31+
if pyarrow is None:
32+
return []
33+
34+
return [
2135
(
2236
["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
2337
[Decimal("1.00"), Decimal("100.001"), None],
@@ -30,7 +44,9 @@ class DecimalTestsMixin:
3044
),
3145
]
3246

33-
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results)
47+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
48+
class DecimalTestsMixin:
49+
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results())
3450
def test_decimals(self, decimal, expected_value, expected_type):
3551
with self.cursor({}) as cursor:
3652
query = "SELECT CAST ({})".format(decimal)
@@ -39,9 +55,7 @@ def test_decimals(self, decimal, expected_value, expected_type):
3955
assert table.field(0).type == expected_type
4056
assert table.to_pydict().popitem()[1][0] == expected_value
4157

42-
@pytest.mark.parametrize(
43-
"decimals, expected_values, expected_type", multi_decimals_and_expected_results
44-
)
58+
@pytest.mark.parametrize("decimals, expected_values, expected_type", multi_decimals_and_expected_results())
4559
def test_multi_decimals(self, decimals, expected_values, expected_type):
4660
with self.cursor({}) as cursor:
4761
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])

tests/e2e/common/large_queries_mixin.py renamed to databricks_sql_connector_core/tests/e2e/common/large_queries_mixin.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import logging
22
import math
33
import time
4+
from unittest import skipUnless
5+
6+
import pytest
7+
from tests.e2e.common.predicates import pysql_supports_arrow
48

59
log = logging.getLogger(__name__)
610

@@ -40,6 +44,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
4044
+ "assuming 10K fetch size."
4145
)
4246

47+
@pytest.mark.skipif(not pysql_supports_arrow(), "Without pyarrow lz4 compression is not supported")
4348
def test_query_with_large_wide_result_set(self):
4449
resultSize = 300 * 1000 * 1000 # 300 MB
4550
width = 8192 # B

tests/e2e/common/predicates.py renamed to databricks_sql_connector_core/tests/e2e/common/predicates.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88

99

1010
def pysql_supports_arrow():
11-
"""Import databricks.sql and test whether Cursor has fetchall_arrow."""
12-
from databricks.sql.client import Cursor
13-
return hasattr(Cursor, 'fetchall_arrow')
11+
"""Checks if the pyarrow library is installed or not"""
12+
try:
13+
import pyarrow
14+
15+
return True
16+
except ImportError:
17+
return False
1418

1519

1620
def pysql_has_version(compare, version):

0 commit comments

Comments
 (0)