Skip to content

CLN: Simplify factorize #48580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 11, 2022
57 changes: 14 additions & 43 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,6 @@ def factorize(
if not isinstance(values, ABCMultiIndex):
values = extract_array(values, extract_numpy=True)

# GH35667, if na_sentinel=None, we will not dropna NaNs from the uniques
# of values, assign na_sentinel=-1 to replace code value for NaN.
dropna = na_sentinel is not None

if (
isinstance(values, (ABCDatetimeArray, ABCTimedeltaArray))
and values.freq is not None
Expand Down Expand Up @@ -793,17 +789,8 @@ def factorize(

else:
values = np.asarray(values) # convert DTA/TDA/MultiIndex
# TODO: pass na_sentinel=na_sentinel to factorize_array. When sort is True and
# na_sentinel is None we append NA on the end because safe_sort does not
# handle null values in uniques.
if na_sentinel is None and sort:
na_sentinel_arg = -1
elif na_sentinel is None:
na_sentinel_arg = None
else:
na_sentinel_arg = na_sentinel

if not dropna and not sort and is_object_dtype(values):
if na_sentinel is None and is_object_dtype(values):
# factorize can now handle differentiating various types of null values.
# These can only occur when the array has object dtype.
# However, for backwards compatibility we only use the null for the
Expand All @@ -816,32 +803,15 @@ def factorize(

codes, uniques = factorize_array(
values,
na_sentinel=na_sentinel_arg,
na_sentinel=na_sentinel,
size_hint=size_hint,
)

if sort and len(uniques) > 0:
if na_sentinel is None:
# TODO: Can remove when na_sentinel=na_sentinel as in TODO above
na_sentinel = -1
uniques, codes = safe_sort(
uniques, codes, na_sentinel=na_sentinel, assume_unique=True, verify=False
)

if not dropna and sort:
# TODO: Can remove entire block when na_sentinel=na_sentinel as in TODO above
if na_sentinel is None:
na_sentinel_arg = -1
else:
na_sentinel_arg = na_sentinel
code_is_na = codes == na_sentinel_arg
if code_is_na.any():
# na_value is set based on the dtype of uniques, and compat set to False is
# because we do not want na_value to be 0 for integers
na_value = na_value_for_dtype(uniques.dtype, compat=False)
uniques = np.append(uniques, [na_value])
codes = np.where(code_is_na, len(uniques) - 1, codes)

uniques = _reconstruct_data(uniques, original.dtype, original)

return _re_wrap_factorize(original, uniques, codes)
Expand Down Expand Up @@ -1796,7 +1766,7 @@ def diff(arr, n: int, axis: AxisInt = 0):
def safe_sort(
values,
codes=None,
na_sentinel: int = -1,
na_sentinel: int | None = -1,
assume_unique: bool = False,
verify: bool = True,
) -> np.ndarray | MultiIndex | tuple[np.ndarray | MultiIndex, np.ndarray]:
Expand All @@ -1813,8 +1783,8 @@ def safe_sort(
codes : list_like, optional
Indices to ``values``. All out of bound indices are treated as
"not found" and will be masked with ``na_sentinel``.
na_sentinel : int, default -1
Value in ``codes`` to mark "not found".
na_sentinel : int or None, default -1
Value in ``codes`` to mark "not found", or None to encode null values as normal.
Ignored when ``codes`` is None.
assume_unique : bool, default False
When True, ``values`` are assumed to be unique, which can speed up
Expand Down Expand Up @@ -1920,24 +1890,25 @@ def safe_sort(
# may deal with them here without performance loss using `mode='wrap'`
new_codes = reverse_indexer.take(codes, mode="wrap")

mask = codes == na_sentinel
if verify:
mask = mask | (codes < -len(values)) | (codes >= len(values))
if na_sentinel is not None:
mask = codes == na_sentinel
if verify:
mask = mask | (codes < -len(values)) | (codes >= len(values))

if mask is not None:
if na_sentinel is not None and mask is not None:
np.putmask(new_codes, mask, na_sentinel)

return ordered, ensure_platform_int(new_codes)


def _sort_mixed(values) -> np.ndarray:
"""order ints before strings in 1d arrays, safe in py3"""
"""order ints before strings before nulls in 1d arrays"""
str_pos = np.array([isinstance(x, str) for x in values], dtype=bool)
none_pos = np.array([x is None for x in values], dtype=bool)
nums = np.sort(values[~str_pos & ~none_pos])
null_pos = np.array([isna(x) for x in values], dtype=bool)
nums = np.sort(values[~str_pos & ~null_pos])
strs = np.sort(values[str_pos])
return np.concatenate(
[nums, np.asarray(strs, dtype=object), np.array(values[none_pos])]
[nums, np.asarray(strs, dtype=object), np.array(values[null_pos])]
)


Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,10 @@ def test_extension_array_codes(self, verify, na_sentinel):
tm.assert_numpy_array_equal(codes, expected_codes)


def test_mixed_str_nan():
values = np.array(["b", np.nan, "a", "b"], dtype=object)
def test_mixed_str_null(nulls_fixture):
values = np.array(["b", nulls_fixture, "a", "b"], dtype=object)
result = safe_sort(values)
expected = np.array([np.nan, "a", "b", "b"], dtype=object)
expected = np.array(["a", "b", "b", nulls_fixture], dtype=object)
tm.assert_numpy_array_equal(result, expected)


Expand Down