Skip to content

Commit 56637af

Browse files
ricardoV94lucianopaz
authored andcommitted
Implement vectorize_node dispatch for some forms of Join
1 parent caa580b commit 56637af

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

pytensor/tensor/basic.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytensor import scalar as ps
2424
from pytensor.gradient import DisconnectedType, grad_undefined
2525
from pytensor.graph import RewriteDatabaseQuery
26-
from pytensor.graph.basic import Apply, Constant, Variable
26+
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
2727
from pytensor.graph.fg import FunctionGraph
2828
from pytensor.graph.op import Op
2929
from pytensor.graph.replace import _vectorize_node
@@ -42,7 +42,7 @@
4242
as_tensor_variable,
4343
get_vector_length,
4444
)
45-
from pytensor.tensor.blockwise import Blockwise
45+
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
4646
from pytensor.tensor.elemwise import (
4747
DimShuffle,
4848
Elemwise,
@@ -2662,6 +2662,36 @@ def join(axis, *tensors_list):
26622662
return join_(axis, *tensors_list)
26632663

26642664

2665+
@_vectorize_node.register(Join)
2666+
def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
2667+
original_axis, *old_inputs = node.inputs
2668+
# We can vectorize join as a shifted axis on the batch inputs if:
2669+
# 1. The batch axis is a constant and has not changed
2670+
# 2. All inputs are batched with the same broadcastable pattern
2671+
if (
2672+
original_axis.type.ndim == 0
2673+
and isinstance(original_axis, Constant)
2674+
and equal_computations([original_axis], [batch_axis])
2675+
):
2676+
batch_ndims = {
2677+
batch_input.type.ndim - old_input.type.ndim
2678+
for batch_input, old_input in zip(batch_inputs, old_inputs)
2679+
}
2680+
if len(batch_ndims) == 1:
2681+
[batch_ndim] = batch_ndims
2682+
batch_bcast = batch_inputs[0].type.broadcastable[:batch_ndim]
2683+
if all(
2684+
batch_input.type.broadcastable[:batch_ndim] == batch_bcast
2685+
for batch_input in batch_inputs[1:]
2686+
):
2687+
original_ndim = node.outputs[0].type.ndim
2688+
original_axis = normalize_axis_index(original_axis.data, original_ndim)
2689+
batch_axis = original_axis + batch_ndim
2690+
return op.make_node(batch_axis, *batch_inputs)
2691+
2692+
return vectorize_node_fallback(op, node, batch_axis, *batch_inputs)
2693+
2694+
26652695
def roll(x, shift, axis=None):
26662696
"""
26672697
Convenience function to roll TensorTypes along the given axis.

tests/tensor/test_basic.py

+35
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor.tensor.basic as ptb
1111
import pytensor.tensor.math as ptm
1212
from pytensor import compile, config, function, shared
13+
from pytensor.compile import SharedVariable
1314
from pytensor.compile.io import In, Out
1415
from pytensor.compile.mode import Mode, get_default_mode
1516
from pytensor.compile.ops import DeepCopyOp
@@ -4565,3 +4566,37 @@ def core_np(x):
45654566
vectorize_pt(x_test),
45664567
vectorize_np(x_test),
45674568
)
4569+
4570+
4571+
@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
4572+
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
4573+
@config.change_flags(cxx="") # C code not needed
4574+
def test_vectorize_join(axis, broadcasting_y):
4575+
# Signature for join along intermediate axis
4576+
signature = "(a,b1,c),(a,b2,c)->(a,b,c)"
4577+
4578+
def core_pt(x, y):
4579+
return join(axis, x, y)
4580+
4581+
def core_np(x, y):
4582+
return np.concatenate([x, y], axis=axis.eval())
4583+
4584+
x = tensor(shape=(4, 2, 3, 5))
4585+
y_shape = {"none": (4, 2, 3, 5), "implicit": (2, 3, 5), "explicit": (1, 2, 3, 5)}
4586+
y = tensor(shape=y_shape[broadcasting_y])
4587+
4588+
vectorize_pt = function([x, y], vectorize(core_pt, signature=signature)(x, y))
4589+
4590+
blockwise_needed = isinstance(axis, SharedVariable) or broadcasting_y != "none"
4591+
has_blockwise = any(
4592+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
4593+
)
4594+
assert has_blockwise == blockwise_needed
4595+
4596+
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
4597+
y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype)
4598+
vectorize_np = np.vectorize(core_np, signature=signature)
4599+
np.testing.assert_allclose(
4600+
vectorize_pt(x_test, y_test),
4601+
vectorize_np(x_test, y_test),
4602+
)

0 commit comments

Comments
 (0)