|
29 | 29 | "from sagemaker import get_execution_role\n",
|
30 | 30 | "from sagemaker.session import Session\n",
|
31 | 31 | "\n",
|
| 32 | + "sagemaker_session = Session()\n", |
| 33 | + "region = sagemaker_session.boto_session.region_name\n", |
| 34 | + "sample_data_bucket = 'sagemaker-sample-data-{}'.format(region)\n", |
| 35 | + "\n", |
32 | 36 | "# S3 bucket for saving files. Feel free to redefine this variable to the bucket of your choice.\n",
|
33 |
| - "bucket = Session().default_bucket()\n", |
| 37 | + "bucket = sagemaker_session.default_bucket()\n", |
34 | 38 | "\n",
|
35 | 39 | "# Bucket location where your custom code will be saved in the tar.gz format.\n",
|
36 |
| - "custom_code_upload_location = 's3://{}/customcode/mxnet'.format(bucket)\n", |
| 40 | + "custom_code_upload_location = 's3://{}/mxnet-mnist-example/code'.format(bucket)\n", |
37 | 41 | "\n",
|
38 | 42 | "# Bucket location where results of model training are saved.\n",
|
39 |
| - "model_artifacts_location = 's3://{}/artifacts'.format(bucket)\n", |
| 43 | + "model_artifacts_location = 's3://{}/mxnet-mnist-example/artifacts'.format(bucket)\n", |
40 | 44 | "\n",
|
41 | 45 | "# IAM execution role that gives SageMaker access to resources in your AWS account.\n",
|
42 | 46 | "# We can use the SageMaker Python SDK to get the role from our notebook environment. \n",
|
|
111 | 115 | "outputs": [],
|
112 | 116 | "source": [
|
113 | 117 | "%%time\n",
|
114 |
| - "import boto3\n", |
115 | 118 | "\n",
|
116 |
| - "region = boto3.Session().region_name\n", |
117 |
| - "train_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/train'.format(region)\n", |
118 |
| - "test_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/test'.format(region)\n", |
| 119 | + "train_data_location = 's3://{}/mxnet/mnist/train'.format(sample_data_bucket)\n", |
| 120 | + "test_data_location = 's3://{}/mxnet/mnist/test'.format(sample_data_bucket)\n", |
119 | 121 | "\n",
|
120 | 122 | "mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})"
|
121 | 123 | ]
|
|
126 | 128 | "source": [
|
127 | 129 | "### SageMaker's transformer class\n",
|
128 | 130 | "\n",
|
129 |
| - "After training, we use our MXNet estimator object to create a `Transformer` by invoking the `transformer()` method. This method takes arguments for configuring our options with the batch transform job; these do not need to be the same values as the one we used for the training job.\n", |
| 131 | + "After training, we use our MXNet estimator object to create a `Transformer` by invoking the `transformer()` method. This method takes arguments for configuring our options with the batch transform job; these do not need to be the same values as the one we used for the training job. The method also creates a SageMaker Model to be used for the batch transform jobs.\n", |
130 | 132 | "\n",
|
131 | 133 | "The `Transformer` class is responsible for running batch transform jobs, which will deploy the trained model to an endpoint and send requests for performing inference."
|
132 | 134 | ]
|
|
148 | 150 | "\n",
|
149 | 151 | "Now we can perform some inference with the model we've trained by running a batch transform job. The request handling behavior of the Endpoint deployed during the transform job is determined by the `mnist.py` script.\n",
|
150 | 152 | "\n",
|
151 |
| - "For demonstration purposes, we will be using an image of a '7' that's already saved in S3:" |
152 |
| - ] |
153 |
| - }, |
154 |
| - { |
155 |
| - "cell_type": "code", |
156 |
| - "execution_count": null, |
157 |
| - "metadata": {}, |
158 |
| - "outputs": [], |
159 |
| - "source": [ |
160 |
| - "transform_data_location = 's3://sagemaker-sample-data-{}/batch-transform/mnist'.format(region)" |
161 |
| - ] |
162 |
| - }, |
163 |
| - { |
164 |
| - "cell_type": "markdown", |
165 |
| - "metadata": {}, |
166 |
| - "source": [ |
167 |
| - "Just for fun, we can print out what the image looks like. First we'll create a temporary directory:" |
| 153 | + "For demonstration purposes, we're going to use input data that contains 1000 MNIST images, located in the public SageMaker sample data S3 bucket. To create the batch transform job, we simply call `transform()` on our transformer with information about the input data." |
168 | 154 | ]
|
169 | 155 | },
|
170 | 156 | {
|
|
173 | 159 | "metadata": {},
|
174 | 160 | "outputs": [],
|
175 | 161 | "source": [
|
176 |
| - "import os\n", |
| 162 | + "input_file_path = 'batch-transform/mnist-1000-samples'\n", |
177 | 163 | "\n",
|
178 |
| - "tmp_dir = '/tmp/data'\n", |
179 |
| - "\n", |
180 |
| - "if not os.path.exists(tmp_dir):\n", |
181 |
| - " os.makedirs(tmp_dir)" |
| 164 | + "transformer.transform('s3://{}/{}'.format(sample_data_bucket, input_file_path), content_type='text/csv')" |
182 | 165 | ]
|
183 | 166 | },
|
184 | 167 | {
|
185 | 168 | "cell_type": "markdown",
|
186 | 169 | "metadata": {},
|
187 | 170 | "source": [
|
188 |
| - "And now we'll print out the image:" |
| 171 | + "Now we wait for the batch transform job to complete. We have a convenience method, `wait()`, that will block until the batch transform job has completed. We can call that here to see if the batch transform job is still running; the cell will finish running when the batch transform job has completed." |
189 | 172 | ]
|
190 | 173 | },
|
191 | 174 | {
|
|
194 | 177 | "metadata": {},
|
195 | 178 | "outputs": [],
|
196 | 179 | "source": [
|
197 |
| - "from numpy import genfromtxt\n", |
198 |
| - "import matplotlib.pyplot as plt\n", |
199 |
| - "\n", |
200 |
| - "plt.rcParams[\"figure.figsize\"] = (2,10)\n", |
201 |
| - " \n", |
202 |
| - "def show_digit(img, caption='', subplot=None):\n", |
203 |
| - " if subplot==None:\n", |
204 |
| - " _,(subplot)=plt.subplots(1,1)\n", |
205 |
| - " imgr=img.reshape((28,28))\n", |
206 |
| - " subplot.axis('off')\n", |
207 |
| - " subplot.imshow(imgr, cmap='gray')\n", |
208 |
| - " plt.title(caption)\n", |
209 |
| - " \n", |
210 |
| - "input_data_file = '/tmp/data/mnist_data.csv'\n", |
211 |
| - "\n", |
212 |
| - "s3 = boto3.resource('s3')\n", |
213 |
| - "s3.Bucket('sagemaker-sample-data-{}'.format(region)).download_file('batch-transform/mnist/data.csv', input_data_file)\n", |
214 |
| - "\n", |
215 |
| - "input_data = genfromtxt(input_data_file, delimiter=',')\n", |
216 |
| - "show_digit(input_data)" |
| 180 | + "transformer.wait()" |
217 | 181 | ]
|
218 | 182 | },
|
219 | 183 | {
|
220 | 184 | "cell_type": "markdown",
|
221 | 185 | "metadata": {},
|
222 | 186 | "source": [
|
223 |
| - "Now we can use the Transformer to classify the handwritten digit:" |
| 187 | + "### Downloading the results\n", |
| 188 | + "\n", |
| 189 | + "The batch transform job uploads its predictions to S3. Since we did not specify `output_path` when creating the Transformer, one was generated based on the batch transform job name:" |
224 | 190 | ]
|
225 | 191 | },
|
226 | 192 | {
|
|
229 | 195 | "metadata": {},
|
230 | 196 | "outputs": [],
|
231 | 197 | "source": [
|
232 |
| - "transformer.transform(transform_data_location, content_type='text/csv')" |
| 198 | + "print(transformer.output_path)" |
233 | 199 | ]
|
234 | 200 | },
|
235 | 201 | {
|
236 | 202 | "cell_type": "markdown",
|
237 | 203 | "metadata": {},
|
238 | 204 | "source": [
|
239 |
| - "Now we wait for the batch transform job to complete. We have a convenience method, `wait()`, that will block until the batch transform job has completed. We can call that here to see if the batch transform job is still running; the cell will finish running when the batch transform job has completed." |
| 205 | + "The output here will be a list of predictions, where each prediction is a list of probabilities, one for each possible label. Since we read the output as a string, we use `ast.literal_eval()` to turn it into a list and find the maximum element of the list gives us the predicted label. Here we define a convenience method to take the output and produce the predicted label." |
240 | 206 | ]
|
241 | 207 | },
|
242 | 208 | {
|
|
245 | 211 | "metadata": {},
|
246 | 212 | "outputs": [],
|
247 | 213 | "source": [
|
248 |
| - "transformer.wait()" |
| 214 | + "import ast\n", |
| 215 | + "\n", |
| 216 | + "def predicted_label(transform_output):\n", |
| 217 | + " output = ast.literal_eval(transform_output)\n", |
| 218 | + " probabilities = output[0]\n", |
| 219 | + " return probabilities.index(max(probabilities))" |
249 | 220 | ]
|
250 | 221 | },
|
251 | 222 | {
|
252 | 223 | "cell_type": "markdown",
|
253 | 224 | "metadata": {},
|
254 | 225 | "source": [
|
255 |
| - "### Downloading the results\n", |
256 |
| - "\n", |
257 |
| - "The batch transform job uploads its predictions to S3. Since we did not specify `output_path` when creating the Transformer, one was generated based on the batch transform job name:" |
| 226 | + "Now let's download the first ten results from S3:" |
258 | 227 | ]
|
259 | 228 | },
|
260 | 229 | {
|
|
263 | 232 | "metadata": {},
|
264 | 233 | "outputs": [],
|
265 | 234 | "source": [
|
266 |
| - "print(transformer.output_path)" |
| 235 | + "import json\n", |
| 236 | + "from urllib.parse import urlparse\n", |
| 237 | + "\n", |
| 238 | + "import boto3\n", |
| 239 | + "\n", |
| 240 | + "parsed_url = urlparse(transformer.output_path)\n", |
| 241 | + "bucket_name = parsed_url.netloc\n", |
| 242 | + "prefix = parsed_url.path[1:]\n", |
| 243 | + "\n", |
| 244 | + "s3 = boto3.resource('s3')\n", |
| 245 | + "\n", |
| 246 | + "predictions = []\n", |
| 247 | + "for i in range(10):\n", |
| 248 | + " file_key = '{}/data-{}.csv.out'.format(prefix, i)\n", |
| 249 | + "\n", |
| 250 | + " output_obj = s3.Object(bucket_name, file_key)\n", |
| 251 | + " output = output_obj.get()[\"Body\"].read().decode('utf-8')\n", |
| 252 | + " \n", |
| 253 | + " predictions.append(predicted_label(output))" |
267 | 254 | ]
|
268 | 255 | },
|
269 | 256 | {
|
270 | 257 | "cell_type": "markdown",
|
271 | 258 | "metadata": {},
|
272 | 259 | "source": [
|
273 |
| - "We use that to download the results from S3:" |
| 260 | + "For demonstration purposes, we're also going to download the corresponding original input data so that we can see how the model did with its predictions." |
274 | 261 | ]
|
275 | 262 | },
|
276 | 263 | {
|
277 | 264 | "cell_type": "code",
|
278 | 265 | "execution_count": null,
|
279 |
| - "metadata": { |
280 |
| - "scrolled": true |
281 |
| - }, |
| 266 | + "metadata": {}, |
282 | 267 | "outputs": [],
|
283 | 268 | "source": [
|
284 |
| - "import json\n", |
285 |
| - "from urllib.parse import urlparse\n", |
286 |
| - " \n", |
287 |
| - "parsed_url = urlparse(transformer.output_path)\n", |
288 |
| - "bucket_name = parsed_url.netloc\n", |
289 |
| - "file_key = '{}/data.csv.out'.format(parsed_url.path[1:])\n", |
290 |
| - " \n", |
291 |
| - "s3 = boto3.resource('s3')\n", |
292 |
| - "output_obj = s3.Object(bucket_name, file_key)\n", |
293 |
| - "output = output_obj.get()[\"Body\"].read().decode('utf-8')" |
| 269 | + "import os\n", |
| 270 | + "\n", |
| 271 | + "tmp_dir = '/tmp/data'\n", |
| 272 | + "\n", |
| 273 | + "if not os.path.exists(tmp_dir):\n", |
| 274 | + " os.makedirs(tmp_dir)" |
294 | 275 | ]
|
295 | 276 | },
|
296 | 277 | {
|
297 | 278 | "cell_type": "markdown",
|
298 | 279 | "metadata": {},
|
299 | 280 | "source": [
|
300 |
| - "The output here is a list of predictions, where each prediction is a list of probabilities, one for each possible label. Since we read the output as a string, we use `ast.literal_eval()` to turn it into a list:" |
| 281 | + "And now we'll print out the images:" |
301 | 282 | ]
|
302 | 283 | },
|
303 | 284 | {
|
|
306 | 287 | "metadata": {},
|
307 | 288 | "outputs": [],
|
308 | 289 | "source": [
|
309 |
| - "import ast\n", |
| 290 | + "from numpy import genfromtxt\n", |
| 291 | + "import matplotlib.pyplot as plt\n", |
| 292 | + "\n", |
| 293 | + "plt.rcParams['figure.figsize'] = (2,10)\n", |
310 | 294 | "\n",
|
311 |
| - "output = ast.literal_eval(output)\n", |
312 |
| - "probabilities = output[0]" |
| 295 | + "def show_digit(img, caption='', subplot=None):\n", |
| 296 | + " if subplot == None:\n", |
| 297 | + " _,(subplot) = plt.subplots(1,1)\n", |
| 298 | + " imgr = img.reshape((28,28))\n", |
| 299 | + " subplot.axis('off')\n", |
| 300 | + " subplot.imshow(imgr, cmap='gray')\n", |
| 301 | + " plt.title(caption)\n", |
| 302 | + "\n", |
| 303 | + "for i in range(10):\n", |
| 304 | + " input_file_name = 'data-{}.csv'.format(i)\n", |
| 305 | + " input_file_key = '{}/{}'.format(input_file_path, input_file_name)\n", |
| 306 | + " \n", |
| 307 | + " s3.Bucket(sample_data_bucket).download_file(input_file_key, os.path.join(tmp_dir, input_file_name))\n", |
| 308 | + " input_data = genfromtxt(os.path.join(tmp_dir, input_file_name), delimiter=',')\n", |
| 309 | + "\n", |
| 310 | + " show_digit(input_data)" |
313 | 311 | ]
|
314 | 312 | },
|
315 | 313 | {
|
316 | 314 | "cell_type": "markdown",
|
317 |
| - "metadata": {}, |
| 315 | + "metadata": { |
| 316 | + "scrolled": true |
| 317 | + }, |
318 | 318 | "source": [
|
319 |
| - "Now that we have the list of probabilities, finding the maximum element of the list gives us the predicted label:" |
| 319 | + "Here, we can see the original labels are:\n", |
| 320 | + "\n", |
| 321 | + "```\n", |
| 322 | + "7, 2, 1, 0, 4, 1, 4, 9, 5, 9\n", |
| 323 | + "```\n", |
| 324 | + "\n", |
| 325 | + "Now let's print out the predictions to compare:" |
320 | 326 | ]
|
321 | 327 | },
|
322 | 328 | {
|
|
325 | 331 | "metadata": {},
|
326 | 332 | "outputs": [],
|
327 | 333 | "source": [
|
328 |
| - "prediction = probabilities.index(max(probabilities))\n", |
329 |
| - "print('Prediction is {}'.format(prediction))" |
| 334 | + "print(predictions)" |
330 | 335 | ]
|
331 | 336 | }
|
332 | 337 | ],
|
|
0 commit comments