14
14
#include " mlir/IR/Operation.h"
15
15
#include " mlir/Pass/Pass.h"
16
16
#include " mlir/Support/IndentedOstream.h"
17
+ #include " llvm/ADT/STLExtras.h"
17
18
#include " llvm/Support/Format.h"
18
19
#include " llvm/Support/GraphWriter.h"
19
20
#include < map>
@@ -29,7 +30,7 @@ using namespace mlir;
29
30
30
31
static const StringRef kLineStyleControlFlow = " dashed" ;
31
32
static const StringRef kLineStyleDataFlow = " solid" ;
32
- static const StringRef kShapeNode = " ellipse " ;
33
+ static const StringRef kShapeNode = " Mrecord " ;
33
34
static const StringRef kShapeNone = " plain" ;
34
35
35
36
// / Return the size limits for eliding large attributes.
@@ -49,16 +50,25 @@ static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
49
50
return buf;
50
51
}
51
52
52
- // / Escape special characters such as '\n' and quotation marks.
53
- static std::string escapeString (std::string str) {
54
- return strFromOs ([&](raw_ostream &os) { os.write_escaped (str); });
55
- }
56
-
57
53
// / Put quotation marks around a given string.
58
54
static std::string quoteString (const std::string &str) {
59
55
return " \" " + str + " \" " ;
60
56
}
61
57
58
+ // / For Graphviz record nodes:
59
+ // / " Braces, vertical bars and angle brackets must be escaped with a backslash
60
+ // / character if you wish them to appear as a literal character "
61
+ std::string escapeLabelString (const std::string &str) {
62
+ std::string buf;
63
+ llvm::raw_string_ostream os (buf);
64
+ for (char c : str) {
65
+ if (llvm::is_contained ({' {' , ' |' , ' <' , ' }' , ' >' , ' \n ' , ' "' }, c))
66
+ os << ' \\ ' ;
67
+ os << c;
68
+ }
69
+ return buf;
70
+ }
71
+
62
72
using AttributeMap = std::map<std::string, std::string>;
63
73
64
74
namespace {
@@ -79,6 +89,12 @@ struct Node {
79
89
std::optional<int > clusterId;
80
90
};
81
91
92
+ struct DataFlowEdge {
93
+ Value value;
94
+ Node node;
95
+ std::string port;
96
+ };
97
+
82
98
// / This pass generates a Graphviz dataflow visualization of an MLIR operation.
83
99
// / Note: See https://www.graphviz.org/doc/info/lang.html for more information
84
100
// / about the Graphviz DOT language.
@@ -107,7 +123,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
107
123
private:
108
124
// / Generate a color mapping that will color every operation with the same
109
125
// / name the same way. It'll interpolate the hue in the HSV color-space,
110
- // / attempting to keep the contrast suitable for black text.
126
+ // / using muted colors that provide good contrast for black text.
111
127
template <typename T>
112
128
void initColorMapping (T &irEntity) {
113
129
backgroundColors.clear ();
@@ -120,17 +136,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
120
136
});
121
137
for (auto indexedOps : llvm::enumerate (ops)) {
122
138
double hue = ((double )indexedOps.index ()) / ops.size ();
139
+ // Use lower saturation (0.3) and higher value (0.95) for better
140
+ // readability
123
141
backgroundColors[indexedOps.value ()->getName ()].second =
124
- std::to_string (hue) + " 1.0 1.0 " ;
142
+ std::to_string (hue) + " 0.3 0.95 " ;
125
143
}
126
144
}
127
145
128
146
// / Emit all edges. This function should be called after all nodes have been
129
147
// / emitted.
130
148
void emitAllEdgeStmts () {
131
149
if (printDataFlowEdges) {
132
- for (const auto &[value, node, label] : dataFlowEdges) {
133
- emitEdgeStmt (valueToNode[value], node, label , kLineStyleDataFlow );
150
+ for (const auto &e : dataFlowEdges) {
151
+ emitEdgeStmt (valueToNode[e. value ], e. node , e. port , kLineStyleDataFlow );
134
152
}
135
153
}
136
154
@@ -147,8 +165,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
147
165
os.indent ();
148
166
// Emit invisible anchor node from/to which arrows can be drawn.
149
167
Node anchorNode = emitNodeStmt (" " , kShapeNone );
150
- os << attrStmt (" label" , quoteString (escapeString (std::move (label))))
151
- << " ;\n " ;
168
+ os << attrStmt (" label" , quoteString (label)) << " ;\n " ;
152
169
builder ();
153
170
os.unindent ();
154
171
os << " }\n " ;
@@ -176,16 +193,17 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
176
193
177
194
// Always emit splat attributes.
178
195
if (isa<SplatElementsAttr>(attr)) {
179
- attr.print (os);
196
+ os << escapeLabelString (
197
+ strFromOs ([&](raw_ostream &os) { attr.print (os); }));
180
198
return ;
181
199
}
182
200
183
201
// Elide "big" elements attributes.
184
202
auto elements = dyn_cast<ElementsAttr>(attr);
185
203
if (elements && elements.getNumElements () > largeAttrLimit) {
186
204
os << std::string (elements.getShapedType ().getRank (), ' [' ) << " ..."
187
- << std::string (elements.getShapedType ().getRank (), ' ]' ) << " : "
188
- << elements.getType ();
205
+ << std::string (elements.getShapedType ().getRank (), ' ]' ) << " : " ;
206
+ emitMlirType (os, elements.getType () );
189
207
return ;
190
208
}
191
209
@@ -199,27 +217,43 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
199
217
std::string buf;
200
218
llvm::raw_string_ostream ss (buf);
201
219
attr.print (ss);
202
- os << truncateString (buf);
220
+ os << escapeLabelString (truncateString (buf));
221
+ }
222
+
223
+ // Print a truncated and escaped MLIR type to `os`.
224
+ void emitMlirType (raw_ostream &os, Type type) {
225
+ std::string buf;
226
+ llvm::raw_string_ostream ss (buf);
227
+ type.print (ss);
228
+ os << escapeLabelString (truncateString (buf));
229
+ }
230
+
231
+ // Print a truncated and escaped MLIR operand to `os`.
232
+ void emitMlirOperand (raw_ostream &os, Value operand) {
233
+ operand.printAsOperand (os, OpPrintingFlags ());
203
234
}
204
235
205
236
// / Append an edge to the list of edges.
206
237
// / Note: Edges are written to the output stream via `emitAllEdgeStmts`.
207
- void emitEdgeStmt (Node n1, Node n2, std::string label , StringRef style) {
238
+ void emitEdgeStmt (Node n1, Node n2, std::string port , StringRef style) {
208
239
AttributeMap attrs;
209
240
attrs[" style" ] = style.str ();
210
- // Do not label edges that start/end at a cluster boundary. Such edges are
211
- // clipped at the boundary, but labels are not. This can lead to labels
212
- // floating around without any edge next to them.
213
- if (!n1.clusterId && !n2.clusterId )
214
- attrs[" label" ] = quoteString (escapeString (std::move (label)));
215
241
// Use `ltail` and `lhead` to draw edges between clusters.
216
242
if (n1.clusterId )
217
243
attrs[" ltail" ] = " cluster_" + std::to_string (*n1.clusterId );
218
244
if (n2.clusterId )
219
245
attrs[" lhead" ] = " cluster_" + std::to_string (*n2.clusterId );
220
246
221
247
edges.push_back (strFromOs ([&](raw_ostream &os) {
222
- os << llvm::format (" v%i -> v%i " , n1.id , n2.id );
248
+ os << " v" << n1.id ;
249
+ if (!port.empty () && !n1.clusterId )
250
+ // Attach edge to south compass point of the result
251
+ os << " :res" << port << " :s" ;
252
+ os << " -> " ;
253
+ os << " v" << n2.id ;
254
+ if (!port.empty () && !n2.clusterId )
255
+ // Attach edge to north compass point of the operand
256
+ os << " :arg" << port << " :n" ;
223
257
emitAttrList (os, attrs);
224
258
}));
225
259
}
@@ -240,20 +274,30 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
240
274
StringRef background = " " ) {
241
275
int nodeId = ++counter;
242
276
AttributeMap attrs;
243
- attrs[" label" ] = quoteString (escapeString ( std::move ( label)) );
277
+ attrs[" label" ] = quoteString (label);
244
278
attrs[" shape" ] = shape.str ();
245
279
if (!background.empty ()) {
246
280
attrs[" style" ] = " filled" ;
247
- attrs[" fillcolor" ] = ( " \" " + background + " \" " ) .str ();
281
+ attrs[" fillcolor" ] = quoteString ( background.str () );
248
282
}
249
283
os << llvm::format (" v%i " , nodeId);
250
284
emitAttrList (os, attrs);
251
285
os << " ;\n " ;
252
286
return Node (nodeId);
253
287
}
254
288
255
- // / Generate a label for an operation.
256
- std::string getLabel (Operation *op) {
289
+ std::string getValuePortName (Value operand) {
290
+ // Print value as an operand and omit the leading '%' character.
291
+ auto str = strFromOs ([&](raw_ostream &os) {
292
+ operand.printAsOperand (os, OpPrintingFlags ());
293
+ });
294
+ // Replace % and # with _
295
+ std::replace (str.begin (), str.end (), ' %' , ' _' );
296
+ std::replace (str.begin (), str.end (), ' #' , ' _' );
297
+ return str;
298
+ }
299
+
300
+ std::string getClusterLabel (Operation *op) {
257
301
return strFromOs ([&](raw_ostream &os) {
258
302
// Print operation name and type.
259
303
os << op->getName ();
@@ -267,18 +311,73 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
267
311
268
312
// Print attributes.
269
313
if (printAttrs) {
270
- os << " \n " ;
314
+ os << " \\ l" ;
315
+ for (const NamedAttribute &attr : op->getAttrs ()) {
316
+ os << escapeLabelString (attr.getName ().getValue ().str ()) << " : " ;
317
+ emitMlirAttr (os, attr.getValue ());
318
+ os << " \\ l" ;
319
+ }
320
+ }
321
+ });
322
+ }
323
+
324
+ // / Generate a label for an operation.
325
+ std::string getRecordLabel (Operation *op) {
326
+ return strFromOs ([&](raw_ostream &os) {
327
+ os << " {" ;
328
+
329
+ // Print operation inputs.
330
+ if (op->getNumOperands () > 0 ) {
331
+ os << " {" ;
332
+ auto operandToPort = [&](Value operand) {
333
+ os << " <arg" << getValuePortName (operand) << " > " ;
334
+ emitMlirOperand (os, operand);
335
+ };
336
+ interleave (op->getOperands (), os, operandToPort, " |" );
337
+ os << " }|" ;
338
+ }
339
+ // Print operation name and type.
340
+ os << op->getName () << " \\ l" ;
341
+
342
+ // Print attributes.
343
+ if (printAttrs && !op->getAttrs ().empty ()) {
344
+ // Extra line break to separate attributes from the operation name.
345
+ os << " \\ l" ;
271
346
for (const NamedAttribute &attr : op->getAttrs ()) {
272
- os << ' \n ' << attr.getName ().getValue () << " : " ;
347
+ os << attr.getName ().getValue () << " : " ;
273
348
emitMlirAttr (os, attr.getValue ());
349
+ os << " \\ l" ;
274
350
}
275
351
}
352
+
353
+ if (op->getNumResults () > 0 ) {
354
+ os << " |{" ;
355
+ auto resultToPort = [&](Value result) {
356
+ os << " <res" << getValuePortName (result) << " > " ;
357
+ emitMlirOperand (os, result);
358
+ if (printResultTypes) {
359
+ os << " " ;
360
+ emitMlirType (os, result.getType ());
361
+ }
362
+ };
363
+ interleave (op->getResults (), os, resultToPort, " |" );
364
+ os << " }" ;
365
+ }
366
+
367
+ os << " }" ;
276
368
});
277
369
}
278
370
279
371
// / Generate a label for a block argument.
280
372
std::string getLabel (BlockArgument arg) {
281
- return " arg" + std::to_string (arg.getArgNumber ());
373
+ return strFromOs ([&](raw_ostream &os) {
374
+ os << " <res" << getValuePortName (arg) << " > " ;
375
+ arg.printAsOperand (os, OpPrintingFlags ());
376
+ if (printResultTypes) {
377
+ os << " " ;
378
+ emitMlirType (os, arg.getType ());
379
+ }
380
+ });
282
381
}
283
382
284
383
// / Process a block. Emit a cluster and one node per block argument and
@@ -287,14 +386,12 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
287
386
emitClusterStmt ([&]() {
288
387
for (BlockArgument &blockArg : block.getArguments ())
289
388
valueToNode[blockArg] = emitNodeStmt (getLabel (blockArg));
290
-
291
389
// Emit a node for each operation.
292
390
std::optional<Node> prevNode;
293
391
for (Operation &op : block) {
294
392
Node nextNode = processOperation (&op);
295
393
if (printControlFlowEdges && prevNode)
296
- emitEdgeStmt (*prevNode, nextNode, /* label=*/ " " ,
297
- kLineStyleControlFlow );
394
+ emitEdgeStmt (*prevNode, nextNode, /* port=*/ " " , kLineStyleControlFlow );
298
395
prevNode = nextNode;
299
396
}
300
397
});
@@ -311,18 +408,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
311
408
for (Region ®ion : op->getRegions ())
312
409
processRegion (region);
313
410
},
314
- getLabel (op));
411
+ getClusterLabel (op));
315
412
} else {
316
- node = emitNodeStmt (getLabel (op), kShapeNode ,
413
+ node = emitNodeStmt (getRecordLabel (op), kShapeNode ,
317
414
backgroundColors[op->getName ()].second );
318
415
}
319
416
320
417
// Insert data flow edges originating from each operand.
321
418
if (printDataFlowEdges) {
322
419
unsigned numOperands = op->getNumOperands ();
323
- for (unsigned i = 0 ; i < numOperands; i++)
324
- dataFlowEdges.push_back ({op->getOperand (i), node,
325
- numOperands == 1 ? " " : std::to_string (i)});
420
+ for (unsigned i = 0 ; i < numOperands; i++) {
421
+ auto operand = op->getOperand (i);
422
+ dataFlowEdges.push_back ({operand, node, getValuePortName (operand)});
423
+ }
326
424
}
327
425
328
426
for (Value result : op->getResults ())
@@ -352,7 +450,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
352
450
// / Mapping of SSA values to Graphviz nodes/clusters.
353
451
DenseMap<Value, Node> valueToNode;
354
452
// / Output for data flow edges is delayed until the end to handle cycles
355
- std::vector<std::tuple<Value, Node, std::string> > dataFlowEdges;
453
+ std::vector<DataFlowEdge > dataFlowEdges;
356
454
// / Counter for generating unique node/subgraph identifiers.
357
455
int counter = 0 ;
358
456
0 commit comments