Skip to content

Commit 6ac8f08

Browse files
nkconnorChoiByungWook
authored andcommitted
PipeModeDataset code block example improvements (#346)
1 parent e3dc3b5 commit 6ac8f08

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

src/sagemaker/tensorflow/README.rst

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,6 @@ In your ``entry_point`` script, you can use ``PipeModeDataset`` like a ``Dataset
779779
780780
from sagemaker_tensorflow import PipeModeDataset
781781
782-
ds = PipeModeDataset(channel='training', record_format='TFRecord')
783-
784782
features = {
785783
'data': tf.FixedLenFeature([], tf.string),
786784
'labels': tf.FixedLenFeature([], tf.int64),
@@ -792,12 +790,13 @@ In your ``entry_point`` script, you can use ``PipeModeDataset`` like a ``Dataset
792790
'data': tf.decode_raw(parsed['data'], tf.float64)
793791
}, parsed['labels'])
794792
795-
ds = PipeModeDataset(channel='training', record_format='TFRecord')
796-
num_epochs = 20
797-
ds = ds.repeat(num_epochs)
798-
ds = ds.prefetch(10)
799-
ds = ds.map(parse, num_parallel_calls=10)
800-
ds = ds.batch(64)
793+
def train_input_fn(training_dir, hyperparameters):
794+
ds = PipeModeDataset(channel='training', record_format='TFRecord')
795+
ds = ds.repeat(20)
796+
ds = ds.prefetch(10)
797+
ds = ds.map(parse, num_parallel_calls=10)
798+
ds = ds.batch(64)
799+
return ds
801800
802801
803802
To run training job with Pipe input mode, pass in ``input_mode='Pipe'`` to your TensorFlow Estimator:

0 commit comments

Comments
 (0)