Skip to content

Commit dea01f5

Browse files
rafaelubalmweric-k256
authored andcommitted
New features and bug fix in MLIR test generation tool
- Option `--variable_names <names>` allows the user to pass names for FileCheck regexps representing variables. Variable names are separated by commas, and empty names can be used to generate specific variable names automatically. For example, `--variable-names arg_0,arg_1,,,result` will produce regexp names `ARG_0`, `ARG_1`, `VAR_0`, `VAR_1`, `RESULT`, `VAR_2`, `VAR_3`, ... - Option '--attribute_names <names>' can be used to generate global regexp names to represent attributes. Useful for affine maps. Same behavior as '--variable_names'. - Bug fixed for scope detection of SSA variables in ops with nested regions that return SSA values (e.g., 'linalg.generic'). Originally, returned SSA values were inserted in the nested scope. This version of the tool has been used to generate unit tests for the following patch: https://reviews.llvm.org/D153291 For example, the main body of the test named 'test_select_2d_one_dynamic' was generated using the following command: ``` $ mlir-opt -pass-pipeline='builtin.module(func.func(tosa-to-linalg))' test_select_2d_one_dynamic.tosa.mlir | generate-test-checks.py --attribute_names map0,map1,map2 --variable_names arg0,arg1,arg2,const1,arg0_dim1,arg1_dim1,,arg2_dim1,max_dim1,,,arg0_broadcast,,,,,,,arg1_broadcast,,,,,,,arg2_broadcast,,,,,,result ``` Reviewed By: eric-k256 Differential Revision: https://reviews.llvm.org/D154458
1 parent 2b0ceae commit dea01f5

File tree

1 file changed

+138
-11
lines changed

1 file changed

+138
-11
lines changed

mlir/utils/generate-test-checks.py

Lines changed: 138 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,60 @@
4545
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
4646
SSA_RE = re.compile(SSA_RE_STR)
4747

48+
# Regex matching the left-hand side of an assignment
49+
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
50+
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
51+
52+
# Regex matching attributes
53+
ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
54+
ATTR_RE = re.compile(ATTR_RE_STR)
55+
56+
# Regex matching the left-hand side of an attribute definition
57+
ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*='
58+
ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)
59+
4860

4961
# Class used to generate and manage string substitution blocks for SSA value
5062
# names.
51-
class SSAVariableNamer:
52-
def __init__(self):
63+
class VariableNamer:
64+
def __init__(self, variable_names):
5365
self.scopes = []
5466
self.name_counter = 0
5567

68+
# Number of variable names to still generate in parent scope
69+
self.generate_in_parent_scope_left = 0
70+
71+
# Parse variable names
72+
self.variable_names = [name.upper() for name in variable_names.split(',')]
73+
self.used_variable_names = set()
74+
75+
# Generate the following 'n' variable names in the parent scope.
76+
def generate_in_parent_scope(self, n):
77+
self.generate_in_parent_scope_left = n
78+
5679
# Generate a substitution name for the given ssa value name.
57-
def generate_name(self, ssa_name):
58-
variable = "VAL_" + str(self.name_counter)
59-
self.name_counter += 1
60-
self.scopes[-1][ssa_name] = variable
61-
return variable
80+
def generate_name(self, source_variable_name):
81+
82+
# Compute variable name
83+
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
84+
if variable_name == '':
85+
variable_name = "VAL_" + str(self.name_counter)
86+
self.name_counter += 1
87+
88+
# Scope where variable name is saved
89+
scope = len(self.scopes) - 1
90+
if self.generate_in_parent_scope_left > 0:
91+
self.generate_in_parent_scope_left -= 1
92+
scope = len(self.scopes) - 2
93+
assert(scope >= 0)
94+
95+
# Save variable
96+
if variable_name in self.used_variable_names:
97+
raise RuntimeError(variable_name + ': duplicate variable name')
98+
self.scopes[scope][source_variable_name] = variable_name
99+
self.used_variable_names.add(variable_name)
100+
101+
return variable_name
62102

63103
# Push a new variable name scope.
64104
def push_name_scope(self):
@@ -76,6 +116,46 @@ def num_scopes(self):
76116
def clear_counter(self):
77117
self.name_counter = 0
78118

119+
class AttributeNamer:
120+
121+
def __init__(self, attribute_names):
122+
self.name_counter = 0
123+
self.attribute_names = [name.upper() for name in attribute_names.split(',')]
124+
self.map = {}
125+
self.used_attribute_names = set()
126+
127+
# Generate a substitution name for the given attribute name.
128+
def generate_name(self, source_attribute_name):
129+
130+
# Compute FileCheck name
131+
attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else ''
132+
if attribute_name == '':
133+
attribute_name = "ATTR_" + str(self.name_counter)
134+
self.name_counter += 1
135+
136+
# Prepend global symbol
137+
attribute_name = '$' + attribute_name
138+
139+
# Save attribute
140+
if attribute_name in self.used_attribute_names:
141+
raise RuntimeError(attribute_name + ': duplicate attribute name')
142+
self.map[source_attribute_name] = attribute_name
143+
self.used_attribute_names.add(attribute_name)
144+
return attribute_name
145+
146+
# Get the saved substitution name for the given attribute name. If no name
147+
# has been generated for the given attribute yet, the source attribute name
148+
# itself is returned.
149+
def get_name(self, source_attribute_name):
150+
return self.map[source_attribute_name] if source_attribute_name in self.map else '?'
151+
152+
# Return the number of SSA results in a line of type
153+
# %0, %1, ... = ...
154+
# The function returns 0 if there are no results.
155+
def get_num_ssa_results(input_line):
156+
m = SSA_RESULTS_RE.match(input_line)
157+
return m.group().count('%') if m else 0
158+
79159

80160
# Process a line of input that has been split at each SSA identifier '%'.
81161
def process_line(line_chunks, variable_namer):
@@ -84,7 +164,7 @@ def process_line(line_chunks, variable_namer):
84164
# Process the rest that contained an SSA value name.
85165
for chunk in line_chunks:
86166
m = SSA_RE.match(chunk)
87-
ssa_name = m.group(0)
167+
ssa_name = m.group(0) if m is not None else ''
88168

89169
# Check if an existing variable exists for this name.
90170
variable = None
@@ -126,6 +206,25 @@ def process_source_lines(source_lines, note, args):
126206
source_segments[-1].append(line + "\n")
127207
return source_segments
128208

209+
def process_attribute_definition(line, attribute_namer, output):
210+
m = ATTR_DEF_RE.match(line)
211+
if m:
212+
attribute_name = attribute_namer.generate_name(m.group(1))
213+
line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n'
214+
output.write(line)
215+
216+
def process_attribute_references(line, attribute_namer):
217+
218+
output_line = ''
219+
components = ATTR_RE.split(line)
220+
for component in components:
221+
m = ATTR_RE.match(component)
222+
if m:
223+
output_line += '#[[' + attribute_namer.get_name(m.group(1)) + ']]'
224+
output_line += component[len(m.group()):]
225+
else:
226+
output_line += component
227+
return output_line
129228

130229
# Pre-process a line of input to remove any character sequences that will be
131230
# problematic with FileCheck.
@@ -171,6 +270,20 @@ def main():
171270
'it omits "module {"',
172271
)
173272
parser.add_argument("-i", "--inplace", action="store_true", default=False)
273+
parser.add_argument(
274+
"--variable_names",
275+
type=str,
276+
default='',
277+
help="Names to be used in FileCheck regular expression to represent SSA "
278+
"variables in the order they are encountered. Separate names with commas, "
279+
"and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
280+
parser.add_argument(
281+
"--attribute_names",
282+
type=str,
283+
default='',
284+
help="Names to be used in FileCheck regular expression to represent "
285+
"attributes in the order they are defined. Separate names with commas,"
286+
"commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
174287

175288
args = parser.parse_args()
176289

@@ -197,15 +310,22 @@ def main():
197310
output = args.output
198311

199312
output_segments = [[]]
200-
# A map containing data used for naming SSA value names.
201-
variable_namer = SSAVariableNamer()
313+
314+
# Namers
315+
variable_namer = VariableNamer(args.variable_names)
316+
attribute_namer = AttributeNamer(args.attribute_names)
317+
318+
# Process lines
202319
for input_line in input_lines:
203320
if not input_line:
204321
continue
205-
lstripped_input_line = input_line.lstrip()
322+
323+
# Check if this is an attribute definition and process it
324+
process_attribute_definition(input_line, attribute_namer, output)
206325

207326
# Lines with blocks begin with a ^. These lines have a trailing comment
208327
# that needs to be stripped.
328+
lstripped_input_line = input_line.lstrip()
209329
is_block = lstripped_input_line[0] == "^"
210330
if is_block:
211331
input_line = input_line.rsplit("//", 1)[0].rstrip()
@@ -222,6 +342,10 @@ def main():
222342
variable_namer.push_name_scope()
223343
if cur_level == args.starts_from_scope:
224344
output_segments.append([])
345+
346+
# Result SSA values must still be pushed to parent scope
347+
num_ssa_results = get_num_ssa_results(input_line)
348+
variable_namer.generate_in_parent_scope(num_ssa_results)
225349

226350
# Omit lines at the near top level e.g. "module {".
227351
if cur_level < args.starts_from_scope:
@@ -234,6 +358,9 @@ def main():
234358
# FileCheck.
235359
input_line = preprocess_line(input_line)
236360

361+
# Process uses of attributes in this line
362+
input_line = process_attribute_references(input_line, attribute_namer)
363+
237364
# Split the line at the each SSA value name.
238365
ssa_split = input_line.split("%")
239366

0 commit comments

Comments
 (0)