Skip to content

Commit 882e080

Browse files
committed
Changed the folder structure and tested the basic setup to work
1 parent ad2b014 commit 882e080

File tree

89 files changed

+259
-116
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+259
-116
lines changed

.idea/databricks-sql-python.iml

Lines changed: 2 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

check.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# # Add the parent directory to sys.path
1212
# sys.path.append(target_folder_path)
1313

14-
from src.databricks import sql
14+
from databricks import sql
1515

1616
# from dotenv import load_dotenv
1717

@@ -22,7 +22,7 @@
2222
# load_dotenv()
2323

2424
host = "e2-dogfood.staging.cloud.databricks.com"
25-
http_path = "/sql/1.0/warehouses/dd43ee29fedd958d"
25+
http_path = "/sql/1.0/warehouses/58aa1b363649e722"
2626

2727
access_token = ""
2828
connection = sql.connect(

README.md renamed to databricks_sql_connector_core/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789
6565

6666
## Contributing
6767

68-
See [CONTRIBUTING.md](CONTRIBUTING.md)
68+
See [CONTRIBUTING.md](../CONTRIBUTING.md)
6969

7070
## License
7171

72-
[Apache License 2.0](LICENSE)
72+
[Apache License 2.0](../LICENSE)

pyproject.toml renamed to databricks_sql_connector_core/pyproject.toml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
2-
name = "databricks-sql-connector"
3-
version = "3.3.0"
4-
description = "Databricks SQL Connector for Python"
2+
name = "databricks-sql-connector-core"
3+
version = "1.0.0"
4+
description = "Databricks SQL Connector core for Python"
55
authors = ["Databricks <[email protected]>"]
66
license = "Apache-2.0"
77
readme = "README.md"
@@ -14,15 +14,15 @@ thrift = ">=0.16.0,<0.21.0"
1414
pandas = [
1515
{ version = ">=1.2.5,<2.2.0", python = ">=3.8" }
1616
]
17-
pyarrow = ">=14.0.1,<17"
17+
#pyarrow = ">=14.0.1,<17"
1818

1919
lz4 = "^4.0.2"
2020
requests = "^2.18.1"
2121
oauthlib = "^3.1.0"
22-
numpy = [
23-
{ version = "^1.16.6", python = ">=3.8,<3.11" },
24-
{ version = "^1.23.4", python = ">=3.11" },
25-
]
22+
#numpy = [
23+
# { version = "^1.16.6", python = ">=3.8,<3.11" },
24+
# { version = "^1.23.4", python = ">=3.11" },
25+
#]
2626
sqlalchemy = { version = ">=2.0.21", optional = true }
2727
openpyxl = "^3.0.10"
2828
alembic = { version = "^1.0.11", optional = true }
@@ -56,11 +56,11 @@ exclude = ['ttypes\.py$', 'TCLIService\.py$']
5656

5757
[tool.black]
5858
exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/'
59-
60-
[tool.pytest.ini_options]
61-
markers = {"reviewed" = "Test case has been reviewed by Databricks"}
62-
minversion = "6.0"
63-
log_cli = "false"
64-
log_cli_level = "INFO"
65-
testpaths = ["tests", "src/databricks/sqlalchemy/test_local"]
66-
env_files = ["test.env"]
59+
#
60+
#[tool.pytest.ini_options]
61+
#markers = {"reviewed" = "Test case has been reviewed by Databricks"}
62+
#minversion = "6.0"
63+
#log_cli = "false"
64+
#log_cli_level = "INFO"
65+
#testpaths = ["tests", "src/databricks/sqlalchemy/test_local"]
66+
#env_files = ["test.env"]

src/databricks/sql/client.py renamed to databricks_sql_connector_core/src/databricks/sql/client.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
22

33
import pandas
4-
import pyarrow
4+
try:
5+
import pyarrow
6+
except ImportError:
7+
pyarrow = None
8+
59
import requests
610
import json
711
import os
@@ -982,14 +986,14 @@ def fetchmany(self, size: int) -> List[Row]:
982986
else:
983987
raise Error("There is no active result set")
984988

985-
def fetchall_arrow(self) -> pyarrow.Table:
989+
def fetchall_arrow(self) -> "pyarrow.Table":
986990
self._check_not_closed()
987991
if self.active_result_set:
988992
return self.active_result_set.fetchall_arrow()
989993
else:
990994
raise Error("There is no active result set")
991995

992-
def fetchmany_arrow(self, size) -> pyarrow.Table:
996+
def fetchmany_arrow(self, size) -> "pyarrow.Table":
993997
self._check_not_closed()
994998
if self.active_result_set:
995999
return self.active_result_set.fetchmany_arrow(size)
@@ -1160,20 +1164,23 @@ def _convert_arrow_table(self, table):
11601164
# Need to use nullable types, as otherwise type can change when there are missing values.
11611165
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
11621166
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html
1163-
dtype_mapping = {
1164-
pyarrow.int8(): pandas.Int8Dtype(),
1165-
pyarrow.int16(): pandas.Int16Dtype(),
1166-
pyarrow.int32(): pandas.Int32Dtype(),
1167-
pyarrow.int64(): pandas.Int64Dtype(),
1168-
pyarrow.uint8(): pandas.UInt8Dtype(),
1169-
pyarrow.uint16(): pandas.UInt16Dtype(),
1170-
pyarrow.uint32(): pandas.UInt32Dtype(),
1171-
pyarrow.uint64(): pandas.UInt64Dtype(),
1172-
pyarrow.bool_(): pandas.BooleanDtype(),
1173-
pyarrow.float32(): pandas.Float32Dtype(),
1174-
pyarrow.float64(): pandas.Float64Dtype(),
1175-
pyarrow.string(): pandas.StringDtype(),
1176-
}
1167+
try:
1168+
dtype_mapping = {
1169+
pyarrow.int8(): pandas.Int8Dtype(),
1170+
pyarrow.int16(): pandas.Int16Dtype(),
1171+
pyarrow.int32(): pandas.Int32Dtype(),
1172+
pyarrow.int64(): pandas.Int64Dtype(),
1173+
pyarrow.uint8(): pandas.UInt8Dtype(),
1174+
pyarrow.uint16(): pandas.UInt16Dtype(),
1175+
pyarrow.uint32(): pandas.UInt32Dtype(),
1176+
pyarrow.uint64(): pandas.UInt64Dtype(),
1177+
pyarrow.bool_(): pandas.BooleanDtype(),
1178+
pyarrow.float32(): pandas.Float32Dtype(),
1179+
pyarrow.float64(): pandas.Float64Dtype(),
1180+
pyarrow.string(): pandas.StringDtype(),
1181+
}
1182+
except AttributeError:
1183+
print("pyarrow is not present")
11771184

11781185
# Need to rename columns, as the to_pandas function cannot handle duplicate column names
11791186
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)])
@@ -1190,7 +1197,7 @@ def _convert_arrow_table(self, table):
11901197
def rownumber(self):
11911198
return self._next_row_index
11921199

1193-
def fetchmany_arrow(self, size: int) -> pyarrow.Table:
1200+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
11941201
"""
11951202
Fetch the next set of rows of a query result, returning a PyArrow table.
11961203
@@ -1215,7 +1222,7 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:
12151222

12161223
return results
12171224

1218-
def fetchall_arrow(self) -> pyarrow.Table:
1225+
def fetchall_arrow(self) -> "pyarrow.Table":
12191226
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
12201227
results = self.results.remaining_rows()
12211228
self._next_row_index += results.num_rows

src/databricks/sql/thrift_backend.py renamed to databricks_sql_connector_core/src/databricks/sql/thrift_backend.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
99
from typing import List, Union
1010

11-
import pyarrow
11+
try:
12+
import pyarrow
13+
except ImportError:
14+
pyarrow = None
1215
import thrift.transport.THttpClient
1316
import thrift.protocol.TBinaryProtocol
1417
import thrift.transport.TSocket
@@ -37,7 +40,7 @@
3740
convert_column_based_set_to_arrow_table,
3841
)
3942

40-
from src.databricks.sql.thrift_api.TCLIService.ttypes import TDBSqlResultFormat
43+
# from databricks.sql import TDBSqlResultFormat
4144

4245
logger = logging.getLogger(__name__)
4346

@@ -654,6 +657,10 @@ def _get_metadata_resp(self, op_handle):
654657

655658
@staticmethod
656659
def _hive_schema_to_arrow_schema(t_table_schema):
660+
661+
if pyarrow is None:
662+
raise ImportError("pyarrow is required to convert Hive schema to Arrow schema")
663+
657664
def map_type(t_type_entry):
658665
if t_type_entry.primitiveEntry:
659666
return {
@@ -760,12 +767,17 @@ def _results_message_to_execute_response(self, resp, operation_state):
760767
description = self._hive_schema_to_description(
761768
t_result_set_metadata_resp.schema
762769
)
763-
schema_bytes = (
764-
t_result_set_metadata_resp.arrowSchema
765-
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
766-
.serialize()
767-
.to_pybytes()
768-
)
770+
771+
if pyarrow:
772+
schema_bytes = (
773+
t_result_set_metadata_resp.arrowSchema
774+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
775+
.serialize()
776+
.to_pybytes()
777+
)
778+
else:
779+
schema_bytes = None
780+
769781
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
770782
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
771783
if direct_results and direct_results.resultSet:

src/databricks/sql/utils.py renamed to databricks_sql_connector_core/src/databricks/sql/utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from ssl import SSLContext
1515

1616
import lz4.frame
17-
import pyarrow
17+
try:
18+
import pyarrow
19+
except ImportError:
20+
pyarrow = None
1821

1922
from databricks.sql import OperationalError, exc
2023
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
@@ -155,7 +158,7 @@ def remaining_rows(self):
155158
class ArrowQueue(ResultSetQueue):
156159
def __init__(
157160
self,
158-
arrow_table: pyarrow.Table,
161+
arrow_table: "pyarrow.Table",
159162
n_valid_rows: int,
160163
start_row_index: int = 0,
161164
):
@@ -170,7 +173,7 @@ def __init__(
170173
self.arrow_table = arrow_table
171174
self.n_valid_rows = n_valid_rows
172175

173-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
176+
def next_n_rows(self, num_rows: int) -> 'pyarrow.Table':
174177
"""Get upto the next n rows of the Arrow dataframe"""
175178
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
176179
# Note that the table.slice API is not the same as Python's slice
@@ -179,7 +182,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
179182
self.cur_row_index += slice.num_rows
180183
return slice
181184

182-
def remaining_rows(self) -> pyarrow.Table:
185+
def remaining_rows(self) -> 'pyarrow.Table':
183186
slice = self.arrow_table.slice(
184187
self.cur_row_index, self.n_valid_rows - self.cur_row_index
185188
)
@@ -239,7 +242,7 @@ def __init__(
239242
self.table = self._create_next_table()
240243
self.table_row_index = 0
241244

242-
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
245+
def next_n_rows(self, num_rows: int) -> 'pyarrow.Table':
243246
"""
244247
Get up to the next n rows of the cloud fetch Arrow dataframes.
245248
@@ -271,7 +274,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table:
271274
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
272275
return results
273276

274-
def remaining_rows(self) -> pyarrow.Table:
277+
def remaining_rows(self) -> 'pyarrow.Table':
275278
"""
276279
Get all remaining rows of the cloud fetch Arrow dataframes.
277280
@@ -292,7 +295,7 @@ def remaining_rows(self) -> pyarrow.Table:
292295
self.table_row_index = 0
293296
return results
294297

295-
def _create_next_table(self) -> Union[pyarrow.Table, None]:
298+
def _create_next_table(self) -> Union['pyarrow.Table', None]:
296299
logger.debug(
297300
"CloudFetchQueue: Trying to get downloaded file for row {}".format(
298301
self.start_row_index
@@ -331,7 +334,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]:
331334

332335
return arrow_table
333336

334-
def _create_empty_table(self) -> pyarrow.Table:
337+
def _create_empty_table(self) -> 'pyarrow.Table':
335338
# Create a 0-row table with just the schema bytes
336339
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
337340

@@ -570,7 +573,7 @@ def transform_paramstyle(
570573
return output
571574

572575

573-
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table:
576+
def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> 'pyarrow.Table':
574577
arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes)
575578
return convert_decimals_in_arrow_table(arrow_table, description)
576579

@@ -597,7 +600,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
597600
return arrow_table, n_rows
598601

599602

600-
def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table:
603+
def convert_decimals_in_arrow_table(table, description) -> 'pyarrow.Table':
601604
for i, col in enumerate(table.itercolumns()):
602605
if description[i][1] == "decimal":
603606
decimal_col = col.to_pandas().apply(
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
try:
2+
from databricks_sqlalchemy import *
3+
except:
4+
import warnings
5+
warnings.warn("Install databricks-sqlalchemy plugin before using this")

0 commit comments

Comments
 (0)