Skip to content

Commit aa1a813

Browse files
ricardoV94Ch0ronomato
authored andcommitted
Generalize and simplify local_reduce_join
1 parent f449291 commit aa1a813

File tree

2 files changed

+91
-71
lines changed

2 files changed

+91
-71
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
register_uncanonicalize,
9292
register_useless,
9393
)
94+
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
9495
from pytensor.tensor.shape import Shape, Shape_i
9596
from pytensor.tensor.subtensor import Subtensor
9697
from pytensor.tensor.type import (
@@ -1628,68 +1629,55 @@ def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
16281629
@node_rewriter([CAReduce])
16291630
def local_reduce_join(fgraph, node):
16301631
"""
1631-
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
1632+
CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)
16321633
1633-
Notes
1634-
-----
1635-
Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in
1636-
all cases.
1637-
1638-
Currently we must reduce on axis 0. It is probably extensible to the case
1639-
where we join and reduce on the same set of axis.
1634+
When a, b have a dim length of 1 along the join axis
16401635
16411636
"""
1642-
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
1643-
join_node = node.inputs[0].owner
1644-
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
1645-
return
1637+
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
1638+
return None
16461639

1647-
if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum):
1648-
# Support only 2 inputs for now
1649-
if len(join_node.inputs) != 3:
1650-
return
1651-
elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
1652-
return
1653-
elif len(join_node.inputs) <= 2:
1654-
# This is a useless join that should get removed by another rewrite?
1655-
return
1640+
[joined_out] = node.inputs
1641+
joined_node = joined_out.owner
1642+
join_axis_tensor, *joined_inputs = joined_node.inputs
16561643

1657-
new_inp = []
1658-
for inp in join_node.inputs[1:]:
1659-
inp = inp.owner
1660-
if not inp:
1661-
return
1662-
if not isinstance(inp.op, DimShuffle) or inp.op.new_order != (
1663-
"x",
1664-
*range(inp.inputs[0].ndim),
1665-
):
1666-
return
1667-
new_inp.append(inp.inputs[0])
1668-
ret = Elemwise(node.op.scalar_op)(*new_inp)
1644+
n_joined_inputs = len(joined_inputs)
1645+
if n_joined_inputs < 2:
1646+
# Let some other rewrite get rid of this useless Join
1647+
return None
1648+
if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
1649+
# We don't rewrite if a single Elemwise cannot take all inputs at once
1650+
return None
16691651

1670-
if ret.dtype != node.outputs[0].dtype:
1671-
# The reduction do something about the dtype.
1672-
return
1652+
if not isinstance(join_axis_tensor, Constant):
1653+
return None
1654+
join_axis = join_axis_tensor.data
16731655

1674-
reduce_axis = node.op.axis
1675-
if reduce_axis is None:
1676-
reduce_axis = tuple(range(node.inputs[0].ndim))
1656+
# Check whether reduction happens on joined axis
1657+
reduce_op = node.op
1658+
reduce_axis = reduce_op.axis
1659+
if reduce_axis is None:
1660+
if joined_out.type.ndim > 1:
1661+
return None
1662+
elif reduce_axis != (join_axis,):
1663+
return None
16771664

1678-
if len(reduce_axis) != 1 or 0 not in reduce_axis:
1679-
return
1665+
# Check all inputs are broadcastable along the join axis and squeeze those dims away
1666+
new_inputs = []
1667+
for inp in joined_inputs:
1668+
if not inp.type.broadcastable[join_axis]:
1669+
return None
1670+
# Most times inputs to join have an expand_dims, we eagerly clean up those here
1671+
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
1672+
new_inputs.append(new_input)
16801673

1681-
# We add the new check late to don't add extra warning.
1682-
try:
1683-
join_axis = get_underlying_scalar_constant_value(
1684-
join_node.inputs[0], only_process_constants=True
1685-
)
1674+
ret = Elemwise(node.op.scalar_op)(*new_inputs)
16861675

1687-
if join_axis != reduce_axis[0]:
1688-
return
1689-
except NotScalarConstantError:
1690-
return
1676+
if ret.dtype != node.outputs[0].dtype:
1677+
# The reduction do something about the dtype.
1678+
return None
16911679

1692-
return [ret]
1680+
return [ret]
16931681

16941682

16951683
@register_infer_shape

tests/tensor/rewriting/test_math.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,7 +3231,7 @@ def test_local_prod_of_div(self):
32313231
class TestLocalReduce:
32323232
def setup_method(self):
32333233
self.mode = get_default_mode().including(
3234-
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
3234+
"canonicalize", "specialize", "uncanonicalize"
32353235
)
32363236

32373237
def test_local_reduce_broadcast_all_0(self):
@@ -3304,62 +3304,94 @@ def test_local_reduce_broadcast_some_1(self):
33043304
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
33053305
)
33063306

3307-
def test_local_reduce_join(self):
3307+
3308+
class TestReduceJoin:
3309+
def setup_method(self):
3310+
self.mode = get_default_mode().including(
3311+
"canonicalize", "specialize", "uncanonicalize"
3312+
)
3313+
3314+
@pytest.mark.parametrize(
3315+
"op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
3316+
)
3317+
def test_local_reduce_join(self, op, nin):
33083318
vx = matrix()
33093319
vy = matrix()
33103320
vz = matrix()
33113321
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
33123322
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
33133323
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
3314-
# Test different reduction scalar operation
3315-
for out, res in [
3316-
(pt_max((vx, vy), 0), np.max((x, y), 0)),
3317-
(pt_min((vx, vy), 0), np.min((x, y), 0)),
3318-
(pt_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
3319-
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
3320-
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
3321-
]:
3322-
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
3323-
assert (f(x, y, z) == res).all(), out
3324-
topo = f.maker.fgraph.toposort()
3325-
assert len(topo) <= 2, out
3326-
assert isinstance(topo[-1].op, Elemwise), out
33273324

3325+
inputs = (vx, vy, vz)[:nin]
3326+
test_values = (x, y, z)[:nin]
3327+
3328+
out = op(inputs, axis=0)
3329+
f = function(inputs, out, mode=self.mode)
3330+
np.testing.assert_allclose(
3331+
f(*test_values), getattr(np, op.__name__)(test_values, axis=0)
3332+
)
3333+
topo = f.maker.fgraph.toposort()
3334+
assert len(topo) <= 2
3335+
assert isinstance(topo[-1].op, Elemwise)
3336+
3337+
def test_type(self):
33283338
# Test different axis for the join and the reduction
33293339
# We must force the dtype, of otherwise, this tests will fail
33303340
# on 32 bit systems
33313341
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
33323342

33333343
f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode)
3334-
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
3344+
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
33353345
topo = f.maker.fgraph.toposort()
33363346
assert isinstance(topo[-1].op, Elemwise)
33373347

33383348
# Test a case that was bugged in a old PyTensor bug
33393349
f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode)
33403350

3341-
utt.assert_allclose(f(), [15, 15])
3351+
np.testing.assert_allclose(f(), [15, 15])
33423352
topo = f.maker.fgraph.toposort()
33433353
assert not isinstance(topo[-1].op, Elemwise)
33443354

33453355
# This case could be rewritten
33463356
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
33473357
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
3348-
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
3358+
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
33493359
topo = f.maker.fgraph.toposort()
33503360
assert not isinstance(topo[-1].op, Elemwise)
33513361

33523362
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
33533363
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
3354-
utt.assert_allclose(f(), [15, 15])
3364+
np.testing.assert_allclose(f(), [15, 15])
33553365
topo = f.maker.fgraph.toposort()
33563366
assert not isinstance(topo[-1].op, Elemwise)
33573367

3368+
def test_not_supported_axis_none(self):
33583369
# Test that the rewrite does not crash in one case where it
33593370
# is not applied. Reported at
33603371
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
3372+
vx = matrix()
3373+
vy = matrix()
3374+
vz = matrix()
3375+
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
3376+
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
3377+
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
3378+
33613379
out = pt_sum([vx, vy, vz], axis=None)
3362-
f = function([vx, vy, vz], out)
3380+
f = function([vx, vy, vz], out, mode=self.mode)
3381+
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))
3382+
3383+
def test_not_supported_unequal_shapes(self):
3384+
# Not the same shape along the join axis
3385+
vx = matrix(shape=(1, 3))
3386+
vy = matrix(shape=(2, 3))
3387+
x = np.asarray([[1, 0, 1]], dtype=config.floatX)
3388+
y = np.asarray([[4, 0, 1], [2, 1, 1]], dtype=config.floatX)
3389+
out = pt_sum(join(0, vx, vy), axis=0)
3390+
3391+
f = function([vx, vy], out, mode=self.mode)
3392+
np.testing.assert_allclose(
3393+
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
3394+
)
33633395

33643396

33653397
def test_local_useless_adds():

0 commit comments

Comments
 (0)