|
10 | 10 | "2. [Prerequisites and Preprocessing](#Prequisites-and-Preprocessing)\n",
|
11 | 11 | " 1. [Permissions and environment variables](#Permissions-and-environment-variables)\n",
|
12 | 12 | "3. [Training the ResNet model](#Training-the-ResNet-model)\n",
|
13 |
| - "4. [Set up hosting for the model](#Set-up-hosting-for-the-model)\n", |
14 |
| - " 1. [Import model into hosting](#Import-model-into-hosting)\n", |
15 |
| - " 2. [Create endpoint configuration](#Create-endpoint-configuration)\n", |
16 |
| - " 3. [Create endpoint](#Create-endpoint)\n", |
17 |
| - "5. [Validate the model for use](#Validate-the-model-for-use)\n" |
| 13 | + "4. [Deploy The Model](#Deploy-the-model)\n", |
| 14 | + " 1. [Create model](#Create-model)\n", |
| 15 | + " 2. [Batch transform](#Batch-transform)\n", |
| 16 | + " 3. [Realtime inference](#Realtime-inference)\n", |
| 17 | + " 1. [Create endpoint configuration](#Create-endpoint-configuration) \n", |
| 18 | + " 2. [Create endpoint](#Create-endpoint) \n", |
| 19 | + " 3. [Perform inference](#Perform-inference) \n", |
| 20 | + " 4. [Clean up](#Clean-up)\n" |
18 | 21 | ]
|
19 | 22 | },
|
20 | 23 | {
|
|
98 | 101 | "\n",
|
99 | 102 | "\n",
|
100 | 103 | "# caltech-256\n",
|
| 104 | + "s3_train_key = \"image-classification-full-training/train\"\n", |
| 105 | + "s3_validation_key = \"image-classification-full-training/validation\"\n", |
| 106 | + "s3_train = 's3://{}/{}/'.format(bucket, s3_train_key)\n", |
| 107 | + "s3_validation = 's3://{}/{}/'.format(bucket, s3_validation_key)\n", |
| 108 | + "\n", |
101 | 109 | "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')\n",
|
102 |
| - "upload_to_s3('train', 'caltech-256-60-train.rec')\n", |
| 110 | + "upload_to_s3(s3_train_key, 'caltech-256-60-train.rec')\n", |
103 | 111 | "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')\n",
|
104 |
| - "upload_to_s3('validation', 'caltech-256-60-val.rec')" |
| 112 | + "upload_to_s3(s3_validation_key, 'caltech-256-60-val.rec')" |
105 | 113 | ]
|
106 | 114 | },
|
107 | 115 | {
|
|
131 | 139 | "* **mini_batch_size**: The number of training samples used for each mini batch. In distributed training, the number of training samples used per batch will be N * mini_batch_size where N is the number of hosts on which training is run"
|
132 | 140 | ]
|
133 | 141 | },
|
| 142 | + { |
| 143 | + "cell_type": "markdown", |
| 144 | + "metadata": {}, |
| 145 | + "source": [ |
| 146 | + "After setting training parameters, we kick off training, and poll for status until training is completed, which in this example, takes between 10 to 12 minutes per epoch on a p2.xlarge machine. The network typically converges after 10 epochs. However, to save the training time, we set the epochs to 2 but please keep in mind that it may not be sufficient to generate a good model. " |
| 147 | + ] |
| 148 | + }, |
134 | 149 | {
|
135 | 150 | "cell_type": "code",
|
136 | 151 | "execution_count": null,
|
|
223 | 238 | " \"DataSource\": {\n",
|
224 | 239 | " \"S3DataSource\": {\n",
|
225 | 240 | " \"S3DataType\": \"S3Prefix\",\n",
|
226 |
| - " \"S3Uri\": 's3://{}/train/'.format(bucket),\n", |
| 241 | + " \"S3Uri\": s3_train,\n", |
227 | 242 | " \"S3DataDistributionType\": \"FullyReplicated\"\n",
|
228 | 243 | " }\n",
|
229 | 244 | " },\n",
|
|
235 | 250 | " \"DataSource\": {\n",
|
236 | 251 | " \"S3DataSource\": {\n",
|
237 | 252 | " \"S3DataType\": \"S3Prefix\",\n",
|
238 |
| - " \"S3Uri\": 's3://{}/validation/'.format(bucket),\n", |
| 253 | + " \"S3Uri\": s3_validation,\n", |
239 | 254 | " \"S3DataDistributionType\": \"FullyReplicated\"\n",
|
240 | 255 | " }\n",
|
241 | 256 | " },\n",
|
|
307 | 322 | "cell_type": "markdown",
|
308 | 323 | "metadata": {},
|
309 | 324 | "source": [
|
310 |
| - "# Inference\n", |
| 325 | + "# Deploy The Model\n", |
311 | 326 | "\n",
|
312 | 327 | "***\n",
|
313 | 328 | "\n",
|
|
316 | 331 | "This section involves several steps,\n",
|
317 | 332 | "\n",
|
318 | 333 | "1. [Create Model](#CreateModel) - Create model for the training output\n",
|
319 |
| - "1. [Create Endpoint Configuration](#CreateEndpointConfiguration) - Create a configuration defining an endpoint.\n", |
320 |
| - "1. [Create Endpoint](#CreateEndpoint) - Use the configuration to create an inference endpoint.\n", |
321 |
| - "1. [Perform Inference](#Perform Inference) - Perform inference on some input data using the endpoint." |
| 334 | + "1. [Batch Transform](#BatchTransform) - Create a transform job to perform batch inference.\n", |
| 335 | + "1. [Host the model for realtime inference](#HostTheModel) - Create an inference endpoint and perform realtime inference." |
322 | 336 | ]
|
323 | 337 | },
|
324 | 338 | {
|
|
327 | 341 | "source": [
|
328 | 342 | "## Create Model\n",
|
329 | 343 | "\n",
|
330 |
| - "We now create a SageMaker Model from the training output. Using the model we can create an Endpoint Configuration." |
| 344 | + "We now create a SageMaker Model from the training output. Using the model we can create a Batch Transform Job or an Endpoint Configuration." |
331 | 345 | ]
|
332 | 346 | },
|
333 | 347 | {
|
|
369 | 383 | "cell_type": "markdown",
|
370 | 384 | "metadata": {},
|
371 | 385 | "source": [
|
372 |
| - "### Create Endpoint Configuration\n", |
| 386 | + "### Batch transform\n", |
| 387 | + "\n", |
| 388 | + "We now create a SageMaker Batch Transform job using the model created above to perform batch prediction." |
| 389 | + ] |
| 390 | + }, |
| 391 | + { |
| 392 | + "cell_type": "markdown", |
| 393 | + "metadata": {}, |
| 394 | + "source": [ |
| 395 | + "#### Download test data" |
| 396 | + ] |
| 397 | + }, |
| 398 | + { |
| 399 | + "cell_type": "code", |
| 400 | + "execution_count": null, |
| 401 | + "metadata": {}, |
| 402 | + "outputs": [], |
| 403 | + "source": [ |
| 404 | + "# Download images under /008.bathtub\n", |
| 405 | + "!wget -r -np -nH --cut-dirs=2 -P /tmp/ -R \"index.html*\" http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/\n" |
| 406 | + ] |
| 407 | + }, |
| 408 | + { |
| 409 | + "cell_type": "code", |
| 410 | + "execution_count": null, |
| 411 | + "metadata": {}, |
| 412 | + "outputs": [], |
| 413 | + "source": [ |
| 414 | + "batch_input = 's3://{}/image-classification-full-training/test/'.format(bucket)\n", |
| 415 | + "test_images = '/tmp/images/008.bathtub'\n", |
| 416 | + "\n", |
| 417 | + "!aws s3 cp $test_images $batch_input --recursive --quiet " |
| 418 | + ] |
| 419 | + }, |
| 420 | + { |
| 421 | + "cell_type": "code", |
| 422 | + "execution_count": null, |
| 423 | + "metadata": {}, |
| 424 | + "outputs": [], |
| 425 | + "source": [ |
| 426 | + "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n", |
| 427 | + "batch_job_name = \"image-classification-model\" + timestamp\n", |
| 428 | + "request = \\\n", |
| 429 | + "{\n", |
| 430 | + " \"TransformJobName\": batch_job_name,\n", |
| 431 | + " \"ModelName\": model_name,\n", |
| 432 | + " \"MaxConcurrentTransforms\": 16,\n", |
| 433 | + " \"MaxPayloadInMB\": 6,\n", |
| 434 | + " \"BatchStrategy\": \"SingleRecord\",\n", |
| 435 | + " \"TransformOutput\": {\n", |
| 436 | + " \"S3OutputPath\": 's3://{}/{}/output'.format(bucket, batch_job_name)\n", |
| 437 | + " },\n", |
| 438 | + " \"TransformInput\": {\n", |
| 439 | + " \"DataSource\": {\n", |
| 440 | + " \"S3DataSource\": {\n", |
| 441 | + " \"S3DataType\": \"S3Prefix\",\n", |
| 442 | + " \"S3Uri\": batch_input\n", |
| 443 | + " }\n", |
| 444 | + " },\n", |
| 445 | + " \"ContentType\": \"application/x-image\",\n", |
| 446 | + " \"SplitType\": \"None\",\n", |
| 447 | + " \"CompressionType\": \"None\"\n", |
| 448 | + " },\n", |
| 449 | + " \"TransformResources\": {\n", |
| 450 | + " \"InstanceType\": \"ml.p2.xlarge\",\n", |
| 451 | + " \"InstanceCount\": 1\n", |
| 452 | + " }\n", |
| 453 | + "}\n", |
| 454 | + "\n", |
| 455 | + "print('Transform job name: {}'.format(batch_job_name))\n", |
| 456 | + "print('\\nInput Data Location: {}'.format(s3_validation))" |
| 457 | + ] |
| 458 | + }, |
| 459 | + { |
| 460 | + "cell_type": "code", |
| 461 | + "execution_count": null, |
| 462 | + "metadata": {}, |
| 463 | + "outputs": [], |
| 464 | + "source": [ |
| 465 | + "sagemaker = boto3.client('sagemaker')\n", |
| 466 | + "sagemaker.create_transform_job(**request)\n", |
| 467 | + "\n", |
| 468 | + "print(\"Created Transform job with name: \", batch_job_name)\n", |
| 469 | + "\n", |
| 470 | + "while(True):\n", |
| 471 | + " response = sagemaker.describe_transform_job(TransformJobName=batch_job_name)\n", |
| 472 | + " status = response['TransformJobStatus']\n", |
| 473 | + " if status == 'Completed':\n", |
| 474 | + " print(\"Transform job ended with status: \" + status)\n", |
| 475 | + " break\n", |
| 476 | + " if status == 'Failed':\n", |
| 477 | + " message = response['FailureReason']\n", |
| 478 | + " print('Transform failed with the following error: {}'.format(message))\n", |
| 479 | + " raise Exception('Transform job failed') \n", |
| 480 | + " time.sleep(30) " |
| 481 | + ] |
| 482 | + }, |
| 483 | + { |
| 484 | + "cell_type": "markdown", |
| 485 | + "metadata": {}, |
| 486 | + "source": [ |
| 487 | + "After the job completes, let's inspect the prediction results. The accuracy may not be quite good because we set the epochs to 2 during training which may not be sufficient to train a good model. " |
| 488 | + ] |
| 489 | + }, |
| 490 | + { |
| 491 | + "cell_type": "code", |
| 492 | + "execution_count": null, |
| 493 | + "metadata": {}, |
| 494 | + "outputs": [], |
| 495 | + "source": [ |
| 496 | + "from urllib.parse import urlparse\n", |
| 497 | + "import json\n", |
| 498 | + "import numpy as np\n", |
| 499 | + "\n", |
| 500 | + "s3_client = boto3.client('s3')\n", |
| 501 | + "object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']\n", |
| 502 | + "\n", |
| 503 | + "def list_objects(s3_client, bucket, prefix):\n", |
| 504 | + " response = s3_client.list_objects(Bucket=bucket, Prefix=prefix)\n", |
| 505 | + " objects = [content['Key'] for content in response['Contents']]\n", |
| 506 | + " return objects\n", |
| 507 | + "\n", |
| 508 | + "def get_label(s3_client, bucket, prefix):\n", |
| 509 | + " filename = prefix.split('/')[-1]\n", |
| 510 | + " s3_client.download_file(bucket, prefix, filename)\n", |
| 511 | + " with open(filename) as f:\n", |
| 512 | + " data = json.load(f)\n", |
| 513 | + " index = np.argmax(data['prediction'])\n", |
| 514 | + " probability = data['prediction'][index]\n", |
| 515 | + " print(\"Result: label - \" + object_categories[index] + \", probability - \" + str(probability))\n", |
| 516 | + " return object_categories[index], probability\n", |
| 517 | + "\n", |
| 518 | + "inputs = list_objects(s3_client, bucket, urlparse(batch_input).path.lstrip('/'))\n", |
| 519 | + "print(\"Sample inputs: \" + str(inputs[:2]))\n", |
| 520 | + "\n", |
| 521 | + "outputs = list_objects(s3_client, bucket, batch_job_name + \"/output\")\n", |
| 522 | + "print(\"Sample output: \" + str(outputs[:2]))\n", |
| 523 | + "\n", |
| 524 | + "# Check prediction result of the first 2 images\n", |
| 525 | + "[get_label(s3_client, bucket, prefix) for prefix in outputs[0:10]]" |
| 526 | + ] |
| 527 | + }, |
| 528 | + { |
| 529 | + "cell_type": "markdown", |
| 530 | + "metadata": {}, |
| 531 | + "source": [ |
| 532 | + "### Realtime inference\n", |
| 533 | + "\n", |
| 534 | + "We now host the model with an endpoint and perform realtime inference.\n", |
| 535 | + "\n", |
| 536 | + "This section involves several steps,\n", |
| 537 | + "1. [Create endpoint configuration](#CreateEndpointConfiguration) - Create a configuration defining an endpoint.\n", |
| 538 | + "1. [Create endpoint](#CreateEndpoint) - Use the configuration to create an inference endpoint.\n", |
| 539 | + "1. [Perform inference](#PerformInference) - Perform inference on some input data using the endpoint.\n", |
| 540 | + "1. [Clean up](#CleanUp) - Delete the endpoint and model" |
| 541 | + ] |
| 542 | + }, |
| 543 | + { |
| 544 | + "cell_type": "markdown", |
| 545 | + "metadata": {}, |
| 546 | + "source": [ |
| 547 | + "#### Create Endpoint Configuration\n", |
373 | 548 | "At launch, we will support configuring REST endpoints in hosting with multiple models, e.g. for A/B testing purposes. In order to support this, customers create an endpoint configuration, that describes the distribution of traffic across the models, whether split, shadowed, or sampled in some way.\n",
|
374 | 549 | "\n",
|
375 | 550 | "In addition, the endpoint configuration describes the instance type required for model deployment, and at launch will describe the autoscaling configuration."
|
|
403 | 578 | "cell_type": "markdown",
|
404 | 579 | "metadata": {},
|
405 | 580 | "source": [
|
406 |
| - "### Create Endpoint\n", |
407 |
| - "Lastly, the customer creates the endpoint that serves up the model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 9-11 minutes to complete." |
| 581 | + "#### Create Endpoint\n", |
| 582 | + "Next, the customer creates the endpoint that serves up the model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 9-11 minutes to complete." |
408 | 583 | ]
|
409 | 584 | },
|
410 | 585 | {
|
|
434 | 609 | "cell_type": "markdown",
|
435 | 610 | "metadata": {},
|
436 | 611 | "source": [
|
437 |
| - "Finally, now the endpoint can be created. It may take sometime to create the endpoint..." |
| 612 | + "Now the endpoint can be created. It may take sometime to create the endpoint..." |
438 | 613 | ]
|
439 | 614 | },
|
440 | 615 | {
|
|
481 | 656 | "cell_type": "markdown",
|
482 | 657 | "metadata": {},
|
483 | 658 | "source": [
|
484 |
| - "## Perform Inference\n", |
| 659 | + "#### Perform Inference\n", |
485 | 660 | "Finally, the customer can now validate the model for use. They can obtain the endpoint from the client library using the result from previous operations, and generate classifications from the trained model using that endpoint.\n"
|
486 | 661 | ]
|
487 | 662 | },
|
|
501 | 676 | "cell_type": "markdown",
|
502 | 677 | "metadata": {},
|
503 | 678 | "source": [
|
504 |
| - "### Download test image" |
| 679 | + "##### Download test image" |
505 | 680 | ]
|
506 | 681 | },
|
507 | 682 | {
|
|
523 | 698 | "cell_type": "markdown",
|
524 | 699 | "metadata": {},
|
525 | 700 | "source": [
|
526 |
| - "### Evaluation\n", |
| 701 | + "##### Evaluation\n", |
527 | 702 | "\n",
|
528 | 703 | "Evaluate the image through the network for inteference. The network outputs class probabilities and typically, one selects the class with the maximum probability as the final class output.\n",
|
529 | 704 | "\n",
|
|
561 | 736 | "cell_type": "markdown",
|
562 | 737 | "metadata": {},
|
563 | 738 | "source": [
|
564 |
| - "### Clean up\n", |
| 739 | + "#### Clean up\n", |
565 | 740 | "\n",
|
566 | 741 | "When we're done with the endpoint, we can just delete it and the backing instances will be released. Run the following cell to delete the endpoint."
|
567 | 742 | ]
|
|
0 commit comments