Skip to content

Commit b4e11f1

Browse files
committed
TYPING: Added types for tests files
Working around a strange typing issue. See pandas-dev#28394 (comment) for more, but the types on these were being inferred incorrectly by mypy with just the addition of the `allows_duplicate_labels` kwarg.
1 parent da1401b commit b4e11f1

File tree

10 files changed

+140
-148
lines changed

10 files changed

+140
-148
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,4 @@ doc/build/html/index.html
118118
doc/tmp.sv
119119
env/
120120
doc/source/savefig/
121+
.dmypy.json

pandas/_typing.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pandas.core.indexes.base import Index # noqa: F401
2424
from pandas.core.series import Series # noqa: F401
2525
from pandas.core.generic import NDFrame # noqa: F401
26+
from pandas.core.base import IndexOpsMixin # noqa: F401
2627

2728

2829
AnyArrayLike = TypeVar("AnyArrayLike", "ExtensionArray", "Index", "Series", np.ndarray)
@@ -32,6 +33,7 @@
3233
FilePathOrBuffer = Union[str, Path, IO[AnyStr]]
3334

3435
FrameOrSeries = TypeVar("FrameOrSeries", bound="NDFrame")
36+
IndexOrSeries = TypeVar("IndexOrSeries", bound="IndexOpsMixin")
3537
Scalar = Union[str, int, float, bool]
3638
Axis = Union[str, int]
3739
Ordered = Optional[bool]

pandas/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pandas as pd
1616
from pandas import DataFrame
17+
from pandas._typing import IndexOrSeries
1718
from pandas.core import ops
1819
import pandas.util.testing as tm
1920

@@ -790,6 +791,17 @@ def tick_classes(request):
790791
return request.param
791792

792793

794+
index_or_series_params = [pd.Index, pd.Series] # type: IndexOrSeries
795+
796+
797+
@pytest.fixture(params=index_or_series_params, ids=["series", "index"])
798+
def index_or_series(request) -> IndexOrSeries:
799+
"""
800+
Parametrized fixture providing the Index or Series class.
801+
"""
802+
return request.param
803+
804+
793805
# ----------------------------------------------------------------
794806
# Global setup for tests using Hypothesis
795807

pandas/tests/arithmetic/test_numeric.py

+11-40
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from decimal import Decimal
66
from itertools import combinations
77
import operator
8+
from typing import List, Type, Union
89

910
import numpy as np
1011
import pytest
@@ -74,33 +75,22 @@ def test_compare_invalid(self):
7475

7576
# ------------------------------------------------------------------
7677
# Numeric dtypes Arithmetic with Datetime/Timedelta Scalar
78+
index_or_series_params = [
79+
pd.Series,
80+
pd.Index,
81+
] # type: List[Union[Type[pd.Index], Type[pd.RangeIndex], Type[pd.Series]]]
82+
left = [pd.RangeIndex(10, 40, 10)] # type: List[Union[Index, Series]]
83+
for cls in index_or_series_params:
84+
for dtype in ["i1", "i2", "i4", "i8", "u1", "u2", "u4", "u8", "f2", "f4", "f8"]:
85+
left.append(cls([10, 20, 30], dtype=dtype))
7786

7887

7988
class TestNumericArraylikeArithmeticWithDatetimeLike:
8089

8190
# TODO: also check name retentention
8291
@pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series])
8392
@pytest.mark.parametrize(
84-
"left",
85-
[pd.RangeIndex(10, 40, 10)]
86-
+ [
87-
cls([10, 20, 30], dtype=dtype)
88-
for dtype in [
89-
"i1",
90-
"i2",
91-
"i4",
92-
"i8",
93-
"u1",
94-
"u2",
95-
"u4",
96-
"u8",
97-
"f2",
98-
"f4",
99-
"f8",
100-
]
101-
for cls in [pd.Series, pd.Index]
102-
],
103-
ids=lambda x: type(x).__name__ + str(x.dtype),
93+
"left", left, ids=lambda x: type(x).__name__ + str(x.dtype)
10494
)
10595
def test_mul_td64arr(self, left, box_cls):
10696
# GH#22390
@@ -120,26 +110,7 @@ def test_mul_td64arr(self, left, box_cls):
120110
# TODO: also check name retentention
121111
@pytest.mark.parametrize("box_cls", [np.array, pd.Index, pd.Series])
122112
@pytest.mark.parametrize(
123-
"left",
124-
[pd.RangeIndex(10, 40, 10)]
125-
+ [
126-
cls([10, 20, 30], dtype=dtype)
127-
for dtype in [
128-
"i1",
129-
"i2",
130-
"i4",
131-
"i8",
132-
"u1",
133-
"u2",
134-
"u4",
135-
"u8",
136-
"f2",
137-
"f4",
138-
"f8",
139-
]
140-
for cls in [pd.Series, pd.Index]
141-
],
142-
ids=lambda x: type(x).__name__ + str(x.dtype),
113+
"left", left, ids=lambda x: type(x).__name__ + str(x.dtype)
143114
)
144115
def test_div_td64arr(self, left, box_cls):
145116
# GH#22390

pandas/tests/arrays/test_array.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,8 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
272272
return super()._from_sequence(scalars, dtype=dtype, copy=copy)
273273

274274

275-
@pytest.mark.parametrize("box", [pd.Series, pd.Index])
276-
def test_array_unboxes(box):
277-
data = box([decimal.Decimal("1"), decimal.Decimal("2")])
275+
def test_array_unboxes(index_or_series):
276+
data = index_or_series([decimal.Decimal("1"), decimal.Decimal("2")])
278277
# make sure it works
279278
with pytest.raises(TypeError):
280279
DecimalArray2._from_sequence(data)

pandas/tests/dtypes/test_concat.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pandas.core.dtypes.concat as _concat
44

5-
from pandas import DatetimeIndex, Index, Period, PeriodIndex, Series, TimedeltaIndex
5+
from pandas import DatetimeIndex, Period, PeriodIndex, Series, TimedeltaIndex
66

77

88
@pytest.mark.parametrize(
@@ -40,9 +40,8 @@
4040
),
4141
],
4242
)
43-
@pytest.mark.parametrize("klass", [Index, Series])
44-
def test_get_dtype_kinds(klass, to_concat, expected):
45-
to_concat_klass = [klass(c) for c in to_concat]
43+
def test_get_dtype_kinds(index_or_series, to_concat, expected):
44+
to_concat_klass = [index_or_series(c) for c in to_concat]
4645
result = _concat.get_dtype_kinds(to_concat_klass)
4746
assert result == set(expected)
4847

pandas/tests/indexing/test_coercion.py

+37-44
Original file line numberDiff line numberDiff line change
@@ -515,55 +515,52 @@ def _assert_where_conversion(
515515
res = target.where(cond, values)
516516
self._assert(res, expected, expected_dtype)
517517

518-
@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
519518
@pytest.mark.parametrize(
520519
"fill_val,exp_dtype",
521520
[(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)],
522521
)
523-
def test_where_object(self, klass, fill_val, exp_dtype):
524-
obj = klass(list("abcd"))
522+
def test_where_object(self, index_or_series, fill_val, exp_dtype):
523+
obj = index_or_series(list("abcd"))
525524
assert obj.dtype == np.object
526-
cond = klass([True, False, True, False])
525+
cond = index_or_series([True, False, True, False])
527526

528-
if fill_val is True and klass is pd.Series:
527+
if fill_val is True and index_or_series is pd.Series:
529528
ret_val = 1
530529
else:
531530
ret_val = fill_val
532531

533-
exp = klass(["a", ret_val, "c", ret_val])
532+
exp = index_or_series(["a", ret_val, "c", ret_val])
534533
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
535534

536535
if fill_val is True:
537-
values = klass([True, False, True, True])
536+
values = index_or_series([True, False, True, True])
538537
else:
539-
values = klass(fill_val * x for x in [5, 6, 7, 8])
538+
values = index_or_series(fill_val * x for x in [5, 6, 7, 8])
540539

541-
exp = klass(["a", values[1], "c", values[3]])
540+
exp = index_or_series(["a", values[1], "c", values[3]])
542541
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
543542

544-
@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
545543
@pytest.mark.parametrize(
546544
"fill_val,exp_dtype",
547545
[(1, np.int64), (1.1, np.float64), (1 + 1j, np.complex128), (True, np.object)],
548546
)
549-
def test_where_int64(self, klass, fill_val, exp_dtype):
550-
if klass is pd.Index and exp_dtype is np.complex128:
547+
def test_where_int64(self, index_or_series, fill_val, exp_dtype):
548+
if index_or_series is pd.Index and exp_dtype is np.complex128:
551549
pytest.skip("Complex Index not supported")
552-
obj = klass([1, 2, 3, 4])
550+
obj = index_or_series([1, 2, 3, 4])
553551
assert obj.dtype == np.int64
554-
cond = klass([True, False, True, False])
552+
cond = index_or_series([True, False, True, False])
555553

556-
exp = klass([1, fill_val, 3, fill_val])
554+
exp = index_or_series([1, fill_val, 3, fill_val])
557555
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
558556

559557
if fill_val is True:
560-
values = klass([True, False, True, True])
558+
values = index_or_series([True, False, True, True])
561559
else:
562-
values = klass(x * fill_val for x in [5, 6, 7, 8])
563-
exp = klass([1, values[1], 3, values[3]])
560+
values = index_or_series(x * fill_val for x in [5, 6, 7, 8])
561+
exp = index_or_series([1, values[1], 3, values[3]])
564562
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
565563

566-
@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
567564
@pytest.mark.parametrize(
568565
"fill_val, exp_dtype",
569566
[
@@ -573,21 +570,21 @@ def test_where_int64(self, klass, fill_val, exp_dtype):
573570
(True, np.object),
574571
],
575572
)
576-
def test_where_float64(self, klass, fill_val, exp_dtype):
577-
if klass is pd.Index and exp_dtype is np.complex128:
573+
def test_where_float64(self, index_or_series, fill_val, exp_dtype):
574+
if index_or_series is pd.Index and exp_dtype is np.complex128:
578575
pytest.skip("Complex Index not supported")
579-
obj = klass([1.1, 2.2, 3.3, 4.4])
576+
obj = index_or_series([1.1, 2.2, 3.3, 4.4])
580577
assert obj.dtype == np.float64
581-
cond = klass([True, False, True, False])
578+
cond = index_or_series([True, False, True, False])
582579

583-
exp = klass([1.1, fill_val, 3.3, fill_val])
580+
exp = index_or_series([1.1, fill_val, 3.3, fill_val])
584581
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
585582

586583
if fill_val is True:
587-
values = klass([True, False, True, True])
584+
values = index_or_series([True, False, True, True])
588585
else:
589-
values = klass(x * fill_val for x in [5, 6, 7, 8])
590-
exp = klass([1.1, values[1], 3.3, values[3]])
586+
values = index_or_series(x * fill_val for x in [5, 6, 7, 8])
587+
exp = index_or_series([1.1, values[1], 3.3, values[3]])
591588
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
592589

593590
@pytest.mark.parametrize(
@@ -783,19 +780,17 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype):
783780
res = target.fillna(value)
784781
self._assert(res, expected, expected_dtype)
785782

786-
@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
787783
@pytest.mark.parametrize(
788784
"fill_val, fill_dtype",
789785
[(1, np.object), (1.1, np.object), (1 + 1j, np.object), (True, np.object)],
790786
)
791-
def test_fillna_object(self, klass, fill_val, fill_dtype):
792-
obj = klass(["a", np.nan, "c", "d"])
787+
def test_fillna_object(self, index_or_series, fill_val, fill_dtype):
788+
obj = index_or_series(["a", np.nan, "c", "d"])
793789
assert obj.dtype == np.object
794790

795-
exp = klass(["a", fill_val, "c", "d"])
791+
exp = index_or_series(["a", fill_val, "c", "d"])
796792
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
797793

798-
@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
799794
@pytest.mark.parametrize(
800795
"fill_val,fill_dtype",
801796
[
@@ -805,15 +800,15 @@ def test_fillna_object(self, klass, fill_val, fill_dtype):
805800
(True, np.object),
806801
],
807802
)
808-
def test_fillna_float64(self, klass, fill_val, fill_dtype):
809-
obj = klass([1.1, np.nan, 3.3, 4.4])
803+
def test_fillna_float64(self, index_or_series, fill_val, fill_dtype):
804+
obj = index_or_series([1.1, np.nan, 3.3, 4.4])
810805
assert obj.dtype == np.float64
811806

812-
exp = klass([1.1, fill_val, 3.3, 4.4])
807+
exp = index_or_series([1.1, fill_val, 3.3, 4.4])
813808
# float + complex -> we don't support a complex Index
814809
# complex for Series,
815810
# object for Index
816-
if fill_dtype == np.complex128 and klass == pd.Index:
811+
if fill_dtype == np.complex128 and index_or_series == pd.Index:
817812
fill_dtype = np.object
818813
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
819814

@@ -833,7 +828,6 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype):
833828
exp = pd.Series([1 + 1j, fill_val, 3 + 3j, 4 + 4j])
834829
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
835830

836-
@pytest.mark.parametrize("klass", [pd.Series, pd.Index], ids=["series", "index"])
837831
@pytest.mark.parametrize(
838832
"fill_val,fill_dtype",
839833
[
@@ -844,8 +838,8 @@ def test_fillna_series_complex128(self, fill_val, fill_dtype):
844838
],
845839
ids=["datetime64", "datetime64tz", "object", "object"],
846840
)
847-
def test_fillna_datetime(self, klass, fill_val, fill_dtype):
848-
obj = klass(
841+
def test_fillna_datetime(self, index_or_series, fill_val, fill_dtype):
842+
obj = index_or_series(
849843
[
850844
pd.Timestamp("2011-01-01"),
851845
pd.NaT,
@@ -855,7 +849,7 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
855849
)
856850
assert obj.dtype == "datetime64[ns]"
857851

858-
exp = klass(
852+
exp = index_or_series(
859853
[
860854
pd.Timestamp("2011-01-01"),
861855
fill_val,
@@ -865,7 +859,6 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
865859
)
866860
self._assert_fillna_conversion(obj, fill_val, exp, fill_dtype)
867861

868-
@pytest.mark.parametrize("klass", [pd.Series, pd.Index])
869862
@pytest.mark.parametrize(
870863
"fill_val,fill_dtype",
871864
[
@@ -876,10 +869,10 @@ def test_fillna_datetime(self, klass, fill_val, fill_dtype):
876869
("x", np.object),
877870
],
878871
)
879-
def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype):
872+
def test_fillna_datetime64tz(self, index_or_series, fill_val, fill_dtype):
880873
tz = "US/Eastern"
881874

882-
obj = klass(
875+
obj = index_or_series(
883876
[
884877
pd.Timestamp("2011-01-01", tz=tz),
885878
pd.NaT,
@@ -889,7 +882,7 @@ def test_fillna_datetime64tz(self, klass, fill_val, fill_dtype):
889882
)
890883
assert obj.dtype == "datetime64[ns, US/Eastern]"
891884

892-
exp = klass(
885+
exp = index_or_series(
893886
[
894887
pd.Timestamp("2011-01-01", tz=tz),
895888
fill_val,

pandas/tests/io/json/test_json_table_schema.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -431,17 +431,15 @@ def test_date_format_raises(self):
431431
self.df.to_json(orient="table", date_format="iso")
432432
self.df.to_json(orient="table")
433433

434-
@pytest.mark.parametrize("kind", [pd.Series, pd.Index])
435-
def test_convert_pandas_type_to_json_field_int(self, kind):
434+
def test_convert_pandas_type_to_json_field_int(self, index_or_series):
436435
data = [1, 2, 3]
437-
result = convert_pandas_type_to_json_field(kind(data, name="name"))
436+
result = convert_pandas_type_to_json_field(index_or_series(data, name="name"))
438437
expected = {"name": "name", "type": "integer"}
439438
assert result == expected
440439

441-
@pytest.mark.parametrize("kind", [pd.Series, pd.Index])
442-
def test_convert_pandas_type_to_json_field_float(self, kind):
440+
def test_convert_pandas_type_to_json_field_float(self, index_or_series):
443441
data = [1.0, 2.0, 3.0]
444-
result = convert_pandas_type_to_json_field(kind(data, name="name"))
442+
result = convert_pandas_type_to_json_field(index_or_series(data, name="name"))
445443
expected = {"name": "name", "type": "number"}
446444
assert result == expected
447445

0 commit comments

Comments
 (0)