Skip to content

Commit 3f18b40

Browse files
jbrockmendelim-vinicius
authored and
im-vinicius
committed
CLN: more accurate is_scalar checks (pandas-dev#52971)
* REF: avoid is_scalar * comment * infer_dtype->is_bool_array * fix invalid refs
1 parent a4d5ef6 commit 3f18b40

File tree

17 files changed

+53
-65
lines changed

17 files changed

+53
-65
lines changed

pandas/core/algorithms.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
is_integer_dtype,
4848
is_list_like,
4949
is_object_dtype,
50-
is_scalar,
5150
is_signed_integer_dtype,
5251
needs_i8_conversion,
5352
)
@@ -1321,15 +1320,15 @@ def searchsorted(
13211320
# Before searching below, we therefore try to give `value` the
13221321
# same dtype as `arr`, while guarding against integer overflows.
13231322
iinfo = np.iinfo(arr.dtype.type)
1324-
value_arr = np.array([value]) if is_scalar(value) else np.array(value)
1323+
value_arr = np.array([value]) if is_integer(value) else np.array(value)
13251324
if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all():
13261325
# value within bounds, so no overflow, so can convert value dtype
13271326
# to dtype of arr
13281327
dtype = arr.dtype
13291328
else:
13301329
dtype = value_arr.dtype
13311330

1332-
if is_scalar(value):
1331+
if is_integer(value):
13331332
# We know that value is int
13341333
value = cast(int, dtype.type(value))
13351334
else:

pandas/core/array_algos/replace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import numpy as np
1515

1616
from pandas.core.dtypes.common import (
17+
is_bool,
1718
is_re,
1819
is_re_compilable,
19-
is_scalar,
2020
)
2121
from pandas.core.dtypes.missing import isna
2222

@@ -72,7 +72,7 @@ def _check_comparison_types(
7272
Raises an error if the two arrays (a,b) cannot be compared.
7373
Otherwise, returns the comparison result as expected.
7474
"""
75-
if is_scalar(result) and isinstance(a, np.ndarray):
75+
if is_bool(result) and isinstance(a, np.ndarray):
7676
type_names = [type(a).__name__, type(b).__name__]
7777

7878
type_names[0] = f"ndarray(dtype={a.dtype})"

pandas/core/arrays/sparse/array.py

-2
Original file line numberDiff line numberDiff line change
@@ -1120,8 +1120,6 @@ def searchsorted(
11201120
) -> npt.NDArray[np.intp] | np.intp:
11211121
msg = "searchsorted requires high memory usage."
11221122
warnings.warn(msg, PerformanceWarning, stacklevel=find_stack_level())
1123-
if not is_scalar(v):
1124-
v = np.asarray(v)
11251123
v = np.asarray(v)
11261124
return np.asarray(self, dtype=self.dtype.subtype).searchsorted(v, side, sorter)
11271125

pandas/core/common.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
ABCSeries,
4545
)
4646
from pandas.core.dtypes.inference import iterable_not_string
47-
from pandas.core.dtypes.missing import isna
4847

4948
if TYPE_CHECKING:
5049
from pandas._typing import (
@@ -129,7 +128,7 @@ def is_bool_indexer(key: Any) -> bool:
129128

130129
if not lib.is_bool_array(key_array):
131130
na_msg = "Cannot mask with non-boolean array containing NA / NaN values"
132-
if lib.infer_dtype(key_array) == "boolean" and isna(key_array).any():
131+
if lib.is_bool_array(key_array, skipna=True):
133132
# Don't raise on e.g. ["A", "B", np.nan], see
134133
# test_loc_getitem_list_of_labels_categoricalindex_with_na
135134
raise ValueError(na_msg)

pandas/core/dtypes/cast.py

-2
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,6 @@ def maybe_cast_pointwise_result(
455455
result maybe casted to the dtype.
456456
"""
457457

458-
assert not is_scalar(result)
459-
460458
if isinstance(dtype, ExtensionDtype):
461459
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
462460
# TODO: avoid this special-casing

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3914,7 +3914,7 @@ def isetitem(self, loc, value) -> None:
39143914
``frame[frame.columns[i]] = value``.
39153915
"""
39163916
if isinstance(value, DataFrame):
3917-
if is_scalar(loc):
3917+
if is_integer(loc):
39183918
loc = [loc]
39193919

39203920
if len(loc) != len(value.columns):

pandas/core/generic.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -8346,9 +8346,7 @@ def clip(
83468346
lower, upper = min(lower, upper), max(lower, upper)
83478347

83488348
# fast-path for scalars
8349-
if (lower is None or (is_scalar(lower) and is_number(lower))) and (
8350-
upper is None or (is_scalar(upper) and is_number(upper))
8351-
):
8349+
if (lower is None or is_number(lower)) and (upper is None or is_number(upper)):
83528350
return self._clip_with_scalar(lower, upper, inplace=inplace)
83538351

83548352
result = self

pandas/core/indexes/base.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,9 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str_t, *inputs, **kwargs):
925925
if ufunc.nout == 2:
926926
# i.e. np.divmod, np.modf, np.frexp
927927
return tuple(self.__array_wrap__(x) for x in result)
928+
elif method == "reduce":
929+
result = lib.item_from_zerodim(result)
930+
return result
928931

929932
if result.dtype == np.float16:
930933
result = result.astype(np.float32)
@@ -937,11 +940,9 @@ def __array_wrap__(self, result, context=None):
937940
Gets called after a ufunc and other functions e.g. np.split.
938941
"""
939942
result = lib.item_from_zerodim(result)
940-
if (
941-
(not isinstance(result, Index) and is_bool_dtype(result.dtype))
942-
or lib.is_scalar(result)
943-
or np.ndim(result) > 1
944-
):
943+
if (not isinstance(result, Index) and is_bool_dtype(result.dtype)) or np.ndim(
944+
result
945+
) > 1:
945946
# exclude Index to avoid warning from is_bool_dtype deprecation;
946947
# in the Index case it doesn't matter which path we go down.
947948
# reached in plotting tests with e.g. np.nonzero(index)

pandas/core/nanops.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
is_integer,
4141
is_numeric_dtype,
4242
is_object_dtype,
43-
is_scalar,
4443
needs_i8_conversion,
4544
pandas_dtype,
4645
)
@@ -291,7 +290,6 @@ def _get_values(
291290
# In _get_values is only called from within nanops, and in all cases
292291
# with scalar fill_value. This guarantee is important for the
293292
# np.where call below
294-
assert is_scalar(fill_value)
295293

296294
mask = _maybe_get_mask(values, skipna, mask)
297295

@@ -876,12 +874,15 @@ def _get_counts_nanvar(
876874
d = count - dtype.type(ddof)
877875

878876
# always return NaN, never inf
879-
if is_scalar(count):
877+
if is_float(count):
880878
if count <= ddof:
881-
count = np.nan
879+
# error: Incompatible types in assignment (expression has type
880+
# "float", variable has type "Union[floating[Any], ndarray[Any,
881+
# dtype[floating[Any]]]]")
882+
count = np.nan # type: ignore[assignment]
882883
d = np.nan
883884
else:
884-
# count is not narrowed by is_scalar check
885+
# count is not narrowed by is_float check
885886
count = cast(np.ndarray, count)
886887
mask = count <= ddof
887888
if mask.any():
@@ -1444,8 +1445,8 @@ def _get_counts(
14441445
values_shape: Shape,
14451446
mask: npt.NDArray[np.bool_] | None,
14461447
axis: AxisInt | None,
1447-
dtype: np.dtype = np.dtype(np.float64),
1448-
) -> float | np.ndarray:
1448+
dtype: np.dtype[np.floating] = np.dtype(np.float64),
1449+
) -> np.floating | npt.NDArray[np.floating]:
14491450
"""
14501451
Get the count of non-null values along an axis
14511452
@@ -1476,7 +1477,7 @@ def _get_counts(
14761477
else:
14771478
count = values_shape[axis]
14781479

1479-
if is_scalar(count):
1480+
if is_integer(count):
14801481
return dtype.type(count)
14811482
return count.astype(dtype, copy=False)
14821483

pandas/core/ops/missing.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _fill_zeros(result, x, y):
5353
is_scalar_type = is_scalar(y)
5454

5555
if not is_variable_type and not is_scalar_type:
56+
# e.g. test_series_ops_name_retention with mod we get here with list/tuple
5657
return result
5758

5859
if is_scalar_type:

pandas/core/series.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1019,10 +1019,7 @@ def _get_with(self, key):
10191019
if not isinstance(key, (list, np.ndarray, ExtensionArray, Series, Index)):
10201020
key = list(key)
10211021

1022-
if isinstance(key, Index):
1023-
key_type = key.inferred_type
1024-
else:
1025-
key_type = lib.infer_dtype(key, skipna=False)
1022+
key_type = lib.infer_dtype(key, skipna=False)
10261023

10271024
# Note: The key_type == "boolean" case should be caught by the
10281025
# com.is_bool_indexer check in __getitem__

pandas/core/strings/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
TYPE_CHECKING,
66
Callable,
77
Literal,
8+
Sequence,
89
)
910

1011
import numpy as np
@@ -79,7 +80,7 @@ def _str_replace(
7980
pass
8081

8182
@abc.abstractmethod
82-
def _str_repeat(self, repeats):
83+
def _str_repeat(self, repeats: int | Sequence[int]):
8384
pass
8485

8586
@abc.abstractmethod

pandas/core/strings/object_array.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
TYPE_CHECKING,
99
Callable,
1010
Literal,
11+
Sequence,
12+
cast,
1113
)
1214
import unicodedata
1315

@@ -17,7 +19,6 @@
1719
import pandas._libs.missing as libmissing
1820
import pandas._libs.ops as libops
1921

20-
from pandas.core.dtypes.common import is_scalar
2122
from pandas.core.dtypes.missing import isna
2223

2324
from pandas.core.strings.base import BaseStringArrayMethods
@@ -177,14 +178,15 @@ def _str_replace(
177178

178179
return self._str_map(f, dtype=str)
179180

180-
def _str_repeat(self, repeats):
181-
if is_scalar(repeats):
181+
def _str_repeat(self, repeats: int | Sequence[int]):
182+
if lib.is_integer(repeats):
183+
rint = cast(int, repeats)
182184

183185
def scalar_rep(x):
184186
try:
185-
return bytes.__mul__(x, repeats)
187+
return bytes.__mul__(x, rint)
186188
except TypeError:
187-
return str.__mul__(x, repeats)
189+
return str.__mul__(x, rint)
188190

189191
return self._str_map(scalar_rep, dtype=str)
190192
else:
@@ -198,8 +200,11 @@ def rep(x, r):
198200
except TypeError:
199201
return str.__mul__(x, r)
200202

201-
repeats = np.asarray(repeats, dtype=object)
202-
result = libops.vec_binop(np.asarray(self), repeats, rep)
203+
result = libops.vec_binop(
204+
np.asarray(self),
205+
np.asarray(repeats, dtype=object),
206+
rep,
207+
)
203208
if isinstance(self, BaseStringArray):
204209
# Not going through map, so we have to do this here.
205210
result = type(self)._from_sequence(result)

pandas/core/tools/datetimes.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
is_integer_dtype,
5757
is_list_like,
5858
is_numeric_dtype,
59-
is_scalar,
6059
)
6160
from pandas.core.dtypes.dtypes import DatetimeTZDtype
6261
from pandas.core.dtypes.generic import (
@@ -599,8 +598,7 @@ def _adjust_to_origin(arg, origin, unit):
599598
else:
600599
# arg must be numeric
601600
if not (
602-
(is_scalar(arg) and (is_integer(arg) or is_float(arg)))
603-
or is_numeric_dtype(np.asarray(arg))
601+
(is_integer(arg) or is_float(arg)) or is_numeric_dtype(np.asarray(arg))
604602
):
605603
raise ValueError(
606604
f"'{arg}' is not compatible with origin='{origin}'; "

pandas/io/parsers/base_parser.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -250,23 +250,18 @@ def _has_complex_date_col(self) -> bool:
250250

251251
@final
252252
def _should_parse_dates(self, i: int) -> bool:
253-
if isinstance(self.parse_dates, bool):
254-
return self.parse_dates
253+
if lib.is_bool(self.parse_dates):
254+
return bool(self.parse_dates)
255255
else:
256256
if self.index_names is not None:
257257
name = self.index_names[i]
258258
else:
259259
name = None
260260
j = i if self.index_col is None else self.index_col[i]
261261

262-
if is_scalar(self.parse_dates):
263-
return (j == self.parse_dates) or (
264-
name is not None and name == self.parse_dates
265-
)
266-
else:
267-
return (j in self.parse_dates) or (
268-
name is not None and name in self.parse_dates
269-
)
262+
return (j in self.parse_dates) or (
263+
name is not None and name in self.parse_dates
264+
)
270265

271266
@final
272267
def _extract_multi_indexer_columns(
@@ -1370,13 +1365,12 @@ def _validate_parse_dates_arg(parse_dates):
13701365
"for the 'parse_dates' parameter"
13711366
)
13721367

1373-
if parse_dates is not None:
1374-
if is_scalar(parse_dates):
1375-
if not lib.is_bool(parse_dates):
1376-
raise TypeError(msg)
1377-
1378-
elif not isinstance(parse_dates, (list, dict)):
1379-
raise TypeError(msg)
1368+
if not (
1369+
parse_dates is None
1370+
or lib.is_bool(parse_dates)
1371+
or isinstance(parse_dates, (list, dict))
1372+
):
1373+
raise TypeError(msg)
13801374

13811375
return parse_dates
13821376

pandas/tests/frame/methods/test_reindex.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
)
2424
import pandas._testing as tm
2525
from pandas.api.types import CategoricalDtype as CDT
26-
import pandas.core.common as com
2726

2827

2928
class TestReindexSetIndex:
@@ -355,7 +354,7 @@ def test_reindex_frame_add_nat(self):
355354
result = df.reindex(range(15))
356355
assert np.issubdtype(result["B"].dtype, np.dtype("M8[ns]"))
357356

358-
mask = com.isna(result)["B"]
357+
mask = isna(result)["B"]
359358
assert mask[-5:].all()
360359
assert not mask[:-5].any()
361360

pandas/tests/frame/test_arithmetic.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Series,
2222
)
2323
import pandas._testing as tm
24-
import pandas.core.common as com
2524
from pandas.core.computation import expressions as expr
2625
from pandas.core.computation.expressions import (
2726
_MIN_ELEMENTS,
@@ -1246,12 +1245,12 @@ def test_operators_none_as_na(self, op):
12461245
filled = df.fillna(np.nan)
12471246
result = op(df, 3)
12481247
expected = op(filled, 3).astype(object)
1249-
expected[com.isna(expected)] = None
1248+
expected[pd.isna(expected)] = None
12501249
tm.assert_frame_equal(result, expected)
12511250

12521251
result = op(df, df)
12531252
expected = op(filled, filled).astype(object)
1254-
expected[com.isna(expected)] = None
1253+
expected[pd.isna(expected)] = None
12551254
tm.assert_frame_equal(result, expected)
12561255

12571256
result = op(df, df.fillna(7))

0 commit comments

Comments
 (0)