Skip to content

Commit ff8f23f

Browse files
author
Dan Choi
committed
Add more comments
1 parent 0930094 commit ff8f23f

File tree

2 files changed

+135
-107
lines changed

2 files changed

+135
-107
lines changed

advanced_functionality/tensorflow_bring_your_own/tensorflow_bring_your_own.ipynb

Lines changed: 96 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@
256256
"source": [
257257
"### Building and registering the container\n",
258258
"\n",
259-
"The following shell code shows how to build the container image using `docker build` and push the container image to ECR using `docker push`. This code is also available as the shell script `container/build-and-push.sh`, which you can run as `build-and-push.sh tensorflow_cifar10_example` to build the image `tensorflow_cifar10_example`. \n",
259+
"The following shell code shows how to build the container image using `docker build` and push the container image to ECR using `docker push`. This code is also available as the shell script `container/build-and-push.sh`, which you can run as `build-and-push.sh tensorflow-cifar10-example` to build the image `tensorflow-cifar10-example`. \n",
260260
"\n",
261261
"This code looks for an ECR repository in the account you're using and the current default region (if you're using a SageMaker notebook instance, this is the region where the notebook instance was created). If the repository doesn't exist, the script will create it."
262262
]
@@ -270,7 +270,7 @@
270270
"%%sh\n",
271271
"\n",
272272
"# The name of our algorithm\n",
273-
"algorithm_name=tensorflow_cifar10_example\n",
273+
"algorithm_name=tensorflow-cifar10-example\n",
274274
"\n",
275275
"cd container\n",
276276
"\n",
@@ -321,7 +321,7 @@
321321
"source": [
322322
"## Download the CIFAR-10 dataset\n",
323323
"Our training algorithm is expecting our training data to be in the file format of [TFRecords](https://www.tensorflow.org/guide/datasets), which is a simple record-oriented binary format that many TensorFlow applications use for training data.\n",
324-
"Below is a python script from the official TensorFlow CIFAR10 example, which downloads the CIFAR-10 dataset and converts them into TFRecords."
324+
"Below is a Python script adapted from the [official TensorFlow CIFAR-10 example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator), which downloads the CIFAR-10 dataset and converts them into TFRecords."
325325
]
326326
},
327327
{
@@ -335,11 +335,20 @@
335335
},
336336
{
337337
"cell_type": "code",
338-
"execution_count": null,
338+
"execution_count": 2,
339339
"metadata": {},
340-
"outputs": [],
340+
"outputs": [
341+
{
342+
"name": "stdout",
343+
"output_type": "stream",
344+
"text": [
345+
"eval.tfrecords\ttrain.tfrecords validation.tfrecords\r\n"
346+
]
347+
}
348+
],
341349
"source": [
342-
"! ls /tmp/cifar-10-data "
350+
"# There should be three tfrecords. (eval, train, validation)\n",
351+
"! ls /tmp/cifar-10-data"
343352
]
344353
},
345354
{
@@ -359,7 +368,7 @@
359368
},
360369
{
361370
"cell_type": "code",
362-
"execution_count": 5,
371+
"execution_count": 3,
363372
"metadata": {},
364373
"outputs": [],
365374
"source": [
@@ -387,10 +396,20 @@
387396
},
388397
{
389398
"cell_type": "code",
390-
"execution_count": null,
399+
"execution_count": 4,
391400
"metadata": {},
392-
"outputs": [],
401+
"outputs": [
402+
{
403+
"name": "stdout",
404+
"output_type": "stream",
405+
"text": [
406+
"SageMaker instance route table setup is ok. We are good to go.\r\n",
407+
"SageMaker instance routing for Docker is ok. We are good to go!\r\n"
408+
]
409+
}
410+
],
393411
"source": [
412+
"# Lets set up our SageMaker notebook instance for local mode.\n",
394413
"!/bin/bash ./utils/setup.sh"
395414
]
396415
},
@@ -428,34 +447,46 @@
428447
},
429448
{
430449
"cell_type": "code",
431-
"execution_count": 11,
450+
"execution_count": null,
451+
"metadata": {},
452+
"outputs": [],
453+
"source": [
454+
"! pip install opencv-python"
455+
]
456+
},
457+
{
458+
"cell_type": "code",
459+
"execution_count": 6,
432460
"metadata": {},
433461
"outputs": [
434462
{
435463
"name": "stdout",
436464
"output_type": "stream",
437465
"text": [
438-
"Collecting opencv-python\n",
439-
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/53/e0/21c8964fa8ef50842ebefaa7346a3cf0e37b56c8ecd97ed6bd2dbe577705/opencv_python-3.4.2.17-cp36-cp36m-manylinux1_x86_64.whl (25.0MB)\n",
440-
"\u001b[K 100% |████████████████████████████████| 25.0MB 2.1MB/s eta 0:00:01\n",
441-
"\u001b[?25hRequirement already satisfied: numpy>=1.11.3 in /home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages (from opencv-python) (1.14.5)\n",
442-
"\u001b[31mdistributed 1.21.8 requires msgpack, which is not installed.\u001b[0m\n",
443-
"Installing collected packages: opencv-python\n",
444-
"Successfully installed opencv-python-3.4.2.17\n",
445-
"\u001b[33mYou are using pip version 10.0.1, however version 18.0 is available.\n",
446-
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n"
466+
"\u001b[36malgo-1-L58J2_1 |\u001b[0m 172.18.0.1 - - [03/Aug/2018:22:32:52 +0000] \"POST /invocations HTTP/1.1\" 200 229 \"-\" \"-\"\r\n"
447467
]
468+
},
469+
{
470+
"data": {
471+
"text/plain": [
472+
"{'predictions': [{'probabilities': [2.29861e-05,\n",
473+
" 0.0104983,\n",
474+
" 0.147974,\n",
475+
" 0.01538,\n",
476+
" 0.0478089,\n",
477+
" 0.00164997,\n",
478+
" 0.758483,\n",
479+
" 0.0164191,\n",
480+
" 0.00125304,\n",
481+
" 0.000510801],\n",
482+
" 'classes': 6}]}"
483+
]
484+
},
485+
"execution_count": 6,
486+
"metadata": {},
487+
"output_type": "execute_result"
448488
}
449489
],
450-
"source": [
451-
"! pip install opencv-python"
452-
]
453-
},
454-
{
455-
"cell_type": "code",
456-
"execution_count": null,
457-
"metadata": {},
458-
"outputs": [],
459490
"source": [
460491
"import cv2\n",
461492
"import numpy\n",
@@ -484,28 +515,9 @@
484515
},
485516
{
486517
"cell_type": "code",
487-
"execution_count": 25,
518+
"execution_count": null,
488519
"metadata": {},
489-
"outputs": [
490-
{
491-
"name": "stderr",
492-
"output_type": "stream",
493-
"text": [
494-
"INFO:sagemaker:Deleting endpoint with name: tensorflow_cifar10_example-2018-08-03-18-06-55-168\n"
495-
]
496-
},
497-
{
498-
"name": "stdout",
499-
"output_type": "stream",
500-
"text": [
501-
"Gracefully stopping... (press Ctrl+C again to force)\n",
502-
"Stopping tmp3n0u5hj2_algo-1-HCRIC_1 ... \r\n",
503-
"\u001b[1A\u001b[2K\r",
504-
"Stopping tmp3n0u5hj2_algo-1-HCRIC_1 ... \u001b[32mdone\u001b[0m\r",
505-
"\u001b[1B"
506-
]
507-
}
508-
],
520+
"outputs": [],
509521
"source": [
510522
"predictor.delete_endpoint()"
511523
]
@@ -523,7 +535,7 @@
523535
},
524536
{
525537
"cell_type": "code",
526-
"execution_count": 26,
538+
"execution_count": null,
527539
"metadata": {},
528540
"outputs": [],
529541
"source": [
@@ -542,7 +554,7 @@
542554
},
543555
{
544556
"cell_type": "code",
545-
"execution_count": 27,
557+
"execution_count": null,
546558
"metadata": {},
547559
"outputs": [],
548560
"source": [
@@ -564,7 +576,7 @@
564576
},
565577
{
566578
"cell_type": "code",
567-
"execution_count": 28,
579+
"execution_count": null,
568580
"metadata": {},
569581
"outputs": [],
570582
"source": [
@@ -580,9 +592,32 @@
580592
"## Training on SageMaker\n",
581593
"Training a model on SageMaker with the Python SDK is done in a way that is similar to the way we trained it locally. This is done by changing our train_instance_type from `local` to one of our [supported EC2 instance types](https://aws.amazon.com/sagemaker/pricing/instance-types/).\n",
582594
"\n",
583-
"In addition, we must now specify the ECR image URL, which we just pushed above. Be sure to replace the string within the Estimator parameter, image_name.\n",
595+
"In addition, we must now specify the ECR image URL, which we just pushed above.\n",
596+
"\n",
597+
"Finally, our local training dataset has to be in Amazon S3 and the S3 URL to our dataset is passed into the `fit()` call.\n",
598+
"\n",
599+
"Let's first fetch our ECR image url that corresponds to the image we just built and pushed."
600+
]
601+
},
602+
{
603+
"cell_type": "code",
604+
"execution_count": null,
605+
"metadata": {},
606+
"outputs": [],
607+
"source": [
608+
"import boto3\n",
609+
"\n",
610+
"client = boto3.client('sts')\n",
611+
"account = client.get_caller_identity()['Account']\n",
584612
"\n",
585-
"Finally, our local training dataset has to be in Amazon S3 and the S3 URL to our dataset is passed into the `fit()` call."
613+
"my_session = boto3.session.Session()\n",
614+
"region = my_session.region_name\n",
615+
"\n",
616+
"algorithm_name = 'tensorflow-cifar10-example'\n",
617+
"\n",
618+
"ecr_image = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, algorithm_name)\n",
619+
"\n",
620+
"print(ecr_image)"
586621
]
587622
},
588623
{
@@ -600,7 +635,7 @@
600635
"estimator = Estimator(role=role,\n",
601636
" train_instance_count=1,\n",
602637
" train_instance_type=instance_type,\n",
603-
" image_name='ecr-image',\n",
638+
" image_name=ecr_image,\n",
604639
" hyperparameters=hyperparameters)\n",
605640
"\n",
606641
"estimator.fit(data_location)\n",
@@ -610,30 +645,9 @@
610645
},
611646
{
612647
"cell_type": "code",
613-
"execution_count": 33,
648+
"execution_count": null,
614649
"metadata": {},
615-
"outputs": [
616-
{
617-
"data": {
618-
"text/plain": [
619-
"{'predictions': [{'probabilities': [0.115806,\n",
620-
" 0.119459,\n",
621-
" 0.028497,\n",
622-
" 0.348986,\n",
623-
" 0.102692,\n",
624-
" 0.0354596,\n",
625-
" 0.0917221,\n",
626-
" 0.00540253,\n",
627-
" 0.121872,\n",
628-
" 0.0301034],\n",
629-
" 'classes': 3}]}"
630-
]
631-
},
632-
"execution_count": 33,
633-
"metadata": {},
634-
"output_type": "execute_result"
635-
}
636-
],
650+
"outputs": [],
637651
"source": [
638652
"image = cv2.imread(\"data/cat.png\", 1)\n",
639653
"\n",
@@ -662,17 +676,9 @@
662676
},
663677
{
664678
"cell_type": "code",
665-
"execution_count": 25,
679+
"execution_count": null,
666680
"metadata": {},
667-
"outputs": [
668-
{
669-
"name": "stdout",
670-
"output_type": "stream",
671-
"text": [
672-
"b'{\\n \"predictions\": [\\n {\\n \"classes\": 3,\\n \"probabilities\": [0.122724, 0.0958609, 0.0519071, 0.272535, 0.097384, 0.0535893, 0.0905842, 0.0250508, 0.123435, 0.0669298]\\n }\\n ]\\n}'\n"
673-
]
674-
}
675-
],
681+
"outputs": [],
676682
"source": [
677683
"import json\n",
678684
"\n",
@@ -681,9 +687,9 @@
681687
"endpoint_name = predictor.endpoint\n",
682688
"\n",
683689
"response = client.invoke_endpoint(EndpointName=endpoint_name, Body=json.dumps(data))\n",
684-
"response_body = response['Body'].decode('utf-8')\n",
690+
"response_body = response['Body']\n",
685691
"\n",
686-
"print(response_body.read())"
692+
"print(response_body.read().decode('utf-8'))"
687693
]
688694
},
689695
{

advanced_functionality/tensorflow_bring_your_own/utils/generate_cifar10_tfrecords.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# http://aws.amazon.com/apache2.0/
88
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
# ==============================================================================
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
1513
"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords.
1614
1715
Generates tf.train.Example protos and writes them to TFRecord files from the
@@ -29,20 +27,44 @@
2927
import tarfile
3028

3129
from six.moves import cPickle as pickle
30+
from six.moves import urllib
3231
from six.moves import xrange # pylint: disable=redefined-builtin
32+
from ipywidgets import FloatProgress
33+
from IPython.display import display
3334
import tensorflow as tf
3435

3536
CIFAR_FILENAME = 'cifar-10-python.tar.gz'
3637
CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
3738
CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
3839

3940

40-
def download_and_extract(data_dir):
41-
# download CIFAR-10 if not already downloaded.
42-
tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir,
43-
CIFAR_DOWNLOAD_URL)
44-
tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),
45-
'r:gz').extractall(data_dir)
41+
def download_and_extract(data_dir, print_progress=True):
42+
"""Download and extract the tarball from Alex's website."""
43+
if not os.path.exists(data_dir):
44+
os.makedirs(data_dir)
45+
46+
if os.path.exists(os.path.join(data_dir, 'cifar-10-batches-bin')):
47+
print('cifar dataset already downloaded')
48+
return
49+
50+
filename = CIFAR_DOWNLOAD_URL.split('/')[-1]
51+
filepath = os.path.join(data_dir, filename)
52+
53+
if not os.path.exists(filepath):
54+
f = FloatProgress(min=0, max=100)
55+
display(f)
56+
sys.stdout.write('\r>> Downloading %s ' % (filename))
57+
58+
def _progress(count, block_size, total_size):
59+
if print_progress:
60+
f.value = 100.0 * count * block_size / total_size
61+
62+
filepath, _ = urllib.request.urlretrieve(CIFAR_DOWNLOAD_URL, filepath, _progress)
63+
print()
64+
statinfo = os.stat(filepath)
65+
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
66+
67+
tarfile.open(filepath, 'r:gz').extractall(data_dir)
4668

4769

4870
def _int64_feature(value):

0 commit comments

Comments
 (0)