Skip to content

Commit b8d5e9c

Browse files
authored
REF: hold PeriodArray in NDArrayBackedExtensionBlock (#44681)
1 parent 447ef57 commit b8d5e9c

File tree

9 files changed

+55
-7
lines changed

9 files changed

+55
-7
lines changed

pandas/core/dtypes/common.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1408,11 +1408,12 @@ def is_1d_only_ea_obj(obj: Any) -> bool:
14081408
from pandas.core.arrays import (
14091409
DatetimeArray,
14101410
ExtensionArray,
1411+
PeriodArray,
14111412
TimedeltaArray,
14121413
)
14131414

14141415
return isinstance(obj, ExtensionArray) and not isinstance(
1415-
obj, (DatetimeArray, TimedeltaArray)
1416+
obj, (DatetimeArray, TimedeltaArray, PeriodArray)
14161417
)
14171418

14181419

@@ -1424,7 +1425,9 @@ def is_1d_only_ea_dtype(dtype: DtypeObj | None) -> bool:
14241425
# here too.
14251426
# NB: need to check DatetimeTZDtype and not is_datetime64tz_dtype
14261427
# to exclude ArrowTimestampUSDtype
1427-
return isinstance(dtype, ExtensionDtype) and not isinstance(dtype, DatetimeTZDtype)
1428+
return isinstance(dtype, ExtensionDtype) and not isinstance(
1429+
dtype, (DatetimeTZDtype, PeriodDtype)
1430+
)
14281431

14291432

14301433
def is_extension_array_dtype(arr_or_dtype) -> bool:

pandas/core/internals/api.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pandas.core.dtypes.common import (
1717
is_datetime64tz_dtype,
18+
is_period_dtype,
1819
pandas_dtype,
1920
)
2021

@@ -62,8 +63,9 @@ def make_block(
6263
placement = BlockPlacement(placement)
6364

6465
ndim = maybe_infer_ndim(values, placement, ndim)
65-
if is_datetime64tz_dtype(values.dtype):
66+
if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
6667
# GH#41168 ensure we can pass 1D dt64tz values
68+
# More generally, any EA dtype that isn't is_1d_only_ea_dtype
6769
values = extract_array(values, extract_numpy=True)
6870
values = ensure_block_shape(values, ndim)
6971

pandas/core/internals/blocks.py

+12
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
CategoricalDtype,
5959
ExtensionDtype,
6060
PandasDtype,
61+
PeriodDtype,
6162
)
6263
from pandas.core.dtypes.generic import (
6364
ABCDataFrame,
@@ -1728,6 +1729,12 @@ class NDArrayBackedExtensionBlock(libinternals.NDArrayBackedBlock, EABackedBlock
17281729

17291730
values: NDArrayBackedExtensionArray
17301731

1732+
# error: Signature of "is_extension" incompatible with supertype "Block"
1733+
@cache_readonly
1734+
def is_extension(self) -> bool: # type: ignore[override]
1735+
# i.e. datetime64tz, PeriodDtype
1736+
return not isinstance(self.dtype, np.dtype)
1737+
17311738
@property
17321739
def is_view(self) -> bool:
17331740
"""return a boolean if I am possibly a view"""
@@ -1756,6 +1763,9 @@ def where(self, other, cond) -> list[Block]:
17561763
try:
17571764
res_values = arr.T._where(cond, other).T
17581765
except (ValueError, TypeError):
1766+
if isinstance(self.dtype, PeriodDtype):
1767+
# TODO: don't special-case
1768+
raise
17591769
blk = self.coerce_to_target_dtype(other)
17601770
nbs = blk.where(other, cond)
17611771
return self._maybe_downcast(nbs, "infer")
@@ -1949,6 +1959,8 @@ def get_block_type(dtype: DtypeObj):
19491959
cls = CategoricalBlock
19501960
elif vtype is Timestamp:
19511961
cls = DatetimeTZBlock
1962+
elif isinstance(dtype, PeriodDtype):
1963+
cls = NDArrayBackedExtensionBlock
19521964
elif isinstance(dtype, ExtensionDtype):
19531965
# Note: need to be sure PandasArray is unwrapped before we get here
19541966
cls = ExtensionBlock

pandas/core/internals/construction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def ndarray_to_mgr(
318318
return arrays_to_mgr(values, columns, index, dtype=dtype, typ=typ)
319319

320320
elif is_extension_array_dtype(vdtype) and not is_1d_only_ea_dtype(vdtype):
321-
# i.e. Datetime64TZ
321+
# i.e. Datetime64TZ, PeriodDtype
322322
values = extract_array(values, extract_numpy=True)
323323
if copy:
324324
values = values.copy()

pandas/tests/arithmetic/test_period.py

+3
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,9 @@ def test_parr_add_timedeltalike_scalar(self, three_days, box_with_array):
12621262
)
12631263

12641264
obj = tm.box_expected(ser, box_with_array)
1265+
if box_with_array is pd.DataFrame:
1266+
assert (obj.dtypes == "Period[D]").all()
1267+
12651268
expected = tm.box_expected(expected, box_with_array)
12661269

12671270
result = obj + three_days

pandas/tests/arrays/period/test_arrow_compat.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
from pandas.compat import pa_version_under2p0
4+
35
from pandas.core.dtypes.dtypes import PeriodDtype
46

57
import pandas as pd
@@ -69,6 +71,9 @@ def test_arrow_array_missing():
6971
assert result.storage.equals(expected)
7072

7173

74+
@pytest.mark.xfail(
75+
pa_version_under2p0, reason="pyarrow incorrectly uses pandas internals API"
76+
)
7277
def test_arrow_table_roundtrip():
7378
from pandas.core.arrays._arrow_utils import ArrowPeriodType
7479

@@ -88,6 +93,9 @@ def test_arrow_table_roundtrip():
8893
tm.assert_frame_equal(result, expected)
8994

9095

96+
@pytest.mark.xfail(
97+
pa_version_under2p0, reason="pyarrow incorrectly uses pandas internals API"
98+
)
9199
def test_arrow_load_from_zero_chunks():
92100
# GH-41040
93101

@@ -106,6 +114,9 @@ def test_arrow_load_from_zero_chunks():
106114
tm.assert_frame_equal(result, df)
107115

108116

117+
@pytest.mark.xfail(
118+
pa_version_under2p0, reason="pyarrow incorrectly uses pandas internals API"
119+
)
109120
def test_arrow_table_roundtrip_without_metadata():
110121
arr = PeriodArray([1, 2, 3], freq="H")
111122
arr[1] = pd.NaT

pandas/tests/internals/test_internals.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,7 @@ def test_interval_can_hold_element(self, dtype, element):
12751275

12761276
def test_period_can_hold_element_emptylist(self):
12771277
pi = period_range("2016", periods=3, freq="A")
1278-
blk = new_block(pi._data, [1], ndim=2)
1278+
blk = new_block(pi._data.reshape(1, 3), [1], ndim=2)
12791279

12801280
assert blk._can_hold_element([])
12811281

pandas/tests/io/test_feather.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import numpy as np
33
import pytest
44

5+
from pandas.compat.pyarrow import pa_version_under2p0
6+
57
import pandas as pd
68
import pandas._testing as tm
79

@@ -85,7 +87,11 @@ def test_basic(self):
8587
),
8688
}
8789
)
88-
df["periods"] = pd.period_range("2013", freq="M", periods=3)
90+
if not pa_version_under2p0:
91+
# older pyarrow incorrectly uses pandas internal API, so
92+
# constructs invalid Block
93+
df["periods"] = pd.period_range("2013", freq="M", periods=3)
94+
8995
df["timedeltas"] = pd.timedelta_range("1 day", periods=3)
9096
df["intervals"] = pd.interval_range(0, 3, 3)
9197

pandas/tests/io/test_parquet.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,15 @@ def test_use_nullable_dtypes(self, engine, request):
648648
"object",
649649
"datetime64[ns, UTC]",
650650
"float",
651-
"period[D]",
651+
pytest.param(
652+
"period[D]",
653+
# Note: I don't know exactly what version the cutoff is;
654+
# On the CI it fails with 1.0.1
655+
marks=pytest.mark.xfail(
656+
pa_version_under2p0,
657+
reason="pyarrow uses pandas internal API incorrectly",
658+
),
659+
),
652660
"Float64",
653661
"string",
654662
],
@@ -887,6 +895,9 @@ def test_pyarrow_backed_string_array(self, pa, string_storage):
887895
check_round_trip(df, pa, expected=df.astype(f"string[{string_storage}]"))
888896

889897
@td.skip_if_no("pyarrow")
898+
@pytest.mark.xfail(
899+
pa_version_under2p0, reason="pyarrow uses pandas internal API incorrectly"
900+
)
890901
def test_additional_extension_types(self, pa):
891902
# test additional ExtensionArrays that are supported through the
892903
# __arrow_array__ protocol + by defining a custom ExtensionType

0 commit comments

Comments
 (0)