diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 01aae569e0..06d023d780 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2120,10 +2120,6 @@ def local_pow_to_nested_squaring(fgraph, node): rval = [rval1] if rval: rval[0] = cast(rval[0], odtype) - # TODO: We can add a specify_broadcastable and/or unbroadcast to make the - # output types compatible. Or work on #408 and let TensorType.filter_variable do it. - if rval[0].type.broadcastable != node.outputs[0].type.broadcastable: - return None return rval diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 49b417fde4..84322989bf 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -29,7 +29,7 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.misc.safe_asarray import _asarray from pytensor.printing import debugprint -from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma +from pytensor.scalar import PolyGamma, Psi, TriGamma from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv @@ -1757,7 +1757,7 @@ def test_local_pow_to_nested_squaring(): utt.assert_allclose(f(val_no0), val_no0 ** (-16)) -def test_local_pow_to_nested_squaring_fails_gracefully(): +def test_local_pow_to_nested_squaring_works_with_static_type(): # Reported in #456 x = vector("x", shape=(1,)) @@ -1771,12 +1771,6 @@ def test_local_pow_to_nested_squaring_fails_gracefully(): fn = function([x], y) - # Check rewrite is not applied (this could change in the future) - assert any( - (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow)) - for node in fn.maker.fgraph.apply_nodes - ) - np.testing.assert_allclose(fn([2.0]), np.array([4.0]))