@@ -1351,19 +1351,19 @@ def max_error(y_true, y_pred):
1351
1351
1352
1352
def _mean_tweedie_deviance (y_true , y_pred , sample_weight , power ):
1353
1353
"""Mean Tweedie deviance regression loss."""
1354
- xp , _ = get_namespace (y_true , y_pred )
1354
+ xp , _ , device_ = get_namespace_and_device (y_true , y_pred )
1355
1355
p = power
1356
- zero = xp .asarray (0 , dtype = y_true .dtype , device = y_true . device )
1356
+ zero = xp .asarray (0 , dtype = y_true .dtype , device = device_ )
1357
1357
if p < 0 :
1358
1358
# 'Extreme stable', y any real number, y_pred > 0
1359
1359
dev = 2 * (
1360
1360
xp .pow (
1361
1361
xp .where (y_true > 0 , y_true , zero ),
1362
- xp .asarray (2 - p , device = y_true . device ),
1362
+ xp .asarray (2 - p , device = device_ ),
1363
1363
)
1364
1364
/ ((1 - p ) * (2 - p ))
1365
- - y_true * xp .pow (y_pred , xp .asarray (1 - p , device = y_true . device )) / (1 - p )
1366
- + xp .pow (y_pred , xp .asarray (2 - p , device = y_true . device )) / (2 - p )
1365
+ - y_true * xp .pow (y_pred , xp .asarray (1 - p , device = device_ )) / (1 - p )
1366
+ + xp .pow (y_pred , xp .asarray (2 - p , device = device_ )) / (2 - p )
1367
1367
)
1368
1368
elif p == 0 :
1369
1369
# Normal distribution, y and y_pred any real number
@@ -1376,10 +1376,9 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
1376
1376
dev = 2 * (xp .log (y_pred / y_true ) + y_true / y_pred - 1 )
1377
1377
else :
1378
1378
dev = 2 * (
1379
- xp .pow (y_true , xp .asarray (2 - p , device = y_true .device ))
1380
- / ((1 - p ) * (2 - p ))
1381
- - y_true * xp .pow (y_pred , xp .asarray (1 - p , device = y_true .device )) / (1 - p )
1382
- + xp .pow (y_pred , xp .asarray (2 - p , device = y_true .device )) / (2 - p )
1379
+ xp .pow (y_true , xp .asarray (2 - p , device = device_ )) / ((1 - p ) * (2 - p ))
1380
+ - y_true * xp .pow (y_pred , xp .asarray (1 - p , device = device_ )) / (1 - p )
1381
+ + xp .pow (y_pred , xp .asarray (2 - p , device = device_ )) / (2 - p )
1383
1382
)
1384
1383
return float (_average (dev , weights = sample_weight ))
1385
1384
0 commit comments