@@ -310,10 +310,37 @@ def test_add(ctx, data):
310
310
311
311
assert_binary_param_dtype (ctx , left , right , res )
312
312
assert_binary_param_shape (ctx , left , right , res )
313
- if not ctx .right_is_scalar :
314
- # add is commutative
315
- expected = ctx .func (right , left )
316
- ah .assert_exactly_equal (res , expected )
313
+ m , M = dh .dtype_ranges [res .dtype ]
314
+ scalar_type = dh .get_scalar_type (res .dtype )
315
+ if ctx .right_is_scalar :
316
+ for idx in sh .ndindex (res .shape ):
317
+ scalar_l = scalar_type (left [idx ])
318
+ expected = scalar_l + right
319
+ if not math .isfinite (expected ) or expected <= m or expected >= M :
320
+ continue
321
+ scalar_o = scalar_type (res [idx ])
322
+ f_l = sh .fmt_idx (ctx .left_sym , idx )
323
+ f_o = sh .fmt_idx (ctx .res_name , idx )
324
+ assert isclose (scalar_o , expected ), (
325
+ f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } + { right } )={ expected } "
326
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
327
+ )
328
+ else :
329
+ ph .assert_array (ctx .func_name , res , ctx .func (right , left )) # cumulative
330
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
331
+ scalar_l = scalar_type (left [l_idx ])
332
+ scalar_r = scalar_type (right [r_idx ])
333
+ expected = scalar_l + scalar_r
334
+ if not math .isfinite (expected ) or expected <= m or expected >= M :
335
+ continue
336
+ scalar_o = scalar_type (res [o_idx ])
337
+ f_l = sh .fmt_idx (ctx .left_sym , l_idx )
338
+ f_r = sh .fmt_idx (ctx .right_sym , r_idx )
339
+ f_o = sh .fmt_idx (ctx .res_name , o_idx )
340
+ assert isclose (scalar_o , expected ), (
341
+ f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } + { f_r } )={ expected } "
342
+ f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
343
+ )
317
344
318
345
319
346
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -1487,9 +1514,9 @@ def test_sign(x):
1487
1514
expr = f"({ f_x } / |{ f_x } |)={ expected } "
1488
1515
scalar_o = scalar_type (out [idx ])
1489
1516
f_o = sh .fmt_idx ("out" , idx )
1490
- assert scalar_o == expected , (
1491
- f" { f_o } = { scalar_o } , but should be { expr } [sign()] \n { f_x } = { scalar_x } "
1492
- )
1517
+ assert (
1518
+ scalar_o == expected
1519
+ ), f" { f_o } = { scalar_o } , but should be { expr } [sign()] \n { f_x } = { scalar_x } "
1493
1520
1494
1521
1495
1522
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
0 commit comments