@@ -130,44 +130,15 @@ def test_prod(x, data):
130
130
out = xp .prod (x , ** kw )
131
131
132
132
dtype = kw .get ("dtype" , None )
133
- if dtype is None :
134
- if dh .is_int_dtype (x .dtype ):
135
- if x .dtype in dh .uint_dtypes :
136
- default_dtype = dh .default_uint
137
- else :
138
- default_dtype = dh .default_int
139
- if default_dtype is None :
140
- _dtype = None
141
- else :
142
- m , M = dh .dtype_ranges [x .dtype ]
143
- d_m , d_M = dh .dtype_ranges [default_dtype ]
144
- if m < d_m or M > d_M :
145
- _dtype = x .dtype
146
- else :
147
- _dtype = default_dtype
148
- elif dh .is_float_dtype (x .dtype , include_complex = False ):
149
- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
150
- _dtype = x .dtype
151
- else :
152
- _dtype = dh .default_float
153
- elif api_version > "2021.12" :
154
- # Complex dtype
155
- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_complex ]:
156
- _dtype = x .dtype
157
- else :
158
- _dtype = dh .default_complex
159
- else :
160
- raise RuntimeError ("Unexpected dtype. This indicates a bug in the test suite." )
161
- else :
162
- _dtype = dtype
163
- if _dtype is None :
133
+ expected_dtype = dh .accumulation_result_dtype (x .dtype , dtype )
134
+ if expected_dtype is None :
164
135
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
165
136
# uint32 or uint64), we skip testing the output dtype.
166
137
# See https://github.com/data-apis/array-api-tests/issues/106
167
138
if x .dtype in dh .uint_dtypes :
168
139
assert dh .is_int_dtype (out .dtype ) # sanity check
169
140
else :
170
- ph .assert_dtype ("prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
141
+ ph .assert_dtype ("prod" , in_dtype = x .dtype , out_dtype = out .dtype , expected = expected_dtype )
171
142
_axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
172
143
ph .assert_keepdimable_shape (
173
144
"prod" , in_shape = x .shape , out_shape = out .shape , axes = _axes , keepdims = keepdims , kw = kw
@@ -246,44 +217,15 @@ def test_sum(x, data):
246
217
out = xp .sum (x , ** kw )
247
218
248
219
dtype = kw .get ("dtype" , None )
249
- if dtype is None :
250
- if dh .is_int_dtype (x .dtype ):
251
- if x .dtype in dh .uint_dtypes :
252
- default_dtype = dh .default_uint
253
- else :
254
- default_dtype = dh .default_int
255
- if default_dtype is None :
256
- _dtype = None
257
- else :
258
- m , M = dh .dtype_ranges [x .dtype ]
259
- d_m , d_M = dh .dtype_ranges [default_dtype ]
260
- if m < d_m or M > d_M :
261
- _dtype = x .dtype
262
- else :
263
- _dtype = default_dtype
264
- elif dh .is_float_dtype (x .dtype , include_complex = False ):
265
- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_float ]:
266
- _dtype = x .dtype
267
- else :
268
- _dtype = dh .default_float
269
- elif api_version > "2021.12" :
270
- # Complex dtype
271
- if dh .dtype_nbits [x .dtype ] > dh .dtype_nbits [dh .default_complex ]:
272
- _dtype = x .dtype
273
- else :
274
- _dtype = dh .default_complex
275
- else :
276
- raise RuntimeError ("Unexpected dtype. This indicates a bug in the test suite." )
277
- else :
278
- _dtype = dtype
279
- if _dtype is None :
220
+ expected_dtype = dh .accumulation_result_dtype (x .dtype , dtype )
221
+ if expected_dtype is None :
280
222
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
281
223
# uint32 or uint64), we skip testing the output dtype.
282
224
# See https://github.com/data-apis/array-api-tests/issues/160
283
225
if x .dtype in dh .uint_dtypes :
284
226
assert dh .is_int_dtype (out .dtype ) # sanity check
285
227
else :
286
- ph .assert_dtype ("sum" , in_dtype = x .dtype , out_dtype = out .dtype , expected = _dtype )
228
+ ph .assert_dtype ("sum" , in_dtype = x .dtype , out_dtype = out .dtype , expected = expected_dtype )
287
229
_axes = sh .normalise_axis (kw .get ("axis" , None ), x .ndim )
288
230
ph .assert_keepdimable_shape (
289
231
"sum" , in_shape = x .shape , out_shape = out .shape , axes = _axes , keepdims = keepdims , kw = kw
0 commit comments