Skip to content

REF: avoid object-casting in _get_codes_for_values #45117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _get_values_for_rank(values: ArrayLike) -> np.ndarray:
return values


def get_data_algo(values: ArrayLike):
def _get_data_algo(values: ArrayLike):
values = _get_values_for_rank(values)

ndtype = _check_object_for_strings(values)
Expand Down Expand Up @@ -555,7 +555,7 @@ def factorize_array(
codes : ndarray[np.intp]
uniques : ndarray
"""
hash_klass, values = get_data_algo(values)
hash_klass, values = _get_data_algo(values)

table = hash_klass(size_hint or len(values))
uniques, codes = table.factorize(
Expand Down Expand Up @@ -1747,7 +1747,7 @@ def safe_sort(

if sorter is None:
# mixed types
hash_klass, values = get_data_algo(values)
hash_klass, values = _get_data_algo(values)
t = hash_klass(len(values))
t.map_locations(values)
sorter = ensure_platform_int(t.lookup(ordered))
Expand Down
31 changes: 3 additions & 28 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
)
from pandas.core.dtypes.common import (
ensure_int64,
ensure_object,
ensure_platform_int,
is_categorical_dtype,
is_datetime64_dtype,
Expand Down Expand Up @@ -93,7 +92,6 @@
import pandas.core.algorithms as algorithms
from pandas.core.algorithms import (
factorize,
get_data_algo,
take_nd,
unique1d,
)
Expand Down Expand Up @@ -2749,8 +2747,6 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:

If `values` is known to be a Categorical, use recode_for_categories instead.
"""
dtype_equal = is_dtype_equal(values.dtype, categories.dtype)

if values.ndim > 1:
flat = values.ravel()
codes = _get_codes_for_values(flat, categories)
Expand All @@ -2762,30 +2758,9 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
# Categorical(array[Period, Period], categories=PeriodIndex(...))
cls = categories.dtype.construct_array_type()
values = maybe_cast_to_extension_array(cls, values)
if not isinstance(values, cls):
# exception raised in _from_sequence
values = ensure_object(values)
# error: Incompatible types in assignment (expression has type
# "ndarray", variable has type "Index")
categories = ensure_object(categories) # type: ignore[assignment]
elif not dtype_equal:
values = ensure_object(values)
# error: Incompatible types in assignment (expression has type "ndarray",
# variable has type "Index")
categories = ensure_object(categories) # type: ignore[assignment]

if isinstance(categories, ABCIndex):
return coerce_indexer_dtype(categories.get_indexer_for(values), categories)

# Only hit here when we've already coerced to object dtypee.

hash_klass, vals = get_data_algo(values)
# pandas/core/arrays/categorical.py:2661: error: Argument 1 to "get_data_algo" has
# incompatible type "Index"; expected "Union[ExtensionArray, ndarray]" [arg-type]
_, cats = get_data_algo(categories) # type: ignore[arg-type]
t = hash_klass(len(cats))
t.map_locations(cats)
return coerce_indexer_dtype(t.lookup(vals), cats)

codes = categories.get_indexer_for(values)
return coerce_indexer_dtype(codes, categories)


def recode_for_categories(
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/io/parser/dtypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ def test_categorical_coerces_timestamp(all_parsers):
tm.assert_frame_equal(result, expected)


@xfail_pyarrow
def test_categorical_coerces_timedelta(all_parsers):
parser = all_parsers
dtype = {"b": CategoricalDtype(pd.to_timedelta(["1H", "2H", "3H"]))}
Expand Down