Skip to content

Commit 9fc9b0e

Browse files
Update loss_functions.py
1 parent a21247e commit 9fc9b0e

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

machine_learning/loss_functions.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -629,15 +629,13 @@ def smooth_l1_loss(y_true: np.ndarray, y_pred: np.ndarray, beta: float = 1.0) ->
629629
return np.mean(loss)
630630

631631

632-
def kullback_leibler_divergence(
633-
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-10
634-
) -> float:
632+
def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float:
635633
"""
636634
Calculate the Kullback-Leibler divergence (KL divergence) loss between true labels
637635
and predicted probabilities.
638636
639-
KL divergence loss quantifies the dissimilarity between true labels and predicted
640-
probabilities. It is often used in training generative models.
637+
KL divergence loss quantifies dissimilarity between true labels and predicted
638+
probabilities. It's often used in training generative models.
641639
642640
KL = Σ(y_true * ln(y_true / y_pred))
643641
@@ -651,7 +649,6 @@ def kullback_leibler_divergence(
651649
>>> predicted_probs = np.array([0.3, 0.3, 0.4])
652650
>>> float(kullback_leibler_divergence(true_labels, predicted_probs))
653651
0.030478754035472025
654-
655652
>>> true_labels = np.array([0.2, 0.3, 0.5])
656653
>>> predicted_probs = np.array([0.3, 0.3, 0.4, 0.5])
657654
>>> kullback_leibler_divergence(true_labels, predicted_probs)
@@ -662,13 +659,7 @@ def kullback_leibler_divergence(
662659
if len(y_true) != len(y_pred):
663660
raise ValueError("Input arrays must have the same length.")
664661

665-
# negligible epsilon to avoid issues with log(0) or division by zero
666-
epsilon = 1e-10
667-
y_pred = np.clip(y_pred, epsilon, None)
668-
669-
# calculate KL divergence only where y_true is not zero
670-
kl_loss = np.where(y_true != 0, y_true * np.log(y_true / y_pred), 0.0)
671-
662+
kl_loss = y_true * np.log(y_true / y_pred)
672663
return np.sum(kl_loss)
673664

674665

0 commit comments

Comments
 (0)