From 14e52e29e141a5c9a7692dfdaa91d8f71d5fa2cb Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Tue, 7 May 2019 14:22:42 +0000 Subject: [PATCH 1/6] feature: Estimator.fit like logs for transformer --- src/sagemaker/session.py | 162 +++++++++++++++++++++++++---------- src/sagemaker/transformer.py | 20 ++++- tests/unit/test_session.py | 98 +++++++++++++++++++++ 3 files changed, 230 insertions(+), 50 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index e64086d953..5b2823e84e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1466,52 +1466,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method last_describe_job_call = time.time() last_description = description while True: - if len(stream_names) < instance_count: - # Log streams are created whenever a container starts writing to stdout/err, so - # this list # may be dynamic until we have a stream for every instance. - try: - streams = client.describe_log_streams( - logGroupName=log_group, - logStreamNamePrefix=job_name + "/", - orderBy="LogStreamName", - limit=instance_count, - ) - stream_names = [s["logStreamName"] for s in streams["logStreams"]] - positions.update( - [ - (s, sagemaker.logs.Position(timestamp=0, skip=0)) - for s in stream_names - if s not in positions - ] - ) - except ClientError as e: - # On the very first training job run on an account, there's no log group until - # the container starts logging, so ignore any errors thrown about that - err = e.response.get("Error", {}) - if err.get("Code", None) != "ResourceNotFoundException": - raise - - if len(stream_names) > 0: - if dot: - print("") - dot = False - for idx, event in sagemaker.logs.multi_stream_iter( - client, log_group, stream_names, positions - ): - color_wrap(idx, event["message"]) - ts, count = positions[stream_names[idx]] - if event["timestamp"] == ts: - positions[stream_names[idx]] = sagemaker.logs.Position( - timestamp=ts, skip=count + 1 - ) - else: - positions[stream_names[idx]] = sagemaker.logs.Position( - timestamp=event["timestamp"], skip=1 - ) - else: - dot = True - print(".", end="") - sys.stdout.flush() + _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap) if state == LogState.COMPLETE: break @@ -1550,6 +1505,87 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method saving = (1 - float(billable_time) / training_time) * 100 print("Managed Spot Training savings: {:.1f}%".format(saving)) + def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning + """Display the logs for a given transform job, optionally tailing them until the + job is complete. If the output is a tty or a Jupyter cell, it will be color-coded + based on which instance the log entry is from. + + Args: + job_name (str): Name of the transform job to display the logs for. + wait (bool): Whether to keep looking for new log entries until the job completes (default: False). + poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). + + Raises: + ValueError: If the transform job fails. + """ + + description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) + instance_count = description['TransformResources']['InstanceCount'] + status = description['TransformJobStatus'] + + stream_names = [] # The list of log streams + positions = {} # The current position in each stream, map of stream name -> position + + # Increase retries allowed (from default of 4), as we don't want waiting for a training job + # to be interrupted by a transient exception. + config = botocore.config.Config(retries={'max_attempts': 15}) + client = self.boto_session.client('logs', config=config) + log_group = '/aws/sagemaker/TransformJobs' + + job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False + + state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE + dot = False + + color_wrap = sagemaker.logs.ColorWrap() + + # The loop below implements a state machine that alternates between checking the job status and + # reading whatever is available in the logs at this point. Note, that if we were called with + # wait == False, we never check the job status. + # + # If wait == TRUE and job is not completed, the initial state is TAILING + # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete). + # + # The state table: + # + # STATE ACTIONS CONDITION NEW STATE + # ---------------- ---------------- ----------------- ---------------- + # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE + # Else TAILING + # JOB_COMPLETE Read logs, Pause Any COMPLETE + # COMPLETE Read logs, Exit N/A + # + # Notes: + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after + # the job was marked complete. + last_describe_job_call = time.time() + while True: + _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap) + if state == LogState.COMPLETE: + break + + time.sleep(poll) + + if state == LogState.JOB_COMPLETE: + state = LogState.COMPLETE + elif time.time() - last_describe_job_call >= 30: + description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) + last_describe_job_call = time.time() + + status = description['TransformJobStatus'] + + if status == 'Completed' or status == 'Failed' or status == 'Stopped': + print() + state = LogState.JOB_COMPLETE + + if wait: + self._check_job_status(job_name, description, 'TransformJobStatus') + if dot: + print() + # Customers are not billed for hardware provisioning, so billable time is less than total time + billable_time = (description['TransformEndTime'] - description['TransformStartTime']) * instance_count + print('Billable seconds:', int(billable_time.total_seconds()) + 1) + def container_def(image, model_data_url=None, env=None): """Create a definition for executing a container as part of a SageMaker model. @@ -1888,3 +1924,37 @@ def _vpc_config_from_training_job( if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT: return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY) return vpc_utils.sanitize(vpc_config_override) + + +def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap): + if len(stream_names) < instance_count: + # Log streams are created whenever a container starts writing to stdout/err, so this list + # may be dynamic until we have a stream for every instance. + try: + streams = client.describe_log_streams(logGroupName=log_group, logStreamNamePrefix=job_name + '/', + orderBy='LogStreamName', limit=instance_count) + stream_names = [s['logStreamName'] for s in streams['logStreams']] + positions.update([(s, sagemaker.logs.Position(timestamp=0, skip=0)) + for s in stream_names if s not in positions]) + except ClientError as e: + # On the very first training job run on an account, there's no log group until + # the container starts logging, so ignore any errors thrown about that + err = e.response.get('Error', {}) + if err.get('Code', None) != 'ResourceNotFoundException': + raise + + if len(stream_names) > 0: + if dot: + print('') + dot = False + for idx, event in sagemaker.logs.multi_stream_iter(client, log_group, stream_names, positions): + color_wrap(idx, event['message']) + ts, count = positions[stream_names[idx]] + if event['timestamp'] == ts: + positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1) + else: + positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=event['timestamp'], skip=1) + else: + dot = True + print('.', end='') + sys.stdout.flush() diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index d5efe87cdf..1be976c56f 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -119,6 +119,8 @@ def transform( input_filter=None, output_filter=None, join_source=None, + wait=False, + logs=False ): """Start a new transform job. @@ -154,6 +156,10 @@ def transform( will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. + wait (bool): Whether the call should wait until the job completes + (default: True). + logs (bool): Whether to show the logs produced by the job. + Only meaningful when wait is True (default: True). """ local_mode = self.sagemaker_session.local_mode if not local_mode and not data.startswith("s3://"): @@ -187,6 +193,9 @@ def transform( join_source, ) + if wait: + self.latest_transform_job.wait(logs=logs) + def delete_model(self): """Delete the corresponding SageMaker model for this Transformer.""" self.sagemaker_session.delete_model(self.model_name) @@ -224,10 +233,10 @@ def _retrieve_image_name(self): "Local instance types require locally created models." % self.model_name ) - def wait(self): + def wait(self, logs=True): """Placeholder docstring""" self._ensure_last_transform_job() - self.latest_transform_job.wait() + self.latest_transform_job.wait(logs=logs) def stop_transform_job(self, wait=True): """Stop latest running batch transform job. @@ -351,8 +360,11 @@ def start_new( return cls(transformer.sagemaker_session, transformer._current_job_name) - def wait(self): - self.sagemaker_session.wait_for_transform_job(self.job_name) + def wait(self, logs=True): + if logs: + self.sagemaker_session.logs_for_transform_job(self.job_name, wait=True) + else: + self.sagemaker_session.wait_for_transform_job(self.job_name) def stop(self): """Placeholder docstring""" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6e2b2c0f36..9b7ce5b6aa 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -340,6 +340,38 @@ def test_s3_input_all_arguments(): IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS) IN_PROGRESS_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "InProgress"}) +COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT = { + 'TransformJobStatus': 'Completed', + 'ModelName': 'some-model', + 'TransformJobName': JOB_NAME, + 'TransformResources': { + 'InstanceCount': INSTANCE_COUNT, + 'InstanceType': INSTANCE_TYPE + }, + 'TransformEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000), + 'TransformStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000), + 'TransformOutput': { + 'AssembleWith': 'None', + 'KmsKeyId': '', + 'S3OutputPath': S3_OUTPUT + }, + 'TransformInput': { + 'CompressionType': 'None', + 'ContentType': 'text/csv', + 'DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': S3_INPUT_URI + }, + 'SplitType': 'Line' + } +} + +STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) +STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'Stopped'}) + +IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) +IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'InProgress'}) + @pytest.fixture() def sagemaker_session(): @@ -841,6 +873,7 @@ def sagemaker_session_complete(): boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT + ims.sagemaker_client.describe_transform_job.return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT return ims @@ -851,6 +884,7 @@ def sagemaker_session_stopped(): boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT + ims.sagemaker_client.describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT return ims @@ -865,6 +899,11 @@ def sagemaker_session_ready_lifecycle(): IN_PROGRESS_DESCRIBE_JOB_RESULT, COMPLETED_DESCRIBE_JOB_RESULT, ] + ims.sagemaker_client.describe_transform_job.side_effect = [ + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT, + ] return ims @@ -879,6 +918,11 @@ def sagemaker_session_full_lifecycle(): IN_PROGRESS_DESCRIBE_JOB_RESULT, COMPLETED_DESCRIBE_JOB_RESULT, ] + ims.sagemaker_client.describe_transform_job.side_effect = [ + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT, + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT, + ] return ims @@ -946,6 +990,60 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle) MODEL_NAME = "some-model" + + +@patch('sagemaker.logs.ColorWrap') +def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete): + ims = sagemaker_session_complete + ims.logs_for_transform_job(JOB_NAME) + ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME) + cw().assert_called_with(0, 'hi there #1') + + +@patch('sagemaker.logs.ColorWrap') +def test_logs_for_transform_job_no_wait_stopped_job(cw, sagemaker_session_stopped): + ims = sagemaker_session_stopped + ims.logs_for_transform_job(JOB_NAME) + ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME) + cw().assert_called_with(0, 'hi there #1') + + +@patch('sagemaker.logs.ColorWrap') +def test_logs_for_transform_job_wait_on_completed(cw, sagemaker_session_complete): + ims = sagemaker_session_complete + ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) + assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] + cw().assert_called_with(0, 'hi there #1') + + +@patch('sagemaker.logs.ColorWrap') +def test_logs_for_transform_job_wait_on_stopped(cw, sagemaker_session_stopped): + ims = sagemaker_session_stopped + ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) + assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] + cw().assert_called_with(0, 'hi there #1') + + +@patch('sagemaker.logs.ColorWrap') +def test_logs_for_transform_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle): + ims = sagemaker_session_ready_lifecycle + ims.logs_for_transform_job(JOB_NAME) + assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] + cw().assert_called_with(0, 'hi there #1') + + +@patch('sagemaker.logs.ColorWrap') +@patch('time.time', side_effect=[0, 30, 60, 90, 120, 150, 180]) +def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle): + ims = sagemaker_session_full_lifecycle + ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) + assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] * 3 + assert cw().call_args_list == [call(0, 'hi there #1'), call(0, 'hi there #2'), + call(0, 'hi there #2a'), call(0, 'hi there #3')] + + +MODEL_NAME = 'some-model' +>>>>>>> feature: Estimator.fit like logs for transformer PRIMARY_CONTAINER = { "Environment": {}, "Image": IMAGE, From 3b72dd3baec99a42ecb35dd6ec3ba05071a6e8ca Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Fri, 24 May 2019 17:28:09 +0000 Subject: [PATCH 2/6] Add integrations tests and refactor --- src/sagemaker/session.py | 16 +++++++-------- src/sagemaker/transformer.py | 2 +- tests/integ/test_transformer.py | 35 +++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5b2823e84e..116ad01904 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1436,9 +1436,8 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method client = self.boto_session.client("logs", config=config) log_group = "/aws/sagemaker/TrainingJobs" - job_already_completed = status in ("Completed", "Failed", "Stopped") + state = _get_initial_job_state(description, 'TrainingJobStatus', wait) - state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE dot = False color_wrap = sagemaker.logs.ColorWrap() @@ -1521,7 +1520,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) instance_count = description['TransformResources']['InstanceCount'] - status = description['TransformJobStatus'] stream_names = [] # The list of log streams positions = {} # The current position in each stream, map of stream name -> position @@ -1532,9 +1530,8 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - client = self.boto_session.client('logs', config=config) log_group = '/aws/sagemaker/TransformJobs' - job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False + state = _get_initial_job_state(description, 'TransformJobStatus', wait) - state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE dot = False color_wrap = sagemaker.logs.ColorWrap() @@ -1582,9 +1579,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - self._check_job_status(job_name, description, 'TransformJobStatus') if dot: print() - # Customers are not billed for hardware provisioning, so billable time is less than total time - billable_time = (description['TransformEndTime'] - description['TransformStartTime']) * instance_count - print('Billable seconds:', int(billable_time.total_seconds()) + 1) def container_def(image, model_data_url=None, env=None): @@ -1926,6 +1920,12 @@ def _vpc_config_from_training_job( return vpc_utils.sanitize(vpc_config_override) +def _get_initial_job_state(description, status_key, wait): + status = description[status_key] + job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False + return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE + + def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap): if len(stream_names) < instance_count: # Log streams are created whenever a container starts writing to stdout/err, so this list diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 1be976c56f..b9b1c57588 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -159,7 +159,7 @@ def transform( wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. - Only meaningful when wait is True (default: True). + Only meaningful when wait is True (default: False). """ local_mode = self.sagemaker_session.local_mode if not local_mode and not data.startswith("s3://"): diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 93a4dd2f34..680c674e84 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -398,6 +398,36 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version): assert desc["TransformJobStatus"] == "Stopped" +def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version): + data_path = os.path.join(DATA_DIR, 'mxnet_mnist') + script_path = os.path.join(data_path, 'mnist.py') + + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, + train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version) + + train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), + key_prefix='integ-test-data/mxnet_mnist/train') + test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), + key_prefix='integ-test-data/mxnet_mnist/test') + job_name = unique_name_from_base('test-mxnet-transform') + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + + transform_input_path = os.path.join(data_path, 'transform', 'data.csv') + transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' + transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, + key_prefix=transform_input_key_prefix) + + with timeout(minutes=45): + transformer = _create_transformer_and_transform_job(mx, transform_input, wait=True, logs=True) + + with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, + minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + transformer.wait() + + def _create_transformer_and_transform_job( estimator, transform_input, @@ -406,6 +436,8 @@ def _create_transformer_and_transform_job( input_filter=None, output_filter=None, join_source=None, + wait=False, + logs=False, ): transformer = estimator.transformer(1, instance_type, volume_kms_key=volume_kms_key) transformer.transform( @@ -414,5 +446,8 @@ def _create_transformer_and_transform_job( input_filter=input_filter, output_filter=output_filter, join_source=join_source, + wait=wait, + logs=logs, ) return transformer + From 3b13e7dc5beb1c1ec6272e6e329df4d6b6050cc5 Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Sat, 8 Jun 2019 07:28:39 +0000 Subject: [PATCH 3/6] Refactor code --- src/sagemaker/session.py | 55 ++++++++++++++++----------------- tests/integ/test_transformer.py | 1 - tests/unit/test_session.py | 4 --- 3 files changed, 27 insertions(+), 33 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 116ad01904..05608e092c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1423,25 +1423,13 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method """ description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) - print(secondary_training_status_message(description, None), end="") - instance_count = description["ResourceConfig"]["InstanceCount"] - status = description["TrainingJobStatus"] + print(secondary_training_status_message(description, None), end='') - stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position - - # Increase retries allowed (from default of 4), as we don't want waiting for a training job - # to be interrupted by a transient exception. - config = botocore.config.Config(retries={"max_attempts": 15}) - client = self.boto_session.client("logs", config=config) - log_group = "/aws/sagemaker/TrainingJobs" + instance_count, stream_names, positions, client, log_group, dot, color_wrap = \ + _logs_initializer(self, description, job='Training') state = _get_initial_job_state(description, 'TrainingJobStatus', wait) - dot = False - - color_wrap = sagemaker.logs.ColorWrap() - # The loop below implements a state machine that alternates between checking the job status # and reading whatever is available in the logs at this point. Note, that if we were # called with wait == False, we never check the job status. @@ -1519,23 +1507,12 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - """ description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) - instance_count = description['TransformResources']['InstanceCount'] - stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position - - # Increase retries allowed (from default of 4), as we don't want waiting for a training job - # to be interrupted by a transient exception. - config = botocore.config.Config(retries={'max_attempts': 15}) - client = self.boto_session.client('logs', config=config) - log_group = '/aws/sagemaker/TransformJobs' + instance_count, stream_names, positions, client, log_group, dot, color_wrap = \ + _logs_initializer(self, description, job='Transform') state = _get_initial_job_state(description, 'TransformJobStatus', wait) - dot = False - - color_wrap = sagemaker.logs.ColorWrap() - # The loop below implements a state machine that alternates between checking the job status and # reading whatever is available in the logs at this point. Note, that if we were called with # wait == False, we never check the job status. @@ -1926,6 +1903,28 @@ def _get_initial_job_state(description, status_key, wait): return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE +def _logs_initializer(sagemaker_session, description, job): + if job == 'Training': + instance_count = description['ResourceConfig']['InstanceCount'] + elif job == 'Transform': + instance_count = description['TransformResources']['InstanceCount'] + + stream_names = [] # The list of log streams + positions = {} # The current position in each stream, map of stream name -> position + + # Increase retries allowed (from default of 4), as we don't want waiting for a training job + # to be interrupted by a transient exception. + config = botocore.config.Config(retries={'max_attempts': 15}) + client = sagemaker_session.boto_session.client('logs', config=config) + log_group = '/aws/sagemaker/' + job + 'Jobs' + + dot = False + + color_wrap = sagemaker.logs.ColorWrap() + + return instance_count, stream_names, positions, client, log_group, dot, color_wrap + + def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap): if len(stream_names) < instance_count: # Log streams are created whenever a container starts writing to stdout/err, so this list diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 680c674e84..1fea946059 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -450,4 +450,3 @@ def _create_transformer_and_transform_job( logs=logs, ) return transformer - diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9b7ce5b6aa..31d4e3695d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -989,9 +989,6 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle) ] -MODEL_NAME = "some-model" - - @patch('sagemaker.logs.ColorWrap') def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete): ims = sagemaker_session_complete @@ -1043,7 +1040,6 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_ MODEL_NAME = 'some-model' ->>>>>>> feature: Estimator.fit like logs for transformer PRIMARY_CONTAINER = { "Environment": {}, "Image": IMAGE, From 30fb37cf8468d84ed488db06b71e53be2b3c9dce Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Wed, 3 Jul 2019 13:05:49 +0000 Subject: [PATCH 4/6] black check --- src/sagemaker/session.py | 115 ++++++++++++++++++++++---------- src/sagemaker/transformer.py | 2 +- tests/integ/test_transformer.py | 47 ++++++++----- tests/unit/test_session.py | 93 ++++++++++++++------------ 4 files changed, 157 insertions(+), 100 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 05608e092c..c9ddf75572 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1423,12 +1423,13 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method """ description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) - print(secondary_training_status_message(description, None), end='') + print(secondary_training_status_message(description, None), end="") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = \ - _logs_initializer(self, description, job='Training') + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_initializer( + self, description, job="Training" + ) - state = _get_initial_job_state(description, 'TrainingJobStatus', wait) + state = _get_initial_job_state(description, "TrainingJobStatus", wait) # The loop below implements a state machine that alternates between checking the job status # and reading whatever is available in the logs at this point. Note, that if we were @@ -1453,7 +1454,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method last_describe_job_call = time.time() last_description = description while True: - _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap) + _flush_log_streams( + stream_names, + instance_count, + client, + log_group, + job_name, + positions, + dot, + color_wrap, + ) if state == LogState.COMPLETE: break @@ -1492,7 +1502,9 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method saving = (1 - float(billable_time) / training_time) * 100 print("Managed Spot Training savings: {:.1f}%".format(saving)) - def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning + def logs_for_transform_job( + self, job_name, wait=False, poll=10 + ): # noqa: C901 - suppress complexity warning """Display the logs for a given transform job, optionally tailing them until the job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. @@ -1508,10 +1520,11 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) - instance_count, stream_names, positions, client, log_group, dot, color_wrap = \ - _logs_initializer(self, description, job='Transform') + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_initializer( + self, description, job="Transform" + ) - state = _get_initial_job_state(description, 'TransformJobStatus', wait) + state = _get_initial_job_state(description, "TransformJobStatus", wait) # The loop below implements a state machine that alternates between checking the job status and # reading whatever is available in the logs at this point. Note, that if we were called with @@ -1534,7 +1547,16 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - # the job was marked complete. last_describe_job_call = time.time() while True: - _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap) + _flush_log_streams( + stream_names, + instance_count, + client, + log_group, + job_name, + positions, + dot, + color_wrap, + ) if state == LogState.COMPLETE: break @@ -1543,17 +1565,19 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - if state == LogState.JOB_COMPLETE: state = LogState.COMPLETE elif time.time() - last_describe_job_call >= 30: - description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) + description = self.sagemaker_client.describe_transform_job( + TransformJobName=job_name + ) last_describe_job_call = time.time() - status = description['TransformJobStatus'] + status = description["TransformJobStatus"] - if status == 'Completed' or status == 'Failed' or status == 'Stopped': + if status == "Completed" or status == "Failed" or status == "Stopped": print() state = LogState.JOB_COMPLETE if wait: - self._check_job_status(job_name, description, 'TransformJobStatus') + self._check_job_status(job_name, description, "TransformJobStatus") if dot: print() @@ -1899,24 +1923,26 @@ def _vpc_config_from_training_job( def _get_initial_job_state(description, status_key, wait): status = description[status_key] - job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False + job_already_completed = ( + True if status == "Completed" or status == "Failed" or status == "Stopped" else False + ) return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE def _logs_initializer(sagemaker_session, description, job): - if job == 'Training': - instance_count = description['ResourceConfig']['InstanceCount'] - elif job == 'Transform': - instance_count = description['TransformResources']['InstanceCount'] + if job == "Training": + instance_count = description["ResourceConfig"]["InstanceCount"] + elif job == "Transform": + instance_count = description["TransformResources"]["InstanceCount"] stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position + positions = {} # The current position in each stream, map of stream name -> position # Increase retries allowed (from default of 4), as we don't want waiting for a training job # to be interrupted by a transient exception. - config = botocore.config.Config(retries={'max_attempts': 15}) - client = sagemaker_session.boto_session.client('logs', config=config) - log_group = '/aws/sagemaker/' + job + 'Jobs' + config = botocore.config.Config(retries={"max_attempts": 15}) + client = sagemaker_session.boto_session.client("logs", config=config) + log_group = "/aws/sagemaker/" + job + "Jobs" dot = False @@ -1925,35 +1951,50 @@ def _logs_initializer(sagemaker_session, description, job): return instance_count, stream_names, positions, client, log_group, dot, color_wrap -def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap): +def _flush_log_streams( + stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap +): if len(stream_names) < instance_count: # Log streams are created whenever a container starts writing to stdout/err, so this list # may be dynamic until we have a stream for every instance. try: - streams = client.describe_log_streams(logGroupName=log_group, logStreamNamePrefix=job_name + '/', - orderBy='LogStreamName', limit=instance_count) - stream_names = [s['logStreamName'] for s in streams['logStreams']] - positions.update([(s, sagemaker.logs.Position(timestamp=0, skip=0)) - for s in stream_names if s not in positions]) + streams = client.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=job_name + "/", + orderBy="LogStreamName", + limit=instance_count, + ) + stream_names = [s["logStreamName"] for s in streams["logStreams"]] + positions.update( + [ + (s, sagemaker.logs.Position(timestamp=0, skip=0)) + for s in stream_names + if s not in positions + ] + ) except ClientError as e: # On the very first training job run on an account, there's no log group until # the container starts logging, so ignore any errors thrown about that - err = e.response.get('Error', {}) - if err.get('Code', None) != 'ResourceNotFoundException': + err = e.response.get("Error", {}) + if err.get("Code", None) != "ResourceNotFoundException": raise if len(stream_names) > 0: if dot: - print('') + print("") dot = False - for idx, event in sagemaker.logs.multi_stream_iter(client, log_group, stream_names, positions): - color_wrap(idx, event['message']) + for idx, event in sagemaker.logs.multi_stream_iter( + client, log_group, stream_names, positions + ): + color_wrap(idx, event["message"]) ts, count = positions[stream_names[idx]] - if event['timestamp'] == ts: + if event["timestamp"] == ts: positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1) else: - positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=event['timestamp'], skip=1) + positions[stream_names[idx]] = sagemaker.logs.Position( + timestamp=event["timestamp"], skip=1 + ) else: dot = True - print('.', end='') + print(".", end="") sys.stdout.flush() diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index b9b1c57588..5b30c8a4c3 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -120,7 +120,7 @@ def transform( output_filter=None, join_source=None, wait=False, - logs=False + logs=False, ): """Start a new transform job. diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 1fea946059..579a0b9535 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -399,32 +399,43 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version): def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version): - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - script_path = os.path.join(data_path, 'mnist.py') + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") - mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - job_name = unique_name_from_base('test-mxnet-transform') + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-transform") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) - transform_input_path = os.path.join(data_path, 'transform', 'data.csv') - transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' - transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) with timeout(minutes=45): - transformer = _create_transformer_and_transform_job(mx, transform_input, wait=True, logs=True) + transformer = _create_transformer_and_transform_job( + mx, transform_input, wait=True, logs=True + ) - with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, - minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): transformer.wait() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 31d4e3695d..d15915167d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -341,36 +341,26 @@ def test_s3_input_all_arguments(): IN_PROGRESS_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "InProgress"}) COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT = { - 'TransformJobStatus': 'Completed', - 'ModelName': 'some-model', - 'TransformJobName': JOB_NAME, - 'TransformResources': { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE - }, - 'TransformEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000), - 'TransformStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000), - 'TransformOutput': { - 'AssembleWith': 'None', - 'KmsKeyId': '', - 'S3OutputPath': S3_OUTPUT + "TransformJobStatus": "Completed", + "ModelName": "some-model", + "TransformJobName": JOB_NAME, + "TransformResources": {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE}, + "TransformEndTime": datetime.datetime(2018, 2, 17, 7, 19, 34, 953000), + "TransformStartTime": datetime.datetime(2018, 2, 17, 7, 15, 0, 103000), + "TransformOutput": {"AssembleWith": "None", "KmsKeyId": "", "S3OutputPath": S3_OUTPUT}, + "TransformInput": { + "CompressionType": "None", + "ContentType": "text/csv", + "DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}, + "SplitType": "Line", }, - 'TransformInput': { - 'CompressionType': 'None', - 'ContentType': 'text/csv', - 'DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': S3_INPUT_URI - }, - 'SplitType': 'Line' - } } STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) -STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'Stopped'}) +STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT.update({"TransformJobStatus": "Stopped"}) IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT) -IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'InProgress'}) +IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({"TransformJobStatus": "InProgress"}) @pytest.fixture() @@ -873,7 +863,9 @@ def sagemaker_session_complete(): boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT - ims.sagemaker_client.describe_transform_job.return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT + ims.sagemaker_client.describe_transform_job.return_value = ( + COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT + ) return ims @@ -989,57 +981,70 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle) ] -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete): ims = sagemaker_session_complete ims.logs_for_transform_job(JOB_NAME) ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME) - cw().assert_called_with(0, 'hi there #1') + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_transform_job_no_wait_stopped_job(cw, sagemaker_session_stopped): ims = sagemaker_session_stopped ims.logs_for_transform_job(JOB_NAME) ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME) - cw().assert_called_with(0, 'hi there #1') + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_transform_job_wait_on_completed(cw, sagemaker_session_complete): ims = sagemaker_session_complete ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] - cw().assert_called_with(0, 'hi there #1') + assert ims.sagemaker_client.describe_transform_job.call_args_list == [ + call(TransformJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_transform_job_wait_on_stopped(cw, sagemaker_session_stopped): ims = sagemaker_session_stopped ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] - cw().assert_called_with(0, 'hi there #1') + assert ims.sagemaker_client.describe_transform_job.call_args_list == [ + call(TransformJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_transform_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle): ims = sagemaker_session_ready_lifecycle ims.logs_for_transform_job(JOB_NAME) - assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] - cw().assert_called_with(0, 'hi there #1') + assert ims.sagemaker_client.describe_transform_job.call_args_list == [ + call(TransformJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') -@patch('time.time', side_effect=[0, 30, 60, 90, 120, 150, 180]) +@patch("sagemaker.logs.ColorWrap") +@patch("time.time", side_effect=[0, 30, 60, 90, 120, 150, 180]) def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle): ims = sagemaker_session_full_lifecycle ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] * 3 - assert cw().call_args_list == [call(0, 'hi there #1'), call(0, 'hi there #2'), - call(0, 'hi there #2a'), call(0, 'hi there #3')] + assert ( + ims.sagemaker_client.describe_transform_job.call_args_list + == [call(TransformJobName=JOB_NAME)] * 3 + ) + assert cw().call_args_list == [ + call(0, "hi there #1"), + call(0, "hi there #2"), + call(0, "hi there #2a"), + call(0, "hi there #3"), + ] -MODEL_NAME = 'some-model' +MODEL_NAME = "some-model" PRIMARY_CONTAINER = { "Environment": {}, "Image": IMAGE, From d178174f74ea60185dc0326f67094f16a2a7f19c Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Thu, 1 Aug 2019 17:47:44 +0000 Subject: [PATCH 5/6] linting fixes --- src/sagemaker/session.py | 42 ++++++++++++++++++------------------ src/sagemaker/transformer.py | 2 +- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index c9ddf75572..8b30fdcbab 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1404,9 +1404,7 @@ def get_caller_identity_arn(self): return role - def logs_for_job( # noqa: C901 - suppress complexity warning for this method - self, job_name, wait=False, poll=10 - ): + def logs_for_job(self, job_name, wait=False, poll=10): """Display the logs for a given training job, optionally tailing them until the job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. @@ -1425,7 +1423,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) print(secondary_training_status_message(description, None), end="") - instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_initializer( + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( self, description, job="Training" ) @@ -1502,17 +1500,17 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method saving = (1 - float(billable_time) / training_time) * 100 print("Managed Spot Training savings: {:.1f}%".format(saving)) - def logs_for_transform_job( - self, job_name, wait=False, poll=10 - ): # noqa: C901 - suppress complexity warning + def logs_for_transform_job(self, job_name, wait=False, poll=10): """Display the logs for a given transform job, optionally tailing them until the job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. Args: job_name (str): Name of the transform job to display the logs for. - wait (bool): Whether to keep looking for new log entries until the job completes (default: False). - poll (int): The interval in seconds between polling for new log entries and job completion (default: 5). + wait (bool): Whether to keep looking for new log entries until the job completes + (default: False). + poll (int): The interval in seconds between polling for new log entries and job + completion (default: 5). Raises: ValueError: If the transform job fails. @@ -1520,18 +1518,19 @@ def logs_for_transform_job( description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name) - instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_initializer( + instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( self, description, job="Transform" ) state = _get_initial_job_state(description, "TransformJobStatus", wait) - # The loop below implements a state machine that alternates between checking the job status and - # reading whatever is available in the logs at this point. Note, that if we were called with - # wait == False, we never check the job status. + # The loop below implements a state machine that alternates between checking the job status + # and reading whatever is available in the logs at this point. Note, that if we were + # called with wait == False, we never check the job status. # # If wait == TRUE and job is not completed, the initial state is TAILING - # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete). + # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is + # complete). # # The state table: # @@ -1543,8 +1542,8 @@ def logs_for_transform_job( # COMPLETE Read logs, Exit N/A # # Notes: - # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after - # the job was marked complete. + # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to + # Cloudwatch after the job was marked complete. last_describe_job_call = time.time() while True: _flush_log_streams( @@ -1572,7 +1571,7 @@ def logs_for_transform_job( status = description["TransformJobStatus"] - if status == "Completed" or status == "Failed" or status == "Stopped": + if status in ("Completed", "Failed", "Stopped"): print() state = LogState.JOB_COMPLETE @@ -1922,14 +1921,14 @@ def _vpc_config_from_training_job( def _get_initial_job_state(description, status_key, wait): + """Placeholder docstring""" status = description[status_key] - job_already_completed = ( - True if status == "Completed" or status == "Failed" or status == "Stopped" else False - ) + job_already_completed = status in ("Completed", "Failed", "Stopped") return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE -def _logs_initializer(sagemaker_session, description, job): +def _logs_init(sagemaker_session, description, job): + """Placeholder docstring""" if job == "Training": instance_count = description["ResourceConfig"]["InstanceCount"] elif job == "Transform": @@ -1954,6 +1953,7 @@ def _logs_initializer(sagemaker_session, description, job): def _flush_log_streams( stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap ): + """Placeholder docstring""" if len(stream_names) < instance_count: # Log streams are created whenever a container starts writing to stdout/err, so this list # may be dynamic until we have a stream for every instance. diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 5b30c8a4c3..0ce37906ff 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -156,7 +156,7 @@ def transform( will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. - wait (bool): Whether the call should wait until the job completes + wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: False). From c18a62da93257566da8c3454ad578c7c0bada9de Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Thu, 5 Sep 2019 07:42:31 +0000 Subject: [PATCH 6/6] new --- src/sagemaker/session.py | 4 +++- tests/integ/test_transformer.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 8b30fdcbab..e5a493e270 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1404,7 +1404,9 @@ def get_caller_identity_arn(self): return role - def logs_for_job(self, job_name, wait=False, poll=10): + def logs_for_job( # noqa: C901 - suppress complexity warning for this method + self, job_name, wait=False, poll=10 + ): """Display the logs for a given training job, optionally tailing them until the job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 579a0b9535..844a72a0f1 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -398,7 +398,7 @@ def test_stop_transform_job(sagemaker_session, mxnet_full_version): assert desc["TransformJobStatus"] == "Stopped" -def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version): +def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version, cpu_instance_type): data_path = os.path.join(DATA_DIR, "mxnet_mnist") script_path = os.path.join(data_path, "mnist.py") @@ -406,7 +406,7 @@ def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version): entry_point=script_path, role="SageMakerRole", train_instance_count=1, - train_instance_type="ml.c4.xlarge", + train_instance_type=cpu_instance_type, sagemaker_session=sagemaker_session, framework_version=mxnet_full_version, ) @@ -430,7 +430,7 @@ def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version): with timeout(minutes=45): transformer = _create_transformer_and_transform_job( - mx, transform_input, wait=True, logs=True + mx, transform_input, cpu_instance_type, wait=True, logs=True ) with timeout_and_delete_model_with_transformer(