Skip to content

Commit d34760d

Browse files
committed
Add dprint method to Nodes
1 parent 1117ea5 commit d34760d

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

pytensor/graph/basic.py

+12
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ def get_parents(self):
7575
"""
7676
raise NotImplementedError()
7777

78+
def dprint(self, **kwargs):
79+
"""Debug print itself
80+
81+
Parameters
82+
----------
83+
kwargs:
84+
Optional keyword arguments to pass to debugprint function.
85+
"""
86+
from pytensor.printing import debugprint
87+
88+
return debugprint(self, **kwargs)
89+
7890

7991
class Apply(Node, Generic[OpType]):
8092
"""A `Node` representing the application of an operation to inputs.

tests/graph/test_basic.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from pytensor.graph.op import Op
3333
from pytensor.graph.type import Type
34+
from pytensor.printing import debugprint
3435
from pytensor.tensor import constant
3536
from pytensor.tensor.math import max_and_argmax
3637
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
@@ -869,3 +870,10 @@ def test_single_pass_per_node(self, mocker):
869870
assert len(inspect.call_args_list) == len(
870871
{a for ((a, b), kw) in inspect.call_args_list}
871872
)
873+
874+
875+
def test_dprint():
876+
r1, r2 = MyVariable(1), MyVariable(2)
877+
o1 = MyOp(r1, r2)
878+
assert o1.dprint(file="str") == debugprint(o1, file="str")
879+
assert o1.owner.dprint(file="str") == debugprint(o1.owner, file="str")

0 commit comments

Comments
 (0)