Skip to content

Commit c364fd1

Browse files
imujjwal96laurenyu
authored andcommitted
feature: Estimator.fit like logs for transformer (#782)
1 parent f54f506 commit c364fd1

File tree

4 files changed

+334
-66
lines changed

4 files changed

+334
-66
lines changed

src/sagemaker/session.py

+174-62
Original file line numberDiff line numberDiff line change
@@ -1428,24 +1428,12 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
14281428

14291429
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
14301430
print(secondary_training_status_message(description, None), end="")
1431-
instance_count = description["ResourceConfig"]["InstanceCount"]
1432-
status = description["TrainingJobStatus"]
1433-
1434-
stream_names = [] # The list of log streams
1435-
positions = {} # The current position in each stream, map of stream name -> position
1436-
1437-
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1438-
# to be interrupted by a transient exception.
1439-
config = botocore.config.Config(retries={"max_attempts": 15})
1440-
client = self.boto_session.client("logs", config=config)
1441-
log_group = "/aws/sagemaker/TrainingJobs"
14421431

1443-
job_already_completed = status in ("Completed", "Failed", "Stopped")
1444-
1445-
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
1446-
dot = False
1432+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
1433+
self, description, job="Training"
1434+
)
14471435

1448-
color_wrap = sagemaker.logs.ColorWrap()
1436+
state = _get_initial_job_state(description, "TrainingJobStatus", wait)
14491437

14501438
# The loop below implements a state machine that alternates between checking the job status
14511439
# and reading whatever is available in the logs at this point. Note, that if we were
@@ -1470,52 +1458,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
14701458
last_describe_job_call = time.time()
14711459
last_description = description
14721460
while True:
1473-
if len(stream_names) < instance_count:
1474-
# Log streams are created whenever a container starts writing to stdout/err, so
1475-
# this list # may be dynamic until we have a stream for every instance.
1476-
try:
1477-
streams = client.describe_log_streams(
1478-
logGroupName=log_group,
1479-
logStreamNamePrefix=job_name + "/",
1480-
orderBy="LogStreamName",
1481-
limit=instance_count,
1482-
)
1483-
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
1484-
positions.update(
1485-
[
1486-
(s, sagemaker.logs.Position(timestamp=0, skip=0))
1487-
for s in stream_names
1488-
if s not in positions
1489-
]
1490-
)
1491-
except ClientError as e:
1492-
# On the very first training job run on an account, there's no log group until
1493-
# the container starts logging, so ignore any errors thrown about that
1494-
err = e.response.get("Error", {})
1495-
if err.get("Code", None) != "ResourceNotFoundException":
1496-
raise
1497-
1498-
if len(stream_names) > 0:
1499-
if dot:
1500-
print("")
1501-
dot = False
1502-
for idx, event in sagemaker.logs.multi_stream_iter(
1503-
client, log_group, stream_names, positions
1504-
):
1505-
color_wrap(idx, event["message"])
1506-
ts, count = positions[stream_names[idx]]
1507-
if event["timestamp"] == ts:
1508-
positions[stream_names[idx]] = sagemaker.logs.Position(
1509-
timestamp=ts, skip=count + 1
1510-
)
1511-
else:
1512-
positions[stream_names[idx]] = sagemaker.logs.Position(
1513-
timestamp=event["timestamp"], skip=1
1514-
)
1515-
else:
1516-
dot = True
1517-
print(".", end="")
1518-
sys.stdout.flush()
1461+
_flush_log_streams(
1462+
stream_names,
1463+
instance_count,
1464+
client,
1465+
log_group,
1466+
job_name,
1467+
positions,
1468+
dot,
1469+
color_wrap,
1470+
)
15191471
if state == LogState.COMPLETE:
15201472
break
15211473

@@ -1554,6 +1506,86 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
15541506
saving = (1 - float(billable_time) / training_time) * 100
15551507
print("Managed Spot Training savings: {:.1f}%".format(saving))
15561508

1509+
def logs_for_transform_job(self, job_name, wait=False, poll=10):
1510+
"""Display the logs for a given transform job, optionally tailing them until the
1511+
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1512+
based on which instance the log entry is from.
1513+
1514+
Args:
1515+
job_name (str): Name of the transform job to display the logs for.
1516+
wait (bool): Whether to keep looking for new log entries until the job completes
1517+
(default: False).
1518+
poll (int): The interval in seconds between polling for new log entries and job
1519+
completion (default: 5).
1520+
1521+
Raises:
1522+
ValueError: If the transform job fails.
1523+
"""
1524+
1525+
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1526+
1527+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
1528+
self, description, job="Transform"
1529+
)
1530+
1531+
state = _get_initial_job_state(description, "TransformJobStatus", wait)
1532+
1533+
# The loop below implements a state machine that alternates between checking the job status
1534+
# and reading whatever is available in the logs at this point. Note, that if we were
1535+
# called with wait == False, we never check the job status.
1536+
#
1537+
# If wait == TRUE and job is not completed, the initial state is TAILING
1538+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
1539+
# complete).
1540+
#
1541+
# The state table:
1542+
#
1543+
# STATE ACTIONS CONDITION NEW STATE
1544+
# ---------------- ---------------- ----------------- ----------------
1545+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1546+
# Else TAILING
1547+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
1548+
# COMPLETE Read logs, Exit N/A
1549+
#
1550+
# Notes:
1551+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
1552+
# Cloudwatch after the job was marked complete.
1553+
last_describe_job_call = time.time()
1554+
while True:
1555+
_flush_log_streams(
1556+
stream_names,
1557+
instance_count,
1558+
client,
1559+
log_group,
1560+
job_name,
1561+
positions,
1562+
dot,
1563+
color_wrap,
1564+
)
1565+
if state == LogState.COMPLETE:
1566+
break
1567+
1568+
time.sleep(poll)
1569+
1570+
if state == LogState.JOB_COMPLETE:
1571+
state = LogState.COMPLETE
1572+
elif time.time() - last_describe_job_call >= 30:
1573+
description = self.sagemaker_client.describe_transform_job(
1574+
TransformJobName=job_name
1575+
)
1576+
last_describe_job_call = time.time()
1577+
1578+
status = description["TransformJobStatus"]
1579+
1580+
if status in ("Completed", "Failed", "Stopped"):
1581+
print()
1582+
state = LogState.JOB_COMPLETE
1583+
1584+
if wait:
1585+
self._check_job_status(job_name, description, "TransformJobStatus")
1586+
if dot:
1587+
print()
1588+
15571589

15581590
def container_def(image, model_data_url=None, env=None):
15591591
"""Create a definition for executing a container as part of a SageMaker model.
@@ -1892,3 +1924,83 @@ def _vpc_config_from_training_job(
18921924
if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT:
18931925
return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY)
18941926
return vpc_utils.sanitize(vpc_config_override)
1927+
1928+
1929+
def _get_initial_job_state(description, status_key, wait):
1930+
"""Placeholder docstring"""
1931+
status = description[status_key]
1932+
job_already_completed = status in ("Completed", "Failed", "Stopped")
1933+
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
1934+
1935+
1936+
def _logs_init(sagemaker_session, description, job):
1937+
"""Placeholder docstring"""
1938+
if job == "Training":
1939+
instance_count = description["ResourceConfig"]["InstanceCount"]
1940+
elif job == "Transform":
1941+
instance_count = description["TransformResources"]["InstanceCount"]
1942+
1943+
stream_names = [] # The list of log streams
1944+
positions = {} # The current position in each stream, map of stream name -> position
1945+
1946+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1947+
# to be interrupted by a transient exception.
1948+
config = botocore.config.Config(retries={"max_attempts": 15})
1949+
client = sagemaker_session.boto_session.client("logs", config=config)
1950+
log_group = "/aws/sagemaker/" + job + "Jobs"
1951+
1952+
dot = False
1953+
1954+
color_wrap = sagemaker.logs.ColorWrap()
1955+
1956+
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
1957+
1958+
1959+
def _flush_log_streams(
1960+
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
1961+
):
1962+
"""Placeholder docstring"""
1963+
if len(stream_names) < instance_count:
1964+
# Log streams are created whenever a container starts writing to stdout/err, so this list
1965+
# may be dynamic until we have a stream for every instance.
1966+
try:
1967+
streams = client.describe_log_streams(
1968+
logGroupName=log_group,
1969+
logStreamNamePrefix=job_name + "/",
1970+
orderBy="LogStreamName",
1971+
limit=instance_count,
1972+
)
1973+
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
1974+
positions.update(
1975+
[
1976+
(s, sagemaker.logs.Position(timestamp=0, skip=0))
1977+
for s in stream_names
1978+
if s not in positions
1979+
]
1980+
)
1981+
except ClientError as e:
1982+
# On the very first training job run on an account, there's no log group until
1983+
# the container starts logging, so ignore any errors thrown about that
1984+
err = e.response.get("Error", {})
1985+
if err.get("Code", None) != "ResourceNotFoundException":
1986+
raise
1987+
1988+
if len(stream_names) > 0:
1989+
if dot:
1990+
print("")
1991+
dot = False
1992+
for idx, event in sagemaker.logs.multi_stream_iter(
1993+
client, log_group, stream_names, positions
1994+
):
1995+
color_wrap(idx, event["message"])
1996+
ts, count = positions[stream_names[idx]]
1997+
if event["timestamp"] == ts:
1998+
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1)
1999+
else:
2000+
positions[stream_names[idx]] = sagemaker.logs.Position(
2001+
timestamp=event["timestamp"], skip=1
2002+
)
2003+
else:
2004+
dot = True
2005+
print(".", end="")
2006+
sys.stdout.flush()

src/sagemaker/transformer.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def transform(
119119
input_filter=None,
120120
output_filter=None,
121121
join_source=None,
122+
wait=False,
123+
logs=False,
122124
):
123125
"""Start a new transform job.
124126
@@ -154,6 +156,10 @@ def transform(
154156
will be joined to the inference result. You can use OutputFilter
155157
to select the useful portion before uploading to S3. (default:
156158
None). Valid values: Input, None.
159+
wait (bool): Whether the call should wait until the job completes
160+
(default: True).
161+
logs (bool): Whether to show the logs produced by the job.
162+
Only meaningful when wait is True (default: False).
157163
"""
158164
local_mode = self.sagemaker_session.local_mode
159165
if not local_mode and not data.startswith("s3://"):
@@ -187,6 +193,9 @@ def transform(
187193
join_source,
188194
)
189195

196+
if wait:
197+
self.latest_transform_job.wait(logs=logs)
198+
190199
def delete_model(self):
191200
"""Delete the corresponding SageMaker model for this Transformer."""
192201
self.sagemaker_session.delete_model(self.model_name)
@@ -224,10 +233,10 @@ def _retrieve_image_name(self):
224233
"Local instance types require locally created models." % self.model_name
225234
)
226235

227-
def wait(self):
236+
def wait(self, logs=True):
228237
"""Placeholder docstring"""
229238
self._ensure_last_transform_job()
230-
self.latest_transform_job.wait()
239+
self.latest_transform_job.wait(logs=logs)
231240

232241
def stop_transform_job(self, wait=True):
233242
"""Stop latest running batch transform job.
@@ -351,8 +360,11 @@ def start_new(
351360

352361
return cls(transformer.sagemaker_session, transformer._current_job_name)
353362

354-
def wait(self):
355-
self.sagemaker_session.wait_for_transform_job(self.job_name)
363+
def wait(self, logs=True):
364+
if logs:
365+
self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True)
366+
else:
367+
self.sagemaker_session.wait_for_transform_job(self.job_name)
356368

357369
def stop(self):
358370
"""Placeholder docstring"""

tests/integ/test_transformer.py

+45
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,47 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version, cpu_instance_
398398
assert desc["TransformJobStatus"] == "Stopped"
399399

400400

401+
def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version, cpu_instance_type):
402+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
403+
script_path = os.path.join(data_path, "mnist.py")
404+
405+
mx = MXNet(
406+
entry_point=script_path,
407+
role="SageMakerRole",
408+
train_instance_count=1,
409+
train_instance_type=cpu_instance_type,
410+
sagemaker_session=sagemaker_session,
411+
framework_version=mxnet_full_version,
412+
)
413+
414+
train_input = mx.sagemaker_session.upload_data(
415+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
416+
)
417+
test_input = mx.sagemaker_session.upload_data(
418+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
419+
)
420+
job_name = unique_name_from_base("test-mxnet-transform")
421+
422+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
423+
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
424+
425+
transform_input_path = os.path.join(data_path, "transform", "data.csv")
426+
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
427+
transform_input = mx.sagemaker_session.upload_data(
428+
path=transform_input_path, key_prefix=transform_input_key_prefix
429+
)
430+
431+
with timeout(minutes=45):
432+
transformer = _create_transformer_and_transform_job(
433+
mx, transform_input, cpu_instance_type, wait=True, logs=True
434+
)
435+
436+
with timeout_and_delete_model_with_transformer(
437+
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
438+
):
439+
transformer.wait()
440+
441+
401442
def _create_transformer_and_transform_job(
402443
estimator,
403444
transform_input,
@@ -406,6 +447,8 @@ def _create_transformer_and_transform_job(
406447
input_filter=None,
407448
output_filter=None,
408449
join_source=None,
450+
wait=False,
451+
logs=False,
409452
):
410453
transformer = estimator.transformer(1, instance_type, volume_kms_key=volume_kms_key)
411454
transformer.transform(
@@ -414,5 +457,7 @@ def _create_transformer_and_transform_job(
414457
input_filter=input_filter,
415458
output_filter=output_filter,
416459
join_source=join_source,
460+
wait=wait,
461+
logs=logs,
417462
)
418463
return transformer

0 commit comments

Comments
 (0)