|
22 | 22 | from sagemaker import KMeans
|
23 | 23 | from sagemaker.mxnet import MXNet
|
24 | 24 | from sagemaker.transformer import Transformer
|
| 25 | +from sagemaker.estimator import Estimator |
25 | 26 | from sagemaker.utils import unique_name_from_base
|
26 | 27 | from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
|
27 | 28 | from tests.integ.kms_utils import get_or_create_kms_key
|
@@ -148,7 +149,7 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
|
148 | 149 | assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
|
149 | 150 |
|
150 | 151 |
|
151 |
| -def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version): |
| 152 | +def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version): |
152 | 153 | data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
|
153 | 154 | script_path = os.path.join(data_path, 'mnist.py')
|
154 | 155 | tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
|
@@ -182,6 +183,55 @@ def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version):
|
182 | 183 | assert tags == model_tags
|
183 | 184 |
|
184 | 185 |
|
| 186 | +def test_transform_byo_estimator(sagemaker_session): |
| 187 | + data_path = os.path.join(DATA_DIR, 'one_p_mnist') |
| 188 | + pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} |
| 189 | + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] |
| 190 | + |
| 191 | + # Load the data into memory as numpy arrays |
| 192 | + train_set_path = os.path.join(data_path, 'mnist.pkl.gz') |
| 193 | + with gzip.open(train_set_path, 'rb') as f: |
| 194 | + train_set, _, _ = pickle.load(f, **pickle_args) |
| 195 | + |
| 196 | + kmeans = KMeans(role='SageMakerRole', train_instance_count=1, |
| 197 | + train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session, |
| 198 | + output_path='s3://{}/'.format(sagemaker_session.default_bucket())) |
| 199 | + |
| 200 | + # set kmeans specific hp |
| 201 | + kmeans.init_method = 'random' |
| 202 | + kmeans.max_iterators = 1 |
| 203 | + kmeans.tol = 1 |
| 204 | + kmeans.num_trials = 1 |
| 205 | + kmeans.local_init_method = 'kmeans++' |
| 206 | + kmeans.half_life_time_size = 1 |
| 207 | + kmeans.epochs = 1 |
| 208 | + |
| 209 | + records = kmeans.record_set(train_set[0][:100]) |
| 210 | + |
| 211 | + job_name = unique_name_from_base('test-kmeans-attach') |
| 212 | + |
| 213 | + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): |
| 214 | + kmeans.fit(records, job_name=job_name) |
| 215 | + |
| 216 | + transform_input_path = os.path.join(data_path, 'transform_input.csv') |
| 217 | + transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform' |
| 218 | + transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path, |
| 219 | + key_prefix=transform_input_key_prefix) |
| 220 | + |
| 221 | + estimator = Estimator.attach(training_job_name=job_name, |
| 222 | + sagemaker_session=sagemaker_session) |
| 223 | + |
| 224 | + transformer = estimator.transformer(1, 'ml.m4.xlarge', tags=tags) |
| 225 | + transformer.transform(transform_input, content_type='text/csv') |
| 226 | + |
| 227 | + with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, |
| 228 | + minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): |
| 229 | + transformer.wait() |
| 230 | + model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) |
| 231 | + model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags'] |
| 232 | + assert tags == model_tags |
| 233 | + |
| 234 | + |
185 | 235 | def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
|
186 | 236 | transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
|
187 | 237 | transformer.transform(transform_input, content_type='text/csv')
|
|
0 commit comments