@@ -89,6 +89,7 @@ def debugprint(
89
89
| Sequence [Variable | Apply | Function | FunctionGraph ],
90
90
depth : int = - 1 ,
91
91
print_type : bool = False ,
92
+ print_shape : bool = False ,
92
93
file : Literal ["str" ] | TextIO | None = None ,
93
94
id_type : IDTypesType = "CHAR" ,
94
95
stop_on_name : bool = False ,
@@ -98,6 +99,7 @@ def debugprint(
98
99
print_op_info : bool = False ,
99
100
print_destroy_map : bool = False ,
100
101
print_view_map : bool = False ,
102
+ print_memory_map : bool = False ,
101
103
print_fgraph_inputs : bool = False ,
102
104
) -> str | TextIO :
103
105
r"""Print a graph as text.
@@ -123,6 +125,8 @@ def debugprint(
123
125
Print graph to this depth (``-1`` for unlimited).
124
126
print_type
125
127
If ``True``, print the `Type`\s of each `Variable` in the graph.
128
+ print_shape
129
+ If ``True``, print the shape of each `Variable` in the graph.
126
130
file
127
131
When `file` extends `TextIO`, print to it; when `file` is
128
132
equal to ``"str"``, return a string; when `file` is ``None``, print to
@@ -153,6 +157,8 @@ def debugprint(
153
157
Whether to print the `destroy_map`\s of printed objects
154
158
print_view_map
155
159
Whether to print the `view_map`\s of printed objects
160
+ print_memory_map
161
+ Whether to set both `print_destroy_map` and `print_view_map` to ``True``.
156
162
print_fgraph_inputs
157
163
Print the inputs of `FunctionGraph`\s.
158
164
@@ -177,6 +183,10 @@ def debugprint(
177
183
if used_ids is None :
178
184
used_ids = dict ()
179
185
186
+ if print_memory_map :
187
+ print_destroy_map = True
188
+ print_view_map = True
189
+
180
190
inputs_to_print = []
181
191
outputs_to_print = []
182
192
profile_list : list [Any | None ] = []
@@ -265,6 +275,7 @@ def debugprint(
265
275
depth = depth ,
266
276
done = done ,
267
277
print_type = print_type ,
278
+ print_shape = print_shape ,
268
279
file = _file ,
269
280
id_type = id_type ,
270
281
inner_graph_ops = inner_graph_vars ,
@@ -295,6 +306,7 @@ def debugprint(
295
306
depth = depth ,
296
307
done = done ,
297
308
print_type = print_type ,
309
+ print_shape = print_shape ,
298
310
file = _file ,
299
311
topo_order = topo_order ,
300
312
id_type = id_type ,
@@ -365,6 +377,7 @@ def debugprint(
365
377
depth = depth ,
366
378
done = done ,
367
379
print_type = print_type ,
380
+ print_shape = print_shape ,
368
381
file = _file ,
369
382
id_type = id_type ,
370
383
inner_graph_ops = inner_graph_vars ,
@@ -387,6 +400,7 @@ def debugprint(
387
400
depth = depth ,
388
401
done = done ,
389
402
print_type = print_type ,
403
+ print_shape = print_shape ,
390
404
file = _file ,
391
405
id_type = id_type ,
392
406
stop_on_name = stop_on_name ,
@@ -421,6 +435,7 @@ def debugprint(
421
435
depth = depth ,
422
436
done = done ,
423
437
print_type = print_type ,
438
+ print_shape = print_shape ,
424
439
file = _file ,
425
440
id_type = id_type ,
426
441
stop_on_name = stop_on_name ,
@@ -452,6 +467,7 @@ def _debugprint(
452
467
depth : int = - 1 ,
453
468
done : dict [Literal ["output" ] | Variable | Apply , str ] | None = None ,
454
469
print_type : bool = False ,
470
+ print_shape : bool = False ,
455
471
file : TextIO = sys .stdout ,
456
472
print_destroy_map : bool = False ,
457
473
print_view_map : bool = False ,
@@ -484,6 +500,8 @@ def _debugprint(
484
500
See `debugprint`.
485
501
print_type
486
502
See `debugprint`.
503
+ print_shape
504
+ See `debugprint`.
487
505
file
488
506
File-like object to which to print.
489
507
print_destroy_map
@@ -532,6 +550,11 @@ def _debugprint(
532
550
else :
533
551
type_str = ""
534
552
553
+ if print_shape and hasattr (var .type , "shape" ):
554
+ shape_str = f" shape={ str (var .type .shape ).replace ('None' , '?' )} "
555
+ else :
556
+ shape_str = ""
557
+
535
558
if prefix_child is None :
536
559
prefix_child = prefix
537
560
@@ -612,7 +635,7 @@ def get_id_str(
612
635
if is_inner_graph_header :
613
636
var_output = f"{ prefix } { node .op } { id_str } { destroy_map_str } { view_map_str } { o } "
614
637
else :
615
- var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
638
+ var_output = f"{ prefix } { node .op } { output_idx } { id_str } { type_str } { shape_str } { var_name } { destroy_map_str } { view_map_str } { o } { data } "
616
639
617
640
if print_op_info and node not in op_information :
618
641
op_information .update (op_debug_information (node .op , node ))
@@ -662,6 +685,7 @@ def get_id_str(
662
685
depth = depth - 1 ,
663
686
done = _done ,
664
687
print_type = print_type ,
688
+ print_shape = print_shape ,
665
689
file = file ,
666
690
topo_order = topo_order ,
667
691
id_type = id_type ,
@@ -692,7 +716,7 @@ def get_id_str(
692
716
else :
693
717
data = ""
694
718
695
- var_output = f"{ prefix } { var } { id_str } { type_str } { data } "
719
+ var_output = f"{ prefix } { var } { id_str } { type_str } { shape_str } { data } "
696
720
697
721
if print_op_info and var .owner and var .owner not in op_information :
698
722
op_information .update (op_debug_information (var .owner .op , var .owner ))
0 commit comments