Skip to content

Commit 5f98ffc

Browse files
committed
add mask entries where y_true is 0
1 parent af2a2aa commit 5f98ffc

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

machine_learning/loss_functions.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,10 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
658658
"""
659659
if len(y_true) != len(y_pred):
660660
raise ValueError("Input arrays must have the same length.")
661-
662-
kl_loss = y_true * np.log(y_true / y_pred)
661+
662+
# Mask y_true is 0 to avoid invalid log calculation
663+
mask = y_true != 0
664+
kl_loss = y_true[mask] * np.log(y_true[mask] / y_pred[mask])
663665
return np.sum(kl_loss)
664666

665667

0 commit comments

Comments
 (0)