Skip to content

Commit 7a14bcc

Browse files
committed
Clean up and split TensorFlow deps of text.py
1 parent 3378008 commit 7a14bcc

File tree

4 files changed

+86
-119
lines changed

4 files changed

+86
-119
lines changed

evaluate.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from six.moves import zip, range
2020
from util.audio import audiofile_to_input_vector
2121
from util.config import Config, initialize_globals
22+
from util.ctc import ctc_label_dense_to_sparse
2223
from util.flags import create_flags, FLAGS
2324
from util.logging import log_error
2425
from util.preprocess import pmap, preprocess
25-
from util.text import Alphabet, ctc_label_dense_to_sparse, wer, levenshtein
26+
from util.text import Alphabet, wer_cer_batch, levenshtein
2627

2728

2829
def split_data(dataset, batch_size):
@@ -47,15 +48,14 @@ def pad_to_dense(jagged):
4748

4849
def process_decode_result(item):
4950
label, decoding, distance, loss = item
50-
sample_wer = wer(label, decoding)
51+
word_distance = levenshtein(label.split(), decoding.split())
52+
word_length = float(len(label.split()))
5153
return AttrDict({
5254
'src': label,
5355
'res': decoding,
5456
'loss': loss,
5557
'distance': distance,
56-
'wer': sample_wer,
57-
'levenshtein': levenshtein(label.split(), decoding.split()),
58-
'label_length': float(len(label.split())),
58+
'wer': word_distance / word_length,
5959
})
6060

6161

@@ -67,19 +67,16 @@ def calculate_report(labels, decodings, distances, losses):
6767
'''
6868
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))
6969

70-
total_levenshtein = sum(s.levenshtein for s in samples)
71-
total_label_length = sum(s.label_length for s in samples)
72-
73-
# Getting the WER from the accumulated levenshteins and lengths
74-
samples_wer = total_levenshtein / total_label_length
70+
# Getting the WER and CER from the accumulated edit distances and lengths
71+
samples_wer, samples_cer = wer_cer_batch(labels, decodings)
7572

7673
# Order the remaining items by their loss (lowest loss on top)
7774
samples.sort(key=lambda s: s.loss)
7875

7976
# Then order by WER (highest WER on top)
8077
samples.sort(key=lambda s: s.wer, reverse=True)
8178

82-
return samples_wer, samples
79+
return samples_wer, samples_cer, samples
8380

8481

8582
def evaluate(test_data, inference_graph):
@@ -183,15 +180,14 @@ def create_windows(features):
183180

184181
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
185182

186-
wer, samples = calculate_report(ground_truths, predictions, distances, losses)
187-
mean_edit_distance = np.mean(distances)
183+
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
188184
mean_loss = np.mean(losses)
189185

190186
# Take only the first report_count items
191187
report_samples = itertools.islice(samples, FLAGS.report_count)
192188

193189
print('Test - WER: %f, CER: %f, loss: %f' %
194-
(wer, mean_edit_distance, mean_loss))
190+
(wer, cer, mean_loss))
195191
print('-' * 80)
196192
for sample in report_samples:
197193
print('WER: %f, CER: %f, loss: %f' %

util/ctc.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import absolute_import, division, print_function
2+
3+
import tensorflow as tf
4+
5+
from functools import reduce
6+
from six.moves import range
7+
8+
9+
# gather_nd is taken from https://github.com/tensorflow/tensorflow/issues/206#issuecomment-229678962
10+
#
11+
# Unfortunately we can't just use tf.gather_nd because it does not have gradients
12+
# implemented yet, so we need this workaround.
13+
#
14+
def gather_nd(params, indices, shape):
15+
rank = len(shape)
16+
flat_params = tf.reshape(params, [-1])
17+
multipliers = [reduce(lambda x, y: x*y, shape[i+1:], 1) for i in range(0, rank)]
18+
indices_unpacked = tf.unstack(tf.transpose(indices, [rank - 1] + list(range(0, rank - 1))))
19+
flat_indices = sum([a*b for a,b in zip(multipliers, indices_unpacked)])
20+
return tf.gather(flat_params, flat_indices)
21+
22+
23+
# ctc_label_dense_to_sparse is taken from https://github.com/tensorflow/tensorflow/issues/1742#issuecomment-205291527
24+
#
25+
# The CTC implementation in TensorFlow needs labels in a sparse representation,
26+
# but sparse data and queues don't mix well, so we store padded tensors in the
27+
# queue and convert to a sparse representation after dequeuing a batch.
28+
#
29+
def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
30+
# The second dimension of labels must be equal to the longest label length in the batch
31+
correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths))
32+
with tf.control_dependencies([correct_shape_assert]):
33+
labels = tf.identity(labels)
34+
35+
label_shape = tf.shape(labels)
36+
num_batches_tns = tf.stack([label_shape[0]])
37+
max_num_labels_tns = tf.stack([label_shape[1]])
38+
def range_less_than(previous_state, current_input):
39+
return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
40+
41+
init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
42+
init = tf.expand_dims(init, 0)
43+
dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1)
44+
dense_mask = dense_mask[:, 0, :]
45+
46+
label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns),
47+
label_shape)
48+
label_ind = tf.boolean_mask(label_array, dense_mask)
49+
50+
batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), tf.reverse(label_shape, [0])))
51+
batch_ind = tf.boolean_mask(batch_array, dense_mask)
52+
53+
indices = tf.transpose(tf.reshape(tf.concat([batch_ind, label_ind], 0), [2, -1]))
54+
shape = [batch_size, tf.reduce_max(label_lengths)]
55+
vals_sparse = gather_nd(labels, indices, shape)
56+
57+
return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape))

util/feeding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from math import ceil
55
from six.moves import range
66
from threading import Thread
7+
from util.ctc import ctc_label_dense_to_sparse
78
from util.gpu import get_available_gpus
8-
from util.text import ctc_label_dense_to_sparse
99

1010

1111
class ModelFeeder(object):

util/text.py

+18-104
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
import codecs
44
import numpy as np
5-
import tensorflow as tf
65
import re
76
import sys
87

98
from six.moves import range
10-
from functools import reduce
119

1210
class Alphabet(object):
1311
def __init__(self, config_file):
@@ -56,74 +54,39 @@ def size(self):
5654
def config_file(self):
5755
return self._config_file
5856

57+
5958
def text_to_char_array(original, alphabet):
6059
r"""
6160
Given a Python string ``original``, remove unsupported characters, map characters
6261
to integers and return a numpy array representing the processed string.
6362
"""
6463
return np.asarray([alphabet.label_from_string(c) for c in original])
6564

66-
def sparse_tuple_from(sequences, dtype=np.int32):
67-
r"""Creates a sparse representention of ``sequences``.
68-
Args:
69-
* sequences: a list of lists of type dtype where each element is a sequence
70-
71-
Returns a tuple with (indices, values, shape)
72-
"""
73-
indices = []
74-
values = []
75-
76-
for n, seq in enumerate(sequences):
77-
indices.extend(zip([n]*len(seq), range(len(seq))))
78-
values.extend(seq)
7965

80-
indices = np.asarray(indices, dtype=np.int64)
81-
values = np.asarray(values, dtype=dtype)
82-
shape = np.asarray([len(sequences), indices.max(0)[1]+1], dtype=np.int64)
83-
84-
return tf.SparseTensor(indices=indices, values=values, shape=shape)
85-
86-
def sparse_tensor_value_to_texts(value, alphabet):
87-
r"""
88-
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
89-
representing its values.
90-
"""
91-
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
92-
93-
def sparse_tuple_to_texts(tuple, alphabet):
94-
indices = tuple[0]
95-
values = tuple[1]
96-
results = [''] * tuple[2][0]
97-
for i in range(len(indices)):
98-
index = indices[i][0]
99-
results[index] += alphabet.string_from_label(values[i])
100-
# List of strings
101-
return results
102-
103-
def wer(original, result):
66+
def wer_cer_batch(originals, results):
10467
r"""
10568
The WER is defined as the editing/Levenshtein distance on word level
10669
divided by the amount of words in the original text.
10770
In case of the original having more words (N) than the result and both
10871
being totally different (all N words resulting in 1 edit operation each),
10972
the WER will always be 1 (N / N = 1).
11073
"""
111-
# The WER ist calculated on word (and NOT on character) level.
112-
# Therefore we split the strings into words first:
113-
original = original.split()
114-
result = result.split()
115-
return levenshtein(original, result) / float(len(original))
116-
117-
def wers(originals, results):
118-
count = len(originals)
119-
rates = []
120-
mean = 0.0
121-
assert count == len(results)
122-
for i in range(count):
123-
rate = wer(originals[i], results[i])
124-
mean = mean + rate
125-
rates.append(rate)
126-
return rates, mean / float(count)
74+
# The WER is calculated on word (and NOT on character) level.
75+
# Therefore we split the strings into words first
76+
assert len(originals) == len(results)
77+
78+
total_cer = 0.0
79+
80+
total_wer = 0.0
81+
total_word_length = 0.0
82+
83+
for original, result in zip(originals, results):
84+
total_cer += levenshtein(original, result)
85+
86+
total_wer += levenshtein(original.split(), result.split())
87+
total_word_length += len(original.split())
88+
89+
return total_wer / total_word_length, total_cer / len(originals)
12790

12891
# The following code is from: http://hetland.org/coding/python/levenshtein.py
12992

@@ -155,55 +118,6 @@ def levenshtein(a,b):
155118

156119
return current[n]
157120

158-
# gather_nd is taken from https://github.com/tensorflow/tensorflow/issues/206#issuecomment-229678962
159-
#
160-
# Unfortunately we can't just use tf.gather_nd because it does not have gradients
161-
# implemented yet, so we need this workaround.
162-
#
163-
def gather_nd(params, indices, shape):
164-
rank = len(shape)
165-
flat_params = tf.reshape(params, [-1])
166-
multipliers = [reduce(lambda x, y: x*y, shape[i+1:], 1) for i in range(0, rank)]
167-
indices_unpacked = tf.unstack(tf.transpose(indices, [rank - 1] + list(range(0, rank - 1))))
168-
flat_indices = sum([a*b for a,b in zip(multipliers, indices_unpacked)])
169-
return tf.gather(flat_params, flat_indices)
170-
171-
# ctc_label_dense_to_sparse is taken from https://github.com/tensorflow/tensorflow/issues/1742#issuecomment-205291527
172-
#
173-
# The CTC implementation in TensorFlow needs labels in a sparse representation,
174-
# but sparse data and queues don't mix well, so we store padded tensors in the
175-
# queue and convert to a sparse representation after dequeuing a batch.
176-
#
177-
def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
178-
# The second dimension of labels must be equal to the longest label length in the batch
179-
correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths))
180-
with tf.control_dependencies([correct_shape_assert]):
181-
labels = tf.identity(labels)
182-
183-
label_shape = tf.shape(labels)
184-
num_batches_tns = tf.stack([label_shape[0]])
185-
max_num_labels_tns = tf.stack([label_shape[1]])
186-
def range_less_than(previous_state, current_input):
187-
return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
188-
189-
init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
190-
init = tf.expand_dims(init, 0)
191-
dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1)
192-
dense_mask = dense_mask[:, 0, :]
193-
194-
label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns),
195-
label_shape)
196-
label_ind = tf.boolean_mask(label_array, dense_mask)
197-
198-
batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), tf.reverse(label_shape, [0])))
199-
batch_ind = tf.boolean_mask(batch_array, dense_mask)
200-
201-
indices = tf.transpose(tf.reshape(tf.concat([batch_ind, label_ind], 0), [2, -1]))
202-
shape = [batch_size, tf.reduce_max(label_lengths)]
203-
vals_sparse = gather_nd(labels, indices, shape)
204-
205-
return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape))
206-
207121
# Validate and normalize transcriptions. Returns a cleaned version of the label
208122
# or None if it's invalid.
209123
def validate_label(label):

0 commit comments

Comments
 (0)