|
62 | 62 | "bucket='<bucket-name>'"
|
63 | 63 | ]
|
64 | 64 | },
|
| 65 | + { |
| 66 | + "cell_type": "code", |
| 67 | + "execution_count": null, |
| 68 | + "metadata": {}, |
| 69 | + "outputs": [], |
| 70 | + "source": [ |
| 71 | + "output_location = 's3://{}/kmeans_highlevel_example/output'.format(bucket)\n", |
| 72 | + "data_location = 's3://{}/kmeans_highlevel_example/data'.format(bucket)\n", |
| 73 | + "\n", |
| 74 | + "print('training data will be uploaded to: {}'.format(data_location))\n", |
| 75 | + "print('training artifacts will be uploaded to: {}'.format(output_location))" |
| 76 | + ] |
| 77 | + }, |
65 | 78 | {
|
66 | 79 | "cell_type": "markdown",
|
67 | 80 | "metadata": {},
|
|
117 | 130 | "show_digit(train_set[0][30], 'This is a {}'.format(train_set[1][30]))"
|
118 | 131 | ]
|
119 | 132 | },
|
120 |
| - { |
121 |
| - "cell_type": "markdown", |
122 |
| - "metadata": {}, |
123 |
| - "source": [ |
124 |
| - "### Data conversion\n", |
125 |
| - "\n", |
126 |
| - "Since algorithms have particular input and output requirements, converting the dataset is also part of the process that a data scientist goes through prior to initiating training. In this particular case, the hosted implementation of k-means takes recordio-wrapped protobuf, where the data we have today is a pickle-ized numpy array on disk.\n", |
127 |
| - "\n", |
128 |
| - "Some of the effort involved in the protobuf format conversion is hidden in a library that is imported, below. This library will be folded into the SDK for algorithm authors to make it easier for algorithm authors to support multiple formats. This doesn't __prevent__ algorithm authors from requiring non-standard formats, but it encourages them to support the standard ones.\n", |
129 |
| - "\n", |
130 |
| - "For this dataset, conversion takes approximately one minute." |
131 |
| - ] |
132 |
| - }, |
133 |
| - { |
134 |
| - "cell_type": "code", |
135 |
| - "execution_count": null, |
136 |
| - "metadata": {}, |
137 |
| - "outputs": [], |
138 |
| - "source": [ |
139 |
| - "%%time\n", |
140 |
| - "import io\n", |
141 |
| - "import sagemaker.kmeans\n", |
142 |
| - "\n", |
143 |
| - "vectors = [t.tolist() for t in train_set[0]]\n", |
144 |
| - "labels = [t.tolist() for t in train_set[1]]\n", |
145 |
| - "\n", |
146 |
| - "buf = io.BytesIO()\n", |
147 |
| - "sagemaker.kmeans.write_data_as_pb_recordio(vectors, labels, buf)\n", |
148 |
| - "buf.seek(0)" |
149 |
| - ] |
150 |
| - }, |
151 | 133 | {
|
152 | 134 | "cell_type": "markdown",
|
153 | 135 | "metadata": {},
|
154 | 136 | "source": [
|
155 | 137 | "## Upload training data"
|
156 | 138 | ]
|
157 | 139 | },
|
158 |
| - { |
159 |
| - "cell_type": "code", |
160 |
| - "execution_count": null, |
161 |
| - "metadata": {}, |
162 |
| - "outputs": [], |
163 |
| - "source": [ |
164 |
| - "%%time\n", |
165 |
| - "import boto3\n", |
166 |
| - "\n", |
167 |
| - "key = 'MNIST-1P-Test/recordio-pb-data'\n", |
168 |
| - "boto3.resource('s3').Bucket(bucket).Object(key).upload_fileobj(buf)\n", |
169 |
| - "s3_train_data = 's3://{}/{}'.format(bucket, key)\n", |
170 |
| - "print('uploaded training data location: {}'.format(s3_train_data))" |
171 |
| - ] |
172 |
| - }, |
173 |
| - { |
174 |
| - "cell_type": "code", |
175 |
| - "execution_count": null, |
176 |
| - "metadata": {}, |
177 |
| - "outputs": [], |
178 |
| - "source": [ |
179 |
| - "output_location = 's3://{}/kmeansoutput'.format(bucket)\n", |
180 |
| - "print('training artifacts will be uploaded to: {}'.format(output_location))" |
181 |
| - ] |
182 |
| - }, |
183 | 140 | {
|
184 | 141 | "cell_type": "markdown",
|
185 | 142 | "metadata": {},
|
|
197 | 154 | "metadata": {},
|
198 | 155 | "outputs": [],
|
199 | 156 | "source": [
|
200 |
| - "from sagemaker.kmeans import KMeans\n", |
| 157 | + "from sagemaker import KMeans\n", |
201 | 158 | "\n",
|
202 | 159 | "kmeans = KMeans(role=role,\n",
|
203 | 160 | " train_instance_count=2,\n",
|
204 | 161 | " train_instance_type='ml.c4.8xlarge',\n",
|
205 | 162 | " output_path=output_location,\n",
|
206 | 163 | " k=10,\n",
|
207 |
| - " feature_dim=784)" |
| 164 | + " data_location=data_location)" |
208 | 165 | ]
|
209 | 166 | },
|
210 | 167 | {
|
|
215 | 172 | "source": [
|
216 | 173 | "%%time\n",
|
217 | 174 | "\n",
|
218 |
| - "kmeans.fit({'train': s3_train_data})" |
| 175 | + "kmeans.fit(kmeans.record_set(train_set[0]))" |
219 | 176 | ]
|
220 | 177 | },
|
221 | 178 | {
|
|
279 | 236 | "%%time \n",
|
280 | 237 | "\n",
|
281 | 238 | "result = kmeans_predictor.predict(valid_set[0][0:100])\n",
|
282 |
| - "clusters = result['labels']" |
| 239 | + "clusters = [r.label['closest_cluster'].float32_tensor.values[0] for r in result]" |
283 | 240 | ]
|
284 | 241 | },
|
285 | 242 | {
|
|
349 | 306 | }
|
350 | 307 | ],
|
351 | 308 | "metadata": {
|
352 |
| - "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.", |
353 | 309 | "kernelspec": {
|
354 | 310 | "display_name": "Environment (conda_python3)",
|
355 | 311 | "language": "python",
|
|
366 | 322 | "nbconvert_exporter": "python",
|
367 | 323 | "pygments_lexer": "ipython3",
|
368 | 324 | "version": "3.6.3"
|
369 |
| - } |
| 325 | + }, |
| 326 | + "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." |
370 | 327 | },
|
371 | 328 | "nbformat": 4,
|
372 | 329 | "nbformat_minor": 2
|
|
0 commit comments