Skip to content

Commit e4eda14

Browse files
authored
Add perplexity loss algorithm (#11028)
1 parent 34eb9c5 commit e4eda14

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

Diff for: machine_learning/loss_functions.py

+92
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,98 @@ def mean_absolute_percentage_error(
379379
return np.mean(absolute_percentage_diff)
380380

381381

382+
def perplexity_loss(
383+
y_true: np.ndarray, y_pred: np.ndarray, epsilon: float = 1e-7
384+
) -> float:
385+
"""
386+
Calculate the perplexity for the y_true and y_pred.
387+
388+
Compute the Perplexity which useful in predicting language model
389+
accuracy in Natural Language Processing (NLP.)
390+
Perplexity is measure of how certain the model in its predictions.
391+
392+
Perplexity Loss = exp(-1/N (Σ ln(p(x)))
393+
394+
Reference:
395+
https://en.wikipedia.org/wiki/Perplexity
396+
397+
Args:
398+
y_true: Actual label encoded sentences of shape (batch_size, sentence_length)
399+
y_pred: Predicted sentences of shape (batch_size, sentence_length, vocab_size)
400+
epsilon: Small floating point number to avoid getting inf for log(0)
401+
402+
Returns:
403+
Perplexity loss between y_true and y_pred.
404+
405+
>>> y_true = np.array([[1, 4], [2, 3]])
406+
>>> y_pred = np.array(
407+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
408+
... [0.24, 0.19, 0.09, 0.18, 0.27]],
409+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
410+
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
411+
... )
412+
>>> perplexity_loss(y_true, y_pred)
413+
5.0247347775367945
414+
>>> y_true = np.array([[1, 4], [2, 3]])
415+
>>> y_pred = np.array(
416+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
417+
... [0.24, 0.19, 0.09, 0.18, 0.27],
418+
... [0.30, 0.10, 0.20, 0.15, 0.25]],
419+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
420+
... [0.28, 0.10, 0.33, 0.15, 0.12],
421+
... [0.30, 0.10, 0.20, 0.15, 0.25]],]
422+
... )
423+
>>> perplexity_loss(y_true, y_pred)
424+
Traceback (most recent call last):
425+
...
426+
ValueError: Sentence length of y_true and y_pred must be equal.
427+
>>> y_true = np.array([[1, 4], [2, 11]])
428+
>>> y_pred = np.array(
429+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
430+
... [0.24, 0.19, 0.09, 0.18, 0.27]],
431+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
432+
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
433+
... )
434+
>>> perplexity_loss(y_true, y_pred)
435+
Traceback (most recent call last):
436+
...
437+
ValueError: Label value must not be greater than vocabulary size.
438+
>>> y_true = np.array([[1, 4]])
439+
>>> y_pred = np.array(
440+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
441+
... [0.24, 0.19, 0.09, 0.18, 0.27]],
442+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
443+
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
444+
... )
445+
>>> perplexity_loss(y_true, y_pred)
446+
Traceback (most recent call last):
447+
...
448+
ValueError: Batch size of y_true and y_pred must be equal.
449+
"""
450+
451+
vocab_size = y_pred.shape[2]
452+
453+
if y_true.shape[0] != y_pred.shape[0]:
454+
raise ValueError("Batch size of y_true and y_pred must be equal.")
455+
if y_true.shape[1] != y_pred.shape[1]:
456+
raise ValueError("Sentence length of y_true and y_pred must be equal.")
457+
if np.max(y_true) > vocab_size:
458+
raise ValueError("Label value must not be greater than vocabulary size.")
459+
460+
# Matrix to select prediction value only for true class
461+
filter_matrix = np.array(
462+
[[list(np.eye(vocab_size)[word]) for word in sentence] for sentence in y_true]
463+
)
464+
465+
# Getting the matrix containing prediction for only true class
466+
true_class_pred = np.sum(y_pred * filter_matrix, axis=2).clip(epsilon, 1)
467+
468+
# Calculating perplexity for each sentence
469+
perp_losses = np.exp(np.negative(np.mean(np.log(true_class_pred), axis=1)))
470+
471+
return np.mean(perp_losses)
472+
473+
382474
if __name__ == "__main__":
383475
import doctest
384476

0 commit comments

Comments
 (0)