We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8dbe5ca commit 62e453dCopy full SHA for 62e453d
tests/link/pytorch/test_extra_ops.py
@@ -43,18 +43,29 @@ def test_pytorch_CumOp(axis, dtype):
43
compare_pytorch_and_py(fgraph, [test_value])
44
45
46
-@pytest.mark.parametrize("axis", [0, 1])
47
-def test_pytorch_Repeat(axis):
+@pytest.mark.parametrize(
+ "axis, repeats",
48
+ [
49
+ (0, (1, 2, 3)),
50
+ (1, (3, 3)),
51
+ pytest.param(
52
+ None,
53
+ 3,
54
+ marks=pytest.mark.xfail(reason="Reshape not implemented"),
55
+ ),
56
+ ],
57
+)
58
+def test_pytorch_Repeat(axis, repeats):
59
a = pt.matrix("a", dtype="float64")
60
61
test_value = np.arange(6, dtype="float64").reshape((3, 2))
62
- out = pt.repeat(a, (1, 2, 3) if axis == 0 else (3, 3), axis=axis)
63
+ out = pt.repeat(a, repeats, axis=axis)
64
fgraph = FunctionGraph([a], [out])
65
66
67
68
+@pytest.mark.parametrize("axis", [None, 0, 1])
69
def test_pytorch_Unique_axis(axis):
70
71
0 commit comments