Skip to content

Commit 790fe8f

Browse files
committed
feature: Estimator.fit like logs for transformer
1 parent 7529a22 commit 790fe8f

File tree

2 files changed

+123
-5
lines changed

2 files changed

+123
-5
lines changed

src/sagemaker/session.py

+110
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,116 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
12251225
billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count
12261226
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
12271227

1228+
def logs_for_transform_job(self, job_name, wait=False, poll=10):
1229+
"""Display the logs for a given transform job, optionally tailing them until the
1230+
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1231+
based on which instance the log entry is from.
1232+
Args:
1233+
job_name (str): Name of the transform job to display the logs for.
1234+
wait (bool): Whether to keep looking for new log entries until the job completes (default: False).
1235+
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1236+
Raises:
1237+
ValueError: If waiting and the transform job fails.
1238+
"""
1239+
1240+
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1241+
instance_count = description['TransformResources']['InstanceCount']
1242+
status = description['TransformJobStatus']
1243+
1244+
stream_names = [] # The list of log streams
1245+
positions = {} # The current position in each stream, map of stream name -> position
1246+
1247+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1248+
# to be interrupted by a transient exception.
1249+
config = botocore.config.Config(retries={'max_attempts': 15})
1250+
client = self.boto_session.client('logs', config=config)
1251+
log_group = '/aws/sagemaker/TransformJobs'
1252+
1253+
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1254+
1255+
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
1256+
dot = False
1257+
1258+
color_wrap = sagemaker.logs.ColorWrap()
1259+
1260+
# The loop below implements a state machine that alternates between checking the job status and
1261+
# reading whatever is available in the logs at this point. Note, that if we were called with
1262+
# wait == False, we never check the job status.
1263+
#
1264+
# If wait == TRUE and job is not completed, the initial state is TAILING
1265+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete).
1266+
#
1267+
# The state table:
1268+
#
1269+
# STATE ACTIONS CONDITION NEW STATE
1270+
# ---------------- ---------------- ----------------- ----------------
1271+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1272+
# Else TAILING
1273+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
1274+
# COMPLETE Read logs, Exit N/A
1275+
#
1276+
# Notes:
1277+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
1278+
# the job was marked complete.
1279+
last_describe_job_call = time.time()
1280+
last_description = description
1281+
while True:
1282+
if len(stream_names) < instance_count:
1283+
# Log streams are created whenever a container starts writing to stdout/err, so this list
1284+
# may be dynamic until we have a stream for every instance.
1285+
try:
1286+
streams = client.describe_log_streams(logGroupName=log_group, logStreamNamePrefix=job_name + '/',
1287+
orderBy='LogStreamName', limit=instance_count)
1288+
stream_names = [s['logStreamName'] for s in streams['logStreams']]
1289+
positions.update([(s, sagemaker.logs.Position(timestamp=0, skip=0))
1290+
for s in stream_names if s not in positions])
1291+
except ClientError as e:
1292+
# On the very first training job run on an account, there's no log group until
1293+
# the container starts logging, so ignore any errors thrown about that
1294+
err = e.response.get('Error', {})
1295+
if err.get('Code', None) != 'ResourceNotFoundException':
1296+
raise
1297+
1298+
if len(stream_names) > 0:
1299+
if dot:
1300+
print('')
1301+
dot = False
1302+
for idx, event in sagemaker.logs.multi_stream_iter(client, log_group, stream_names, positions):
1303+
color_wrap(idx, event['message'])
1304+
ts, count = positions[stream_names[idx]]
1305+
if event['timestamp'] == ts:
1306+
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1)
1307+
else:
1308+
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=event['timestamp'], skip=1)
1309+
else:
1310+
dot = True
1311+
print('.', end='')
1312+
sys.stdout.flush()
1313+
if state == LogState.COMPLETE:
1314+
break
1315+
1316+
time.sleep(poll)
1317+
1318+
if state == LogState.JOB_COMPLETE:
1319+
state = LogState.COMPLETE
1320+
elif time.time() - last_describe_job_call >= 30:
1321+
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1322+
last_describe_job_call = time.time()
1323+
1324+
status = description['TransformJobStatus']
1325+
1326+
if status == 'Completed' or status == 'Failed' or status == 'Stopped':
1327+
print()
1328+
state = LogState.JOB_COMPLETE
1329+
1330+
if wait:
1331+
self._check_job_status(job_name, description, 'TransformJobStatus')
1332+
if dot:
1333+
print()
1334+
# Customers are not billed for hardware provisioning, so billable time is less than total time
1335+
billable_time = (description['TransformEndTime'] - description['TransformStartTime']) * instance_count
1336+
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
1337+
12281338

12291339
def container_def(image, model_data_url=None, env=None):
12301340
"""Create a definition for executing a container as part of a SageMaker model.

src/sagemaker/transformer.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
7979
self.sagemaker_session = sagemaker_session or Session()
8080

8181
def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
82-
job_name=None):
82+
wait=True, logs=True, job_name=None):
8383
"""Start a new transform job.
8484
8585
Args:
@@ -96,6 +96,9 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
9696
Valid values: 'Gzip', None.
9797
split_type (str): The record delimiter for the input object (default: 'None').
9898
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
99+
wait (bool): Whether the call should wait until the job completes (default: True).
100+
logs (bool): Whether to show the logs produced by the job.
101+
Only meaningful when wait is True (default: True).
99102
job_name (str): job name (default: None). If not specified, one will be generated.
100103
"""
101104
local_mode = self.sagemaker_session.local_mode
@@ -113,6 +116,8 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
113116

114117
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
115118
split_type)
119+
if wait:
120+
self.latest_transform_job.wait(logs=logs)
116121

117122
def delete_model(self):
118123
"""Delete the corresponding SageMaker model for this Transformer.
@@ -130,9 +135,9 @@ def _retrieve_image_name(self):
130135
'Local instance types require locally created models.'
131136
% self.model_name)
132137

133-
def wait(self):
138+
def wait(self, logs=True):
134139
self._ensure_last_transform_job()
135-
self.latest_transform_job.wait()
140+
self.latest_transform_job.wait(logs=logs)
136141

137142
def _ensure_last_transform_job(self):
138143
if self.latest_transform_job is None:
@@ -205,8 +210,11 @@ def start_new(cls, transformer, data, data_type, content_type, compression_type,
205210

206211
return cls(transformer.sagemaker_session, transformer._current_job_name)
207212

208-
def wait(self):
209-
self.sagemaker_session.wait_for_transform_job(self.job_name)
213+
def wait(self, logs=True):
214+
if logs:
215+
self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True)
216+
else:
217+
self.sagemaker_session.wait_for_transform_job(self.job_name)
210218

211219
@staticmethod
212220
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):

0 commit comments

Comments
 (0)