Skip to content

Speedup AdvancedSubtensor1 and AdvancedIncSubtensor1 in C backend #1346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
if len(indices) == 1:
indices = indices[0]

if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)

Check warning on line 71 in pytensor/link/jax/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/subtensor.py#L71

Added line #L71 was not covered by tests

return jax_fn(x, indices, y)

return incsubtensor
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def cholesky(a):
@numba_funcify.register(PivotToPermutations)
def pivot_to_permutation(op, node, **kwargs):
inverse = op.inverse
dtype = node.inputs[0].dtype
dtype = node.outputs[0].dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a dumb mistake, I wonder which idiot wrote this code


@numba_njit
def numba_pivot_to_permutation(piv):
Expand Down
8 changes: 4 additions & 4 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
x, vals, idxs = node.inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
# TODO: Add runtime_broadcast check

if set_instead_of_inc:
if broadcast:
if broadcast_with_index:

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
Expand All @@ -318,7 +318,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
x[idx] = val
return x
else:
if broadcast:
if broadcast_with_index:

@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
Expand Down
4 changes: 4 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@

def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)

Check warning on line 113 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L113

Added line #L113 was not covered by tests
if not inplace:
x = x.clone()
x[indices] = y.type_as(x)
Expand All @@ -120,6 +122,8 @@

def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)

Check warning on line 126 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L126

Added line #L126 was not covered by tests
if not inplace:
x = x.clone()
x[indices] += y.type_as(x)
Expand Down
17 changes: 14 additions & 3 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,14 @@ def _check_runtime_broadcast(node, value, shape):
if v_static_dim is None and value_dim == 1 and out_dim != 1:
raise ValueError(Alloc._runtime_broadcast_error_msg)

@staticmethod
def value_is_scalar_zero(x: TensorVariable) -> bool:
return (
all(x.type.broadcastable)
and isinstance(x, Constant)
and (x.unique_value == 0)
)

def perform(self, node, inputs, out_):
(out,) = out_
v = inputs[0]
Expand All @@ -1659,6 +1667,7 @@ def c_code(self, node, name, inp, out, sub):
o_static_shape = node.outputs[0].type.shape
v_ndim = len(v_static_shape)
o_ndim = len(o_static_shape)
is_zero = self.value_is_scalar_zero(node.inputs[0])
assert o_ndim == len(inp[1:])

# Declare variables
Expand Down Expand Up @@ -1699,16 +1708,18 @@ def c_code(self, node, name, inp, out, sub):
{fail}
}}
}}

if ({int(is_zero)} && (PyArray_IS_C_CONTIGUOUS({zz}) || PyArray_IS_F_CONTIGUOUS({zz}))){{
PyArray_FILLWBYTE({zz}, 0);
}}
// This function takes care of broadcasting
if (PyArray_CopyInto({zz}, {vv}) == -1)
else if (PyArray_CopyInto({zz}, {vv}) == -1)
{fail}
"""

return code

def c_code_cache_version(self):
return (4,)
return (5,)

def infer_shape(self, fgraph, node, input_shapes):
return [node.inputs[1:]]
Expand Down
26 changes: 20 additions & 6 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,12 +1295,26 @@

@node_rewriter([AdvancedIncSubtensor1], inplace=True)
def local_inplace_AdvancedIncSubtensor1(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.clone_inplace()
new_node = new_op(*node.inputs)
copy_stack_trace(node.outputs, new_node)
return [new_node]
return False
if node.op.inplace:
return

x, y, idx = node.inputs
if fgraph.has_destroyers([x]):
# In this case we can't operate inplace, but if x is just an alloc of zeros
# We're better off duplicating it and then acting on it inplace.
if (
x.owner is not None
and isinstance(x.owner.op, Alloc)
and x.owner.op.value_is_scalar_zero(x.owner.inputs[0])
):
x = x.owner.clone().outputs[0]

Check warning on line 1310 in pytensor/tensor/rewriting/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/subtensor.py#L1310

Added line #L1310 was not covered by tests
else:
return None # Inplace isn't valid

Check warning on line 1312 in pytensor/tensor/rewriting/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/subtensor.py#L1312

Added line #L1312 was not covered by tests

new_op = node.op.clone_inplace()
new_node = new_op(x, y, idx)
copy_stack_trace(node.outputs, new_node)
return [new_node]


compile.optdb.register(
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def make_node(self, pivots):

def perform(self, node, inputs, outputs):
[pivots] = inputs
p_inv = np.arange(len(pivots), dtype=pivots.dtype)
p_inv = np.arange(len(pivots), dtype="int64")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to always be careful about single vs double precision on floats (use floatX) but not int -- why is that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float32 were very geared towards GPU concerns, integers are not as problematic [citation needed]


for i in range(len(pivots)):
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
Expand Down Expand Up @@ -639,7 +639,7 @@ def make_node(self, A):
)

LU = matrix(shape=A.type.shape, dtype=A.type.dtype)
pivots = vector(shape=(A.type.shape[0],), dtype="int64")
pivots = vector(shape=(A.type.shape[0],), dtype="int32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like why 64 above but 32 here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what the scipy function returns. The reason I went with i64 above is that np.argsort returns that regardless of the input, so I kept the outputs the same type regardless of return_inverse or whatever the property on that other Op is


return Apply(self, [A], [LU, pivots])

Expand Down
Loading