Skip to content

Commit 9e47ff7

Browse files
authored
BUG: incorrect casting ints to Period in GroupBy.agg (#39362)
1 parent a1e8b59 commit 9e47ff7

File tree

5 files changed

+53
-5
lines changed

5 files changed

+53
-5
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ Plotting
343343

344344
Groupby/resample/rolling
345345
^^^^^^^^^^^^^^^^^^^^^^^^
346-
346+
- Bug in :meth:`DataFrameGroupBy.agg` and :meth:`SeriesGroupBy.agg` with :class:`PeriodDtype` columns incorrectly casting results too aggressively (:issue:`38254`)
347347
- Bug in :meth:`SeriesGroupBy.value_counts` where unobserved categories in a grouped categorical series were not tallied (:issue:`38672`)
348348
- Bug in :meth:`SeriesGroupBy.value_counts` where error was raised on an empty series (:issue:`39172`)
349349
- Bug in :meth:`.GroupBy.indices` would contain non-existent indices when null values were present in the groupby keys (:issue:`9304`)

pandas/_libs/tslibs/period.pyx

+24
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,28 @@ cdef accessor _get_accessor_func(str field):
14021402
return NULL
14031403

14041404

1405+
@cython.wraparound(False)
1406+
@cython.boundscheck(False)
1407+
def from_ordinals(const int64_t[:] values, freq):
1408+
cdef:
1409+
Py_ssize_t i, n = len(values)
1410+
int64_t[:] result = np.empty(len(values), dtype="i8")
1411+
int64_t val
1412+
1413+
freq = to_offset(freq)
1414+
if not isinstance(freq, BaseOffset):
1415+
raise ValueError("freq not specified and cannot be inferred")
1416+
1417+
for i in range(n):
1418+
val = values[i]
1419+
if val == NPY_NAT:
1420+
result[i] = NPY_NAT
1421+
else:
1422+
result[i] = Period(val, freq=freq).ordinal
1423+
1424+
return result.base
1425+
1426+
14051427
@cython.wraparound(False)
14061428
@cython.boundscheck(False)
14071429
def extract_ordinals(ndarray[object] values, freq):
@@ -1419,6 +1441,8 @@ def extract_ordinals(ndarray[object] values, freq):
14191441

14201442
if is_null_datetimelike(p):
14211443
ordinals[i] = NPY_NAT
1444+
elif util.is_integer_object(p):
1445+
raise TypeError(p)
14221446
else:
14231447
try:
14241448
ordinals[i] = p.ordinal

pandas/core/arrays/period.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
is_datetime64_dtype,
3939
is_dtype_equal,
4040
is_float_dtype,
41+
is_integer_dtype,
4142
is_period_dtype,
4243
pandas_dtype,
4344
)
@@ -897,18 +898,23 @@ def period_array(
897898
if not isinstance(data, (np.ndarray, list, tuple, ABCSeries)):
898899
data = list(data)
899900

900-
data = np.asarray(data)
901+
arrdata = np.asarray(data)
901902

902903
dtype: Optional[PeriodDtype]
903904
if freq:
904905
dtype = PeriodDtype(freq)
905906
else:
906907
dtype = None
907908

908-
if is_float_dtype(data) and len(data) > 0:
909+
if is_float_dtype(arrdata) and len(arrdata) > 0:
909910
raise TypeError("PeriodIndex does not allow floating point in construction")
910911

911-
data = ensure_object(data)
912+
if is_integer_dtype(arrdata.dtype):
913+
arr = arrdata.astype(np.int64, copy=False)
914+
ordinals = libperiod.from_ordinals(arr, freq)
915+
return PeriodArray(ordinals, dtype=dtype)
916+
917+
data = ensure_object(arrdata)
912918

913919
return PeriodArray._from_sequence(data, dtype=dtype)
914920

pandas/tests/arrays/test_period.py

+11
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@ def test_period_array_freq_mismatch():
102102
PeriodArray(arr, freq=pd.tseries.offsets.MonthEnd())
103103

104104

105+
def test_from_sequence_disallows_i8():
106+
arr = period_array(["2000", "2001"], freq="D")
107+
108+
msg = str(arr[0].ordinal)
109+
with pytest.raises(TypeError, match=msg):
110+
PeriodArray._from_sequence(arr.asi8, dtype=arr.dtype)
111+
112+
with pytest.raises(TypeError, match=msg):
113+
PeriodArray._from_sequence(list(arr.asi8), dtype=arr.dtype)
114+
115+
105116
def test_asi8():
106117
result = period_array(["2000", "2001", None], freq="D").asi8
107118
expected = np.array([10957, 11323, iNaT])

pandas/tests/groupby/aggregate/test_other.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,14 @@ def test_agg_over_numpy_arrays():
432432
tm.assert_frame_equal(result, expected)
433433

434434

435-
def test_agg_tzaware_non_datetime_result():
435+
@pytest.mark.parametrize("as_period", [True, False])
436+
def test_agg_tzaware_non_datetime_result(as_period):
436437
# discussed in GH#29589, fixed in GH#29641, operating on tzaware values
437438
# with function that is not dtype-preserving
438439
dti = pd.date_range("2012-01-01", periods=4, tz="UTC")
440+
if as_period:
441+
dti = dti.tz_localize(None).to_period("D")
442+
439443
df = DataFrame({"a": [0, 0, 1, 1], "b": dti})
440444
gb = df.groupby("a")
441445

@@ -454,6 +458,9 @@ def test_agg_tzaware_non_datetime_result():
454458
result = gb["b"].agg(lambda x: x.iloc[-1] - x.iloc[0])
455459
expected = Series([pd.Timedelta(days=1), pd.Timedelta(days=1)], name="b")
456460
expected.index.name = "a"
461+
if as_period:
462+
expected = Series([pd.offsets.Day(1), pd.offsets.Day(1)], name="b")
463+
expected.index.name = "a"
457464
tm.assert_series_equal(result, expected)
458465

459466

0 commit comments

Comments
 (0)