From 6982a6c1e1574ec9646581fad3cfe00f429495ac Mon Sep 17 00:00:00 2001 From: Baek-Donghyeon <71208448+nistring@users.noreply.github.com> Date: Wed, 14 May 2025 07:34:51 +0900 Subject: [PATCH] Update dice.py The dice loss varies with batch sizes, and this issue remains unresolved. https://github.com/qubvel-org/segmentation_models.pytorch/issues/712#issue-1550779281 --- segmentation_models_pytorch/losses/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/losses/dice.py b/segmentation_models_pytorch/losses/dice.py index b8baae98..38a074f2 100644 --- a/segmentation_models_pytorch/losses/dice.py +++ b/segmentation_models_pytorch/losses/dice.py @@ -70,7 +70,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: bs = y_true.size(0) num_classes = y_pred.size(1) - dims = (0, 2) + dims = (2) if self.mode == BINARY_MODE: y_true = y_true.view(bs, 1, -1)