Skip to content

Commit eed127f

Browse files
committed
Generalize Blockwise inplace logic
1 parent 470ea60 commit eed127f

File tree

9 files changed

+228
-56
lines changed

9 files changed

+228
-56
lines changed

pytensor/graph/op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,11 @@ def make_thunk(
577577
)
578578
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
579579

580+
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
581+
"""Try to return a version of self that can inplace on candidate_inputs."""
582+
# TODO: Document this in the Create your own op docs
583+
raise NotImplementedError()
584+
580585
def __str__(self):
581586
return getattr(type(self), "__name__", super().__str__())
582587

pytensor/tensor/blockwise.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
signature: Optional[str] = None,
6161
name: Optional[str] = None,
6262
gufunc_spec: Optional[tuple[str, int, int]] = None,
63+
destroy_map=None,
6364
**kwargs,
6465
):
6566
"""
@@ -94,6 +95,14 @@ def __init__(
9495
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
9596
self.gufunc_spec = gufunc_spec
9697
self._gufunc = None
98+
if destroy_map is not None:
99+
# TODO: Check core_op destroy_map is compatible with Blockwise destroy_map
100+
self.destroy_map = destroy_map
101+
if self.destroy_map != core_op.destroy_map:
102+
raise ValueError(
103+
"Blockwise destroy_map must be the same as that of the core_op"
104+
)
105+
97106
super().__init__(**kwargs)
98107

99108
def __getstate__(self):

pytensor/tensor/rewriting/basic.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,6 @@ def alloc_like(
135135
return rval
136136

137137

138-
def make_inplace(node, inplace_prop="inplace"):
139-
op = getattr(node.op, "core_op", node.op)
140-
props = op._props_dict()
141-
if props[inplace_prop]:
142-
return False
143-
144-
props[inplace_prop] = True
145-
inplace_op = type(op)(**props)
146-
147-
return inplace_op.make_node(*node.inputs).outputs
148-
149-
150138
def register_useless(
151139
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
152140
):

pytensor/tensor/rewriting/blockwise.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import itertools
12
from typing import Optional
23

4+
from pytensor.compile import Supervisor
35
from pytensor.compile.mode import optdb
46
from pytensor.graph import Constant, node_rewriter
57
from pytensor.graph.replace import vectorize_node
6-
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
8+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
79
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
810
from pytensor.tensor.blockwise import Blockwise
911
from pytensor.tensor.math import Dot
@@ -57,7 +59,7 @@ def local_useless_unbatched_blockwise(fgraph, node):
5759
"fast_run",
5860
"fast_compile",
5961
"blockwise",
60-
position=49,
62+
position=99, # TODO: Check if this makes sense
6163
)
6264

6365

@@ -199,3 +201,76 @@ def local_blockwise_alloc(fgraph, node):
199201
assert new_outs[0].type.broadcastable == old_out_type.broadcastable
200202
copy_stack_trace(node.outputs, new_outs)
201203
return new_outs
204+
205+
206+
@node_rewriter([Blockwise], inplace=True)
207+
def node_blockwise_inplace(fgraph, node):
208+
# Find inputs that are candidates for inplacing
209+
blockwise_op = node.op
210+
211+
if blockwise_op.destroy_map:
212+
# Op already has inplace
213+
return False
214+
215+
core_op = blockwise_op.core_op
216+
batch_ndim = blockwise_op.batch_ndim(node)
217+
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
218+
219+
# TODO: Refactor this code, which is also present in Elemwise Inplacer
220+
protected_inputs = [
221+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
222+
]
223+
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
224+
protected_inputs.extend(fgraph.outputs)
225+
226+
candidate_inputs = [
227+
idx
228+
for idx, inp in enumerate(node.inputs)
229+
if (
230+
not isinstance(inp, Constant)
231+
and inp.type.broadcastable[:batch_ndim] == out_batch_bcast
232+
and not fgraph.has_destroyers([inp])
233+
and inp not in protected_inputs
234+
)
235+
]
236+
237+
if not candidate_inputs:
238+
return None
239+
240+
try:
241+
inplace_core_op = core_op.try_inplace_inputs(candidate_inputs)
242+
except NotImplementedError:
243+
return False
244+
245+
core_destroy_map = inplace_core_op.destroy_map
246+
247+
if not core_destroy_map:
248+
return False
249+
250+
# Check Op is not trying to inplace on non-candidate inputs
251+
for destroyed_inputs in core_destroy_map.values():
252+
for destroyed_input in destroyed_inputs:
253+
if destroyed_input not in candidate_inputs:
254+
raise ValueError("core_op did not respect candidate inputs")
255+
256+
# Recreate core_op with inplace
257+
inplace_blockwise_op = Blockwise(
258+
core_op=inplace_core_op,
259+
signature=blockwise_op.signature,
260+
name=blockwise_op.name,
261+
gufunc_spec=blockwise_op.gufunc_spec,
262+
destroy_map=core_destroy_map,
263+
)
264+
265+
return inplace_blockwise_op.make_node(*node.inputs).outputs
266+
267+
268+
# After destroyhandler(49.5) but before we try to make elemwise things inplace (75)
269+
blockwise_inplace = in2out(node_blockwise_inplace, name="blockwise_inplace")
270+
optdb.register(
271+
"blockwise_inplace",
272+
blockwise_inplace,
273+
"fast_run",
274+
"inplace",
275+
position=69.0,
276+
)

pytensor/tensor/rewriting/elemwise.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,8 @@ def apply(self, fgraph):
185185
for i in range(len(node.inputs))
186186
if i not in baseline.values()
187187
and not isinstance(node.inputs[i], Constant)
188-
and
189188
# the next line should not be costly most of the time.
190-
not fgraph.has_destroyers([node.inputs[i]])
189+
and not fgraph.has_destroyers([node.inputs[i]])
191190
and node.inputs[i] not in protected_inputs
192191
]
193192
else:

pytensor/tensor/rewriting/linalg.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import logging
22
from typing import cast
33

4-
from pytensor.compile import optdb
5-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
4+
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
65
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
76
from pytensor.tensor.blas import Dot22
87
from pytensor.tensor.blockwise import Blockwise
98
from pytensor.tensor.elemwise import DimShuffle
109
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
1110
from pytensor.tensor.nlinalg import MatrixInverse, det
1211
from pytensor.tensor.rewriting.basic import (
13-
make_inplace,
1412
register_canonicalize,
1513
register_specialize,
1614
register_stabilize,
@@ -312,21 +310,3 @@ def local_log_prod_sqr(fgraph, node):
312310

313311
# TODO: have a reduction like prod and sum that simply
314312
# returns the sign of the prod multiplication.
315-
316-
317-
@node_rewriter([Cholesky], inplace=True)
318-
def local_inplace_cholesky(fgraph, node):
319-
return make_inplace(node, "overwrite_a")
320-
321-
322-
# After destroyhandler(49.5) but before we try to make elemwise things
323-
# inplace (75)
324-
linalg_opt_inplace = in2out(local_inplace_cholesky, name="linalg_opt_inplace")
325-
optdb.register(
326-
"InplaceLinalgOpt",
327-
linalg_opt_inplace,
328-
"fast_run",
329-
"inplace",
330-
"linalg_opt_inplace",
331-
position=69.0,
332-
)

pytensor/tensor/slinalg.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ def conjugate_solve_triangular(outer, inner):
141141
else:
142142
return [grad]
143143

144+
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
145+
if candidate_inputs == [0]:
146+
return type(self)(
147+
lower=self.lower, overwrite_a=True, on_error=self.on_error
148+
)
149+
144150

145151
def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
146152
return Blockwise(Cholesky(lower=lower, on_error=on_error, overwrite_a=overwrite_a))(
@@ -155,6 +161,8 @@ class SolveBase(Op):
155161
"lower",
156162
"check_finite",
157163
"b_ndim",
164+
"overwrite_a",
165+
"overwrite_b",
158166
)
159167

160168
def __init__(
@@ -163,6 +171,8 @@ def __init__(
163171
lower=False,
164172
check_finite=True,
165173
b_ndim,
174+
overwrite_a=False,
175+
overwrite_b=False,
166176
):
167177
self.lower = lower
168178
self.check_finite = check_finite
@@ -172,6 +182,16 @@ def __init__(
172182
self.gufunc_signature = "(m,m),(m)->(m)"
173183
else:
174184
self.gufunc_signature = "(m,m),(m,n)->(m,n)"
185+
self.overwrite_a = overwrite_a
186+
self.overwrite_b = overwrite_b
187+
destroy_map = {}
188+
if self.overwrite_a and self.overwrite_b:
189+
destroy_map[0] = [0, 1]
190+
elif self.overwrite_a:
191+
destroy_map[0] = [0]
192+
elif self.overwrite_b:
193+
destroy_map[0] = [1]
194+
self.destroy_map = destroy_map
175195

176196
def perform(self, node, inputs, outputs):
177197
pass
@@ -245,7 +265,16 @@ def _default_b_ndim(b, b_ndim):
245265

246266

247267
class CholeskySolve(SolveBase):
268+
__props__ = (
269+
"lower",
270+
"check_finite",
271+
"b_ndim",
272+
"overwrite_b",
273+
)
274+
248275
def __init__(self, **kwargs):
276+
if kwargs.get("overwrite_a", False):
277+
raise ValueError("overwrite_a is not supported for CholeskySolve")
249278
kwargs.setdefault("lower", True)
250279
super().__init__(**kwargs)
251280

@@ -260,8 +289,15 @@ def perform(self, node, inputs, output_storage):
260289
output_storage[0][0] = rval
261290

262291
def L_op(self, *args, **kwargs):
292+
# TODO: Base impl should work, let's try it
263293
raise NotImplementedError()
264294

295+
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
296+
if 1 in candidate_inputs:
297+
new_props = self._props_dict()
298+
new_props["overwrite_b"] = True
299+
return type(self)(**new_props)
300+
265301

266302
def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None):
267303
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
@@ -296,9 +332,12 @@ class SolveTriangular(SolveBase):
296332
"lower",
297333
"check_finite",
298334
"b_ndim",
335+
"overwrite_b",
299336
)
300337

301338
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
339+
if kwargs.get("overwrite_a", False):
340+
raise ValueError("overwrite_a is not supported for SolverTriangulare")
302341
super().__init__(**kwargs)
303342
self.trans = trans
304343
self.unit_diagonal = unit_diagonal
@@ -324,6 +363,12 @@ def L_op(self, inputs, outputs, output_gradients):
324363

325364
return res
326365

366+
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
367+
if 1 in candidate_inputs:
368+
new_props = self._props_dict()
369+
new_props["overwrite_b"] = True
370+
return type(self)(**new_props)
371+
327372

328373
def solve_triangular(
329374
a: TensorVariable,
@@ -383,6 +428,8 @@ class Solve(SolveBase):
383428
"lower",
384429
"check_finite",
385430
"b_ndim",
431+
"overwrite_a",
432+
"overwrite_b",
386433
)
387434

388435
def __init__(self, *, assume_a="gen", **kwargs):
@@ -402,6 +449,14 @@ def perform(self, node, inputs, outputs):
402449
assume_a=self.assume_a,
403450
)
404451

452+
def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
453+
new_props = self._props_dict()
454+
if 0 in candidate_inputs:
455+
new_props["overwrite_a"] = True
456+
if 1 in candidate_inputs:
457+
new_props["overwrite_b"] = True
458+
return type(self)(**new_props)
459+
405460

406461
def solve(
407462
a,

tests/tensor/rewriting/test_linalg.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -312,33 +312,38 @@ def test_invalid_batched_a(self):
312312
config.mode == "FAST_COMPILE",
313313
reason="inplace rewrites disabled when mode is FAST_COMPILE",
314314
)
315-
def test_local_inplace_cholesky():
316-
X = matrix("X")
315+
@pytest.mark.parametrize("is_batched", (False, True))
316+
def test_local_inplace_cholesky(is_batched):
317+
shape = (5, None, None) if is_batched else (None, None)
318+
X = tensor("X", shape=shape)
317319
L = cholesky(X, overwrite_a=False, lower=True)
318320
f = function([pytensor.In(X, mutable=True)], L)
319321

320322
assert not L.owner.op.core_op.overwrite_a
321323

322-
nodes = f.maker.fgraph.toposort()
323-
for node in nodes:
324-
if isinstance(node, Cholesky):
325-
assert node.overwrite_a
326-
break
324+
if is_batched:
325+
[cholesky_op] = [
326+
node.op.core_op
327+
for node in f.maker.fgraph.apply_nodes
328+
if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky)
329+
]
330+
else:
331+
[cholesky_op] = [
332+
node.op
333+
for node in f.maker.fgraph.apply_nodes
334+
if isinstance(node.op, Cholesky)
335+
]
336+
assert cholesky_op.overwrite_a
327337

328338
X_val = np.random.normal(size=(10, 10)).astype(config.floatX)
329339
X_val_in = X_val @ X_val.T
340+
if is_batched:
341+
X_val_in = np.broadcast_to(X_val_in, (5, *X_val_in.shape)).copy()
330342
X_val_in_copy = X_val_in.copy()
343+
331344
f(X_val_in)
332345

333346
assert_allclose(
334-
X_val_in[np.triu_indices_from(X_val_in, k=1)],
335-
0.0,
336-
atol=1e-4 if config.floatX == "float32" else 1e-8,
337-
rtol=1e-4 if config.floatX == "float32" else 1e-8,
338-
)
339-
assert_allclose(
340-
X_val_in @ X_val_in.T,
341-
X_val_in_copy,
342-
atol=1e-4 if config.floatX == "float32" else 1e-8,
343-
rtol=1e-4 if config.floatX == "float32" else 1e-8,
347+
X_val_in,
348+
np.linalg.cholesky(X_val_in_copy),
344349
)

0 commit comments

Comments
 (0)