-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathtest_arrow_queue.py
38 lines (33 loc) · 1.37 KB
/
test_arrow_queue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import unittest
import pytest
try:
import pyarrow as pa
except ImportError:
pa = None
from databricks.sql.utils import ArrowQueue
@pytest.mark.skipif(pa is None, reason="PyArrow is not installed")
class ArrowQueueSuite(unittest.TestCase):
@staticmethod
def make_arrow_table(batch):
n_cols = len(batch[0]) if batch else 0
schema = pa.schema({"col%s" % i: pa.uint32() for i in range(n_cols)})
cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)]
return pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema)
def test_fetchmany_respects_n_rows(self):
arrow_table = self.make_arrow_table(
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
)
aq = ArrowQueue(arrow_table, 3)
self.assertEqual(
aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]])
)
self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[6, 7, 8]]))
def test_fetch_remaining_rows_respects_n_rows(self):
arrow_table = self.make_arrow_table(
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
)
aq = ArrowQueue(arrow_table, 3)
self.assertEqual(aq.next_n_rows(1), self.make_arrow_table([[0, 1, 2]]))
self.assertEqual(
aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]])
)