Skip to content

Commit 45e45b5

Browse files
rhshadrachnoatamir
authored andcommitted
CLN: Simplify factorize (pandas-dev#48580)
* CLN: Simplify factorize * Update type-hint and docstring * Update test * cleanup
1 parent 9c0f8fe commit 45e45b5

File tree

2 files changed

+17
-46
lines changed

2 files changed

+17
-46
lines changed

pandas/core/algorithms.py

+14-43
Original file line numberDiff line numberDiff line change
@@ -762,10 +762,6 @@ def factorize(
762762
if not isinstance(values, ABCMultiIndex):
763763
values = extract_array(values, extract_numpy=True)
764764

765-
# GH35667, if na_sentinel=None, we will not dropna NaNs from the uniques
766-
# of values, assign na_sentinel=-1 to replace code value for NaN.
767-
dropna = na_sentinel is not None
768-
769765
if (
770766
isinstance(values, (ABCDatetimeArray, ABCTimedeltaArray))
771767
and values.freq is not None
@@ -793,17 +789,8 @@ def factorize(
793789

794790
else:
795791
values = np.asarray(values) # convert DTA/TDA/MultiIndex
796-
# TODO: pass na_sentinel=na_sentinel to factorize_array. When sort is True and
797-
# na_sentinel is None we append NA on the end because safe_sort does not
798-
# handle null values in uniques.
799-
if na_sentinel is None and sort:
800-
na_sentinel_arg = -1
801-
elif na_sentinel is None:
802-
na_sentinel_arg = None
803-
else:
804-
na_sentinel_arg = na_sentinel
805792

806-
if not dropna and not sort and is_object_dtype(values):
793+
if na_sentinel is None and is_object_dtype(values):
807794
# factorize can now handle differentiating various types of null values.
808795
# These can only occur when the array has object dtype.
809796
# However, for backwards compatibility we only use the null for the
@@ -816,32 +803,15 @@ def factorize(
816803

817804
codes, uniques = factorize_array(
818805
values,
819-
na_sentinel=na_sentinel_arg,
806+
na_sentinel=na_sentinel,
820807
size_hint=size_hint,
821808
)
822809

823810
if sort and len(uniques) > 0:
824-
if na_sentinel is None:
825-
# TODO: Can remove when na_sentinel=na_sentinel as in TODO above
826-
na_sentinel = -1
827811
uniques, codes = safe_sort(
828812
uniques, codes, na_sentinel=na_sentinel, assume_unique=True, verify=False
829813
)
830814

831-
if not dropna and sort:
832-
# TODO: Can remove entire block when na_sentinel=na_sentinel as in TODO above
833-
if na_sentinel is None:
834-
na_sentinel_arg = -1
835-
else:
836-
na_sentinel_arg = na_sentinel
837-
code_is_na = codes == na_sentinel_arg
838-
if code_is_na.any():
839-
# na_value is set based on the dtype of uniques, and compat set to False is
840-
# because we do not want na_value to be 0 for integers
841-
na_value = na_value_for_dtype(uniques.dtype, compat=False)
842-
uniques = np.append(uniques, [na_value])
843-
codes = np.where(code_is_na, len(uniques) - 1, codes)
844-
845815
uniques = _reconstruct_data(uniques, original.dtype, original)
846816

847817
return _re_wrap_factorize(original, uniques, codes)
@@ -1796,7 +1766,7 @@ def diff(arr, n: int, axis: AxisInt = 0):
17961766
def safe_sort(
17971767
values,
17981768
codes=None,
1799-
na_sentinel: int = -1,
1769+
na_sentinel: int | None = -1,
18001770
assume_unique: bool = False,
18011771
verify: bool = True,
18021772
) -> np.ndarray | MultiIndex | tuple[np.ndarray | MultiIndex, np.ndarray]:
@@ -1813,8 +1783,8 @@ def safe_sort(
18131783
codes : list_like, optional
18141784
Indices to ``values``. All out of bound indices are treated as
18151785
"not found" and will be masked with ``na_sentinel``.
1816-
na_sentinel : int, default -1
1817-
Value in ``codes`` to mark "not found".
1786+
na_sentinel : int or None, default -1
1787+
Value in ``codes`` to mark "not found", or None to encode null values as normal.
18181788
Ignored when ``codes`` is None.
18191789
assume_unique : bool, default False
18201790
When True, ``values`` are assumed to be unique, which can speed up
@@ -1920,24 +1890,25 @@ def safe_sort(
19201890
# may deal with them here without performance loss using `mode='wrap'`
19211891
new_codes = reverse_indexer.take(codes, mode="wrap")
19221892

1923-
mask = codes == na_sentinel
1924-
if verify:
1925-
mask = mask | (codes < -len(values)) | (codes >= len(values))
1893+
if na_sentinel is not None:
1894+
mask = codes == na_sentinel
1895+
if verify:
1896+
mask = mask | (codes < -len(values)) | (codes >= len(values))
19261897

1927-
if mask is not None:
1898+
if na_sentinel is not None and mask is not None:
19281899
np.putmask(new_codes, mask, na_sentinel)
19291900

19301901
return ordered, ensure_platform_int(new_codes)
19311902

19321903

19331904
def _sort_mixed(values) -> np.ndarray:
1934-
"""order ints before strings in 1d arrays, safe in py3"""
1905+
"""order ints before strings before nulls in 1d arrays"""
19351906
str_pos = np.array([isinstance(x, str) for x in values], dtype=bool)
1936-
none_pos = np.array([x is None for x in values], dtype=bool)
1937-
nums = np.sort(values[~str_pos & ~none_pos])
1907+
null_pos = np.array([isna(x) for x in values], dtype=bool)
1908+
nums = np.sort(values[~str_pos & ~null_pos])
19381909
strs = np.sort(values[str_pos])
19391910
return np.concatenate(
1940-
[nums, np.asarray(strs, dtype=object), np.array(values[none_pos])]
1911+
[nums, np.asarray(strs, dtype=object), np.array(values[null_pos])]
19411912
)
19421913

19431914

pandas/tests/test_sorting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,10 @@ def test_extension_array_codes(self, verify, na_sentinel):
506506
tm.assert_numpy_array_equal(codes, expected_codes)
507507

508508

509-
def test_mixed_str_nan():
510-
values = np.array(["b", np.nan, "a", "b"], dtype=object)
509+
def test_mixed_str_null(nulls_fixture):
510+
values = np.array(["b", nulls_fixture, "a", "b"], dtype=object)
511511
result = safe_sort(values)
512-
expected = np.array([np.nan, "a", "b", "b"], dtype=object)
512+
expected = np.array(["a", "b", "b", nulls_fixture], dtype=object)
513513
tm.assert_numpy_array_equal(result, expected)
514514

515515

0 commit comments

Comments
 (0)