You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Maybe I am mistaken, but when looking into the implementation of the dice loss I noticed that it computes the loss across the whole batch rather than over each image. At least for me this lead to extremely bouncy gradients when running loss.backward(). After going over some literature, I also always found that the loss should be computed across each image and then reduced (e.g. in most cases just the mean). In fact, computing it for the whole batch vs. for each image and then average gives very different results.
Luckily, the fix is very easy. In this script, change within the code of class DiceLoss the forward function line 72 to dims = (2) instead of dims = (0,2).
At least this gives the desired outputs for multiclass problems, no matter the batchsize (tested batch sizes 1 and 3 and 32 and all worked perfectly fine). I didn't test other cases though.
Best,
Lasse
The text was updated successfully, but these errors were encountered:
Uh oh!
There was an error while loading. Please reload this page.
Hello there!
First of all, thank you for this great package!
Maybe I am mistaken, but when looking into the implementation of the dice loss I noticed that it computes the loss across the whole batch rather than over each image. At least for me this lead to extremely bouncy gradients when running loss.backward(). After going over some literature, I also always found that the loss should be computed across each image and then reduced (e.g. in most cases just the mean). In fact, computing it for the whole batch vs. for each image and then average gives very different results.
Luckily, the fix is very easy. In this script, change within the code of class DiceLoss the forward function line 72 to dims = (2) instead of dims = (0,2).
At least this gives the desired outputs for multiclass problems, no matter the batchsize (tested batch sizes 1 and 3 and 32 and all worked perfectly fine). I didn't test other cases though.
Best,
Lasse
The text was updated successfully, but these errors were encountered: