diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 220465a1d2847..4379220c33687 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -339,6 +339,11 @@ repos: language: python entry: python scripts/validate_unwanted_patterns.py --validation-type="strings_with_wrong_placed_whitespace" types_or: [python, cython] + - id: unwanted-patterns-nodefault-used-not-only-for-typing + name: Check that `pandas._libs.lib.NoDefault` is used only for typing + language: python + entry: python scripts/validate_unwanted_patterns.py --validation-type="nodefault_used_not_only_for_typing" + types: [python] - id: use-pd_array-in-core name: Import pandas.array as pd_array in core language: python diff --git a/scripts/tests/test_validate_unwanted_patterns.py b/scripts/tests/test_validate_unwanted_patterns.py index 90eca13b21628..b4423197e2573 100644 --- a/scripts/tests/test_validate_unwanted_patterns.py +++ b/scripts/tests/test_validate_unwanted_patterns.py @@ -375,3 +375,72 @@ def test_strings_with_wrong_placed_whitespace_raises(self, data, expected): validate_unwanted_patterns.strings_with_wrong_placed_whitespace(fd) ) assert result == expected + + +class TestNoDefaultUsedNotOnlyForTyping: + @pytest.mark.parametrize( + "data", + [ + ( + """ +def f( + a: int | NoDefault, + b: float | lib.NoDefault = 0.1, + c: pandas._libs.lib.NoDefault = lib.no_default, +) -> lib.NoDefault | None: + pass +""" + ), + ( + """ +# var = lib.NoDefault +# the above is incorrect +a: NoDefault | int +b: lib.NoDefault = lib.no_default +""" + ), + ], + ) + def test_nodefault_used_not_only_for_typing(self, data): + fd = io.StringIO(data.strip()) + result = list(validate_unwanted_patterns.nodefault_used_not_only_for_typing(fd)) + assert result == [] + + @pytest.mark.parametrize( + "data, expected", + [ + ( + ( + """ +def f( + a = lib.NoDefault, + b: Any + = pandas._libs.lib.NoDefault, +): + pass +""" + ), + [ + (2, "NoDefault is used not only for typing"), + (4, "NoDefault is used not only for typing"), + ], + ), + ( + ( + """ +a: Any = lib.NoDefault +if a is NoDefault: + pass +""" + ), + [ + (1, "NoDefault is used not only for typing"), + (2, "NoDefault is used not only for typing"), + ], + ), + ], + ) + def test_nodefault_used_not_only_for_typing_raises(self, data, expected): + fd = io.StringIO(data.strip()) + result = list(validate_unwanted_patterns.nodefault_used_not_only_for_typing(fd)) + assert result == expected diff --git a/scripts/validate_unwanted_patterns.py b/scripts/validate_unwanted_patterns.py index e171d1825ac48..cffae7d18bee1 100755 --- a/scripts/validate_unwanted_patterns.py +++ b/scripts/validate_unwanted_patterns.py @@ -353,6 +353,52 @@ def has_wrong_whitespace(first_line: str, second_line: str) -> bool: ) +def nodefault_used_not_only_for_typing(file_obj: IO[str]) -> Iterable[Tuple[int, str]]: + """Test case where pandas._libs.lib.NoDefault is not used for typing. + + Parameters + ---------- + file_obj : IO + File-like object containing the Python code to validate. + + Yields + ------ + line_number : int + Line number of misused lib.NoDefault. + msg : str + Explanation of the error. + """ + contents = file_obj.read() + tree = ast.parse(contents) + in_annotation = False + nodes: List[tuple[bool, ast.AST]] = [(in_annotation, tree)] + + while nodes: + in_annotation, node = nodes.pop() + if not in_annotation and ( + isinstance(node, ast.Name) # Case `NoDefault` + and node.id == "NoDefault" + or isinstance(node, ast.Attribute) # Cases e.g. `lib.NoDefault` + and node.attr == "NoDefault" + ): + yield (node.lineno, "NoDefault is used not only for typing") + + # This part is adapted from + # https://github.com/asottile/pyupgrade/blob/5495a248f2165941c5d3b82ac3226ba7ad1fa59d/pyupgrade/_data.py#L70-L113 + for name in reversed(node._fields): + value = getattr(node, name) + if name in {"annotation", "returns"}: + next_in_annotation = True + else: + next_in_annotation = in_annotation + if isinstance(value, ast.AST): + nodes.append((next_in_annotation, value)) + elif isinstance(value, list): + for value in reversed(value): + if isinstance(value, ast.AST): + nodes.append((next_in_annotation, value)) + + def main( function: Callable[[IO[str]], Iterable[Tuple[int, str]]], source_path: str, @@ -405,6 +451,7 @@ def main( "private_function_across_module", "private_import_across_module", "strings_with_wrong_placed_whitespace", + "nodefault_used_not_only_for_typing", ] parser = argparse.ArgumentParser(description="Unwanted patterns checker.") @@ -413,7 +460,7 @@ def main( parser.add_argument( "--format", "-f", - default="{source_path}:{line_number}:{msg}", + default="{source_path}:{line_number}: {msg}", help="Output format of the error message.", ) parser.add_argument(