Skip to content

Commit 07ad9ad

Browse files
committed
Specify some args for matrix funcs uninspectable in torch
1 parent 4431440 commit 07ad9ad

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

array_api_tests/test_signatures.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,11 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
131131
"stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
132132
"iinfo": {"type": "xp.int64"},
133133
"finfo": {"type": "xp.float64"},
134-
"logaddexp": {a: "xp.ones((5,), dtype=xp.float64)" for a in ["x1", "x2"]},
134+
"cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"},
135+
"inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"},
136+
"solve": {
137+
a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"]
138+
},
135139
},
136140
)
137141
# We default most array arguments heuristically. As functions/methods work only
@@ -163,7 +167,7 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
163167
if func_name in casty_names:
164168
shape = ()
165169
elif func_name in matrixy_names:
166-
shape = (2, 2)
170+
shape = (3, 3)
167171
else:
168172
shape = (5,)
169173
fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})"
@@ -179,9 +183,6 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
179183

180184

181185
def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
182-
if func_name in matrixy_names:
183-
pytest.xfail("TODO")
184-
185186
params = list(stub_sig.parameters.values())
186187

187188
if len(params) == 0:

0 commit comments

Comments
 (0)