Skip to content

Commit d71ec09

Browse files
authored
REF: simplify CategoricalIndex.__new__ (#38605)
1 parent fc1df2e commit d71ec09

File tree

4 files changed

+29
-23
lines changed

4 files changed

+29
-23
lines changed

pandas/core/arrays/categorical.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,13 @@ class Categorical(NDArrayBackedExtensionArray, PandasObject, ObjectStringArrayMi
298298
_can_hold_na = True
299299

300300
def __init__(
301-
self, values, categories=None, ordered=None, dtype=None, fastpath=False
301+
self,
302+
values,
303+
categories=None,
304+
ordered=None,
305+
dtype=None,
306+
fastpath=False,
307+
copy: bool = True,
302308
):
303309

304310
dtype = CategoricalDtype._from_values_or_dtype(
@@ -359,9 +365,9 @@ def __init__(
359365
dtype = CategoricalDtype(categories, dtype.ordered)
360366

361367
elif is_categorical_dtype(values.dtype):
362-
old_codes = extract_array(values).codes
368+
old_codes = extract_array(values)._codes
363369
codes = recode_for_categories(
364-
old_codes, values.dtype.categories, dtype.categories
370+
old_codes, values.dtype.categories, dtype.categories, copy=copy
365371
)
366372

367373
else:
@@ -389,7 +395,7 @@ def _constructor(self) -> Type["Categorical"]:
389395

390396
@classmethod
391397
def _from_sequence(cls, scalars, *, dtype=None, copy=False):
392-
return Categorical(scalars, dtype=dtype)
398+
return Categorical(scalars, dtype=dtype, copy=copy)
393399

394400
def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
395401
"""

pandas/core/indexes/category.py

+7-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
is_categorical_dtype,
1616
is_scalar,
1717
)
18-
from pandas.core.dtypes.dtypes import CategoricalDtype
1918
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna, notna
2019

2120
from pandas.core import accessor
@@ -184,28 +183,18 @@ def __new__(
184183
cls, data=None, categories=None, ordered=None, dtype=None, copy=False, name=None
185184
):
186185

187-
dtype = CategoricalDtype._from_values_or_dtype(data, categories, ordered, dtype)
188-
189186
name = maybe_extract_name(name, data, cls)
190187

191-
if not is_categorical_dtype(data):
188+
if is_scalar(data):
192189
# don't allow scalars
193190
# if data is None, then categories must be provided
194-
if is_scalar(data):
195-
if data is not None or categories is None:
196-
raise cls._scalar_data_error(data)
197-
data = []
198-
199-
assert isinstance(dtype, CategoricalDtype), dtype
200-
data = extract_array(data, extract_numpy=True)
191+
if data is not None or categories is None:
192+
raise cls._scalar_data_error(data)
193+
data = []
201194

202-
if not isinstance(data, Categorical):
203-
data = Categorical(data, dtype=dtype)
204-
elif isinstance(dtype, CategoricalDtype) and dtype != data.dtype:
205-
# we want to silently ignore dtype='category'
206-
data = data._set_dtype(dtype)
207-
208-
data = data.copy() if copy else data
195+
data = Categorical(
196+
data, categories=categories, ordered=ordered, dtype=dtype, copy=copy
197+
)
209198

210199
return cls._simple_new(data, name=name)
211200

pandas/tests/arrays/categorical/test_constructors.py

+11
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,14 @@ def test_categorical_extension_array_nullable(self, nulls_fixture):
699699
result = Categorical(arr)
700700
expected = Categorical(Series([pd.NA, pd.NA], dtype="object"))
701701
tm.assert_categorical_equal(result, expected)
702+
703+
def test_from_sequence_copy(self):
704+
cat = Categorical(np.arange(5).repeat(2))
705+
result = Categorical._from_sequence(cat, dtype=None, copy=False)
706+
707+
# more generally, we'd be OK with a view
708+
assert result._codes is cat._codes
709+
710+
result = Categorical._from_sequence(cat, dtype=None, copy=True)
711+
712+
assert not np.shares_memory(result._codes, cat._codes)

pandas/tests/indexes/categorical/test_category.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def test_ensure_copied_data(self, index):
304304
assert _base(index.values) is not _base(result.values)
305305

306306
result = CategoricalIndex(index.values, copy=False)
307-
assert _base(index.values) is _base(result.values)
307+
assert result._data._codes is index._data._codes
308308

309309
def test_frame_repr(self):
310310
df = pd.DataFrame({"A": [1, 2, 3]}, index=CategoricalIndex(["a", "b", "c"]))

0 commit comments

Comments
 (0)