@@ -3474,7 +3474,7 @@ def test_linear_interpolation_formula_0d_inputs(self):
3474
3474
assert nfb ._lerp (a , b , t ) == 2.6
3475
3475
3476
3476
3477
- @pytest .mark .xfail (reason = 'TODO: implement' )
3477
+ # @pytest.mark.xfail(reason='TODO: implement')
3478
3478
class TestMedian :
3479
3479
3480
3480
def test_basic (self ):
@@ -3494,7 +3494,11 @@ def test_basic(self):
3494
3494
assert_equal (a [0 ], np .median (a ))
3495
3495
a = np .array ([0.0444502 , 0.141249 , 0.0463301 ])
3496
3496
assert_equal (a [- 1 ], np .median (a ))
3497
+
3498
+ @pytest .mark .xfail (reason = "median: scalar output vs 0-dim" )
3499
+ def test_basic_2 (self ):
3497
3500
# check array scalar result
3501
+ a = np .array ([0.0444502 , 0.141249 , 0.0463301 ])
3498
3502
assert_equal (np .median (a ).ndim , 0 )
3499
3503
a [1 ] = np .nan
3500
3504
assert_equal (np .median (a ).ndim , 0 )
@@ -3590,7 +3594,7 @@ def test_nan_behavior(self):
3590
3594
3591
3595
# no axis
3592
3596
assert_equal (np .median (a ), np .nan )
3593
- assert_equal (np .median (a ).ndim , 0 )
3597
+ # assert_equal(np.median(a).ndim, 0)
3594
3598
3595
3599
# axis0
3596
3600
b = np .median (np .arange (24 , dtype = float ).reshape (2 , 3 , 4 ), 0 )
@@ -3604,12 +3608,29 @@ def test_nan_behavior(self):
3604
3608
b [1 , 2 ] = np .nan
3605
3609
assert_equal (np .median (a , 1 ), b )
3606
3610
3611
+
3612
+ @pytest .mark .xfail (reason = "median: does not support tuple axes" )
3613
+ def test_nan_behavior_2 (self ):
3614
+ a = np .arange (24 , dtype = float ).reshape (2 , 3 , 4 )
3615
+ a [1 , 2 , 3 ] = np .nan
3616
+ a [1 , 1 , 2 ] = np .nan
3617
+
3607
3618
# axis02
3608
3619
b = np .median (np .arange (24 , dtype = float ).reshape (2 , 3 , 4 ), (0 , 2 ))
3609
3620
b [1 ] = np .nan
3610
3621
b [2 ] = np .nan
3611
3622
assert_equal (np .median (a , (0 , 2 )), b )
3612
3623
3624
+ @pytest .mark .xfail (reason = "median: scalar vs 0-dim" )
3625
+ def test_nan_behavior_3 (self ):
3626
+ a = np .arange (24 , dtype = float ).reshape (2 , 3 , 4 )
3627
+ a [1 , 2 , 3 ] = np .nan
3628
+ a [1 , 1 , 2 ] = np .nan
3629
+
3630
+ # no axis
3631
+ assert_equal (np .median (a ).ndim , 0 )
3632
+
3633
+ @pytest .mark .xfail (reason = "median: torch.quantile does not handle empty tensors" )
3613
3634
@pytest .mark .skipif (IS_WASM , reason = "fp errors don't work correctly" )
3614
3635
def test_empty (self ):
3615
3636
# mean(empty array) emits two warnings: empty slice and divide by 0
@@ -3640,6 +3661,7 @@ def test_empty(self):
3640
3661
assert_equal (np .median (a , axis = 2 ), b )
3641
3662
assert_ (w [0 ].category is RuntimeWarning )
3642
3663
3664
+ @pytest .mark .xfail (reason = "median: tuple axes not implemented" )
3643
3665
def test_extended_axis (self ):
3644
3666
o = np .random .normal (size = (71 , 23 ))
3645
3667
x = np .dstack ([o ] * 10 )
@@ -3682,6 +3704,10 @@ def test_keepdims(self):
3682
3704
d = np .ones ((3 , 5 , 7 , 11 ))
3683
3705
assert_equal (np .median (d , axis = None , keepdims = True ).shape ,
3684
3706
(1 , 1 , 1 , 1 ))
3707
+
3708
+ @pytest .mark .xfail (reason = "median: tuple axis" )
3709
+ def test_keepdims_2 (self ):
3710
+ d = np .ones ((3 , 5 , 7 , 11 ))
3685
3711
assert_equal (np .median (d , axis = (0 , 1 ), keepdims = True ).shape ,
3686
3712
(1 , 1 , 7 , 11 ))
3687
3713
assert_equal (np .median (d , axis = (0 , 3 ), keepdims = True ).shape ,
@@ -3693,6 +3719,7 @@ def test_keepdims(self):
3693
3719
assert_equal (np .median (d , axis = (0 , 1 , 3 ), keepdims = True ).shape ,
3694
3720
(1 , 1 , 7 , 1 ))
3695
3721
3722
+ @pytest .mark .xfail (reason = "median: tuple axis" )
3696
3723
@pytest .mark .parametrize (
3697
3724
argnames = 'axis' ,
3698
3725
argvalues = [
0 commit comments