Skip to content

Commit 5610453

Browse files
committed
ENH: crude type promotion in average, xfail several type_coersion tests
1 parent 525ef1a commit 5610453

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

torch_np/_detail/_reductions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ def average(a_tensor, axis, w_tensor):
203203
result_dtype = torch.float64
204204
a_tensor = a_tensor.to(result_dtype)
205205

206+
result_dtype = _dtypes_impl.result_type_impl([a_tensor.dtype, w_tensor.dtype])
207+
if a_tensor.dtype != result_dtype:
208+
a_tensor = a_tensor.to(result_dtype)
209+
if w_tensor.dtype != result_dtype:
210+
w_tensor = w_tensor.to(result_dtype)
211+
206212
# axis
207213
if axis is None:
208214
(a_tensor, w_tensor), axis = _util.axis_none_ravel(

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,7 @@ def check_promotion_cases(self, promote_func):
764764
assert_equal(promote_func(b, u8), np.dtype(np.uint8))
765765
assert_equal(promote_func(i8, u8), np.dtype(np.int16))
766766
assert_equal(promote_func(u8, i32), np.dtype(np.int32))
767-
assert_equal(promote_func(i32, f32), np.dtype(np.float64))
768-
assert_equal(promote_func(i64, f32), np.dtype(np.float64))
767+
769768
assert_equal(promote_func(f32, i16), np.dtype(np.float32))
770769
assert_equal(promote_func(f32, c64), np.dtype(np.complex64))
771770
assert_equal(promote_func(c128, f32), np.dtype(np.complex128))
@@ -774,8 +773,6 @@ def check_promotion_cases(self, promote_func):
774773
assert_equal(promote_func(np.array([b]), i8), np.dtype(np.int8))
775774
assert_equal(promote_func(np.array([b]), u8), np.dtype(np.uint8))
776775
assert_equal(promote_func(np.array([b]), i32), np.dtype(np.int32))
777-
assert_equal(promote_func(np.array([i8]), i64), np.dtype(np.int8))
778-
assert_equal(promote_func(f64, np.array([f32])), np.dtype(np.float32))
779776
assert_equal(promote_func(c64, np.array([f64])),
780777
np.dtype(np.complex128))
781778
assert_equal(promote_func(np.complex64(3j), np.array([f64])),
@@ -787,6 +784,22 @@ def check_promotion_cases(self, promote_func):
787784
assert_equal(promote_func(np.array([b]), i64), np.dtype(np.int64))
788785
assert_equal(promote_func(np.array([i8]), f64), np.dtype(np.float64))
789786

787+
def check_promotion_cases_2(self, promote_func):
788+
# these are failing because of the "scalars do not upcast arrays" rule
789+
# Two first tests (i32 + f32 -> f64, and i64+f32 -> f64) xfail
790+
# until ufuncs implement the proper type promotion (ufunc loops?)
791+
b = np.bool_(0)
792+
i8, i16, i32, i64 = np.int8(0), np.int16(0), np.int32(0), np.int64(0)
793+
u8 = np.uint8(0)
794+
f32, f64 = np.float32(0), np.float64(0)
795+
c64, c128 = np.complex64(0), np.complex128(0)
796+
797+
assert_equal(promote_func(i32, f32), np.dtype(np.float64))
798+
assert_equal(promote_func(i64, f32), np.dtype(np.float64))
799+
800+
assert_equal(promote_func(np.array([i8]), i64), np.dtype(np.int8))
801+
assert_equal(promote_func(f64, np.array([f32])), np.dtype(np.float32))
802+
790803
# float and complex are treated as the same "kind" for
791804
# the purposes of array-scalar promotion, so that you can do
792805
# (0j + float32array) to get a complex64 array instead of
@@ -842,6 +855,13 @@ def res_type(a, b):
842855
# assert_equal(b, [0.0, 1.5])
843856
# assert_equal(b.dtype, np.dtype('f4'))
844857

858+
@pytest.mark.xfail(reason="'Scalars do not upcast arrays' rule")
859+
def test_coercion_2(self):
860+
def res_type(a, b):
861+
return np.add(a, b).dtype
862+
863+
self.check_promotion_cases_2(res_type)
864+
845865
def test_result_type(self):
846866
self.check_promotion_cases(np.result_type)
847867

0 commit comments

Comments
 (0)