Skip to content

Commit 30941cb

Browse files
committed
Example with _from_factorize
1 parent cd5c2db commit 30941cb

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

pandas/core/arrays/base.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,21 @@ def _constructor_from_sequence(cls, scalars):
7777
raise AbstractMethodError(cls)
7878

7979
@classmethod
80-
def _constructor_from_simple_ndarray(cls, values, instance):
80+
def _from_factorized(cls, values, original):
81+
"""Reconstruct an ExtensionArray after factorization.
82+
83+
Parameters
84+
----------
85+
values : ndarray
86+
An integer ndarray with the factorized values.
87+
original : ExtensionArray
88+
The original ndarray that was factorized.
89+
90+
See Also
91+
--------
92+
pandas.factorize
93+
ExtensionArray.factorize
94+
"""
8195
raise AbstractMethodError(cls)
8296

8397
# ------------------------------------------------------------------------
@@ -305,7 +319,16 @@ def unique(self):
305319
uniques = unique(self.astype(object))
306320
return self._constructor_from_sequence(uniques)
307321

308-
def _simple_ndarray(self):
322+
def _values_for_factorize(self):
323+
"""Return an array suitable for factorization.
324+
325+
Returns
326+
-------
327+
ndarray
328+
An array suitable for factoraization. This should maintain order
329+
and be a supported dtype.
330+
331+
"""
309332
return self.astype(object)
310333

311334
def factorize(self, na_sentinel=-1):
@@ -337,17 +360,17 @@ def factorize(self, na_sentinel=-1):
337360
-----
338361
:meth:`pandas.factorize` offers a `sort` keyword as well.
339362
"""
340-
# Implementor note: make sure to exclude missing values from your
341-
# `uniques`. It should only contain non-NA values.
363+
# Implementor notes: There are two options for overriding the
364+
# behavior of `factorize`: here and `_values_for_factorize`.
342365
from pandas.core.algorithms import _factorize_array
343366

344367
mask = self.isna()
345-
arr = self._simple_ndarray()
368+
arr = self._values_for_factorize()
346369
arr[mask] = np.nan
347370

348371
labels, uniques = _factorize_array(arr, check_nulls=True,
349372
na_sentinel=na_sentinel)
350-
uniques = self._constructor_from_simple_ndarray(uniques, instance=arr)
373+
uniques = self._from_factorized(uniques, arr)
351374
return labels, uniques
352375

353376
# ------------------------------------------------------------------------

pandas/tests/extension/decimal/array.py

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def __init__(self, values):
3636
def _constructor_from_sequence(cls, scalars):
3737
return cls(scalars)
3838

39+
@classmethod
40+
def _from_factorized(cls, values, original):
41+
return cls(values)
42+
3943
def __getitem__(self, item):
4044
if isinstance(item, numbers.Integral):
4145
return self.values[item]

pandas/tests/extension/json/array.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import numpy as np
99

10-
import pandas as pd
1110
from pandas.core.dtypes.base import ExtensionDtype
1211
from pandas.core.arrays import ExtensionArray
1312

@@ -38,6 +37,10 @@ def __init__(self, values):
3837
def _constructor_from_sequence(cls, scalars):
3938
return cls(scalars)
4039

40+
@classmethod
41+
def _from_factorized(cls, values, original):
42+
return cls([collections.UserDict(x) for x in values if x != ()])
43+
4144
def __getitem__(self, item):
4245
if isinstance(item, numbers.Integral):
4346
return self.data[item]
@@ -105,20 +108,9 @@ def _concat_same_type(cls, to_concat):
105108
data = list(itertools.chain.from_iterable([x.data for x in to_concat]))
106109
return cls(data)
107110

108-
def factorize(self, na_sentinel=-1):
111+
def _values_for_factorize(self):
109112
frozen = tuple(tuple(x.items()) for x in self)
110-
labels, uniques = pd.factorize(frozen)
111-
112-
# fixup NA
113-
if self.isna().any():
114-
na_code = labels[self.isna()][0]
115-
116-
labels[labels == na_code] = na_sentinel
117-
labels[labels > na_code] -= 1
118-
119-
uniques = JSONArray([collections.UserDict(x)
120-
for x in uniques if x != ()])
121-
return labels, uniques
113+
return np.array(frozen, dtype=object)
122114

123115

124116
def make_data():

0 commit comments

Comments
 (0)