Skip to content

Commit fef3a07

Browse files
committed
MAINT: un-xfail TestTypes
1 parent bd0ac5e commit fef3a07

File tree

2 files changed

+23
-209
lines changed

2 files changed

+23
-209
lines changed

torch_np/_ndarray.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,24 @@ def wrapped(x, *args, **kwds):
494494

495495

496496
def can_cast(from_, to, casting="safe"):
497-
from_ = from_.dtype if isinstance(from_, ndarray) else _dtypes.dtype(from_)
498-
to_ = to.dtype if isinstance(to, ndarray) else _dtypes.dtype(to)
497+
from_ = _extract_dtype(from_)
498+
to_ = extract_dtype(to_)
499499

500500
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
501501

502+
'''
503+
# XXX: merge with _dtypes.can_cast. The Q is who converts from ndarray, if needed.
504+
try:
505+
from_dtype = asarray(from_).dtype
506+
except (TypeError, RuntimeError):
507+
# not an array_like; try convering to a dtype
508+
from_dtype = _dtypes.dtype(from_)
509+
510+
try:
511+
to_dtype = asarray(to).dtype
512+
except (TypeError, RuntimeError):
513+
to_dtype = _dtypes.dtype(to)
514+
'''
502515

503516
def _extract_dtype(entry):
504517
try:

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 8 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,6 @@ def test_warnings(self):
743743
assert_("underflow" in str(w[-1].message))
744744

745745

746-
@pytest.mark.xfail(reason="TODO")
747746
class TestTypes:
748747
def check_promotion_cases(self, promote_func):
749748
# tests that the scalars get coerced correctly.
@@ -765,12 +764,9 @@ def check_promotion_cases(self, promote_func):
765764
assert_equal(promote_func(b, u8), np.dtype(np.uint8))
766765
assert_equal(promote_func(i8, u8), np.dtype(np.int16))
767766
assert_equal(promote_func(u8, i32), np.dtype(np.int32))
768-
assert_equal(promote_func(i64, u32), np.dtype(np.int64))
769-
assert_equal(promote_func(u64, i32), np.dtype(np.float64))
770767
assert_equal(promote_func(i32, f32), np.dtype(np.float64))
771768
assert_equal(promote_func(i64, f32), np.dtype(np.float64))
772769
assert_equal(promote_func(f32, i16), np.dtype(np.float32))
773-
assert_equal(promote_func(f32, u32), np.dtype(np.float64))
774770
assert_equal(promote_func(f32, c64), np.dtype(np.complex64))
775771
assert_equal(promote_func(c128, f32), np.dtype(np.complex128))
776772

@@ -779,14 +775,7 @@ def check_promotion_cases(self, promote_func):
779775
assert_equal(promote_func(np.array([b]), u8), np.dtype(np.uint8))
780776
assert_equal(promote_func(np.array([b]), i32), np.dtype(np.int32))
781777
assert_equal(promote_func(np.array([i8]), i64), np.dtype(np.int8))
782-
assert_equal(promote_func(u64, np.array([i32])), np.dtype(np.int32))
783-
assert_equal(promote_func(np.int32(-1), np.array([u64])),
784-
np.dtype(np.float64))
785778
assert_equal(promote_func(f64, np.array([f32])), np.dtype(np.float32))
786-
assert_equal(promote_func(fld, np.array([f32])), np.dtype(np.float32))
787-
assert_equal(promote_func(np.array([f64]), fld), np.dtype(np.float64))
788-
assert_equal(promote_func(fld, np.array([c64])),
789-
np.dtype(np.complex64))
790779
assert_equal(promote_func(c64, np.array([f64])),
791780
np.dtype(np.complex128))
792781
assert_equal(promote_func(np.complex64(3j), np.array([f64])),
@@ -797,7 +786,6 @@ def check_promotion_cases(self, promote_func):
797786
assert_equal(promote_func(np.array([b]), f64), np.dtype(np.float64))
798787
assert_equal(promote_func(np.array([b]), i64), np.dtype(np.int64))
799788
assert_equal(promote_func(np.array([i8]), f64), np.dtype(np.float64))
800-
assert_equal(promote_func(np.array([u16]), f64), np.dtype(np.float64))
801789

802790
# float and complex are treated as the same "kind" for
803791
# the purposes of array-scalar promotion, so that you can do
@@ -856,169 +844,29 @@ def res_type(a, b):
856844

857845
def test_result_type(self):
858846
self.check_promotion_cases(np.result_type)
847+
848+
@pytest.mark.skip(reason='array(None) not supported')
849+
def test_tesult_type_2(self):
859850
assert_(np.result_type(None) == np.dtype(None))
860851

852+
@pytest.mark.skip(reason='no endianness in dtypes')
861853
def test_promote_types_endian(self):
862854
# promote_types should always return native-endian types
863855
assert_equal(np.promote_types('<i8', '<i8'), np.dtype('i8'))
864856
assert_equal(np.promote_types('>i8', '>i8'), np.dtype('i8'))
865857

866-
assert_equal(np.promote_types('>i8', '>U16'), np.dtype('U21'))
867-
assert_equal(np.promote_types('<i8', '<U16'), np.dtype('U21'))
868-
assert_equal(np.promote_types('>U16', '>i8'), np.dtype('U21'))
869-
assert_equal(np.promote_types('<U16', '<i8'), np.dtype('U21'))
870-
871-
def test_can_cast_and_promote_usertypes(self):
872-
# The rational type defines safe casting for signed integers,
873-
# boolean. Rational itself *does* cast safely to double.
874-
# (rational does not actually cast to all signed integers, e.g.
875-
# int64 can be both long and longlong and it registers only the first)
876-
valid_types = ["int8", "int16", "int32", "int64", "bool"]
877-
invalid_types = "BHILQP" + "FDG" + "mM" + "f" + "V"
878-
879-
rational_dt = np.dtype(rational)
880-
for numpy_dtype in valid_types:
881-
numpy_dtype = np.dtype(numpy_dtype)
882-
assert np.can_cast(numpy_dtype, rational_dt)
883-
assert np.promote_types(numpy_dtype, rational_dt) is rational_dt
884-
885-
for numpy_dtype in invalid_types:
886-
numpy_dtype = np.dtype(numpy_dtype)
887-
assert not np.can_cast(numpy_dtype, rational_dt)
888-
with pytest.raises(TypeError):
889-
np.promote_types(numpy_dtype, rational_dt)
890-
891-
double_dt = np.dtype("double")
892-
assert np.can_cast(rational_dt, double_dt)
893-
assert np.promote_types(double_dt, rational_dt) is double_dt
894-
895-
@pytest.mark.parametrize("swap", ["", "swap"])
896-
@pytest.mark.parametrize("string_dtype", ["U", "S"])
897-
def test_promote_types_strings(self, swap, string_dtype):
898-
if swap == "swap":
899-
promote_types = lambda a, b: np.promote_types(b, a)
900-
else:
901-
promote_types = np.promote_types
902-
903-
S = string_dtype
904-
905-
# Promote numeric with unsized string:
906-
assert_equal(promote_types('bool', S), np.dtype(S+'5'))
907-
assert_equal(promote_types('b', S), np.dtype(S+'4'))
908-
assert_equal(promote_types('u1', S), np.dtype(S+'3'))
909-
assert_equal(promote_types('u2', S), np.dtype(S+'5'))
910-
assert_equal(promote_types('u4', S), np.dtype(S+'10'))
911-
assert_equal(promote_types('u8', S), np.dtype(S+'20'))
912-
assert_equal(promote_types('i1', S), np.dtype(S+'4'))
913-
assert_equal(promote_types('i2', S), np.dtype(S+'6'))
914-
assert_equal(promote_types('i4', S), np.dtype(S+'11'))
915-
assert_equal(promote_types('i8', S), np.dtype(S+'21'))
916-
# Promote numeric with sized string:
917-
assert_equal(promote_types('bool', S+'1'), np.dtype(S+'5'))
918-
assert_equal(promote_types('bool', S+'30'), np.dtype(S+'30'))
919-
assert_equal(promote_types('b', S+'1'), np.dtype(S+'4'))
920-
assert_equal(promote_types('b', S+'30'), np.dtype(S+'30'))
921-
assert_equal(promote_types('u1', S+'1'), np.dtype(S+'3'))
922-
assert_equal(promote_types('u1', S+'30'), np.dtype(S+'30'))
923-
assert_equal(promote_types('u2', S+'1'), np.dtype(S+'5'))
924-
assert_equal(promote_types('u2', S+'30'), np.dtype(S+'30'))
925-
assert_equal(promote_types('u4', S+'1'), np.dtype(S+'10'))
926-
assert_equal(promote_types('u4', S+'30'), np.dtype(S+'30'))
927-
assert_equal(promote_types('u8', S+'1'), np.dtype(S+'20'))
928-
assert_equal(promote_types('u8', S+'30'), np.dtype(S+'30'))
929-
930-
931-
@pytest.mark.parametrize("dtype",
932-
list(np.typecodes["All"]) +
933-
["i,i", "10i", "S3", "S100", "U3", "U100", rational])
934-
def test_promote_identical_types_metadata(self, dtype):
935-
# The same type passed in twice to promote types always
936-
# preserves metadata
937-
metadata = {1: 1}
938-
dtype = np.dtype(dtype, metadata=metadata)
939-
940-
res = np.promote_types(dtype, dtype)
941-
assert res.metadata == dtype.metadata
942-
943-
# byte-swapping preserves and makes the dtype native:
944-
dtype = dtype.newbyteorder()
945-
if dtype.isnative:
946-
# The type does not have byte swapping
947-
return
948-
949-
res = np.promote_types(dtype, dtype)
950-
951-
# Metadata is (currently) generally lost on byte-swapping (except for
952-
# unicode.
953-
if dtype.char != "U":
954-
assert res.metadata is None
955-
else:
956-
assert res.metadata == metadata
957-
assert res.isnative
958-
959-
@pytest.mark.slow
960-
@pytest.mark.filterwarnings('ignore:Promotion of numbers:FutureWarning')
961-
@pytest.mark.parametrize(["dtype1", "dtype2"],
962-
itertools.product(
963-
list(np.typecodes["All"]) +
964-
["i,i", "S3", "S100", "U3", "U100", rational],
965-
repeat=2))
966-
def test_promote_types_metadata(self, dtype1, dtype2):
967-
"""Metadata handling in promotion does not appear formalized
968-
right now in NumPy. This test should thus be considered to
969-
document behaviour, rather than test the correct definition of it.
970-
971-
This test is very ugly, it was useful for rewriting part of the
972-
promotion, but probably should eventually be replaced/deleted
973-
(i.e. when metadata handling in promotion is better defined).
974-
"""
975-
metadata1 = {1: 1}
976-
metadata2 = {2: 2}
977-
dtype1 = np.dtype(dtype1, metadata=metadata1)
978-
dtype2 = np.dtype(dtype2, metadata=metadata2)
979-
980-
try:
981-
res = np.promote_types(dtype1, dtype2)
982-
except TypeError:
983-
# Promotion failed, this test only checks metadata
984-
return
985-
986-
if res.char not in "USV" or res.names is not None or res.shape != ():
987-
# All except string dtypes (and unstructured void) lose metadata
988-
# on promotion (unless both dtypes are identical).
989-
# At some point structured ones did not, but were restrictive.
990-
assert res.metadata is None
991-
elif res == dtype1:
992-
# If one result is the result, it is usually returned unchanged:
993-
assert res is dtype1
994-
elif res == dtype2:
995-
# dtype1 may have been cast to the same type/kind as dtype2.
996-
# If the resulting dtype is identical we currently pick the cast
997-
# version of dtype1, which lost the metadata:
998-
if np.promote_types(dtype1, dtype2.kind) == dtype2:
999-
res.metadata is None
1000-
else:
1001-
res.metadata == metadata2
1002-
else:
1003-
assert res.metadata is None
1004-
1005-
# Try again for byteswapped version
1006-
dtype1 = dtype1.newbyteorder()
1007-
assert dtype1.metadata == metadata1
1008-
res_bs = np.promote_types(dtype1, dtype2)
1009-
assert res_bs == res
1010-
assert res_bs.metadata == res.metadata
1011-
1012858
def test_can_cast(self):
1013859
assert_(np.can_cast(np.int32, np.int64))
1014860
assert_(np.can_cast(np.float64, complex))
1015861
assert_(not np.can_cast(complex, float))
1016862

1017863
assert_(np.can_cast('i8', 'f8'))
1018864
assert_(not np.can_cast('i8', 'f4'))
1019-
assert_(np.can_cast('i4', 'S11'))
1020865

1021866
assert_(np.can_cast('i8', 'i8', 'no'))
867+
868+
@pytest.mark.skip(reason="no endianness in dtypes")
869+
def test_can_cast_2(self):
1022870
assert_(not np.can_cast('<i8', '>i8', 'no'))
1023871

1024872
assert_(np.can_cast('<i8', '>i8', 'equiv'))
@@ -1032,60 +880,13 @@ def test_can_cast(self):
1032880

1033881
assert_(np.can_cast('<i8', '>u4', 'unsafe'))
1034882

1035-
assert_(np.can_cast('bool', 'S5'))
1036-
assert_(not np.can_cast('bool', 'S4'))
1037-
1038-
assert_(np.can_cast('b', 'S4'))
1039-
assert_(not np.can_cast('b', 'S3'))
1040-
1041-
assert_(np.can_cast('u1', 'S3'))
1042-
assert_(not np.can_cast('u1', 'S2'))
1043-
assert_(np.can_cast('u2', 'S5'))
1044-
assert_(not np.can_cast('u2', 'S4'))
1045-
assert_(np.can_cast('u4', 'S10'))
1046-
assert_(not np.can_cast('u4', 'S9'))
1047-
assert_(np.can_cast('u8', 'S20'))
1048-
assert_(not np.can_cast('u8', 'S19'))
1049-
1050-
assert_(np.can_cast('i1', 'S4'))
1051-
assert_(not np.can_cast('i1', 'S3'))
1052-
assert_(np.can_cast('i2', 'S6'))
1053-
assert_(not np.can_cast('i2', 'S5'))
1054-
assert_(np.can_cast('i4', 'S11'))
1055-
assert_(not np.can_cast('i4', 'S10'))
1056-
assert_(np.can_cast('i8', 'S21'))
1057-
assert_(not np.can_cast('i8', 'S20'))
1058-
1059-
assert_(np.can_cast('bool', 'S5'))
1060-
assert_(not np.can_cast('bool', 'S4'))
1061-
1062-
assert_(np.can_cast('b', 'U4'))
1063-
assert_(not np.can_cast('b', 'U3'))
1064-
1065-
assert_(np.can_cast('u1', 'U3'))
1066-
assert_(not np.can_cast('u1', 'U2'))
1067-
assert_(np.can_cast('u2', 'U5'))
1068-
assert_(not np.can_cast('u2', 'U4'))
1069-
assert_(np.can_cast('u4', 'U10'))
1070-
assert_(not np.can_cast('u4', 'U9'))
1071-
assert_(np.can_cast('u8', 'U20'))
1072-
assert_(not np.can_cast('u8', 'U19'))
1073-
1074-
assert_(np.can_cast('i1', 'U4'))
1075-
assert_(not np.can_cast('i1', 'U3'))
1076-
assert_(np.can_cast('i2', 'U6'))
1077-
assert_(not np.can_cast('i2', 'U5'))
1078-
assert_(np.can_cast('i4', 'U11'))
1079-
assert_(not np.can_cast('i4', 'U10'))
1080-
assert_(np.can_cast('i8', 'U21'))
1081-
assert_(not np.can_cast('i8', 'U20'))
1082-
1083883
assert_raises(TypeError, np.can_cast, 'i4', None)
1084884
assert_raises(TypeError, np.can_cast, None, 'i4')
1085885

1086886
# Also test keyword arguments
1087887
assert_(np.can_cast(from_=np.int32, to=np.int64))
1088888

889+
@pytest.mark.xfail(reason='value-based casting?')
1089890
def test_can_cast_values(self):
1090891
# gh-5917
1091892
for dt in np.sctypes['int'] + np.sctypes['uint']:

0 commit comments

Comments
 (0)