Skip to content

Commit 263c12f

Browse files
committed
Add integrations tests and refactor
1 parent 9409168 commit 263c12f

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

src/sagemaker/session.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,6 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
11241124
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
11251125
print(secondary_training_status_message(description, None), end='')
11261126
instance_count = description['ResourceConfig']['InstanceCount']
1127-
status = description['TrainingJobStatus']
11281127

11291128
stream_names = [] # The list of log streams
11301129
positions = {} # The current position in each stream, map of stream name -> position
@@ -1135,9 +1134,8 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
11351134
client = self.boto_session.client('logs', config=config)
11361135
log_group = '/aws/sagemaker/TrainingJobs'
11371136

1138-
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1137+
state = _get_initial_job_state(description, 'TrainingJobStatus', wait)
11391138

1140-
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
11411139
dot = False
11421140

11431141
color_wrap = sagemaker.logs.ColorWrap()
@@ -1211,7 +1209,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
12111209

12121210
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
12131211
instance_count = description['TransformResources']['InstanceCount']
1214-
status = description['TransformJobStatus']
12151212

12161213
stream_names = [] # The list of log streams
12171214
positions = {} # The current position in each stream, map of stream name -> position
@@ -1222,9 +1219,8 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
12221219
client = self.boto_session.client('logs', config=config)
12231220
log_group = '/aws/sagemaker/TransformJobs'
12241221

1225-
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1222+
state = _get_initial_job_state(description, 'TransformJobStatus', wait)
12261223

1227-
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
12281224
dot = False
12291225

12301226
color_wrap = sagemaker.logs.ColorWrap()
@@ -1272,9 +1268,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
12721268
self._check_job_status(job_name, description, 'TransformJobStatus')
12731269
if dot:
12741270
print()
1275-
# Customers are not billed for hardware provisioning, so billable time is less than total time
1276-
billable_time = (description['TransformEndTime'] - description['TransformStartTime']) * instance_count
1277-
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
12781271

12791272

12801273
def container_def(image, model_data_url=None, env=None):
@@ -1644,6 +1637,12 @@ def _vpc_config_from_training_job(training_job_desc, vpc_config_override=vpc_uti
16441637
return vpc_utils.sanitize(vpc_config_override)
16451638

16461639

1640+
def _get_initial_job_state(description, status_key, wait):
1641+
status = description[status_key]
1642+
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1643+
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
1644+
1645+
16471646
def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap):
16481647
if len(stream_names) < instance_count:
16491648
# Log streams are created whenever a container starts writing to stdout/err, so this list

src/sagemaker/transformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
7979
self.sagemaker_session = sagemaker_session or Session()
8080

8181
def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
82-
job_name=None, wait=True, logs=True):
82+
job_name=None, wait=False, logs=False):
8383
"""Start a new transform job.
8484
8585
Args:
@@ -97,9 +97,9 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
9797
split_type (str): The record delimiter for the input object (default: 'None').
9898
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
9999
job_name (str): job name (default: None). If not specified, one will be generated.
100-
wait (bool): Whether the call should wait until the job completes (default: True).
100+
wait (bool): Whether the call should wait until the job completes (default: False).
101101
logs (bool): Whether to show the logs produced by the job.
102-
Only meaningful when wait is True (default: True).
102+
Only meaningful when wait is True (default: False).
103103
"""
104104
local_mode = self.sagemaker_session.local_mode
105105
if not local_mode and not data.startswith('s3://'):

tests/integ/test_transformer.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,37 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
148148
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
149149

150150

151-
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
151+
def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version):
152+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
153+
script_path = os.path.join(data_path, 'mnist.py')
154+
155+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
156+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
157+
framework_version=mxnet_full_version)
158+
159+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
160+
key_prefix='integ-test-data/mxnet_mnist/train')
161+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
162+
key_prefix='integ-test-data/mxnet_mnist/test')
163+
job_name = unique_name_from_base('test-mxnet-transform')
164+
165+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
166+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
167+
168+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
169+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
170+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
171+
key_prefix=transform_input_key_prefix)
172+
173+
with timeout(minutes=45):
174+
transformer = _create_transformer_and_transform_job(mx, transform_input, wait=True, logs=True)
175+
176+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
177+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
178+
transformer.wait()
179+
180+
181+
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, wait=False, logs=False):
152182
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
153-
transformer.transform(transform_input, content_type='text/csv')
183+
transformer.transform(transform_input, content_type='text/csv', wait=wait, logs=logs)
154184
return transformer

0 commit comments

Comments
 (0)