From f48502fd2678ec18a6d3c0b9fffda4396d17e5a2 Mon Sep 17 00:00:00 2001 From: samukweku Date: Sun, 19 Nov 2023 18:08:13 +1100 Subject: [PATCH 01/34] updates --- doc/source/whatsnew/v2.2.0.rst | 19 +++ pandas/__init__.py | 2 + pandas/core/api.py | 2 + pandas/core/case_when.py | 218 +++++++++++++++++++++++++++++++++ pandas/core/series.py | 79 +++++++++++- pandas/tests/api/test_api.py | 1 + pandas/tests/test_case_when.py | 168 +++++++++++++++++++++++++ 7 files changed, 488 insertions(+), 1 deletion(-) create mode 100644 pandas/core/case_when.py create mode 100644 pandas/tests/test_case_when.py diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 16cfc3efc1024..9daa83cf6b85f 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -14,6 +14,25 @@ including other versions of pandas. Enhancements ~~~~~~~~~~~~ +.. _whatsnew_220.enhancements.case_when: + +Create a pandas Series based on one or more conditions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`) + +.. ipython:: python + + import pandas as pd + + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + pd.case_when( + (df.a == 1, 'first'), # condition, replacement + (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement + default = 'default', # optional + ) + + .. _whatsnew_220.enhancements.calamine: Calamine engine for :func:`read_excel` diff --git a/pandas/__init__.py b/pandas/__init__.py index 7fab662ed2de4..08cb057090b23 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -73,6 +73,7 @@ notnull, # indexes Index, + case_when, CategoricalIndex, RangeIndex, MultiIndex, @@ -253,6 +254,7 @@ "ArrowDtype", "BooleanDtype", "Categorical", + "case_when", "CategoricalDtype", "CategoricalIndex", "DataFrame", diff --git a/pandas/core/api.py b/pandas/core/api.py index 2cfe5ffc0170d..5c7271c155459 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -42,6 +42,7 @@ UInt64Dtype, ) from pandas.core.arrays.string_ import StringDtype +from pandas.core.case_when import case_when from pandas.core.construction import array from pandas.core.flags import Flags from pandas.core.groupby import ( @@ -86,6 +87,7 @@ "bdate_range", "BooleanDtype", "Categorical", + "case_when", "CategoricalDtype", "CategoricalIndex", "DataFrame", diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py new file mode 100644 index 0000000000000..d2ab59379ac03 --- /dev/null +++ b/pandas/core/case_when.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pandas._libs import lib + +from pandas.core.dtypes.cast import ( + construct_1d_arraylike_from_scalar, + find_common_type, + infer_dtype_from, +) +from pandas.core.dtypes.common import is_scalar +from pandas.core.dtypes.generic import ABCSeries + +from pandas.core.construction import array as pd_array + +if TYPE_CHECKING: + from pandas._typing import Series + + +def case_when( + *args: tuple[tuple], + default=lib.no_default, +) -> Series: + """ + Replace values where the conditions are True. + + Parameters + ---------- + *args : tuple(s) of array-like, scalar + Variable argument of tuples of conditions and expected replacements. + Takes the form: ``(condition0, replacement0)``, + ``(condition1, replacement1)``, ... . + ``condition`` should be a 1-D boolean array. + When multiple boolean conditions are satisfied, + the first replacement is used. + If ``condition`` is a Series, and the equivalent ``replacement`` + is a Series, they must have the same index. + If there are multiple replacement options, + and they are Series, they must have the same index. + + default : scalar, array-like, default None + If provided, it is the replacement value to use + if all conditions evaluate to False. + If not specified, entries will be filled with the + corresponding NULL value. + + .. versionadded:: 2.2.0 + + Returns + ------- + Series + + See Also + -------- + Series.mask : Replace values where the condition is True. + + Examples + -------- + >>> df = pd.DataFrame({ + ... "a": [0,0,1,2], + ... "b": [0,3,4,5], + ... "c": [6,7,8,9] + ... }) + >>> df + a b c + 0 0 0 6 + 1 0 3 7 + 2 1 4 8 + 3 2 5 9 + + >>> pd.case_when((df.a.gt(0), df.a), # condition, replacement + ... (df.b.gt(0), df.b), # condition, replacement + ... default=df.c) # optional + 0 6 + 1 3 + 2 1 + 3 2 + Name: c, dtype: int64 + """ + from pandas import Series + from pandas._testing.asserters import assert_index_equal + + args = validate_case_when(args=args) + + conditions, replacements = zip(*args) + common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements] + + if default is not lib.no_default: + arg_dtype, _ = infer_dtype_from(default) + common_dtypes.append(arg_dtype) + else: + default = None + if len(set(common_dtypes)) > 1: + common_dtypes = find_common_type(common_dtypes) + updated_replacements = [] + for condition, replacement in zip(conditions, replacements): + if is_scalar(replacement): + replacement = construct_1d_arraylike_from_scalar( + value=replacement, length=len(condition), dtype=common_dtypes + ) + elif isinstance(replacement, ABCSeries): + replacement = replacement.astype(common_dtypes) + else: + replacement = pd_array(replacement, dtype=common_dtypes) + updated_replacements.append(replacement) + replacements = updated_replacements + if (default is not None) and isinstance(default, ABCSeries): + default = default.astype(common_dtypes) + else: + common_dtypes = common_dtypes[0] + if not isinstance(default, ABCSeries): + cond_indices = [cond for cond in conditions if isinstance(cond, ABCSeries)] + replacement_indices = [ + replacement + for replacement in replacements + if isinstance(replacement, ABCSeries) + ] + cond_length = None + if replacement_indices: + for left, right in zip(replacement_indices, replacement_indices[1:]): + try: + assert_index_equal(left.index, right.index, check_order=False) + except AssertionError: + raise AssertionError( + "All replacement objects must have the same index." + ) + if cond_indices: + for left, right in zip(cond_indices, cond_indices[1:]): + try: + assert_index_equal(left.index, right.index, check_order=False) + except AssertionError: + raise AssertionError( + "All condition objects must have the same index." + ) + if replacement_indices: + try: + assert_index_equal( + replacement_indices[0].index, + cond_indices[0].index, + check_order=False, + ) + except AssertionError: + raise AssertionError( + "All replacement objects and condition objects " + "should have the same index." + ) + else: + conditions = [ + np.asanyarray(cond) if not hasattr(cond, "shape") else cond + for cond in conditions + ] + cond_length = {len(cond) for cond in conditions} + if len(cond_length) > 1: + raise ValueError("The boolean conditions should have the same length.") + cond_length = len(conditions[0]) + if not is_scalar(default): + if len(default) != cond_length: + raise ValueError( + "length of `default` does not match the length " + "of any of the conditions." + ) + if not replacement_indices: + for num, replacement in enumerate(replacements): + if is_scalar(replacement): + continue + if not hasattr(replacement, "shape"): + replacement = np.asanyarray(replacement) + if len(replacement) != cond_length: + raise ValueError( + f"Length of condition{num} does not match " + f"the length of replacement{num}; " + f"{cond_length} != {len(replacement)}" + ) + if cond_indices: + default_index = cond_indices[0].index + elif replacement_indices: + default_index = replacement_indices[0].index + else: + default_index = range(cond_length) + default = Series(default, index=default_index, dtype=common_dtypes) + counter = reversed(range(len(conditions))) + for position, condition, replacement in zip( + counter, conditions[::-1], replacements[::-1] + ): + try: + default = default.mask( + condition, other=replacement, axis=0, inplace=False, level=None + ) + except Exception as error: + raise ValueError( + f"condition{position} and replacement{position} failed to evaluate." + ) from error + return default + + +def validate_case_when(args: tuple) -> tuple: + """ + Validates the variable arguments for the case_when function. + """ + if not len(args): + raise ValueError( + "provide at least one boolean condition, " + "with a corresponding replacement." + ) + + for num, entry in enumerate(args): + if not isinstance(entry, tuple): + raise TypeError(f"Argument {num} must be a tuple; got {type(entry)}.") + if len(entry) != 2: + raise ValueError( + f"Argument {num} must have length 2; " + "a condition and replacement; " + f"got length {len(entry)}." + ) + return args \ No newline at end of file diff --git a/pandas/core/series.py b/pandas/core/series.py index 1bbd10429ea22..0b3bf829a6296 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -28,7 +28,10 @@ from pandas._config import using_copy_on_write from pandas._config.config import _get_option - +from pandas.core.case_when import ( + case_when, + validate_case_when, +) from pandas._libs import ( lib, properties, @@ -5554,6 +5557,80 @@ def between( return lmask & rmask + def case_when( + self, + *args: tuple[tuple], + ) -> Series: + """ + Replace values where the conditions are True. + + Parameters + ---------- + *args : tuple(s) of array-like, scalar. + Variable argument of tuples of conditions and expected replacements. + Takes the form: ``(condition0, replacement0)``, + ``(condition1, replacement1)``, ... . + ``condition`` should be a 1-D boolean array-like object + or a callable. If ``condition`` is a callable, + it is computed on the Series + and should return a boolean Series or array. + The callable must not change the input Series + (though pandas doesn`t check it). ``replacement`` should be a + 1-D array-like object, a scalar or a callable. + If ``replacement`` is a callable, it is computed on the Series + and should return a scalar or Series. The callable + must not change the input Series + (though pandas doesn`t check it). + If ``condition`` is a Series, and the equivalent ``replacement`` + is a Series, they must have the same index. + If there are multiple replacement options, + and they are Series, they must have the same index. + + level : int, default None + Alignment level if needed. + + .. versionadded:: 2.2.0 + + Returns + ------- + Series + + See Also + -------- + Series.mask : Replace values where the condition is True. + + Examples + -------- + >>> df = pd.DataFrame({ + ... "a": [0,0,1,2], + ... "b": [0,3,4,5], + ... "c": [6,7,8,9] + ... }) + >>> df + a b c + 0 0 0 6 + 1 0 3 7 + 2 1 4 8 + 3 2 5 9 + + >>> df.c.case_when((df.a.gt(0), df.a), # condition, replacement + ... (df.b.gt(0), df.b)) + 0 6 + 1 3 + 2 1 + 3 2 + Name: c, dtype: int64 + """ + args = validate_case_when(args) + args = [ + ( + com.apply_if_callable(condition, self), + com.apply_if_callable(replacement, self), + ) + for condition, replacement in args + ] + return case_when(*args, default=self, level=None) + # error: Cannot determine type of 'isna' @doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] def isna(self) -> Series: diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index 60bcb97aaa364..ae4fe5d56ebf6 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -106,6 +106,7 @@ class TestPDApi(Base): funcs = [ "array", "bdate_range", + "case_when", "concat", "crosstab", "cut", diff --git a/pandas/tests/test_case_when.py b/pandas/tests/test_case_when.py new file mode 100644 index 0000000000000..869e952584ae8 --- /dev/null +++ b/pandas/tests/test_case_when.py @@ -0,0 +1,168 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + array as pd_array, + case_when, + date_range, +) +import pandas._testing as tm + + +@pytest.fixture +def df(): + """ + base dataframe for testing + """ + return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +def test_case_when_no_args(): + """ + Raise ValueError if no args is provided. + """ + msg = "provide at least one boolean condition, " + msg += "with a corresponding replacement." + with pytest.raises(ValueError, match=msg): + case_when() + + +def test_case_when_odd_args(df): + """ + Raise ValueError if no of args is odd. + """ + msg = "Argument 0 must have length 2; " + msg += "a condition and replacement; got length 3." + + with pytest.raises(ValueError, match=msg): + case_when((df["a"].eq(1), 1, df.a.gt(1))) + + +def test_case_when_raise_error_from_mask(df): + """ + Raise Error from within Series.mask + """ + msg = "condition0 and replacement0 failed to evaluate." + with pytest.raises(ValueError, match=msg): + case_when((df["a"].eq(1), df)) + + +def test_case_when_error_multiple_replacements_series(df): + """ + Test output when the replacements indices are different. + """ + with pytest.raises( + AssertionError, match="All replacement objects must have the same index." + ): + case_when( + ([True, False, False], Series(1)), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ) + + +def test_case_when_error_multiple_conditions_series(df): + """ + Test output when the conditions indices are different. + """ + with pytest.raises( + AssertionError, match="All condition objects must have the same index." + ): + case_when( + (Series([True, False, False], index=[2, 3, 4]), 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ) + + +def test_case_when_raise_error_different_index_condition_and_replacements(df): + """ + Raise if the replacement index and condition index are different. + """ + msg = "All replacement objects and condition objects " + msg += "should have the same index." + with pytest.raises(AssertionError, match=msg): + case_when( + (df.a.eq(1), Series(1)), (Series([False, True, False]), Series(2)) + ) + + +def test_case_when_single_condition(df): + """ + Test output on a single condition. + """ + result = case_when((df.a.eq(1), 1)) + expected = Series([1, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions(df): + """ + Test output when booleans are derived from a computation + """ + result = case_when((df.a.eq(1), 1), (Series([False, True, False]), 2)) + expected = Series([1, 2, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_replacement_list(df): + """ + Test output when replacement is a list + """ + result = case_when( + ([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3]) + ) + expected = Series([1, 2, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_replacement_extension_dtype(df): + """ + Test output when replacement has an extension dtype + """ + result = case_when( + ([True, False, False], 1), + (df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")), + ) + expected = Series([1, 2, np.nan], dtype="Int64") + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_replacement_series(df): + """ + Test output when replacement is a Series + """ + result = case_when( + (np.array([True, False, False]), 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ) + expected = Series([1, 2, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_case_when_multiple_conditions_default_is_not_none(df): + """ + Test output when default is not None + """ + result = case_when( + ([True, False, False], 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + default=-1, + ) + expected = Series([1, 2, -1]) + tm.assert_series_equal(result, expected) + + +def test_case_when_non_range_index(): + """ + Test output if index is not RangeIndex + """ + rng = np.random.default_rng(seed=123) + dates = date_range("1/1/2000", periods=8) + df = DataFrame( + rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"] + ) + result = case_when((df.A.gt(0), df.B), default=5) + result = Series(result, name="A") + expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) + tm.assert_series_equal(result, expected) \ No newline at end of file From 40057c723974eafe56360207aeeaf5bc51e3e716 Mon Sep 17 00:00:00 2001 From: samukweku Date: Sun, 19 Nov 2023 18:16:05 +1100 Subject: [PATCH 02/34] add test for default if Series --- pandas/tests/test_case_when.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pandas/tests/test_case_when.py b/pandas/tests/test_case_when.py index 869e952584ae8..619fe15011176 100644 --- a/pandas/tests/test_case_when.py +++ b/pandas/tests/test_case_when.py @@ -82,9 +82,7 @@ def test_case_when_raise_error_different_index_condition_and_replacements(df): msg = "All replacement objects and condition objects " msg += "should have the same index." with pytest.raises(AssertionError, match=msg): - case_when( - (df.a.eq(1), Series(1)), (Series([False, True, False]), Series(2)) - ) + case_when((df.a.eq(1), Series(1)), (Series([False, True, False]), Series(2))) def test_case_when_single_condition(df): @@ -153,6 +151,19 @@ def test_case_when_multiple_conditions_default_is_not_none(df): tm.assert_series_equal(result, expected) +def test_case_when_multiple_conditions_default_is_a_series(df): + """ + Test output when default is not None + """ + result = case_when( + ([True, False, False], 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + default=Series(-1, index=df.index), + ) + expected = Series([1, 2, -1]) + tm.assert_series_equal(result, expected) + + def test_case_when_non_range_index(): """ Test output if index is not RangeIndex @@ -165,4 +176,4 @@ def test_case_when_non_range_index(): result = case_when((df.A.gt(0), df.B), default=5) result = Series(result, name="A") expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) - tm.assert_series_equal(result, expected) \ No newline at end of file + tm.assert_series_equal(result, expected) From 089bbe6809ffc380ff708544cdf422fdb4c8b398 Mon Sep 17 00:00:00 2001 From: samukweku Date: Thu, 23 Nov 2023 13:38:57 +1100 Subject: [PATCH 03/34] updates based on feedback --- pandas/core/case_when.py | 23 ++++++----------------- pandas/core/series.py | 13 +++++++------ pandas/tests/test_case_when.py | 4 ++-- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py index d2ab59379ac03..22cbb2aae4764 100644 --- a/pandas/core/case_when.py +++ b/pandas/core/case_when.py @@ -81,9 +81,8 @@ def case_when( Name: c, dtype: int64 """ from pandas import Series - from pandas._testing.asserters import assert_index_equal - args = validate_case_when(args=args) + validate_case_when(args=args) conditions, replacements = zip(*args) common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements] @@ -121,28 +120,18 @@ def case_when( cond_length = None if replacement_indices: for left, right in zip(replacement_indices, replacement_indices[1:]): - try: - assert_index_equal(left.index, right.index, check_order=False) - except AssertionError: + if not left.index.equals(right.index): raise AssertionError( "All replacement objects must have the same index." ) if cond_indices: for left, right in zip(cond_indices, cond_indices[1:]): - try: - assert_index_equal(left.index, right.index, check_order=False) - except AssertionError: + if not left.index.equals(right.index): raise AssertionError( "All condition objects must have the same index." ) if replacement_indices: - try: - assert_index_equal( - replacement_indices[0].index, - cond_indices[0].index, - check_order=False, - ) - except AssertionError: + if not replacement_indices[0].index.equals(cond_indices[0].index): raise AssertionError( "All replacement objects and condition objects " "should have the same index." @@ -191,7 +180,7 @@ def case_when( ) except Exception as error: raise ValueError( - f"condition{position} and replacement{position} failed to evaluate." + f"Failed to apply condition{position} and replacement{position}." ) from error return default @@ -215,4 +204,4 @@ def validate_case_when(args: tuple) -> tuple: "a condition and replacement; " f"got length {len(entry)}." ) - return args \ No newline at end of file + return None diff --git a/pandas/core/series.py b/pandas/core/series.py index 4dbb39163b00e..e12d28c6df6d2 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -28,10 +28,7 @@ from pandas._config import using_copy_on_write from pandas._config.config import _get_option -from pandas.core.case_when import ( - case_when, - validate_case_when, -) + from pandas._libs import ( lib, properties, @@ -110,6 +107,10 @@ from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.arrays.sparse import SparseAccessor from pandas.core.arrays.string_ import StringDtype +from pandas.core.case_when import ( + case_when, + validate_case_when, +) from pandas.core.construction import ( extract_array, sanitize_array, @@ -5634,7 +5635,7 @@ def case_when( 3 2 Name: c, dtype: int64 """ - args = validate_case_when(args) + validate_case_when(args) args = [ ( com.apply_if_callable(condition, self), @@ -5643,7 +5644,7 @@ def case_when( for condition, replacement in args ] return case_when(*args, default=self, level=None) - + # error: Cannot determine type of 'isna' @doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] def isna(self) -> Series: diff --git a/pandas/tests/test_case_when.py b/pandas/tests/test_case_when.py index 619fe15011176..87977025078be 100644 --- a/pandas/tests/test_case_when.py +++ b/pandas/tests/test_case_when.py @@ -25,7 +25,7 @@ def test_case_when_no_args(): """ msg = "provide at least one boolean condition, " msg += "with a corresponding replacement." - with pytest.raises(ValueError, match=msg): + with pytest.raises(ValueError, match=msg): # GH39154 case_when() @@ -44,7 +44,7 @@ def test_case_when_raise_error_from_mask(df): """ Raise Error from within Series.mask """ - msg = "condition0 and replacement0 failed to evaluate." + msg = "Failed to apply condition0 and replacement0." with pytest.raises(ValueError, match=msg): case_when((df["a"].eq(1), df)) From b95ce55b8e129aff5c8cb0c24e33452e1a6dc139 Mon Sep 17 00:00:00 2001 From: samukweku Date: Thu, 30 Nov 2023 23:11:10 +1100 Subject: [PATCH 04/34] updates based on feedback --- doc/source/whatsnew/v2.2.0.rst | 17 +++++++++++++++++ pandas/core/case_when.py | 11 ++++++++--- pandas/core/series.py | 2 +- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 8cb4b3f24d435..d88020a5400c7 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -13,6 +13,23 @@ including other versions of pandas. Enhancements ~~~~~~~~~~~~ +.. _whatsnew_220.enhancements.case_when: + +Create a pandas Series based on one or more conditions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`) + +.. ipython:: python + + import pandas as pd + + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + pd.case_when( + (df.a == 1, 'first'), # condition, replacement + (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement + default = 'default', # optional + ) .. _whatsnew_220.enhancements.adbc_support: diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py index 22cbb2aae4764..30a5b1939e876 100644 --- a/pandas/core/case_when.py +++ b/pandas/core/case_when.py @@ -17,12 +17,16 @@ from pandas.core.construction import array as pd_array if TYPE_CHECKING: - from pandas._typing import Series + from pandas._typing import ( + ArrayLike, + Scalar, + Series, + ) def case_when( - *args: tuple[tuple], - default=lib.no_default, + *args: tuple[tuple[tuple[ArrayLike], tuple[ArrayLike | Scalar]]], + default: ArrayLike | Scalar = lib.no_default, ) -> Series: """ Replace values where the conditions are True. @@ -189,6 +193,7 @@ def validate_case_when(args: tuple) -> tuple: """ Validates the variable arguments for the case_when function. """ + if not len(args): raise ValueError( "provide at least one boolean condition, " diff --git a/pandas/core/series.py b/pandas/core/series.py index 377eb0b5023b9..c2bb0b55d203a 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5588,7 +5588,7 @@ def between( def case_when( self, - *args: tuple[tuple], + *args: tuple[tuple[tuple[ArrayLike], tuple[ArrayLike | Scalar]]], ) -> Series: """ Replace values where the conditions are True. From 8be4349c134b3a42988d7928ef22ad772879ae03 Mon Sep 17 00:00:00 2001 From: samukweku Date: Fri, 1 Dec 2023 09:19:58 +1100 Subject: [PATCH 05/34] update typing hints for *args, based on feedback --- pandas/core/case_when.py | 2 +- pandas/core/series.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py index 30a5b1939e876..4bff4ab167f60 100644 --- a/pandas/core/case_when.py +++ b/pandas/core/case_when.py @@ -25,7 +25,7 @@ def case_when( - *args: tuple[tuple[tuple[ArrayLike], tuple[ArrayLike | Scalar]]], + *args: tuple[tuple[tuple[ArrayLike, ArrayLike | Scalar]]], default: ArrayLike | Scalar = lib.no_default, ) -> Series: """ diff --git a/pandas/core/series.py b/pandas/core/series.py index 615e39ee765b9..808926046bb5e 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5592,7 +5592,7 @@ def between( def case_when( self, - *args: tuple[tuple[tuple[ArrayLike], tuple[ArrayLike | Scalar]]], + *args: tuple[tuple[tuple[ArrayLike, ArrayLike | Scalar]]], ) -> Series: """ Replace values where the conditions are True. From ec180868c1181e5de6bb83c56ecc32eedb369381 Mon Sep 17 00:00:00 2001 From: samukweku Date: Fri, 1 Dec 2023 17:36:50 +1100 Subject: [PATCH 06/34] update typehints; add caselist argument - based on feedback --- doc/source/whatsnew/v2.2.0.rst | 6 ++-- pandas/core/case_when.py | 33 +++++++++-------- pandas/core/series.py | 17 ++++----- pandas/tests/test_case_when.py | 66 +++++++++++++++++++++------------- 4 files changed, 70 insertions(+), 52 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index e8d833a70f3db..086b2c14f7aa7 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -26,9 +26,9 @@ The :func:`case_when` function has been added to create a Series object based on df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) pd.case_when( - (df.a == 1, 'first'), # condition, replacement - (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement - default = 'default', # optional + caselist = [(df.a == 1, 'first'), # condition, replacement + (df.a.gt(1) & df.b.eq(5), 'second')], # condition, replacement + default = 'default', # optional ) .. _whatsnew_220.enhancements.adbc_support: diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py index 4bff4ab167f60..cb304605a5076 100644 --- a/pandas/core/case_when.py +++ b/pandas/core/case_when.py @@ -25,7 +25,7 @@ def case_when( - *args: tuple[tuple[tuple[ArrayLike, ArrayLike | Scalar]]], + caselist: list[tuple[ArrayLike, ArrayLike | Scalar]], default: ArrayLike | Scalar = lib.no_default, ) -> Series: """ @@ -33,8 +33,8 @@ def case_when( Parameters ---------- - *args : tuple(s) of array-like, scalar - Variable argument of tuples of conditions and expected replacements. + caselist : A list of tuple(s) of array-like, scalar + List of tuples of conditions and expected replacements. Takes the form: ``(condition0, replacement0)``, ``(condition1, replacement1)``, ... . ``condition`` should be a 1-D boolean array. @@ -75,9 +75,8 @@ def case_when( 2 1 4 8 3 2 5 9 - >>> pd.case_when((df.a.gt(0), df.a), # condition, replacement - ... (df.b.gt(0), df.b), # condition, replacement - ... default=df.c) # optional + >>> caselist = [(df.a.gt(0), df.a), (df.b.gt(0), df.b)] # condition, replacement + >>> pd.case_when(caselist=caselist, default=df.c) # default is optional 0 6 1 3 2 1 @@ -86,9 +85,9 @@ def case_when( """ from pandas import Series - validate_case_when(args=args) + validate_case_when(caselist=caselist) - conditions, replacements = zip(*args) + conditions, replacements = zip(*caselist) common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements] if default is not lib.no_default: @@ -189,24 +188,30 @@ def case_when( return default -def validate_case_when(args: tuple) -> tuple: +def validate_case_when(caselist: list) -> None: """ Validates the variable arguments for the case_when function. """ - if not len(args): + if not isinstance(caselist, list): + raise TypeError( + f"The caselist argument should be a list; instead got {type(caselist)}" + ) + + if not len(caselist): raise ValueError( "provide at least one boolean condition, " "with a corresponding replacement." ) - for num, entry in enumerate(args): + for num, entry in enumerate(caselist): if not isinstance(entry, tuple): - raise TypeError(f"Argument {num} must be a tuple; got {type(entry)}.") + raise TypeError( + f"Argument {num} must be a tuple; instead got {type(entry)}." + ) if len(entry) != 2: raise ValueError( f"Argument {num} must have length 2; " "a condition and replacement; " - f"got length {len(entry)}." + f"instead got length {len(entry)}." ) - return None diff --git a/pandas/core/series.py b/pandas/core/series.py index 808926046bb5e..72e86db4733d2 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5592,15 +5592,15 @@ def between( def case_when( self, - *args: tuple[tuple[tuple[ArrayLike, ArrayLike | Scalar]]], + caselist: list[tuple[ArrayLike, ArrayLike | Scalar]], ) -> Series: """ Replace values where the conditions are True. Parameters ---------- - *args : tuple(s) of array-like, scalar. - Variable argument of tuples of conditions and expected replacements. + caselist : A list of tuple(s) of array-like, scalar. + A list of tuples of conditions and expected replacements. Takes the form: ``(condition0, replacement0)``, ``(condition1, replacement1)``, ... . ``condition`` should be a 1-D boolean array-like object @@ -5619,9 +5619,6 @@ def case_when( If there are multiple replacement options, and they are Series, they must have the same index. - level : int, default None - Alignment level if needed. - .. versionadded:: 2.2.0 Returns @@ -5654,15 +5651,15 @@ def case_when( 3 2 Name: c, dtype: int64 """ - validate_case_when(args) - args = [ + validate_case_when(caselist) + caselist = [ ( com.apply_if_callable(condition, self), com.apply_if_callable(replacement, self), ) - for condition, replacement in args + for condition, replacement in caselist ] - return case_when(*args, default=self, level=None) + return case_when(caselist=caselist, default=self, level=None) # error: Cannot determine type of 'isna' @doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] diff --git a/pandas/tests/test_case_when.py b/pandas/tests/test_case_when.py index 87977025078be..8a96e5b260851 100644 --- a/pandas/tests/test_case_when.py +++ b/pandas/tests/test_case_when.py @@ -19,25 +19,35 @@ def df(): return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) -def test_case_when_no_args(): +def test_case_when_caselist_is_not_a_list(): """ - Raise ValueError if no args is provided. + Raise ValueError if caselist is not a list. + """ + msg = "The caselist argument should be a list; " + msg += "instead got.+" + with pytest.raises(TypeError, match=msg): # GH39154 + case_when(caselist=()) + + +def test_case_when_no_caselist(): + """ + Raise ValueError if no caselist is provided. """ msg = "provide at least one boolean condition, " msg += "with a corresponding replacement." with pytest.raises(ValueError, match=msg): # GH39154 - case_when() + case_when([]) -def test_case_when_odd_args(df): +def test_case_when_odd_caselist(df): """ - Raise ValueError if no of args is odd. + Raise ValueError if no of caselist is odd. """ msg = "Argument 0 must have length 2; " - msg += "a condition and replacement; got length 3." + msg += "a condition and replacement; instead got length 3." with pytest.raises(ValueError, match=msg): - case_when((df["a"].eq(1), 1, df.a.gt(1))) + case_when([(df["a"].eq(1), 1, df.a.gt(1))]) def test_case_when_raise_error_from_mask(df): @@ -46,7 +56,7 @@ def test_case_when_raise_error_from_mask(df): """ msg = "Failed to apply condition0 and replacement0." with pytest.raises(ValueError, match=msg): - case_when((df["a"].eq(1), df)) + case_when([(df["a"].eq(1), df)]) def test_case_when_error_multiple_replacements_series(df): @@ -57,8 +67,10 @@ def test_case_when_error_multiple_replacements_series(df): AssertionError, match="All replacement objects must have the same index." ): case_when( - ([True, False, False], Series(1)), - (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + [ + ([True, False, False], Series(1)), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ], ) @@ -70,8 +82,10 @@ def test_case_when_error_multiple_conditions_series(df): AssertionError, match="All condition objects must have the same index." ): case_when( - (Series([True, False, False], index=[2, 3, 4]), 1), - (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + [ + (Series([True, False, False], index=[2, 3, 4]), 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ], ) @@ -82,14 +96,14 @@ def test_case_when_raise_error_different_index_condition_and_replacements(df): msg = "All replacement objects and condition objects " msg += "should have the same index." with pytest.raises(AssertionError, match=msg): - case_when((df.a.eq(1), Series(1)), (Series([False, True, False]), Series(2))) + case_when([(df.a.eq(1), Series(1)), (Series([False, True, False]), Series(2))]) def test_case_when_single_condition(df): """ Test output on a single condition. """ - result = case_when((df.a.eq(1), 1)) + result = case_when([(df.a.eq(1), 1)]) expected = Series([1, np.nan, np.nan]) tm.assert_series_equal(result, expected) @@ -98,7 +112,7 @@ def test_case_when_multiple_conditions(df): """ Test output when booleans are derived from a computation """ - result = case_when((df.a.eq(1), 1), (Series([False, True, False]), 2)) + result = case_when([(df.a.eq(1), 1), (Series([False, True, False]), 2)]) expected = Series([1, 2, np.nan]) tm.assert_series_equal(result, expected) @@ -108,7 +122,7 @@ def test_case_when_multiple_conditions_replacement_list(df): Test output when replacement is a list """ result = case_when( - ([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3]) + [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3])] ) expected = Series([1, 2, np.nan]) tm.assert_series_equal(result, expected) @@ -119,8 +133,10 @@ def test_case_when_multiple_conditions_replacement_extension_dtype(df): Test output when replacement has an extension dtype """ result = case_when( - ([True, False, False], 1), - (df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")), + [ + ([True, False, False], 1), + (df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")), + ], ) expected = Series([1, 2, np.nan], dtype="Int64") tm.assert_series_equal(result, expected) @@ -131,8 +147,10 @@ def test_case_when_multiple_conditions_replacement_series(df): Test output when replacement is a Series """ result = case_when( - (np.array([True, False, False]), 1), - (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + [ + (np.array([True, False, False]), 1), + (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + ], ) expected = Series([1, 2, np.nan]) tm.assert_series_equal(result, expected) @@ -143,8 +161,7 @@ def test_case_when_multiple_conditions_default_is_not_none(df): Test output when default is not None """ result = case_when( - ([True, False, False], 1), - (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3]))], default=-1, ) expected = Series([1, 2, -1]) @@ -156,8 +173,7 @@ def test_case_when_multiple_conditions_default_is_a_series(df): Test output when default is not None """ result = case_when( - ([True, False, False], 1), - (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), + [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3]))], default=Series(-1, index=df.index), ) expected = Series([1, 2, -1]) @@ -173,7 +189,7 @@ def test_case_when_non_range_index(): df = DataFrame( rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"] ) - result = case_when((df.A.gt(0), df.B), default=5) + result = case_when([(df.A.gt(0), df.B)], default=5) result = Series(result, name="A") expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) tm.assert_series_equal(result, expected) From 0b72fbb5c6e2ae51afd52691d4b1f52abf8998ad Mon Sep 17 00:00:00 2001 From: samukweku Date: Fri, 1 Dec 2023 17:39:40 +1100 Subject: [PATCH 07/34] cleanup docstrings --- pandas/core/case_when.py | 5 ++--- pandas/core/series.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py index cb304605a5076..2a1ead72c52a8 100644 --- a/pandas/core/case_when.py +++ b/pandas/core/case_when.py @@ -33,8 +33,7 @@ def case_when( Parameters ---------- - caselist : A list of tuple(s) of array-like, scalar - List of tuples of conditions and expected replacements. + caselist : List of tuples of conditions and expected replacements. Takes the form: ``(condition0, replacement0)``, ``(condition1, replacement1)``, ... . ``condition`` should be a 1-D boolean array. @@ -190,7 +189,7 @@ def case_when( def validate_case_when(caselist: list) -> None: """ - Validates the variable arguments for the case_when function. + Validates the arguments for the case_when function. """ if not isinstance(caselist, list): diff --git a/pandas/core/series.py b/pandas/core/series.py index 72e86db4733d2..0cd5ba346dcd7 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5599,8 +5599,7 @@ def case_when( Parameters ---------- - caselist : A list of tuple(s) of array-like, scalar. - A list of tuples of conditions and expected replacements. + caselist : A list of tuples of conditions and expected replacements. Takes the form: ``(condition0, replacement0)``, ``(condition1, replacement1)``, ... . ``condition`` should be a 1-D boolean array-like object @@ -5659,7 +5658,7 @@ def case_when( ) for condition, replacement in caselist ] - return case_when(caselist=caselist, default=self, level=None) + return case_when(caselist=caselist, default=self) # error: Cannot determine type of 'isna' @doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] From a44148129292b3b75757b5ee386b41c4aa5eee10 Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 16 Dec 2023 00:07:57 +1100 Subject: [PATCH 08/34] support method only for case_when --- doc/source/whatsnew/v2.2.0.rst | 4 +- pandas/__init__.py | 2 - pandas/core/api.py | 2 - pandas/core/case_when.py | 216 ------------------ pandas/core/series.py | 73 +++++- pandas/tests/api/test_api.py | 1 - .../{ => series/methods}/test_case_when.py | 75 ++---- 7 files changed, 85 insertions(+), 288 deletions(-) delete mode 100644 pandas/core/case_when.py rename pandas/tests/{ => series/methods}/test_case_when.py (65%) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 4881962882c88..e103fa3edf5a7 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -25,10 +25,10 @@ The :func:`case_when` function has been added to create a Series object based on import pandas as pd df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) - pd.case_when( + default=pd.Series(['default'],index=df.index) + default.case_when( caselist = [(df.a == 1, 'first'), # condition, replacement (df.a.gt(1) & df.b.eq(5), 'second')], # condition, replacement - default = 'default', # optional ) .. _whatsnew_220.enhancements.adbc_support: diff --git a/pandas/__init__.py b/pandas/__init__.py index 08cb057090b23..7fab662ed2de4 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -73,7 +73,6 @@ notnull, # indexes Index, - case_when, CategoricalIndex, RangeIndex, MultiIndex, @@ -254,7 +253,6 @@ "ArrowDtype", "BooleanDtype", "Categorical", - "case_when", "CategoricalDtype", "CategoricalIndex", "DataFrame", diff --git a/pandas/core/api.py b/pandas/core/api.py index 5c7271c155459..2cfe5ffc0170d 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -42,7 +42,6 @@ UInt64Dtype, ) from pandas.core.arrays.string_ import StringDtype -from pandas.core.case_when import case_when from pandas.core.construction import array from pandas.core.flags import Flags from pandas.core.groupby import ( @@ -87,7 +86,6 @@ "bdate_range", "BooleanDtype", "Categorical", - "case_when", "CategoricalDtype", "CategoricalIndex", "DataFrame", diff --git a/pandas/core/case_when.py b/pandas/core/case_when.py deleted file mode 100644 index 2a1ead72c52a8..0000000000000 --- a/pandas/core/case_when.py +++ /dev/null @@ -1,216 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from pandas._libs import lib - -from pandas.core.dtypes.cast import ( - construct_1d_arraylike_from_scalar, - find_common_type, - infer_dtype_from, -) -from pandas.core.dtypes.common import is_scalar -from pandas.core.dtypes.generic import ABCSeries - -from pandas.core.construction import array as pd_array - -if TYPE_CHECKING: - from pandas._typing import ( - ArrayLike, - Scalar, - Series, - ) - - -def case_when( - caselist: list[tuple[ArrayLike, ArrayLike | Scalar]], - default: ArrayLike | Scalar = lib.no_default, -) -> Series: - """ - Replace values where the conditions are True. - - Parameters - ---------- - caselist : List of tuples of conditions and expected replacements. - Takes the form: ``(condition0, replacement0)``, - ``(condition1, replacement1)``, ... . - ``condition`` should be a 1-D boolean array. - When multiple boolean conditions are satisfied, - the first replacement is used. - If ``condition`` is a Series, and the equivalent ``replacement`` - is a Series, they must have the same index. - If there are multiple replacement options, - and they are Series, they must have the same index. - - default : scalar, array-like, default None - If provided, it is the replacement value to use - if all conditions evaluate to False. - If not specified, entries will be filled with the - corresponding NULL value. - - .. versionadded:: 2.2.0 - - Returns - ------- - Series - - See Also - -------- - Series.mask : Replace values where the condition is True. - - Examples - -------- - >>> df = pd.DataFrame({ - ... "a": [0,0,1,2], - ... "b": [0,3,4,5], - ... "c": [6,7,8,9] - ... }) - >>> df - a b c - 0 0 0 6 - 1 0 3 7 - 2 1 4 8 - 3 2 5 9 - - >>> caselist = [(df.a.gt(0), df.a), (df.b.gt(0), df.b)] # condition, replacement - >>> pd.case_when(caselist=caselist, default=df.c) # default is optional - 0 6 - 1 3 - 2 1 - 3 2 - Name: c, dtype: int64 - """ - from pandas import Series - - validate_case_when(caselist=caselist) - - conditions, replacements = zip(*caselist) - common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements] - - if default is not lib.no_default: - arg_dtype, _ = infer_dtype_from(default) - common_dtypes.append(arg_dtype) - else: - default = None - if len(set(common_dtypes)) > 1: - common_dtypes = find_common_type(common_dtypes) - updated_replacements = [] - for condition, replacement in zip(conditions, replacements): - if is_scalar(replacement): - replacement = construct_1d_arraylike_from_scalar( - value=replacement, length=len(condition), dtype=common_dtypes - ) - elif isinstance(replacement, ABCSeries): - replacement = replacement.astype(common_dtypes) - else: - replacement = pd_array(replacement, dtype=common_dtypes) - updated_replacements.append(replacement) - replacements = updated_replacements - if (default is not None) and isinstance(default, ABCSeries): - default = default.astype(common_dtypes) - else: - common_dtypes = common_dtypes[0] - if not isinstance(default, ABCSeries): - cond_indices = [cond for cond in conditions if isinstance(cond, ABCSeries)] - replacement_indices = [ - replacement - for replacement in replacements - if isinstance(replacement, ABCSeries) - ] - cond_length = None - if replacement_indices: - for left, right in zip(replacement_indices, replacement_indices[1:]): - if not left.index.equals(right.index): - raise AssertionError( - "All replacement objects must have the same index." - ) - if cond_indices: - for left, right in zip(cond_indices, cond_indices[1:]): - if not left.index.equals(right.index): - raise AssertionError( - "All condition objects must have the same index." - ) - if replacement_indices: - if not replacement_indices[0].index.equals(cond_indices[0].index): - raise AssertionError( - "All replacement objects and condition objects " - "should have the same index." - ) - else: - conditions = [ - np.asanyarray(cond) if not hasattr(cond, "shape") else cond - for cond in conditions - ] - cond_length = {len(cond) for cond in conditions} - if len(cond_length) > 1: - raise ValueError("The boolean conditions should have the same length.") - cond_length = len(conditions[0]) - if not is_scalar(default): - if len(default) != cond_length: - raise ValueError( - "length of `default` does not match the length " - "of any of the conditions." - ) - if not replacement_indices: - for num, replacement in enumerate(replacements): - if is_scalar(replacement): - continue - if not hasattr(replacement, "shape"): - replacement = np.asanyarray(replacement) - if len(replacement) != cond_length: - raise ValueError( - f"Length of condition{num} does not match " - f"the length of replacement{num}; " - f"{cond_length} != {len(replacement)}" - ) - if cond_indices: - default_index = cond_indices[0].index - elif replacement_indices: - default_index = replacement_indices[0].index - else: - default_index = range(cond_length) - default = Series(default, index=default_index, dtype=common_dtypes) - counter = reversed(range(len(conditions))) - for position, condition, replacement in zip( - counter, conditions[::-1], replacements[::-1] - ): - try: - default = default.mask( - condition, other=replacement, axis=0, inplace=False, level=None - ) - except Exception as error: - raise ValueError( - f"Failed to apply condition{position} and replacement{position}." - ) from error - return default - - -def validate_case_when(caselist: list) -> None: - """ - Validates the arguments for the case_when function. - """ - - if not isinstance(caselist, list): - raise TypeError( - f"The caselist argument should be a list; instead got {type(caselist)}" - ) - - if not len(caselist): - raise ValueError( - "provide at least one boolean condition, " - "with a corresponding replacement." - ) - - for num, entry in enumerate(caselist): - if not isinstance(entry, tuple): - raise TypeError( - f"Argument {num} must be a tuple; instead got {type(entry)}." - ) - if len(entry) != 2: - raise ValueError( - f"Argument {num} must have length 2; " - "a condition and replacement; " - f"instead got length {len(entry)}." - ) diff --git a/pandas/core/series.py b/pandas/core/series.py index 0a86272ae8daf..95151c738af06 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -67,6 +67,9 @@ from pandas.core.dtypes.astype import astype_is_view from pandas.core.dtypes.cast import ( LossySetitemError, + construct_1d_arraylike_from_scalar, + find_common_type, + infer_dtype_from, maybe_box_native, maybe_cast_pointwise_result, ) @@ -84,7 +87,10 @@ CategoricalDtype, ExtensionDtype, ) -from pandas.core.dtypes.generic import ABCDataFrame +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) from pandas.core.dtypes.inference import is_hashable from pandas.core.dtypes.missing import ( isna, @@ -112,11 +118,8 @@ from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.arrays.sparse import SparseAccessor from pandas.core.arrays.string_ import StringDtype -from pandas.core.case_when import ( - case_when, - validate_case_when, -) from pandas.core.construction import ( + array as pd_array, extract_array, sanitize_array, ) @@ -5680,7 +5683,28 @@ def case_when( 3 2 Name: c, dtype: int64 """ - validate_case_when(caselist) + if not isinstance(caselist, list): + raise TypeError( + f"The caselist argument should be a list; instead got {type(caselist)}" + ) + + if not len(caselist): + raise ValueError( + "provide at least one boolean condition, " + "with a corresponding replacement." + ) + + for num, entry in enumerate(caselist): + if not isinstance(entry, tuple): + raise TypeError( + f"Argument {num} must be a tuple; instead got {type(entry)}." + ) + if len(entry) != 2: + raise ValueError( + f"Argument {num} must have length 2; " + "a condition and replacement; " + f"instead got length {len(entry)}." + ) caselist = [ ( com.apply_if_callable(condition, self), @@ -5688,7 +5712,42 @@ def case_when( ) for condition, replacement in caselist ] - return case_when(caselist=caselist, default=self) + default = self.copy() + conditions, replacements = zip(*caselist) + common_dtypes = [ + infer_dtype_from(replacement)[0] for replacement in replacements + ] + arg_dtype, _ = infer_dtype_from(default) + common_dtypes.append(arg_dtype) + if len(set(common_dtypes)) > 1: + common_dtypes = find_common_type(common_dtypes) + updated_replacements = [] + for condition, replacement in zip(conditions, replacements): + if is_scalar(replacement): + replacement = construct_1d_arraylike_from_scalar( + value=replacement, length=len(condition), dtype=common_dtypes + ) + elif isinstance(replacement, ABCSeries): + replacement = replacement.astype(common_dtypes) + else: + replacement = pd_array(replacement, dtype=common_dtypes) + updated_replacements.append(replacement) + replacements = updated_replacements + default = default.astype(common_dtypes) + + counter = reversed(range(len(conditions))) + for position, condition, replacement in zip( + counter, conditions[::-1], replacements[::-1] + ): + try: + default = default.mask( + condition, other=replacement, axis=0, inplace=False, level=None + ) + except Exception as error: + raise ValueError( + f"Failed to apply condition{position} and replacement{position}." + ) from error + return default # error: Cannot determine type of 'isna' @doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type] diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index ae4fe5d56ebf6..60bcb97aaa364 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -106,7 +106,6 @@ class TestPDApi(Base): funcs = [ "array", "bdate_range", - "case_when", "concat", "crosstab", "cut", diff --git a/pandas/tests/test_case_when.py b/pandas/tests/series/methods/test_case_when.py similarity index 65% rename from pandas/tests/test_case_when.py rename to pandas/tests/series/methods/test_case_when.py index 8a96e5b260851..bc6d319a5e003 100644 --- a/pandas/tests/test_case_when.py +++ b/pandas/tests/series/methods/test_case_when.py @@ -5,7 +5,6 @@ DataFrame, Series, array as pd_array, - case_when, date_range, ) import pandas._testing as tm @@ -19,24 +18,24 @@ def df(): return DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) -def test_case_when_caselist_is_not_a_list(): +def test_case_when_caselist_is_not_a_list(df): """ Raise ValueError if caselist is not a list. """ msg = "The caselist argument should be a list; " msg += "instead got.+" with pytest.raises(TypeError, match=msg): # GH39154 - case_when(caselist=()) + df["a"].case_when(caselist=()) -def test_case_when_no_caselist(): +def test_case_when_no_caselist(df): """ Raise ValueError if no caselist is provided. """ msg = "provide at least one boolean condition, " msg += "with a corresponding replacement." with pytest.raises(ValueError, match=msg): # GH39154 - case_when([]) + df["a"].case_when([]) def test_case_when_odd_caselist(df): @@ -47,7 +46,7 @@ def test_case_when_odd_caselist(df): msg += "a condition and replacement; instead got length 3." with pytest.raises(ValueError, match=msg): - case_when([(df["a"].eq(1), 1, df.a.gt(1))]) + df["a"].case_when([(df["a"].eq(1), 1, df.a.gt(1))]) def test_case_when_raise_error_from_mask(df): @@ -56,54 +55,14 @@ def test_case_when_raise_error_from_mask(df): """ msg = "Failed to apply condition0 and replacement0." with pytest.raises(ValueError, match=msg): - case_when([(df["a"].eq(1), df)]) - - -def test_case_when_error_multiple_replacements_series(df): - """ - Test output when the replacements indices are different. - """ - with pytest.raises( - AssertionError, match="All replacement objects must have the same index." - ): - case_when( - [ - ([True, False, False], Series(1)), - (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), - ], - ) - - -def test_case_when_error_multiple_conditions_series(df): - """ - Test output when the conditions indices are different. - """ - with pytest.raises( - AssertionError, match="All condition objects must have the same index." - ): - case_when( - [ - (Series([True, False, False], index=[2, 3, 4]), 1), - (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), - ], - ) - - -def test_case_when_raise_error_different_index_condition_and_replacements(df): - """ - Raise if the replacement index and condition index are different. - """ - msg = "All replacement objects and condition objects " - msg += "should have the same index." - with pytest.raises(AssertionError, match=msg): - case_when([(df.a.eq(1), Series(1)), (Series([False, True, False]), Series(2))]) + df["a"].case_when([(df["a"].eq(1), [1, 2])]) def test_case_when_single_condition(df): """ Test output on a single condition. """ - result = case_when([(df.a.eq(1), 1)]) + result = Series([np.nan, np.nan, np.nan]).case_when([(df.a.eq(1), 1)]) expected = Series([1, np.nan, np.nan]) tm.assert_series_equal(result, expected) @@ -112,7 +71,9 @@ def test_case_when_multiple_conditions(df): """ Test output when booleans are derived from a computation """ - result = case_when([(df.a.eq(1), 1), (Series([False, True, False]), 2)]) + result = Series([np.nan, np.nan, np.nan]).case_when( + [(df.a.eq(1), 1), (Series([False, True, False]), 2)] + ) expected = Series([1, 2, np.nan]) tm.assert_series_equal(result, expected) @@ -121,7 +82,7 @@ def test_case_when_multiple_conditions_replacement_list(df): """ Test output when replacement is a list """ - result = case_when( + result = Series([np.nan, np.nan, np.nan]).case_when( [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), [1, 2, 3])] ) expected = Series([1, 2, np.nan]) @@ -132,13 +93,13 @@ def test_case_when_multiple_conditions_replacement_extension_dtype(df): """ Test output when replacement has an extension dtype """ - result = case_when( + result = Series([np.nan, np.nan, np.nan]).case_when( [ ([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), pd_array([1, 2, 3], dtype="Int64")), ], ) - expected = Series([1, 2, np.nan], dtype="Int64") + expected = Series([1, 2, np.nan], dtype="Float64") tm.assert_series_equal(result, expected) @@ -146,7 +107,7 @@ def test_case_when_multiple_conditions_replacement_series(df): """ Test output when replacement is a Series """ - result = case_when( + result = Series([np.nan, np.nan, np.nan]).case_when( [ (np.array([True, False, False]), 1), (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3])), @@ -160,9 +121,8 @@ def test_case_when_multiple_conditions_default_is_not_none(df): """ Test output when default is not None """ - result = case_when( + result = Series([-1, -1, -1]).case_when( [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3]))], - default=-1, ) expected = Series([1, 2, -1]) tm.assert_series_equal(result, expected) @@ -172,9 +132,8 @@ def test_case_when_multiple_conditions_default_is_a_series(df): """ Test output when default is not None """ - result = case_when( + result = Series(-1, index=df.index).case_when( [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3]))], - default=Series(-1, index=df.index), ) expected = Series([1, 2, -1]) tm.assert_series_equal(result, expected) @@ -189,7 +148,7 @@ def test_case_when_non_range_index(): df = DataFrame( rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"] ) - result = case_when([(df.A.gt(0), df.B)], default=5) + result = Series(5, index=df.index).case_when([(df.A.gt(0), df.B)]) result = Series(result, name="A") expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) tm.assert_series_equal(result, expected) From 29ad697cc5448eb725461ae3a9f3aec9a88adfae Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 16 Dec 2023 00:17:52 +1100 Subject: [PATCH 09/34] minor update --- pandas/core/series.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index 95151c738af06..aab8e0c8b7167 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5714,11 +5714,7 @@ def case_when( ] default = self.copy() conditions, replacements = zip(*caselist) - common_dtypes = [ - infer_dtype_from(replacement)[0] for replacement in replacements - ] - arg_dtype, _ = infer_dtype_from(default) - common_dtypes.append(arg_dtype) + common_dtypes = [infer_dtype_from(arg)[0] for arg in [*replacements, default]] if len(set(common_dtypes)) > 1: common_dtypes = find_common_type(common_dtypes) updated_replacements = [] From bf740f9be5c3c33ae09ab33cb8e17100e13069dc Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 16 Dec 2023 00:19:38 +1100 Subject: [PATCH 10/34] fix test --- pandas/tests/series/methods/test_case_when.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandas/tests/series/methods/test_case_when.py b/pandas/tests/series/methods/test_case_when.py index bc6d319a5e003..45f31dfacc567 100644 --- a/pandas/tests/series/methods/test_case_when.py +++ b/pandas/tests/series/methods/test_case_when.py @@ -148,7 +148,6 @@ def test_case_when_non_range_index(): df = DataFrame( rng.standard_normal(size=(8, 4)), index=dates, columns=["A", "B", "C", "D"] ) - result = Series(5, index=df.index).case_when([(df.A.gt(0), df.B)]) - result = Series(result, name="A") + result = Series(5, index=df.index, name="A").case_when([(df.A.gt(0), df.B)]) expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) tm.assert_series_equal(result, expected) From 264a675a0fc1eaa7c7b10afd24e0a3f46092bf2a Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 16 Dec 2023 00:22:03 +1100 Subject: [PATCH 11/34] remove redundant tests --- pandas/tests/series/methods/test_case_when.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/pandas/tests/series/methods/test_case_when.py b/pandas/tests/series/methods/test_case_when.py index 45f31dfacc567..24f15648a7138 100644 --- a/pandas/tests/series/methods/test_case_when.py +++ b/pandas/tests/series/methods/test_case_when.py @@ -117,28 +117,6 @@ def test_case_when_multiple_conditions_replacement_series(df): tm.assert_series_equal(result, expected) -def test_case_when_multiple_conditions_default_is_not_none(df): - """ - Test output when default is not None - """ - result = Series([-1, -1, -1]).case_when( - [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3]))], - ) - expected = Series([1, 2, -1]) - tm.assert_series_equal(result, expected) - - -def test_case_when_multiple_conditions_default_is_a_series(df): - """ - Test output when default is not None - """ - result = Series(-1, index=df.index).case_when( - [([True, False, False], 1), (df["a"].gt(1) & df["b"].eq(5), Series([1, 2, 3]))], - ) - expected = Series([1, 2, -1]) - tm.assert_series_equal(result, expected) - - def test_case_when_non_range_index(): """ Test output if index is not RangeIndex From 2a3035e61a516c45a35dddc10dfae5b1714cf28a Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 16 Dec 2023 00:23:43 +1100 Subject: [PATCH 12/34] cleanup docs --- pandas/core/series.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index aab8e0c8b7167..f33b48ad33d36 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5646,10 +5646,6 @@ def case_when( and should return a scalar or Series. The callable must not change the input Series (though pandas doesn`t check it). - If ``condition`` is a Series, and the equivalent ``replacement`` - is a Series, they must have the same index. - If there are multiple replacement options, - and they are Series, they must have the same index. .. versionadded:: 2.2.0 From 5e333045da2d744b166f4da1573ae993ee8122c0 Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 23 Dec 2023 06:53:55 +1100 Subject: [PATCH 13/34] use singular version - common_dtype --- pandas/core/series.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index f33b48ad33d36..ccd97dd293a72 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5712,20 +5712,20 @@ def case_when( conditions, replacements = zip(*caselist) common_dtypes = [infer_dtype_from(arg)[0] for arg in [*replacements, default]] if len(set(common_dtypes)) > 1: - common_dtypes = find_common_type(common_dtypes) + common_dtype = find_common_type(common_dtypes) updated_replacements = [] for condition, replacement in zip(conditions, replacements): if is_scalar(replacement): replacement = construct_1d_arraylike_from_scalar( - value=replacement, length=len(condition), dtype=common_dtypes + value=replacement, length=len(condition), dtype=common_dtype ) elif isinstance(replacement, ABCSeries): - replacement = replacement.astype(common_dtypes) + replacement = replacement.astype(common_dtype) else: - replacement = pd_array(replacement, dtype=common_dtypes) + replacement = pd_array(replacement, dtype=common_dtype) updated_replacements.append(replacement) replacements = updated_replacements - default = default.astype(common_dtypes) + default = default.astype(common_dtype) counter = reversed(range(len(conditions))) for position, condition, replacement in zip( From 8569cd1e124abc191dc124ea252bbbaa1ec2be29 Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 23 Dec 2023 13:47:15 +1100 Subject: [PATCH 14/34] fix doctest failure --- pandas/core/series.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index 2cc1a5190528f..c57244af51d0e 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5684,8 +5684,8 @@ def case_when( 2 1 4 8 3 2 5 9 - >>> df.c.case_when((df.a.gt(0), df.a), # condition, replacement - ... (df.b.gt(0), df.b)) + >>> df.c.case_when(caselist=[(df.a.gt(0), df.a), # condition, replacement + ... (df.b.gt(0), df.b)]) 0 6 1 3 2 1 From bbb588786c8c47597158963dbc7048183a62c14c Mon Sep 17 00:00:00 2001 From: samukweku Date: Sat, 23 Dec 2023 13:51:35 +1100 Subject: [PATCH 15/34] fix for whatnew --- doc/source/whatsnew/v2.2.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 2d467ba745f06..90d9903939018 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -96,6 +96,7 @@ This future dtype inference logic can be enabled with: Enhancements ~~~~~~~~~~~~ + .. _whatsnew_220.enhancements.case_when: Create a pandas Series based on one or more conditions From e03e3dc97942d8360000f893217415fee844eed5 Mon Sep 17 00:00:00 2001 From: Samuel Oranyeli Date: Sun, 24 Dec 2023 00:01:05 +1100 Subject: [PATCH 16/34] Update doc/source/whatsnew/v2.2.0.rst Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com> --- doc/source/whatsnew/v2.2.0.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 90d9903939018..5da37c99b525c 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -109,10 +109,12 @@ The :func:`case_when` function has been added to create a Series object based on import pandas as pd df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) - default=pd.Series(['default'],index=df.index) + default=pd.Series('default', index=df.index) default.case_when( - caselist = [(df.a == 1, 'first'), # condition, replacement - (df.a.gt(1) & df.b.eq(5), 'second')], # condition, replacement + caselist=[ + (df.a == 1, 'first'), # condition, replacement + (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement + ], ) .. _whatsnew_220.enhancements.adbc_support: From 283488fc530313307f5ff3f8f5b463c76a75a662 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 23 Dec 2023 19:20:16 +0100 Subject: [PATCH 17/34] Update v2.2.0.rst --- doc/source/whatsnew/v2.2.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 5da37c99b525c..e4b7ceda84cce 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -114,7 +114,7 @@ The :func:`case_when` function has been added to create a Series object based on caselist=[ (df.a == 1, 'first'), # condition, replacement (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement - ], + ], ) .. _whatsnew_220.enhancements.adbc_support: From 7a8694c7c0edcf75e13f27a2346a112163edc2bd Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 23 Dec 2023 19:20:56 +0100 Subject: [PATCH 18/34] Update v2.2.0.rst --- doc/source/whatsnew/v2.2.0.rst | 40 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index e4b7ceda84cce..dbc4d13b85606 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -97,26 +97,6 @@ This future dtype inference logic can be enabled with: Enhancements ~~~~~~~~~~~~ -.. _whatsnew_220.enhancements.case_when: - -Create a pandas Series based on one or more conditions -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`) - -.. ipython:: python - - import pandas as pd - - df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) - default=pd.Series('default', index=df.index) - default.case_when( - caselist=[ - (df.a == 1, 'first'), # condition, replacement - (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement - ], - ) - .. _whatsnew_220.enhancements.adbc_support: ADBC Driver support in to_sql and read_sql @@ -208,6 +188,26 @@ For a full list of ADBC drivers and their development status, see the `ADBC Driv Implementation Status `_ documentation. +.. _whatsnew_220.enhancements.case_when: + +Create a pandas Series based on one or more conditions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`) + +.. ipython:: python + + import pandas as pd + + df = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6])) + default=pd.Series('default', index=df.index) + default.case_when( + caselist=[ + (df.a == 1, 'first'), # condition, replacement + (df.a.gt(1) & df.b.eq(5), 'second'), # condition, replacement + ], + ) + .. _whatsnew_220.enhancements.to_numpy_ea: ``to_numpy`` for NumPy nullable and Arrow types converts to suitable NumPy dtype From 67dfcaa978a7bf210186b2a5dbbb81d08b7bc704 Mon Sep 17 00:00:00 2001 From: samukweku Date: Sun, 24 Dec 2023 16:40:31 +1100 Subject: [PATCH 19/34] improve typing and add test for callable --- pandas/core/series.py | 7 ++++++- pandas/tests/series/methods/test_case_when.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index c57244af51d0e..ec14da3b8f7f8 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5638,7 +5638,12 @@ def between( def case_when( self, - caselist: list[tuple[ArrayLike, ArrayLike | Scalar]], + caselist: list[ + tuple[ + ArrayLike | Callable[[Series | np.ndarray | Sequence[bool]]], + ArrayLike | Callable[[Series | np.ndarray]] | Scalar, + ] + ], ) -> Series: """ Replace values where the conditions are True. diff --git a/pandas/tests/series/methods/test_case_when.py b/pandas/tests/series/methods/test_case_when.py index 24f15648a7138..7cb60a11644a3 100644 --- a/pandas/tests/series/methods/test_case_when.py +++ b/pandas/tests/series/methods/test_case_when.py @@ -129,3 +129,20 @@ def test_case_when_non_range_index(): result = Series(5, index=df.index, name="A").case_when([(df.A.gt(0), df.B)]) expected = df.A.mask(df.A.gt(0), df.B).where(df.A.gt(0), 5) tm.assert_series_equal(result, expected) + + +def test_case_when_callable(): + """ + Test output on a callable + """ + # https://numpy.org/doc/stable/reference/generated/numpy.piecewise.html + x = np.linspace(-2.5, 2.5, 6) + ser = Series(x) + result = ser.case_when( + caselist=[ + (lambda df: df < 0, lambda df: -df), + (lambda df: df >= 0, lambda df: df), + ] + ) + expected = np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x]) + tm.assert_series_equal(result, Series(expected)) From 3da7cf280d6e7d704f38732f06abcbe1e3d279d1 Mon Sep 17 00:00:00 2001 From: samukweku Date: Sun, 24 Dec 2023 18:06:45 +1100 Subject: [PATCH 20/34] fix typing error --- pandas/core/series.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index ec14da3b8f7f8..962a18a6b272c 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5640,9 +5640,9 @@ def case_when( self, caselist: list[ tuple[ - ArrayLike | Callable[[Series | np.ndarray | Sequence[bool]]], - ArrayLike | Callable[[Series | np.ndarray]] | Scalar, - ] + ArrayLike | Callable[[Series], Series | np.ndarray | Sequence[bool]], + ArrayLike | Scalar | Callable[[Series], Series | np.ndarray], + ], ], ) -> Series: """ From bdc54f62fb92f9a5887eef4ae443a9785a52b72e Mon Sep 17 00:00:00 2001 From: Samuel Oranyeli Date: Mon, 25 Dec 2023 22:21:26 +1100 Subject: [PATCH 21/34] Update pandas/core/series.py Co-authored-by: Xiao Yuan --- pandas/core/series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index 962a18a6b272c..586facf757c5c 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5702,7 +5702,7 @@ def case_when( f"The caselist argument should be a list; instead got {type(caselist)}" ) - if not len(caselist): + if not caselist: raise ValueError( "provide at least one boolean condition, " "with a corresponding replacement." From b68d20e9f91b96738406097b3b6bbd8da32f90fd Mon Sep 17 00:00:00 2001 From: Samuel Oranyeli Date: Thu, 28 Dec 2023 11:19:31 +1100 Subject: [PATCH 22/34] Update doc/source/whatsnew/v2.2.0.rst Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> --- doc/source/whatsnew/v2.2.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 6a6cc07409692..0f26474833234 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -193,7 +193,7 @@ documentation. Create a pandas Series based on one or more conditions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The :func:`case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`) +The :meth:`Series.case_when` function has been added to create a Series object based on one or more conditions. (:issue:`39154`) .. ipython:: python From b4de208502700fc7e52eba65d95ac4d3c95c67e6 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 27 Dec 2023 10:52:01 -0800 Subject: [PATCH 23/34] PERF: resolution, is_normalized (#56637) PERF: resolution --- pandas/_libs/tslibs/vectorized.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/_libs/tslibs/vectorized.pyx b/pandas/_libs/tslibs/vectorized.pyx index 0a19092f57706..1e09874639d4f 100644 --- a/pandas/_libs/tslibs/vectorized.pyx +++ b/pandas/_libs/tslibs/vectorized.pyx @@ -234,7 +234,7 @@ def get_resolution( for i in range(n): # Analogous to: utc_val = stamps[i] - utc_val = cnp.PyArray_GETITEM(stamps, cnp.PyArray_ITER_DATA(it)) + utc_val = (cnp.PyArray_ITER_DATA(it))[0] if utc_val == NPY_NAT: pass @@ -331,7 +331,7 @@ def is_date_array_normalized(ndarray stamps, tzinfo tz, NPY_DATETIMEUNIT reso) - for i in range(n): # Analogous to: utc_val = stamps[i] - utc_val = cnp.PyArray_GETITEM(stamps, cnp.PyArray_ITER_DATA(it)) + utc_val = (cnp.PyArray_ITER_DATA(it))[0] local_val = info.utc_val_to_local_val(utc_val, &pos) From 5966bfe38e58713e458845fb1ea60ab9ab09a761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Wed, 27 Dec 2023 13:54:54 -0500 Subject: [PATCH 24/34] TYP: more simple return types from ruff (#56628) TYP: more return types from ruff --- pandas/core/apply.py | 2 +- pandas/core/array_algos/replace.py | 2 +- pandas/core/arrays/arrow/accessors.py | 2 +- pandas/core/arrays/arrow/extension_types.py | 2 +- pandas/core/arrays/categorical.py | 6 ++++-- pandas/core/arrays/datetimelike.py | 4 ++-- pandas/core/arrays/sparse/accessor.py | 6 +++--- pandas/core/arrays/sparse/array.py | 4 +++- pandas/core/arrays/sparse/scipy_sparse.py | 2 +- pandas/core/arrays/string_.py | 2 +- pandas/core/computation/eval.py | 8 ++++---- pandas/core/computation/ops.py | 2 +- pandas/core/dtypes/cast.py | 2 +- pandas/core/frame.py | 4 ++-- pandas/core/generic.py | 4 ++-- pandas/core/indexes/base.py | 4 ++-- pandas/core/indexes/multi.py | 2 +- pandas/core/indexing.py | 14 ++++++++------ pandas/core/internals/base.py | 2 +- pandas/core/internals/managers.py | 6 ++++-- pandas/core/ops/array_ops.py | 2 +- pandas/core/reshape/concat.py | 2 +- pandas/core/reshape/encoding.py | 2 +- pandas/core/reshape/merge.py | 2 +- pandas/core/reshape/reshape.py | 2 +- pandas/core/window/rolling.py | 6 +++--- pandas/io/formats/style_render.py | 4 ++-- pandas/io/pytables.py | 2 +- 28 files changed, 55 insertions(+), 47 deletions(-) diff --git a/pandas/core/apply.py b/pandas/core/apply.py index 25a71ce5b5f4f..784e11415ade6 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -827,7 +827,7 @@ def generate_numba_apply_func( def apply_with_numba(self): pass - def validate_values_for_numba(self): + def validate_values_for_numba(self) -> None: # Validate column dtyps all OK for colname, dtype in self.obj.dtypes.items(): if not is_numeric_dtype(dtype): diff --git a/pandas/core/array_algos/replace.py b/pandas/core/array_algos/replace.py index 5f377276be480..60fc172139f13 100644 --- a/pandas/core/array_algos/replace.py +++ b/pandas/core/array_algos/replace.py @@ -67,7 +67,7 @@ def compare_or_regex_search( def _check_comparison_types( result: ArrayLike | bool, a: ArrayLike, b: Scalar | Pattern - ): + ) -> None: """ Raises an error if the two arrays (a,b) cannot be compared. Otherwise, returns the comparison result as expected. diff --git a/pandas/core/arrays/arrow/accessors.py b/pandas/core/arrays/arrow/accessors.py index 7f88267943526..23825faa70095 100644 --- a/pandas/core/arrays/arrow/accessors.py +++ b/pandas/core/arrays/arrow/accessors.py @@ -39,7 +39,7 @@ def __init__(self, data, validation_msg: str) -> None: def _is_valid_pyarrow_dtype(self, pyarrow_dtype) -> bool: pass - def _validate(self, data): + def _validate(self, data) -> None: dtype = data.dtype if not isinstance(dtype, ArrowDtype): # Raise AttributeError so that inspect can handle non-struct Series. diff --git a/pandas/core/arrays/arrow/extension_types.py b/pandas/core/arrays/arrow/extension_types.py index 72bfd6f2212f8..d52b60df47adc 100644 --- a/pandas/core/arrays/arrow/extension_types.py +++ b/pandas/core/arrays/arrow/extension_types.py @@ -135,7 +135,7 @@ def to_pandas_dtype(self) -> IntervalDtype: """ -def patch_pyarrow(): +def patch_pyarrow() -> None: # starting from pyarrow 14.0.1, it has its own mechanism if not pa_version_under14p1: return diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 065a942cae768..8a88227ad54a3 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2164,7 +2164,9 @@ def __contains__(self, key) -> bool: # ------------------------------------------------------------------ # Rendering Methods - def _formatter(self, boxed: bool = False): + # error: Return type "None" of "_formatter" incompatible with return + # type "Callable[[Any], str | None]" in supertype "ExtensionArray" + def _formatter(self, boxed: bool = False) -> None: # type: ignore[override] # Returning None here will cause format_array to do inference. return None @@ -2890,7 +2892,7 @@ def __init__(self, data) -> None: self._freeze() @staticmethod - def _validate(data): + def _validate(data) -> None: if not isinstance(data.dtype, CategoricalDtype): raise AttributeError("Can only use .cat accessor with a 'category' dtype") diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 11a0c7bf18fcb..e04fcb84d51a0 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -2058,7 +2058,7 @@ def freq(self, value) -> None: self._freq = value @final - def _maybe_pin_freq(self, freq, validate_kwds: dict): + def _maybe_pin_freq(self, freq, validate_kwds: dict) -> None: """ Constructor helper to pin the appropriate `freq` attribute. Assumes that self._freq is currently set to any freq inferred in @@ -2092,7 +2092,7 @@ def _maybe_pin_freq(self, freq, validate_kwds: dict): @final @classmethod - def _validate_frequency(cls, index, freq: BaseOffset, **kwargs): + def _validate_frequency(cls, index, freq: BaseOffset, **kwargs) -> None: """ Validate that a frequency is compatible with the values of a given Datetime Array/Index or Timedelta Array/Index diff --git a/pandas/core/arrays/sparse/accessor.py b/pandas/core/arrays/sparse/accessor.py index fc7debb1f31e4..3dd7ebf564ca1 100644 --- a/pandas/core/arrays/sparse/accessor.py +++ b/pandas/core/arrays/sparse/accessor.py @@ -30,7 +30,7 @@ def __init__(self, data=None) -> None: self._parent = data self._validate(data) - def _validate(self, data): + def _validate(self, data) -> None: raise NotImplementedError @@ -50,7 +50,7 @@ class SparseAccessor(BaseAccessor, PandasDelegate): array([2, 2, 2]) """ - def _validate(self, data): + def _validate(self, data) -> None: if not isinstance(data.dtype, SparseDtype): raise AttributeError(self._validation_msg) @@ -243,7 +243,7 @@ class SparseFrameAccessor(BaseAccessor, PandasDelegate): 0.5 """ - def _validate(self, data): + def _validate(self, data) -> None: dtypes = data.dtypes if not all(isinstance(t, SparseDtype) for t in dtypes): raise AttributeError(self._validation_msg) diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 5db77db2a9c66..7a3ea85dde2b4 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -1830,7 +1830,9 @@ def __repr__(self) -> str: pp_index = printing.pprint_thing(self.sp_index) return f"{pp_str}\nFill: {pp_fill}\n{pp_index}" - def _formatter(self, boxed: bool = False): + # error: Return type "None" of "_formatter" incompatible with return + # type "Callable[[Any], str | None]" in supertype "ExtensionArray" + def _formatter(self, boxed: bool = False) -> None: # type: ignore[override] # Defer to the formatter from the GenericArrayFormatter calling us. # This will infer the correct formatter from the dtype of the values. return None diff --git a/pandas/core/arrays/sparse/scipy_sparse.py b/pandas/core/arrays/sparse/scipy_sparse.py index 71b71a9779da5..31e09c923d933 100644 --- a/pandas/core/arrays/sparse/scipy_sparse.py +++ b/pandas/core/arrays/sparse/scipy_sparse.py @@ -27,7 +27,7 @@ ) -def _check_is_partition(parts: Iterable, whole: Iterable): +def _check_is_partition(parts: Iterable, whole: Iterable) -> None: whole = set(whole) parts = [set(x) for x in parts] if set.intersection(*parts) != set(): diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 00197a150fb97..f451ebc352733 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -364,7 +364,7 @@ def __init__(self, values, copy: bool = False) -> None: self._validate() NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage="python")) - def _validate(self): + def _validate(self) -> None: """Validate that we only store NA or strings.""" if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True): raise ValueError("StringArray requires a sequence of strings or pandas.NA") diff --git a/pandas/core/computation/eval.py b/pandas/core/computation/eval.py index f1fe528de06f8..6313c2e2c98de 100644 --- a/pandas/core/computation/eval.py +++ b/pandas/core/computation/eval.py @@ -72,7 +72,7 @@ def _check_engine(engine: str | None) -> str: return engine -def _check_parser(parser: str): +def _check_parser(parser: str) -> None: """ Make sure a valid parser is passed. @@ -91,7 +91,7 @@ def _check_parser(parser: str): ) -def _check_resolvers(resolvers): +def _check_resolvers(resolvers) -> None: if resolvers is not None: for resolver in resolvers: if not hasattr(resolver, "__getitem__"): @@ -102,7 +102,7 @@ def _check_resolvers(resolvers): ) -def _check_expression(expr): +def _check_expression(expr) -> None: """ Make sure an expression is not an empty string @@ -149,7 +149,7 @@ def _convert_expression(expr) -> str: return s -def _check_for_locals(expr: str, stack_level: int, parser: str): +def _check_for_locals(expr: str, stack_level: int, parser: str) -> None: at_top_of_stack = stack_level == 0 not_pandas_parser = parser != "pandas" diff --git a/pandas/core/computation/ops.py b/pandas/core/computation/ops.py index 95ac20ba39edc..9422434b5cde3 100644 --- a/pandas/core/computation/ops.py +++ b/pandas/core/computation/ops.py @@ -491,7 +491,7 @@ def stringify(value): v = v.tz_convert("UTC") self.lhs.update(v) - def _disallow_scalar_only_bool_ops(self): + def _disallow_scalar_only_bool_ops(self) -> None: rhs = self.rhs lhs = self.lhs diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 7a088bf84c48e..72c33e95f68a0 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -231,7 +231,7 @@ def _maybe_unbox_datetimelike(value: Scalar, dtype: DtypeObj) -> Scalar: return value -def _disallow_mismatched_datetimelike(value, dtype: DtypeObj): +def _disallow_mismatched_datetimelike(value, dtype: DtypeObj) -> None: """ numpy allows np.array(dt64values, dtype="timedelta64[ns]") and vice-versa, but we do not want to allow this, so we need to diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 3e2e589440bd9..a46e42b9241ff 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -4316,7 +4316,7 @@ def _setitem_array(self, key, value): else: self._iset_not_inplace(key, value) - def _iset_not_inplace(self, key, value): + def _iset_not_inplace(self, key, value) -> None: # GH#39510 when setting with df[key] = obj with a list-like key and # list-like value, we iterate over those listlikes and set columns # one at a time. This is different from dispatching to @@ -4360,7 +4360,7 @@ def igetitem(obj, i: int): finally: self.columns = orig_columns - def _setitem_frame(self, key, value): + def _setitem_frame(self, key, value) -> None: # support boolean setting with DataFrame input, e.g. # df[df > df2] = 0 if isinstance(key, np.ndarray): diff --git a/pandas/core/generic.py b/pandas/core/generic.py index de25a02c6b37c..91a150c63c5b6 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4394,7 +4394,7 @@ def _check_is_chained_assignment_possible(self) -> bool_t: return False @final - def _check_setitem_copy(self, t: str = "setting", force: bool_t = False): + def _check_setitem_copy(self, t: str = "setting", force: bool_t = False) -> None: """ Parameters @@ -4510,7 +4510,7 @@ def __delitem__(self, key) -> None: # Unsorted @final - def _check_inplace_and_allows_duplicate_labels(self, inplace: bool_t): + def _check_inplace_and_allows_duplicate_labels(self, inplace: bool_t) -> None: if inplace and not self.flags.allows_duplicate_labels: raise ValueError( "Cannot specify 'inplace=True' when " diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 88a08dd55f739..d262dcd144d79 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3209,7 +3209,7 @@ def _get_reconciled_name_object(self, other): return self @final - def _validate_sort_keyword(self, sort): + def _validate_sort_keyword(self, sort) -> None: if sort not in [None, False, True]: raise ValueError( "The 'sort' keyword only takes the values of " @@ -6051,7 +6051,7 @@ def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]: # by RangeIndex, MultIIndex return self._data.argsort(*args, **kwargs) - def _check_indexing_error(self, key): + def _check_indexing_error(self, key) -> None: if not is_scalar(key): # if key is not a scalar, directly raise an error (the code below # would convert to numpy arrays and raise later any way) - GH29926 diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 2a4e027e2b806..56e3899eae6f6 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -1571,7 +1571,7 @@ def _format_multi( def _get_names(self) -> FrozenList: return FrozenList(self._names) - def _set_names(self, names, *, level=None, validate: bool = True): + def _set_names(self, names, *, level=None, validate: bool = True) -> None: """ Set new names on index. Each name has to be a hashable type. diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index 4be7e17035128..a7dd3b486ab11 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -911,7 +911,7 @@ def __setitem__(self, key, value) -> None: iloc = self if self.name == "iloc" else self.obj.iloc iloc._setitem_with_indexer(indexer, value, self.name) - def _validate_key(self, key, axis: AxisInt): + def _validate_key(self, key, axis: AxisInt) -> None: """ Ensure that key is valid for current indexer. @@ -1225,7 +1225,7 @@ class _LocIndexer(_LocationIndexer): # Key Checks @doc(_LocationIndexer._validate_key) - def _validate_key(self, key, axis: Axis): + def _validate_key(self, key, axis: Axis) -> None: # valid for a collection of labels (we check their presence later) # slice of labels (where start-end in labels) # slice of integers (only if in the labels) @@ -1572,7 +1572,7 @@ class _iLocIndexer(_LocationIndexer): # ------------------------------------------------------------------- # Key Checks - def _validate_key(self, key, axis: AxisInt): + def _validate_key(self, key, axis: AxisInt) -> None: if com.is_bool_indexer(key): if hasattr(key, "index") and isinstance(key.index, Index): if key.index.inferred_type == "integer": @@ -1783,7 +1783,7 @@ def _get_setitem_indexer(self, key): # ------------------------------------------------------------------- - def _setitem_with_indexer(self, indexer, value, name: str = "iloc"): + def _setitem_with_indexer(self, indexer, value, name: str = "iloc") -> None: """ _setitem_with_indexer is for setting values on a Series/DataFrame using positional indexers. @@ -2038,7 +2038,7 @@ def _setitem_with_indexer_split_path(self, indexer, value, name: str): for loc in ilocs: self._setitem_single_column(loc, value, pi) - def _setitem_with_indexer_2d_value(self, indexer, value): + def _setitem_with_indexer_2d_value(self, indexer, value) -> None: # We get here with np.ndim(value) == 2, excluding DataFrame, # which goes through _setitem_with_indexer_frame_value pi = indexer[0] @@ -2060,7 +2060,9 @@ def _setitem_with_indexer_2d_value(self, indexer, value): value_col = value_col.tolist() self._setitem_single_column(loc, value_col, pi) - def _setitem_with_indexer_frame_value(self, indexer, value: DataFrame, name: str): + def _setitem_with_indexer_frame_value( + self, indexer, value: DataFrame, name: str + ) -> None: ilocs = self._ensure_iterable_column_indexer(indexer[1]) sub_indexer = list(indexer) diff --git a/pandas/core/internals/base.py b/pandas/core/internals/base.py index ae91f167205a0..8f16a6623c8cb 100644 --- a/pandas/core/internals/base.py +++ b/pandas/core/internals/base.py @@ -53,7 +53,7 @@ class _AlreadyWarned: - def __init__(self): + def __init__(self) -> None: # This class is used on the manager level to the block level to # ensure that we warn only once. The block method can update the # warned_already option without returning a value to keep the diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index 3719bf1f77f85..5f38720135efa 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -1940,13 +1940,15 @@ def _post_setstate(self) -> None: def _block(self) -> Block: return self.blocks[0] + # error: Cannot override writeable attribute with read-only property @property - def _blknos(self): + def _blknos(self) -> None: # type: ignore[override] """compat with BlockManager""" return None + # error: Cannot override writeable attribute with read-only property @property - def _blklocs(self): + def _blklocs(self) -> None: # type: ignore[override] """compat with BlockManager""" return None diff --git a/pandas/core/ops/array_ops.py b/pandas/core/ops/array_ops.py index 4b762a359d321..8ccd7c84cb05c 100644 --- a/pandas/core/ops/array_ops.py +++ b/pandas/core/ops/array_ops.py @@ -591,7 +591,7 @@ def maybe_prepare_scalar_for_op(obj, shape: Shape): } -def _bool_arith_check(op, a: np.ndarray, b): +def _bool_arith_check(op, a: np.ndarray, b) -> None: """ In contrast to numpy, pandas raises an error for certain operations with booleans. diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index aacea92611697..31859c7d04e04 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -765,7 +765,7 @@ def _get_concat_axis(self) -> Index: return concat_axis - def _maybe_check_integrity(self, concat_index: Index): + def _maybe_check_integrity(self, concat_index: Index) -> None: if self.verify_integrity: if not concat_index.is_unique: overlap = concat_index[concat_index.duplicated()].unique() diff --git a/pandas/core/reshape/encoding.py b/pandas/core/reshape/encoding.py index 3ed67bb7b7c02..44158227d903b 100644 --- a/pandas/core/reshape/encoding.py +++ b/pandas/core/reshape/encoding.py @@ -169,7 +169,7 @@ def get_dummies( data_to_encode = data[columns] # validate prefixes and separator to avoid silently dropping cols - def check_len(item, name: str): + def check_len(item, name: str) -> None: if is_list_like(item): if not len(item) == data_to_encode.shape[1]: len_msg = ( diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 690e3c2700c6c..f4903023e8059 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -2091,7 +2091,7 @@ def _maybe_require_matching_dtypes( ) -> None: # TODO: why do we do this for AsOfMerge but not the others? - def _check_dtype_match(left: ArrayLike, right: ArrayLike, i: int): + def _check_dtype_match(left: ArrayLike, right: ArrayLike, i: int) -> None: if left.dtype != right.dtype: if isinstance(left.dtype, CategoricalDtype) and isinstance( right.dtype, CategoricalDtype diff --git a/pandas/core/reshape/reshape.py b/pandas/core/reshape/reshape.py index 7a49682d7c57c..3493f1c78da91 100644 --- a/pandas/core/reshape/reshape.py +++ b/pandas/core/reshape/reshape.py @@ -188,7 +188,7 @@ def _make_sorted_values(self, values: np.ndarray) -> np.ndarray: return sorted_values return values - def _make_selectors(self): + def _make_selectors(self) -> None: new_levels = self.new_index_levels # make the mask diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index e78bd258c11ff..fa5b84fefb883 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -1143,7 +1143,7 @@ class Window(BaseWindow): "method", ] - def _validate(self): + def _validate(self) -> None: super()._validate() if not isinstance(self.win_type, str): @@ -1861,7 +1861,7 @@ class Rolling(RollingAndExpandingMixin): "method", ] - def _validate(self): + def _validate(self) -> None: super()._validate() # we allow rolling on a datetimelike index @@ -2906,7 +2906,7 @@ def _get_window_indexer(self) -> GroupbyIndexer: ) return window_indexer - def _validate_datetimelike_monotonic(self): + def _validate_datetimelike_monotonic(self) -> None: """ Validate that each group in self._on is monotonic """ diff --git a/pandas/io/formats/style_render.py b/pandas/io/formats/style_render.py index 416b263ba8497..55541e5262719 100644 --- a/pandas/io/formats/style_render.py +++ b/pandas/io/formats/style_render.py @@ -2288,12 +2288,12 @@ def _parse_latex_css_conversion(styles: CSSList) -> CSSList: Ignore conversion if tagged with `--latex` option, skipped if no conversion found. """ - def font_weight(value, arg): + def font_weight(value, arg) -> tuple[str, str] | None: if value in ("bold", "bolder"): return "bfseries", f"{arg}" return None - def font_style(value, arg): + def font_style(value, arg) -> tuple[str, str] | None: if value == "italic": return "itshape", f"{arg}" if value == "oblique": diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 1139519d2bcd3..c30238e412450 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -3207,7 +3207,7 @@ class SeriesFixed(GenericFixed): name: Hashable @property - def shape(self): + def shape(self) -> tuple[int] | None: try: return (len(self.group.values),) except (TypeError, AttributeError): From 3e404fa8f6ea6944c70b5965812390309dda7636 Mon Sep 17 00:00:00 2001 From: Caden Gobat <36030084+cgobat@users.noreply.github.com> Date: Wed, 27 Dec 2023 10:59:52 -0800 Subject: [PATCH 25/34] ENH: Update CFF with publication reference, Zenodo DOI, and other details (#56589) Add McKinney (2010) reference, DOI, team website, & more keywords --- CITATION.cff | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/CITATION.cff b/CITATION.cff index 741e7e7ac8c85..11f45b0d87ec7 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,12 +3,50 @@ title: 'pandas-dev/pandas: Pandas' message: 'If you use this software, please cite it as below.' authors: - name: "The pandas development team" + website: "https://pandas.pydata.org/about/team.html" abstract: "Pandas is a powerful data structures for data analysis, time series, and statistics." +doi: 10.5281/zenodo.3509134 license: BSD-3-Clause license-url: "https://github.com/pandas-dev/pandas/blob/main/LICENSE" repository-code: "https://github.com/pandas-dev/pandas" keywords: - python - data science + - flexible + - pandas + - alignment + - data analysis type: software -url: "https://github.com/pandas-dev/pandas" +url: "https://pandas.pydata.org/" +references: + - type: article + authors: + - given-names: Wes + family-names: McKinney + affiliation: AQR Capital Management, LLC + email: wesmckinn@gmail.com + title: Data Structures for Statistical Computing in Python + doi: 10.25080/Majora-92bf1922-00a + license: CC-BY-3.0 + start: 56 + end: 61 + year: 2010 + collection-title: Proceedings of the 9th Python in Science Conference + collection-doi: 10.25080/Majora-92bf1922-012 + collection-type: proceedings + editors: + - given-names: Stéfan + name-particle: van der + family-names: Walt + - given-names: Jarrod + family-names: Millman + conference: + name: 9th Python in Science Conference (SciPy 2010) + city: Austin, TX + country: US + date-start: "2010-06-28" + date-end: "2010-07-03" + keywords: + - data structure + - statistics + - R From 21659bc29b5d8e8b666beb647ca2d83f9629a8eb Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 27 Dec 2023 20:02:54 +0100 Subject: [PATCH 26/34] DOC: Fixup CoW userguide (#56636) --- doc/source/user_guide/copy_on_write.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/user_guide/copy_on_write.rst b/doc/source/user_guide/copy_on_write.rst index 050c3901c3420..a083297925007 100644 --- a/doc/source/user_guide/copy_on_write.rst +++ b/doc/source/user_guide/copy_on_write.rst @@ -317,7 +317,7 @@ you are modifying one object inplace. .. ipython:: python df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - df2 = df.reset_index() + df2 = df.reset_index(drop=True) df2.iloc[0, 0] = 100 This creates two objects that share data and thus the setitem operation will trigger a @@ -328,7 +328,7 @@ held by the object. .. ipython:: python df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - df = df.reset_index() + df = df.reset_index(drop=True) df.iloc[0, 0] = 100 No copy is necessary in this example. From f6d8cd047a75854c24392bb8022ab10ffbc7e633 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 27 Dec 2023 11:09:17 -0800 Subject: [PATCH 27/34] REF: check monotonicity inside _can_use_libjoin (#55342) * REF: fix can_use_libjoin check * DOC: docstring for can_use_libjoin * Make can_use_libjoin checks more-correct * avoid allocating mapping in monotonic cases * fix categorical memory usage tests * catch decimal.InvalidOperation --------- Co-authored-by: Luke Manley --- pandas/_libs/index.pyx | 12 +++++++++++- pandas/core/frame.py | 2 +- pandas/core/indexes/base.py | 20 ++++++++++---------- pandas/tests/extension/test_categorical.py | 5 ----- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index 0dc139781f58d..675288e20d1f8 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -43,6 +43,8 @@ from pandas._libs.missing cimport ( is_matching_na, ) +from decimal import InvalidOperation + # Defines shift of MultiIndex codes to avoid negative codes (missing values) multiindex_nulls_shift = 2 @@ -248,6 +250,10 @@ cdef class IndexEngine: @property def is_unique(self) -> bool: + # for why we check is_monotonic_increasing here, see + # https://github.com/pandas-dev/pandas/pull/55342#discussion_r1361405781 + if self.need_monotonic_check: + self.is_monotonic_increasing if self.need_unique_check: self._do_unique_check() @@ -281,7 +287,7 @@ cdef class IndexEngine: values = self.values self.monotonic_inc, self.monotonic_dec, is_strict_monotonic = \ self._call_monotonic(values) - except TypeError: + except (TypeError, InvalidOperation): self.monotonic_inc = 0 self.monotonic_dec = 0 is_strict_monotonic = 0 @@ -843,6 +849,10 @@ cdef class SharedEngine: @property def is_unique(self) -> bool: + # for why we check is_monotonic_increasing here, see + # https://github.com/pandas-dev/pandas/pull/55342#discussion_r1361405781 + if self.need_monotonic_check: + self.is_monotonic_increasing if self.need_unique_check: arr = self.values.unique() self.unique = len(arr) == len(self.values) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index a46e42b9241ff..c24ef4d6d6d42 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3711,7 +3711,7 @@ def memory_usage(self, index: bool = True, deep: bool = False) -> Series: many repeated values. >>> df['object'].astype('category').memory_usage(deep=True) - 5244 + 5136 """ result = self._constructor_sliced( [c.memory_usage(index=False, deep=deep) for col, c in self.items()], diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index d262dcd144d79..166d6946beacf 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3382,9 +3382,7 @@ def _union(self, other: Index, sort: bool | None): if ( sort in (None, True) - and self.is_monotonic_increasing - and other.is_monotonic_increasing - and not (self.has_duplicates and other.has_duplicates) + and (self.is_unique or other.is_unique) and self._can_use_libjoin and other._can_use_libjoin ): @@ -3536,12 +3534,7 @@ def _intersection(self, other: Index, sort: bool = False): """ intersection specialized to the case with matching dtypes. """ - if ( - self.is_monotonic_increasing - and other.is_monotonic_increasing - and self._can_use_libjoin - and other._can_use_libjoin - ): + if self._can_use_libjoin and other._can_use_libjoin: try: res_indexer, indexer, _ = self._inner_indexer(other) except TypeError: @@ -4980,7 +4973,10 @@ def _get_leaf_sorter(labels: list[np.ndarray]) -> npt.NDArray[np.intp]: def _join_monotonic( self, other: Index, how: JoinHow = "left" ) -> tuple[Index, npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]: - # We only get here with matching dtypes and both monotonic increasing + # We only get here with (caller is responsible for ensuring): + # 1) matching dtypes + # 2) both monotonic increasing + # 3) other.is_unique or self.is_unique assert other.dtype == self.dtype assert self._can_use_libjoin and other._can_use_libjoin @@ -5062,6 +5058,10 @@ def _can_use_libjoin(self) -> bool: making a copy. If we cannot, this negates the performance benefit of using libjoin. """ + if not self.is_monotonic_increasing: + # The libjoin functions all assume monotonicity. + return False + if type(self) is Index: # excludes EAs, but include masks, we get here with monotonic # values only, meaning no NA diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index 6f33b18b19c51..1b322b1797144 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -75,11 +75,6 @@ def data_for_grouping(): class TestCategorical(base.ExtensionTests): - @pytest.mark.xfail(reason="Memory usage doesn't match") - def test_memory_usage(self, data): - # TODO: Is this deliberate? - super().test_memory_usage(data) - def test_contains(self, data, data_missing): # GH-37867 # na value handling in Categorical.__contains__ is deprecated. From becc626219196e6885299efb67a4773b7ba6bc65 Mon Sep 17 00:00:00 2001 From: Richard Shadrach <45562402+rhshadrach@users.noreply.github.com> Date: Wed, 27 Dec 2023 14:19:25 -0500 Subject: [PATCH 28/34] DOC: Minor fixups for 2.2.0 whatsnew (#56632) --- doc/source/whatsnew/v2.2.0.rst | 118 ++++++++++++--------------------- 1 file changed, 43 insertions(+), 75 deletions(-) diff --git a/doc/source/whatsnew/v2.2.0.rst b/doc/source/whatsnew/v2.2.0.rst index 0f26474833234..82e2b18db6114 100644 --- a/doc/source/whatsnew/v2.2.0.rst +++ b/doc/source/whatsnew/v2.2.0.rst @@ -123,7 +123,7 @@ nullability handling. with pg_dbapi.connect(uri) as conn: df.to_sql("pandas_table", conn, index=False) - # for roundtripping + # for round-tripping with pg_dbapi.connect(uri) as conn: df2 = pd.read_sql("pandas_table", conn) @@ -176,7 +176,7 @@ leverage the ``dtype_backend="pyarrow"`` argument of :func:`~pandas.read_sql` .. code-block:: ipython - # for roundtripping + # for round-tripping with pg_dbapi.connect(uri) as conn: df2 = pd.read_sql("pandas_table", conn, dtype_backend="pyarrow") @@ -326,22 +326,21 @@ Other enhancements - :meth:`~DataFrame.to_sql` with method parameter set to ``multi`` works with Oracle on the backend - :attr:`Series.attrs` / :attr:`DataFrame.attrs` now uses a deepcopy for propagating ``attrs`` (:issue:`54134`). - :func:`get_dummies` now returning extension dtypes ``boolean`` or ``bool[pyarrow]`` that are compatible with the input dtype (:issue:`56273`) -- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`) +- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"`` (:issue:`54480`) - :func:`read_sas` returns ``datetime64`` dtypes with resolutions better matching those stored natively in SAS, and avoids returning object-dtype in cases that cannot be stored with ``datetime64[ns]`` dtype (:issue:`56127`) -- :func:`read_spss` now returns a :class:`DataFrame` that stores the metadata in :attr:`DataFrame.attrs`. (:issue:`54264`) +- :func:`read_spss` now returns a :class:`DataFrame` that stores the metadata in :attr:`DataFrame.attrs` (:issue:`54264`) - :func:`tseries.api.guess_datetime_format` is now part of the public API (:issue:`54727`) +- :meth:`DataFrame.apply` now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`) - :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`) - :meth:`ExtensionArray.duplicated` added to allow extension type implementations of the ``duplicated`` method (:issue:`55255`) - :meth:`Series.ffill`, :meth:`Series.bfill`, :meth:`DataFrame.ffill`, and :meth:`DataFrame.bfill` have gained the argument ``limit_area`` (:issue:`56492`) - Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`) -- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`) - Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`) - Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`) -- Improved error message that appears in :meth:`DatetimeIndex.to_period` with frequencies which are not supported as period frequencies, such as "BMS" (:issue:`56243`) -- Improved error message when constructing :class:`Period` with invalid offsets such as "QS" (:issue:`55785`) +- Improved error message that appears in :meth:`DatetimeIndex.to_period` with frequencies which are not supported as period frequencies, such as ``"BMS"`` (:issue:`56243`) +- Improved error message when constructing :class:`Period` with invalid offsets such as ``"QS"`` (:issue:`55785`) - The dtypes ``string[pyarrow]`` and ``string[pyarrow_numpy]`` now both utilize the ``large_string`` type from PyArrow to avoid overflow for long columns (:issue:`56259`) - .. --------------------------------------------------------------------------- .. _whatsnew_220.notable_bug_fixes: @@ -406,6 +405,8 @@ index levels when joining on two indexes with different levels (:issue:`34133`). left = pd.DataFrame({"left": 1}, index=pd.MultiIndex.from_tuples([("x", 1), ("x", 2)], names=["A", "B"])) right = pd.DataFrame({"right": 2}, index=pd.MultiIndex.from_tuples([(1, 1), (2, 2)], names=["B", "C"])) + left + right result = left.join(right) *Old Behavior* @@ -435,15 +436,6 @@ Backwards incompatible API changes Increased minimum versions for dependencies ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Some minimum supported versions of dependencies were updated. -If installed, we now require: - -+-----------------+-----------------+----------+---------+ -| Package | Minimum Version | Required | Changed | -+=================+=================+==========+=========+ -| | | X | X | -+-----------------+-----------------+----------+---------+ - For `optional libraries `_ the general recommendation is to use the latest version. The following table lists the lowest version per library that is currently being tested throughout the development of pandas. Optional libraries below the lowest tested version may still work, but are not considered supported. @@ -453,8 +445,6 @@ Optional libraries below the lowest tested version may still work, but are not c +=================+=================+=========+ | mypy (dev) | 1.8.0 | X | +-----------------+-----------------+---------+ -| | | X | -+-----------------+-----------------+---------+ See :ref:`install.dependencies` and :ref:`install.optional_dependencies` for more. @@ -626,20 +616,20 @@ Other Deprecations - Deprecated ``year``, ``month``, ``quarter``, ``day``, ``hour``, ``minute``, and ``second`` keywords in the :class:`PeriodIndex` constructor, use :meth:`PeriodIndex.from_fields` instead (:issue:`55960`) - Deprecated accepting a type as an argument in :meth:`Index.view`, call without any arguments instead (:issue:`55709`) - Deprecated allowing non-integer ``periods`` argument in :func:`date_range`, :func:`timedelta_range`, :func:`period_range`, and :func:`interval_range` (:issue:`56036`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_clipboard`. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_csv` except ``path_or_buf``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_dict`. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_excel` except ``excel_writer``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_gbq` except ``destination_table``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_hdf` except ``path_or_buf``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_html` except ``buf``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_json` except ``path_or_buf``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_latex` except ``buf``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_markdown` except ``buf``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_parquet` except ``path``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_pickle` except ``path``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_string` except ``buf``. (:issue:`54229`) -- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_xml` except ``path_or_buffer``. (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_clipboard` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_csv` except ``path_or_buf`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_dict` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_excel` except ``excel_writer`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_gbq` except ``destination_table`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_hdf` except ``path_or_buf`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_html` except ``buf`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_json` except ``path_or_buf`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_latex` except ``buf`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_markdown` except ``buf`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_parquet` except ``path`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_pickle` except ``path`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_string` except ``buf`` (:issue:`54229`) +- Deprecated allowing non-keyword arguments in :meth:`DataFrame.to_xml` except ``path_or_buffer`` (:issue:`54229`) - Deprecated allowing passing :class:`BlockManager` objects to :class:`DataFrame` or :class:`SingleBlockManager` objects to :class:`Series` (:issue:`52419`) - Deprecated behavior of :meth:`Index.insert` with an object-dtype index silently performing type inference on the result, explicitly call ``result.infer_objects(copy=False)`` for the old behavior instead (:issue:`51363`) - Deprecated casting non-datetimelike values (mainly strings) in :meth:`Series.isin` and :meth:`Index.isin` with ``datetime64``, ``timedelta64``, and :class:`PeriodDtype` dtypes (:issue:`53111`) @@ -712,31 +702,30 @@ Bug fixes Categorical ^^^^^^^^^^^ - :meth:`Categorical.isin` raising ``InvalidIndexError`` for categorical containing overlapping :class:`Interval` values (:issue:`34974`) -- Bug in :meth:`CategoricalDtype.__eq__` returning false for unordered categorical data with mixed types (:issue:`55468`) -- +- Bug in :meth:`CategoricalDtype.__eq__` returning ``False`` for unordered categorical data with mixed types (:issue:`55468`) Datetimelike ^^^^^^^^^^^^ - Bug in :class:`DatetimeIndex` construction when passing both a ``tz`` and either ``dayfirst`` or ``yearfirst`` ignoring dayfirst/yearfirst (:issue:`55813`) - Bug in :class:`DatetimeIndex` when passing an object-dtype ndarray of float objects and a ``tz`` incorrectly localizing the result (:issue:`55780`) - Bug in :func:`Series.isin` with :class:`DatetimeTZDtype` dtype and comparison values that are all ``NaT`` incorrectly returning all-``False`` even if the series contains ``NaT`` entries (:issue:`56427`) -- Bug in :func:`concat` raising ``AttributeError`` when concatenating all-NA DataFrame with :class:`DatetimeTZDtype` dtype DataFrame. (:issue:`52093`) +- Bug in :func:`concat` raising ``AttributeError`` when concatenating all-NA DataFrame with :class:`DatetimeTZDtype` dtype DataFrame (:issue:`52093`) - Bug in :func:`testing.assert_extension_array_equal` that could use the wrong unit when comparing resolutions (:issue:`55730`) - Bug in :func:`to_datetime` and :class:`DatetimeIndex` when passing a list of mixed-string-and-numeric types incorrectly raising (:issue:`55780`) - Bug in :func:`to_datetime` and :class:`DatetimeIndex` when passing mixed-type objects with a mix of timezones or mix of timezone-awareness failing to raise ``ValueError`` (:issue:`55693`) +- Bug in :meth:`.Tick.delta` with very large ticks raising ``OverflowError`` instead of ``OutOfBoundsTimedelta`` (:issue:`55503`) - Bug in :meth:`DatetimeIndex.shift` with non-nanosecond resolution incorrectly returning with nanosecond resolution (:issue:`56117`) - Bug in :meth:`DatetimeIndex.union` returning object dtype for tz-aware indexes with the same timezone but different units (:issue:`55238`) - Bug in :meth:`Index.is_monotonic_increasing` and :meth:`Index.is_monotonic_decreasing` always caching :meth:`Index.is_unique` as ``True`` when first value in index is ``NaT`` (:issue:`55755`) - Bug in :meth:`Index.view` to a datetime64 dtype with non-supported resolution incorrectly raising (:issue:`55710`) - Bug in :meth:`Series.dt.round` with non-nanosecond resolution and ``NaT`` entries incorrectly raising ``OverflowError`` (:issue:`56158`) - Bug in :meth:`Series.fillna` with non-nanosecond resolution dtypes and higher-resolution vector values returning incorrect (internally-corrupted) results (:issue:`56410`) -- Bug in :meth:`Tick.delta` with very large ticks raising ``OverflowError`` instead of ``OutOfBoundsTimedelta`` (:issue:`55503`) - Bug in :meth:`Timestamp.unit` being inferred incorrectly from an ISO8601 format string with minute or hour resolution and a timezone offset (:issue:`56208`) -- Bug in ``.astype`` converting from a higher-resolution ``datetime64`` dtype to a lower-resolution ``datetime64`` dtype (e.g. ``datetime64[us]->datetim64[ms]``) silently overflowing with values near the lower implementation bound (:issue:`55979`) +- Bug in ``.astype`` converting from a higher-resolution ``datetime64`` dtype to a lower-resolution ``datetime64`` dtype (e.g. ``datetime64[us]->datetime64[ms]``) silently overflowing with values near the lower implementation bound (:issue:`55979`) - Bug in adding or subtracting a :class:`Week` offset to a ``datetime64`` :class:`Series`, :class:`Index`, or :class:`DataFrame` column with non-nanosecond resolution returning incorrect results (:issue:`55583`) - Bug in addition or subtraction of :class:`BusinessDay` offset with ``offset`` attribute to non-nanosecond :class:`Index`, :class:`Series`, or :class:`DataFrame` column giving incorrect results (:issue:`55608`) - Bug in addition or subtraction of :class:`DateOffset` objects with microsecond components to ``datetime64`` :class:`Index`, :class:`Series`, or :class:`DataFrame` columns with non-nanosecond resolution (:issue:`55595`) -- Bug in addition or subtraction of very large :class:`Tick` objects with :class:`Timestamp` or :class:`Timedelta` objects raising ``OverflowError`` instead of ``OutOfBoundsTimedelta`` (:issue:`55503`) +- Bug in addition or subtraction of very large :class:`.Tick` objects with :class:`Timestamp` or :class:`Timedelta` objects raising ``OverflowError`` instead of ``OutOfBoundsTimedelta`` (:issue:`55503`) - Bug in creating a :class:`Index`, :class:`Series`, or :class:`DataFrame` with a non-nanosecond :class:`DatetimeTZDtype` and inputs that would be out of bounds with nanosecond resolution incorrectly raising ``OutOfBoundsDatetime`` (:issue:`54620`) - Bug in creating a :class:`Index`, :class:`Series`, or :class:`DataFrame` with a non-nanosecond ``datetime64`` (or :class:`DatetimeTZDtype`) from mixed-numeric inputs treating those as nanoseconds instead of as multiples of the dtype's unit (which would happen with non-mixed numeric inputs) (:issue:`56004`) - Bug in creating a :class:`Index`, :class:`Series`, or :class:`DataFrame` with a non-nanosecond ``datetime64`` dtype and inputs that would be out of bounds for a ``datetime64[ns]`` incorrectly raising ``OutOfBoundsDatetime`` (:issue:`55756`) @@ -759,14 +748,12 @@ Numeric ^^^^^^^ - Bug in :func:`read_csv` with ``engine="pyarrow"`` causing rounding errors for large integers (:issue:`52505`) - Bug in :meth:`Series.pow` not filling missing values correctly (:issue:`55512`) -- Conversion ^^^^^^^^^^ - Bug in :meth:`DataFrame.astype` when called with ``str`` on unpickled array - the array might change in-place (:issue:`54654`) - Bug in :meth:`DataFrame.astype` where ``errors="ignore"`` had no effect for extension types (:issue:`54654`) - Bug in :meth:`Series.convert_dtypes` not converting all NA column to ``null[pyarrow]`` (:issue:`55346`) -- Strings ^^^^^^^ @@ -783,13 +770,12 @@ Strings Interval ^^^^^^^^ -- Bug in :class:`Interval` ``__repr__`` not displaying UTC offsets for :class:`Timestamp` bounds. Additionally the hour, minute and second components will now be shown. (:issue:`55015`) +- Bug in :class:`Interval` ``__repr__`` not displaying UTC offsets for :class:`Timestamp` bounds. Additionally the hour, minute and second components will now be shown (:issue:`55015`) - Bug in :meth:`IntervalIndex.factorize` and :meth:`Series.factorize` with :class:`IntervalDtype` with datetime64 or timedelta64 intervals not preserving non-nanosecond units (:issue:`56099`) - Bug in :meth:`IntervalIndex.from_arrays` when passed ``datetime64`` or ``timedelta64`` arrays with mismatched resolutions constructing an invalid ``IntervalArray`` object (:issue:`55714`) - Bug in :meth:`IntervalIndex.get_indexer` with datetime or timedelta intervals incorrectly matching on integer targets (:issue:`47772`) - Bug in :meth:`IntervalIndex.get_indexer` with timezone-aware datetime intervals incorrectly matching on a sequence of timezone-naive targets (:issue:`47772`) - Bug in setting values on a :class:`Series` with an :class:`IntervalIndex` using a slice incorrectly raising (:issue:`54722`) -- Indexing ^^^^^^^^ @@ -801,25 +787,23 @@ Indexing Missing ^^^^^^^ - Bug in :meth:`DataFrame.update` wasn't updating in-place for tz-aware datetime64 dtypes (:issue:`56227`) -- MultiIndex ^^^^^^^^^^ - Bug in :meth:`MultiIndex.get_indexer` not raising ``ValueError`` when ``method`` provided and index is non-monotonic (:issue:`53452`) -- I/O ^^^ -- Bug in :func:`read_csv` where ``engine="python"`` did not respect ``chunksize`` arg when ``skiprows`` was specified. (:issue:`56323`) -- Bug in :func:`read_csv` where ``engine="python"`` was causing a ``TypeError`` when a callable ``skiprows`` and a chunk size was specified. (:issue:`55677`) -- Bug in :func:`read_csv` where ``on_bad_lines="warn"`` would write to ``stderr`` instead of raise a Python warning. This now yields a :class:`.errors.ParserWarning` (:issue:`54296`) +- Bug in :func:`read_csv` where ``engine="python"`` did not respect ``chunksize`` arg when ``skiprows`` was specified (:issue:`56323`) +- Bug in :func:`read_csv` where ``engine="python"`` was causing a ``TypeError`` when a callable ``skiprows`` and a chunk size was specified (:issue:`55677`) +- Bug in :func:`read_csv` where ``on_bad_lines="warn"`` would write to ``stderr`` instead of raising a Python warning; this now yields a :class:`.errors.ParserWarning` (:issue:`54296`) - Bug in :func:`read_csv` with ``engine="pyarrow"`` where ``quotechar`` was ignored (:issue:`52266`) -- Bug in :func:`read_csv` with ``engine="pyarrow"`` where ``usecols`` wasn't working with a csv with no headers (:issue:`54459`) -- Bug in :func:`read_excel`, with ``engine="xlrd"`` (``xls`` files) erroring when file contains NaNs/Infs (:issue:`54564`) +- Bug in :func:`read_csv` with ``engine="pyarrow"`` where ``usecols`` wasn't working with a CSV with no headers (:issue:`54459`) +- Bug in :func:`read_excel`, with ``engine="xlrd"`` (``xls`` files) erroring when the file contains ``NaN`` or ``Inf`` (:issue:`54564`) - Bug in :func:`read_json` not handling dtype conversion properly if ``infer_string`` is set (:issue:`56195`) -- Bug in :meth:`DataFrame.to_excel`, with ``OdsWriter`` (``ods`` files) writing boolean/string value (:issue:`54994`) +- Bug in :meth:`DataFrame.to_excel`, with ``OdsWriter`` (``ods`` files) writing Boolean/string value (:issue:`54994`) - Bug in :meth:`DataFrame.to_hdf` and :func:`read_hdf` with ``datetime64`` dtypes with non-nanosecond resolution failing to round-trip correctly (:issue:`55622`) -- Bug in :meth:`~pandas.read_excel` with ``engine="odf"`` (``ods`` files) when string contains annotation (:issue:`55200`) +- Bug in :meth:`~pandas.read_excel` with ``engine="odf"`` (``ods`` files) when a string cell contains an annotation (:issue:`55200`) - Bug in :meth:`~pandas.read_excel` with an ODS file without cached formatted cell for float values (:issue:`55219`) - Bug where :meth:`DataFrame.to_json` would raise an ``OverflowError`` instead of a ``TypeError`` with unsupported NumPy types (:issue:`55403`) @@ -828,12 +812,11 @@ Period - Bug in :class:`PeriodIndex` construction when more than one of ``data``, ``ordinal`` and ``**fields`` are passed failing to raise ``ValueError`` (:issue:`55961`) - Bug in :class:`Period` addition silently wrapping around instead of raising ``OverflowError`` (:issue:`55503`) - Bug in casting from :class:`PeriodDtype` with ``astype`` to ``datetime64`` or :class:`DatetimeTZDtype` with non-nanosecond unit incorrectly returning with nanosecond unit (:issue:`55958`) -- Plotting ^^^^^^^^ -- Bug in :meth:`DataFrame.plot.box` with ``vert=False`` and a matplotlib ``Axes`` created with ``sharey=True`` (:issue:`54941`) -- Bug in :meth:`DataFrame.plot.scatter` discaring string columns (:issue:`56142`) +- Bug in :meth:`DataFrame.plot.box` with ``vert=False`` and a Matplotlib ``Axes`` created with ``sharey=True`` (:issue:`54941`) +- Bug in :meth:`DataFrame.plot.scatter` discarding string columns (:issue:`56142`) - Bug in :meth:`Series.plot` when reusing an ``ax`` object failing to raise when a ``how`` keyword is passed (:issue:`55953`) Groupby/resample/rolling @@ -841,9 +824,9 @@ Groupby/resample/rolling - Bug in :class:`.Rolling` where duplicate datetimelike indexes are treated as consecutive rather than equal with ``closed='left'`` and ``closed='neither'`` (:issue:`20712`) - Bug in :meth:`.DataFrameGroupBy.idxmin`, :meth:`.DataFrameGroupBy.idxmax`, :meth:`.SeriesGroupBy.idxmin`, and :meth:`.SeriesGroupBy.idxmax` would not retain :class:`.Categorical` dtype when the index was a :class:`.CategoricalIndex` that contained NA values (:issue:`54234`) - Bug in :meth:`.DataFrameGroupBy.transform` and :meth:`.SeriesGroupBy.transform` when ``observed=False`` and ``f="idxmin"`` or ``f="idxmax"`` would incorrectly raise on unobserved categories (:issue:`54234`) -- Bug in :meth:`.DataFrameGroupBy.value_counts` and :meth:`.SeriesGroupBy.value_count` could result in incorrect sorting if the columns of the DataFrame or name of the Series are integers (:issue:`55951`) -- Bug in :meth:`.DataFrameGroupBy.value_counts` and :meth:`.SeriesGroupBy.value_count` would not respect ``sort=False`` in :meth:`DataFrame.groupby` and :meth:`Series.groupby` (:issue:`55951`) -- Bug in :meth:`.DataFrameGroupBy.value_counts` and :meth:`.SeriesGroupBy.value_count` would sort by proportions rather than frequencies when ``sort=True`` and ``normalize=True`` (:issue:`55951`) +- Bug in :meth:`.DataFrameGroupBy.value_counts` and :meth:`.SeriesGroupBy.value_counts` could result in incorrect sorting if the columns of the DataFrame or name of the Series are integers (:issue:`55951`) +- Bug in :meth:`.DataFrameGroupBy.value_counts` and :meth:`.SeriesGroupBy.value_counts` would not respect ``sort=False`` in :meth:`DataFrame.groupby` and :meth:`Series.groupby` (:issue:`55951`) +- Bug in :meth:`.DataFrameGroupBy.value_counts` and :meth:`.SeriesGroupBy.value_counts` would sort by proportions rather than frequencies when ``sort=True`` and ``normalize=True`` (:issue:`55951`) - Bug in :meth:`DataFrame.asfreq` and :meth:`Series.asfreq` with a :class:`DatetimeIndex` with non-nanosecond resolution incorrectly converting to nanosecond resolution (:issue:`55958`) - Bug in :meth:`DataFrame.ewm` when passed ``times`` with non-nanosecond ``datetime64`` or :class:`DatetimeTZDtype` dtype (:issue:`56262`) - Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` where grouping by a combination of ``Decimal`` and NA values would fail when ``sort=True`` (:issue:`54847`) @@ -865,22 +848,11 @@ Reshaping - Bug in :meth:`DataFrame.melt` where it would not preserve the datetime (:issue:`55254`) - Bug in :meth:`DataFrame.pivot_table` where the row margin is incorrect when the columns have numeric names (:issue:`26568`) - Bug in :meth:`DataFrame.pivot` with numeric columns and extension dtype for data (:issue:`56528`) -- Bug in :meth:`DataFrame.stack` and :meth:`Series.stack` with ``future_stack=True`` would not preserve NA values in the index (:issue:`56573`) +- Bug in :meth:`DataFrame.stack` with ``future_stack=True`` would not preserve NA values in the index (:issue:`56573`) Sparse ^^^^^^ - Bug in :meth:`SparseArray.take` when using a different fill value than the array's fill value (:issue:`55181`) -- - -ExtensionArray -^^^^^^^^^^^^^^ -- -- - -Styler -^^^^^^ -- -- Other ^^^^^ @@ -891,15 +863,11 @@ Other - Bug in :meth:`DataFrame.apply` where passing ``raw=True`` ignored ``args`` passed to the applied function (:issue:`55009`) - Bug in :meth:`DataFrame.from_dict` which would always sort the rows of the created :class:`DataFrame`. (:issue:`55683`) - Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` raising a ``ValueError`` (:issue:`56478`) -- Bug in rendering ``inf`` values inside a a :class:`DataFrame` with the ``use_inf_as_na`` option enabled (:issue:`55483`) +- Bug in rendering ``inf`` values inside a :class:`DataFrame` with the ``use_inf_as_na`` option enabled (:issue:`55483`) - Bug in rendering a :class:`Series` with a :class:`MultiIndex` when one of the index level's names is 0 not having that name displayed (:issue:`55415`) - Bug in the error message when assigning an empty :class:`DataFrame` to a column (:issue:`55956`) - Bug when time-like strings were being cast to :class:`ArrowDtype` with ``pyarrow.time64`` type (:issue:`56463`) -.. ***DO NOT USE THIS SECTION*** - -- -- .. --------------------------------------------------------------------------- .. _whatsnew_220.contributors: From 918a19e3625a6c301c9030593d149f6919a0b9f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Wed, 27 Dec 2023 18:38:22 -0500 Subject: [PATCH 29/34] TYP: Fix some PythonParser and Plotting types (#56643) * TYP: fix some annotations * pyupgrade --- pandas/core/interchange/from_dataframe.py | 37 +++++++++++++++- pandas/io/parsers/python_parser.py | 51 ++++++++++++----------- pandas/plotting/_matplotlib/boxplot.py | 4 +- pandas/plotting/_matplotlib/hist.py | 4 +- pandas/plotting/_matplotlib/style.py | 40 +++++++++++++++++- pandas/plotting/_matplotlib/tools.py | 9 ++-- pyright_reportGeneralTypeIssues.json | 2 +- 7 files changed, 106 insertions(+), 41 deletions(-) diff --git a/pandas/core/interchange/from_dataframe.py b/pandas/core/interchange/from_dataframe.py index d45ae37890ba7..73f492c83c2ff 100644 --- a/pandas/core/interchange/from_dataframe.py +++ b/pandas/core/interchange/from_dataframe.py @@ -2,7 +2,10 @@ import ctypes import re -from typing import Any +from typing import ( + Any, + overload, +) import numpy as np @@ -459,12 +462,42 @@ def buffer_to_ndarray( return np.array([], dtype=ctypes_type) +@overload +def set_nulls( + data: np.ndarray, + col: Column, + validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, + allow_modify_inplace: bool = ..., +) -> np.ndarray: + ... + + +@overload +def set_nulls( + data: pd.Series, + col: Column, + validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, + allow_modify_inplace: bool = ..., +) -> pd.Series: + ... + + +@overload +def set_nulls( + data: np.ndarray | pd.Series, + col: Column, + validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, + allow_modify_inplace: bool = ..., +) -> np.ndarray | pd.Series: + ... + + def set_nulls( data: np.ndarray | pd.Series, col: Column, validity: tuple[Buffer, tuple[DtypeKind, int, str, str]] | None, allow_modify_inplace: bool = True, -): +) -> np.ndarray | pd.Series: """ Set null values for the data according to the column null kind. diff --git a/pandas/io/parsers/python_parser.py b/pandas/io/parsers/python_parser.py index 79e7554a5744c..c1880eb815032 100644 --- a/pandas/io/parsers/python_parser.py +++ b/pandas/io/parsers/python_parser.py @@ -4,12 +4,6 @@ abc, defaultdict, ) -from collections.abc import ( - Hashable, - Iterator, - Mapping, - Sequence, -) import csv from io import StringIO import re @@ -50,15 +44,24 @@ ) if TYPE_CHECKING: + from collections.abc import ( + Hashable, + Iterator, + Mapping, + Sequence, + ) + from pandas._typing import ( ArrayLike, ReadCsvBuffer, Scalar, + T, ) from pandas import ( Index, MultiIndex, + Series, ) # BOM character (byte order mark) @@ -77,7 +80,7 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None: """ super().__init__(kwds) - self.data: Iterator[str] | None = None + self.data: Iterator[list[str]] | list[list[Scalar]] = [] self.buf: list = [] self.pos = 0 self.line_pos = 0 @@ -116,10 +119,11 @@ def __init__(self, f: ReadCsvBuffer[str] | list, **kwds) -> None: # Set self.data to something that can read lines. if isinstance(f, list): - # read_excel: f is a list - self.data = cast(Iterator[str], f) + # read_excel: f is a nested list, can contain non-str + self.data = f else: assert hasattr(f, "readline") + # yields list of str self.data = self._make_reader(f) # Get columns in two steps: infer from data, then @@ -179,7 +183,7 @@ def num(self) -> re.Pattern: ) return re.compile(regex) - def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]): + def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> Iterator[list[str]]: sep = self.delimiter if sep is None or len(sep) == 1: @@ -246,7 +250,9 @@ def _read(): def read( self, rows: int | None = None ) -> tuple[ - Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike] + Index | None, + Sequence[Hashable] | MultiIndex, + Mapping[Hashable, ArrayLike | Series], ]: try: content = self._get_lines(rows) @@ -326,7 +332,9 @@ def _exclude_implicit_index( def get_chunk( self, size: int | None = None ) -> tuple[ - Index | None, Sequence[Hashable] | MultiIndex, Mapping[Hashable, ArrayLike] + Index | None, + Sequence[Hashable] | MultiIndex, + Mapping[Hashable, ArrayLike | Series], ]: if size is None: # error: "PythonParser" has no attribute "chunksize" @@ -689,7 +697,7 @@ def _check_for_bom(self, first_row: list[Scalar]) -> list[Scalar]: new_row_list: list[Scalar] = [new_row] return new_row_list + first_row[1:] - def _is_line_empty(self, line: list[Scalar]) -> bool: + def _is_line_empty(self, line: Sequence[Scalar]) -> bool: """ Check if a line is empty or not. @@ -730,8 +738,6 @@ def _next_line(self) -> list[Scalar]: else: while self.skipfunc(self.pos): self.pos += 1 - # assert for mypy, data is Iterator[str] or None, would error in next - assert self.data is not None next(self.data) while True: @@ -800,12 +806,10 @@ def _next_iter_line(self, row_num: int) -> list[Scalar] | None: The row number of the line being parsed. """ try: - # assert for mypy, data is Iterator[str] or None, would error in next - assert self.data is not None + assert not isinstance(self.data, list) line = next(self.data) - # for mypy - assert isinstance(line, list) - return line + # lie about list[str] vs list[Scalar] to minimize ignores + return line # type: ignore[return-value] except csv.Error as e: if self.on_bad_lines in ( self.BadLineHandleMethod.ERROR, @@ -855,7 +859,7 @@ def _check_comments(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: ret.append(rl) return ret - def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: + def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]: """ Iterate through the lines and remove any that are either empty or contain only one whitespace value @@ -1121,9 +1125,6 @@ def _get_lines(self, rows: int | None = None) -> list[list[Scalar]]: row_ct = 0 offset = self.pos if self.pos is not None else 0 while row_ct < rows: - # assert for mypy, data is Iterator[str] or None, would - # error in next - assert self.data is not None new_row = next(self.data) if not self.skipfunc(offset + row_index): row_ct += 1 @@ -1338,7 +1339,7 @@ def _make_reader(self, f: IO[str] | ReadCsvBuffer[str]) -> FixedWidthReader: self.infer_nrows, ) - def _remove_empty_lines(self, lines: list[list[Scalar]]) -> list[list[Scalar]]: + def _remove_empty_lines(self, lines: list[list[T]]) -> list[list[T]]: """ Returns the list of lines without the empty ones. With fixed-width fields, empty lines become arrays of empty strings. diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index d2b76decaa75d..084452ec23719 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -371,8 +371,8 @@ def _get_colors(): # num_colors=3 is required as method maybe_color_bp takes the colors # in positions 0 and 2. # if colors not provided, use same defaults as DataFrame.plot.box - result = get_standard_colors(num_colors=3) - result = np.take(result, [0, 0, 2]) + result_list = get_standard_colors(num_colors=3) + result = np.take(result_list, [0, 0, 2]) result = np.append(result, "k") colors = kwds.pop("color", None) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index e610f1adb602c..898abc9b78e3f 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -457,10 +457,8 @@ def hist_series( ax.grid(grid) axes = np.array([ax]) - # error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any, - # dtype[Any]]"; expected "Axes | Sequence[Axes]" set_ticks_props( - axes, # type: ignore[arg-type] + axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, diff --git a/pandas/plotting/_matplotlib/style.py b/pandas/plotting/_matplotlib/style.py index bf4e4be3bfd82..45a077a6151cf 100644 --- a/pandas/plotting/_matplotlib/style.py +++ b/pandas/plotting/_matplotlib/style.py @@ -3,11 +3,13 @@ from collections.abc import ( Collection, Iterator, + Sequence, ) import itertools from typing import ( TYPE_CHECKING, cast, + overload, ) import warnings @@ -26,12 +28,46 @@ from matplotlib.colors import Colormap +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: dict[str, Color], +) -> dict[str, Color]: + ... + + +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: Color | Sequence[Color] | None = ..., +) -> list[Color]: + ... + + +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: dict[str, Color] | Color | Sequence[Color] | None = ..., +) -> dict[str, Color] | list[Color]: + ... + + def get_standard_colors( num_colors: int, colormap: Colormap | None = None, color_type: str = "default", - color: dict[str, Color] | Color | Collection[Color] | None = None, -): + *, + color: dict[str, Color] | Color | Sequence[Color] | None = None, +) -> dict[str, Color] | list[Color]: """ Get standard colors based on `colormap`, `color_type` or `color` inputs. diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py index 898b5b25e7b01..89a8a7cf79719 100644 --- a/pandas/plotting/_matplotlib/tools.py +++ b/pandas/plotting/_matplotlib/tools.py @@ -19,10 +19,7 @@ ) if TYPE_CHECKING: - from collections.abc import ( - Iterable, - Sequence, - ) + from collections.abc import Iterable from matplotlib.axes import Axes from matplotlib.axis import Axis @@ -442,7 +439,7 @@ def handle_shared_axes( _remove_labels_from_axis(ax.yaxis) -def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray: +def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray: if not is_list_like(axes): return np.array([axes]) elif isinstance(axes, (np.ndarray, ABCIndex)): @@ -451,7 +448,7 @@ def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray: def set_ticks_props( - axes: Axes | Sequence[Axes], + axes: Axes | Iterable[Axes], xlabelsize: int | None = None, xrot=None, ylabelsize: int | None = None, diff --git a/pyright_reportGeneralTypeIssues.json b/pyright_reportGeneralTypeIssues.json index a38343d6198ae..da27906e041cf 100644 --- a/pyright_reportGeneralTypeIssues.json +++ b/pyright_reportGeneralTypeIssues.json @@ -99,11 +99,11 @@ "pandas/io/parsers/base_parser.py", "pandas/io/parsers/c_parser_wrapper.py", "pandas/io/pytables.py", - "pandas/io/sas/sas_xport.py", "pandas/io/sql.py", "pandas/io/stata.py", "pandas/plotting/_matplotlib/boxplot.py", "pandas/plotting/_matplotlib/core.py", + "pandas/plotting/_matplotlib/misc.py", "pandas/plotting/_matplotlib/timeseries.py", "pandas/plotting/_matplotlib/tools.py", "pandas/tseries/frequencies.py", From 5744df23089eb154bf84c9ac2bdfe186a012bcf3 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 28 Dec 2023 01:02:39 +0100 Subject: [PATCH 30/34] BUG: Series.to_numpy raising for arrow floats to numpy floats (#56644) * BUG: Series.to_numpy raising for arrow floats to numpy floats * Fixup --- pandas/core/arrays/arrow/array.py | 11 ++++++++++- pandas/tests/extension/test_arrow.py | 8 ++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 23b5448029dd9..de1ed9ecfdaf1 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -37,6 +37,7 @@ CategoricalDtype, is_array_like, is_bool_dtype, + is_float_dtype, is_integer, is_list_like, is_numeric_dtype, @@ -1320,6 +1321,7 @@ def to_numpy( copy: bool = False, na_value: object = lib.no_default, ) -> np.ndarray: + original_na_value = na_value dtype, na_value = to_numpy_dtype_inference(self, dtype, na_value, self._hasna) pa_type = self._pa_array.type if not self._hasna or isna(na_value) or pa.types.is_null(pa_type): @@ -1345,7 +1347,14 @@ def to_numpy( if dtype is not None and isna(na_value): na_value = None result = np.full(len(data), fill_value=na_value, dtype=dtype) - elif not data._hasna or (pa.types.is_floating(pa_type) and na_value is np.nan): + elif not data._hasna or ( + pa.types.is_floating(pa_type) + and ( + na_value is np.nan + or original_na_value is lib.no_default + and is_float_dtype(dtype) + ) + ): result = data._pa_array.to_numpy() if dtype is not None: result = result.astype(dtype, copy=False) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 3b03272f18203..5624acfb64764 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -3153,6 +3153,14 @@ def test_string_to_time_parsing_cast(): tm.assert_series_equal(result, expected) +def test_to_numpy_float(): + # GH#56267 + ser = pd.Series([32, 40, None], dtype="float[pyarrow]") + result = ser.astype("float64") + expected = pd.Series([32, 40, np.nan], dtype="float64") + tm.assert_series_equal(result, expected) + + def test_to_numpy_timestamp_to_int(): # GH 55997 ser = pd.Series(["2020-01-01 04:30:00"], dtype="timestamp[ns][pyarrow]") From bc6ba0eb076db71cfff261940ad2121910bc7a8e Mon Sep 17 00:00:00 2001 From: samukweku Date: Thu, 28 Dec 2023 11:29:48 +1100 Subject: [PATCH 31/34] updates based on feedback --- pandas/core/series.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index 586facf757c5c..5f5a1fbe6c62d 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5677,20 +5677,12 @@ def case_when( Examples -------- - >>> df = pd.DataFrame({ - ... "a": [0,0,1,2], - ... "b": [0,3,4,5], - ... "c": [6,7,8,9] - ... }) - >>> df - a b c - 0 0 0 6 - 1 0 3 7 - 2 1 4 8 - 3 2 5 9 - - >>> df.c.case_when(caselist=[(df.a.gt(0), df.a), # condition, replacement - ... (df.b.gt(0), df.b)]) + >>> c = pd.Series([6,7,8,9],name='c') + >>> a = pd.Series([0,0,1,2]) + >>> b = pd.Series([0,3,4,5]) + + >>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement + ... (b.gt(0), b)]) 0 6 1 3 2 1 From a0f479779cb308e28f72ac0da9a072b604f01998 Mon Sep 17 00:00:00 2001 From: samukweku Date: Thu, 28 Dec 2023 11:36:28 +1100 Subject: [PATCH 32/34] add to API reference --- doc/source/reference/series.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/reference/series.rst b/doc/source/reference/series.rst index af262f9e6c336..a4ea0ec396ceb 100644 --- a/doc/source/reference/series.rst +++ b/doc/source/reference/series.rst @@ -177,6 +177,7 @@ Reindexing / selection / label manipulation :toctree: api/ Series.align + Series.case_when Series.drop Series.droplevel Series.drop_duplicates From cb7d6e36ae7ab4b70f67ccc9ba5749ed360978d2 Mon Sep 17 00:00:00 2001 From: samukweku Date: Thu, 28 Dec 2023 12:36:07 +1100 Subject: [PATCH 33/34] fix whitespace --- pandas/core/series.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index 5f5a1fbe6c62d..a5e58a6d1c6ac 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5677,9 +5677,9 @@ def case_when( Examples -------- - >>> c = pd.Series([6,7,8,9],name='c') - >>> a = pd.Series([0,0,1,2]) - >>> b = pd.Series([0,3,4,5]) + >>> c = pd.Series([6, 7, 8, 9], name='c') + >>> a = pd.Series([0, 0, 1, 2]) + >>> b = pd.Series([0, 3, 4, 5]) >>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement ... (b.gt(0), b)]) From 9679b9e90d164e9f55a0248529bf1921a190a59b Mon Sep 17 00:00:00 2001 From: Samuel Oranyeli Date: Fri, 5 Jan 2024 19:54:59 +1100 Subject: [PATCH 34/34] Update series.py Fix PR05 error --- pandas/core/series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/series.py b/pandas/core/series.py index a5e58a6d1c6ac..8f9f574dfdc9d 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -5650,7 +5650,7 @@ def case_when( Parameters ---------- - caselist : A list of tuples of conditions and expected replacements. + caselist : A list of tuples of conditions and expected replacements Takes the form: ``(condition0, replacement0)``, ``(condition1, replacement1)``, ... . ``condition`` should be a 1-D boolean array-like object