|
100 | 100 | values_eq_approx_remove_inf_nan,
|
101 | 101 | values_eq_approx_remove_nan,
|
102 | 102 | )
|
103 |
| -from pytensor.tensor.variable import TensorConstant, get_unique_constant_value |
| 103 | +from pytensor.tensor.variable import ( |
| 104 | + TensorConstant, |
| 105 | + TensorVariable, |
| 106 | + get_unique_constant_value, |
| 107 | +) |
104 | 108 |
|
105 | 109 |
|
106 | 110 | def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
|
@@ -1575,42 +1579,48 @@ def local_sum_prod_all_to_none(fgraph, node):
|
1575 | 1579 |
|
1576 | 1580 |
|
1577 | 1581 | @register_canonicalize
|
1578 |
| -@node_rewriter([Sum, Prod]) |
1579 |
| -def local_op_of_op(fgraph, node): |
| 1582 | +@node_rewriter([CAReduce]) |
| 1583 | +def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None: |
1580 | 1584 | """
|
1581 |
| - Prod(Prod()) -> single Prod() |
1582 |
| - or |
1583 | 1585 | Sum(Sum()) -> single Sum()
|
| 1586 | + or any CAReduce(Careduce(x)) of the same type |
1584 | 1587 |
|
1585 | 1588 | """
|
1586 |
| - op_type = Sum if isinstance(node.op, Sum) else Prod |
1587 |
| - (node_inps,) = node.inputs |
1588 |
| - out_dtype = node.op.dtype |
1589 |
| - # This is done to make sure the rewrite doesn't affect other |
1590 |
| - # computations. |
1591 |
| - if len(fgraph.clients[node_inps]) == 1: |
1592 |
| - if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)): |
1593 |
| - # check to see either the inner or outer prod is doing a |
1594 |
| - # product over all axis, in which case we can remove it |
1595 |
| - if node_inps.owner.op.axis is None or node.op.axis is None: |
1596 |
| - return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])] |
1597 |
| - |
1598 |
| - # figure out which axes were in the original sum |
1599 |
| - newaxis = list(node_inps.owner.op.axis) |
1600 |
| - for i in node.op.axis: |
1601 |
| - new_i = i |
1602 |
| - for ii in node_inps.owner.op.axis: |
1603 |
| - if new_i >= ii: |
1604 |
| - new_i += 1 |
1605 |
| - assert new_i not in newaxis |
1606 |
| - newaxis.append(new_i) |
1607 |
| - |
1608 |
| - assert len(newaxis) == len( |
1609 |
| - list(node_inps.owner.op.axis) + list(node.op.axis) |
1610 |
| - ) |
| 1589 | + [inner_reduce] = node.inputs |
| 1590 | + if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)): |
| 1591 | + return None |
| 1592 | + |
| 1593 | + # Don't apply rewrite if inner_reduce is used elsewhere |
| 1594 | + if len(fgraph.clients[inner_reduce]) > 1: |
| 1595 | + return None |
| 1596 | + |
| 1597 | + # Check if CAReduces have the same scalar op |
| 1598 | + outer_op: CAReduce = node.op |
| 1599 | + inner_op = inner_reduce.owner.op |
| 1600 | + |
| 1601 | + if outer_op.scalar_op != inner_op.scalar_op: |
| 1602 | + return None |
1611 | 1603 |
|
1612 |
| - combined = op_type(newaxis, dtype=out_dtype) |
1613 |
| - return [combined(node_inps.owner.inputs[0])] |
| 1604 | + outer_axis = outer_op.axis |
| 1605 | + inner_axis = inner_op.axis |
| 1606 | + [x] = inner_reduce.owner.inputs |
| 1607 | + # check to see either the inner or outer prod is doing a |
| 1608 | + # product over all axis, in which case we can remove it |
| 1609 | + if outer_axis is None or inner_axis is None: |
| 1610 | + return [outer_op.clone(axis=None)(x)] |
| 1611 | + |
| 1612 | + # Merge axis |
| 1613 | + newaxis = list(inner_axis) |
| 1614 | + for i in outer_axis: |
| 1615 | + new_i = i |
| 1616 | + for ii in inner_axis: |
| 1617 | + if new_i >= ii: |
| 1618 | + new_i += 1 |
| 1619 | + assert new_i not in newaxis |
| 1620 | + newaxis.append(new_i) |
| 1621 | + |
| 1622 | + assert len(newaxis) == len(inner_axis) + len(outer_axis) |
| 1623 | + return [outer_op.clone(axis=sorted(newaxis))(x)] |
1614 | 1624 |
|
1615 | 1625 |
|
1616 | 1626 | @register_canonicalize
|
|
0 commit comments