diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 8ab28c53f3..c83cbdc51a 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1428,24 +1428,12 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) print(secondary_training_status_message(description, None), end="") - instance_count = description["ResourceConfig"]["InstanceCount"] - status = description["TrainingJobStatus"] - - stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position - - # Increase retries allowed (from default of 4), as we don't want waiting for a training job - # to be interrupted by a transient exception. - config = botocore.config.Config(retries={"max_attempts": 15}) - client = self.boto_session.client("logs", config=config) - log_group = "/aws/sagemaker/TrainingJobs" - job_already_completed = status in ("Completed", "Failed", "Stopped") - - state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE - dot = False + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( + self, description, job="Training" + ) - color_wrap = sagemaker.logs.ColorWrap() + state = _get_initial_job_state(description, "TrainingJobStatus", wait) # The loop below implements a state machine that alternates between checking the job status # and reading whatever is available in the logs at this point. Note, that if we were @@ -1470,52 +1458,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method last_describe_job_call = time.time() last_description = description while True: - if len(stream_names) < instance_count: - # Log streams are created whenever a container starts writing to stdout/err, so - # this list # may be dynamic until we have a stream for every instance. - try: - streams = client.describe_log_streams( - logGroupName=log_group, - logStreamNamePrefix=job_name + "/", - orderBy="LogStreamName", - limit=instance_count, - ) - stream_names = [s["logStreamName"] for s in streams["logStreams"]] - positions.update( - [ - (s, sagemaker.logs.Position(timestamp=0, skip=0)) - for s in stream_names - if s not in positions - ] - ) - except ClientError as e: - # On the very first training job run on an account, there's no log group until - # the container starts logging, so ignore any errors thrown about that - err = e.response.get("Error", {}) - if err.get("Code", None) != "ResourceNotFoundException": - raise - - if len(stream_names) > 0: - if dot: - print("") - dot = False - for idx, event in sagemaker.logs.multi_stream_iter( - client, log_group, stream_names, positions - ): - color_wrap(idx, event["message"]) - ts, count = positions[stream_names[idx]] - if event["timestamp"] == ts: - positions[stream_names[idx]] = sagemaker.logs.Position( - timestamp=ts, skip=count + 1 - ) - else: - positions[stream_names[idx]] = sagemaker.logs.Position( - timestamp=event["timestamp"], skip=1 - ) - else: - dot = True - print(".", end="") - sys.stdout.flush() + _flush_log_streams( + stream_names, + instance_count, + client, + log_group, + job_name, + positions, + dot, + color_wrap, + ) if state == LogState.COMPLETE: break @@ -1554,6 +1506,86 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method saving = (1 - float(billable_time) / training_time) * 100 print("Managed Spot Training savings: {:.1f}%".format(saving)) + def logs_for_transform_job(self, job_name, wait=False, poll=10): + """Display the logs for a given transform job, optionally tailing them until the + job is complete. If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + job_name (str): Name of the transform job to display the logs for. + wait (bool): Whether to keep looking for new log entries until the job completes + (default: False). + poll (int): The interval in seconds between polling for new log entries and job + completion (default: 5). + + Raises: + ValueError: If the transform job fails. + """ + + description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) + + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( + self, description, job="Transform" + ) + + state = _get_initial_job_state(description, "TransformJobStatus", wait) + + # The loop below implements a state machine that alternates between checking the job status + # and reading whatever is available in the logs at this point. Note, that if we were + # called with wait == False, we never check the job status. + # + # If wait == TRUE and job is not completed, the initial state is TAILING + # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is + # complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to + # Cloudwatch after the job was marked complete. + last_describe_job_call = time.time() + while True: + _flush_log_streams( + stream_names, + instance_count, + client, + log_group, + job_name, + positions, + dot, + color_wrap, + ) + if state == LogState.COMPLETE: + break + + time.sleep(poll) + + if state == LogState.JOB_COMPLETE: + state = LogState.COMPLETE + elif time.time() - last_describe_job_call >= 30: + description = self.sagemaker_client.describe_transform_job( + TransformJobName=job_name + ) + last_describe_job_call = time.time() + + status = description["TransformJobStatus"] + + if status in ("Completed", "Failed", "Stopped"): + print() + state = LogState.JOB_COMPLETE + + if wait: + self._check_job_status(job_name, description, "TransformJobStatus") + if dot: + print() + def container_def(image, model_data_url=None, env=None): """Create a definition for executing a container as part of a SageMaker model. @@ -1892,3 +1924,83 @@ def _vpc_config_from_training_job( if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT: return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY) return vpc_utils.sanitize(vpc_config_override) + + +def _get_initial_job_state(description, status_key, wait): + """Placeholder docstring""" + status = description[status_key] + job_already_completed = status in ("Completed", "Failed", "Stopped") + return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE + + +def _logs_init(sagemaker_session, description, job): + """Placeholder docstring""" + if job == "Training": + instance_count = description["ResourceConfig"]["InstanceCount"] + elif job == "Transform": + instance_count = description["TransformResources"]["InstanceCount"] + + stream_names = [] # The list of log streams + positions = {} # The current position in each stream, map of stream name -> position + + # Increase retries allowed (from default of 4), as we don't want waiting for a training job + # to be interrupted by a transient exception. + config = botocore.config.Config(retries={"max_attempts": 15}) + client = sagemaker_session.boto_session.client("logs", config=config) + log_group = "/aws/sagemaker/" + job + "Jobs" + + dot = False + + color_wrap = sagemaker.logs.ColorWrap() + + return instance_count, stream_names, positions, client, log_group, dot, color_wrap + + +def _flush_log_streams( + stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap +): + """Placeholder docstring""" + if len(stream_names) < instance_count: + # Log streams are created whenever a container starts writing to stdout/err, so this list + # may be dynamic until we have a stream for every instance. + try: + streams = client.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=job_name + "/", + orderBy="LogStreamName", + limit=instance_count, + ) + stream_names = [s["logStreamName"] for s in streams["logStreams"]] + positions.update( + [ + (s, sagemaker.logs.Position(timestamp=0, skip=0)) + for s in stream_names + if s not in positions + ] + ) + except ClientError as e: + # On the very first training job run on an account, there's no log group until + # the container starts logging, so ignore any errors thrown about that + err = e.response.get("Error", {}) + if err.get("Code", None) != "ResourceNotFoundException": + raise + + if len(stream_names) > 0: + if dot: + print("") + dot = False + for idx, event in sagemaker.logs.multi_stream_iter( + client, log_group, stream_names, positions + ): + color_wrap(idx, event["message"]) + ts, count = positions[stream_names[idx]] + if event["timestamp"] == ts: + positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1) + else: + positions[stream_names[idx]] = sagemaker.logs.Position( + timestamp=event["timestamp"], skip=1 + ) + else: + dot = True + print(".", end="") + sys.stdout.flush() diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index d5efe87cdf..0ce37906ff 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -119,6 +119,8 @@ def transform( input_filter=None, output_filter=None, join_source=None, + wait=False, + logs=False, ): """Start a new transform job. @@ -154,6 +156,10 @@ def transform( will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. + wait (bool): Whether the call should wait until the job completes + (default: True). + logs (bool): Whether to show the logs produced by the job. + Only meaningful when wait is True (default: False). """ local_mode = self.sagemaker_session.local_mode if not local_mode and not data.startswith("s3://"): @@ -187,6 +193,9 @@ def transform( join_source, ) + if wait: + self.latest_transform_job.wait(logs=logs) + def delete_model(self): """Delete the corresponding SageMaker model for this Transformer.""" self.sagemaker_session.delete_model(self.model_name) @@ -224,10 +233,10 @@ def _retrieve_image_name(self): "Local instance types require locally created models." % self.model_name ) - def wait(self): + def wait(self, logs=True): """Placeholder docstring""" self._ensure_last_transform_job() - self.latest_transform_job.wait() + self.latest_transform_job.wait(logs=logs) def stop_transform_job(self, wait=True): """Stop latest running batch transform job. @@ -351,8 +360,11 @@ def start_new( return cls(transformer.sagemaker_session, transformer._current_job_name) - def wait(self): - self.sagemaker_session.wait_for_transform_job(self.job_name) + def wait(self, logs=True): + if logs: + self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True) + else: + self.sagemaker_session.wait_for_transform_job(self.job_name) def stop(self): """Placeholder docstring""" diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 456c980d7a..e6a9001aa1 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -398,6 +398,47 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version, cpu_instance_ assert desc["TransformJobStatus"] == "Stopped" +def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version, cpu_instance_type): + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-transform") + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) + + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) + + with timeout(minutes=45): + transformer = _create_transformer_and_transform_job( + mx, transform_input, cpu_instance_type, wait=True, logs=True + ) + + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): + transformer.wait() + + def _create_transformer_and_transform_job( estimator, transform_input, @@ -406,6 +447,8 @@ def _create_transformer_and_transform_job( input_filter=None, output_filter=None, join_source=None, + wait=False, + logs=False, ): transformer = estimator.transformer(1, instance_type, volume_kms_key=volume_kms_key) transformer.transform( @@ -414,5 +457,7 @@ def _create_transformer_and_transform_job( input_filter=input_filter, output_filter=output_filter, join_source=join_source, + wait=wait, + logs=logs, ) return transformer diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1615df36b9..274c177618 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -349,6 +349,28 @@ def test_s3_input_all_arguments(): IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS) IN_PROGRESS_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "InProgress"}) +COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT = { + "TransformJobStatus": "Completed", + "ModelName": "some-model", + "TransformJobName": JOB_NAME, + "TransformResources": {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE}, + "TransformEndTime": datetime.datetime(2018, 2, 17, 7, 19, 34, 953000), + "TransformStartTime": datetime.datetime(2018, 2, 17, 7, 15, 0, 103000), + "TransformOutput": {"AssembleWith": "None", "KmsKeyId": "", "S3OutputPath": S3_OUTPUT}, + "TransformInput": { + "CompressionType": "None", + "ContentType": "text/csv", + "DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}, + "SplitType": "Line", + }, +} + +STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) +STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT.update({"TransformJobStatus": "Stopped"}) + +IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) +IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({"TransformJobStatus": "InProgress"}) + @pytest.fixture() def sagemaker_session(): @@ -852,6 +874,9 @@ def sagemaker_session_complete(): boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT + ims.sagemaker_client.describe_transform_job.return_value = ( + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT + ) return ims @@ -862,6 +887,7 @@ def sagemaker_session_stopped(): boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT + ims.sagemaker_client.describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT return ims @@ -876,6 +902,11 @@ def sagemaker_session_ready_lifecycle(): IN_PROGRESS_DESCRIBE_JOB_RESULT, COMPLETED_DESCRIBE_JOB_RESULT, ] + ims.sagemaker_client.describe_transform_job.side_effect = [ + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT, + ] return ims @@ -890,6 +921,11 @@ def sagemaker_session_full_lifecycle(): IN_PROGRESS_DESCRIBE_JOB_RESULT, COMPLETED_DESCRIBE_JOB_RESULT, ] + ims.sagemaker_client.describe_transform_job.side_effect = [ + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT, + ] return ims @@ -956,6 +992,69 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle) ] +@patch("sagemaker.logs.ColorWrap") +def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete): + ims = sagemaker_session_complete + ims.logs_for_transform_job(JOB_NAME) + ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME) + cw().assert_called_with(0, "hi there #1") + + +@patch("sagemaker.logs.ColorWrap") +def test_logs_for_transform_job_no_wait_stopped_job(cw, sagemaker_session_stopped): + ims = sagemaker_session_stopped + ims.logs_for_transform_job(JOB_NAME) + ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME) + cw().assert_called_with(0, "hi there #1") + + +@patch("sagemaker.logs.ColorWrap") +def test_logs_for_transform_job_wait_on_completed(cw, sagemaker_session_complete): + ims = sagemaker_session_complete + ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) + assert ims.sagemaker_client.describe_transform_job.call_args_list == [ + call(TransformJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") + + +@patch("sagemaker.logs.ColorWrap") +def test_logs_for_transform_job_wait_on_stopped(cw, sagemaker_session_stopped): + ims = sagemaker_session_stopped + ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) + assert ims.sagemaker_client.describe_transform_job.call_args_list == [ + call(TransformJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") + + +@patch("sagemaker.logs.ColorWrap") +def test_logs_for_transform_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle): + ims = sagemaker_session_ready_lifecycle + ims.logs_for_transform_job(JOB_NAME) + assert ims.sagemaker_client.describe_transform_job.call_args_list == [ + call(TransformJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") + + +@patch("sagemaker.logs.ColorWrap") +@patch("time.time", side_effect=[0, 30, 60, 90, 120, 150, 180]) +def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle): + ims = sagemaker_session_full_lifecycle + ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) + assert ( + ims.sagemaker_client.describe_transform_job.call_args_list + == [call(TransformJobName=JOB_NAME)] * 3 + ) + assert cw().call_args_list == [ + call(0, "hi there #1"), + call(0, "hi there #2"), + call(0, "hi there #2a"), + call(0, "hi there #3"), + ] + + MODEL_NAME = "some-model" PRIMARY_CONTAINER = { "Environment": {},