|
16 | 16 | import os
|
17 | 17 | import pickle
|
18 | 18 | import sys
|
| 19 | +import time |
19 | 20 |
|
20 | 21 | import pytest
|
21 | 22 |
|
@@ -364,4 +365,50 @@ def _create_transformer_and_transform_job(
|
364 | 365 | output_filter=output_filter,
|
365 | 366 | join_source=join_source,
|
366 | 367 | )
|
| 368 | + |
| 369 | + |
| 370 | +def test_stop_transform_job(sagemaker_session, mxnet_full_version): |
| 371 | + data_path = os.path.join(DATA_DIR, 'mxnet_mnist') |
| 372 | + script_path = os.path.join(data_path, 'mnist.py') |
| 373 | + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] |
| 374 | + |
| 375 | + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, |
| 376 | + train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, |
| 377 | + framework_version=mxnet_full_version) |
| 378 | + |
| 379 | + train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), |
| 380 | + key_prefix='integ-test-data/mxnet_mnist/train') |
| 381 | + test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), |
| 382 | + key_prefix='integ-test-data/mxnet_mnist/test') |
| 383 | + job_name = unique_name_from_base('test-mxnet-transform') |
| 384 | + |
| 385 | + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): |
| 386 | + mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) |
| 387 | + |
| 388 | + transform_input_path = os.path.join(data_path, 'transform', 'data.csv') |
| 389 | + transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' |
| 390 | + transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, |
| 391 | + key_prefix=transform_input_key_prefix) |
| 392 | + |
| 393 | + transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) |
| 394 | + transformer.transform(transform_input, content_type='text/csv') |
| 395 | + |
| 396 | + time.sleep(15) |
| 397 | + |
| 398 | + latest_transform_job_name = transformer.latest_transform_job.name |
| 399 | + |
| 400 | + print('Attempting to stop {}'.format(latest_transform_job_name)) |
| 401 | + |
| 402 | + transformer.stop_transform_job() |
| 403 | + |
| 404 | + desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \ |
| 405 | + .describe_transform_job(TransformJobName=latest_transform_job_name) |
| 406 | + assert desc['TransformJobStatus'] == 'Stopping' |
| 407 | + |
| 408 | + |
| 409 | +def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, |
| 410 | + input_filter=None, output_filter=None, join_source=None): |
| 411 | + transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) |
| 412 | + transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter, |
| 413 | + output_filter=output_filter, join_source=join_source) |
367 | 414 | return transformer
|
0 commit comments