Skip to content

Commit e8a9270

Browse files
committed
Better values testing for bitwise op/elwise tests
1 parent 1e228bc commit e8a9270

File tree

1 file changed

+202
-109
lines changed

1 file changed

+202
-109
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 202 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -437,31 +437,52 @@ def test_bitwise_and(
437437
res = func(left, right)
438438

439439
assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name)
440-
if not right_is_scalar:
441-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
442-
shape = broadcast_shapes(left.shape, right.shape)
443-
ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape")
444-
_left = xp.broadcast_to(left, shape)
445-
_right = xp.broadcast_to(right, shape)
446-
447-
# Compare against the Python & operator.
448-
if res.dtype == xp.bool:
449-
for idx in sh.ndindex(res.shape):
450-
s_left = bool(_left[idx])
451-
s_right = bool(_right[idx])
452-
s_res = bool(res[idx])
453-
assert (s_left and s_right) == s_res
454-
else:
455-
for idx in sh.ndindex(res.shape):
456-
s_left = int(_left[idx])
457-
s_right = int(_right[idx])
458-
s_res = int(res[idx])
459-
s_and = ah.int_to_dtype(
460-
s_left & s_right,
440+
assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name)
441+
scalar_type = dh.get_scalar_type(res.dtype)
442+
if right_is_scalar:
443+
for idx in sh.ndindex(res.shape):
444+
scalar_l = scalar_type(left[idx])
445+
if res.dtype == xp.bool:
446+
expected = scalar_l and right
447+
else:
448+
# for mypy
449+
assert isinstance(scalar_l, int)
450+
assert isinstance(right, int)
451+
expected = ah.int_to_dtype(
452+
scalar_l & right,
453+
dh.dtype_nbits[res.dtype],
454+
dh.dtype_signed[res.dtype],
455+
)
456+
scalar_o = scalar_type(res[idx])
457+
f_l = sh.fmt_idx(left_sym, idx)
458+
f_o = sh.fmt_idx(res_name, idx)
459+
assert scalar_o == expected, (
460+
f"{f_o}={scalar_o}, but should be ({f_l} & {right})={expected} "
461+
f"[{func_name}()]\n{f_l}={scalar_l}"
462+
)
463+
else:
464+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
465+
scalar_l = scalar_type(left[l_idx])
466+
scalar_r = scalar_type(right[r_idx])
467+
if res.dtype == xp.bool:
468+
expected = scalar_l and scalar_r
469+
else:
470+
# for mypy
471+
assert isinstance(scalar_l, int)
472+
assert isinstance(scalar_r, int)
473+
expected = ah.int_to_dtype(
474+
scalar_l & scalar_r,
461475
dh.dtype_nbits[res.dtype],
462476
dh.dtype_signed[res.dtype],
463477
)
464-
assert s_and == s_res
478+
scalar_o = scalar_type(res[o_idx])
479+
f_l = sh.fmt_idx(left_sym, l_idx)
480+
f_r = sh.fmt_idx(right_sym, r_idx)
481+
f_o = sh.fmt_idx(res_name, o_idx)
482+
assert scalar_o == expected, (
483+
f"{f_o}={scalar_o}, but should be ({f_l} & {f_r})={expected} "
484+
f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
485+
)
465486

466487

467488
@pytest.mark.parametrize(
@@ -489,25 +510,41 @@ def test_bitwise_left_shift(
489510
res = func(left, right)
490511

491512
assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name)
492-
if not right_is_scalar:
493-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
494-
shape = broadcast_shapes(left.shape, right.shape)
495-
ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape")
496-
_left = xp.broadcast_to(left, shape)
497-
_right = xp.broadcast_to(right, shape)
498-
499-
# Compare against the Python << operator.
513+
assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name)
514+
if right_is_scalar:
500515
for idx in sh.ndindex(res.shape):
501-
s_left = int(_left[idx])
502-
s_right = int(_right[idx])
503-
s_res = int(res[idx])
504-
s_shift = ah.int_to_dtype(
516+
scalar_l = int(left[idx])
517+
expected = ah.int_to_dtype(
505518
# We avoid shifting very large ints
506-
s_left << s_right if s_right < dh.dtype_nbits[res.dtype] else 0,
519+
scalar_l << right if right < dh.dtype_nbits[res.dtype] else 0,
507520
dh.dtype_nbits[res.dtype],
508521
dh.dtype_signed[res.dtype],
509522
)
510-
assert s_shift == s_res
523+
scalar_o = int(res[idx])
524+
f_l = sh.fmt_idx(left_sym, idx)
525+
f_o = sh.fmt_idx(res_name, idx)
526+
assert scalar_o == expected, (
527+
f"{f_o}={scalar_o}, but should be ({f_l} << {right})={expected} "
528+
f"[{func_name}()]\n{f_l}={scalar_l}"
529+
)
530+
else:
531+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
532+
scalar_l = int(left[l_idx])
533+
scalar_r = int(right[r_idx])
534+
expected = ah.int_to_dtype(
535+
# We avoid shifting very large ints
536+
scalar_l << scalar_r if scalar_r < dh.dtype_nbits[res.dtype] else 0,
537+
dh.dtype_nbits[res.dtype],
538+
dh.dtype_signed[res.dtype],
539+
)
540+
scalar_o = int(res[o_idx])
541+
f_l = sh.fmt_idx(left_sym, l_idx)
542+
f_r = sh.fmt_idx(right_sym, r_idx)
543+
f_o = sh.fmt_idx(res_name, o_idx)
544+
assert scalar_o == expected, (
545+
f"{f_o}={scalar_o}, but should be ({f_l} << {f_r})={expected} "
546+
f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
547+
)
511548

512549

513550
@pytest.mark.parametrize(
@@ -522,20 +559,23 @@ def test_bitwise_invert(func_name, func, strat, data):
522559

523560
ph.assert_dtype(func_name, x.dtype, out.dtype)
524561
ph.assert_shape(func_name, out.shape, x.shape)
525-
# Compare against the Python ~ operator.
526-
if out.dtype == xp.bool:
527-
for idx in sh.ndindex(out.shape):
528-
s_x = bool(x[idx])
529-
s_out = bool(out[idx])
530-
assert (not s_x) == s_out
531-
else:
532-
for idx in sh.ndindex(out.shape):
533-
s_x = int(x[idx])
534-
s_out = int(out[idx])
535-
s_invert = ah.int_to_dtype(
536-
~s_x, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]
562+
for idx in sh.ndindex(out.shape):
563+
if out.dtype == xp.bool:
564+
scalar_x = bool(x[idx])
565+
scalar_o = bool(out[idx])
566+
expected = not scalar_x
567+
else:
568+
scalar_x = int(x[idx])
569+
scalar_o = int(out[idx])
570+
expected = ah.int_to_dtype(
571+
~scalar_x, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]
537572
)
538-
assert s_invert == s_out
573+
f_x = sh.fmt_idx("x", idx)
574+
f_o = sh.fmt_idx("out", idx)
575+
assert scalar_o == expected, (
576+
f"{f_o}={scalar_o}, but should be ~{f_x}={scalar_x} "
577+
f"[{func_name}()]\n{f_x}={scalar_x}"
578+
)
539579

540580

541581
@pytest.mark.parametrize(
@@ -559,31 +599,50 @@ def test_bitwise_or(
559599
res = func(left, right)
560600

561601
assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name)
562-
if not right_is_scalar:
563-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
564-
shape = broadcast_shapes(left.shape, right.shape)
565-
ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape")
566-
_left = xp.broadcast_to(left, shape)
567-
_right = xp.broadcast_to(right, shape)
568-
569-
# Compare against the Python | operator.
570-
if res.dtype == xp.bool:
571-
for idx in sh.ndindex(res.shape):
572-
s_left = bool(_left[idx])
573-
s_right = bool(_right[idx])
574-
s_res = bool(res[idx])
575-
assert (s_left or s_right) == s_res
576-
else:
577-
for idx in sh.ndindex(res.shape):
578-
s_left = int(_left[idx])
579-
s_right = int(_right[idx])
580-
s_res = int(res[idx])
581-
s_or = ah.int_to_dtype(
582-
s_left | s_right,
602+
assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name)
603+
if right_is_scalar:
604+
for idx in sh.ndindex(res.shape):
605+
if res.dtype == xp.bool:
606+
scalar_l = bool(left[idx])
607+
scalar_o = bool(res[idx])
608+
expected = scalar_l or right
609+
else:
610+
scalar_l = int(left[idx])
611+
scalar_o = int(res[idx])
612+
expected = ah.int_to_dtype(
613+
scalar_l | right,
614+
dh.dtype_nbits[res.dtype],
615+
dh.dtype_signed[res.dtype],
616+
)
617+
f_l = sh.fmt_idx(left_sym, idx)
618+
f_o = sh.fmt_idx(res_name, idx)
619+
assert scalar_o == expected, (
620+
f"{f_o}={scalar_o}, but should be ({f_l} | {right})={expected} "
621+
f"[{func_name}()]\n{f_l}={scalar_l}"
622+
)
623+
else:
624+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
625+
if res.dtype == xp.bool:
626+
scalar_l = bool(left[l_idx])
627+
scalar_r = bool(right[r_idx])
628+
scalar_o = bool(res[o_idx])
629+
expected = scalar_l or scalar_r
630+
else:
631+
scalar_l = int(left[l_idx])
632+
scalar_r = int(right[r_idx])
633+
scalar_o = int(res[o_idx])
634+
expected = ah.int_to_dtype(
635+
scalar_l | scalar_r,
583636
dh.dtype_nbits[res.dtype],
584637
dh.dtype_signed[res.dtype],
585638
)
586-
assert s_or == s_res
639+
f_l = sh.fmt_idx(left_sym, l_idx)
640+
f_r = sh.fmt_idx(right_sym, r_idx)
641+
f_o = sh.fmt_idx(res_name, o_idx)
642+
assert scalar_o == expected, (
643+
f"{f_o}={scalar_o}, but should be ({f_l} | {f_r})={expected} "
644+
f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
645+
)
587646

588647

589648
@pytest.mark.parametrize(
@@ -611,24 +670,39 @@ def test_bitwise_right_shift(
611670
res = func(left, right)
612671

613672
assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name)
614-
if not right_is_scalar:
615-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
616-
shape = broadcast_shapes(left.shape, right.shape)
617-
ph.assert_shape(
618-
"bitwise_right_shift", res.shape, shape, repr_name=f"{res_name}.shape"
619-
)
620-
_left = xp.broadcast_to(left, shape)
621-
_right = xp.broadcast_to(right, shape)
622-
623-
# Compare against the Python >> operator.
673+
assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name)
674+
if right_is_scalar:
624675
for idx in sh.ndindex(res.shape):
625-
s_left = int(_left[idx])
626-
s_right = int(_right[idx])
627-
s_res = int(res[idx])
628-
s_shift = ah.int_to_dtype(
629-
s_left >> s_right, dh.dtype_nbits[res.dtype], dh.dtype_signed[res.dtype]
676+
scalar_l = int(left[idx])
677+
expected = ah.int_to_dtype(
678+
scalar_l >> right,
679+
dh.dtype_nbits[res.dtype],
680+
dh.dtype_signed[res.dtype],
681+
)
682+
scalar_o = int(res[idx])
683+
f_l = sh.fmt_idx(left_sym, idx)
684+
f_o = sh.fmt_idx(res_name, idx)
685+
assert scalar_o == expected, (
686+
f"{f_o}={scalar_o}, but should be ({f_l} >> {right})={expected} "
687+
f"[{func_name}()]\n{f_l}={scalar_l}"
688+
)
689+
else:
690+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
691+
scalar_l = int(left[l_idx])
692+
scalar_r = int(right[r_idx])
693+
expected = ah.int_to_dtype(
694+
scalar_l >> scalar_r,
695+
dh.dtype_nbits[res.dtype],
696+
dh.dtype_signed[res.dtype],
697+
)
698+
scalar_o = int(res[o_idx])
699+
f_l = sh.fmt_idx(left_sym, l_idx)
700+
f_r = sh.fmt_idx(right_sym, r_idx)
701+
f_o = sh.fmt_idx(res_name, o_idx)
702+
assert scalar_o == expected, (
703+
f"{f_o}={scalar_o}, but should be ({f_l} >> {f_r})={expected} "
704+
f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
630705
)
631-
assert s_shift == s_res
632706

633707

634708
@pytest.mark.parametrize(
@@ -652,31 +726,50 @@ def test_bitwise_xor(
652726
res = func(left, right)
653727

654728
assert_binary_param_dtype(func_name, left, right, right_is_scalar, res, res_name)
655-
if not right_is_scalar:
656-
# TODO: generate indices without broadcasting arrays (see test_equal comment)
657-
shape = broadcast_shapes(left.shape, right.shape)
658-
ph.assert_shape(func_name, res.shape, shape, repr_name=f"{res_name}.shape")
659-
_left = xp.broadcast_to(left, shape)
660-
_right = xp.broadcast_to(right, shape)
661-
662-
# Compare against the Python ^ operator.
663-
if res.dtype == xp.bool:
664-
for idx in sh.ndindex(res.shape):
665-
s_left = bool(_left[idx])
666-
s_right = bool(_right[idx])
667-
s_res = bool(res[idx])
668-
assert (s_left ^ s_right) == s_res
669-
else:
670-
for idx in sh.ndindex(res.shape):
671-
s_left = int(_left[idx])
672-
s_right = int(_right[idx])
673-
s_res = int(res[idx])
674-
s_xor = ah.int_to_dtype(
675-
s_left ^ s_right,
729+
assert_binary_param_shape(func_name, left, right, right_is_scalar, res, res_name)
730+
if right_is_scalar:
731+
for idx in sh.ndindex(res.shape):
732+
if res.dtype == xp.bool:
733+
scalar_l = bool(left[idx])
734+
scalar_o = bool(res[idx])
735+
expected = scalar_l ^ right
736+
else:
737+
scalar_l = int(left[idx])
738+
scalar_o = int(res[idx])
739+
expected = ah.int_to_dtype(
740+
scalar_l ^ right,
676741
dh.dtype_nbits[res.dtype],
677742
dh.dtype_signed[res.dtype],
678743
)
679-
assert s_xor == s_res
744+
f_l = sh.fmt_idx(left_sym, idx)
745+
f_o = sh.fmt_idx(res_name, idx)
746+
assert scalar_o == expected, (
747+
f"{f_o}={scalar_o}, but should be ({f_l} ^ {right})={expected} "
748+
f"[{func_name}()]\n{f_l}={scalar_l}"
749+
)
750+
else:
751+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
752+
if res.dtype == xp.bool:
753+
scalar_l = bool(left[l_idx])
754+
scalar_r = bool(right[r_idx])
755+
scalar_o = bool(res[o_idx])
756+
expected = scalar_l ^ scalar_r
757+
else:
758+
scalar_l = int(left[l_idx])
759+
scalar_r = int(right[r_idx])
760+
scalar_o = int(res[o_idx])
761+
expected = ah.int_to_dtype(
762+
scalar_l ^ scalar_r,
763+
dh.dtype_nbits[res.dtype],
764+
dh.dtype_signed[res.dtype],
765+
)
766+
f_l = sh.fmt_idx(left_sym, l_idx)
767+
f_r = sh.fmt_idx(right_sym, r_idx)
768+
f_o = sh.fmt_idx(res_name, o_idx)
769+
assert scalar_o == expected, (
770+
f"{f_o}={scalar_o}, but should be ({f_l} ^ {f_r})={expected} "
771+
f"[{func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
772+
)
680773

681774

682775
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)