Skip to content

Commit 7333063

Browse files
committed
black check
1 parent 103a06d commit 7333063

File tree

4 files changed

+60
-53
lines changed

4 files changed

+60
-53
lines changed

src/sagemaker/session.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1062,15 +1062,17 @@ def stop_transform_job(self, name):
10621062
ClientError: If an error occurs while trying to stop the batch transform job.
10631063
"""
10641064
try:
1065-
LOGGER.info('Stopping transform job: {}'.format(name))
1065+
LOGGER.info("Stopping transform job: {}".format(name))
10661066
self.sagemaker_client.stop_transform_job(TransformJobName=name)
10671067
except ClientError as e:
1068-
error_code = e.response['Error']['Code']
1068+
error_code = e.response["Error"]["Code"]
10691069
# allow to pass if the job already stopped
1070-
if error_code == 'ValidationException':
1071-
LOGGER.info('Transform job: {} is already stopped or not running.'.format(name))
1070+
if error_code == "ValidationException":
1071+
LOGGER.info("Transform job: {} is already stopped or not running.".format(name))
10721072
else:
1073-
LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name))
1073+
LOGGER.error(
1074+
"Error occurred while attempting to stop transform job: {}.".format(name)
1075+
)
10741076
raise
10751077

10761078
def _check_job_status(self, job, desc, status_key_name):

src/sagemaker/transformer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,13 @@ def wait(self):
228228
self._ensure_last_transform_job()
229229
self.latest_transform_job.wait()
230230

231-
def stop_transform_job(self):
231+
def stop_transform_job(self, wait=True):
232232
"""Stop latest running batch transform job.
233233
"""
234234
self._ensure_last_transform_job()
235235
self.latest_transform_job.stop()
236+
if wait:
237+
self.latest_transform_job.wait()
236238

237239
def _ensure_last_transform_job(self):
238240
"""Placeholder docstring"""
@@ -352,6 +354,7 @@ def wait(self):
352354
self.sagemaker_session.wait_for_transform_job(self.job_name)
353355

354356
def stop(self):
357+
"""Placeholder docstring"""
355358
self.sagemaker_session.stop_transform_job(name=self.job_name)
356359

357360
@staticmethod

tests/integ/test_transformer.py

+47-45
Original file line numberDiff line numberDiff line change
@@ -351,66 +351,68 @@ def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version)
351351
)
352352

353353

354-
def _create_transformer_and_transform_job(
355-
estimator,
356-
transform_input,
357-
volume_kms_key=None,
358-
input_filter=None,
359-
output_filter=None,
360-
join_source=None,
361-
):
362-
transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key)
363-
transformer.transform(
364-
transform_input,
365-
content_type="text/csv",
366-
input_filter=input_filter,
367-
output_filter=output_filter,
368-
join_source=join_source,
369-
)
370-
371-
372354
def test_stop_transform_job(sagemaker_session, mxnet_full_version):
373-
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
374-
script_path = os.path.join(data_path, 'mnist.py')
375-
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
355+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
356+
script_path = os.path.join(data_path, "mnist.py")
357+
tags = [{"Key": "some-tag", "Value": "value-for-tag"}]
376358

377-
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
378-
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
379-
framework_version=mxnet_full_version)
359+
mx = MXNet(
360+
entry_point=script_path,
361+
role="SageMakerRole",
362+
train_instance_count=1,
363+
train_instance_type="ml.c4.xlarge",
364+
sagemaker_session=sagemaker_session,
365+
framework_version=mxnet_full_version,
366+
)
380367

381-
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
382-
key_prefix='integ-test-data/mxnet_mnist/train')
383-
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
384-
key_prefix='integ-test-data/mxnet_mnist/test')
385-
job_name = unique_name_from_base('test-mxnet-transform')
368+
train_input = mx.sagemaker_session.upload_data(
369+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
370+
)
371+
test_input = mx.sagemaker_session.upload_data(
372+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
373+
)
374+
job_name = unique_name_from_base("test-mxnet-transform")
386375

387376
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
388-
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
377+
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
389378

390-
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
391-
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
392-
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
393-
key_prefix=transform_input_key_prefix)
379+
transform_input_path = os.path.join(data_path, "transform", "data.csv")
380+
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
381+
transform_input = mx.sagemaker_session.upload_data(
382+
path=transform_input_path, key_prefix=transform_input_key_prefix
383+
)
394384

395-
transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
396-
transformer.transform(transform_input, content_type='text/csv')
385+
transformer = mx.transformer(1, "ml.m4.xlarge", tags=tags)
386+
transformer.transform(transform_input, content_type="text/csv")
397387

398388
time.sleep(15)
399389

400390
latest_transform_job_name = transformer.latest_transform_job.name
401391

402-
print('Attempting to stop {}'.format(latest_transform_job_name))
392+
print("Attempting to stop {}".format(latest_transform_job_name))
403393

404394
transformer.stop_transform_job()
405395

406-
desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \
407-
.describe_transform_job(TransformJobName=latest_transform_job_name)
408-
assert desc['TransformJobStatus'] == 'Stopping'
396+
desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job(
397+
TransformJobName=latest_transform_job_name
398+
)
399+
assert desc["TransformJobStatus"] == "Stopping"
409400

410401

411-
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None,
412-
input_filter=None, output_filter=None, join_source=None):
413-
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
414-
transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter,
415-
output_filter=output_filter, join_source=join_source)
402+
def _create_transformer_and_transform_job(
403+
estimator,
404+
transform_input,
405+
volume_kms_key=None,
406+
input_filter=None,
407+
output_filter=None,
408+
join_source=None,
409+
):
410+
transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key)
411+
transformer.transform(
412+
transform_input,
413+
content_type="text/csv",
414+
input_filter=input_filter,
415+
output_filter=output_filter,
416+
join_source=join_source,
417+
)
416418
return transformer

tests/unit/test_transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def test_restart_output_path(start_new_job, transformer, sagemaker_session):
452452

453453

454454
def test_stop_transform_job(sagemaker_session, transformer):
455-
sagemaker_session.stop_transform_job = Mock(name='stop_transform_job')
455+
sagemaker_session.stop_transform_job = Mock(name="stop_transform_job")
456456
transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME)
457457

458458
transformer.stop_transform_job()
@@ -463,4 +463,4 @@ def test_stop_transform_job(sagemaker_session, transformer):
463463
def test_stop_transform_job_no_transform_job(transformer):
464464
with pytest.raises(ValueError) as e:
465465
transformer.stop_transform_job()
466-
assert 'No transform job available' in str(e)
466+
assert "No transform job available" in str(e)

0 commit comments

Comments
 (0)