Skip to content

Commit c5dafc5

Browse files
author
Chuyang Deng
committed
make fixture for training_job_description
1 parent 7acdbc4 commit c5dafc5

File tree

2 files changed

+38
-64
lines changed

2 files changed

+38
-64
lines changed

src/sagemaker/estimator.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
589589
has a Complete status, it can be ``deploy()`` ed to create a SageMaker
590590
Endpoint and return a ``Predictor``.
591591
592-
If the training job is in progress, attach will block and display log
593-
messages from the training job, until the training job completes.
592+
If the training job is in progress, attach will block until the training job
593+
completes, but logs of the training job will not display. To see the logs
594+
content, please call ``logs()``
594595
595596
Examples:
596597
>>> my_estimator.fit(wait=False)
597598
>>> training_job_name = my_estimator.latest_training_job.name
598599
Later on:
599600
>>> attached_estimator = Estimator.attach(training_job_name)
601+
>>> attached_estimator.logs()
600602
>>> attached_estimator.deploy()
601603
602604
Args:
@@ -634,8 +636,10 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
634636
return estimator
635637

636638
def logs(self):
637-
"""Display the logs for Estimator's training job. If the output is a tty or a Jupyter
638-
cell, it will be color-coded based on which instance the log entry is from.
639+
"""Display the logs for Estimator's training job.
640+
641+
If the output is a tty or a Jupyter cell, it will be color-coded based
642+
on which instance the log entry is from.
639643
"""
640644
self.sagemaker_session.logs_for_job(self.latest_training_job, wait=True)
641645

@@ -1837,14 +1841,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
18371841
has a Complete status, it can be ``deploy()`` ed to create a SageMaker
18381842
Endpoint and return a ``Predictor``.
18391843
1840-
If the training job is in progress, attach will block and display log
1841-
messages from the training job, until the training job completes.
1844+
If the training job is in progress, attach will block until the training job
1845+
completes, but logs of the training job will not display. To see the logs
1846+
content, please call ``logs()``
18421847
18431848
Examples:
18441849
>>> my_estimator.fit(wait=False)
18451850
>>> training_job_name = my_estimator.latest_training_job.name
18461851
Later on:
18471852
>>> attached_estimator = Estimator.attach(training_job_name)
1853+
>>> attached_estimator.logs()
18481854
>>> attached_estimator.deploy()
18491855
18501856
Args:

tests/unit/test_estimator.py

+26-58
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,16 @@ def sagemaker_session():
191191
return sms
192192

193193

194+
@pytest.fixture()
195+
def training_job_description(sagemaker_session):
196+
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
197+
mock_describe_training_job = Mock(
198+
name="describe_training_job", return_value=returned_job_description
199+
)
200+
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
201+
return returned_job_description
202+
203+
194204
def test_framework_all_init_args(sagemaker_session):
195205
f = DummyFramework(
196206
"my_script.py",
@@ -651,13 +661,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session):
651661
assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"]
652662

653663

654-
def test_attach_framework(sagemaker_session):
655-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
656-
returned_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
657-
returned_job_description["EnableNetworkIsolation"] = True
658-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
659-
name="describe_training_job", return_value=returned_job_description
660-
)
664+
def test_attach_framework(sagemaker_session, training_job_description):
665+
training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
666+
training_job_description["EnableNetworkIsolation"] = True
661667

662668
framework_estimator = DummyFramework.attach(
663669
training_job_name="neo", sagemaker_session=sagemaker_session
@@ -681,50 +687,25 @@ def test_attach_framework(sagemaker_session):
681687
assert framework_estimator.enable_network_isolation() is True
682688

683689

684-
def test_attach_no_logs(sagemaker_session):
685-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
686-
mock_describe_training_job = Mock(
687-
name="describe_training_job", return_value=returned_job_description
688-
)
689-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
690+
def test_attach_no_logs(sagemaker_session, training_job_description):
690691
Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
691692
sagemaker_session.logs_for_job.assert_not_called()
692693

693694

694-
def test_logs(sagemaker_session):
695-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
696-
mock_describe_training_job = Mock(
697-
name="describe_training_job", return_value=returned_job_description
698-
)
699-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
695+
def test_logs(sagemaker_session, training_job_description):
700696
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
701697
estimator.logs()
702698
sagemaker_session.logs_for_job.assert_called_with(estimator.latest_training_job, wait=True)
703699

704700

705-
def test_attach_without_hyperparameters(sagemaker_session):
706-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
707-
del returned_job_description["HyperParameters"]
708-
709-
mock_describe_training_job = Mock(
710-
name="describe_training_job", return_value=returned_job_description
711-
)
712-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
713-
701+
def test_attach_without_hyperparameters(sagemaker_session, training_job_description):
702+
del training_job_description["HyperParameters"]
714703
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
715-
716704
assert estimator.hyperparameters() == {}
717705

718706

719-
def test_attach_framework_with_tuning(sagemaker_session):
720-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
721-
returned_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy"
722-
723-
mock_describe_training_job = Mock(
724-
name="describe_training_job", return_value=returned_job_description
725-
)
726-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
727-
707+
def test_attach_framework_with_tuning(sagemaker_session, training_job_description):
708+
training_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy"
728709
framework_estimator = DummyFramework.attach(
729710
training_job_name="neo", sagemaker_session=sagemaker_session
730711
)
@@ -744,48 +725,35 @@ def test_attach_framework_with_tuning(sagemaker_session):
744725
assert framework_estimator.encrypt_inter_container_traffic is False
745726

746727

747-
def test_attach_framework_with_model_channel(sagemaker_session):
728+
def test_attach_framework_with_model_channel(sagemaker_session, training_job_description):
748729
s3_uri = "s3://some/s3/path/model.tar.gz"
749-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
750-
returned_job_description["InputDataConfig"] = [
730+
training_job_description["InputDataConfig"] = [
751731
{
752732
"ChannelName": "model",
753733
"InputMode": "File",
754734
"DataSource": {"S3DataSource": {"S3Uri": s3_uri}},
755735
}
756736
]
757737

758-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
759-
name="describe_training_job", return_value=returned_job_description
760-
)
761-
762738
framework_estimator = DummyFramework.attach(
763739
training_job_name="neo", sagemaker_session=sagemaker_session
764740
)
765741
assert framework_estimator.model_uri is s3_uri
766742
assert framework_estimator.encrypt_inter_container_traffic is False
767743

768744

769-
def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker_session):
770-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
771-
returned_job_description["EnableInterContainerTrafficEncryption"] = True
772-
773-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
774-
name="describe_training_job", return_value=returned_job_description
775-
)
776-
745+
def test_attach_framework_with_inter_container_traffic_encryption_flag(
746+
sagemaker_session, training_job_description
747+
):
748+
training_job_description["EnableInterContainerTrafficEncryption"] = True
777749
framework_estimator = DummyFramework.attach(
778750
training_job_name="neo", sagemaker_session=sagemaker_session
779751
)
780752

781753
assert framework_estimator.encrypt_inter_container_traffic is True
782754

783755

784-
def test_attach_framework_base_from_generated_name(sagemaker_session):
785-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
786-
name="describe_training_job", return_value=RETURNED_JOB_DESCRIPTION
787-
)
788-
756+
def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description):
789757
base_job_name = "neo"
790758
framework_estimator = DummyFramework.attach(
791759
training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session

0 commit comments

Comments
 (0)