Skip to content

Commit 912c2d0

Browse files
committed
BUG/API: _values_for_factorize/_from_factorized round-trip
1 parent 138337b commit 912c2d0

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

pandas/core/arrays/boolean.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def _values_for_factorize(self) -> Tuple[np.ndarray, int]:
320320

321321
@classmethod
322322
def _from_factorized(cls, values, original: "BooleanArray") -> "BooleanArray":
323-
return cls._from_sequence(values, dtype=original.dtype)
323+
mask = values == -1
324+
return cls(values.astype(bool), mask)
324325

325326
_HANDLED_TYPES = (np.ndarray, numbers.Number, bool, np.bool_)
326327

pandas/tests/extension/base/reshaping.py

+7
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,10 @@ def test_transpose(self, data):
324324
self.assert_frame_equal(result, expected)
325325
self.assert_frame_equal(np.transpose(np.transpose(df)), df)
326326
self.assert_frame_equal(np.transpose(np.transpose(df[["A"]])), df[["A"]])
327+
328+
def test_factorize_roundtrip(self, data):
329+
# GH#32673
330+
values = data._values_for_factorize()[0]
331+
result = type(data)._from_factorized(values, data)
332+
333+
self.assert_equal(result, data)

pandas/tests/extension/json/array.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
6565

6666
@classmethod
6767
def _from_factorized(cls, values, original):
68-
return cls([UserDict(x) for x in values if x != ()])
68+
return cls(
69+
[UserDict(x) if x != () else original.dtype.na_value for x in values]
70+
)
6971

7072
def __getitem__(self, item):
7173
if isinstance(item, numbers.Integral):

pandas/tests/extension/test_datetime.py

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas.core.dtypes.dtypes import DatetimeTZDtype
55

66
import pandas as pd
7+
import pandas._testing as tm
78
from pandas.core.arrays import DatetimeArray
89
from pandas.tests.extension import base
910

@@ -201,6 +202,13 @@ def test_unstack(self, obj):
201202
result = ser.unstack(0)
202203
self.assert_equal(result, expected)
203204

205+
def test_factorize_roundtrip(self, data):
206+
# GH#32673, for DTA we dont preserve freq
207+
values = data._values_for_factorize()[0]
208+
result = type(data)._from_factorized(values, data)
209+
210+
tm.assert_numpy_array_equal(result.asi8, data.asi8)
211+
204212

205213
class TestSetitem(BaseDatetimeTests, base.BaseSetitemTests):
206214
pass

0 commit comments

Comments
 (0)