Skip to content

Commit b79d232

Browse files
Implement ufunc_outer like add.outer for binary Elemwise operations
1 parent 35ae5db commit b79d232

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

pytensor/tensor/elemwise.py

+10
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,16 @@ def c_code_cache_version_apply(self, node):
11701170
else:
11711171
return ()
11721172

1173+
def outer(self, x, y):
1174+
from pytensor.tensor.basic import expand_dims
1175+
1176+
if self.scalar_op.nin not in (-1, 2):
1177+
raise NotImplementedError("outer is only available for binary operators")
1178+
1179+
x_ = expand_dims(x, tuple(range(-y.ndim, 0)))
1180+
y_ = expand_dims(y, tuple(range(x.ndim)))
1181+
return self(x_, y_)
1182+
11731183

11741184
class CAReduce(COp):
11751185
"""Reduces a scalar operation along specified axes.

tests/tensor/test_elemwise.py

+21
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
import pytensor
1010
import pytensor.scalar as ps
11+
import pytensor.tensor as pt
1112
import tests.unittest_tools as utt
13+
from pytensor.compile.function import function
1214
from pytensor.compile.mode import Mode
1315
from pytensor.configdefaults import config
1416
from pytensor.graph.basic import Apply, Variable
@@ -893,6 +895,25 @@ def test_invalid_static_shape(self):
893895
):
894896
x + y
895897

898+
@pytest.mark.parametrize(
899+
"shape_x, shape_y, op, np_op",
900+
[
901+
((3, 5), (7, 1, 3), pt.add, np.add),
902+
((2, 3), (1, 4), pt.mul, np.multiply),
903+
],
904+
)
905+
def test_outer(self, shape_x, shape_y, op, np_op):
906+
x = tensor(dtype=np.float64, shape=shape_x)
907+
y = tensor(dtype=np.float64, shape=shape_y)
908+
909+
z = op.outer(x, y)
910+
911+
f = function([x, y], z)
912+
x1 = np.ones(shape_x)
913+
y1 = np.ones(shape_y)
914+
915+
np.testing.assert_array_equal(f(x1, y1), np_op.outer(x1, y1))
916+
896917

897918
def test_not_implemented_elemwise_grad():
898919
# Regression test for unimplemented gradient in an Elemwise Op.

0 commit comments

Comments
 (0)