Skip to content

Commit 03d7a05

Browse files
committed
feature: Estimator.fit like logs for transformer
1 parent d1e2ab7 commit 03d7a05

File tree

3 files changed

+219
-36
lines changed

3 files changed

+219
-36
lines changed

src/sagemaker/session.py

+116-31
Original file line numberDiff line numberDiff line change
@@ -1164,37 +1164,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
11641164
last_describe_job_call = time.time()
11651165
last_description = description
11661166
while True:
1167-
if len(stream_names) < instance_count:
1168-
# Log streams are created whenever a container starts writing to stdout/err, so this list
1169-
# may be dynamic until we have a stream for every instance.
1170-
try:
1171-
streams = client.describe_log_streams(logGroupName=log_group, logStreamNamePrefix=job_name + '/',
1172-
orderBy='LogStreamName', limit=instance_count)
1173-
stream_names = [s['logStreamName'] for s in streams['logStreams']]
1174-
positions.update([(s, sagemaker.logs.Position(timestamp=0, skip=0))
1175-
for s in stream_names if s not in positions])
1176-
except ClientError as e:
1177-
# On the very first training job run on an account, there's no log group until
1178-
# the container starts logging, so ignore any errors thrown about that
1179-
err = e.response.get('Error', {})
1180-
if err.get('Code', None) != 'ResourceNotFoundException':
1181-
raise
1182-
1183-
if len(stream_names) > 0:
1184-
if dot:
1185-
print('')
1186-
dot = False
1187-
for idx, event in sagemaker.logs.multi_stream_iter(client, log_group, stream_names, positions):
1188-
color_wrap(idx, event['message'])
1189-
ts, count = positions[stream_names[idx]]
1190-
if event['timestamp'] == ts:
1191-
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1)
1192-
else:
1193-
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=event['timestamp'], skip=1)
1194-
else:
1195-
dot = True
1196-
print('.', end='')
1197-
sys.stdout.flush()
1167+
_flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap)
11981168
if state == LogState.COMPLETE:
11991169
break
12001170

@@ -1225,6 +1195,87 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
12251195
billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count
12261196
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
12271197

1198+
def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning
1199+
"""Display the logs for a given transform job, optionally tailing them until the
1200+
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1201+
based on which instance the log entry is from.
1202+
1203+
Args:
1204+
job_name (str): Name of the transform job to display the logs for.
1205+
wait (bool): Whether to keep looking for new log entries until the job completes (default: False).
1206+
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1207+
1208+
Raises:
1209+
ValueError: If the transform job fails.
1210+
"""
1211+
1212+
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1213+
instance_count = description['TransformResources']['InstanceCount']
1214+
status = description['TransformJobStatus']
1215+
1216+
stream_names = [] # The list of log streams
1217+
positions = {} # The current position in each stream, map of stream name -> position
1218+
1219+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1220+
# to be interrupted by a transient exception.
1221+
config = botocore.config.Config(retries={'max_attempts': 15})
1222+
client = self.boto_session.client('logs', config=config)
1223+
log_group = '/aws/sagemaker/TransformJobs'
1224+
1225+
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1226+
1227+
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
1228+
dot = False
1229+
1230+
color_wrap = sagemaker.logs.ColorWrap()
1231+
1232+
# The loop below implements a state machine that alternates between checking the job status and
1233+
# reading whatever is available in the logs at this point. Note, that if we were called with
1234+
# wait == False, we never check the job status.
1235+
#
1236+
# If wait == TRUE and job is not completed, the initial state is TAILING
1237+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete).
1238+
#
1239+
# The state table:
1240+
#
1241+
# STATE ACTIONS CONDITION NEW STATE
1242+
# ---------------- ---------------- ----------------- ----------------
1243+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1244+
# Else TAILING
1245+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
1246+
# COMPLETE Read logs, Exit N/A
1247+
#
1248+
# Notes:
1249+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
1250+
# the job was marked complete.
1251+
last_describe_job_call = time.time()
1252+
while True:
1253+
_flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap)
1254+
if state == LogState.COMPLETE:
1255+
break
1256+
1257+
time.sleep(poll)
1258+
1259+
if state == LogState.JOB_COMPLETE:
1260+
state = LogState.COMPLETE
1261+
elif time.time() - last_describe_job_call >= 30:
1262+
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1263+
last_describe_job_call = time.time()
1264+
1265+
status = description['TransformJobStatus']
1266+
1267+
if status == 'Completed' or status == 'Failed' or status == 'Stopped':
1268+
print()
1269+
state = LogState.JOB_COMPLETE
1270+
1271+
if wait:
1272+
self._check_job_status(job_name, description, 'TransformJobStatus')
1273+
if dot:
1274+
print()
1275+
# Customers are not billed for hardware provisioning, so billable time is less than total time
1276+
billable_time = (description['TransformEndTime'] - description['TransformStartTime']) * instance_count
1277+
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
1278+
12281279

12291280
def container_def(image, model_data_url=None, env=None):
12301281
"""Create a definition for executing a container as part of a SageMaker model.
@@ -1591,3 +1642,37 @@ def _vpc_config_from_training_job(training_job_desc, vpc_config_override=vpc_uti
15911642
return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY)
15921643
else:
15931644
return vpc_utils.sanitize(vpc_config_override)
1645+
1646+
1647+
def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap):
1648+
if len(stream_names) < instance_count:
1649+
# Log streams are created whenever a container starts writing to stdout/err, so this list
1650+
# may be dynamic until we have a stream for every instance.
1651+
try:
1652+
streams = client.describe_log_streams(logGroupName=log_group, logStreamNamePrefix=job_name + '/',
1653+
orderBy='LogStreamName', limit=instance_count)
1654+
stream_names = [s['logStreamName'] for s in streams['logStreams']]
1655+
positions.update([(s, sagemaker.logs.Position(timestamp=0, skip=0))
1656+
for s in stream_names if s not in positions])
1657+
except ClientError as e:
1658+
# On the very first training job run on an account, there's no log group until
1659+
# the container starts logging, so ignore any errors thrown about that
1660+
err = e.response.get('Error', {})
1661+
if err.get('Code', None) != 'ResourceNotFoundException':
1662+
raise
1663+
1664+
if len(stream_names) > 0:
1665+
if dot:
1666+
print('')
1667+
dot = False
1668+
for idx, event in sagemaker.logs.multi_stream_iter(client, log_group, stream_names, positions):
1669+
color_wrap(idx, event['message'])
1670+
ts, count = positions[stream_names[idx]]
1671+
if event['timestamp'] == ts:
1672+
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1)
1673+
else:
1674+
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=event['timestamp'], skip=1)
1675+
else:
1676+
dot = True
1677+
print('.', end='')
1678+
sys.stdout.flush()

src/sagemaker/transformer.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
7979
self.sagemaker_session = sagemaker_session or Session()
8080

8181
def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
82-
job_name=None):
82+
job_name=None, wait=True, logs=True):
8383
"""Start a new transform job.
8484
8585
Args:
@@ -97,6 +97,9 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
9797
split_type (str): The record delimiter for the input object (default: 'None').
9898
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
9999
job_name (str): job name (default: None). If not specified, one will be generated.
100+
wait (bool): Whether the call should wait until the job completes (default: True).
101+
logs (bool): Whether to show the logs produced by the job.
102+
Only meaningful when wait is True (default: True).
100103
"""
101104
local_mode = self.sagemaker_session.local_mode
102105
if not local_mode and not data.startswith('s3://'):
@@ -113,6 +116,8 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
113116

114117
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
115118
split_type)
119+
if wait:
120+
self.latest_transform_job.wait(logs=logs)
116121

117122
def delete_model(self):
118123
"""Delete the corresponding SageMaker model for this Transformer.
@@ -130,9 +135,9 @@ def _retrieve_image_name(self):
130135
'Local instance types require locally created models.'
131136
% self.model_name)
132137

133-
def wait(self):
138+
def wait(self, logs=True):
134139
self._ensure_last_transform_job()
135-
self.latest_transform_job.wait()
140+
self.latest_transform_job.wait(logs=logs)
136141

137142
def _ensure_last_transform_job(self):
138143
if self.latest_transform_job is None:
@@ -205,8 +210,11 @@ def start_new(cls, transformer, data, data_type, content_type, compression_type,
205210

206211
return cls(transformer.sagemaker_session, transformer._current_job_name)
207212

208-
def wait(self):
209-
self.sagemaker_session.wait_for_transform_job(self.job_name)
213+
def wait(self, logs=True):
214+
if logs:
215+
self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True)
216+
else:
217+
self.sagemaker_session.wait_for_transform_job(self.job_name)
210218

211219
@staticmethod
212220
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):

tests/unit/test_session.py

+90
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,38 @@ def test_s3_input_all_arguments():
308308
IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
309309
IN_PROGRESS_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'InProgress'})
310310

311+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT = {
312+
'TransformJobStatus': 'Completed',
313+
'ModelName': 'some-model',
314+
'TransformJobName': JOB_NAME,
315+
'TransformResources': {
316+
'InstanceCount': INSTANCE_COUNT,
317+
'InstanceType': INSTANCE_TYPE
318+
},
319+
'TransformEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000),
320+
'TransformStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000),
321+
'TransformOutput': {
322+
'AssembleWith': 'None',
323+
'KmsKeyId': '',
324+
'S3OutputPath': S3_OUTPUT
325+
},
326+
'TransformInput': {
327+
'CompressionType': 'None',
328+
'ContentType': 'text/csv',
329+
'DataSource': {
330+
'S3DataType': 'S3Prefix',
331+
'S3Uri': S3_INPUT_URI
332+
},
333+
'SplitType': 'Line'
334+
}
335+
}
336+
337+
STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT)
338+
STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'Stopped'})
339+
340+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT)
341+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'InProgress'})
342+
311343

312344
@pytest.fixture()
313345
def sagemaker_session():
@@ -653,6 +685,7 @@ def sagemaker_session_complete():
653685
boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS
654686
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
655687
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
688+
ims.sagemaker_client.describe_transform_job.return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT
656689
return ims
657690

658691

@@ -663,6 +696,7 @@ def sagemaker_session_stopped():
663696
boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS
664697
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
665698
ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT
699+
ims.sagemaker_client.describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT
666700
return ims
667701

668702

@@ -675,6 +709,9 @@ def sagemaker_session_ready_lifecycle():
675709
ims.sagemaker_client.describe_training_job.side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT,
676710
IN_PROGRESS_DESCRIBE_JOB_RESULT,
677711
COMPLETED_DESCRIBE_JOB_RESULT]
712+
ims.sagemaker_client.describe_transform_job.side_effect = [IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
713+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
714+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT]
678715
return ims
679716

680717

@@ -687,6 +724,9 @@ def sagemaker_session_full_lifecycle():
687724
ims.sagemaker_client.describe_training_job.side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT,
688725
IN_PROGRESS_DESCRIBE_JOB_RESULT,
689726
COMPLETED_DESCRIBE_JOB_RESULT]
727+
ims.sagemaker_client.describe_transform_job.side_effect = [IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
728+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
729+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT]
690730
return ims
691731

692732

@@ -740,6 +780,56 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
740780
call(0, 'hi there #2a'), call(0, 'hi there #3')]
741781

742782

783+
@patch('sagemaker.logs.ColorWrap')
784+
def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete):
785+
ims = sagemaker_session_complete
786+
ims.logs_for_transform_job(JOB_NAME)
787+
ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME)
788+
cw().assert_called_with(0, 'hi there #1')
789+
790+
791+
@patch('sagemaker.logs.ColorWrap')
792+
def test_logs_for_transform_job_no_wait_stopped_job(cw, sagemaker_session_stopped):
793+
ims = sagemaker_session_stopped
794+
ims.logs_for_transform_job(JOB_NAME)
795+
ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME)
796+
cw().assert_called_with(0, 'hi there #1')
797+
798+
799+
@patch('sagemaker.logs.ColorWrap')
800+
def test_logs_for_transform_job_wait_on_completed(cw, sagemaker_session_complete):
801+
ims = sagemaker_session_complete
802+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
803+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
804+
cw().assert_called_with(0, 'hi there #1')
805+
806+
807+
@patch('sagemaker.logs.ColorWrap')
808+
def test_logs_for_transform_job_wait_on_stopped(cw, sagemaker_session_stopped):
809+
ims = sagemaker_session_stopped
810+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
811+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
812+
cw().assert_called_with(0, 'hi there #1')
813+
814+
815+
@patch('sagemaker.logs.ColorWrap')
816+
def test_logs_for_transform_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle):
817+
ims = sagemaker_session_ready_lifecycle
818+
ims.logs_for_transform_job(JOB_NAME)
819+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
820+
cw().assert_called_with(0, 'hi there #1')
821+
822+
823+
@patch('sagemaker.logs.ColorWrap')
824+
@patch('time.time', side_effect=[0, 30, 60, 90, 120, 150, 180])
825+
def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle):
826+
ims = sagemaker_session_full_lifecycle
827+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
828+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] * 3
829+
assert cw().call_args_list == [call(0, 'hi there #1'), call(0, 'hi there #2'),
830+
call(0, 'hi there #2a'), call(0, 'hi there #3')]
831+
832+
743833
MODEL_NAME = 'some-model'
744834
PRIMARY_CONTAINER = {
745835
'Environment': {},

0 commit comments

Comments
 (0)