Skip to content

Commit 920932f

Browse files
committed
Remove xp.asarray(..., device=device_) idioms in _mean_tweedie_deviance
1 parent 9b563b2 commit 920932f

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

sklearn/metrics/_regression.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1353,17 +1353,16 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
13531353
"""Mean Tweedie deviance regression loss."""
13541354
xp, _, device_ = get_namespace_and_device(y_true, y_pred)
13551355
p = power
1356-
zero = xp.asarray(0, dtype=y_true.dtype, device=device_)
13571356
if p < 0:
13581357
# 'Extreme stable', y any real number, y_pred > 0
13591358
dev = 2 * (
13601359
xp.pow(
1361-
xp.where(y_true > 0, y_true, zero),
1362-
xp.asarray(2 - p, device=device_),
1360+
xp.where(y_true > 0, y_true, 0),
1361+
(2 - p),
13631362
)
13641363
/ ((1 - p) * (2 - p))
1365-
- y_true * xp.pow(y_pred, xp.asarray(1 - p, device=device_)) / (1 - p)
1366-
+ xp.pow(y_pred, xp.asarray(2 - p, device=device_)) / (2 - p)
1364+
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1365+
+ xp.pow(y_pred, 2 - p) / (2 - p)
13671366
)
13681367
elif p == 0:
13691368
# Normal distribution, y and y_pred any real number
@@ -1376,9 +1375,9 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
13761375
dev = 2 * (xp.log(y_pred / y_true) + y_true / y_pred - 1)
13771376
else:
13781377
dev = 2 * (
1379-
xp.pow(y_true, xp.asarray(2 - p, device=device_)) / ((1 - p) * (2 - p))
1380-
- y_true * xp.pow(y_pred, xp.asarray(1 - p, device=device_)) / (1 - p)
1381-
+ xp.pow(y_pred, xp.asarray(2 - p, device=device_)) / (2 - p)
1378+
xp.pow(y_true, 2 - p) / ((1 - p) * (2 - p))
1379+
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1380+
+ xp.pow(y_pred, 2 - p) / (2 - p)
13821381
)
13831382
return float(_average(dev, weights=sample_weight))
13841383

0 commit comments

Comments
 (0)