Skip to content

Commit 5e4356a

Browse files
committed
Fix
1 parent cebf42d commit 5e4356a

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

sklearn/metrics/_regression.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -1351,19 +1351,19 @@ def max_error(y_true, y_pred):
13511351

13521352
def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
13531353
"""Mean Tweedie deviance regression loss."""
1354-
xp, _ = get_namespace(y_true, y_pred)
1354+
xp, _, device_ = get_namespace_and_device(y_true, y_pred)
13551355
p = power
1356-
zero = xp.asarray(0, dtype=y_true.dtype, device=y_true.device)
1356+
zero = xp.asarray(0, dtype=y_true.dtype, device=device_)
13571357
if p < 0:
13581358
# 'Extreme stable', y any real number, y_pred > 0
13591359
dev = 2 * (
13601360
xp.pow(
13611361
xp.where(y_true > 0, y_true, zero),
1362-
xp.asarray(2 - p, device=y_true.device),
1362+
xp.asarray(2 - p, device=device_),
13631363
)
13641364
/ ((1 - p) * (2 - p))
1365-
- y_true * xp.pow(y_pred, xp.asarray(1 - p, device=y_true.device)) / (1 - p)
1366-
+ xp.pow(y_pred, xp.asarray(2 - p, device=y_true.device)) / (2 - p)
1365+
- y_true * xp.pow(y_pred, xp.asarray(1 - p, device=device_)) / (1 - p)
1366+
+ xp.pow(y_pred, xp.asarray(2 - p, device=device_)) / (2 - p)
13671367
)
13681368
elif p == 0:
13691369
# Normal distribution, y and y_pred any real number
@@ -1376,10 +1376,9 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
13761376
dev = 2 * (xp.log(y_pred / y_true) + y_true / y_pred - 1)
13771377
else:
13781378
dev = 2 * (
1379-
xp.pow(y_true, xp.asarray(2 - p, device=y_true.device))
1380-
/ ((1 - p) * (2 - p))
1381-
- y_true * xp.pow(y_pred, xp.asarray(1 - p, device=y_true.device)) / (1 - p)
1382-
+ xp.pow(y_pred, xp.asarray(2 - p, device=y_true.device)) / (2 - p)
1379+
xp.pow(y_true, xp.asarray(2 - p, device=device_)) / ((1 - p) * (2 - p))
1380+
- y_true * xp.pow(y_pred, xp.asarray(1 - p, device=device_)) / (1 - p)
1381+
+ xp.pow(y_pred, xp.asarray(2 - p, device=device_)) / (2 - p)
13831382
)
13841383
return float(_average(dev, weights=sample_weight))
13851384

0 commit comments

Comments
 (0)