Skip to content

Commit 4084eb3

Browse files
jbrockmendelproost
authored andcommitted
CLN: simplify core.algorithms (pandas-dev#29199)
1 parent 06aca8b commit 4084eb3

File tree

4 files changed

+42
-64
lines changed

4 files changed

+42
-64
lines changed

pandas/core/algorithms.py

+35-42
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from pandas.core.dtypes.missing import isna, na_value_for_dtype
4848

4949
from pandas.core import common as com
50-
from pandas.core.construction import array
50+
from pandas.core.construction import array, extract_array
5151
from pandas.core.indexers import validate_indices
5252

5353
_shared_docs = {} # type: Dict[str, str]
@@ -82,9 +82,12 @@ def _ensure_data(values, dtype=None):
8282
"""
8383

8484
# we check some simple dtypes first
85+
if is_object_dtype(dtype):
86+
return ensure_object(np.asarray(values)), "object", "object"
87+
elif is_object_dtype(values) and dtype is None:
88+
return ensure_object(np.asarray(values)), "object", "object"
89+
8590
try:
86-
if is_object_dtype(dtype):
87-
return ensure_object(np.asarray(values)), "object", "object"
8891
if is_bool_dtype(values) or is_bool_dtype(dtype):
8992
# we are actually coercing to uint64
9093
# until our algos support uint8 directly (see TODO)
@@ -95,8 +98,6 @@ def _ensure_data(values, dtype=None):
9598
return ensure_uint64(values), "uint64", "uint64"
9699
elif is_float_dtype(values) or is_float_dtype(dtype):
97100
return ensure_float64(values), "float64", "float64"
98-
elif is_object_dtype(values) and dtype is None:
99-
return ensure_object(np.asarray(values)), "object", "object"
100101
elif is_complex_dtype(values) or is_complex_dtype(dtype):
101102

102103
# ignore the fact that we are casting to float
@@ -207,11 +208,11 @@ def _ensure_arraylike(values):
207208

208209

209210
_hashtables = {
210-
"float64": (htable.Float64HashTable, htable.Float64Vector),
211-
"uint64": (htable.UInt64HashTable, htable.UInt64Vector),
212-
"int64": (htable.Int64HashTable, htable.Int64Vector),
213-
"string": (htable.StringHashTable, htable.ObjectVector),
214-
"object": (htable.PyObjectHashTable, htable.ObjectVector),
211+
"float64": htable.Float64HashTable,
212+
"uint64": htable.UInt64HashTable,
213+
"int64": htable.Int64HashTable,
214+
"string": htable.StringHashTable,
215+
"object": htable.PyObjectHashTable,
215216
}
216217

217218

@@ -223,11 +224,9 @@ def _get_hashtable_algo(values):
223224
224225
Returns
225226
-------
226-
tuples(hashtable class,
227-
vector class,
228-
values,
229-
dtype,
230-
ndtype)
227+
htable : HashTable subclass
228+
values : ndarray
229+
dtype : str or dtype
231230
"""
232231
values, dtype, ndtype = _ensure_data(values)
233232

@@ -238,23 +237,21 @@ def _get_hashtable_algo(values):
238237
# StringHashTable and ObjectHashtable
239238
if lib.infer_dtype(values, skipna=False) in ["string"]:
240239
ndtype = "string"
241-
else:
242-
ndtype = "object"
243240

244-
htable, table = _hashtables[ndtype]
245-
return (htable, table, values, dtype, ndtype)
241+
htable = _hashtables[ndtype]
242+
return htable, values, dtype
246243

247244

248245
def _get_values_for_rank(values):
249246
if is_categorical_dtype(values):
250247
values = values._values_for_rank()
251248

252-
values, dtype, ndtype = _ensure_data(values)
253-
return values, dtype, ndtype
249+
values, _, ndtype = _ensure_data(values)
250+
return values, ndtype
254251

255252

256-
def _get_data_algo(values, func_map):
257-
values, dtype, ndtype = _get_values_for_rank(values)
253+
def _get_data_algo(values):
254+
values, ndtype = _get_values_for_rank(values)
258255

259256
if ndtype == "object":
260257

@@ -264,7 +261,7 @@ def _get_data_algo(values, func_map):
264261
if lib.infer_dtype(values, skipna=False) in ["string"]:
265262
ndtype = "string"
266263

267-
f = func_map.get(ndtype, func_map["object"])
264+
f = _hashtables.get(ndtype, _hashtables["object"])
268265

269266
return f, values
270267

@@ -295,7 +292,7 @@ def match(to_match, values, na_sentinel=-1):
295292
match : ndarray of integers
296293
"""
297294
values = com.asarray_tuplesafe(values)
298-
htable, _, values, dtype, ndtype = _get_hashtable_algo(values)
295+
htable, values, dtype = _get_hashtable_algo(values)
299296
to_match, _, _ = _ensure_data(to_match, dtype)
300297
table = htable(min(len(to_match), 1000000))
301298
table.map_locations(values)
@@ -398,7 +395,7 @@ def unique(values):
398395
return values.unique()
399396

400397
original = values
401-
htable, _, values, dtype, ndtype = _get_hashtable_algo(values)
398+
htable, values, _ = _get_hashtable_algo(values)
402399

403400
table = htable(len(values))
404401
uniques = table.unique(values)
@@ -480,7 +477,8 @@ def isin(comps, values):
480477

481478

482479
def _factorize_array(values, na_sentinel=-1, size_hint=None, na_value=None):
483-
"""Factorize an array-like to labels and uniques.
480+
"""
481+
Factorize an array-like to labels and uniques.
484482
485483
This doesn't do any coercion of types or unboxing before factorization.
486484
@@ -498,9 +496,10 @@ def _factorize_array(values, na_sentinel=-1, size_hint=None, na_value=None):
498496
499497
Returns
500498
-------
501-
labels, uniques : ndarray
499+
labels : ndarray
500+
uniques : ndarray
502501
"""
503-
(hash_klass, _), values = _get_data_algo(values, _hashtables)
502+
hash_klass, values = _get_data_algo(values)
504503

505504
table = hash_klass(size_hint or len(values))
506505
uniques, labels = table.factorize(
@@ -652,17 +651,13 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
652651
original = values
653652

654653
if is_extension_array_dtype(values):
655-
values = getattr(values, "_values", values)
654+
values = extract_array(values)
656655
labels, uniques = values.factorize(na_sentinel=na_sentinel)
657656
dtype = original.dtype
658657
else:
659658
values, dtype, _ = _ensure_data(values)
660659

661-
if (
662-
is_datetime64_any_dtype(original)
663-
or is_timedelta64_dtype(original)
664-
or is_period_dtype(original)
665-
):
660+
if original.dtype.kind in ["m", "M"]:
666661
na_value = na_value_for_dtype(original.dtype)
667662
else:
668663
na_value = None
@@ -835,7 +830,7 @@ def duplicated(values, keep="first"):
835830
duplicated : ndarray
836831
"""
837832

838-
values, dtype, ndtype = _ensure_data(values)
833+
values, _, ndtype = _ensure_data(values)
839834
f = getattr(htable, "duplicated_{dtype}".format(dtype=ndtype))
840835
return f(values, keep=keep)
841836

@@ -872,7 +867,7 @@ def mode(values, dropna: bool = True):
872867
mask = values.isnull()
873868
values = values[~mask]
874869

875-
values, dtype, ndtype = _ensure_data(values)
870+
values, _, ndtype = _ensure_data(values)
876871

877872
f = getattr(htable, "mode_{dtype}".format(dtype=ndtype))
878873
result = f(values, dropna=dropna)
@@ -910,7 +905,7 @@ def rank(values, axis=0, method="average", na_option="keep", ascending=True, pct
910905
(e.g. 1, 2, 3) or in percentile form (e.g. 0.333..., 0.666..., 1).
911906
"""
912907
if values.ndim == 1:
913-
values, _, _ = _get_values_for_rank(values)
908+
values, _ = _get_values_for_rank(values)
914909
ranks = algos.rank_1d(
915910
values,
916911
ties_method=method,
@@ -919,7 +914,7 @@ def rank(values, axis=0, method="average", na_option="keep", ascending=True, pct
919914
pct=pct,
920915
)
921916
elif values.ndim == 2:
922-
values, _, _ = _get_values_for_rank(values)
917+
values, _ = _get_values_for_rank(values)
923918
ranks = algos.rank_2d(
924919
values,
925920
axis=axis,
@@ -1634,9 +1629,7 @@ def take_nd(
16341629
if is_extension_array_dtype(arr):
16351630
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
16361631

1637-
if isinstance(arr, (ABCIndexClass, ABCSeries)):
1638-
arr = arr._values
1639-
1632+
arr = extract_array(arr)
16401633
arr = np.asarray(arr)
16411634

16421635
if indexer is None:

pandas/core/arrays/categorical.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,7 @@
4747
from pandas.core import ops
4848
from pandas.core.accessor import PandasDelegate, delegate_names
4949
import pandas.core.algorithms as algorithms
50-
from pandas.core.algorithms import (
51-
_get_data_algo,
52-
_hashtables,
53-
factorize,
54-
take,
55-
take_1d,
56-
unique1d,
57-
)
50+
from pandas.core.algorithms import _get_data_algo, factorize, take, take_1d, unique1d
5851
from pandas.core.base import NoNewAttributesMixin, PandasObject, _shared_docs
5952
import pandas.core.common as com
6053
from pandas.core.construction import array, extract_array, sanitize_array
@@ -2097,7 +2090,6 @@ def __setitem__(self, key, value):
20972090
"""
20982091
Item assignment.
20992092
2100-
21012093
Raises
21022094
------
21032095
ValueError
@@ -2631,8 +2623,8 @@ def _get_codes_for_values(values, categories):
26312623
values = ensure_object(values)
26322624
categories = ensure_object(categories)
26332625

2634-
(hash_klass, vec_klass), vals = _get_data_algo(values, _hashtables)
2635-
(_, _), cats = _get_data_algo(categories, _hashtables)
2626+
hash_klass, vals = _get_data_algo(values)
2627+
_, cats = _get_data_algo(categories)
26362628
t = hash_klass(len(cats))
26372629
t.map_locations(cats)
26382630
return coerce_indexer_dtype(t.lookup(vals), cats)

pandas/core/dtypes/cast.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
ensure_str,
2222
is_bool,
2323
is_bool_dtype,
24-
is_categorical_dtype,
2524
is_complex,
2625
is_complex_dtype,
2726
is_datetime64_dtype,
@@ -1325,14 +1324,10 @@ def construct_1d_arraylike_from_scalar(value, length, dtype):
13251324
np.ndarray / pandas type of length, filled with value
13261325
13271326
"""
1328-
if is_datetime64tz_dtype(dtype):
1329-
from pandas import DatetimeIndex
1330-
1331-
subarr = DatetimeIndex([value] * length, dtype=dtype)
1332-
elif is_categorical_dtype(dtype):
1333-
from pandas import Categorical
1327+
if is_extension_array_dtype(dtype):
1328+
cls = dtype.construct_array_type()
1329+
subarr = cls._from_sequence([value] * length, dtype=dtype)
13341330

1335-
subarr = Categorical([value] * length, dtype=dtype)
13361331
else:
13371332
if not isinstance(dtype, (np.dtype, type(np.dtype))):
13381333
dtype = dtype.dtype

pandas/core/sorting.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,7 @@ def sort_mixed(values):
484484

485485
if sorter is None:
486486
# mixed types
487-
(hash_klass, _), values = algorithms._get_data_algo(
488-
values, algorithms._hashtables
489-
)
487+
hash_klass, values = algorithms._get_data_algo(values)
490488
t = hash_klass(len(values))
491489
t.map_locations(values)
492490
sorter = ensure_platform_int(t.lookup(ordered))

0 commit comments

Comments
 (0)