diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 6dd0e8211b..56a3e2c9b2 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -2,15 +2,16 @@ from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify # Load dispatch specializations -import pytensor.link.numba.dispatch.scalar -import pytensor.link.numba.dispatch.tensor_basic +import pytensor.link.numba.dispatch.blockwise +import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.extra_ops import pytensor.link.numba.dispatch.nlinalg import pytensor.link.numba.dispatch.random -import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.scan -import pytensor.link.numba.dispatch.sparse +import pytensor.link.numba.dispatch.scalar import pytensor.link.numba.dispatch.slinalg +import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor +import pytensor.link.numba.dispatch.tensor_basic # isort: on diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py new file mode 100644 index 0000000000..131788e843 --- /dev/null +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -0,0 +1,92 @@ +import sys +from typing import cast + +from numba.core.extending import overload +from numba.np.unsafe.ndarray import to_fixed_tuple + +from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit +from pytensor.link.numba.dispatch.vectorize_codegen import ( + _jit_options, + _vectorized, + encode_literals, + store_core_outputs, +) +from pytensor.link.utils import compile_function_src +from pytensor.tensor import TensorVariable, get_vector_length +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape + + +@numba_funcify.register +def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): + [blockwise_node] = op.fgraph.apply_nodes + blockwise_op: Blockwise = blockwise_node.op + core_op = blockwise_op.core_op + nin = len(blockwise_node.inputs) + nout = len(blockwise_node.outputs) + core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:]) + + core_node = blockwise_op._create_dummy_core_node( + cast(tuple[TensorVariable], blockwise_node.inputs) + ) + core_op_fn = numba_funcify( + core_op, + node=core_node, + parent_node=node, + fastmath=_jit_options["fastmath"], + **kwargs, + ) + core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) + + batch_ndim = blockwise_op.batch_ndim(node) + + # numba doesn't support nested literals right now... + input_bc_patterns = encode_literals( + tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs[:nin]) + ) + output_bc_patterns = encode_literals( + tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs) + ) + output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs)) + inplace_pattern = encode_literals(()) + + # Numba does not allow a tuple generator in the Jitted function so we have to compile a helper to convert core_shapes into tuples + # Alternatively, add an Op that converts shape vectors into tuples, like we did for JAX + src = "def to_tuple(core_shapes): return (" + for i in range(nout): + src += f"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]})," + src += ")" + + to_tuple = numba_njit( + compile_function_src( + src, + "to_tuple", + global_env={"to_fixed_tuple": to_fixed_tuple}, + ), + # cache=True leads to a numba.cloudpickle dump failure in Python 3.10 + # May be fine in Python 3.11, but I didn't test. It was fine in 3.12 + cache=sys.version_info >= (3, 12), + ) + + def blockwise_wrapper(*inputs_and_core_shapes): + inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:] + tuple_core_shapes = to_tuple(core_shapes) + return _vectorized( + core_op_fn, + input_bc_patterns, + output_bc_patterns, + output_dtypes, + inplace_pattern, + (), # constant_inputs + inputs, + tuple_core_shapes, + None, # size + ) + + def blockwise(*inputs_and_core_shapes): + raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented") + + @overload(blockwise, jit_options=_jit_options) + def ov_blockwise(*inputs_and_core_shapes): + return blockwise_wrapper + + return blockwise diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 29584daa5f..04181e8335 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -388,7 +388,7 @@ def random_wrapper(core_shape, rng, size, *dist_params): return rng, draws def random(core_shape, rng, size, *dist_params): - pass + raise NotImplementedError("Non-jitted random variable not implemented") @overload(random, jit_options=_jit_options) def ov_random(core_shape, rng, size, *dist_params): diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 1c3a221642..b3366f21af 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -6,7 +6,8 @@ from pytensor import config from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType -from pytensor.graph.basic import Apply, Constant +from pytensor.graph import FunctionGraph +from pytensor.graph.basic import Apply, Constant, ancestors from pytensor.graph.null_type import NullType from pytensor.graph.op import Op from pytensor.graph.replace import ( @@ -185,15 +186,40 @@ def infer_shape( batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) + # Try to extract the core shapes from the core_op + core_op_infer_shape = getattr(self.core_op, "infer_shape", None) + if core_op_infer_shape is not None: + dummy_core_node = self._create_dummy_core_node(node.inputs) + dummy_core_inputs = dummy_core_node.inputs + dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) + core_input_shapes = [ + input_shape[batch_ndims:] for input_shape in input_shapes + ] + core_output_shapes = core_op_infer_shape( + dummy_fgraph, dummy_core_node, core_input_shapes + ) + out_shapes = [] - for output, sig in zip(node.outputs, self.outputs_sig, strict=True): + for o, (output, sig) in enumerate( + zip(node.outputs, self.outputs_sig, strict=True) + ): core_out_shape = [] for i, dim_name in enumerate(sig): # The output dim is the same as another input dim if dim_name in core_dims: core_out_shape.append(core_dims[dim_name]) else: - # TODO: We could try to make use of infer_shape of core_op + if core_op_infer_shape is not None: + # If the input values are needed to compute the dimension length, we can't use the infer_shape + # of the core_node as the value is not constant across batch dims of the Blockwise + core_out_dim = core_output_shapes[o][i] + if not ( + set(dummy_core_inputs) & set(ancestors([core_out_dim])) + ): + core_out_shape.append(core_out_dim) + continue + + # Fallback shape requires evaluating the Blockwise Op core_out_shape.append(Shape_i(batch_ndims + i)(output)) out_shapes.append((*batch_shape, *core_out_shape)) @@ -416,3 +442,11 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: class OpWithCoreShape(OpFromGraph): """Generalizes an `Op` to include core shape as an additional input.""" + + +class BlockwiseWithCoreShape(OpWithCoreShape): + """Generalizes a Blockwise `Op` to include a core shape parameter.""" + + def __str__(self): + [blockwise_node] = self.fgraph.apply_nodes + return f"[{blockwise_node.op!s}]" diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index fe170f4718..b6dcf3b5e8 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -15,7 +15,7 @@ def introduce_explicit_core_shape_rv(fgraph, node): This core_shape is used by the numba backend to pre-allocate the output array. If available, the core shape is extracted from the shape feature of the graph, - which has a higher change of having been simplified, optimized, constant-folded. + which has a higher chance of having been simplified, optimized, constant-folded. If missing, we fall back to the op._supp_shape_from_params method. This rewrite is required for the numba backend implementation of RandomVariable. diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index fc5c528f2d..4e75140ceb 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -9,6 +9,7 @@ import pytensor.tensor.rewriting.jax import pytensor.tensor.rewriting.linalg import pytensor.tensor.rewriting.math +import pytensor.tensor.rewriting.numba import pytensor.tensor.rewriting.ofg import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.special diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py new file mode 100644 index 0000000000..91ab131424 --- /dev/null +++ b/pytensor/tensor/rewriting/numba.py @@ -0,0 +1,108 @@ +from pytensor.compile import optdb +from pytensor.graph import node_rewriter +from pytensor.graph.basic import applys_between +from pytensor.graph.rewriting.basic import out2in +from pytensor.tensor.basic import as_tensor, constant +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape +from pytensor.tensor.rewriting.shape import ShapeFeature + + +@node_rewriter([Blockwise]) +def introduce_explicit_core_shape_blockwise(fgraph, node): + """Introduce the core shape of a Blockwise. + + We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph + that has an extra "non-functional" input that represents the core shape of the Blockwise variable. + This core_shape is used by the numba backend to pre-allocate the output array. + + If available, the core shape is extracted from the shape feature of the graph, + which has a higher change of having been simplified, optimized, constant-folded. + If missing, we fall back to the op._supp_shape_from_params method. + + This rewrite is required for the numba backend implementation of Blockwise. + + Example + ------- + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(5, None, None)) + outs = pt.linalg.svd(x, compute_uv=True) + pytensor.dprint(outs) + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A] + # └─ x [id B] + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A] + # └─ ··· + # Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A] + # └─ ··· + + # After the rewrite, note the new 3 core shape inputs + fn = pytensor.function([x], outs, mode="NUMBA") + fn.dprint(print_type=False) + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6 + # ├─ x [id B] + # ├─ MakeVector{dtype='int64'} [id C] 5 + # │ ├─ Shape_i{1} [id D] 2 + # │ │ └─ x [id B] + # │ └─ Shape_i{1} [id D] 2 + # │ └─ ··· + # ├─ MakeVector{dtype='int64'} [id E] 4 + # │ └─ Minimum [id F] 3 + # │ ├─ Shape_i{1} [id D] 2 + # │ │ └─ ··· + # │ └─ Shape_i{2} [id G] 0 + # │ └─ x [id B] + # └─ MakeVector{dtype='int64'} [id H] 1 + # ├─ Shape_i{2} [id G] 0 + # │ └─ ··· + # └─ Shape_i{2} [id G] 0 + # └─ ··· + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6 + # └─ ··· + # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6 + # └─ ··· + """ + op: Blockwise = node.op # type: ignore[annotation-unchecked] + batch_ndim = op.batch_ndim(node) + + shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked] + if shape_feature: + core_shapes = [ + [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] + for out in node.outputs + ] + else: + input_shapes = [tuple(inp.shape) for inp in node.inputs] + core_shapes = [ + out_shape[batch_ndim:] + for out_shape in op.infer_shape(None, node, input_shapes) + ] + + core_shapes = [ + as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64") + for core_shape in core_shapes + ] + + if any( + isinstance(node.op, Blockwise) + for node in applys_between(node.inputs, core_shapes) + ): + # If Blockwise shows up in the shape graph we can't introduce the core shape + return None + + return BlockwiseWithCoreShape( + [*node.inputs, *core_shapes], + node.outputs, + destroy_map=op.destroy_map, + )(*node.inputs, *core_shapes, return_list=True) + + +optdb.register( + introduce_explicit_core_shape_blockwise.__name__, + out2in(introduce_explicit_core_shape_blockwise), + "numba", + position=100, +) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index ec88b0fd50..0086b15a80 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -244,7 +244,7 @@ def compare_numba_and_py( Parameters ---------- fgraph - `FunctionGraph` or inputs to compare. + `FunctionGraph` or tuple(inputs, outputs) to compare. inputs Numeric inputs to be passed to the compiled graphs. assert_fn diff --git a/tests/link/numba/test_blockwise.py b/tests/link/numba/test_blockwise.py new file mode 100644 index 0000000000..ced4185e14 --- /dev/null +++ b/tests/link/numba/test_blockwise.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from pytensor import function +from pytensor.tensor import tensor +from pytensor.tensor.basic import ARange +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.nlinalg import SVD, Det +from pytensor.tensor.slinalg import Cholesky, cholesky +from tests.link.numba.test_basic import compare_numba_and_py, numba_mode + + +# Fails if object mode warning is issued when not expected +pytestmark = pytest.mark.filterwarnings("error") + + +@pytest.mark.parametrize("shape_opt", [True, False], ids=str) +@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str) +def test_blockwise(core_op, shape_opt): + x = tensor(shape=(5, None, None)) + outs = Blockwise(core_op=core_op)(x, return_list=True) + + mode = ( + numba_mode.including("ShapeOpt") + if shape_opt + else numba_mode.excluding("ShapeOpt") + ) + x_test = np.eye(3) * np.arange(1, 6)[:, None, None] + compare_numba_and_py( + ([x], outs), + [x_test], + numba_mode=mode, + eval_obj_mode=False, + ) + + +def test_non_square_blockwise(): + """Test that Op that cannot always be blockwised at runtime fails gracefully.""" + x = tensor(shape=(3,), dtype="int64") + out = Blockwise(core_op=ARange(dtype="int64"), signature="(),(),()->(a)")(0, x, 1) + + with pytest.warns(UserWarning, match="Numba will use object mode"): + fn = function([x], out, mode="NUMBA") + + np.testing.assert_allclose(fn([5, 5, 5]), np.broadcast_to(np.arange(5), (3, 5))) + + with pytest.raises(ValueError): + fn([3, 4, 5]) + + +def test_blockwise_benchmark(benchmark): + x = tensor(shape=(5, 3, 3)) + out = cholesky(x) + assert isinstance(out.owner.op, Blockwise) + + fn = function([x], out, mode="NUMBA") + x_test = np.eye(3) * np.arange(1, 6)[:, None, None] + fn(x_test) # JIT compile + benchmark(fn, x_test) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index b342c576bd..8ce40d48ef 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -259,6 +259,58 @@ def test_blockwise_shape(): assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4) +def test_blockwise_infer_core_shape(): + class TestOpWithInferShape(Op): + def make_node(self, a, b): + assert a.type.ndim == 1 + assert b.type.ndim == 1 + c = tensor(shape=(None,)) + d = tensor(shape=(None,)) + return Apply(self, [a, b], [c, d]) + + def perform(self, node, inputs, outputs): + a, b = inputs + c, d = outputs + c[0] = np.arange(a.size + b.size) + d[0] = np.arange(a.sum() + b.sum()) + + def infer_shape(self, fgraph, node, input_shapes): + # First output shape depends only on input_shapes + # Second output shape depends on input values + x, y = node.inputs + [(x_shape,), (y_shape,)] = input_shapes + return (x_shape + y_shape,), (x.sum() + y.sum(),) + + blockwise_op = Blockwise( + core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)" + ) + + a = tensor("a", shape=(5, 3)) + b = tensor("b", shape=(1, 4)) + c, d = blockwise_op(a, b) + assert c.type.shape == (5, None) + assert d.type.shape == (5, None) + + c_shape_fn = pytensor.function([a, b], c.shape) + # c_shape can be computed from the input shapes alone + assert not any( + isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape) + for n in c_shape_fn.maker.fgraph.apply_nodes + ) + + d_shape_fn = pytensor.function([a, b], d.shape) + # d_shape cannot be computed from the input shapes alone + assert any( + isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape) + for n in d_shape_fn.maker.fgraph.apply_nodes + ) + + a_test = np.zeros(a.type.shape, dtype=a.type.dtype) + b_test = np.zeros(b.type.shape, dtype=b.type.dtype) + assert tuple(c_shape_fn(a_test, b_test)) == (5, 7) + assert tuple(d_shape_fn(a_test, b_test)) == (5, 0) + + class BlockwiseOpTester: """Base class to test Blockwise works for specific Ops"""