Skip to content

Commit 4f91875

Browse files
authored
BUG: concat_compat with 2D PeriodArray (#44598)
1 parent 7ca5a6b commit 4f91875

File tree

4 files changed

+46
-14
lines changed

4 files changed

+46
-14
lines changed

pandas/core/arrays/categorical.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -2291,18 +2291,28 @@ def _concat_same_type(
22912291
) -> CategoricalT:
22922292
from pandas.core.dtypes.concat import union_categoricals
22932293

2294-
result = union_categoricals(to_concat)
2295-
2296-
# in case we are concatenating along axis != 0, we need to reshape
2297-
# the result from union_categoricals
22982294
first = to_concat[0]
22992295
if axis >= first.ndim:
2300-
raise ValueError
2296+
raise ValueError(
2297+
f"axis {axis} is out of bounds for array of dimension {first.ndim}"
2298+
)
2299+
23012300
if axis == 1:
2302-
if not all(len(x) == len(first) for x in to_concat):
2301+
# Flatten, concatenate then reshape
2302+
if not all(x.ndim == 2 for x in to_concat):
23032303
raise ValueError
2304-
# TODO: Will this get contiguity wrong?
2305-
result = result.reshape(-1, len(to_concat), order="F")
2304+
2305+
# pass correctly-shaped to union_categoricals
2306+
tc_flat = []
2307+
for obj in to_concat:
2308+
tc_flat.extend([obj[:, i] for i in range(obj.shape[1])])
2309+
2310+
res_flat = cls._concat_same_type(tc_flat, axis=0)
2311+
2312+
result = res_flat.reshape(len(first), -1, order="F")
2313+
return result
2314+
2315+
result = union_categoricals(to_concat)
23062316
return result
23072317

23082318
# ------------------------------------------------------------------

pandas/core/dtypes/concat.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def is_nonempty(x) -> bool:
107107
to_concat = non_empties
108108

109109
kinds = {obj.dtype.kind for obj in to_concat}
110-
contains_datetime = any(kind in ["m", "M"] for kind in kinds)
110+
contains_datetime = any(kind in ["m", "M"] for kind in kinds) or any(
111+
isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat
112+
)
111113

112114
all_empty = not len(non_empties)
113115
single_dtype = len({x.dtype for x in to_concat}) == 1

pandas/tests/dtypes/test_concat.py

+18
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,21 @@ def test_concat_single_dataframe_tz_aware(copy):
2626
expected = df.copy()
2727
result = pd.concat([df], copy=copy)
2828
tm.assert_frame_equal(result, expected)
29+
30+
31+
def test_concat_periodarray_2d():
32+
pi = pd.period_range("2016-01-01", periods=36, freq="D")
33+
arr = pi._data.reshape(6, 6)
34+
35+
result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)
36+
tm.assert_period_array_equal(result, arr)
37+
38+
result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)
39+
tm.assert_period_array_equal(result, arr)
40+
41+
msg = "all the input array dimensions for the concatenation axis must match exactly"
42+
with pytest.raises(ValueError, match=msg):
43+
_concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)
44+
45+
with pytest.raises(ValueError, match=msg):
46+
_concat.concat_compat([arr[:2], arr[2:]], axis=1)

pandas/tests/extension/base/dim2.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,23 @@ def test_tolist_2d(self, data):
122122
assert result == expected
123123

124124
def test_concat_2d(self, data):
125-
left = data.reshape(-1, 1)
125+
left = type(data)._concat_same_type([data, data]).reshape(-1, 2)
126126
right = left.copy()
127127

128128
# axis=0
129129
result = left._concat_same_type([left, right], axis=0)
130-
expected = data._concat_same_type([data, data]).reshape(-1, 1)
130+
expected = data._concat_same_type([data] * 4).reshape(-1, 2)
131131
self.assert_extension_array_equal(result, expected)
132132

133133
# axis=1
134134
result = left._concat_same_type([left, right], axis=1)
135-
expected = data.repeat(2).reshape(-1, 2)
136-
self.assert_extension_array_equal(result, expected)
135+
assert result.shape == (len(data), 4)
136+
self.assert_extension_array_equal(result[:, :2], left)
137+
self.assert_extension_array_equal(result[:, 2:], right)
137138

138139
# axis > 1 -> invalid
139-
with pytest.raises(ValueError):
140+
msg = "axis 2 is out of bounds for array of dimension 2"
141+
with pytest.raises(ValueError, match=msg):
140142
left._concat_same_type([left, right], axis=2)
141143

142144
@pytest.mark.parametrize("method", ["backfill", "pad"])

0 commit comments

Comments
 (0)