Skip to content

Commit 1f67070

Browse files
authored
[ViewOpGraph] Improve GraphViz output (llvm#125509)
This patch improves the GraphViz output of ViewOpGraph (--view-op-graph). - Switch to rectangular record-based nodes, inspired by a similar visualization in [Glow](https://github.com/pytorch/glow). Rectangles make more efficient use of space when printing text. - Add input and output ports for each operand and result, and remove edge labels. - Switch to a muted color palette to reduce eye strain.
1 parent 1611059 commit 1f67070

File tree

4 files changed

+207
-109
lines changed

4 files changed

+207
-109
lines changed

mlir/lib/Transforms/ViewOpGraph.cpp

Lines changed: 137 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/Operation.h"
1515
#include "mlir/Pass/Pass.h"
1616
#include "mlir/Support/IndentedOstream.h"
17+
#include "llvm/ADT/STLExtras.h"
1718
#include "llvm/Support/Format.h"
1819
#include "llvm/Support/GraphWriter.h"
1920
#include <map>
@@ -29,7 +30,7 @@ using namespace mlir;
2930

3031
static const StringRef kLineStyleControlFlow = "dashed";
3132
static const StringRef kLineStyleDataFlow = "solid";
32-
static const StringRef kShapeNode = "ellipse";
33+
static const StringRef kShapeNode = "Mrecord";
3334
static const StringRef kShapeNone = "plain";
3435

3536
/// Return the size limits for eliding large attributes.
@@ -49,16 +50,25 @@ static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
4950
return buf;
5051
}
5152

52-
/// Escape special characters such as '\n' and quotation marks.
53-
static std::string escapeString(std::string str) {
54-
return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
55-
}
56-
5753
/// Put quotation marks around a given string.
5854
static std::string quoteString(const std::string &str) {
5955
return "\"" + str + "\"";
6056
}
6157

58+
/// For Graphviz record nodes:
59+
/// " Braces, vertical bars and angle brackets must be escaped with a backslash
60+
/// character if you wish them to appear as a literal character "
61+
std::string escapeLabelString(const std::string &str) {
62+
std::string buf;
63+
llvm::raw_string_ostream os(buf);
64+
for (char c : str) {
65+
if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
66+
os << '\\';
67+
os << c;
68+
}
69+
return buf;
70+
}
71+
6272
using AttributeMap = std::map<std::string, std::string>;
6373

6474
namespace {
@@ -79,6 +89,12 @@ struct Node {
7989
std::optional<int> clusterId;
8090
};
8191

92+
struct DataFlowEdge {
93+
Value value;
94+
Node node;
95+
std::string port;
96+
};
97+
8298
/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
8399
/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
84100
/// about the Graphviz DOT language.
@@ -107,7 +123,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
107123
private:
108124
/// Generate a color mapping that will color every operation with the same
109125
/// name the same way. It'll interpolate the hue in the HSV color-space,
110-
/// attempting to keep the contrast suitable for black text.
126+
/// using muted colors that provide good contrast for black text.
111127
template <typename T>
112128
void initColorMapping(T &irEntity) {
113129
backgroundColors.clear();
@@ -120,17 +136,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
120136
});
121137
for (auto indexedOps : llvm::enumerate(ops)) {
122138
double hue = ((double)indexedOps.index()) / ops.size();
139+
// Use lower saturation (0.3) and higher value (0.95) for better
140+
// readability
123141
backgroundColors[indexedOps.value()->getName()].second =
124-
std::to_string(hue) + " 1.0 1.0";
142+
std::to_string(hue) + " 0.3 0.95";
125143
}
126144
}
127145

128146
/// Emit all edges. This function should be called after all nodes have been
129147
/// emitted.
130148
void emitAllEdgeStmts() {
131149
if (printDataFlowEdges) {
132-
for (const auto &[value, node, label] : dataFlowEdges) {
133-
emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
150+
for (const auto &e : dataFlowEdges) {
151+
emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
134152
}
135153
}
136154

@@ -147,8 +165,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
147165
os.indent();
148166
// Emit invisible anchor node from/to which arrows can be drawn.
149167
Node anchorNode = emitNodeStmt(" ", kShapeNone);
150-
os << attrStmt("label", quoteString(escapeString(std::move(label))))
151-
<< ";\n";
168+
os << attrStmt("label", quoteString(label)) << ";\n";
152169
builder();
153170
os.unindent();
154171
os << "}\n";
@@ -176,16 +193,17 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
176193

177194
// Always emit splat attributes.
178195
if (isa<SplatElementsAttr>(attr)) {
179-
attr.print(os);
196+
os << escapeLabelString(
197+
strFromOs([&](raw_ostream &os) { attr.print(os); }));
180198
return;
181199
}
182200

183201
// Elide "big" elements attributes.
184202
auto elements = dyn_cast<ElementsAttr>(attr);
185203
if (elements && elements.getNumElements() > largeAttrLimit) {
186204
os << std::string(elements.getShapedType().getRank(), '[') << "..."
187-
<< std::string(elements.getShapedType().getRank(), ']') << " : "
188-
<< elements.getType();
205+
<< std::string(elements.getShapedType().getRank(), ']') << " : ";
206+
emitMlirType(os, elements.getType());
189207
return;
190208
}
191209

@@ -199,27 +217,43 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
199217
std::string buf;
200218
llvm::raw_string_ostream ss(buf);
201219
attr.print(ss);
202-
os << truncateString(buf);
220+
os << escapeLabelString(truncateString(buf));
221+
}
222+
223+
// Print a truncated and escaped MLIR type to `os`.
224+
void emitMlirType(raw_ostream &os, Type type) {
225+
std::string buf;
226+
llvm::raw_string_ostream ss(buf);
227+
type.print(ss);
228+
os << escapeLabelString(truncateString(buf));
229+
}
230+
231+
// Print a truncated and escaped MLIR operand to `os`.
232+
void emitMlirOperand(raw_ostream &os, Value operand) {
233+
operand.printAsOperand(os, OpPrintingFlags());
203234
}
204235

205236
/// Append an edge to the list of edges.
206237
/// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
207-
void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
238+
void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
208239
AttributeMap attrs;
209240
attrs["style"] = style.str();
210-
// Do not label edges that start/end at a cluster boundary. Such edges are
211-
// clipped at the boundary, but labels are not. This can lead to labels
212-
// floating around without any edge next to them.
213-
if (!n1.clusterId && !n2.clusterId)
214-
attrs["label"] = quoteString(escapeString(std::move(label)));
215241
// Use `ltail` and `lhead` to draw edges between clusters.
216242
if (n1.clusterId)
217243
attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
218244
if (n2.clusterId)
219245
attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
220246

221247
edges.push_back(strFromOs([&](raw_ostream &os) {
222-
os << llvm::format("v%i -> v%i ", n1.id, n2.id);
248+
os << "v" << n1.id;
249+
if (!port.empty() && !n1.clusterId)
250+
// Attach edge to south compass point of the result
251+
os << ":res" << port << ":s";
252+
os << " -> ";
253+
os << "v" << n2.id;
254+
if (!port.empty() && !n2.clusterId)
255+
// Attach edge to north compass point of the operand
256+
os << ":arg" << port << ":n";
223257
emitAttrList(os, attrs);
224258
}));
225259
}
@@ -240,20 +274,30 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
240274
StringRef background = "") {
241275
int nodeId = ++counter;
242276
AttributeMap attrs;
243-
attrs["label"] = quoteString(escapeString(std::move(label)));
277+
attrs["label"] = quoteString(label);
244278
attrs["shape"] = shape.str();
245279
if (!background.empty()) {
246280
attrs["style"] = "filled";
247-
attrs["fillcolor"] = ("\"" + background + "\"").str();
281+
attrs["fillcolor"] = quoteString(background.str());
248282
}
249283
os << llvm::format("v%i ", nodeId);
250284
emitAttrList(os, attrs);
251285
os << ";\n";
252286
return Node(nodeId);
253287
}
254288

255-
/// Generate a label for an operation.
256-
std::string getLabel(Operation *op) {
289+
std::string getValuePortName(Value operand) {
290+
// Print value as an operand and omit the leading '%' character.
291+
auto str = strFromOs([&](raw_ostream &os) {
292+
operand.printAsOperand(os, OpPrintingFlags());
293+
});
294+
// Replace % and # with _
295+
std::replace(str.begin(), str.end(), '%', '_');
296+
std::replace(str.begin(), str.end(), '#', '_');
297+
return str;
298+
}
299+
300+
std::string getClusterLabel(Operation *op) {
257301
return strFromOs([&](raw_ostream &os) {
258302
// Print operation name and type.
259303
os << op->getName();
@@ -267,18 +311,73 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
267311

268312
// Print attributes.
269313
if (printAttrs) {
270-
os << "\n";
314+
os << "\\l";
315+
for (const NamedAttribute &attr : op->getAttrs()) {
316+
os << escapeLabelString(attr.getName().getValue().str()) << ": ";
317+
emitMlirAttr(os, attr.getValue());
318+
os << "\\l";
319+
}
320+
}
321+
});
322+
}
323+
324+
/// Generate a label for an operation.
325+
std::string getRecordLabel(Operation *op) {
326+
return strFromOs([&](raw_ostream &os) {
327+
os << "{";
328+
329+
// Print operation inputs.
330+
if (op->getNumOperands() > 0) {
331+
os << "{";
332+
auto operandToPort = [&](Value operand) {
333+
os << "<arg" << getValuePortName(operand) << "> ";
334+
emitMlirOperand(os, operand);
335+
};
336+
interleave(op->getOperands(), os, operandToPort, "|");
337+
os << "}|";
338+
}
339+
// Print operation name and type.
340+
os << op->getName() << "\\l";
341+
342+
// Print attributes.
343+
if (printAttrs && !op->getAttrs().empty()) {
344+
// Extra line break to separate attributes from the operation name.
345+
os << "\\l";
271346
for (const NamedAttribute &attr : op->getAttrs()) {
272-
os << '\n' << attr.getName().getValue() << ": ";
347+
os << attr.getName().getValue() << ": ";
273348
emitMlirAttr(os, attr.getValue());
349+
os << "\\l";
274350
}
275351
}
352+
353+
if (op->getNumResults() > 0) {
354+
os << "|{";
355+
auto resultToPort = [&](Value result) {
356+
os << "<res" << getValuePortName(result) << "> ";
357+
emitMlirOperand(os, result);
358+
if (printResultTypes) {
359+
os << " ";
360+
emitMlirType(os, result.getType());
361+
}
362+
};
363+
interleave(op->getResults(), os, resultToPort, "|");
364+
os << "}";
365+
}
366+
367+
os << "}";
276368
});
277369
}
278370

279371
/// Generate a label for a block argument.
280372
std::string getLabel(BlockArgument arg) {
281-
return "arg" + std::to_string(arg.getArgNumber());
373+
return strFromOs([&](raw_ostream &os) {
374+
os << "<res" << getValuePortName(arg) << "> ";
375+
arg.printAsOperand(os, OpPrintingFlags());
376+
if (printResultTypes) {
377+
os << " ";
378+
emitMlirType(os, arg.getType());
379+
}
380+
});
282381
}
283382

284383
/// Process a block. Emit a cluster and one node per block argument and
@@ -287,14 +386,12 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
287386
emitClusterStmt([&]() {
288387
for (BlockArgument &blockArg : block.getArguments())
289388
valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
290-
291389
// Emit a node for each operation.
292390
std::optional<Node> prevNode;
293391
for (Operation &op : block) {
294392
Node nextNode = processOperation(&op);
295393
if (printControlFlowEdges && prevNode)
296-
emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
297-
kLineStyleControlFlow);
394+
emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
298395
prevNode = nextNode;
299396
}
300397
});
@@ -311,18 +408,19 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
311408
for (Region &region : op->getRegions())
312409
processRegion(region);
313410
},
314-
getLabel(op));
411+
getClusterLabel(op));
315412
} else {
316-
node = emitNodeStmt(getLabel(op), kShapeNode,
413+
node = emitNodeStmt(getRecordLabel(op), kShapeNode,
317414
backgroundColors[op->getName()].second);
318415
}
319416

320417
// Insert data flow edges originating from each operand.
321418
if (printDataFlowEdges) {
322419
unsigned numOperands = op->getNumOperands();
323-
for (unsigned i = 0; i < numOperands; i++)
324-
dataFlowEdges.push_back({op->getOperand(i), node,
325-
numOperands == 1 ? "" : std::to_string(i)});
420+
for (unsigned i = 0; i < numOperands; i++) {
421+
auto operand = op->getOperand(i);
422+
dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
423+
}
326424
}
327425

328426
for (Value result : op->getResults())
@@ -352,7 +450,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
352450
/// Mapping of SSA values to Graphviz nodes/clusters.
353451
DenseMap<Value, Node> valueToNode;
354452
/// Output for data flow edges is delayed until the end to handle cycles
355-
std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
453+
std::vector<DataFlowEdge> dataFlowEdges;
356454
/// Counter for generating unique node/subgraph identifiers.
357455
int counter = 0;
358456

mlir/test/Transforms/print-op-graph-back-edges.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
// RUN: mlir-opt -view-op-graph %s -o %t 2>&1 | FileCheck -check-prefix=DFG %s
22

33
// DFG-LABEL: digraph G {
4-
// DFG: compound = true;
5-
// DFG: subgraph cluster_1 {
6-
// DFG: v2 [label = " ", shape = plain];
7-
// DFG: label = "builtin.module : ()\n";
8-
// DFG: subgraph cluster_3 {
9-
// DFG: v4 [label = " ", shape = plain];
10-
// DFG: label = "";
11-
// DFG: v5 [fillcolor = "0.000000 1.0 1.0", label = "arith.addi : (index)\n\noverflowFlags: #arith.overflow<none...", shape = ellipse, style = filled];
12-
// DFG: v6 [fillcolor = "0.333333 1.0 1.0", label = "arith.constant : (index)\n\nvalue: 0 : index", shape = ellipse, style = filled];
13-
// DFG: v7 [fillcolor = "0.333333 1.0 1.0", label = "arith.constant : (index)\n\nvalue: 1 : index", shape = ellipse, style = filled];
14-
// DFG: }
15-
// DFG: }
16-
// DFG: v6 -> v5 [label = "0", style = solid];
17-
// DFG: v7 -> v5 [label = "1", style = solid];
18-
// DFG: }
4+
// DFG-NEXT: compound = true;
5+
// DFG-NEXT: subgraph cluster_1 {
6+
// DFG-NEXT: v2 [label = " ", shape = plain];
7+
// DFG-NEXT: label = "builtin.module : ()\l";
8+
// DFG-NEXT: subgraph cluster_3 {
9+
// DFG-NEXT: v4 [label = " ", shape = plain];
10+
// DFG-NEXT: label = "";
11+
// DFG-NEXT: v5 [fillcolor = "0.000000 0.3 0.95", label = "{{\{\{}}<arg_c0> %c0|<arg_c1> %c1}|arith.addi\l\loverflowFlags: #arith.overflow\<none...\l|{<res_0> %0 index}}", shape = Mrecord, style = filled];
12+
// DFG-NEXT: v6 [fillcolor = "0.333333 0.3 0.95", label = "{arith.constant\l\lvalue: 0 : index\l|{<res_c0> %c0 index}}", shape = Mrecord, style = filled];
13+
// DFG-NEXT: v7 [fillcolor = "0.333333 0.3 0.95", label = "{arith.constant\l\lvalue: 1 : index\l|{<res_c1> %c1 index}}", shape = Mrecord, style = filled];
14+
// DFG-NEXT: }
15+
// DFG-NEXT: }
16+
// DFG-NEXT: v6:res_c0:s -> v5:arg_c0:n[style = solid];
17+
// DFG-NEXT: v7:res_c1:s -> v5:arg_c1:n[style = solid];
18+
// DFG-NEXT: }
1919

2020
module {
2121
%add = arith.addi %c0, %c1 : index

0 commit comments

Comments
 (0)