Skip to content

Commit f3613da

Browse files
committed
Use tf.contrib.layers.dense_to_sparse instead of util/ctc.py
1 parent 7a14bcc commit f3613da

File tree

3 files changed

+17
-64
lines changed

3 files changed

+17
-64
lines changed

evaluate.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from six.moves import zip, range
2020
from util.audio import audiofile_to_input_vector
2121
from util.config import Config, initialize_globals
22-
from util.ctc import ctc_label_dense_to_sparse
2322
from util.flags import create_flags, FLAGS
2423
from util.logging import log_error
2524
from util.preprocess import pmap, preprocess
@@ -111,7 +110,14 @@ def create_windows(features):
111110
labels_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size, None], name="labels")
112111
label_lengths_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size], name="label_lengths")
113112

114-
sparse_labels = tf.cast(ctc_label_dense_to_sparse(labels_ph, label_lengths_ph, FLAGS.test_batch_size), tf.int32)
113+
# We add 1 to all elements of the transcript to avoid any zero values
114+
# since we use that as an end-of-sequence token for converting the batch
115+
# into a SparseTensor. So here we convert the placeholder back into a
116+
# SparseTensor and subtract ones to get the real labels.
117+
sparse_labels = tf.contrib.layers.dense_to_sparse(labels_ph)
118+
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
119+
sparse_labels = tf.sparse_add(sparse_labels, neg_ones)
120+
115121
loss = tf.nn.ctc_loss(labels=sparse_labels,
116122
inputs=layers['raw_logits'],
117123
sequence_length=inputs['input_lengths'])
@@ -143,7 +149,7 @@ def create_windows(features):
143149

144150
features = pad_to_dense(batch['features'].values)
145151
features_len = batch['features_len'].values
146-
labels = pad_to_dense(batch['transcript'].values)
152+
labels = pad_to_dense(batch['transcript'].values + 1)
147153
label_lengths = batch['transcript_len'].values
148154

149155
logits, loss_ = session.run([transposed, loss], feed_dict={

util/ctc.py

-57
This file was deleted.

util/feeding.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from math import ceil
55
from six.moves import range
66
from threading import Thread
7-
from util.ctc import ctc_label_dense_to_sparse
87
from util.gpu import get_available_gpus
98

109

@@ -143,11 +142,14 @@ def _populate_batch_queue(self, session, coord):
143142
(features.strides[0], features.strides[0], features.strides[1]),
144143
writeable=False)
145144

145+
# We add 1 to all elements of the transcript here to avoid any zero
146+
# values since we use that as an end-of-sequence token for converting
147+
# the batch into a SparseTensor.
146148
try:
147149
session.run(self._enqueue_op, feed_dict={
148150
self._model_feeder.ph_x: features,
149151
self._model_feeder.ph_x_length: num_strides,
150-
self._model_feeder.ph_y: transcript,
152+
self._model_feeder.ph_y: transcript + 1,
151153
self._model_feeder.ph_y_length: transcript_len
152154
})
153155
except tf.errors.CancelledError:
@@ -173,8 +175,10 @@ def next_batch(self):
173175
Draw the next batch from from the combined switchable queue.
174176
'''
175177
source, source_lengths, target, target_lengths = self._queue.dequeue_many(self._model_feeder.ph_batch_size)
176-
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._model_feeder.ph_batch_size)
177-
return source, source_lengths, sparse_labels
178+
# Back to sparse, then subtract one to get the real labels
179+
sparse_labels = tf.contrib.layers.dense_to_sparse(target)
180+
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
181+
return source, source_lengths, tf.sparse_add(sparse_labels, neg_ones)
178182

179183
def start_queue_threads(self, session, coord):
180184
'''

0 commit comments

Comments
 (0)