File tree 2 files changed +14
-8
lines changed
2 files changed +14
-8
lines changed Original file line number Diff line number Diff line change @@ -36,12 +36,22 @@ def count_nonzero(a, axis=None):
36
36
37
37
def argmax (tensor , axis = None ):
38
38
axis = _util .allow_only_single_axis (axis )
39
+
40
+ if tensor .dtype == torch .bool :
41
+ # RuntimeError: "argmax_cpu" not implemented for 'Bool'
42
+ tensor = tensor .view (torch .int8 )
43
+
39
44
tensor = torch .argmax (tensor , axis )
40
45
return tensor
41
46
42
47
43
48
def argmin (tensor , axis = None ):
44
49
axis = _util .allow_only_single_axis (axis )
50
+
51
+ if tensor .dtype == torch .bool :
52
+ # RuntimeError: "argmin_cpu" not implemented for 'Bool'
53
+ tensor = tensor .view (torch .int8 )
54
+
45
55
tensor = torch .argmin (tensor , axis )
46
56
return tensor
47
57
Original file line number Diff line number Diff line change @@ -404,10 +404,8 @@ def test_combinations(self, data):
404
404
# with suppress_warnings() as sup:
405
405
# sup.filter(RuntimeWarning,
406
406
# "invalid value encountered in reduce")
407
- if np .asarray (arr ).dtype .kind in "cb" :
408
- pytest .xfail (
409
- reason = "'max_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
410
- )
407
+ if np .asarray (arr ).dtype .kind in "c" :
408
+ pytest .xfail (reason = "'max_values_cpu' not implemented for 'ComplexDouble'" )
411
409
412
410
val = np .max (arr )
413
411
@@ -508,10 +506,8 @@ class TestArgmin:
508
506
def test_combinations (self , data ):
509
507
arr , pos = data
510
508
511
- if np .asarray (arr ).dtype .kind in "cb" :
512
- pytest .xfail (
513
- reason = "'min_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
514
- )
509
+ if np .asarray (arr ).dtype .kind in "c" :
510
+ pytest .xfail (reason = "'min_values_cpu' not implemented for 'ComplexDouble'" )
515
511
516
512
# with suppress_warnings() as sup:
517
513
# sup.filter(RuntimeWarning, "invalid value encountered in reduce")
You can’t perform that action at this time.
0 commit comments