|
2 | 2 |
|
3 | 3 | import codecs
|
4 | 4 | import numpy as np
|
5 |
| -import tensorflow as tf |
6 | 5 | import re
|
7 | 6 | import sys
|
8 | 7 |
|
9 | 8 | from six.moves import range
|
10 |
| -from functools import reduce |
11 | 9 |
|
12 | 10 | class Alphabet(object):
|
13 | 11 | def __init__(self, config_file):
|
@@ -56,74 +54,42 @@ def size(self):
|
56 | 54 | def config_file(self):
|
57 | 55 | return self._config_file
|
58 | 56 |
|
| 57 | + |
59 | 58 | def text_to_char_array(original, alphabet):
|
60 | 59 | r"""
|
61 | 60 | Given a Python string ``original``, remove unsupported characters, map characters
|
62 | 61 | to integers and return a numpy array representing the processed string.
|
63 | 62 | """
|
64 | 63 | return np.asarray([alphabet.label_from_string(c) for c in original])
|
65 | 64 |
|
66 |
| -def sparse_tuple_from(sequences, dtype=np.int32): |
67 |
| - r"""Creates a sparse representention of ``sequences``. |
68 |
| - Args: |
69 |
| - * sequences: a list of lists of type dtype where each element is a sequence |
70 |
| -
|
71 |
| - Returns a tuple with (indices, values, shape) |
72 |
| - """ |
73 |
| - indices = [] |
74 |
| - values = [] |
75 |
| - |
76 |
| - for n, seq in enumerate(sequences): |
77 |
| - indices.extend(zip([n]*len(seq), range(len(seq)))) |
78 |
| - values.extend(seq) |
79 | 65 |
|
80 |
| - indices = np.asarray(indices, dtype=np.int64) |
81 |
| - values = np.asarray(values, dtype=dtype) |
82 |
| - shape = np.asarray([len(sequences), indices.max(0)[1]+1], dtype=np.int64) |
83 |
| - |
84 |
| - return tf.SparseTensor(indices=indices, values=values, shape=shape) |
85 |
| - |
86 |
| -def sparse_tensor_value_to_texts(value, alphabet): |
87 |
| - r""" |
88 |
| - Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings |
89 |
| - representing its values. |
90 |
| - """ |
91 |
| - return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet) |
92 |
| - |
93 |
| -def sparse_tuple_to_texts(tuple, alphabet): |
94 |
| - indices = tuple[0] |
95 |
| - values = tuple[1] |
96 |
| - results = [''] * tuple[2][0] |
97 |
| - for i in range(len(indices)): |
98 |
| - index = indices[i][0] |
99 |
| - results[index] += alphabet.string_from_label(values[i]) |
100 |
| - # List of strings |
101 |
| - return results |
102 |
| - |
103 |
| -def wer(original, result): |
| 66 | +def wer_cer_batch(originals, results): |
104 | 67 | r"""
|
105 | 68 | The WER is defined as the editing/Levenshtein distance on word level
|
106 | 69 | divided by the amount of words in the original text.
|
107 | 70 | In case of the original having more words (N) than the result and both
|
108 | 71 | being totally different (all N words resulting in 1 edit operation each),
|
109 | 72 | the WER will always be 1 (N / N = 1).
|
110 | 73 | """
|
111 |
| - # The WER ist calculated on word (and NOT on character) level. |
112 |
| - # Therefore we split the strings into words first: |
113 |
| - original = original.split() |
114 |
| - result = result.split() |
115 |
| - return levenshtein(original, result) / float(len(original)) |
116 |
| - |
117 |
| -def wers(originals, results): |
118 |
| - count = len(originals) |
119 |
| - rates = [] |
120 |
| - mean = 0.0 |
121 |
| - assert count == len(results) |
122 |
| - for i in range(count): |
123 |
| - rate = wer(originals[i], results[i]) |
124 |
| - mean = mean + rate |
125 |
| - rates.append(rate) |
126 |
| - return rates, mean / float(count) |
| 74 | + # The WER is calculated on word (and NOT on character) level. |
| 75 | + # Therefore we split the strings into words first |
| 76 | + assert len(originals) == len(results) |
| 77 | + |
| 78 | + total_cer = 0.0 |
| 79 | + total_char_length = 0.0 |
| 80 | + |
| 81 | + total_wer = 0.0 |
| 82 | + total_word_length = 0.0 |
| 83 | + |
| 84 | + for original, result in zip(originals, results): |
| 85 | + total_cer += levenshtein(original, result) |
| 86 | + total_char_length += len(original) |
| 87 | + |
| 88 | + total_wer += levenshtein(original.split(), result.split()) |
| 89 | + total_word_length += len(original.split()) |
| 90 | + |
| 91 | + return total_wer / total_word_length, total_cer / total_char_length |
| 92 | + |
127 | 93 |
|
128 | 94 | # The following code is from: http://hetland.org/coding/python/levenshtein.py
|
129 | 95 |
|
@@ -155,55 +121,6 @@ def levenshtein(a,b):
|
155 | 121 |
|
156 | 122 | return current[n]
|
157 | 123 |
|
158 |
| -# gather_nd is taken from https://github.com/tensorflow/tensorflow/issues/206#issuecomment-229678962 |
159 |
| -# |
160 |
| -# Unfortunately we can't just use tf.gather_nd because it does not have gradients |
161 |
| -# implemented yet, so we need this workaround. |
162 |
| -# |
163 |
| -def gather_nd(params, indices, shape): |
164 |
| - rank = len(shape) |
165 |
| - flat_params = tf.reshape(params, [-1]) |
166 |
| - multipliers = [reduce(lambda x, y: x*y, shape[i+1:], 1) for i in range(0, rank)] |
167 |
| - indices_unpacked = tf.unstack(tf.transpose(indices, [rank - 1] + list(range(0, rank - 1)))) |
168 |
| - flat_indices = sum([a*b for a,b in zip(multipliers, indices_unpacked)]) |
169 |
| - return tf.gather(flat_params, flat_indices) |
170 |
| - |
171 |
| -# ctc_label_dense_to_sparse is taken from https://github.com/tensorflow/tensorflow/issues/1742#issuecomment-205291527 |
172 |
| -# |
173 |
| -# The CTC implementation in TensorFlow needs labels in a sparse representation, |
174 |
| -# but sparse data and queues don't mix well, so we store padded tensors in the |
175 |
| -# queue and convert to a sparse representation after dequeuing a batch. |
176 |
| -# |
177 |
| -def ctc_label_dense_to_sparse(labels, label_lengths, batch_size): |
178 |
| - # The second dimension of labels must be equal to the longest label length in the batch |
179 |
| - correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths)) |
180 |
| - with tf.control_dependencies([correct_shape_assert]): |
181 |
| - labels = tf.identity(labels) |
182 |
| - |
183 |
| - label_shape = tf.shape(labels) |
184 |
| - num_batches_tns = tf.stack([label_shape[0]]) |
185 |
| - max_num_labels_tns = tf.stack([label_shape[1]]) |
186 |
| - def range_less_than(previous_state, current_input): |
187 |
| - return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input |
188 |
| - |
189 |
| - init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool) |
190 |
| - init = tf.expand_dims(init, 0) |
191 |
| - dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1) |
192 |
| - dense_mask = dense_mask[:, 0, :] |
193 |
| - |
194 |
| - label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns), |
195 |
| - label_shape) |
196 |
| - label_ind = tf.boolean_mask(label_array, dense_mask) |
197 |
| - |
198 |
| - batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), tf.reverse(label_shape, [0]))) |
199 |
| - batch_ind = tf.boolean_mask(batch_array, dense_mask) |
200 |
| - |
201 |
| - indices = tf.transpose(tf.reshape(tf.concat([batch_ind, label_ind], 0), [2, -1])) |
202 |
| - shape = [batch_size, tf.reduce_max(label_lengths)] |
203 |
| - vals_sparse = gather_nd(labels, indices, shape) |
204 |
| - |
205 |
| - return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape)) |
206 |
| - |
207 | 124 | # Validate and normalize transcriptions. Returns a cleaned version of the label
|
208 | 125 | # or None if it's invalid.
|
209 | 126 | def validate_label(label):
|
|
0 commit comments