|
4 | 4 | import numpy as np
|
5 | 5 |
|
6 | 6 | from pytensor import Variable
|
| 7 | +from pytensor.compile import optdb |
7 | 8 | from pytensor.graph import Constant, FunctionGraph, node_rewriter
|
8 | 9 | from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
|
9 | 10 | from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
|
|
37 | 38 | )
|
38 | 39 | from pytensor.tensor.special import Softmax, softmax
|
39 | 40 | from pytensor.tensor.subtensor import (
|
| 41 | + AdvancedSubtensor, |
40 | 42 | AdvancedSubtensor1,
|
41 | 43 | Subtensor,
|
| 44 | + _non_consecutive_adv_indexing, |
42 | 45 | as_index_literal,
|
43 | 46 | get_canonical_form_slice,
|
44 | 47 | get_constant_idx,
|
45 | 48 | get_idx_list,
|
46 | 49 | indices_from_subtensor,
|
47 | 50 | )
|
48 | 51 | from pytensor.tensor.type import TensorType
|
49 |
| -from pytensor.tensor.type_other import SliceType |
| 52 | +from pytensor.tensor.type_other import NoneTypeT, SliceType |
50 | 53 | from pytensor.tensor.variable import TensorVariable
|
51 | 54 |
|
52 | 55 |
|
@@ -769,3 +772,79 @@ def local_subtensor_shape_constant(fgraph, node):
|
769 | 772 | return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
|
770 | 773 | elif shape_parts:
|
771 | 774 | return [as_tensor(1, dtype=np.int64)]
|
| 775 | + |
| 776 | + |
| 777 | +@node_rewriter([Subtensor]) |
| 778 | +def local_subtensor_of_adv_subtensor(fgraph, node): |
| 779 | + """Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones. |
| 780 | +
|
| 781 | + x[:, :, vec_idx][i, j] -> x[i, j][vec_idx] |
| 782 | + x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k] |
| 783 | +
|
| 784 | + Restricted to a single advanced indexing dimension. |
| 785 | +
|
| 786 | + An alternative approach could have fused the basic and advanced indices, |
| 787 | + so it is not clear this rewrite should be canonical or a specialization. |
| 788 | + Users must include it manually if it fits their use case. |
| 789 | + """ |
| 790 | + adv_subtensor, *idxs = node.inputs |
| 791 | + |
| 792 | + if not ( |
| 793 | + adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor) |
| 794 | + ): |
| 795 | + return None |
| 796 | + |
| 797 | + if len(fgraph.clients[adv_subtensor]) > 1: |
| 798 | + # AdvancedSubtensor involves a full_copy, so we don't want to do it twice |
| 799 | + return None |
| 800 | + |
| 801 | + x, *adv_idxs = adv_subtensor.owner.inputs |
| 802 | + |
| 803 | + # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices |
| 804 | + if any( |
| 805 | + ( |
| 806 | + isinstance(adv_idx.type, NoneTypeT) |
| 807 | + or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool") |
| 808 | + or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx)) |
| 809 | + ) |
| 810 | + for adv_idx in adv_idxs |
| 811 | + ) or _non_consecutive_adv_indexing(adv_idxs): |
| 812 | + return None |
| 813 | + |
| 814 | + for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): |
| 815 | + # We already made sure there were only None slices besides integer indexes |
| 816 | + if isinstance(adv_idx.type, TensorType): |
| 817 | + break |
| 818 | + else: # no-break |
| 819 | + # Not sure if this should ever happen, but better safe than sorry |
| 820 | + return None |
| 821 | + |
| 822 | + basic_idxs = indices_from_subtensor(idxs, node.op.idx_list) |
| 823 | + basic_idxs_lifted = basic_idxs[:first_adv_idx_dim] |
| 824 | + basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[ |
| 825 | + first_adv_idx_dim: |
| 826 | + ] |
| 827 | + |
| 828 | + if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted): |
| 829 | + # All basic indices happen to the right of the advanced indices |
| 830 | + return None |
| 831 | + |
| 832 | + [basic_subtensor] = node.outputs |
| 833 | + dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted) |
| 834 | + |
| 835 | + x_indexed = x[basic_idxs_lifted] |
| 836 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) |
| 837 | + |
| 838 | + x_after_index_lift = expand_dims(x_indexed, dropped_dims) |
| 839 | + x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs) |
| 840 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) |
| 841 | + |
| 842 | + new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) |
| 843 | + return [new_out] |
| 844 | + |
| 845 | + |
| 846 | +# Rewrite will only be included if tagged by name |
| 847 | +r = local_subtensor_of_adv_subtensor |
| 848 | +optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 849 | +optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 850 | +del r |
0 commit comments