Skip to content

PERF: avoid copies in lib.infer_dtype #45057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 53 additions & 38 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ from pandas._libs.missing cimport (
is_matching_na,
is_null_datetime64,
is_null_timedelta64,
isnaobj,
)
from pandas._libs.tslibs.conversion cimport convert_to_tsobject
from pandas._libs.tslibs.nattype cimport (
Expand Down Expand Up @@ -1454,6 +1453,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
ndarray values
bint seen_pdnat = False
bint seen_val = False
flatiter it

if util.is_array(value):
values = value
Expand Down Expand Up @@ -1491,24 +1491,22 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
# This should not be reached
values = values.astype(object)

# for f-contiguous array 1000 x 1000, passing order="K" gives 5000x speedup
values = values.ravel(order="K")

if skipna:
values = values[~isnaobj(values)]

n = cnp.PyArray_SIZE(values)
if n == 0:
return "empty"

# Iterate until we find our first valid value. We will use this
# value to decide which of the is_foo_array functions to call.
it = PyArray_IterNew(values)
for i in range(n):
val = values[i]
# The PyArray_GETITEM and PyArray_ITER_NEXT are faster
# equivalents to `val = values[i]`
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
PyArray_ITER_NEXT(it)

# do not use checknull to keep
# np.datetime64('nat') and np.timedelta64('nat')
if val is None or util.is_nan(val):
if val is None or util.is_nan(val) or val is C_NA:
pass
elif val is NaT:
seen_pdnat = True
Expand All @@ -1520,23 +1518,25 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
if seen_val is False and seen_pdnat is True:
return "datetime"
# float/object nan is handled in latter logic
if seen_val is False and skipna:
return "empty"

if util.is_datetime64_object(val):
if is_datetime64_array(values):
if is_datetime64_array(values, skipna=skipna):
return "datetime64"

elif is_timedelta(val):
if is_timedelta_or_timedelta64_array(values):
if is_timedelta_or_timedelta64_array(values, skipna=skipna):
return "timedelta"

elif util.is_integer_object(val):
# ordering matters here; this check must come after the is_timedelta
# check otherwise numpy timedelta64 objects would come through here

if is_integer_array(values):
if is_integer_array(values, skipna=skipna):
return "integer"
elif is_integer_float_array(values):
if is_integer_na_array(values):
elif is_integer_float_array(values, skipna=skipna):
if is_integer_na_array(values, skipna=skipna):
return "integer-na"
else:
return "mixed-integer-float"
Expand All @@ -1557,7 +1557,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
return "time"

elif is_decimal(val):
if is_decimal_array(values):
if is_decimal_array(values, skipna=skipna):
return "decimal"

elif util.is_complex_object(val):
Expand All @@ -1567,8 +1567,8 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
elif util.is_float_object(val):
if is_float_array(values):
return "floating"
elif is_integer_float_array(values):
if is_integer_na_array(values):
elif is_integer_float_array(values, skipna=skipna):
if is_integer_na_array(values, skipna=skipna):
return "integer-na"
else:
return "mixed-integer-float"
Expand All @@ -1586,15 +1586,18 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
return "bytes"

elif is_period_object(val):
if is_period_array(values):
if is_period_array(values, skipna=skipna):
return "period"

elif is_interval(val):
if is_interval_array(values):
return "interval"

cnp.PyArray_ITER_RESET(it)
for i in range(n):
val = values[i]
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
PyArray_ITER_NEXT(it)

if util.is_integer_object(val):
return "mixed-integer"

Expand Down Expand Up @@ -1823,10 +1826,11 @@ cdef class IntegerValidator(Validator):


# Note: only python-exposed for tests
cpdef bint is_integer_array(ndarray values):
cpdef bint is_integer_array(ndarray values, bint skipna=True):
cdef:
IntegerValidator validator = IntegerValidator(len(values),
values.dtype)
values.dtype,
skipna=skipna)
return validator.validate(values)


Expand All @@ -1837,10 +1841,10 @@ cdef class IntegerNaValidator(Validator):
or (util.is_nan(value) and util.is_float_object(value)))


cdef bint is_integer_na_array(ndarray values):
cdef bint is_integer_na_array(ndarray values, bint skipna=True):
cdef:
IntegerNaValidator validator = IntegerNaValidator(len(values),
values.dtype)
values.dtype, skipna=skipna)
return validator.validate(values)


Expand All @@ -1853,10 +1857,11 @@ cdef class IntegerFloatValidator(Validator):
return issubclass(self.dtype.type, np.integer)


cdef bint is_integer_float_array(ndarray values):
cdef bint is_integer_float_array(ndarray values, bint skipna=True):
cdef:
IntegerFloatValidator validator = IntegerFloatValidator(len(values),
values.dtype)
values.dtype,
skipna=skipna)
return validator.validate(values)


Expand Down Expand Up @@ -1900,9 +1905,11 @@ cdef class DecimalValidator(Validator):
return is_decimal(value)


cdef bint is_decimal_array(ndarray values):
cdef bint is_decimal_array(ndarray values, bint skipna=False):
cdef:
DecimalValidator validator = DecimalValidator(len(values), values.dtype)
DecimalValidator validator = DecimalValidator(
len(values), values.dtype, skipna=skipna
)
return validator.validate(values)


Expand Down Expand Up @@ -1997,10 +2004,10 @@ cdef class Datetime64Validator(DatetimeValidator):


# Note: only python-exposed for tests
cpdef bint is_datetime64_array(ndarray values):
cpdef bint is_datetime64_array(ndarray values, bint skipna=True):
cdef:
Datetime64Validator validator = Datetime64Validator(len(values),
skipna=True)
skipna=skipna)
return validator.validate(values)


Expand All @@ -2012,10 +2019,10 @@ cdef class AnyDatetimeValidator(DatetimeValidator):
)


cdef bint is_datetime_or_datetime64_array(ndarray values):
cdef bint is_datetime_or_datetime64_array(ndarray values, bint skipna=True):
cdef:
AnyDatetimeValidator validator = AnyDatetimeValidator(len(values),
skipna=True)
skipna=skipna)
return validator.validate(values)


Expand Down Expand Up @@ -2069,13 +2076,13 @@ cdef class AnyTimedeltaValidator(TimedeltaValidator):


# Note: only python-exposed for tests
cpdef bint is_timedelta_or_timedelta64_array(ndarray values):
cpdef bint is_timedelta_or_timedelta64_array(ndarray values, bint skipna=True):
"""
Infer with timedeltas and/or nat/none.
"""
cdef:
AnyTimedeltaValidator validator = AnyTimedeltaValidator(len(values),
skipna=True)
skipna=skipna)
return validator.validate(values)


Expand Down Expand Up @@ -2105,20 +2112,28 @@ cpdef bint is_time_array(ndarray values, bint skipna=False):
return validator.validate(values)


cdef bint is_period_array(ndarray[object] values):
# FIXME: actually use skipna
cdef bint is_period_array(ndarray values, bint skipna=True):
"""
Is this an ndarray of Period objects (or NaT) with a single `freq`?
"""
# values should be object-dtype, but ndarray[object] assumes 1D, while
# this _may_ be 2D.
cdef:
Py_ssize_t i, n = len(values)
Py_ssize_t i, N = values.size
int dtype_code = -10000 # i.e. c_FreqGroup.FR_UND
object val
flatiter it

if len(values) == 0:
if N == 0:
return False

for i in range(n):
val = values[i]
it = PyArray_IterNew(values)
for i in range(N):
# The PyArray_GETITEM and PyArray_ITER_NEXT are faster
# equivalents to `val = values[i]`
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
PyArray_ITER_NEXT(it)

if is_period_object(val):
if dtype_code == -10000:
Expand Down
1 change: 1 addition & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,7 @@ def any_numpy_dtype(request):
_any_skipna_inferred_dtype = [
("string", ["a", np.nan, "c"]),
("string", ["a", pd.NA, "c"]),
("mixed", ["a", pd.NaT, "c"]), # pd.NaT not considered valid by is_string_array
("bytes", [b"a", np.nan, b"c"]),
("empty", [np.nan, np.nan, np.nan]),
("empty", []),
Expand Down
8 changes: 1 addition & 7 deletions pandas/core/arrays/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,7 @@ def coerce_to_array(
inferred_type = lib.infer_dtype(values, skipna=True)
if inferred_type == "empty":
pass
elif inferred_type not in [
"floating",
"integer",
"mixed-integer",
"integer-na",
"mixed-integer-float",
]:
elif inferred_type == "boolean":
raise TypeError(f"{values.dtype} cannot be converted to a FloatingDtype")

elif is_bool_dtype(values) and is_float_dtype(dtype):
Expand Down
12 changes: 2 additions & 10 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,8 @@ def coerce_to_array(
inferred_type = lib.infer_dtype(values, skipna=True)
if inferred_type == "empty":
pass
elif inferred_type not in [
"floating",
"integer",
"mixed-integer",
"integer-na",
"mixed-integer-float",
"string",
"unicode",
]:
raise TypeError(f"{values.dtype} cannot be converted to an IntegerDtype")
elif inferred_type == "boolean":
raise TypeError(f"{values.dtype} cannot be converted to a FloatingDtype")

elif is_bool_dtype(values) and is_integer_dtype(dtype):
values = np.array(values, dtype=int, copy=copy)
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/arrays/floating/test_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_to_array_error(values):
"cannot be converted to a FloatingDtype",
"values must be a 1D list-like",
"Cannot pass scalar",
r"float\(\) argument must be a string or a (real )?number, not 'dict'",
]
)
with pytest.raises((TypeError, ValueError), match=msg):
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/arrays/integer/test_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_to_integer_array_error(values):
r"invalid literal for int\(\) with base 10:",
r"values must be a 1D list-like",
r"Cannot pass scalar",
r"int\(\) argument must be a string",
]
)
with pytest.raises((ValueError, TypeError), match=msg):
Expand Down
18 changes: 16 additions & 2 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,26 @@ def test_constructor_raises(cls):
with pytest.raises(ValueError, match=msg):
cls(np.array([]))

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", np.datetime64("nat")], dtype=object))
if cls is pd.arrays.StringArray:
# GH#45057 np.nan and None do NOT raise, as they are considered valid NAs
# for string dtype
cls(np.array(["a", np.nan], dtype=object))
cls(np.array(["a", None], dtype=object))
else:
with pytest.raises(ValueError, match=msg):
cls(np.array(["a", np.nan], dtype=object))
with pytest.raises(ValueError, match=msg):
cls(np.array(["a", None], dtype=object))

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", pd.NaT], dtype=object))

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object))

with pytest.raises(ValueError, match=msg):
cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object))


@pytest.mark.parametrize("na", [np.nan, np.float64("nan"), float("nan"), None, pd.NA])
def test_constructor_nan_like(na):
Expand Down
19 changes: 19 additions & 0 deletions pandas/tests/dtypes/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,10 +1134,20 @@ def test_unicode(self):
# This could also return "string" or "mixed-string"
assert result == "mixed"

# even though we use skipna, we are only skipping those NAs that are
# considered matching by is_string_array
arr = ["a", np.nan, "c"]
result = lib.infer_dtype(arr, skipna=True)
assert result == "string"

arr = ["a", pd.NA, "c"]
result = lib.infer_dtype(arr, skipna=True)
assert result == "string"

arr = ["a", pd.NaT, "c"]
result = lib.infer_dtype(arr, skipna=True)
assert result == "mixed"

arr = ["a", "c"]
result = lib.infer_dtype(arr, skipna=False)
assert result == "string"
Expand Down Expand Up @@ -1544,15 +1554,24 @@ def test_is_string_array(self):
assert lib.is_string_array(
np.array(["foo", "bar", pd.NA], dtype=object), skipna=True
)
# we allow NaN/None in the StringArray constructor, so its allowed here
assert lib.is_string_array(
np.array(["foo", "bar", None], dtype=object), skipna=True
)
assert lib.is_string_array(
np.array(["foo", "bar", np.nan], dtype=object), skipna=True
)
# But not e.g. datetimelike or Decimal NAs
assert not lib.is_string_array(
np.array(["foo", "bar", pd.NaT], dtype=object), skipna=True
)
assert not lib.is_string_array(
np.array(["foo", "bar", np.datetime64("NaT")], dtype=object), skipna=True
)
assert not lib.is_string_array(
np.array(["foo", "bar", Decimal("NaN")], dtype=object), skipna=True
)

assert not lib.is_string_array(
np.array(["foo", "bar", None], dtype=object), skipna=False
)
Expand Down