Skip to content

Commit 3ada36b

Browse files
authored
Auto generate visit_source_order (#17180)
## Summary part of: #15655 I tried generating the source order function using code generation. I tried a simple approach, but it is not enough to generate all of them this way. There is one good thing, that most of the implementations are fine with this. We only have a few that are not. So one benefit of this PR could be it eliminates a lot of the code, hence changing the AST structure will only leave a few places to be fixed. The `source_order` field determines if a node requires a source order implementation. If it’s empty it means source order does not visit anything. Initially I didn’t want to repeat the field names. But I found two things: - `ExprIf` statement unlike other statements does not have the fields defined in source order. This and also some fields do not need to be included in the visit. So we just need a way to determine order, and determine presence. - Relying on the fields sounds more complicated to me. Maybe another solution is to add a new attribute `order` to each field? I'm open to suggestions. But anyway, except for the `ExprIf` we don't need to write the field names in order. Just knowing what fields must be visited are enough. Some nodes had a more complex visitor: `ExprCompare` required zipping two fields. `ExprBoolOp` required a match over the fields. `FstringValue` required a match, I created a new walk_ function that does the match. and used it in code generation. I don’t think this provides real value. Because I mostly moved the code from one file to another. I was tried it as an option. I prefer to leave it in the code as before. Some visitors visit a slice of items. Others visit a single element. I put a check on this in code generation to see if the field requires a for loop or not. I think better approach is to have a consistent style. So we can by default loop over any field that is a sequence. For field types `StringLiteralValue` and `BytesLiteralValue` the types are not a sequence in toml definition. But they implement `iter` so they are iterated over. So the code generation does not properly identify this. So in the code I'm checking for their types. ## Test Plan All the tests should pass without any changes. I checked the generated code to make sure it's the same as old code. I'm not sure if there's a test for the source order visitor.
1 parent bd89838 commit 3ada36b

File tree

5 files changed

+1062
-900
lines changed

5 files changed

+1062
-900
lines changed

crates/ruff_python_ast/ast.toml

+27-6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
# derives:
3838
# List of derives to add to the syntax node struct. Clone, Debug, PartialEq are added by default.
3939
#
40+
# custom_source_order:
41+
# A boolean that specifies if this node has a custom source order visitor implementation.
42+
# generation of visit_source_order will be skipped for this node.
43+
#
4044
# fields:
4145
# List of fields in the syntax node struct. Each field is a table with the
4246
# following keys:
@@ -48,6 +52,10 @@
4852
# * `Expr*` - A vector of Expr.
4953
# * `&Expr*` - A boxed slice of Expr.
5054
# These properties cannot be nested, for example we cannot create a vector of option types.
55+
# * is_annotation - If this field is a type annotation.
56+
#
57+
# source_order:
58+
# Defines in what order the fields appear in source
5159
#
5260
# variant:
5361
# The name of the enum variant for this syntax node. Defaults to the node
@@ -57,9 +65,13 @@
5765
anynode_is_label = "module"
5866
doc = "See also [mod](https://docs.python.org/3/library/ast.html#ast.mod)"
5967

60-
[Mod.nodes]
61-
ModModule = {}
62-
ModExpression = {}
68+
[Mod.nodes.ModModule]
69+
doc = "See also [Module](https://docs.python.org/3/library/ast.html#ast.Module)"
70+
fields = [{ name = "body", type = "Stmt*" }]
71+
72+
[Mod.nodes.ModExpression]
73+
doc = "See also [Module](https://docs.python.org/3/library/ast.html#ast.Module)"
74+
fields = [{ name = "body", type = "Box<Expr>" }]
6375

6476
[Stmt]
6577
add_suffix_to_is_methods = true
@@ -77,8 +89,7 @@ fields = [
7789
{ name = "name", type = "Identifier" },
7890
{ name = "type_params", type = "Box<crate::TypeParams>?" },
7991
{ name = "parameters", type = "Box<crate::Parameters>" },
80-
81-
{ name = "returns", type = "Expr?" },
92+
{ name = "returns", type = "Expr?", is_annotation = true },
8293
{ name = "body", type = "Stmt*" },
8394
]
8495

@@ -127,7 +138,7 @@ fields = [
127138
doc = "See also [AnnAssign](https://docs.python.org/3/library/ast.html#ast.AnnAssign)"
128139
fields = [
129140
{ name = "target", type = "Expr" },
130-
{ name = "annotation", type = "Expr" },
141+
{ name = "annotation", type = "Expr", is_annotation = true },
131142
{ name = "value", type = "Expr?" },
132143
{ name = "simple", type = "bool" },
133144
]
@@ -305,6 +316,7 @@ doc = "See also [expr](https://docs.python.org/3/library/ast.html#ast.expr)"
305316
[Expr.nodes.ExprBoolOp]
306317
doc = "See also [BoolOp](https://docs.python.org/3/library/ast.html#ast.BoolOp)"
307318
fields = [{ name = "op", type = "BoolOp" }, { name = "values", type = "Expr*" }]
319+
custom_source_order = true
308320

309321
[Expr.nodes.ExprNamed]
310322
doc = "See also [NamedExpr](https://docs.python.org/3/library/ast.html#ast.NamedExpr)"
@@ -339,10 +351,12 @@ fields = [
339351
{ name = "body", type = "Expr" },
340352
{ name = "orelse", type = "Expr" },
341353
]
354+
source_order = ["body", "test", "orelse"]
342355

343356
[Expr.nodes.ExprDict]
344357
doc = "See also [Dict](https://docs.python.org/3/library/ast.html#ast.Dict)"
345358
fields = [{ name = "items", type = "DictItem*" }]
359+
custom_source_order = true
346360

347361
[Expr.nodes.ExprSet]
348362
doc = "See also [Set](https://docs.python.org/3/library/ast.html#ast.Set)"
@@ -397,6 +411,8 @@ fields = [
397411
{ name = "ops", type = "&CmpOp*" },
398412
{ name = "comparators", type = "&Expr*" },
399413
]
414+
# The fields must be visited simultaneously
415+
custom_source_order = true
400416

401417
[Expr.nodes.ExprCall]
402418
doc = "See also [Call](https://docs.python.org/3/library/ast.html#ast.Call)"
@@ -415,16 +431,21 @@ it keeps them separate and provide various methods to access the parts.
415431
416432
See also [JoinedStr](https://docs.python.org/3/library/ast.html#ast.JoinedStr)"""
417433
fields = [{ name = "value", type = "FStringValue" }]
434+
custom_source_order = true
418435

419436
[Expr.nodes.ExprStringLiteral]
420437
doc = """An AST node that represents either a single-part string literal
421438
or an implicitly concatenated string literal."""
422439
fields = [{ name = "value", type = "StringLiteralValue" }]
440+
# Because StringLiteralValue type is an iterator and it's not clear from the type
441+
custom_source_order = true
423442

424443
[Expr.nodes.ExprBytesLiteral]
425444
doc = """An AST node that represents either a single-part bytestring literal
426445
or an implicitly concatenated bytestring literal."""
427446
fields = [{ name = "value", type = "BytesLiteralValue" }]
447+
# Because BytesLiteralValue type is an iterator and it's not clear from the type
448+
custom_source_order = true
428449

429450
[Expr.nodes.ExprNumberLiteral]
430451
fields = [{ name = "value", type = "Number" }]

crates/ruff_python_ast/generate.py

+155-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import tomllib
1616

1717
# Types that require `crate::`. We can slowly remove these types as we move them to generate scripts.
18-
types_requiring_create_prefix = [
18+
types_requiring_create_prefix = {
1919
"IpyEscapeKind",
2020
"ExprContext",
2121
"Identifier",
@@ -33,12 +33,11 @@
3333
"Decorator",
3434
"TypeParams",
3535
"Parameters",
36-
"Arguments",
3736
"ElifElseClause",
3837
"WithItem",
3938
"MatchCase",
4039
"Alias",
41-
]
40+
}
4241

4342

4443
def rustfmt(code: str) -> str:
@@ -124,6 +123,8 @@ class Node:
124123
doc: str | None
125124
fields: list[Field] | None
126125
derives: list[str]
126+
custom_source_order: bool
127+
source_order: list[str] | None
127128

128129
def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None:
129130
self.name = node_name
@@ -133,33 +134,90 @@ def __init__(self, group: Group, node_name: str, node: dict[str, Any]) -> None:
133134
fields = node.get("fields")
134135
if fields is not None:
135136
self.fields = [Field(f) for f in fields]
137+
self.custom_source_order = node.get("custom_source_order", False)
136138
self.derives = node.get("derives", [])
137139
self.doc = node.get("doc")
140+
self.source_order = node.get("source_order")
141+
142+
def fields_in_source_order(self) -> list[Field]:
143+
if self.fields is None:
144+
return []
145+
if self.source_order is None:
146+
return list(filter(lambda x: not x.skip_source_order(), self.fields))
147+
148+
fields = []
149+
for field_name in self.source_order:
150+
field = None
151+
for field in self.fields:
152+
if field.skip_source_order():
153+
continue
154+
if field.name == field_name:
155+
field = field
156+
break
157+
fields.append(field)
158+
return fields
138159

139160

140161
@dataclass
141162
class Field:
142163
name: str
143164
ty: str
165+
_skip_visit: bool
166+
is_annotation: bool
144167
parsed_ty: FieldType
145168

146169
def __init__(self, field: dict[str, Any]) -> None:
147170
self.name = field["name"]
148171
self.ty = field["type"]
149172
self.parsed_ty = FieldType(self.ty)
173+
self._skip_visit = field.get("skip_visit", False)
174+
self.is_annotation = field.get("is_annotation", False)
175+
176+
def skip_source_order(self) -> bool:
177+
return self._skip_visit or self.parsed_ty.inner in [
178+
"str",
179+
"ExprContext",
180+
"Name",
181+
"u32",
182+
"bool",
183+
"Number",
184+
"IpyEscapeKind",
185+
]
186+
187+
188+
# Extracts the type argument from the given rust type with AST field type syntax.
189+
# Box<str> -> str
190+
# Box<Expr?> -> Expr
191+
# If the type does not have a type argument, it will return the string.
192+
# Does not support nested types
193+
def extract_type_argument(rust_type_str: str) -> str:
194+
rust_type_str = rust_type_str.replace("*", "")
195+
rust_type_str = rust_type_str.replace("?", "")
196+
rust_type_str = rust_type_str.replace("&", "")
197+
198+
open_bracket_index = rust_type_str.find("<")
199+
if open_bracket_index == -1:
200+
return rust_type_str
201+
close_bracket_index = rust_type_str.rfind(">")
202+
if close_bracket_index == -1 or close_bracket_index <= open_bracket_index:
203+
raise ValueError(f"Brackets are not balanced for type {rust_type_str}")
204+
inner_type = rust_type_str[open_bracket_index + 1 : close_bracket_index].strip()
205+
return inner_type
150206

151207

152208
@dataclass
153209
class FieldType:
154210
rule: str
155211
name: str
212+
inner: str
156213
seq: bool = False
157214
optional: bool = False
158215
slice_: bool = False
159216

160217
def __init__(self, rule: str) -> None:
161218
self.rule = rule
162219
self.name = ""
220+
self.inner = extract_type_argument(rule)
163221

164222
# The following cases are the limitations of this parser(and not used in the ast.toml):
165223
# * Rules that involve declaring a sequence with optional items e.g. Vec<Option<...>>
@@ -201,6 +259,7 @@ def write_preamble(out: list[str]) -> None:
201259
// Run `crates/ruff_python_ast/generate.py` to re-generate the file.
202260
203261
use crate::name::Name;
262+
use crate::visitor::source_order::SourceOrderVisitor;
204263
""")
205264

206265

@@ -703,6 +762,98 @@ def write_node(out: list[str], ast: Ast) -> None:
703762
out.append("")
704763

705764

765+
# ------------------------------------------------------------------------------
766+
# Source order visitor
767+
768+
769+
@dataclass
770+
class VisitorInfo:
771+
name: str
772+
accepts_sequence: bool = False
773+
774+
775+
# Map of AST node types to their corresponding visitor information
776+
type_to_visitor_function: dict[str, VisitorInfo] = {
777+
"Decorator": VisitorInfo("visit_decorator"),
778+
"Identifier": VisitorInfo("visit_identifier"),
779+
"crate::TypeParams": VisitorInfo("visit_type_params", True),
780+
"crate::Parameters": VisitorInfo("visit_parameters", True),
781+
"Expr": VisitorInfo("visit_expr"),
782+
"Stmt": VisitorInfo("visit_body", True),
783+
"Arguments": VisitorInfo("visit_arguments", True),
784+
"crate::Arguments": VisitorInfo("visit_arguments", True),
785+
"Operator": VisitorInfo("visit_operator"),
786+
"ElifElseClause": VisitorInfo("visit_elif_else_clause"),
787+
"WithItem": VisitorInfo("visit_with_item"),
788+
"MatchCase": VisitorInfo("visit_match_case"),
789+
"ExceptHandler": VisitorInfo("visit_except_handler"),
790+
"Alias": VisitorInfo("visit_alias"),
791+
"UnaryOp": VisitorInfo("visit_unary_op"),
792+
"DictItem": VisitorInfo("visit_dict_item"),
793+
"Comprehension": VisitorInfo("visit_comprehension"),
794+
"CmpOp": VisitorInfo("visit_cmp_op"),
795+
"FStringValue": VisitorInfo("visit_f_string_value"),
796+
"StringLiteralValue": VisitorInfo("visit_string_literal"),
797+
"BytesLiteralValue": VisitorInfo("visit_bytes_literal"),
798+
}
799+
annotation_visitor_function = VisitorInfo("visit_annotation")
800+
801+
802+
def write_source_order(out: list[str], ast: Ast) -> None:
803+
for group in ast.groups:
804+
for node in group.nodes:
805+
if node.fields is None or node.custom_source_order:
806+
continue
807+
name = node.name
808+
fields_list = ""
809+
body = ""
810+
811+
for field in node.fields:
812+
if field.skip_source_order():
813+
fields_list += f"{field.name}: _,\n"
814+
else:
815+
fields_list += f"{field.name},\n"
816+
fields_list += "range: _,\n"
817+
818+
for field in node.fields_in_source_order():
819+
visitor = type_to_visitor_function[field.parsed_ty.inner]
820+
if field.is_annotation:
821+
visitor = annotation_visitor_function
822+
823+
if field.parsed_ty.optional:
824+
body += f"""
825+
if let Some({field.name}) = {field.name} {{
826+
visitor.{visitor.name}({field.name});
827+
}}\n
828+
"""
829+
elif not visitor.accepts_sequence and field.parsed_ty.seq:
830+
body += f"""
831+
for elm in {field.name} {{
832+
visitor.{visitor.name}(elm);
833+
}}
834+
"""
835+
else:
836+
body += f"visitor.{visitor.name}({field.name});\n"
837+
838+
visitor_arg_name = "visitor"
839+
if len(node.fields_in_source_order()) == 0:
840+
visitor_arg_name = "_"
841+
842+
out.append(f"""
843+
impl {name} {{
844+
pub(crate) fn visit_source_order<'a, V>(&'a self, {visitor_arg_name}: &mut V)
845+
where
846+
V: SourceOrderVisitor<'a> + ?Sized,
847+
{{
848+
let {name} {{
849+
{fields_list}
850+
}} = self;
851+
{body}
852+
}}
853+
}}
854+
""")
855+
856+
706857
# ------------------------------------------------------------------------------
707858
# Format and write output
708859

@@ -715,6 +866,7 @@ def generate(ast: Ast) -> list[str]:
715866
write_anynoderef(out, ast)
716867
write_nodekind(out, ast)
717868
write_node(out, ast)
869+
write_source_order(out, ast)
718870
return out
719871

720872

0 commit comments

Comments
 (0)