From 95859d3a0cc3e368cdc19c48459788dd163a520c Mon Sep 17 00:00:00 2001 From: AtomicVar Date: Sat, 13 Jan 2024 22:58:01 +0800 Subject: [PATCH] Add connectionist temporal classification (CTC) loss algorithm --- machine_learning/loss_functions.py | 104 +++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/machine_learning/loss_functions.py b/machine_learning/loss_functions.py index 36a760326f3d..327cea06175f 100644 --- a/machine_learning/loss_functions.py +++ b/machine_learning/loss_functions.py @@ -471,6 +471,110 @@ def perplexity_loss( return np.mean(perp_losses) +def connectionist_temporal_classification_loss( + y_true: np.ndarray, y_pred: np.ndarray, blank: int = 0 +): + """ + Calculate the connectionist temporal classification (CTC) loss between the given + log probabilities and targets. + + CTC loss is used in speech recognition, handwriting recognition and other sequence + problems. It's used to get around not knowing the alignment between the input and + the output. + + References: + - https://en.wikipedia.org/wiki/Connectionist_temporal_classification + - https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html + + Parameters: + - y_true: True labels (containing class indices). + - y_pred: Predicted class probabilities for each input timestep. + - blank: Index of the blank token, default: 0. + + Returns: + - CTC loss between y_true and y_pred. + + >>> y_true = np.array([1, 2, 3]) + >>> y_pred = np.array([[0.1, 0.6, 0.1, 0.2], + ... [0.2, 0.1, 0.5, 0.2], + ... [0.2, 0.1, 0.5, 0.2]]) + >>> connectionist_temporal_classification_loss(y_true, y_pred) + 2.8134107167600364 + + >>> y_true = np.array([1, 2, 3, 1]) + >>> y_pred = np.random.rand(3, 4) + >>> connectionist_temporal_classification_loss(y_true, y_pred) + Traceback (most recent call last): + ... + ValueError: y_true cannot be longer than y_pred. + + >>> y_true = np.array([[1, 2, 3]]) + >>> y_pred = np.random.rand(3, 4) + >>> connectionist_temporal_classification_loss(y_true, y_pred) + Traceback (most recent call last): + ... + ValueError: y_true should be an 1D array. + + >>> y_true = np.array([1, 2, 3]) + >>> y_pred = np.array([0.1, 0.6, 0.1, 0.2]) + >>> connectionist_temporal_classification_loss(y_true, y_pred) + Traceback (most recent call last): + ... + ValueError: y_pred should be a 2D array. + + >>> y_true = np.array([1, 2, 3]) + >>> y_pred = np.array([[0.1, 0.6, 0.1], [0.2, 0.1, 0.5], [0.2, 0.1, 0.5]]) + >>> connectionist_temporal_classification_loss(y_true, y_pred) + Traceback (most recent call last): + ... + ValueError: Class indices in y_true should be less than y_pred.shape[1]. + """ + if len(y_true) > len(y_pred): + raise ValueError("y_true cannot be longer than y_pred.") + + if y_true.ndim != 1: + raise ValueError("y_true should be an 1D array.") + + if y_pred.ndim != 2: + raise ValueError("y_pred should be a 2D array.") + + if np.max(y_true) >= y_pred.shape[1]: + raise ValueError("Class indices in y_true should be less than y_pred.shape[1].") + + log_probs = np.log(y_pred) + input_len = log_probs.shape[0] # Input sequence length + target_len = len(y_true) # Target sequence length + target_len_extended = 2 * target_len + 1 # Target sequence length with blanks + + # Initialize blank and target sequences + extended_targets = np.full(target_len_extended, blank) + extended_targets[1::2] = y_true + + # Initialize alpha (forward variable) + alpha = np.full((input_len, target_len_extended), -np.inf) + alpha[0, 0] = log_probs[0, blank] # Starting with blank + if target_len_extended > 1: + alpha[0, 1] = log_probs[0, extended_targets[1]] + + # Dynamic programming to calculate alpha + for t in range(1, input_len): + for s in range(target_len_extended): + current_label = extended_targets[s] + alpha[t, s] = alpha[t - 1, s] + if s > 0: + alpha[t, s] = np.logaddexp(alpha[t, s], alpha[t - 1, s - 1]) + if s > 1 and current_label != extended_targets[s - 2]: + alpha[t, s] = np.logaddexp(alpha[t, s], alpha[t - 1, s - 2]) + alpha[t, s] += log_probs[t, current_label] + + # CTC loss is the negative log probability of the target sequence + loss = -np.logaddexp( + alpha[input_len - 1, target_len_extended - 1], + alpha[input_len - 1, target_len_extended - 2], + ) + return loss + + if __name__ == "__main__": import doctest