-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
/
Copy pathtest_arrow_interface.py
93 lines (65 loc) · 2.83 KB
/
test_arrow_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import ctypes
import pytest
from pandas._config import using_string_dtype
import pandas.util._test_decorators as td
import pandas as pd
import pandas._testing as tm
pa = pytest.importorskip("pyarrow")
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@td.skip_if_no("pyarrow", min_version="14.0")
def test_dataframe_arrow_interface():
df = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
capsule = df.__arrow_c_stream__()
assert (
ctypes.pythonapi.PyCapsule_IsValid(
ctypes.py_object(capsule), b"arrow_array_stream"
)
== 1
)
table = pa.table(df)
expected = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
assert table.equals(expected)
schema = pa.schema([("a", pa.int8()), ("b", pa.string())])
table = pa.table(df, schema=schema)
expected = expected.cast(schema)
assert table.equals(expected)
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@td.skip_if_no("pyarrow", min_version="15.0")
def test_dataframe_to_arrow():
df = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
table = pa.RecordBatchReader.from_stream(df).read_all()
expected = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
assert table.equals(expected)
schema = pa.schema([("a", pa.int8()), ("b", pa.string())])
table = pa.RecordBatchReader.from_stream(df, schema=schema).read_all()
expected = expected.cast(schema)
assert table.equals(expected)
class ArrowArrayWrapper:
def __init__(self, batch):
self.array = batch
def __arrow_c_array__(self, requested_schema=None):
return self.array.__arrow_c_array__(requested_schema)
class ArrowStreamWrapper:
def __init__(self, table):
self.stream = table
def __arrow_c_stream__(self, requested_schema=None):
return self.stream.__arrow_c_stream__(requested_schema)
@td.skip_if_no("pyarrow", min_version="14.0")
def test_dataframe_from_arrow():
# objects with __arrow_c_stream__
table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
result = pd.DataFrame.from_arrow(table)
expected = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
tm.assert_frame_equal(result, expected)
# not only pyarrow object are supported
result = pd.DataFrame.from_arrow(ArrowStreamWrapper(table))
tm.assert_frame_equal(result, expected)
# objects with __arrow_c_array__
batch = pa.record_batch([[1, 2, 3], ["a", "b", "c"]], names=["a", "b"])
result = pd.DataFrame.from_arrow(table)
tm.assert_frame_equal(result, expected)
result = pd.DataFrame.from_arrow(ArrowArrayWrapper(batch))
tm.assert_frame_equal(result, expected)
# only accept actual Arrow objects
with pytest.raises(TypeError, match="Expected an Arrow-compatible tabular object"):
pd.DataFrame.from_arrow({"a": [1, 2, 3], "b": ["a", "b", "c"]})