Skip to content

Commit 30cfd19

Browse files
BANG-225: add more tests; refactor the code
1 parent a2025ba commit 30cfd19

File tree

3 files changed

+51
-29
lines changed

3 files changed

+51
-29
lines changed

flake8_variables_names/ast_helpers.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,44 @@
11
import ast
2+
import functools
23
from typing import List, Tuple, Union
34

45
from flake8_variables_names.list_helpers import flat
56

67

7-
def extract_names_from_node(node: Union[ast.expr, ast.stmt]) -> List[ast.Name]:
8-
if isinstance(node, ast.Name):
9-
return [node]
10-
if isinstance(node, ast.Assign):
11-
nodes = []
12-
for target in node.targets:
13-
nodes.extend(extract_names_from_node(target))
14-
return nodes
15-
if isinstance(node, ast.AnnAssign):
16-
return extract_names_from_node(node.target)
17-
if isinstance(node, ast.Starred):
18-
return extract_names_from_node(node.value)
19-
if isinstance(node, ast.Tuple):
20-
nodes = []
21-
for elt in node.elts:
22-
nodes.extend(extract_names_from_node(elt))
23-
return nodes
8+
@functools.singledispatch
9+
def extract_names_from_node(node) -> List[ast.Name]:
2410
return []
2511

2612

13+
@extract_names_from_node.register
14+
def _extract_names_from_name_node(node: ast.Name):
15+
return [node]
16+
17+
18+
@extract_names_from_node.register
19+
def _extract_names_from_assign_node(node: ast.Assign):
20+
return flat([extract_names_from_node(target) for target in node.targets])
21+
22+
23+
# in some versions of Python, singledispatch does not support `Union` in type annotations
24+
@extract_names_from_node.register(ast.AnnAssign)
25+
@extract_names_from_node.register(ast.For)
26+
def _extract_names_from_annassign_node(node):
27+
return extract_names_from_node(node.target)
28+
29+
30+
@extract_names_from_node.register
31+
def _extract_names_from_starred_node(node: ast.Starred):
32+
return extract_names_from_node(node.value)
33+
34+
35+
@extract_names_from_node.register
36+
def _extract_names_from_tuple_node(node: ast.Tuple):
37+
return flat([extract_names_from_node(elt) for elt in node.elts])
38+
39+
2740
def get_var_names_from_assignment(
28-
assignment_node: Union[ast.Assign, ast.AnnAssign],
41+
assignment_node: Union[ast.Assign, ast.AnnAssign, ast.For],
2942
) -> List[Tuple[str, ast.AST]]:
3043
return [(n.id, n) for n in extract_names_from_node(assignment_node)]
3144

@@ -39,20 +52,13 @@ def get_var_names_from_funcdef(funcdef_node: ast.FunctionDef) -> List[Tuple[str,
3952
return vars_info
4053

4154

42-
def get_var_names_from_for(for_node: ast.For) -> List[Tuple[str, ast.AST]]:
43-
if isinstance(for_node.target, ast.Name):
44-
return [(for_node.target.id, for_node.target)]
45-
elif isinstance(for_node.target, ast.Tuple):
46-
return [(n.id, n) for n in for_node.target.elts if isinstance(n, ast.Name)]
47-
return []
48-
49-
5055
def extract_all_variable_names(ast_tree: ast.AST) -> List[Tuple[str, ast.AST]]:
5156
var_info: List[Tuple[str, ast.AST]] = []
52-
assignments = [n for n in ast.walk(ast_tree) if isinstance(n, (ast.Assign, ast.AnnAssign))]
57+
assignments = [
58+
n for n in ast.walk(ast_tree)
59+
if isinstance(n, (ast.Assign, ast.AnnAssign, ast.For))
60+
]
5361
var_info += flat([get_var_names_from_assignment(a) for a in assignments])
5462
funcdefs = [n for n in ast.walk(ast_tree) if isinstance(n, ast.FunctionDef)]
5563
var_info += flat([get_var_names_from_funcdef(f) for f in funcdefs])
56-
fors = [n for n in ast.walk(ast_tree) if isinstance(n, ast.For)]
57-
var_info += flat([get_var_names_from_for(f) for f in fors])
5864
return var_info

tests/test_files/loops_names.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
for result in get_some_data():
2+
pass
3+
4+
for a, *b in zip(data, goosebumps):
5+
print(a + b)

tests/test_variables_names.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def test_ok_for_short_names_file():
2222
)
2323

2424

25+
def test_ok_for_names_in_loops_file():
26+
errors = run_validator_for_test_file('loops_names.py', use_strict_mode=True)
27+
assert len(errors) == 3
28+
errors = run_validator_for_test_file('loops_names.py', use_strict_mode=False)
29+
assert len(errors) == 2
30+
assert (
31+
get_error_message(errors[0])
32+
== "VNE001 single letter variable names like 'a' are not allowed"
33+
)
34+
35+
2536
def test_ok_for_commented_names_file():
2637
errors = run_validator_for_test_file('commented_names.py', use_strict_mode=True)
2738
assert not errors

0 commit comments

Comments
 (0)