Skip to content

Commit a58c97f

Browse files
committed
Refractore
1 parent 2470581 commit a58c97f

10 files changed

+70
-24
lines changed

databricks_sql_connector_core/tests/e2e/common/decimal_tests.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from decimal import Decimal
22

3-
import pyarrow
43
import pytest
54

5+
try:
6+
import pyarrow
7+
except ImportError:
8+
pyarrow = None
69

7-
class DecimalTestsMixin:
8-
decimal_and_expected_results = [
10+
from tests.e2e.predicate import pysql_supports_arrow
11+
12+
def decimal_and_expected_results():
13+
return [
914
("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)),
1015
("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)),
1116
("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)),
@@ -17,7 +22,8 @@ class DecimalTestsMixin:
1722
("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)),
1823
]
1924

20-
multi_decimals_and_expected_results = [
25+
def multi_decimals_and_expected_results():
26+
return [
2127
(
2228
["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"],
2329
[Decimal("1.00"), Decimal("100.001"), None],
@@ -30,7 +36,9 @@ class DecimalTestsMixin:
3036
),
3137
]
3238

33-
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results)
39+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
40+
class DecimalTestsMixin:
41+
@pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results())
3442
def test_decimals(self, decimal, expected_value, expected_type):
3543
with self.cursor({}) as cursor:
3644
query = "SELECT CAST ({})".format(decimal)
@@ -39,9 +47,7 @@ def test_decimals(self, decimal, expected_value, expected_type):
3947
assert table.field(0).type == expected_type
4048
assert table.to_pydict().popitem()[1][0] == expected_value
4149

42-
@pytest.mark.parametrize(
43-
"decimals, expected_values, expected_type", multi_decimals_and_expected_results
44-
)
50+
@pytest.mark.parametrize("decimals, expected_values, expected_type", multi_decimals_and_expected_results())
4551
def test_multi_decimals(self, decimals, expected_values, expected_type):
4652
with self.cursor({}) as cursor:
4753
union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals])

databricks_sql_connector_core/tests/e2e/common/large_queries_mixin.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import logging
22
import math
33
import time
4+
from unittest import skipUnless
5+
6+
import pytest
7+
from tests.e2e.common.predicates import pysql_supports_arrow
48

59
log = logging.getLogger(__name__)
610

@@ -40,6 +44,7 @@ def fetch_rows(self, cursor, row_count, fetchmany_size):
4044
+ "assuming 10K fetch size."
4145
)
4246

47+
@skipUnless(pysql_supports_arrow(), "Without pyarrow lz4 compression is not supported")
4348
def test_query_with_large_wide_result_set(self):
4449
resultSize = 300 * 1000 * 1000 # 300 MB
4550
width = 8192 # B

databricks_sql_connector_core/tests/e2e/common/predicates.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88

99

1010
def pysql_supports_arrow():
11-
"""Import databricks.sql and test whether Cursor has fetchall_arrow."""
12-
from databricks.sql.client import Cursor
13-
return hasattr(Cursor, 'fetchall_arrow')
11+
"""Checks if the pyarrow library is installed or not"""
12+
try:
13+
import pyarrow
14+
15+
return True
16+
except ImportError:
17+
return False
1418

1519

1620
def pysql_has_version(compare, version):

databricks_sql_connector_core/tests/e2e/test_complex_types.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from numpy import ndarray
33

44
from tests.e2e.test_driver import PySQLPytestTestCase
5+
from tests.e2e.predicate import pysql_supports_arrow
56

6-
7+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
78
class TestComplexTypes(PySQLPytestTestCase):
89
@pytest.fixture(scope="class")
910
def table_fixture(self, connection_details):

databricks_sql_connector_core/tests/e2e/test_driver.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from uuid import uuid4
1313

1414
import numpy as np
15-
import pyarrow
1615
import pytz
1716
import thrift
1817
import pytest
@@ -35,6 +34,7 @@
3534
pysql_supports_arrow,
3635
compare_dbr_versions,
3736
is_thrift_v5_plus,
37+
pysql_supports_arrow
3838
)
3939
from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin
4040
from tests.e2e.common.large_queries_mixin import LargeQueriesMixin
@@ -48,6 +48,11 @@
4848

4949
from databricks.sql.exc import SessionAlreadyClosedError
5050

51+
try:
52+
import pyarrow
53+
except:
54+
pyarrow = None
55+
5156
log = logging.getLogger(__name__)
5257

5358
unsafe_logger = logging.getLogger("databricks.sql.unsafe")
@@ -591,7 +596,7 @@ def test_ssp_passthrough(self):
591596
cursor.execute("SET ansi_mode")
592597
assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)]
593598

594-
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
599+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
595600
def test_timestamps_arrow(self):
596601
with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor:
597602
for timestamp, expected in self.timestamp_and_expected_results:
@@ -611,7 +616,7 @@ def test_timestamps_arrow(self):
611616
aware_timestamp and aware_timestamp.timestamp() * 1000000
612617
), "timestamp {} did not match {}".format(timestamp, expected)
613618

614-
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
619+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
615620
def test_multi_timestamps_arrow(self):
616621
with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor:
617622
query, expected = self.multi_query()
@@ -627,7 +632,7 @@ def test_multi_timestamps_arrow(self):
627632
]
628633
assert result == expected
629634

630-
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
635+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
631636
def test_timezone_with_timestamp(self):
632637
if self.should_add_timezone():
633638
with self.cursor() as cursor:
@@ -646,7 +651,7 @@ def test_timezone_with_timestamp(self):
646651
assert arrow_result_table.field(0).type == ts_type
647652
assert arrow_result_value == expected.timestamp() * 1000000
648653

649-
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
654+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
650655
def test_can_flip_compression(self):
651656
with self.cursor() as cursor:
652657
cursor.execute("SELECT array(1,2,3,4)")
@@ -663,7 +668,7 @@ def test_can_flip_compression(self):
663668
def _should_have_native_complex_types(self):
664669
return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments)
665670

666-
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
671+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
667672
def test_arrays_are_not_returned_as_strings_arrow(self):
668673
if self._should_have_native_complex_types():
669674
with self.cursor() as cursor:
@@ -674,7 +679,7 @@ def test_arrays_are_not_returned_as_strings_arrow(self):
674679
assert pyarrow.types.is_list(list_type)
675680
assert pyarrow.types.is_integer(list_type.value_type)
676681

677-
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
682+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
678683
def test_structs_are_not_returned_as_strings_arrow(self):
679684
if self._should_have_native_complex_types():
680685
with self.cursor() as cursor:
@@ -684,7 +689,7 @@ def test_structs_are_not_returned_as_strings_arrow(self):
684689
struct_type = arrow_df.field(0).type
685690
assert pyarrow.types.is_struct(struct_type)
686691

687-
@skipUnless(pysql_supports_arrow(), "arrow test needs arrow support")
692+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
688693
def test_decimal_not_returned_as_strings_arrow(self):
689694
if self._should_have_native_complex_types():
690695
with self.cursor() as cursor:

databricks_sql_connector_core/tests/e2e/test_parameterized_queries.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
VoidParameter,
2929
)
3030
from tests.e2e.test_driver import PySQLPytestTestCase
31+
from tests.e2e.predicate import pysql_supports_arrow
3132

3233

3334
class ParamStyle(Enum):
@@ -284,6 +285,8 @@ def test_primitive_single(
284285
(PrimitiveExtra.TINYINT, TinyIntParameter),
285286
],
286287
)
288+
289+
@pytest.mark.skipif(not pysql_supports_arrow(),reason="Without pyarrow TIMESTAMP_NTZ datatype cannot be inferred",)
287290
def test_dbsqlparameter_single(
288291
self,
289292
primitive: Primitive,

databricks_sql_connector_core/tests/unit/test_arrow_queue.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
import unittest
22

3-
import pyarrow as pa
3+
import pytest
44

55
from databricks.sql.utils import ArrowQueue
66

7+
try:
8+
import pyarrow as pa
9+
except ImportError:
10+
pa = None
711

12+
from tests.e2e.predicate import pysql_supports_arrow
13+
14+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
815
class ArrowQueueSuite(unittest.TestCase):
916
@staticmethod
1017
def make_arrow_table(batch):

databricks_sql_connector_core/tests/unit/test_cloud_fetch_queue.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
import pyarrow
1+
import pytest
22
import unittest
33
from unittest.mock import MagicMock, patch
44
from ssl import create_default_context
55

66
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
77
import databricks.sql.utils as utils
8+
from tests.e2e.predicate import pysql_supports_arrow
89

10+
try:
11+
import pyarrow
12+
except ImportError:
13+
pyarrow = None
14+
15+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
916
class CloudFetchQueueSuite(unittest.TestCase):
1017

1118
def create_result_link(

databricks_sql_connector_core/tests/unit/test_download_manager.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import unittest
22
from unittest.mock import patch, MagicMock
3+
import pytest
34

45
from ssl import create_default_context
56

67
import databricks.sql.cloudfetch.download_manager as download_manager
78
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
89

10+
from tests.e2e.predicate import pysql_supports_arrow
911

12+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
1013
class DownloadManagerTests(unittest.TestCase):
1114
"""
1215
Unit tests for checking download manager logic.

databricks_sql_connector_core/tests/unit/test_fetches.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import unittest
22
from unittest.mock import Mock
3-
4-
import pyarrow as pa
3+
import pytest
54

65
import databricks.sql.client as client
76
from databricks.sql.utils import ExecuteResponse, ArrowQueue
7+
from tests.e2e.predicate import pysql_supports_arrow
88

9+
try:
10+
import pyarrow as pa
11+
except ImportError:
12+
pa = None
913

14+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Skipping because pyarrow is not installed")
1015
class FetchTests(unittest.TestCase):
1116
"""
1217
Unit tests for checking the fetch logic.

0 commit comments

Comments
 (0)