Skip to content

Commit a062c6a

Browse files
authored
change: infer base name from job name in estimator.attach() (aws#1648)
1 parent 2dfce2f commit a062c6a

File tree

12 files changed

+53
-27
lines changed

12 files changed

+53
-27
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,10 @@ class constructor
274274
init_params["image_name"] = image_name
275275
return init_params
276276

277-
training_job_name = init_params["base_job_name"]
278-
279277
if framework != cls.__framework_name__:
280278
raise ValueError(
281279
"Training job: {} didn't use image for requested framework".format(
282-
training_job_name
280+
job_details["TrainingJobName"]
283281
)
284282
)
285283
return init_params

src/sagemaker/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from sagemaker.session import Session
5757
from sagemaker.session import s3_input
5858
from sagemaker.transformer import Transformer
59-
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value
59+
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, get_config_value
6060
from sagemaker import vpc_utils
6161

6262

@@ -616,7 +616,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
616616

617617
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
618618
estimator.latest_training_job = _TrainingJob(
619-
sagemaker_session=sagemaker_session, job_name=init_params["base_job_name"]
619+
sagemaker_session=sagemaker_session, job_name=training_job_name
620620
)
621621
estimator._current_job_name = estimator.latest_training_job.name
622622
estimator.latest_training_job.wait()
@@ -776,7 +776,7 @@ class constructor
776776
init_params["train_volume_size"] = job_details["ResourceConfig"]["VolumeSizeInGB"]
777777
init_params["train_max_run"] = job_details["StoppingCondition"]["MaxRuntimeInSeconds"]
778778
init_params["input_mode"] = job_details["AlgorithmSpecification"]["TrainingInputMode"]
779-
init_params["base_job_name"] = job_details["TrainingJobName"]
779+
init_params["base_job_name"] = base_from_name(job_details["TrainingJobName"])
780780
init_params["output_path"] = job_details["OutputDataConfig"]["S3OutputPath"]
781781
init_params["output_kms_key"] = job_details["OutputDataConfig"]["KmsKeyId"]
782782
if "EnableNetworkIsolation" in job_details:

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,10 @@ class constructor
277277
init_params["image_name"] = image_name
278278
return init_params
279279

280-
training_job_name = init_params["base_job_name"]
281-
282280
if framework != cls.__framework_name__:
283281
raise ValueError(
284282
"Training job: {} didn't use image for requested framework".format(
285-
training_job_name
283+
job_details["TrainingJobName"]
286284
)
287285
)
288286

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,10 @@ class constructor
223223
init_params["image_name"] = image_name
224224
return init_params
225225

226-
training_job_name = init_params["base_job_name"]
227-
228226
if framework != cls.__framework_name__:
229227
raise ValueError(
230228
"Training job: {} didn't use image for requested framework".format(
231-
training_job_name
229+
job_details["TrainingJobName"]
232230
)
233231
)
234232

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,9 @@ class constructor
314314
toolkit, toolkit_version = cls._toolkit_and_version_from_tag(tag)
315315

316316
if not cls._is_combination_supported(toolkit, toolkit_version, framework):
317-
training_job_name = init_params["base_job_name"]
318317
raise ValueError(
319318
"Training job: {} didn't use image for requested framework".format(
320-
training_job_name
319+
job_details["TrainingJobName"]
321320
)
322321
)
323322

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,10 @@ class constructor
246246
init_params["image_name"] = image_name
247247
return init_params
248248

249-
training_job_name = init_params["base_job_name"]
250-
251249
if framework and framework != cls.__framework_name__:
252250
raise ValueError(
253251
"Training job: {} didn't use image for requested framework".format(
254-
training_job_name
252+
job_details["TrainingJobName"]
255253
)
256254
)
257255

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
221221
if not script_mode:
222222
init_params["image_name"] = image_name
223223

224-
training_job_name = init_params["base_job_name"]
225224
if framework != cls.__framework_name__:
226225
raise ValueError(
227226
"Training job: {} didn't use image for requested framework".format(
228-
training_job_name
227+
job_details["TrainingJobName"]
229228
)
230229
)
231230

src/sagemaker/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ def base_name_from_image(image):
104104
return algo_name
105105

106106

107+
def base_from_name(name):
108+
"""Extract the base name of the resource name (for use with future resource name generation).
109+
110+
This function looks for timestamps that match the ones produced by
111+
:func:`~sagemaker.utils.name_from_base`.
112+
113+
Args:
114+
name (str): The resource name.
115+
116+
Returns:
117+
str: The base name, as extracted from the resource name.
118+
"""
119+
m = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", name)
120+
return m.group(1) if m else name
121+
122+
107123
def sagemaker_timestamp():
108124
"""Return a timestamp with millisecond precision."""
109125
moment = time.time()

src/sagemaker/xgboost/estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
238238

239239
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
240240
estimator.latest_training_job = _TrainingJob(
241-
sagemaker_session=sagemaker_session, job_name=init_params["base_job_name"]
241+
sagemaker_session=sagemaker_session, job_name=training_job_name
242242
)
243243
estimator._current_job_name = estimator.latest_training_job.name
244244
estimator.latest_training_job.wait()
@@ -268,10 +268,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
268268
init_params["py_version"] = py_version
269269

270270
if framework and framework != cls.__framework_name__:
271-
training_job_name = init_params["base_job_name"]
272271
raise ValueError(
273272
"Training job: {} didn't use image for requested framework".format(
274-
training_job_name
273+
job_details["TrainingJobName"]
275274
)
276275
)
277276
init_params["framework_version"] = framework_version_from_tag(tag)

tests/unit/test_estimator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222
from mock import ANY, MagicMock, Mock, patch
2323

24-
from sagemaker import vpc_utils
24+
from sagemaker import utils, vpc_utils
2525
from sagemaker.amazon.amazon_estimator import registry
2626
from sagemaker.algorithm import AlgorithmEstimator
2727
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
@@ -796,6 +796,19 @@ def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker
796796
assert framework_estimator.encrypt_inter_container_traffic is True
797797

798798

799+
def test_attach_framework_base_from_generated_name(sagemaker_session):
800+
sagemaker_session.sagemaker_client.describe_training_job = Mock(
801+
name="describe_training_job", return_value=RETURNED_JOB_DESCRIPTION
802+
)
803+
804+
base_job_name = "neo"
805+
framework_estimator = DummyFramework.attach(
806+
training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session
807+
)
808+
809+
assert framework_estimator.base_job_name == base_job_name
810+
811+
799812
@patch("time.strftime", return_value=TIMESTAMP)
800813
def test_fit_verify_job_name(strftime, sagemaker_session):
801814
fw = DummyFramework(

tests/unit/test_tuner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def test_deploy_default(tuner):
787787

788788
tuner.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
789789
name="describe_hyper_parameter_tuning_job",
790-
return_value={"BestTrainingJob": {"TrainingJobName": JOB_NAME}},
790+
return_value={"BestTrainingJob": {"TrainingJobName": TRAINING_JOB_NAME}},
791791
)
792792

793793
tuner.sagemaker_session.sagemaker_client.list_tags = Mock(
@@ -807,7 +807,7 @@ def test_deploy_default(tuner):
807807
assert args[2]["ModelDataUrl"] == MODEL_DATA
808808

809809
assert isinstance(predictor, Predictor)
810-
assert predictor.endpoint_name.startswith(JOB_NAME)
810+
assert predictor.endpoint_name.startswith(TRAINING_JOB_NAME)
811811
assert predictor.sagemaker_session == tuner.sagemaker_session
812812

813813

@@ -823,7 +823,7 @@ def test_deploy_estimator_dict(tuner):
823823
name="describe_hyper_parameter_tuning_job",
824824
return_value={
825825
"BestTrainingJob": {
826-
"TrainingJobName": JOB_NAME,
826+
"TrainingJobName": TRAINING_JOB_NAME,
827827
"TrainingJobDefinitionName": ESTIMATOR_NAME,
828828
}
829829
},
@@ -846,7 +846,7 @@ def test_deploy_estimator_dict(tuner):
846846
assert args[2]["ModelDataUrl"] == MODEL_DATA
847847

848848
assert isinstance(predictor, Predictor)
849-
assert predictor.endpoint_name.startswith(JOB_NAME)
849+
assert predictor.endpoint_name.startswith(TRAINING_JOB_NAME)
850850
assert predictor.sagemaker_session == tuner.sagemaker_session
851851

852852

tests/unit/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ def test_name_from_training_arn():
115115
assert name == "resnet-sgd-tuningjob-11-22-38-46-002-2927640b"
116116

117117

118+
def test_base_from_name():
119+
name = "mxnet-training-2020-06-29-15-19-25-475"
120+
assert "mxnet-training" == sagemaker.utils.base_from_name(name)
121+
122+
name = "sagemaker-pytorch-200629-1611"
123+
assert "sagemaker-pytorch" == sagemaker.utils.base_from_name(name)
124+
125+
118126
MESSAGE = "message"
119127
STATUS = "status"
120128
TRAINING_JOB_DESCRIPTION_1 = {

0 commit comments

Comments
 (0)