|
1 | 1 | {
|
2 | 2 | "cells": [
|
3 |
| - { |
4 |
| - "cell_type": "markdown", |
5 |
| - "metadata": {}, |
6 |
| - "source": [ |
7 |
| - "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n", |
8 |
| - "\n", |
9 |
| - "Licensed under the Apache License, Version 2.0 (the \"License\").\n", |
10 |
| - "You may not use this file except in compliance with the License.\n", |
11 |
| - "A copy of the License is located at\n", |
12 |
| - " \n", |
13 |
| - " http://aws.amazon.com/apache2.0/\n", |
14 |
| - "\n", |
15 |
| - "or in the \"license\" file accompanying this file. This file is distributed\n", |
16 |
| - "on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either\n", |
17 |
| - "express or implied. See the License for the specific language governing\n", |
18 |
| - "permissions and limitations under the License." |
19 |
| - ] |
20 |
| - }, |
21 | 3 | {
|
22 | 4 | "cell_type": "markdown",
|
23 | 5 | "metadata": {},
|
24 | 6 | "source": [
|
25 | 7 | "# SageMaker PySpark K-Means Clustering MNIST Example\n",
|
26 | 8 | "\n",
|
27 | 9 | "1. [Introduction](#Introduction)\n",
|
28 |
| - "2. [Loading the Data](#Loading-the-Data)\n", |
29 |
| - "3. [Training and Hosting a Model](#Training-and-Hosting-a-Model)\n", |
30 |
| - "4. [Inference](#Inference)\n", |
31 |
| - "5. [More on SageMaker Spark](#More-on-SageMaker-Spark)\n" |
| 10 | + "2. [Setup](#Setup)\n", |
| 11 | + "3. [Loading the Data](#Loading-the-Data)\n", |
| 12 | + "4. [Training and Hosting a Model](#Training-and-Hosting-a-Model)\n", |
| 13 | + "5. [Inference](#Inference)\n", |
| 14 | + "6. [More on SageMaker Spark](#More-on-SageMaker-Spark)\n" |
32 | 15 | ]
|
33 | 16 | },
|
34 | 17 | {
|
|
38 | 21 | "## Introduction\n",
|
39 | 22 | "This notebook will show how to classify handwritten digits using the K-Means clustering algorithm through the SageMaker PySpark library. We will train on Amazon SageMaker using K-Means clustering on the MNIST dataset, host the trained model on Amazon SageMaker, and then make predictions against that hosted model.\n",
|
40 | 23 | "\n",
|
41 |
| - "You can visit SageMaker Spark's Github repository at https://github.com/aws/sagemaker-spark for more about SageMaker Spark.\n", |
| 24 | + "Unlike the other notebooks that demonstrate K-Means clustering on Amazon SageMaker, this notebook uses a SparkSession to manipulate data, and uses the SageMaker Spark library to interact with SageMaker with Spark Estimators and Transformers.\n", |
42 | 25 | "\n",
|
43 |
| - "We will train on Amazon SageMaker using the KMeans Clustering on the MNIST dataset, host the trained model on Amazon SageMaker, and then make predictions against that hosted model.\n", |
| 26 | + "You can visit SageMaker Spark's GitHub repository at https://github.com/aws/sagemaker-spark to learn more about SageMaker Spark.\n", |
44 | 27 | "\n",
|
45 |
| - "First, we load the MNIST dataset into a Spark Dataframe, which dataset is available in LibSVM format at\n", |
46 |
| - "\n", |
47 |
| - "`s3://sagemaker-sample-data-[region]/spark/mnist/train/`\n", |
| 28 | + "This notebook was created and tested on an ml.m4.xlarge notebook instance." |
| 29 | + ] |
| 30 | + }, |
| 31 | + { |
| 32 | + "cell_type": "markdown", |
| 33 | + "metadata": {}, |
| 34 | + "source": [ |
| 35 | + "## Setup\n", |
48 | 36 | "\n",
|
49 |
| - "where `[region]` is replaced with a supported AWS region, such as us-east-1" |
| 37 | + "First, we import the necessary modules and create the SparkSession with the SageMaker Spark dependencies." |
50 | 38 | ]
|
51 | 39 | },
|
52 | 40 | {
|
|
63 | 51 | "from pyspark.sql import SparkSession\n",
|
64 | 52 | "\n",
|
65 | 53 | "import sagemaker\n",
|
66 |
| - "import sagemaker_pyspark\n", |
67 | 54 | "from sagemaker import get_execution_role\n",
|
68 |
| - "\n", |
69 |
| - "sagemaker_session = sagemaker.Session()\n", |
| 55 | + "import sagemaker_pyspark\n", |
70 | 56 | "\n",
|
71 | 57 | "role = get_execution_role()\n",
|
72 | 58 | "\n",
|
|
81 | 67 | " .master(\"local[*]\").getOrCreate()"
|
82 | 68 | ]
|
83 | 69 | },
|
84 |
| - { |
85 |
| - "cell_type": "code", |
86 |
| - "execution_count": null, |
87 |
| - "metadata": {}, |
88 |
| - "outputs": [], |
89 |
| - "source": [ |
90 |
| - "import boto3\n", |
91 |
| - "\n", |
92 |
| - "# use the region-specific sample data bucket\n", |
93 |
| - "region = boto3.Session().region_name\n", |
94 |
| - "trainingData = spark.read.format('libsvm')\\\n", |
95 |
| - " .option('numFeatures', '784')\\\n", |
96 |
| - " .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))\n", |
97 |
| - "\n", |
98 |
| - "testData = spark.read.format('libsvm')\\\n", |
99 |
| - " .option('numFeatures', '784')\\\n", |
100 |
| - " .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))" |
101 |
| - ] |
102 |
| - }, |
103 | 70 | {
|
104 | 71 | "cell_type": "markdown",
|
105 | 72 | "metadata": {},
|
106 | 73 | "source": [
|
107 | 74 | "## Loading the Data\n",
|
108 | 75 | "\n",
|
| 76 | + "Now, we load the MNIST dataset into a Spark Dataframe, which dataset is available in LibSVM format at\n", |
| 77 | + "\n", |
| 78 | + "`s3://sagemaker-sample-data-[region]/spark/mnist/train/`\n", |
| 79 | + "\n", |
| 80 | + "where `[region]` is replaced with a supported AWS region, such as us-east-1.\n", |
| 81 | + "\n", |
109 | 82 | "In order to train and make inferences our input DataFrame must have a column of Doubles (named \"label\" by default) and a column of Vectors of Doubles (named \"features\" by default).\n",
|
110 | 83 | "\n",
|
111 |
| - "Spark's LibSVM DataFrameReader loads a DataFrame already suitable for training and inference." |
| 84 | + "Spark's LibSVM DataFrameReader loads a DataFrame already suitable for training and inference.\n", |
| 85 | + "\n", |
| 86 | + "Here, we load into a DataFrame in the SparkSession running on the local Notebook Instance, but you can connect your Notebook Instance to a remote Spark cluster for heavier workloads. Starting from EMR 5.11.0, SageMaker Spark is pre-installed on EMR Spark clusters. For more on connecting your SageMaker Notebook Instance to a remote EMR cluster, please see [this blog post](https://aws.amazon.com/blogs/machine-learning/build-amazon-sagemaker-notebooks-backed-by-spark-in-amazon-emr/)." |
112 | 87 | ]
|
113 | 88 | },
|
114 | 89 | {
|
|
119 | 94 | },
|
120 | 95 | "outputs": [],
|
121 | 96 | "source": [
|
| 97 | + "import boto3\n", |
| 98 | + "\n", |
| 99 | + "region = boto3.Session().region_name\n", |
| 100 | + "\n", |
| 101 | + "trainingData = spark.read.format('libsvm')\\\n", |
| 102 | + " .option('numFeatures', '784')\\\n", |
| 103 | + " .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))\n", |
| 104 | + "\n", |
| 105 | + "testData = spark.read.format('libsvm')\\\n", |
| 106 | + " .option('numFeatures', '784')\\\n", |
| 107 | + " .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))\n", |
| 108 | + "\n", |
122 | 109 | "trainingData.show()"
|
123 | 110 | ]
|
124 | 111 | },
|
|
151 | 138 | "\n",
|
152 | 139 | "kmeans_estimator = KMeansSageMakerEstimator(\n",
|
153 | 140 | " sagemakerRole=IAMRole(role),\n",
|
154 |
| - " trainingInstanceType='ml.p2.xlarge',\n", |
| 141 | + " trainingInstanceType='ml.m4.xlarge',\n", |
155 | 142 | " trainingInstanceCount=1,\n",
|
156 |
| - " endpointInstanceType='ml.c4.xlarge',\n", |
| 143 | + " endpointInstanceType='ml.m4.xlarge',\n", |
157 | 144 | " endpointInitialInstanceCount=1)\n",
|
158 | 145 | "\n",
|
159 | 146 | "kmeans_estimator.setK(10)\n",
|
|
170 | 157 | "## Inference\n",
|
171 | 158 | "\n",
|
172 | 159 | "Now we transform our DataFrame.\n",
|
173 |
| - "To do this, we serialize each row's \"features\" Vector of Doubles into a Protobuf format for inference against the Amazon SageMaker Endpoint. We deserialize the Protobuf responses back into our DataFrame:" |
| 160 | + "To do this, we serialize each row's \"features\" Vector of Doubles into a Protobuf format for inference against the Amazon SageMaker Endpoint. We deserialize the Protobuf responses back into our DataFrame. This serialization and deserialization is handled automatically by the `transform()` method:" |
174 | 161 | ]
|
175 | 162 | },
|
176 | 163 | {
|
|
186 | 173 | "transformedData.show()"
|
187 | 174 | ]
|
188 | 175 | },
|
| 176 | + { |
| 177 | + "cell_type": "markdown", |
| 178 | + "metadata": {}, |
| 179 | + "source": [ |
| 180 | + "How well did the algorithm perform? Let us display the digits from each of the clusters and manually inspect the results:" |
| 181 | + ] |
| 182 | + }, |
189 | 183 | {
|
190 | 184 | "cell_type": "code",
|
191 | 185 | "execution_count": null,
|
|
278 | 272 | "nbconvert_exporter": "python",
|
279 | 273 | "pygments_lexer": "ipython3",
|
280 | 274 | "version": "3.6.2"
|
281 |
| - } |
| 275 | + }, |
| 276 | + "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." |
282 | 277 | },
|
283 | 278 | "nbformat": 4,
|
284 | 279 | "nbformat_minor": 2
|
|
0 commit comments