Skip to content

Commit 4356704

Browse files
authored
Merge pull request #1862 from mozilla/split_tf_deps
Clean up and split TensorFlow deps out of util/text.py
2 parents 1741467 + 12c6275 commit 4356704

File tree

3 files changed

+47
-124
lines changed

3 files changed

+47
-124
lines changed

evaluate.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from util.flags import create_flags, FLAGS
2323
from util.logging import log_error
2424
from util.preprocess import pmap, preprocess
25-
from util.text import Alphabet, ctc_label_dense_to_sparse, wer, levenshtein
25+
from util.text import Alphabet, wer_cer_batch, levenshtein
2626

2727

2828
def split_data(dataset, batch_size):
@@ -47,15 +47,14 @@ def pad_to_dense(jagged):
4747

4848
def process_decode_result(item):
4949
label, decoding, distance, loss = item
50-
sample_wer = wer(label, decoding)
50+
word_distance = levenshtein(label.split(), decoding.split())
51+
word_length = float(len(label.split()))
5152
return AttrDict({
5253
'src': label,
5354
'res': decoding,
5455
'loss': loss,
5556
'distance': distance,
56-
'wer': sample_wer,
57-
'levenshtein': levenshtein(label.split(), decoding.split()),
58-
'label_length': float(len(label.split())),
57+
'wer': word_distance / word_length,
5958
})
6059

6160

@@ -67,19 +66,16 @@ def calculate_report(labels, decodings, distances, losses):
6766
'''
6867
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))
6968

70-
total_levenshtein = sum(s.levenshtein for s in samples)
71-
total_label_length = sum(s.label_length for s in samples)
72-
73-
# Getting the WER from the accumulated levenshteins and lengths
74-
samples_wer = total_levenshtein / total_label_length
69+
# Getting the WER and CER from the accumulated edit distances and lengths
70+
samples_wer, samples_cer = wer_cer_batch(labels, decodings)
7571

7672
# Order the remaining items by their loss (lowest loss on top)
7773
samples.sort(key=lambda s: s.loss)
7874

7975
# Then order by WER (highest WER on top)
8076
samples.sort(key=lambda s: s.wer, reverse=True)
8177

82-
return samples_wer, samples
78+
return samples_wer, samples_cer, samples
8379

8480

8581
def evaluate(test_data, inference_graph):
@@ -114,7 +110,14 @@ def create_windows(features):
114110
labels_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size, None], name="labels")
115111
label_lengths_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size], name="label_lengths")
116112

117-
sparse_labels = tf.cast(ctc_label_dense_to_sparse(labels_ph, label_lengths_ph, FLAGS.test_batch_size), tf.int32)
113+
# We add 1 to all elements of the transcript to avoid any zero values
114+
# since we use that as an end-of-sequence token for converting the batch
115+
# into a SparseTensor. So here we convert the placeholder back into a
116+
# SparseTensor and subtract ones to get the real labels.
117+
sparse_labels = tf.contrib.layers.dense_to_sparse(labels_ph)
118+
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
119+
sparse_labels = tf.sparse_add(sparse_labels, neg_ones)
120+
118121
loss = tf.nn.ctc_loss(labels=sparse_labels,
119122
inputs=layers['raw_logits'],
120123
sequence_length=inputs['input_lengths'])
@@ -146,7 +149,7 @@ def create_windows(features):
146149

147150
features = pad_to_dense(batch['features'].values)
148151
features_len = batch['features_len'].values
149-
labels = pad_to_dense(batch['transcript'].values)
152+
labels = pad_to_dense(batch['transcript'].values + 1)
150153
label_lengths = batch['transcript_len'].values
151154

152155
logits, loss_ = session.run([transposed, loss], feed_dict={
@@ -183,15 +186,14 @@ def create_windows(features):
183186

184187
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
185188

186-
wer, samples = calculate_report(ground_truths, predictions, distances, losses)
187-
mean_edit_distance = np.mean(distances)
189+
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
188190
mean_loss = np.mean(losses)
189191

190192
# Take only the first report_count items
191193
report_samples = itertools.islice(samples, FLAGS.report_count)
192194

193195
print('Test - WER: %f, CER: %f, loss: %f' %
194-
(wer, mean_edit_distance, mean_loss))
196+
(wer, cer, mean_loss))
195197
print('-' * 80)
196198
for sample in report_samples:
197199
print('WER: %f, CER: %f, loss: %f' %

util/feeding.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from six.moves import range
66
from threading import Thread
77
from util.gpu import get_available_gpus
8-
from util.text import ctc_label_dense_to_sparse
98

109

1110
class ModelFeeder(object):
@@ -143,11 +142,14 @@ def _populate_batch_queue(self, session, coord):
143142
(features.strides[0], features.strides[0], features.strides[1]),
144143
writeable=False)
145144

145+
# We add 1 to all elements of the transcript here to avoid any zero
146+
# values since we use that as an end-of-sequence token for converting
147+
# the batch into a SparseTensor.
146148
try:
147149
session.run(self._enqueue_op, feed_dict={
148150
self._model_feeder.ph_x: features,
149151
self._model_feeder.ph_x_length: num_strides,
150-
self._model_feeder.ph_y: transcript,
152+
self._model_feeder.ph_y: transcript + 1,
151153
self._model_feeder.ph_y_length: transcript_len
152154
})
153155
except tf.errors.CancelledError:
@@ -173,8 +175,10 @@ def next_batch(self):
173175
Draw the next batch from from the combined switchable queue.
174176
'''
175177
source, source_lengths, target, target_lengths = self._queue.dequeue_many(self._model_feeder.ph_batch_size)
176-
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._model_feeder.ph_batch_size)
177-
return source, source_lengths, sparse_labels
178+
# Back to sparse, then subtract one to get the real labels
179+
sparse_labels = tf.contrib.layers.dense_to_sparse(target)
180+
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
181+
return source, source_lengths, tf.sparse_add(sparse_labels, neg_ones)
178182

179183
def start_queue_threads(self, session, coord):
180184
'''

util/text.py

+21-104
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
import codecs
44
import numpy as np
5-
import tensorflow as tf
65
import re
76
import sys
87

98
from six.moves import range
10-
from functools import reduce
119

1210
class Alphabet(object):
1311
def __init__(self, config_file):
@@ -56,74 +54,42 @@ def size(self):
5654
def config_file(self):
5755
return self._config_file
5856

57+
5958
def text_to_char_array(original, alphabet):
6059
r"""
6160
Given a Python string ``original``, remove unsupported characters, map characters
6261
to integers and return a numpy array representing the processed string.
6362
"""
6463
return np.asarray([alphabet.label_from_string(c) for c in original])
6564

66-
def sparse_tuple_from(sequences, dtype=np.int32):
67-
r"""Creates a sparse representention of ``sequences``.
68-
Args:
69-
* sequences: a list of lists of type dtype where each element is a sequence
70-
71-
Returns a tuple with (indices, values, shape)
72-
"""
73-
indices = []
74-
values = []
75-
76-
for n, seq in enumerate(sequences):
77-
indices.extend(zip([n]*len(seq), range(len(seq))))
78-
values.extend(seq)
7965

80-
indices = np.asarray(indices, dtype=np.int64)
81-
values = np.asarray(values, dtype=dtype)
82-
shape = np.asarray([len(sequences), indices.max(0)[1]+1], dtype=np.int64)
83-
84-
return tf.SparseTensor(indices=indices, values=values, shape=shape)
85-
86-
def sparse_tensor_value_to_texts(value, alphabet):
87-
r"""
88-
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
89-
representing its values.
90-
"""
91-
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
92-
93-
def sparse_tuple_to_texts(tuple, alphabet):
94-
indices = tuple[0]
95-
values = tuple[1]
96-
results = [''] * tuple[2][0]
97-
for i in range(len(indices)):
98-
index = indices[i][0]
99-
results[index] += alphabet.string_from_label(values[i])
100-
# List of strings
101-
return results
102-
103-
def wer(original, result):
66+
def wer_cer_batch(originals, results):
10467
r"""
10568
The WER is defined as the editing/Levenshtein distance on word level
10669
divided by the amount of words in the original text.
10770
In case of the original having more words (N) than the result and both
10871
being totally different (all N words resulting in 1 edit operation each),
10972
the WER will always be 1 (N / N = 1).
11073
"""
111-
# The WER ist calculated on word (and NOT on character) level.
112-
# Therefore we split the strings into words first:
113-
original = original.split()
114-
result = result.split()
115-
return levenshtein(original, result) / float(len(original))
116-
117-
def wers(originals, results):
118-
count = len(originals)
119-
rates = []
120-
mean = 0.0
121-
assert count == len(results)
122-
for i in range(count):
123-
rate = wer(originals[i], results[i])
124-
mean = mean + rate
125-
rates.append(rate)
126-
return rates, mean / float(count)
74+
# The WER is calculated on word (and NOT on character) level.
75+
# Therefore we split the strings into words first
76+
assert len(originals) == len(results)
77+
78+
total_cer = 0.0
79+
total_char_length = 0.0
80+
81+
total_wer = 0.0
82+
total_word_length = 0.0
83+
84+
for original, result in zip(originals, results):
85+
total_cer += levenshtein(original, result)
86+
total_char_length += len(original)
87+
88+
total_wer += levenshtein(original.split(), result.split())
89+
total_word_length += len(original.split())
90+
91+
return total_wer / total_word_length, total_cer / total_char_length
92+
12793

12894
# The following code is from: http://hetland.org/coding/python/levenshtein.py
12995

@@ -155,55 +121,6 @@ def levenshtein(a,b):
155121

156122
return current[n]
157123

158-
# gather_nd is taken from https://github.com/tensorflow/tensorflow/issues/206#issuecomment-229678962
159-
#
160-
# Unfortunately we can't just use tf.gather_nd because it does not have gradients
161-
# implemented yet, so we need this workaround.
162-
#
163-
def gather_nd(params, indices, shape):
164-
rank = len(shape)
165-
flat_params = tf.reshape(params, [-1])
166-
multipliers = [reduce(lambda x, y: x*y, shape[i+1:], 1) for i in range(0, rank)]
167-
indices_unpacked = tf.unstack(tf.transpose(indices, [rank - 1] + list(range(0, rank - 1))))
168-
flat_indices = sum([a*b for a,b in zip(multipliers, indices_unpacked)])
169-
return tf.gather(flat_params, flat_indices)
170-
171-
# ctc_label_dense_to_sparse is taken from https://github.com/tensorflow/tensorflow/issues/1742#issuecomment-205291527
172-
#
173-
# The CTC implementation in TensorFlow needs labels in a sparse representation,
174-
# but sparse data and queues don't mix well, so we store padded tensors in the
175-
# queue and convert to a sparse representation after dequeuing a batch.
176-
#
177-
def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
178-
# The second dimension of labels must be equal to the longest label length in the batch
179-
correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths))
180-
with tf.control_dependencies([correct_shape_assert]):
181-
labels = tf.identity(labels)
182-
183-
label_shape = tf.shape(labels)
184-
num_batches_tns = tf.stack([label_shape[0]])
185-
max_num_labels_tns = tf.stack([label_shape[1]])
186-
def range_less_than(previous_state, current_input):
187-
return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
188-
189-
init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
190-
init = tf.expand_dims(init, 0)
191-
dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1)
192-
dense_mask = dense_mask[:, 0, :]
193-
194-
label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns),
195-
label_shape)
196-
label_ind = tf.boolean_mask(label_array, dense_mask)
197-
198-
batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns), tf.reverse(label_shape, [0])))
199-
batch_ind = tf.boolean_mask(batch_array, dense_mask)
200-
201-
indices = tf.transpose(tf.reshape(tf.concat([batch_ind, label_ind], 0), [2, -1]))
202-
shape = [batch_size, tf.reduce_max(label_lengths)]
203-
vals_sparse = gather_nd(labels, indices, shape)
204-
205-
return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape))
206-
207124
# Validate and normalize transcriptions. Returns a cleaned version of the label
208125
# or None if it's invalid.
209126
def validate_label(label):

0 commit comments

Comments
 (0)