Skip to content

Commit 7d1d9ee

Browse files
Backport PR #45242: CLN/PERF: avoid double-checks (#45266)
Co-authored-by: jbrockmendel <[email protected]>
1 parent ce648b6 commit 7d1d9ee

File tree

7 files changed

+62
-44
lines changed

7 files changed

+62
-44
lines changed

pandas/core/arrays/categorical.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1994,7 +1994,7 @@ def _formatter(self, boxed: bool = False):
19941994
# Defer to CategoricalFormatter's formatter.
19951995
return None
19961996

1997-
def _tidy_repr(self, max_vals=10, footer=True) -> str:
1997+
def _tidy_repr(self, max_vals: int = 10, footer: bool = True) -> str:
19981998
"""
19991999
a short repr displaying only max_vals and an optional (but default
20002000
footer)
@@ -2009,7 +2009,7 @@ def _tidy_repr(self, max_vals=10, footer=True) -> str:
20092009

20102010
return str(result)
20112011

2012-
def _repr_categories(self):
2012+
def _repr_categories(self) -> list[str]:
20132013
"""
20142014
return the base repr for the categories
20152015
"""

pandas/core/dtypes/cast.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2121,11 +2121,11 @@ def convert_scalar_for_putitemlike(scalar: Scalar, dtype: np.dtype) -> Scalar:
21212121
scalar = maybe_box_datetimelike(scalar, dtype)
21222122
return maybe_unbox_datetimelike(scalar, dtype)
21232123
else:
2124-
validate_numeric_casting(dtype, scalar)
2124+
_validate_numeric_casting(dtype, scalar)
21252125
return scalar
21262126

21272127

2128-
def validate_numeric_casting(dtype: np.dtype, value: Scalar) -> None:
2128+
def _validate_numeric_casting(dtype: np.dtype, value: Scalar) -> None:
21292129
"""
21302130
Check that we can losslessly insert the given value into an array
21312131
with the given dtype.

pandas/core/frame.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -3867,9 +3867,13 @@ def _set_value(
38673867

38683868
series = self._get_item_cache(col)
38693869
loc = self.index.get_loc(index)
3870-
if not can_hold_element(series._values, value):
3871-
# We'll go through loc and end up casting.
3872-
raise TypeError
3870+
dtype = series.dtype
3871+
if isinstance(dtype, np.dtype) and dtype.kind not in ["m", "M"]:
3872+
# otherwise we have EA values, and this check will be done
3873+
# via setitem_inplace
3874+
if not can_hold_element(series._values, value):
3875+
# We'll go through loc and end up casting.
3876+
raise TypeError
38733877

38743878
series._mgr.setitem_inplace(loc, value)
38753879
# Note: trying to use series._set_value breaks tests in

pandas/core/indexes/base.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@
195195
str_t = str
196196

197197

198-
_o_dtype = np.dtype("object")
198+
_dtype_obj = np.dtype("object")
199199

200200

201201
def _maybe_return_indexers(meth: F) -> F:
@@ -487,7 +487,7 @@ def __new__(
487487
# maybe coerce to a sub-class
488488
arr = data
489489
else:
490-
arr = com.asarray_tuplesafe(data, dtype=np.dtype("object"))
490+
arr = com.asarray_tuplesafe(data, dtype=_dtype_obj)
491491

492492
if dtype is None:
493493
arr = _maybe_cast_data_without_dtype(
@@ -524,7 +524,7 @@ def __new__(
524524
)
525525
# other iterable of some kind
526526

527-
subarr = com.asarray_tuplesafe(data, dtype=np.dtype("object"))
527+
subarr = com.asarray_tuplesafe(data, dtype=_dtype_obj)
528528
if dtype is None:
529529
# with e.g. a list [1, 2, 3] casting to numeric is _not_ deprecated
530530
# error: Incompatible types in assignment (expression has type
@@ -609,9 +609,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
609609

610610
return Int64Index
611611

612-
# error: Non-overlapping equality check (left operand type: "dtype[Any]", right
613-
# operand type: "Type[object]")
614-
elif dtype == object: # type: ignore[comparison-overlap]
612+
elif dtype == _dtype_obj:
615613
# NB: assuming away MultiIndex
616614
return Index
617615

@@ -680,7 +678,7 @@ def _with_infer(cls, *args, **kwargs):
680678
warnings.filterwarnings("ignore", ".*the Index constructor", FutureWarning)
681679
result = cls(*args, **kwargs)
682680

683-
if result.dtype == object and not result._is_multi:
681+
if result.dtype == _dtype_obj and not result._is_multi:
684682
# error: Argument 1 to "maybe_convert_objects" has incompatible type
685683
# "Union[ExtensionArray, ndarray[Any, Any]]"; expected
686684
# "ndarray[Any, Any]"
@@ -3248,7 +3246,7 @@ def _wrap_setop_result(self, other: Index, result) -> Index:
32483246
else:
32493247
result = self._shallow_copy(result, name=name)
32503248

3251-
if type(self) is Index and self.dtype != object:
3249+
if type(self) is Index and self.dtype != _dtype_obj:
32523250
# i.e. ExtensionArray-backed
32533251
# TODO(ExtensionIndex): revert this astype; it is a kludge to make
32543252
# it possible to split ExtensionEngine from ExtensionIndex PR.
@@ -5960,7 +5958,7 @@ def _find_common_type_compat(self, target) -> DtypeObj:
59605958
if is_signed_integer_dtype(self.dtype) or is_signed_integer_dtype(
59615959
target_dtype
59625960
):
5963-
return np.dtype("object")
5961+
return _dtype_obj
59645962

59655963
dtype = find_common_type([self.dtype, target_dtype])
59665964

@@ -5975,7 +5973,7 @@ def _find_common_type_compat(self, target) -> DtypeObj:
59755973
# FIXME: some cases where float64 cast can be lossy?
59765974
dtype = np.dtype(np.float64)
59775975
if dtype.kind == "c":
5978-
dtype = np.dtype(object)
5976+
dtype = _dtype_obj
59795977
return dtype
59805978

59815979
@final

pandas/core/internals/blocks.py

+37-21
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,14 @@ def shape(self) -> Shape:
355355
def dtype(self) -> DtypeObj:
356356
return self.values.dtype
357357

358-
def iget(self, i):
359-
return self.values[i]
358+
def iget(self, i: int | tuple[int, int] | tuple[slice, int]):
359+
# In the case where we have a tuple[slice, int], the slice will always
360+
# be slice(None)
361+
# Note: only reached with self.ndim == 2
362+
# Invalid index type "Union[int, Tuple[int, int], Tuple[slice, int]]"
363+
# for "Union[ndarray[Any, Any], ExtensionArray]"; expected type
364+
# "Union[int, integer[Any]]"
365+
return self.values[i] # type: ignore[index]
360366

361367
def set_inplace(self, locs, values) -> None:
362368
"""
@@ -1166,19 +1172,17 @@ def where(self, other, cond) -> list[Block]:
11661172
values = values.T
11671173

11681174
icond, noop = validate_putmask(values, ~cond)
1175+
if noop:
1176+
# GH-39595: Always return a copy; short-circuit up/downcasting
1177+
return self.copy()
11691178

11701179
if other is lib.no_default:
11711180
other = self.fill_value
11721181

11731182
if is_valid_na_for_dtype(other, self.dtype) and self.dtype != _dtype_obj:
11741183
other = self.fill_value
11751184

1176-
if noop:
1177-
# TODO: avoid the downcasting at the end in this case?
1178-
# GH-39595: Always return a copy
1179-
result = values.copy()
1180-
1181-
elif not self._can_hold_element(other):
1185+
if not self._can_hold_element(other):
11821186
# we cannot coerce, return a compat dtype
11831187
block = self.coerce_to_target_dtype(other)
11841188
blocks = block.where(orig_other, cond)
@@ -1350,11 +1354,7 @@ def where(self, other, cond) -> list[Block]:
13501354
try:
13511355
res_values = arr._where(cond, other).T
13521356
except (ValueError, TypeError) as err:
1353-
if isinstance(err, ValueError):
1354-
# TODO(2.0): once DTA._validate_setitem_value deprecation
1355-
# is enforced, stop catching ValueError here altogether
1356-
if "Timezones don't match" not in str(err):
1357-
raise
1357+
_catch_deprecated_value_error(err)
13581358

13591359
if is_interval_dtype(self.dtype):
13601360
# TestSetitemFloatIntervalWithIntIntervalValues
@@ -1397,10 +1397,7 @@ def putmask(self, mask, new) -> list[Block]:
13971397
# Caller is responsible for ensuring matching lengths
13981398
values._putmask(mask, new)
13991399
except (TypeError, ValueError) as err:
1400-
if isinstance(err, ValueError) and "Timezones don't match" not in str(err):
1401-
# TODO(2.0): remove catching ValueError at all since
1402-
# DTA raising here is deprecated
1403-
raise
1400+
_catch_deprecated_value_error(err)
14041401

14051402
if is_interval_dtype(self.dtype):
14061403
# Discussion about what we want to support in the general
@@ -1490,11 +1487,18 @@ def shape(self) -> Shape:
14901487
return (len(self.values),)
14911488
return len(self._mgr_locs), len(self.values)
14921489

1493-
def iget(self, col):
1490+
def iget(self, i: int | tuple[int, int] | tuple[slice, int]):
1491+
# In the case where we have a tuple[slice, int], the slice will always
1492+
# be slice(None)
1493+
# We _could_ make the annotation more specific, but mypy would
1494+
# complain about override mismatch:
1495+
# Literal[0] | tuple[Literal[0], int] | tuple[slice, int]
14941496

1495-
if self.ndim == 2 and isinstance(col, tuple):
1497+
# Note: only reached with self.ndim == 2
1498+
1499+
if isinstance(i, tuple):
14961500
# TODO(EA2D): unnecessary with 2D EAs
1497-
col, loc = col
1501+
col, loc = i
14981502
if not com.is_null_slice(col) and col != 0:
14991503
raise IndexError(f"{self} only contains one item")
15001504
elif isinstance(col, slice):
@@ -1503,7 +1507,7 @@ def iget(self, col):
15031507
return self.values[[loc]]
15041508
return self.values[loc]
15051509
else:
1506-
if col != 0:
1510+
if i != 0:
15071511
raise IndexError(f"{self} only contains one item")
15081512
return self.values
15091513

@@ -1829,6 +1833,18 @@ def fillna(
18291833
return [self.make_block_same_class(values=new_values)]
18301834

18311835

1836+
def _catch_deprecated_value_error(err: Exception) -> None:
1837+
"""
1838+
We catch ValueError for now, but only a specific one raised by DatetimeArray
1839+
which will no longer be raised in version.2.0.
1840+
"""
1841+
if isinstance(err, ValueError):
1842+
# TODO(2.0): once DTA._validate_setitem_value deprecation
1843+
# is enforced, stop catching ValueError here altogether
1844+
if "Timezones don't match" not in str(err):
1845+
raise
1846+
1847+
18321848
class DatetimeLikeBlock(NDArrayBackedExtensionBlock):
18331849
"""Block for datetime64[ns], timedelta64[ns]."""
18341850

pandas/core/reshape/reshape.py

-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pandas.core.dtypes.common import (
2020
ensure_platform_int,
2121
is_1d_only_ea_dtype,
22-
is_bool_dtype,
2322
is_extension_array_dtype,
2423
is_integer,
2524
is_integer_dtype,
@@ -279,9 +278,6 @@ def get_new_values(self, values, fill_value=None):
279278
if needs_i8_conversion(values.dtype):
280279
sorted_values = sorted_values.view("i8")
281280
new_values = new_values.view("i8")
282-
elif is_bool_dtype(values.dtype):
283-
sorted_values = sorted_values.astype("object")
284-
new_values = new_values.astype("object")
285281
else:
286282
sorted_values = sorted_values.astype(name, copy=False)
287283

pandas/core/series.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1145,8 +1145,12 @@ def __setitem__(self, key, value) -> None:
11451145

11461146
def _set_with_engine(self, key, value) -> None:
11471147
loc = self.index.get_loc(key)
1148-
if not can_hold_element(self._values, value):
1149-
raise ValueError
1148+
dtype = self.dtype
1149+
if isinstance(dtype, np.dtype) and dtype.kind not in ["m", "M"]:
1150+
# otherwise we have EA values, and this check will be done
1151+
# via setitem_inplace
1152+
if not can_hold_element(self._values, value):
1153+
raise ValueError
11501154

11511155
# this is equivalent to self._values[key] = value
11521156
self._mgr.setitem_inplace(loc, value)

0 commit comments

Comments
 (0)