Skip to content

Commit 813155e

Browse files
committed
feature: Estimator.fit like logs for transformer
1 parent fe29f60 commit 813155e

File tree

3 files changed

+230
-51
lines changed

3 files changed

+230
-51
lines changed

src/sagemaker/session.py

+117-47
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12781278
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
12791279
12801280
Raises:
1281-
ValueError: If waiting and the training job fails.
1281+
ValueError: If the training job fails.
12821282
"""
12831283

12841284
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
@@ -1326,52 +1326,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
13261326
last_describe_job_call = time.time()
13271327
last_description = description
13281328
while True:
1329-
if len(stream_names) < instance_count:
1330-
# Log streams are created whenever a container starts writing to stdout/err, so this list
1331-
# may be dynamic until we have a stream for every instance.
1332-
try:
1333-
streams = client.describe_log_streams(
1334-
logGroupName=log_group,
1335-
logStreamNamePrefix=job_name + "/",
1336-
orderBy="LogStreamName",
1337-
limit=instance_count,
1338-
)
1339-
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
1340-
positions.update(
1341-
[
1342-
(s, sagemaker.logs.Position(timestamp=0, skip=0))
1343-
for s in stream_names
1344-
if s not in positions
1345-
]
1346-
)
1347-
except ClientError as e:
1348-
# On the very first training job run on an account, there's no log group until
1349-
# the container starts logging, so ignore any errors thrown about that
1350-
err = e.response.get("Error", {})
1351-
if err.get("Code", None) != "ResourceNotFoundException":
1352-
raise
1353-
1354-
if len(stream_names) > 0:
1355-
if dot:
1356-
print("")
1357-
dot = False
1358-
for idx, event in sagemaker.logs.multi_stream_iter(
1359-
client, log_group, stream_names, positions
1360-
):
1361-
color_wrap(idx, event["message"])
1362-
ts, count = positions[stream_names[idx]]
1363-
if event["timestamp"] == ts:
1364-
positions[stream_names[idx]] = sagemaker.logs.Position(
1365-
timestamp=ts, skip=count + 1
1366-
)
1367-
else:
1368-
positions[stream_names[idx]] = sagemaker.logs.Position(
1369-
timestamp=event["timestamp"], skip=1
1370-
)
1371-
else:
1372-
dot = True
1373-
print(".", end="")
1374-
sys.stdout.flush()
1329+
_flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap)
13751330
if state == LogState.COMPLETE:
13761331
break
13771332

@@ -1404,6 +1359,87 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
14041359
) * instance_count
14051360
print("Billable seconds:", int(billable_time.total_seconds()) + 1)
14061361

1362+
def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning
1363+
"""Display the logs for a given transform job, optionally tailing them until the
1364+
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1365+
based on which instance the log entry is from.
1366+
1367+
Args:
1368+
job_name (str): Name of the transform job to display the logs for.
1369+
wait (bool): Whether to keep looking for new log entries until the job completes (default: False).
1370+
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1371+
1372+
Raises:
1373+
ValueError: If the transform job fails.
1374+
"""
1375+
1376+
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1377+
instance_count = description['TransformResources']['InstanceCount']
1378+
status = description['TransformJobStatus']
1379+
1380+
stream_names = [] # The list of log streams
1381+
positions = {} # The current position in each stream, map of stream name -> position
1382+
1383+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1384+
# to be interrupted by a transient exception.
1385+
config = botocore.config.Config(retries={'max_attempts': 15})
1386+
client = self.boto_session.client('logs', config=config)
1387+
log_group = '/aws/sagemaker/TransformJobs'
1388+
1389+
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1390+
1391+
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
1392+
dot = False
1393+
1394+
color_wrap = sagemaker.logs.ColorWrap()
1395+
1396+
# The loop below implements a state machine that alternates between checking the job status and
1397+
# reading whatever is available in the logs at this point. Note, that if we were called with
1398+
# wait == False, we never check the job status.
1399+
#
1400+
# If wait == TRUE and job is not completed, the initial state is TAILING
1401+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete).
1402+
#
1403+
# The state table:
1404+
#
1405+
# STATE ACTIONS CONDITION NEW STATE
1406+
# ---------------- ---------------- ----------------- ----------------
1407+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1408+
# Else TAILING
1409+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
1410+
# COMPLETE Read logs, Exit N/A
1411+
#
1412+
# Notes:
1413+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
1414+
# the job was marked complete.
1415+
last_describe_job_call = time.time()
1416+
while True:
1417+
_flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap)
1418+
if state == LogState.COMPLETE:
1419+
break
1420+
1421+
time.sleep(poll)
1422+
1423+
if state == LogState.JOB_COMPLETE:
1424+
state = LogState.COMPLETE
1425+
elif time.time() - last_describe_job_call >= 30:
1426+
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1427+
last_describe_job_call = time.time()
1428+
1429+
status = description['TransformJobStatus']
1430+
1431+
if status == 'Completed' or status == 'Failed' or status == 'Stopped':
1432+
print()
1433+
state = LogState.JOB_COMPLETE
1434+
1435+
if wait:
1436+
self._check_job_status(job_name, description, 'TransformJobStatus')
1437+
if dot:
1438+
print()
1439+
# Customers are not billed for hardware provisioning, so billable time is less than total time
1440+
billable_time = (description['TransformEndTime'] - description['TransformStartTime']) * instance_count
1441+
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
1442+
14071443

14081444
def container_def(image, model_data_url=None, env=None):
14091445
"""Create a definition for executing a container as part of a SageMaker model.
@@ -1795,3 +1831,37 @@ def _vpc_config_from_training_job(
17951831
return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY)
17961832
else:
17971833
return vpc_utils.sanitize(vpc_config_override)
1834+
1835+
1836+
def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap):
1837+
if len(stream_names) < instance_count:
1838+
# Log streams are created whenever a container starts writing to stdout/err, so this list
1839+
# may be dynamic until we have a stream for every instance.
1840+
try:
1841+
streams = client.describe_log_streams(logGroupName=log_group, logStreamNamePrefix=job_name + '/',
1842+
orderBy='LogStreamName', limit=instance_count)
1843+
stream_names = [s['logStreamName'] for s in streams['logStreams']]
1844+
positions.update([(s, sagemaker.logs.Position(timestamp=0, skip=0))
1845+
for s in stream_names if s not in positions])
1846+
except ClientError as e:
1847+
# On the very first training job run on an account, there's no log group until
1848+
# the container starts logging, so ignore any errors thrown about that
1849+
err = e.response.get('Error', {})
1850+
if err.get('Code', None) != 'ResourceNotFoundException':
1851+
raise
1852+
1853+
if len(stream_names) > 0:
1854+
if dot:
1855+
print('')
1856+
dot = False
1857+
for idx, event in sagemaker.logs.multi_stream_iter(client, log_group, stream_names, positions):
1858+
color_wrap(idx, event['message'])
1859+
ts, count = positions[stream_names[idx]]
1860+
if event['timestamp'] == ts:
1861+
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1)
1862+
else:
1863+
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=event['timestamp'], skip=1)
1864+
else:
1865+
dot = True
1866+
print('.', end='')
1867+
sys.stdout.flush()

src/sagemaker/transformer.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def transform(
104104
input_filter=None,
105105
output_filter=None,
106106
join_source=None,
107+
wait=False,
108+
logs=False
107109
):
108110
"""Start a new transform job.
109111
@@ -131,6 +133,9 @@ def transform(
131133
meaning the entire input record will be joined to the inference result.
132134
You can use OutputFilter to select the useful portion before uploading to S3. (default: None).
133135
Valid values: Input, None.
136+
wait (bool): Whether the call should wait until the job completes (default: True).
137+
logs (bool): Whether to show the logs produced by the job.
138+
Only meaningful when wait is True (default: True).
134139
"""
135140
local_mode = self.sagemaker_session.local_mode
136141
if not local_mode and not data.startswith("s3://"):
@@ -163,6 +168,9 @@ def transform(
163168
join_source,
164169
)
165170

171+
if wait:
172+
self.latest_transform_job.wait(logs=logs)
173+
166174
def delete_model(self):
167175
"""Delete the corresponding SageMaker model for this Transformer.
168176
@@ -200,9 +208,9 @@ def _retrieve_image_name(self):
200208
"Local instance types require locally created models." % self.model_name
201209
)
202210

203-
def wait(self):
211+
def wait(self, logs=True):
204212
self._ensure_last_transform_job()
205-
self.latest_transform_job.wait()
213+
self.latest_transform_job.wait(logs=logs)
206214

207215
def _ensure_last_transform_job(self):
208216
if self.latest_transform_job is None:
@@ -300,8 +308,11 @@ def start_new(
300308

301309
return cls(transformer.sagemaker_session, transformer._current_job_name)
302310

303-
def wait(self):
304-
self.sagemaker_session.wait_for_transform_job(self.job_name)
311+
def wait(self, logs=True):
312+
if logs:
313+
self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True)
314+
else:
315+
self.sagemaker_session.wait_for_transform_job(self.job_name)
305316

306317
@staticmethod
307318
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):

tests/unit/test_session.py

+98
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,38 @@ def test_s3_input_all_arguments():
340340
IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
341341
IN_PROGRESS_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "InProgress"})
342342

343+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT = {
344+
'TransformJobStatus': 'Completed',
345+
'ModelName': 'some-model',
346+
'TransformJobName': JOB_NAME,
347+
'TransformResources': {
348+
'InstanceCount': INSTANCE_COUNT,
349+
'InstanceType': INSTANCE_TYPE
350+
},
351+
'TransformEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000),
352+
'TransformStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000),
353+
'TransformOutput': {
354+
'AssembleWith': 'None',
355+
'KmsKeyId': '',
356+
'S3OutputPath': S3_OUTPUT
357+
},
358+
'TransformInput': {
359+
'CompressionType': 'None',
360+
'ContentType': 'text/csv',
361+
'DataSource': {
362+
'S3DataType': 'S3Prefix',
363+
'S3Uri': S3_INPUT_URI
364+
},
365+
'SplitType': 'Line'
366+
}
367+
}
368+
369+
STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT)
370+
STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'Stopped'})
371+
372+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT)
373+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'InProgress'})
374+
343375

344376
@pytest.fixture()
345377
def sagemaker_session():
@@ -787,6 +819,7 @@ def sagemaker_session_complete():
787819
boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS
788820
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
789821
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
822+
ims.sagemaker_client.describe_transform_job.return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT
790823
return ims
791824

792825

@@ -797,6 +830,7 @@ def sagemaker_session_stopped():
797830
boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS
798831
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
799832
ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT
833+
ims.sagemaker_client.describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT
800834
return ims
801835

802836

@@ -811,6 +845,11 @@ def sagemaker_session_ready_lifecycle():
811845
IN_PROGRESS_DESCRIBE_JOB_RESULT,
812846
COMPLETED_DESCRIBE_JOB_RESULT,
813847
]
848+
ims.sagemaker_client.describe_transform_job.side_effect = [
849+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
850+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
851+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT,
852+
]
814853
return ims
815854

816855

@@ -825,6 +864,11 @@ def sagemaker_session_full_lifecycle():
825864
IN_PROGRESS_DESCRIBE_JOB_RESULT,
826865
COMPLETED_DESCRIBE_JOB_RESULT,
827866
]
867+
ims.sagemaker_client.describe_transform_job.side_effect = [
868+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
869+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
870+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT,
871+
]
828872
return ims
829873

830874

@@ -892,6 +936,60 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
892936

893937

894938
MODEL_NAME = "some-model"
939+
940+
941+
@patch('sagemaker.logs.ColorWrap')
942+
def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete):
943+
ims = sagemaker_session_complete
944+
ims.logs_for_transform_job(JOB_NAME)
945+
ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME)
946+
cw().assert_called_with(0, 'hi there #1')
947+
948+
949+
@patch('sagemaker.logs.ColorWrap')
950+
def test_logs_for_transform_job_no_wait_stopped_job(cw, sagemaker_session_stopped):
951+
ims = sagemaker_session_stopped
952+
ims.logs_for_transform_job(JOB_NAME)
953+
ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME)
954+
cw().assert_called_with(0, 'hi there #1')
955+
956+
957+
@patch('sagemaker.logs.ColorWrap')
958+
def test_logs_for_transform_job_wait_on_completed(cw, sagemaker_session_complete):
959+
ims = sagemaker_session_complete
960+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
961+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
962+
cw().assert_called_with(0, 'hi there #1')
963+
964+
965+
@patch('sagemaker.logs.ColorWrap')
966+
def test_logs_for_transform_job_wait_on_stopped(cw, sagemaker_session_stopped):
967+
ims = sagemaker_session_stopped
968+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
969+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
970+
cw().assert_called_with(0, 'hi there #1')
971+
972+
973+
@patch('sagemaker.logs.ColorWrap')
974+
def test_logs_for_transform_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle):
975+
ims = sagemaker_session_ready_lifecycle
976+
ims.logs_for_transform_job(JOB_NAME)
977+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
978+
cw().assert_called_with(0, 'hi there #1')
979+
980+
981+
@patch('sagemaker.logs.ColorWrap')
982+
@patch('time.time', side_effect=[0, 30, 60, 90, 120, 150, 180])
983+
def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle):
984+
ims = sagemaker_session_full_lifecycle
985+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
986+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] * 3
987+
assert cw().call_args_list == [call(0, 'hi there #1'), call(0, 'hi there #2'),
988+
call(0, 'hi there #2a'), call(0, 'hi there #3')]
989+
990+
991+
MODEL_NAME = 'some-model'
992+
>>>>>>> feature: Estimator.fit like logs for transformer
895993
PRIMARY_CONTAINER = {
896994
"Environment": {},
897995
"Image": IMAGE,

0 commit comments

Comments
 (0)