From 7acdbc49cfdb30eafa470abe7fa3cadc17635db0 Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Tue, 14 Jul 2020 10:24:23 -0700 Subject: [PATCH 1/2] change: seperate logs() from attach() --- src/sagemaker/estimator.py | 8 +++++++- tests/unit/test_estimator.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7645e395fd..6ab6b13d17 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -630,9 +630,15 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m sagemaker_session=sagemaker_session, job_name=training_job_name ) estimator._current_job_name = estimator.latest_training_job.name - estimator.latest_training_job.wait() + estimator.latest_training_job.wait(logs="None") return estimator + def logs(self): + """Display the logs for Estimator's training job. If the output is a tty or a Jupyter + cell, it will be color-coded based on which instance the log entry is from. + """ + self.sagemaker_session.logs_for_job(self.latest_training_job, wait=True) + def deploy( self, initial_instance_count, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index fa99ba1315..3eaa5d042f 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -681,6 +681,27 @@ def test_attach_framework(sagemaker_session): assert framework_estimator.enable_network_isolation() is True +def test_attach_no_logs(sagemaker_session): + returned_job_description = RETURNED_JOB_DESCRIPTION.copy() + mock_describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job + Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) + sagemaker_session.logs_for_job.assert_not_called() + + +def test_logs(sagemaker_session): + returned_job_description = RETURNED_JOB_DESCRIPTION.copy() + mock_describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job + estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) + estimator.logs() + sagemaker_session.logs_for_job.assert_called_with(estimator.latest_training_job, wait=True) + + def test_attach_without_hyperparameters(sagemaker_session): returned_job_description = RETURNED_JOB_DESCRIPTION.copy() del returned_job_description["HyperParameters"] From c5dafc58fd7052369d24a7bfc6372c1a4882a38b Mon Sep 17 00:00:00 2001 From: Chuyang Deng Date: Tue, 14 Jul 2020 18:06:14 -0700 Subject: [PATCH 2/2] make fixture for training_job_description --- src/sagemaker/estimator.py | 18 +++++--- tests/unit/test_estimator.py | 84 +++++++++++------------------------- 2 files changed, 38 insertions(+), 64 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6ab6b13d17..b7b44a507a 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -589,14 +589,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m has a Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``. - If the training job is in progress, attach will block and display log - messages from the training job, until the training job completes. + If the training job is in progress, attach will block until the training job + completes, but logs of the training job will not display. To see the logs + content, please call ``logs()`` Examples: >>> my_estimator.fit(wait=False) >>> training_job_name = my_estimator.latest_training_job.name Later on: >>> attached_estimator = Estimator.attach(training_job_name) + >>> attached_estimator.logs() >>> attached_estimator.deploy() Args: @@ -634,8 +636,10 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m return estimator def logs(self): - """Display the logs for Estimator's training job. If the output is a tty or a Jupyter - cell, it will be color-coded based on which instance the log entry is from. + """Display the logs for Estimator's training job. + + If the output is a tty or a Jupyter cell, it will be color-coded based + on which instance the log entry is from. """ self.sagemaker_session.logs_for_job(self.latest_training_job, wait=True) @@ -1837,14 +1841,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m has a Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``. - If the training job is in progress, attach will block and display log - messages from the training job, until the training job completes. + If the training job is in progress, attach will block until the training job + completes, but logs of the training job will not display. To see the logs + content, please call ``logs()`` Examples: >>> my_estimator.fit(wait=False) >>> training_job_name = my_estimator.latest_training_job.name Later on: >>> attached_estimator = Estimator.attach(training_job_name) + >>> attached_estimator.logs() >>> attached_estimator.deploy() Args: diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 3eaa5d042f..3bce3d58b6 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -191,6 +191,16 @@ def sagemaker_session(): return sms +@pytest.fixture() +def training_job_description(sagemaker_session): + returned_job_description = RETURNED_JOB_DESCRIPTION.copy() + mock_describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job + return returned_job_description + + def test_framework_all_init_args(sagemaker_session): f = DummyFramework( "my_script.py", @@ -651,13 +661,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session): assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"] -def test_attach_framework(sagemaker_session): - returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} - returned_job_description["EnableNetworkIsolation"] = True - sagemaker_session.sagemaker_client.describe_training_job = Mock( - name="describe_training_job", return_value=returned_job_description - ) +def test_attach_framework(sagemaker_session, training_job_description): + training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + training_job_description["EnableNetworkIsolation"] = True framework_estimator = DummyFramework.attach( training_job_name="neo", sagemaker_session=sagemaker_session @@ -681,50 +687,25 @@ def test_attach_framework(sagemaker_session): assert framework_estimator.enable_network_isolation() is True -def test_attach_no_logs(sagemaker_session): - returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - mock_describe_training_job = Mock( - name="describe_training_job", return_value=returned_job_description - ) - sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job +def test_attach_no_logs(sagemaker_session, training_job_description): Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) sagemaker_session.logs_for_job.assert_not_called() -def test_logs(sagemaker_session): - returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - mock_describe_training_job = Mock( - name="describe_training_job", return_value=returned_job_description - ) - sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job +def test_logs(sagemaker_session, training_job_description): estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) estimator.logs() sagemaker_session.logs_for_job.assert_called_with(estimator.latest_training_job, wait=True) -def test_attach_without_hyperparameters(sagemaker_session): - returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - del returned_job_description["HyperParameters"] - - mock_describe_training_job = Mock( - name="describe_training_job", return_value=returned_job_description - ) - sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job - +def test_attach_without_hyperparameters(sagemaker_session, training_job_description): + del training_job_description["HyperParameters"] estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) - assert estimator.hyperparameters() == {} -def test_attach_framework_with_tuning(sagemaker_session): - returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy" - - mock_describe_training_job = Mock( - name="describe_training_job", return_value=returned_job_description - ) - sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job - +def test_attach_framework_with_tuning(sagemaker_session, training_job_description): + training_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy" framework_estimator = DummyFramework.attach( training_job_name="neo", sagemaker_session=sagemaker_session ) @@ -744,10 +725,9 @@ def test_attach_framework_with_tuning(sagemaker_session): assert framework_estimator.encrypt_inter_container_traffic is False -def test_attach_framework_with_model_channel(sagemaker_session): +def test_attach_framework_with_model_channel(sagemaker_session, training_job_description): s3_uri = "s3://some/s3/path/model.tar.gz" - returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description["InputDataConfig"] = [ + training_job_description["InputDataConfig"] = [ { "ChannelName": "model", "InputMode": "File", @@ -755,10 +735,6 @@ def test_attach_framework_with_model_channel(sagemaker_session): } ] - sagemaker_session.sagemaker_client.describe_training_job = Mock( - name="describe_training_job", return_value=returned_job_description - ) - framework_estimator = DummyFramework.attach( training_job_name="neo", sagemaker_session=sagemaker_session ) @@ -766,14 +742,10 @@ def test_attach_framework_with_model_channel(sagemaker_session): assert framework_estimator.encrypt_inter_container_traffic is False -def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker_session): - returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description["EnableInterContainerTrafficEncryption"] = True - - sagemaker_session.sagemaker_client.describe_training_job = Mock( - name="describe_training_job", return_value=returned_job_description - ) - +def test_attach_framework_with_inter_container_traffic_encryption_flag( + sagemaker_session, training_job_description +): + training_job_description["EnableInterContainerTrafficEncryption"] = True framework_estimator = DummyFramework.attach( training_job_name="neo", sagemaker_session=sagemaker_session ) @@ -781,11 +753,7 @@ def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker assert framework_estimator.encrypt_inter_container_traffic is True -def test_attach_framework_base_from_generated_name(sagemaker_session): - sagemaker_session.sagemaker_client.describe_training_job = Mock( - name="describe_training_job", return_value=RETURNED_JOB_DESCRIPTION - ) - +def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description): base_job_name = "neo" framework_estimator = DummyFramework.attach( training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session