Skip to content

Commit 956a733

Browse files
committed
Add integrations tests and refactor
1 parent 813155e commit 956a733

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

src/sagemaker/session.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1295,11 +1295,8 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12951295
client = self.boto_session.client("logs", config=config)
12961296
log_group = "/aws/sagemaker/TrainingJobs"
12971297

1298-
job_already_completed = (
1299-
True if status == "Completed" or status == "Failed" or status == "Stopped" else False
1300-
)
1298+
state = _get_initial_job_state(description, 'TrainingJobStatus', wait)
13011299

1302-
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
13031300
dot = False
13041301

13051302
color_wrap = sagemaker.logs.ColorWrap()
@@ -1375,7 +1372,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
13751372

13761373
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
13771374
instance_count = description['TransformResources']['InstanceCount']
1378-
status = description['TransformJobStatus']
13791375

13801376
stream_names = [] # The list of log streams
13811377
positions = {} # The current position in each stream, map of stream name -> position
@@ -1386,9 +1382,8 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
13861382
client = self.boto_session.client('logs', config=config)
13871383
log_group = '/aws/sagemaker/TransformJobs'
13881384

1389-
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1385+
state = _get_initial_job_state(description, 'TransformJobStatus', wait)
13901386

1391-
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
13921387
dot = False
13931388

13941389
color_wrap = sagemaker.logs.ColorWrap()
@@ -1436,9 +1431,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
14361431
self._check_job_status(job_name, description, 'TransformJobStatus')
14371432
if dot:
14381433
print()
1439-
# Customers are not billed for hardware provisioning, so billable time is less than total time
1440-
billable_time = (description['TransformEndTime'] - description['TransformStartTime']) * instance_count
1441-
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
14421434

14431435

14441436
def container_def(image, model_data_url=None, env=None):
@@ -1833,6 +1825,12 @@ def _vpc_config_from_training_job(
18331825
return vpc_utils.sanitize(vpc_config_override)
18341826

18351827

1828+
def _get_initial_job_state(description, status_key, wait):
1829+
status = description[status_key]
1830+
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1831+
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
1832+
1833+
18361834
def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap):
18371835
if len(stream_names) < instance_count:
18381836
# Log streams are created whenever a container starts writing to stdout/err, so this list

src/sagemaker/transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def transform(
133133
meaning the entire input record will be joined to the inference result.
134134
You can use OutputFilter to select the useful portion before uploading to S3. (default: None).
135135
Valid values: Input, None.
136-
wait (bool): Whether the call should wait until the job completes (default: True).
136+
wait (bool): Whether the call should wait until the job completes (default: False).
137137
logs (bool): Whether to show the logs produced by the job.
138-
Only meaningful when wait is True (default: True).
138+
Only meaningful when wait is True (default: False).
139139
"""
140140
local_mode = self.sagemaker_session.local_mode
141141
if not local_mode and not data.startswith("s3://"):

tests/integ/test_transformer.py

+35
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,45 @@ def test_transform_byo_estimator(sagemaker_session):
301301
assert tags == model_tags
302302

303303

304+
def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version):
305+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
306+
script_path = os.path.join(data_path, 'mnist.py')
307+
308+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
309+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
310+
framework_version=mxnet_full_version)
311+
312+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
313+
key_prefix='integ-test-data/mxnet_mnist/train')
314+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
315+
key_prefix='integ-test-data/mxnet_mnist/test')
316+
job_name = unique_name_from_base('test-mxnet-transform')
317+
318+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
319+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
320+
321+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
322+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
323+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
324+
key_prefix=transform_input_key_prefix)
325+
326+
with timeout(minutes=45):
327+
transformer = _create_transformer_and_transform_job(mx, transform_input, wait=True, logs=True)
328+
329+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
330+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
331+
transformer.wait()
332+
333+
304334
def _create_transformer_and_transform_job(
305335
estimator,
306336
transform_input,
307337
volume_kms_key=None,
308338
input_filter=None,
309339
output_filter=None,
310340
join_source=None,
341+
wait=False,
342+
logs=False,
311343
):
312344
transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key)
313345
transformer.transform(
@@ -316,5 +348,8 @@ def _create_transformer_and_transform_job(
316348
input_filter=input_filter,
317349
output_filter=output_filter,
318350
join_source=join_source,
351+
wait=wait,
352+
logs=logs,
319353
)
320354
return transformer
355+

0 commit comments

Comments
 (0)