Skip to content

Commit 1b653b1

Browse files
authored
STYLE pre-commit check to ensure that test functions name starts with test (#50397)
* add check * put test back, its not a helper * reword * typing * dont exclude funcs starting with underscore Co-authored-by: MarcoGorelli <>
1 parent 2e1206e commit 1b653b1

15 files changed

+235
-81
lines changed

.pre-commit-config.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,13 @@ repos:
333333
additional_dependencies:
334334
- autotyping==22.9.0
335335
- libcst==0.4.7
336+
- id: check-test-naming
337+
name: check that test names start with 'test'
338+
entry: python -m scripts.check_test_naming
339+
types: [python]
340+
files: ^pandas/tests
341+
language: python
342+
exclude: |
343+
(?x)
344+
^pandas/tests/generic/test_generic.py # GH50380
345+
|^pandas/tests/io/json/test_readlines.py # GH50378

pandas/tests/computation/test_eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def test_pow(self, lhs, rhs, engine, parser):
353353
expected = _eval_single_bin(middle, "**", rhs, engine)
354354
tm.assert_almost_equal(result, expected)
355355

356-
def check_single_invert_op(self, lhs, engine, parser):
356+
def test_check_single_invert_op(self, lhs, engine, parser):
357357
# simple
358358
try:
359359
elb = lhs.astype(bool)

pandas/tests/frame/methods/test_dtypes.py

-7
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@
1515
import pandas._testing as tm
1616

1717

18-
def _check_cast(df, v):
19-
"""
20-
Check if all dtypes of df are equal to v
21-
"""
22-
assert all(s.dtype.name == v for _, s in df.items())
23-
24-
2518
class TestDataFrameDataTypes:
2619
def test_empty_frame_dtypes(self):
2720
empty_df = DataFrame()

pandas/tests/frame/methods/test_to_timestamp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_to_timestamp_columns(self):
121121
assert result1.columns.freqstr == "AS-JAN"
122122
assert result2.columns.freqstr == "AS-JAN"
123123

124-
def to_timestamp_invalid_axis(self):
124+
def test_to_timestamp_invalid_axis(self):
125125
index = period_range(freq="A", start="1/1/2001", end="12/1/2009")
126126
obj = DataFrame(np.random.randn(len(index), 5), index=index)
127127

pandas/tests/internals/test_internals.py

-21
Original file line numberDiff line numberDiff line change
@@ -1323,10 +1323,6 @@ def test_period_can_hold_element(self, element):
13231323
elem = element(dti)
13241324
self.check_series_setitem(elem, pi, False)
13251325

1326-
def check_setting(self, elem, index: Index, inplace: bool):
1327-
self.check_series_setitem(elem, index, inplace)
1328-
self.check_frame_setitem(elem, index, inplace)
1329-
13301326
def check_can_hold_element(self, obj, elem, inplace: bool):
13311327
blk = obj._mgr.blocks[0]
13321328
if inplace:
@@ -1350,23 +1346,6 @@ def check_series_setitem(self, elem, index: Index, inplace: bool):
13501346
else:
13511347
assert ser.dtype == object
13521348

1353-
def check_frame_setitem(self, elem, index: Index, inplace: bool):
1354-
arr = index._data.copy()
1355-
df = DataFrame(arr)
1356-
1357-
self.check_can_hold_element(df, elem, inplace)
1358-
1359-
if is_scalar(elem):
1360-
df.iloc[0, 0] = elem
1361-
else:
1362-
df.iloc[: len(elem), 0] = elem
1363-
1364-
if inplace:
1365-
# assertion here implies setting was done inplace
1366-
assert df._mgr.arrays[0] is arr
1367-
else:
1368-
assert df.dtypes[0] == object
1369-
13701349

13711350
class TestShouldStore:
13721351
def test_should_store_categorical(self):

pandas/tests/io/test_feather.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ def test_read_columns(self):
113113
columns = ["col1", "col3"]
114114
self.check_round_trip(df, expected=df[columns], columns=columns)
115115

116-
def read_columns_different_order(self):
116+
def test_read_columns_different_order(self):
117117
# GH 33878
118118
df = pd.DataFrame({"A": [1, 2], "B": ["x", "y"], "C": [True, False]})
119-
self.check_round_trip(df, columns=["B", "A"])
119+
expected = df[["B", "A"]]
120+
self.check_round_trip(df, expected, columns=["B", "A"])
120121

121122
def test_unsupported_other(self):
122123

pandas/tests/reshape/concat/test_append_common.py

-15
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,6 @@ def item(self, request):
5555

5656
item2 = item
5757

58-
def _check_expected_dtype(self, obj, label):
59-
"""
60-
Check whether obj has expected dtype depending on label
61-
considering not-supported dtypes
62-
"""
63-
if isinstance(obj, Index):
64-
assert obj.dtype == label
65-
elif isinstance(obj, Series):
66-
if label.startswith("period"):
67-
assert obj.dtype == "Period[M]"
68-
else:
69-
assert obj.dtype == label
70-
else:
71-
raise ValueError
72-
7358
def test_dtypes(self, item, index_or_series):
7459
# to confirm test case covers intended dtypes
7560
typ, vals = item

pandas/tests/series/methods/test_explode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_invert_array():
7676
@pytest.mark.parametrize(
7777
"s", [pd.Series([1, 2, 3]), pd.Series(pd.date_range("2019", periods=3, tz="UTC"))]
7878
)
79-
def non_object_dtype(s):
79+
def test_non_object_dtype(s):
8080
result = s.explode()
8181
tm.assert_series_equal(result, s)
8282

pandas/tests/strings/test_cat.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
_testing as tm,
1212
concat,
1313
)
14-
from pandas.tests.strings.test_strings import assert_series_or_index_equal
14+
15+
16+
def assert_series_or_index_equal(left, right):
17+
if isinstance(left, Series):
18+
tm.assert_series_equal(left, right)
19+
else: # Index
20+
tm.assert_index_equal(left, right)
1521

1622

1723
@pytest.mark.parametrize("other", [None, Series, Index])

pandas/tests/strings/test_strings.py

-7
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@ def test_startswith_endswith_non_str_patterns(pattern):
2626
ser.str.endswith(pattern)
2727

2828

29-
def assert_series_or_index_equal(left, right):
30-
if isinstance(left, Series):
31-
tm.assert_series_equal(left, right)
32-
else: # Index
33-
tm.assert_index_equal(left, right)
34-
35-
3629
# test integer/float dtypes (inferred by constructor) and mixed
3730

3831

pandas/tests/tseries/offsets/test_dst.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,18 @@
3030
YearEnd,
3131
)
3232

33-
from pandas.tests.tseries.offsets.test_offsets import get_utc_offset_hours
3433
from pandas.util.version import Version
3534

3635
# error: Module has no attribute "__version__"
3736
pytz_version = Version(pytz.__version__) # type: ignore[attr-defined]
3837

3938

39+
def get_utc_offset_hours(ts):
40+
# take a Timestamp and compute total hours of utc offset
41+
o = ts.utcoffset()
42+
return (o.days * 24 * 3600 + o.seconds) / 3600.0
43+
44+
4045
class TestDST:
4146

4247
# one microsecond before the DST transition

pandas/tests/tseries/offsets/test_offsets.py

-6
Original file line numberDiff line numberDiff line change
@@ -900,12 +900,6 @@ def test_str_for_named_is_name(self):
900900
assert offset.freqstr == name
901901

902902

903-
def get_utc_offset_hours(ts):
904-
# take a Timestamp and compute total hours of utc offset
905-
o = ts.utcoffset()
906-
return (o.days * 24 * 3600 + o.seconds) / 3600.0
907-
908-
909903
# ---------------------------------------------------------------------
910904

911905

pandas/tests/util/test_assert_frame_equal.py

-18
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,6 @@ def _assert_frame_equal_both(a, b, **kwargs):
3434
tm.assert_frame_equal(b, a, **kwargs)
3535

3636

37-
def _assert_not_frame_equal(a, b, **kwargs):
38-
"""
39-
Check that two DataFrame are not equal.
40-
41-
Parameters
42-
----------
43-
a : DataFrame
44-
The first DataFrame to compare.
45-
b : DataFrame
46-
The second DataFrame to compare.
47-
kwargs : dict
48-
The arguments passed to `tm.assert_frame_equal`.
49-
"""
50-
msg = "The two DataFrames were equal when they shouldn't have been"
51-
with pytest.raises(AssertionError, match=msg):
52-
tm.assert_frame_equal(a, b, **kwargs)
53-
54-
5537
@pytest.mark.parametrize("check_like", [True, False])
5638
def test_frame_equal_row_order_mismatch(check_like, obj_fixture):
5739
df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"])

scripts/check_test_naming.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
Check that test names start with `test`, and that test classes start with `Test`.
3+
4+
This is meant to be run as a pre-commit hook - to run it manually, you can do:
5+
6+
pre-commit run check-test-naming --all-files
7+
8+
NOTE: if this finds a false positive, you can add the comment `# not a test` to the
9+
class or function definition. Though hopefully that shouldn't be necessary.
10+
"""
11+
from __future__ import annotations
12+
13+
import argparse
14+
import ast
15+
import os
16+
from pathlib import Path
17+
import sys
18+
from typing import (
19+
Iterator,
20+
Sequence,
21+
)
22+
23+
PRAGMA = "# not a test"
24+
25+
26+
def _find_names(node: ast.Module) -> Iterator[str]:
27+
for _node in ast.walk(node):
28+
if isinstance(_node, ast.Name):
29+
yield _node.id
30+
elif isinstance(_node, ast.Attribute):
31+
yield _node.attr
32+
33+
34+
def _is_fixture(node: ast.expr) -> bool:
35+
if isinstance(node, ast.Call):
36+
node = node.func
37+
return (
38+
isinstance(node, ast.Attribute)
39+
and node.attr == "fixture"
40+
and isinstance(node.value, ast.Name)
41+
and node.value.id == "pytest"
42+
)
43+
44+
45+
def _is_register_dtype(node):
46+
return isinstance(node, ast.Name) and node.id == "register_extension_dtype"
47+
48+
49+
def is_misnamed_test_func(
50+
node: ast.expr | ast.stmt, names: Sequence[str], line: str
51+
) -> bool:
52+
return (
53+
isinstance(node, ast.FunctionDef)
54+
and not node.name.startswith("test")
55+
and names.count(node.name) == 0
56+
and not any(_is_fixture(decorator) for decorator in node.decorator_list)
57+
and PRAGMA not in line
58+
and node.name
59+
not in ("teardown_method", "setup_method", "teardown_class", "setup_class")
60+
)
61+
62+
63+
def is_misnamed_test_class(
64+
node: ast.expr | ast.stmt, names: Sequence[str], line: str
65+
) -> bool:
66+
return (
67+
isinstance(node, ast.ClassDef)
68+
and not node.name.startswith("Test")
69+
and names.count(node.name) == 0
70+
and not any(_is_register_dtype(decorator) for decorator in node.decorator_list)
71+
and PRAGMA not in line
72+
)
73+
74+
75+
def main(content: str, file: str) -> int:
76+
lines = content.splitlines()
77+
tree = ast.parse(content)
78+
names = list(_find_names(tree))
79+
ret = 0
80+
for node in tree.body:
81+
if is_misnamed_test_func(node, names, lines[node.lineno - 1]):
82+
print(
83+
f"{file}:{node.lineno}:{node.col_offset} "
84+
"found test function which does not start with 'test'"
85+
)
86+
ret = 1
87+
elif is_misnamed_test_class(node, names, lines[node.lineno - 1]):
88+
print(
89+
f"{file}:{node.lineno}:{node.col_offset} "
90+
"found test class which does not start with 'Test'"
91+
)
92+
ret = 1
93+
if (
94+
isinstance(node, ast.ClassDef)
95+
and names.count(node.name) == 0
96+
and not any(
97+
_is_register_dtype(decorator) for decorator in node.decorator_list
98+
)
99+
and PRAGMA not in lines[node.lineno - 1]
100+
):
101+
for _node in node.body:
102+
if is_misnamed_test_func(_node, names, lines[_node.lineno - 1]):
103+
# It could be that this function is used somewhere by the
104+
# parent class. For example, there might be a base class
105+
# with
106+
#
107+
# class Foo:
108+
# def foo(self):
109+
# assert 1+1==2
110+
# def test_foo(self):
111+
# self.foo()
112+
#
113+
# and then some subclass overwrites `foo`. So, we check that
114+
# `self.foo` doesn't appear in any of the test classes.
115+
# Note some false negatives might get through, but that's OK.
116+
# This is good enough that has helped identify several examples
117+
# of tests not being run.
118+
assert isinstance(_node, ast.FunctionDef) # help mypy
119+
should_continue = False
120+
for _file in (Path("pandas") / "tests").rglob("*.py"):
121+
with open(os.path.join(_file)) as fd:
122+
_content = fd.read()
123+
if f"self.{_node.name}" in _content:
124+
should_continue = True
125+
break
126+
if should_continue:
127+
continue
128+
129+
print(
130+
f"{file}:{_node.lineno}:{_node.col_offset} "
131+
"found test function which does not start with 'test'"
132+
)
133+
ret = 1
134+
return ret
135+
136+
137+
if __name__ == "__main__":
138+
parser = argparse.ArgumentParser()
139+
parser.add_argument("paths", nargs="*")
140+
args = parser.parse_args()
141+
142+
ret = 0
143+
144+
for file in args.paths:
145+
filename = os.path.basename(file)
146+
if not (filename.startswith("test") and filename.endswith(".py")):
147+
continue
148+
with open(file, encoding="utf-8") as fd:
149+
content = fd.read()
150+
ret |= main(content, file)
151+
152+
sys.exit(ret)

0 commit comments

Comments
 (0)