Skip to content

Commit 10d0741

Browse files
committed
Fix scalar case in XElemwise
1 parent 5a7b23c commit 10d0741

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pytensor/xtensor/vectorization.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def make_node(self, *inputs):
3636
# Keep the non-None shape
3737
dims_and_shape[dim] = dim_length
3838

39-
output_dims, output_shape = zip(*dims_and_shape.items())
39+
if dims_and_shape:
40+
output_dims, output_shape = zip(*dims_and_shape.items())
41+
else:
42+
output_dims, output_shape = (), ()
4043

4144
dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs]
4245
output_dtypes = [

tests/xtensor/test_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414
from tests.xtensor.util import xr_assert_allclose, xr_function
1515

1616

17+
def test_scalar_case():
18+
x = xtensor("x", dims=(), shape=())
19+
y = xtensor("y", dims=(), shape=())
20+
out = add(x, y)
21+
22+
fn = function([x, y], out)
23+
24+
x_test = DataArray(2.0, dims=())
25+
y_test = DataArray(3.0, dims=())
26+
np.testing.assert_allclose(fn(x_test.values, y_test.values), 5.0)
27+
28+
1729
def test_dimension_alignment():
1830
x = xtensor("x", dims=("city", "country", "planet"), shape=(2, 3, 4))
1931
y = xtensor(

0 commit comments

Comments
 (0)