Skip to content

Commit f0feaff

Browse files
authored
Merge pull request #1613 from henryiii/henryiii/fix/mainif
fix(setup.py): look inside if name == main block
2 parents 80a54b0 + f34ae77 commit f0feaff

File tree

2 files changed

+114
-2
lines changed

2 files changed

+114
-2
lines changed

cibuildwheel/projectfiles.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,43 @@
88
from ._compat import tomllib
99

1010

11+
def get_parent(node: ast.AST | None, depth: int = 1) -> ast.AST | None:
12+
for _ in range(depth):
13+
node = getattr(node, "parent", None)
14+
return node
15+
16+
17+
def is_main(parent: ast.AST | None) -> bool:
18+
if parent is None:
19+
return False
20+
21+
# This would be much nicer with 3.10's pattern matching!
22+
if not isinstance(parent, ast.If):
23+
return False
24+
if not isinstance(parent.test, ast.Compare):
25+
return False
26+
27+
try:
28+
(op,) = parent.test.ops
29+
(comp,) = parent.test.comparators
30+
except ValueError:
31+
return False
32+
33+
if not isinstance(op, ast.Eq):
34+
return False
35+
36+
values = {comp, parent.test.left}
37+
38+
mains = {x for x in values if isinstance(x, ast.Constant) and x.value == "__main__"}
39+
if len(mains) != 1:
40+
return False
41+
consts = {x for x in values if isinstance(x, ast.Name) and x.id == "__name__"}
42+
if len(consts) != 1:
43+
return False
44+
45+
return True
46+
47+
1148
class Analyzer(ast.NodeVisitor):
1249
def __init__(self) -> None:
1350
self.requires_python: str | None = None
@@ -19,13 +56,22 @@ def visit(self, node: ast.AST) -> None:
1956
super().visit(node)
2057

2158
def visit_keyword(self, node: ast.keyword) -> None:
59+
# Must not be nested except for if __name__ == "__main__"
60+
2261
self.generic_visit(node)
23-
# Must not be nested in an if or other structure
2462
# This will be Module -> Expr -> Call -> keyword
63+
parent = get_parent(node, 4)
64+
unnested = parent is None
65+
66+
# This will be Module -> If -> Expr -> Call -> keyword
67+
name_main_unnested = (
68+
parent is not None and get_parent(parent) is None and is_main(get_parent(node, 3))
69+
)
70+
2571
if (
2672
node.arg == "python_requires"
27-
and not hasattr(node.parent.parent.parent, "parent") # type: ignore[attr-defined]
2873
and isinstance(node.value, ast.Constant)
74+
and (unnested or name_main_unnested)
2975
):
3076
self.requires_python = node.value.value
3177

unit_test/projectfiles_test.py

+66
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,72 @@ def test_read_setup_py_simple(tmp_path):
2626
assert get_requires_python_str(tmp_path) == "1.23"
2727

2828

29+
def test_read_setup_py_if_main(tmp_path):
30+
with open(tmp_path / "setup.py", "w") as f:
31+
f.write(
32+
dedent(
33+
"""
34+
from setuptools import setup
35+
36+
if __name__ == "__main__":
37+
setup(
38+
name = "hello",
39+
other = 23,
40+
example = ["item", "other"],
41+
python_requires = "1.23",
42+
)
43+
"""
44+
)
45+
)
46+
47+
assert setup_py_python_requires(tmp_path.joinpath("setup.py").read_text()) == "1.23"
48+
assert get_requires_python_str(tmp_path) == "1.23"
49+
50+
51+
def test_read_setup_py_if_main_reversed(tmp_path):
52+
with open(tmp_path / "setup.py", "w") as f:
53+
f.write(
54+
dedent(
55+
"""
56+
from setuptools import setup
57+
58+
if "__main__" == __name__:
59+
setup(
60+
name = "hello",
61+
other = 23,
62+
example = ["item", "other"],
63+
python_requires = "1.23",
64+
)
65+
"""
66+
)
67+
)
68+
69+
assert setup_py_python_requires(tmp_path.joinpath("setup.py").read_text()) == "1.23"
70+
assert get_requires_python_str(tmp_path) == "1.23"
71+
72+
73+
def test_read_setup_py_if_invalid(tmp_path):
74+
with open(tmp_path / "setup.py", "w") as f:
75+
f.write(
76+
dedent(
77+
"""
78+
from setuptools import setup
79+
80+
if True:
81+
setup(
82+
name = "hello",
83+
other = 23,
84+
example = ["item", "other"],
85+
python_requires = "1.23",
86+
)
87+
"""
88+
)
89+
)
90+
91+
assert not setup_py_python_requires(tmp_path.joinpath("setup.py").read_text())
92+
assert not get_requires_python_str(tmp_path)
93+
94+
2995
def test_read_setup_py_full(tmp_path):
3096
with open(tmp_path / "setup.py", "w", encoding="utf8") as f:
3197
f.write(

0 commit comments

Comments
 (0)