Skip to content

Commit d7bfe73

Browse files
authored
Merge pull request #309 from ev-br/irfftn_shapes
switch back testing of irfftn shapes
2 parents 298ba5b + 07bc1c2 commit d7bfe73

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

array_api_tests/test_fft.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,17 @@ def test_irfftn(x, data):
227227
expected=dh.dtype_components[x.dtype],
228228
)
229229

230-
# TODO: assert shape correctly
231-
# _axes = sh.normalize_axis(axes, x.ndim)
232-
# _s = x.shape if s is None else s
233-
# expected = []
234-
# for i in range(x.ndim):
235-
# if i in _axes:
236-
# side = _s[_axes.index(i)]
237-
# else:
238-
# side = x.shape[i]
239-
# expected.append(side)
240-
# last_axis = max(_axes)
241-
# expected[last_axis] = _s[_axes.index(last_axis)] // 2 + 1
242-
# ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
230+
_axes = sh.normalize_axis(axes, x.ndim)
231+
_s = x.shape if s is None else s
232+
expected = []
233+
for i in range(x.ndim):
234+
if i in _axes:
235+
side = _s[_axes.index(i)]
236+
else:
237+
side = x.shape[i]
238+
expected.append(side)
239+
expected[_axes[-1]] = 2*(_s[-1] - 1) if s is None else _s[-1]
240+
ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
243241

244242

245243
@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())

0 commit comments

Comments
 (0)