Skip to content

Commit b12dc30

Browse files
authored
Add print_shape and print_memory_map option to debugprint (#1236)
* Added print_shape option to debugprint and simplify __str__ logic in TensorType * Add print_memory_map option to debugprint to enable destroy and view maps
1 parent 4c27eb9 commit b12dc30

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

pytensor/printing.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def debugprint(
8989
| Sequence[Variable | Apply | Function | FunctionGraph],
9090
depth: int = -1,
9191
print_type: bool = False,
92+
print_shape: bool = False,
9293
file: Literal["str"] | TextIO | None = None,
9394
id_type: IDTypesType = "CHAR",
9495
stop_on_name: bool = False,
@@ -98,6 +99,7 @@ def debugprint(
9899
print_op_info: bool = False,
99100
print_destroy_map: bool = False,
100101
print_view_map: bool = False,
102+
print_memory_map: bool = False,
101103
print_fgraph_inputs: bool = False,
102104
) -> str | TextIO:
103105
r"""Print a graph as text.
@@ -123,6 +125,8 @@ def debugprint(
123125
Print graph to this depth (``-1`` for unlimited).
124126
print_type
125127
If ``True``, print the `Type`\s of each `Variable` in the graph.
128+
print_shape
129+
If ``True``, print the shape of each `Variable` in the graph.
126130
file
127131
When `file` extends `TextIO`, print to it; when `file` is
128132
equal to ``"str"``, return a string; when `file` is ``None``, print to
@@ -153,6 +157,8 @@ def debugprint(
153157
Whether to print the `destroy_map`\s of printed objects
154158
print_view_map
155159
Whether to print the `view_map`\s of printed objects
160+
print_memory_map
161+
Whether to set both `print_destroy_map` and `print_view_map` to ``True``.
156162
print_fgraph_inputs
157163
Print the inputs of `FunctionGraph`\s.
158164
@@ -177,6 +183,10 @@ def debugprint(
177183
if used_ids is None:
178184
used_ids = dict()
179185

186+
if print_memory_map:
187+
print_destroy_map = True
188+
print_view_map = True
189+
180190
inputs_to_print = []
181191
outputs_to_print = []
182192
profile_list: list[Any | None] = []
@@ -265,6 +275,7 @@ def debugprint(
265275
depth=depth,
266276
done=done,
267277
print_type=print_type,
278+
print_shape=print_shape,
268279
file=_file,
269280
id_type=id_type,
270281
inner_graph_ops=inner_graph_vars,
@@ -295,6 +306,7 @@ def debugprint(
295306
depth=depth,
296307
done=done,
297308
print_type=print_type,
309+
print_shape=print_shape,
298310
file=_file,
299311
topo_order=topo_order,
300312
id_type=id_type,
@@ -365,6 +377,7 @@ def debugprint(
365377
depth=depth,
366378
done=done,
367379
print_type=print_type,
380+
print_shape=print_shape,
368381
file=_file,
369382
id_type=id_type,
370383
inner_graph_ops=inner_graph_vars,
@@ -387,6 +400,7 @@ def debugprint(
387400
depth=depth,
388401
done=done,
389402
print_type=print_type,
403+
print_shape=print_shape,
390404
file=_file,
391405
id_type=id_type,
392406
stop_on_name=stop_on_name,
@@ -421,6 +435,7 @@ def debugprint(
421435
depth=depth,
422436
done=done,
423437
print_type=print_type,
438+
print_shape=print_shape,
424439
file=_file,
425440
id_type=id_type,
426441
stop_on_name=stop_on_name,
@@ -452,6 +467,7 @@ def _debugprint(
452467
depth: int = -1,
453468
done: dict[Literal["output"] | Variable | Apply, str] | None = None,
454469
print_type: bool = False,
470+
print_shape: bool = False,
455471
file: TextIO = sys.stdout,
456472
print_destroy_map: bool = False,
457473
print_view_map: bool = False,
@@ -484,6 +500,8 @@ def _debugprint(
484500
See `debugprint`.
485501
print_type
486502
See `debugprint`.
503+
print_shape
504+
See `debugprint`.
487505
file
488506
File-like object to which to print.
489507
print_destroy_map
@@ -532,6 +550,11 @@ def _debugprint(
532550
else:
533551
type_str = ""
534552

553+
if print_shape and hasattr(var.type, "shape"):
554+
shape_str = f" shape={str(var.type.shape).replace('None', '?')}"
555+
else:
556+
shape_str = ""
557+
535558
if prefix_child is None:
536559
prefix_child = prefix
537560

@@ -612,7 +635,7 @@ def get_id_str(
612635
if is_inner_graph_header:
613636
var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}"
614637
else:
615-
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
638+
var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{shape_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
616639

617640
if print_op_info and node not in op_information:
618641
op_information.update(op_debug_information(node.op, node))
@@ -662,6 +685,7 @@ def get_id_str(
662685
depth=depth - 1,
663686
done=_done,
664687
print_type=print_type,
688+
print_shape=print_shape,
665689
file=file,
666690
topo_order=topo_order,
667691
id_type=id_type,
@@ -692,7 +716,7 @@ def get_id_str(
692716
else:
693717
data = ""
694718

695-
var_output = f"{prefix}{var}{id_str}{type_str}{data}"
719+
var_output = f"{prefix}{var}{id_str}{type_str}{shape_str}{data}"
696720

697721
if print_op_info and var.owner and var.owner not in op_information:
698722
op_information.update(op_debug_information(var.owner.op, var.owner))

pytensor/tensor/type.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -399,22 +399,13 @@ def __str__(self):
399399
else:
400400
shape = self.shape
401401
len_shape = len(shape)
402-
403-
def shape_str(s):
404-
if s is None:
405-
return "?"
406-
else:
407-
return str(s)
408-
409-
formatted_shape = ", ".join(shape_str(s) for s in shape)
410-
if len_shape == 1:
411-
formatted_shape += ","
402+
formatted_shape = str(shape).replace("None", "?")
412403

413404
if len_shape > 2:
414405
name = f"Tensor{len_shape}"
415406
else:
416407
name = ("Scalar", "Vector", "Matrix")[len_shape]
417-
return f"{name}({self.dtype}, shape=({formatted_shape}))"
408+
return f"{name}({self.dtype}, shape={formatted_shape})"
418409

419410
def __repr__(self):
420411
return f"TensorType({self.dtype}, shape={self.shape})"

0 commit comments

Comments
 (0)