Skip to content

Commit b74cf3f

Browse files
committed
Simplify string representation of Elemwise and CAReduce
1 parent 5841c30 commit b74cf3f

File tree

9 files changed

+126
-156
lines changed

9 files changed

+126
-156
lines changed

pytensor/graph/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def __eq__(self, other):
234234

235235
dct["__eq__"] = __eq__
236236

237+
# FIXME: This overrides __str__ inheritance when props are provided
237238
if "__str__" not in dct:
238239
if len(props) == 0:
239240

pytensor/tensor/elemwise.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.link.c.params_type import ParamsType
1515
from pytensor.misc.frozendict import frozendict
1616
from pytensor.misc.safe_asarray import _asarray
17-
from pytensor.printing import FunctionPrinter, Printer, pprint
17+
from pytensor.printing import Printer, pprint
1818
from pytensor.scalar import get_scalar_type
1919
from pytensor.scalar.basic import bool as scalar_bool
2020
from pytensor.scalar.basic import identity as scalar_identity
@@ -498,15 +498,9 @@ def make_node(self, *inputs):
498498
return Apply(self, inputs, outputs)
499499

500500
def __str__(self):
501-
if self.name is None:
502-
if self.inplace_pattern:
503-
items = list(self.inplace_pattern.items())
504-
items.sort()
505-
return f"{type(self).__name__}{{{self.scalar_op}}}{items}"
506-
else:
507-
return f"{type(self).__name__}{{{self.scalar_op}}}"
508-
else:
501+
if self.name:
509502
return self.name
503+
return str(self.scalar_op).capitalize()
510504

511505
def R_op(self, inputs, eval_points):
512506
outs = self(*inputs, return_list=True)
@@ -1477,23 +1471,17 @@ def clone(
14771471

14781472
return res
14791473

1480-
def __str__(self):
1481-
prefix = f"{type(self).__name__}{{{self.scalar_op}}}"
1482-
extra_params = []
1483-
1484-
if self.axis is not None:
1485-
axis = ", ".join(str(x) for x in self.axis)
1486-
extra_params.append(f"axis=[{axis}]")
1487-
1488-
if self.acc_dtype:
1489-
extra_params.append(f"acc_dtype={self.acc_dtype}")
1490-
1491-
extra_params_str = ", ".join(extra_params)
1492-
1493-
if extra_params_str:
1494-
return f"{prefix}{{{extra_params_str}}}"
1474+
def _axis_str(self):
1475+
axis = self.axis
1476+
if axis is None:
1477+
return "axes=None"
1478+
elif len(axis) == 1:
1479+
return f"axis={axis[0]}"
14951480
else:
1496-
return f"{prefix}"
1481+
return f"axes={list(axis)}"
1482+
1483+
def __str__(self):
1484+
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
14971485

14981486
def perform(self, node, inp, out):
14991487
(input,) = inp
@@ -1737,21 +1725,17 @@ def construct(symbol):
17371725
symbolname = symbolname or symbol.__name__
17381726

17391727
if symbolname.endswith("_inplace"):
1740-
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
1741-
scalar_op = getattr(scalar, symbolname[: -len("_inplace")])
1728+
base_symbol_name = symbolname[: -len("_inplace")]
1729+
scalar_op = getattr(scalar, base_symbol_name)
17421730
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
17431731
rval = Elemwise(
17441732
inplace_scalar_op,
17451733
{0: 0},
1746-
name=elemwise_name,
17471734
nfunc_spec=(nfunc and (nfunc, nin, nout)),
17481735
)
17491736
else:
1750-
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
17511737
scalar_op = getattr(scalar, symbolname)
1752-
rval = Elemwise(
1753-
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
1754-
)
1738+
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
17551739

17561740
if getattr(symbol, "__doc__"):
17571741
rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__
@@ -1761,8 +1745,6 @@ def construct(symbol):
17611745
rval.__epydoc_asRoutine = symbol
17621746
rval.__module__ = symbol.__module__
17631747

1764-
pprint.assign(rval, FunctionPrinter([symbolname.replace("_inplace", "=")]))
1765-
17661748
return rval
17671749

17681750
if symbol:

pytensor/tensor/math.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False):
583583
return [out, argout]
584584

585585

586-
class NonZeroCAReduce(CAReduce):
586+
class FixedOpCAReduce(CAReduce):
587+
def __str__(self):
588+
return f"{type(self).__name__}{{{self._axis_str()}}}"
589+
590+
591+
class NonZeroDimsCAReduce(FixedOpCAReduce):
587592
def _c_all(self, node, name, inames, onames, sub):
588593
decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub)
589594

@@ -614,7 +619,7 @@ def _c_all(self, node, name, inames, onames, sub):
614619
return decl, checks, alloc, loop, end
615620

616621

617-
class Max(NonZeroCAReduce):
622+
class Max(NonZeroDimsCAReduce):
618623
nfunc_spec = ("max", 1, 1)
619624

620625
def __init__(self, axis):
@@ -625,7 +630,7 @@ def clone(self, **kwargs):
625630
return type(self)(axis=axis)
626631

627632

628-
class Min(NonZeroCAReduce):
633+
class Min(NonZeroDimsCAReduce):
629634
nfunc_spec = ("min", 1, 1)
630635

631636
def __init__(self, axis):
@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle):
14961501
"""Return complex-valued tensor from polar coordinate specification."""
14971502

14981503

1499-
class Mean(CAReduce):
1504+
class Mean(FixedOpCAReduce):
15001505
__props__ = ("axis",)
15011506
nfunc_spec = ("mean", 1, 1)
15021507

@@ -2356,7 +2361,7 @@ def outer(x, y):
23562361
return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0))
23572362

23582363

2359-
class All(CAReduce):
2364+
class All(FixedOpCAReduce):
23602365
"""Applies `logical and` to all the values of a tensor along the
23612366
specified axis(es).
23622367
@@ -2370,12 +2375,6 @@ def __init__(self, axis=None):
23702375
def _output_dtype(self, idtype):
23712376
return "bool"
23722377

2373-
def __str__(self):
2374-
if self.axis is None:
2375-
return "All"
2376-
else:
2377-
return "All{%s}" % ", ".join(map(str, self.axis))
2378-
23792378
def make_node(self, input):
23802379
input = as_tensor_variable(input)
23812380
if input.dtype != "bool":
@@ -2392,7 +2391,7 @@ def clone(self, **kwargs):
23922391
return type(self)(axis=axis)
23932392

23942393

2395-
class Any(CAReduce):
2394+
class Any(FixedOpCAReduce):
23962395
"""Applies `bitwise or` to all the values of a tensor along the
23972396
specified axis(es).
23982397
@@ -2406,12 +2405,6 @@ def __init__(self, axis=None):
24062405
def _output_dtype(self, idtype):
24072406
return "bool"
24082407

2409-
def __str__(self):
2410-
if self.axis is None:
2411-
return "Any"
2412-
else:
2413-
return "Any{%s}" % ", ".join(map(str, self.axis))
2414-
24152408
def make_node(self, input):
24162409
input = as_tensor_variable(input)
24172410
if input.dtype != "bool":
@@ -2428,7 +2421,7 @@ def clone(self, **kwargs):
24282421
return type(self)(axis=axis)
24292422

24302423

2431-
class Sum(CAReduce):
2424+
class Sum(FixedOpCAReduce):
24322425
"""
24332426
Sums all the values of a tensor along the specified axis(es).
24342427
@@ -2449,14 +2442,6 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None):
24492442
upcast_discrete_output=True,
24502443
)
24512444

2452-
def __str__(self):
2453-
name = self.__class__.__name__
2454-
axis = ""
2455-
if self.axis is not None:
2456-
axis = ", ".join(str(x) for x in self.axis)
2457-
axis = f"axis=[{axis}], "
2458-
return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}"
2459-
24602445
def L_op(self, inp, out, grads):
24612446
(x,) = inp
24622447

@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
25262511
pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"]))
25272512

25282513

2529-
class Prod(CAReduce):
2514+
class Prod(FixedOpCAReduce):
25302515
"""
25312516
Multiplies all the values of a tensor along the specified axis(es).
25322517
@@ -2537,7 +2522,6 @@ class Prod(CAReduce):
25372522
"""
25382523

25392524
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input")
2540-
25412525
nfunc_spec = ("prod", 1, 1)
25422526

25432527
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
@@ -2683,6 +2667,14 @@ def clone(self, **kwargs):
26832667
no_zeros_in_input=no_zeros_in_input,
26842668
)
26852669

2670+
def __str__(self):
2671+
if self.no_zeros_in_input:
2672+
return f"{super().__str__()[:-1]}, no_zeros_in_input}})"
2673+
return super().__str__()
2674+
2675+
def __repr__(self):
2676+
return f"{super().__repr__()[:-1]}, no_zeros_in_input={self.no_zeros_in_input})"
2677+
26862678

26872679
def prod(
26882680
input,
@@ -2751,7 +2743,7 @@ def c_code_cache_version(self):
27512743
mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros")
27522744

27532745

2754-
class ProdWithoutZeros(CAReduce):
2746+
class ProdWithoutZeros(FixedOpCAReduce):
27552747
def __init__(self, axis=None, dtype=None, acc_dtype=None):
27562748
super().__init__(
27572749
mul_without_zeros,

pytensor/tensor/rewriting/math.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
All,
4343
Any,
4444
Dot,
45-
NonZeroCAReduce,
45+
FixedOpCAReduce,
46+
NonZeroDimsCAReduce,
4647
Prod,
4748
ProdWithoutZeros,
4849
Sum,
@@ -1671,7 +1672,8 @@ def local_op_of_op(fgraph, node):
16711672
ProdWithoutZeros,
16721673
]
16731674
+ CAReduce.__subclasses__()
1674-
+ NonZeroCAReduce.__subclasses__()
1675+
+ FixedOpCAReduce.__subclasses__()
1676+
+ NonZeroDimsCAReduce.__subclasses__()
16751677
)
16761678

16771679

tests/compile/test_builders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,9 @@ def test_debugprint():
579579
Inner graphs:
580580
581581
OpFromGraph{inline=False} [id A]
582-
Elemwise{add,no_inplace} [id E]
582+
Add [id E]
583583
├─ *0-<TensorType(float64, (?, ?))> [id F]
584-
└─ Elemwise{mul,no_inplace} [id G]
584+
└─ Mul [id G]
585585
├─ *1-<TensorType(float64, (?, ?))> [id H]
586586
└─ *2-<TensorType(float64, (?, ?))> [id I]
587587
"""

tests/link/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def func(*args, op=op):
156156

157157
assert (
158158
"""
159-
# Elemwise{add,no_inplace}(Test
159+
# Add(Test
160160
# Op().0, Test
161161
# Op().1)
162162
"""

0 commit comments

Comments
 (0)