Skip to content

Commit a37451f

Browse files
authored
ENH: add validate parameter to Categorical.from_codes (#53122)
* ENH: add validate parameter to Categorical.from_codes * add GH number * simplify a bit * add segfault warning
1 parent 2c164f0 commit a37451f

File tree

7 files changed

+59
-25
lines changed

7 files changed

+59
-25
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Other enhancements
9595
- Improved error message when creating a DataFrame with empty data (0 rows), no index and an incorrect number of columns. (:issue:`52084`)
9696
- Let :meth:`DataFrame.to_feather` accept a non-default :class:`Index` and non-string column names (:issue:`51787`)
9797
- Performance improvement in :func:`read_csv` (:issue:`52632`) with ``engine="c"``
98+
- :meth:`Categorical.from_codes` has gotten a ``validate`` parameter (:issue:`50975`)
9899
- Performance improvement in :func:`concat` with homogeneous ``np.float64`` or ``np.float32`` dtypes (:issue:`52685`)
99100
- Performance improvement in :meth:`DataFrame.filter` when ``items`` is given (:issue:`52941`)
100101
-

pandas/core/arrays/categorical.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,12 @@ def _from_inferred_categories(
649649

650650
@classmethod
651651
def from_codes(
652-
cls, codes, categories=None, ordered=None, dtype: Dtype | None = None
652+
cls,
653+
codes,
654+
categories=None,
655+
ordered=None,
656+
dtype: Dtype | None = None,
657+
validate: bool = True,
653658
) -> Self:
654659
"""
655660
Make a Categorical type from codes and categories or dtype.
@@ -677,6 +682,12 @@ def from_codes(
677682
dtype : CategoricalDtype or "category", optional
678683
If :class:`CategoricalDtype`, cannot be used together with
679684
`categories` or `ordered`.
685+
validate : bool, default True
686+
If True, validate that the codes are valid for the dtype.
687+
If False, don't validate that the codes are valid. Be careful about skipping
688+
validation, as invalid codes can lead to severe problems, such as segfaults.
689+
690+
.. versionadded:: 2.1.0
680691
681692
Returns
682693
-------
@@ -699,18 +710,9 @@ def from_codes(
699710
)
700711
raise ValueError(msg)
701712

702-
if isinstance(codes, ExtensionArray) and is_integer_dtype(codes.dtype):
703-
# Avoid the implicit conversion of Int to object
704-
if isna(codes).any():
705-
raise ValueError("codes cannot contain NA values")
706-
codes = codes.to_numpy(dtype=np.int64)
707-
else:
708-
codes = np.asarray(codes)
709-
if len(codes) and codes.dtype.kind not in "iu":
710-
raise ValueError("codes need to be array-like integers")
711-
712-
if len(codes) and (codes.max() >= len(dtype.categories) or codes.min() < -1):
713-
raise ValueError("codes need to be between -1 and len(categories)-1")
713+
if validate:
714+
# beware: non-valid codes may segfault
715+
codes = cls._validate_codes_for_dtype(codes, dtype=dtype)
714716

715717
return cls._simple_new(codes, dtype=dtype)
716718

@@ -1325,7 +1327,7 @@ def map(
13251327

13261328
if new_categories.is_unique and not new_categories.hasnans and na_val is np.nan:
13271329
new_dtype = CategoricalDtype(new_categories, ordered=self.ordered)
1328-
return self.from_codes(self._codes.copy(), dtype=new_dtype)
1330+
return self.from_codes(self._codes.copy(), dtype=new_dtype, validate=False)
13291331

13301332
if has_nans:
13311333
new_categories = new_categories.insert(len(new_categories), na_val)
@@ -1378,6 +1380,22 @@ def _validate_scalar(self, fill_value):
13781380
) from None
13791381
return fill_value
13801382

1383+
@classmethod
1384+
def _validate_codes_for_dtype(cls, codes, *, dtype: CategoricalDtype) -> np.ndarray:
1385+
if isinstance(codes, ExtensionArray) and is_integer_dtype(codes.dtype):
1386+
# Avoid the implicit conversion of Int to object
1387+
if isna(codes).any():
1388+
raise ValueError("codes cannot contain NA values")
1389+
codes = codes.to_numpy(dtype=np.int64)
1390+
else:
1391+
codes = np.asarray(codes)
1392+
if len(codes) and codes.dtype.kind not in "iu":
1393+
raise ValueError("codes need to be array-like integers")
1394+
1395+
if len(codes) and (codes.max() >= len(dtype.categories) or codes.min() < -1):
1396+
raise ValueError("codes need to be between -1 and len(categories)-1")
1397+
return codes
1398+
13811399
# -------------------------------------------------------------
13821400

13831401
@ravel_compat
@@ -2724,7 +2742,7 @@ def factorize_from_iterable(values) -> tuple[np.ndarray, Index]:
27242742
# The Categorical we want to build has the same categories
27252743
# as values but its codes are by def [0, ..., len(n_categories) - 1]
27262744
cat_codes = np.arange(len(values.categories), dtype=values.codes.dtype)
2727-
cat = Categorical.from_codes(cat_codes, dtype=values.dtype)
2745+
cat = Categorical.from_codes(cat_codes, dtype=values.dtype, validate=False)
27282746

27292747
categories = CategoricalIndex(cat)
27302748
codes = values.codes

pandas/core/groupby/grouper.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def group_index(self) -> Index:
721721
if self._sort and (codes == len(uniques)).any():
722722
# Add NA value on the end when sorting
723723
uniques = Categorical.from_codes(
724-
np.append(uniques.codes, [-1]), uniques.categories
724+
np.append(uniques.codes, [-1]), uniques.categories, validate=False
725725
)
726726
elif len(codes) > 0:
727727
# Need to determine proper placement of NA value when not sorting
@@ -730,8 +730,9 @@ def group_index(self) -> Index:
730730
if cat.codes[na_idx] < 0:
731731
# count number of unique codes that comes before the nan value
732732
na_unique_idx = algorithms.nunique_ints(cat.codes[:na_idx])
733+
new_codes = np.insert(uniques.codes, na_unique_idx, -1)
733734
uniques = Categorical.from_codes(
734-
np.insert(uniques.codes, na_unique_idx, -1), uniques.categories
735+
new_codes, uniques.categories, validate=False
735736
)
736737
return Index._with_infer(uniques, name=self.name)
737738

@@ -754,7 +755,7 @@ def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
754755
ucodes = np.arange(len(categories))
755756

756757
uniques = Categorical.from_codes(
757-
codes=ucodes, categories=categories, ordered=cat.ordered
758+
codes=ucodes, categories=categories, ordered=cat.ordered, validate=False
758759
)
759760

760761
codes = cat.codes
@@ -800,7 +801,8 @@ def _codes_and_uniques(self) -> tuple[npt.NDArray[np.signedinteger], ArrayLike]:
800801

801802
@cache_readonly
802803
def groups(self) -> dict[Hashable, np.ndarray]:
803-
return self._index.groupby(Categorical.from_codes(self.codes, self.group_index))
804+
cats = Categorical.from_codes(self.codes, self.group_index, validate=False)
805+
return self._index.groupby(cats)
804806

805807

806808
def get_grouper(

pandas/core/indexes/multi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2393,7 +2393,7 @@ def cats(level_codes):
23932393
)
23942394

23952395
return [
2396-
Categorical.from_codes(level_codes, cats(level_codes), ordered=True)
2396+
Categorical.from_codes(level_codes, cats(level_codes), True, validate=False)
23972397
for level_codes in self.codes
23982398
]
23992399

@@ -2583,7 +2583,7 @@ def _get_indexer_level_0(self, target) -> npt.NDArray[np.intp]:
25832583
"""
25842584
lev = self.levels[0]
25852585
codes = self._codes[0]
2586-
cat = Categorical.from_codes(codes=codes, categories=lev)
2586+
cat = Categorical.from_codes(codes=codes, categories=lev, validate=False)
25872587
ci = Index(cat)
25882588
return ci.get_indexer_for(target)
25892589

pandas/core/reshape/tile.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ def _bins_to_cuts(
411411
if isinstance(bins, IntervalIndex):
412412
# we have a fast-path here
413413
ids = bins.get_indexer(x)
414-
result = Categorical.from_codes(ids, categories=bins, ordered=True)
414+
cat_dtype = CategoricalDtype(bins, ordered=True)
415+
result = Categorical.from_codes(ids, dtype=cat_dtype, validate=False)
415416
return result, bins
416417

417418
unique_bins = algos.unique(bins)

pandas/io/pytables.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2520,7 +2520,7 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str):
25202520
codes[codes != -1] -= mask.astype(int).cumsum()._values
25212521

25222522
converted = Categorical.from_codes(
2523-
codes, categories=categories, ordered=ordered
2523+
codes, categories=categories, ordered=ordered, validate=False
25242524
)
25252525

25262526
else:

pandas/tests/arrays/categorical/test_constructors.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,13 @@ def test_construction_with_null(self, klass, nulls_fixture):
509509

510510
tm.assert_categorical_equal(result, expected)
511511

512-
def test_from_codes_nullable_int_categories(self, any_numeric_ea_dtype):
512+
@pytest.mark.parametrize("validate", [True, False])
513+
def test_from_codes_nullable_int_categories(self, any_numeric_ea_dtype, validate):
513514
# GH#39649
514515
cats = pd.array(range(5), dtype=any_numeric_ea_dtype)
515516
codes = np.random.randint(5, size=3)
516517
dtype = CategoricalDtype(cats)
517-
arr = Categorical.from_codes(codes, dtype=dtype)
518+
arr = Categorical.from_codes(codes, dtype=dtype, validate=validate)
518519
assert arr.categories.dtype == cats.dtype
519520
tm.assert_index_equal(arr.categories, Index(cats))
520521

@@ -525,6 +526,17 @@ def test_from_codes_empty(self):
525526

526527
tm.assert_categorical_equal(result, expected)
527528

529+
@pytest.mark.parametrize("validate", [True, False])
530+
def test_from_codes_validate(self, validate):
531+
# GH53122
532+
dtype = CategoricalDtype(["a", "b"])
533+
if validate:
534+
with pytest.raises(ValueError, match="codes need to be between "):
535+
Categorical.from_codes([4, 5], dtype=dtype, validate=validate)
536+
else:
537+
# passes, though has incorrect codes, but that's the user responsibility
538+
Categorical.from_codes([4, 5], dtype=dtype, validate=validate)
539+
528540
def test_from_codes_too_few_categories(self):
529541
dtype = CategoricalDtype(categories=[1, 2])
530542
msg = "codes need to be between "

0 commit comments

Comments
 (0)