Skip to content

feature: Estimator.fit like logs for transformer #782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 174 additions & 62 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,24 +1428,12 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method

description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
print(secondary_training_status_message(description, None), end="")
instance_count = description["ResourceConfig"]["InstanceCount"]
status = description["TrainingJobStatus"]

stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position

# Increase retries allowed (from default of 4), as we don't want waiting for a training job
# to be interrupted by a transient exception.
config = botocore.config.Config(retries={"max_attempts": 15})
client = self.boto_session.client("logs", config=config)
log_group = "/aws/sagemaker/TrainingJobs"

job_already_completed = status in ("Completed", "Failed", "Stopped")

state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
dot = False
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
self, description, job="Training"
)

color_wrap = sagemaker.logs.ColorWrap()
state = _get_initial_job_state(description, "TrainingJobStatus", wait)

# The loop below implements a state machine that alternates between checking the job status
# and reading whatever is available in the logs at this point. Note, that if we were
Expand All @@ -1470,52 +1458,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
last_describe_job_call = time.time()
last_description = description
while True:
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so
# this list # may be dynamic until we have a stream for every instance.
try:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=instance_count,
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
positions.update(
[
(s, sagemaker.logs.Position(timestamp=0, skip=0))
for s in stream_names
if s not in positions
]
)
except ClientError as e:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
err = e.response.get("Error", {})
if err.get("Code", None) != "ResourceNotFoundException":
raise

if len(stream_names) > 0:
if dot:
print("")
dot = False
for idx, event in sagemaker.logs.multi_stream_iter(
client, log_group, stream_names, positions
):
color_wrap(idx, event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = sagemaker.logs.Position(
timestamp=ts, skip=count + 1
)
else:
positions[stream_names[idx]] = sagemaker.logs.Position(
timestamp=event["timestamp"], skip=1
)
else:
dot = True
print(".", end="")
sys.stdout.flush()
_flush_log_streams(
stream_names,
instance_count,
client,
log_group,
job_name,
positions,
dot,
color_wrap,
)
if state == LogState.COMPLETE:
break

Expand Down Expand Up @@ -1554,6 +1506,86 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
saving = (1 - float(billable_time) / training_time) * 100
print("Managed Spot Training savings: {:.1f}%".format(saving))

def logs_for_transform_job(self, job_name, wait=False, poll=10):
"""Display the logs for a given transform job, optionally tailing them until the
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
based on which instance the log entry is from.

Args:
job_name (str): Name of the transform job to display the logs for.
wait (bool): Whether to keep looking for new log entries until the job completes
(default: False).
poll (int): The interval in seconds between polling for new log entries and job
completion (default: 5).

Raises:
ValueError: If the transform job fails.
"""

description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)

instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
self, description, job="Transform"
)

state = _get_initial_job_state(description, "TransformJobStatus", wait)

# The loop below implements a state machine that alternates between checking the job status
# and reading whatever is available in the logs at this point. Note, that if we were
# called with wait == False, we never check the job status.
#
# If wait == TRUE and job is not completed, the initial state is TAILING
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
# complete).
#
# The state table:
#
# STATE ACTIONS CONDITION NEW STATE
# ---------------- ---------------- ----------------- ----------------
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
# Else TAILING
# JOB_COMPLETE Read logs, Pause Any COMPLETE
# COMPLETE Read logs, Exit N/A
#
# Notes:
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
# Cloudwatch after the job was marked complete.
last_describe_job_call = time.time()
while True:
_flush_log_streams(
stream_names,
instance_count,
client,
log_group,
job_name,
positions,
dot,
color_wrap,
)
if state == LogState.COMPLETE:
break

time.sleep(poll)

if state == LogState.JOB_COMPLETE:
state = LogState.COMPLETE
elif time.time() - last_describe_job_call >= 30:
description = self.sagemaker_client.describe_transform_job(
TransformJobName=job_name
)
last_describe_job_call = time.time()

status = description["TransformJobStatus"]

if status in ("Completed", "Failed", "Stopped"):
print()
state = LogState.JOB_COMPLETE

if wait:
self._check_job_status(job_name, description, "TransformJobStatus")
if dot:
print()


def container_def(image, model_data_url=None, env=None):
"""Create a definition for executing a container as part of a SageMaker model.
Expand Down Expand Up @@ -1892,3 +1924,83 @@ def _vpc_config_from_training_job(
if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT:
return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY)
return vpc_utils.sanitize(vpc_config_override)


def _get_initial_job_state(description, status_key, wait):
"""Placeholder docstring"""
status = description[status_key]
job_already_completed = status in ("Completed", "Failed", "Stopped")
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE


def _logs_init(sagemaker_session, description, job):
"""Placeholder docstring"""
if job == "Training":
instance_count = description["ResourceConfig"]["InstanceCount"]
elif job == "Transform":
instance_count = description["TransformResources"]["InstanceCount"]

stream_names = [] # The list of log streams
positions = {} # The current position in each stream, map of stream name -> position

# Increase retries allowed (from default of 4), as we don't want waiting for a training job
# to be interrupted by a transient exception.
config = botocore.config.Config(retries={"max_attempts": 15})
client = sagemaker_session.boto_session.client("logs", config=config)
log_group = "/aws/sagemaker/" + job + "Jobs"

dot = False

color_wrap = sagemaker.logs.ColorWrap()

return instance_count, stream_names, positions, client, log_group, dot, color_wrap


def _flush_log_streams(
stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap
):
"""Placeholder docstring"""
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
# may be dynamic until we have a stream for every instance.
try:
streams = client.describe_log_streams(
logGroupName=log_group,
logStreamNamePrefix=job_name + "/",
orderBy="LogStreamName",
limit=instance_count,
)
stream_names = [s["logStreamName"] for s in streams["logStreams"]]
positions.update(
[
(s, sagemaker.logs.Position(timestamp=0, skip=0))
for s in stream_names
if s not in positions
]
)
except ClientError as e:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
err = e.response.get("Error", {})
if err.get("Code", None) != "ResourceNotFoundException":
raise

if len(stream_names) > 0:
if dot:
print("")
dot = False
for idx, event in sagemaker.logs.multi_stream_iter(
client, log_group, stream_names, positions
):
color_wrap(idx, event["message"])
ts, count = positions[stream_names[idx]]
if event["timestamp"] == ts:
positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1)
else:
positions[stream_names[idx]] = sagemaker.logs.Position(
timestamp=event["timestamp"], skip=1
)
else:
dot = True
print(".", end="")
sys.stdout.flush()
20 changes: 16 additions & 4 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def transform(
input_filter=None,
output_filter=None,
join_source=None,
wait=False,
logs=False,
):
"""Start a new transform job.

Expand Down Expand Up @@ -154,6 +156,10 @@ def transform(
will be joined to the inference result. You can use OutputFilter
to select the useful portion before uploading to S3. (default:
None). Valid values: Input, None.
wait (bool): Whether the call should wait until the job completes
(default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when wait is True (default: False).
"""
local_mode = self.sagemaker_session.local_mode
if not local_mode and not data.startswith("s3://"):
Expand Down Expand Up @@ -187,6 +193,9 @@ def transform(
join_source,
)

if wait:
self.latest_transform_job.wait(logs=logs)

def delete_model(self):
"""Delete the corresponding SageMaker model for this Transformer."""
self.sagemaker_session.delete_model(self.model_name)
Expand Down Expand Up @@ -224,10 +233,10 @@ def _retrieve_image_name(self):
"Local instance types require locally created models." % self.model_name
)

def wait(self):
def wait(self, logs=True):
"""Placeholder docstring"""
self._ensure_last_transform_job()
self.latest_transform_job.wait()
self.latest_transform_job.wait(logs=logs)

def stop_transform_job(self, wait=True):
"""Stop latest running batch transform job.
Expand Down Expand Up @@ -351,8 +360,11 @@ def start_new(

return cls(transformer.sagemaker_session, transformer._current_job_name)

def wait(self):
self.sagemaker_session.wait_for_transform_job(self.job_name)
def wait(self, logs=True):
if logs:
self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True)
else:
self.sagemaker_session.wait_for_transform_job(self.job_name)

def stop(self):
"""Placeholder docstring"""
Expand Down
45 changes: 45 additions & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,47 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version, cpu_instance_
assert desc["TransformJobStatus"] == "Stopped"


def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version, cpu_instance_type):
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
script_path = os.path.join(data_path, "mnist.py")

mx = MXNet(
entry_point=script_path,
role="SageMakerRole",
train_instance_count=1,
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version,
)

train_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
)
test_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
)
job_name = unique_name_from_base("test-mxnet-transform")

with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)

transform_input_path = os.path.join(data_path, "transform", "data.csv")
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
transform_input = mx.sagemaker_session.upload_data(
path=transform_input_path, key_prefix=transform_input_key_prefix
)

with timeout(minutes=45):
transformer = _create_transformer_and_transform_job(
mx, transform_input, cpu_instance_type, wait=True, logs=True
)

with timeout_and_delete_model_with_transformer(
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
):
transformer.wait()


def _create_transformer_and_transform_job(
estimator,
transform_input,
Expand All @@ -406,6 +447,8 @@ def _create_transformer_and_transform_job(
input_filter=None,
output_filter=None,
join_source=None,
wait=False,
logs=False,
):
transformer = estimator.transformer(1, instance_type, volume_kms_key=volume_kms_key)
transformer.transform(
Expand All @@ -414,5 +457,7 @@ def _create_transformer_and_transform_job(
input_filter=input_filter,
output_filter=output_filter,
join_source=join_source,
wait=wait,
logs=logs,
)
return transformer
Loading