Skip to content

Commit bf50423

Browse files
committed
Implements shape and MakeVector Ops in PyTorch
- Shape - Shape_i - Reshape - SpecifyShape - Unbroadcast - MakeVector
1 parent a6b9585 commit bf50423

File tree

6 files changed

+149
-12
lines changed

6 files changed

+149
-12
lines changed

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
import pytensor.link.pytorch.dispatch.elemwise
77
import pytensor.link.pytorch.dispatch.extra_ops
88
import pytensor.link.pytorch.dispatch.sort
9+
import pytensor.link.pytorch.dispatch.shape
910
# isort: on

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.link.utils import fgraph_to_python
88
from pytensor.raise_op import CheckAndRaise
9-
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
9+
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
1010

1111

1212
@singledispatch
1313
def pytorch_typify(data, dtype=None, **kwargs):
1414
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
15-
return torch.as_tensor(data, dtype=dtype)
15+
if data is not None:
16+
return torch.as_tensor(data, dtype=dtype)
17+
return None
1618

1719

1820
@singledispatch
@@ -116,3 +118,13 @@ def eye(N, M, k):
116118
return zeros
117119

118120
return eye
121+
122+
123+
@pytorch_funcify.register(MakeVector)
124+
def pytorch_funcify_MakeVector(op, **kwargs):
125+
torch_dtype = getattr(torch, op.dtype)
126+
127+
def makevector(*x):
128+
return torch.tensor(x, dtype=torch_dtype)
129+
130+
return makevector
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
5+
6+
7+
@pytorch_funcify.register(Reshape)
8+
def pytorch_funcify_Reshape(op, node, **kwargs):
9+
shape = node.inputs[1]
10+
11+
def reshape(x, shape=shape):
12+
return torch.reshape(x, tuple(shape))
13+
14+
return reshape
15+
16+
17+
@pytorch_funcify.register(Shape)
18+
def pytorch_funcify_Shape(op, **kwargs):
19+
def shape(x):
20+
return x.shape
21+
22+
return shape
23+
24+
25+
@pytorch_funcify.register(Shape_i)
26+
def pytorch_funcify_Shape_i(op, **kwargs):
27+
i = op.i
28+
29+
def shape_i(x):
30+
return x.shape[i]
31+
32+
return shape_i
33+
34+
35+
@pytorch_funcify.register(SpecifyShape)
36+
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
37+
def specifyshape(x, *shape):
38+
assert x.ndim == len(shape)
39+
for actual, expected in zip(x.shape, shape):
40+
if expected is None:
41+
continue
42+
if actual != expected:
43+
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
44+
return x
45+
46+
return specifyshape
47+
48+
49+
@pytorch_funcify.register(Unbroadcast)
50+
def pytorch_funcify_Unbroadcast(op, **kwargs):
51+
def unbroadcast(x):
52+
return x
53+
54+
return unbroadcast

tests/link/pytorch/test_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,10 @@ def test_eye(dtype):
294294
for _M in range(1, 6):
295295
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
296296
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))
297+
298+
299+
def test_pytorch_MakeVector():
300+
x = ptb.make_vector(1, 2, 3)
301+
x_fg = FunctionGraph([], [x])
302+
303+
compare_pytorch_and_py(x_fg, [])

tests/link/pytorch/test_shape.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy as np
2+
3+
import pytensor.tensor as pt
4+
from pytensor.compile.ops import DeepCopyOp, ViewOp
5+
from pytensor.configdefaults import config
6+
from pytensor.graph.fg import FunctionGraph
7+
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
8+
from pytensor.tensor.type import iscalar, vector
9+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
10+
11+
12+
def test_pytorch_shape_ops():
13+
x_np = np.zeros((20, 3))
14+
x = Shape()(pt.as_tensor_variable(x_np))
15+
x_fg = FunctionGraph([], [x])
16+
17+
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
18+
19+
x = Shape_i(1)(pt.as_tensor_variable(x_np))
20+
x_fg = FunctionGraph([], [x])
21+
22+
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
23+
24+
25+
def test_pytorch_specify_shape():
26+
in_pt = pt.matrix("in")
27+
x = pt.specify_shape(in_pt, (4, None))
28+
x_fg = FunctionGraph([in_pt], [x])
29+
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
30+
31+
# When used to assert two arrays have similar shapes
32+
in_pt = pt.matrix("in")
33+
shape_pt = pt.matrix("shape")
34+
x = pt.specify_shape(in_pt, shape_pt.shape)
35+
x_fg = FunctionGraph([in_pt, shape_pt], [x])
36+
compare_pytorch_and_py(
37+
x_fg,
38+
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
39+
)
40+
41+
42+
def test_pytorch_Reshape_constant():
43+
a = vector("a")
44+
x = reshape(a, (2, 2))
45+
x_fg = FunctionGraph([a], [x])
46+
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
47+
48+
49+
def test_pytorch_Reshape_shape_graph_input():
50+
a = vector("a")
51+
shape_pt = iscalar("b")
52+
x = reshape(a, (shape_pt, shape_pt))
53+
x_fg = FunctionGraph([a, shape_pt], [x])
54+
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
55+
56+
57+
def test_pytorch_compile_ops():
58+
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
59+
x_fg = FunctionGraph([], [x])
60+
61+
compare_pytorch_and_py(x_fg, [])
62+
63+
x_np = np.zeros((20, 1, 1))
64+
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
65+
x_fg = FunctionGraph([], [x])
66+
67+
compare_pytorch_and_py(x_fg, [])
68+
69+
x = ViewOp()(pt.as_tensor_variable(x_np))
70+
x_fg = FunctionGraph([], [x])
71+
72+
compare_pytorch_and_py(x_fg, [])

tests/link/pytorch/test_sort.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,7 @@
88

99

1010
@pytest.mark.parametrize("func", (sort, argsort))
11-
@pytest.mark.parametrize(
12-
"axis",
13-
[
14-
pytest.param(0),
15-
pytest.param(1),
16-
pytest.param(
17-
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
18-
),
19-
],
20-
)
11+
@pytest.mark.parametrize("axis", [0, 1, None])
2112
def test_sort(func, axis):
2213
x = matrix("x", shape=(2, 2), dtype="float64")
2314
out = func(x, axis=axis)

0 commit comments

Comments
 (0)