Skip to content

Commit 31eb62b

Browse files
committed
moved perplexity loss code
1 parent cfaec83 commit 31eb62b

File tree

2 files changed

+94
-108
lines changed

2 files changed

+94
-108
lines changed

Diff for: machine_learning/loss_functions.py

+94
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,100 @@ def mean_squared_logarithmic_error(y_true: np.ndarray, y_pred: np.ndarray) -> fl
246246
return np.mean(squared_logarithmic_errors)
247247

248248

249+
def perplexity_loss(y_true: np.ndarray, y_pred: np.ndarray) -> float:
250+
"""
251+
Calculate the perplexity for the y_true and y_pred.
252+
253+
Compute the Perplexity which useful in predicting language model
254+
accuracy in Natural Language Processing (NLP.)
255+
Perplexity is measure of how certain the model in its predictions.
256+
257+
Perplexity Loss = exp(-1/N (Σ ln(p(x)))
258+
259+
Reference:
260+
https://en.wikipedia.org/wiki/Perplexity
261+
262+
Args:
263+
y_true: Actual label encoded sentences of shape (batch_size, sentence_length)
264+
y_pred: Predicted sentences of shape (batch_size, sentence_length, vocab_size)
265+
266+
Returns:
267+
Perplexity loss between y_true and y_pred.
268+
269+
>>> y_true = np.array([[1, 4], [2, 3]])
270+
>>> y_pred = np.array(
271+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
272+
... [0.24, 0.19, 0.09, 0.18, 0.27]],
273+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
274+
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
275+
... )
276+
>>> perplexity_loss(y_true, y_pred)
277+
5.024732177979022
278+
>>> y_true = np.array([[1, 4], [2, 3]])
279+
>>> y_pred = np.array(
280+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
281+
... [0.24, 0.19, 0.09, 0.18, 0.27],
282+
... [0.30, 0.10, 0.20, 0.15, 0.25]],
283+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
284+
... [0.28, 0.10, 0.33, 0.15, 0.12],
285+
... [0.30, 0.10, 0.20, 0.15, 0.25]],]
286+
... )
287+
>>> perplexity_loss(y_true, y_pred)
288+
Traceback (most recent call last):
289+
...
290+
ValueError: Sentence length of y_true and y_pred must be equal.
291+
>>> y_true = np.array([[1, 4], [2, 11]])
292+
>>> y_pred = np.array(
293+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
294+
... [0.24, 0.19, 0.09, 0.18, 0.27]],
295+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
296+
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
297+
... )
298+
>>> perplexity_loss(y_true, y_pred)
299+
Traceback (most recent call last):
300+
...
301+
ValueError: Label value must not be greater than vocabulary size.
302+
>>> y_true = np.array([[1, 4]])
303+
>>> y_pred = np.array(
304+
... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
305+
... [0.24, 0.19, 0.09, 0.18, 0.27]],
306+
... [[0.03, 0.26, 0.21, 0.18, 0.30],
307+
... [0.28, 0.10, 0.33, 0.15, 0.12]]]
308+
... )
309+
>>> perplexity_loss(y_true, y_pred)
310+
Traceback (most recent call last):
311+
...
312+
ValueError: Batch size of y_true and y_pred must be equal.
313+
"""
314+
315+
# Add small constant to avoid getting inf for log(0)
316+
epsilon = 1e-7
317+
318+
vocab_size = y_pred.shape[2]
319+
320+
if y_true.shape[0] != y_pred.shape[0]:
321+
raise ValueError("Batch size of y_true and y_pred must be equal.")
322+
if y_true.shape[1] != y_pred.shape[1]:
323+
raise ValueError("Sentence length of y_true and y_pred must be equal.")
324+
if np.max(y_true) > vocab_size:
325+
raise ValueError("Label value must not be greater than vocabulary size.")
326+
327+
# Matrix to select prediction value only for true class
328+
filter_matrix = np.array(
329+
[[list(np.eye(vocab_size)[word]) for word in sentence] for sentence in y_true]
330+
)
331+
332+
# Getting the matrix containing prediction for only true class
333+
true_class_pred = np.sum(y_pred * filter_matrix, axis=2)
334+
335+
# Calculating perplexity for each sentence
336+
perp_losses = np.exp(
337+
np.negative(np.mean(np.log(true_class_pred + epsilon), axis=1))
338+
)
339+
340+
return np.mean(perp_losses)
341+
342+
249343
if __name__ == "__main__":
250344
import doctest
251345

Diff for: machine_learning/loss_functions/perplexity_loss.py

-108
This file was deleted.

0 commit comments

Comments
 (0)