@@ -1353,17 +1353,16 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
1353
1353
"""Mean Tweedie deviance regression loss."""
1354
1354
xp , _ , device_ = get_namespace_and_device (y_true , y_pred )
1355
1355
p = power
1356
- zero = xp .asarray (0 , dtype = y_true .dtype , device = device_ )
1357
1356
if p < 0 :
1358
1357
# 'Extreme stable', y any real number, y_pred > 0
1359
1358
dev = 2 * (
1360
1359
xp .pow (
1361
- xp .where (y_true > 0 , y_true , zero ),
1362
- xp . asarray (2 - p , device = device_ ),
1360
+ xp .where (y_true > 0 , y_true , 0 ),
1361
+ (2 - p ),
1363
1362
)
1364
1363
/ ((1 - p ) * (2 - p ))
1365
- - y_true * xp .pow (y_pred , xp . asarray ( 1 - p , device = device_ ) ) / (1 - p )
1366
- + xp .pow (y_pred , xp . asarray ( 2 - p , device = device_ ) ) / (2 - p )
1364
+ - y_true * xp .pow (y_pred , 1 - p ) / (1 - p )
1365
+ + xp .pow (y_pred , 2 - p ) / (2 - p )
1367
1366
)
1368
1367
elif p == 0 :
1369
1368
# Normal distribution, y and y_pred any real number
@@ -1376,9 +1375,9 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
1376
1375
dev = 2 * (xp .log (y_pred / y_true ) + y_true / y_pred - 1 )
1377
1376
else :
1378
1377
dev = 2 * (
1379
- xp .pow (y_true , xp . asarray ( 2 - p , device = device_ ) ) / ((1 - p ) * (2 - p ))
1380
- - y_true * xp .pow (y_pred , xp . asarray ( 1 - p , device = device_ ) ) / (1 - p )
1381
- + xp .pow (y_pred , xp . asarray ( 2 - p , device = device_ ) ) / (2 - p )
1378
+ xp .pow (y_true , 2 - p ) / ((1 - p ) * (2 - p ))
1379
+ - y_true * xp .pow (y_pred , 1 - p ) / (1 - p )
1380
+ + xp .pow (y_pred , 2 - p ) / (2 - p )
1382
1381
)
1383
1382
return float (_average (dev , weights = sample_weight ))
1384
1383
0 commit comments