@@ -246,6 +246,100 @@ def mean_squared_logarithmic_error(y_true: np.ndarray, y_pred: np.ndarray) -> fl
246
246
return np .mean (squared_logarithmic_errors )
247
247
248
248
249
+ def perplexity_loss (y_true : np .ndarray , y_pred : np .ndarray ) -> float :
250
+ """
251
+ Calculate the perplexity for the y_true and y_pred.
252
+
253
+ Compute the Perplexity which useful in predicting language model
254
+ accuracy in Natural Language Processing (NLP.)
255
+ Perplexity is measure of how certain the model in its predictions.
256
+
257
+ Perplexity Loss = exp(-1/N (Σ ln(p(x)))
258
+
259
+ Reference:
260
+ https://en.wikipedia.org/wiki/Perplexity
261
+
262
+ Args:
263
+ y_true: Actual label encoded sentences of shape (batch_size, sentence_length)
264
+ y_pred: Predicted sentences of shape (batch_size, sentence_length, vocab_size)
265
+
266
+ Returns:
267
+ Perplexity loss between y_true and y_pred.
268
+
269
+ >>> y_true = np.array([[1, 4], [2, 3]])
270
+ >>> y_pred = np.array(
271
+ ... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
272
+ ... [0.24, 0.19, 0.09, 0.18, 0.27]],
273
+ ... [[0.03, 0.26, 0.21, 0.18, 0.30],
274
+ ... [0.28, 0.10, 0.33, 0.15, 0.12]]]
275
+ ... )
276
+ >>> perplexity_loss(y_true, y_pred)
277
+ 5.024732177979022
278
+ >>> y_true = np.array([[1, 4], [2, 3]])
279
+ >>> y_pred = np.array(
280
+ ... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
281
+ ... [0.24, 0.19, 0.09, 0.18, 0.27],
282
+ ... [0.30, 0.10, 0.20, 0.15, 0.25]],
283
+ ... [[0.03, 0.26, 0.21, 0.18, 0.30],
284
+ ... [0.28, 0.10, 0.33, 0.15, 0.12],
285
+ ... [0.30, 0.10, 0.20, 0.15, 0.25]],]
286
+ ... )
287
+ >>> perplexity_loss(y_true, y_pred)
288
+ Traceback (most recent call last):
289
+ ...
290
+ ValueError: Sentence length of y_true and y_pred must be equal.
291
+ >>> y_true = np.array([[1, 4], [2, 11]])
292
+ >>> y_pred = np.array(
293
+ ... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
294
+ ... [0.24, 0.19, 0.09, 0.18, 0.27]],
295
+ ... [[0.03, 0.26, 0.21, 0.18, 0.30],
296
+ ... [0.28, 0.10, 0.33, 0.15, 0.12]]]
297
+ ... )
298
+ >>> perplexity_loss(y_true, y_pred)
299
+ Traceback (most recent call last):
300
+ ...
301
+ ValueError: Label value must not be greater than vocabulary size.
302
+ >>> y_true = np.array([[1, 4]])
303
+ >>> y_pred = np.array(
304
+ ... [[[0.28, 0.19, 0.21 , 0.15, 0.15],
305
+ ... [0.24, 0.19, 0.09, 0.18, 0.27]],
306
+ ... [[0.03, 0.26, 0.21, 0.18, 0.30],
307
+ ... [0.28, 0.10, 0.33, 0.15, 0.12]]]
308
+ ... )
309
+ >>> perplexity_loss(y_true, y_pred)
310
+ Traceback (most recent call last):
311
+ ...
312
+ ValueError: Batch size of y_true and y_pred must be equal.
313
+ """
314
+
315
+ # Add small constant to avoid getting inf for log(0)
316
+ epsilon = 1e-7
317
+
318
+ vocab_size = y_pred .shape [2 ]
319
+
320
+ if y_true .shape [0 ] != y_pred .shape [0 ]:
321
+ raise ValueError ("Batch size of y_true and y_pred must be equal." )
322
+ if y_true .shape [1 ] != y_pred .shape [1 ]:
323
+ raise ValueError ("Sentence length of y_true and y_pred must be equal." )
324
+ if np .max (y_true ) > vocab_size :
325
+ raise ValueError ("Label value must not be greater than vocabulary size." )
326
+
327
+ # Matrix to select prediction value only for true class
328
+ filter_matrix = np .array (
329
+ [[list (np .eye (vocab_size )[word ]) for word in sentence ] for sentence in y_true ]
330
+ )
331
+
332
+ # Getting the matrix containing prediction for only true class
333
+ true_class_pred = np .sum (y_pred * filter_matrix , axis = 2 )
334
+
335
+ # Calculating perplexity for each sentence
336
+ perp_losses = np .exp (
337
+ np .negative (np .mean (np .log (true_class_pred + epsilon ), axis = 1 ))
338
+ )
339
+
340
+ return np .mean (perp_losses )
341
+
342
+
249
343
if __name__ == "__main__" :
250
344
import doctest
251
345
0 commit comments