Skip to content

Commit 66db4ce

Browse files
committed
Fix: change the default naming for base_name_from_image
1 parent 528c9de commit 66db4ce

File tree

9 files changed

+72
-13
lines changed

9 files changed

+72
-13
lines changed

src/sagemaker/estimator.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
102102
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
103103
SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
104104
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
105+
JOB_CLASS_NAME = "training-job"
105106

106107
def __init__(
107108
self,
@@ -576,7 +577,9 @@ def _ensure_base_job_name(self):
576577
self.base_job_name = (
577578
self.base_job_name
578579
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
579-
or base_name_from_image(self.training_image_uri())
580+
or base_name_from_image(
581+
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
582+
)
580583
)
581584

582585
def _get_or_create_name(self, name=None):
@@ -982,7 +985,9 @@ def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_conf
982985

983986
def _compilation_job_name(self):
984987
"""Placeholder docstring"""
985-
base_name = self.base_job_name or base_name_from_image(self.training_image_uri())
988+
base_name = self.base_job_name or base_name_from_image(
989+
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
990+
)
986991
return name_from_base("compilation-" + base_name)
987992

988993
def compile_model(

src/sagemaker/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
635635
self._base_name = (
636636
self._base_name
637637
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri)
638-
or utils.base_name_from_image(image_uri)
638+
or utils.base_name_from_image(image_uri, default_base_name=Model.__name__)
639639
)
640640

641641
def _set_model_name_if_needed(self):

src/sagemaker/processing.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
class Processor(object):
4646
"""Handles Amazon SageMaker Processing tasks."""
4747

48+
JOB_CLASS_NAME = "processing-job"
49+
4850
def __init__(
4951
self,
5052
role,
@@ -280,7 +282,9 @@ def _generate_current_job_name(self, job_name=None):
280282
if self.base_job_name:
281283
base_name = self.base_job_name
282284
else:
283-
base_name = base_name_from_image(self.image_uri)
285+
base_name = base_name_from_image(
286+
self.image_uri, default_base_name=Processor.JOB_CLASS_NAME
287+
)
284288

285289
return name_from_base(base_name)
286290

src/sagemaker/transformer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
class Transformer(object):
2626
"""A class for handling creating and interacting with Amazon SageMaker transform jobs."""
2727

28+
JOB_CLASS_NAME = "transform-job"
29+
2830
def __init__(
2931
self,
3032
model_name,
@@ -240,7 +242,7 @@ def _retrieve_base_name(self):
240242
image_uri = self._retrieve_image_uri()
241243

242244
if image_uri:
243-
return base_name_from_image(image_uri)
245+
return base_name_from_image(image_uri, default_base_name=Transformer.JOB_CLASS_NAME)
244246

245247
return self.model_name
246248

src/sagemaker/tuner.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
3030
from sagemaker.analytics import HyperparameterTuningJobAnalytics
3131
from sagemaker.deprecations import removed_function
32-
from sagemaker.estimator import Framework
32+
from sagemaker.estimator import Framework, EstimatorBase
3333
from sagemaker.inputs import TrainingInput
3434
from sagemaker.job import _Job
3535
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
@@ -367,7 +367,9 @@ def _prepare_job_name_for_tuning(self, job_name=None):
367367
estimator = (
368368
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
369369
)
370-
base_name = base_name_from_image(estimator.training_image_uri())
370+
base_name = base_name_from_image(
371+
estimator.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
372+
)
371373

372374
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
373375
getattr(estimator, "source_dir", None),

src/sagemaker/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ def unique_name_from_base(base, max_length=63):
9191
return "{}-{}-{}".format(trimmed, ts, unique)
9292

9393

94-
def base_name_from_image(image):
94+
def base_name_from_image(image, default_base_name=None):
9595
"""Extract the base name of the image to use as the 'algorithm name' for the job.
9696
9797
Args:
9898
image (str): Image name.
99+
default_base_name (str): The default base name
99100
100101
Returns:
101102
str: Algorithm name, as extracted from the image name.
@@ -104,7 +105,7 @@ def base_name_from_image(image):
104105
if is_pipeline_parameter_string(image) and image.default_value:
105106
image_str = image.default_value
106107
else:
107-
return "base_name"
108+
return default_base_name if default_base_name else "base_name"
108109
else:
109110
image_str = image
110111

src/sagemaker/workflow/airflow.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sagemaker import fw_utils, job, utils, s3, session, vpc_utils
2121
from sagemaker.amazon import amazon_estimator
2222
from sagemaker.tensorflow import TensorFlow
23+
from sagemaker.estimator import EstimatorBase
24+
from sagemaker.processing import Processor
2325

2426

2527
def prepare_framework(estimator, s3_operations):
@@ -151,7 +153,8 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
151153
estimator._current_job_name = job_name
152154
else:
153155
base_name = estimator.base_job_name or utils.base_name_from_image(
154-
estimator.training_image_uri()
156+
estimator.training_image_uri(),
157+
default_base_name=EstimatorBase.JOB_CLASS_NAME,
155158
)
156159
estimator._current_job_name = utils.name_from_base(base_name)
157160

@@ -1138,7 +1141,7 @@ def processing_config(
11381141
processor._current_job_name = (
11391142
utils.name_from_base(base_name)
11401143
if base_name is not None
1141-
else utils.base_name_from_image(processor.image_uri)
1144+
else utils.base_name_from_image(processor.image_uri, Processor.JOB_CLASS_NAME)
11421145
)
11431146

11441147
config = {

tests/unit/sagemaker/model/test_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def test_create_sagemaker_model_generates_model_name(
287287
)
288288
model._create_sagemaker_model(INSTANCE_TYPE)
289289

290-
base_name_from_image.assert_called_with(MODEL_IMAGE)
290+
base_name_from_image.assert_called_with(MODEL_IMAGE, default_base_name="Model")
291291
name_from_base.assert_called_with(base_name_from_image.return_value)
292292

293293
sagemaker_session.create_model.assert_called_with(
@@ -317,7 +317,7 @@ def test_create_sagemaker_model_generates_model_name_each_time(
317317
model._create_sagemaker_model(INSTANCE_TYPE)
318318
model._create_sagemaker_model(INSTANCE_TYPE)
319319

320-
base_name_from_image.assert_called_once_with(MODEL_IMAGE)
320+
base_name_from_image.assert_called_once_with(MODEL_IMAGE, default_base_name="Model")
321321
name_from_base.assert_called_with(base_name_from_image.return_value)
322322
assert 2 == name_from_base.call_count
323323

tests/unit/test_utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
import sagemaker
3131
from sagemaker.session_settings import SessionSettings
32+
from tests.unit.sagemaker.workflow.helpers import CustomStep
33+
from sagemaker.workflow.parameters import ParameterString
3234

3335
BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"
3436

@@ -82,6 +84,46 @@ def test_name_from_image(base_name_from_image, name_from_base):
8284
name_from_base.assert_called_with(base_name_from_image.return_value, max_length=max_length)
8385

8486

87+
@pytest.mark.parametrize(
88+
"inputs",
89+
[
90+
(
91+
CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath,
92+
None,
93+
"base_name",
94+
),
95+
(
96+
CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath,
97+
"whatever",
98+
"whatever",
99+
),
100+
(ParameterString(name="image_uri"), None, "base_name"),
101+
(ParameterString(name="image_uri"), "whatever", "whatever"),
102+
(
103+
ParameterString(
104+
name="image_uri",
105+
default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer",
106+
),
107+
None,
108+
"analyzer",
109+
),
110+
(
111+
ParameterString(
112+
name="image_uri",
113+
default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer",
114+
),
115+
"whatever",
116+
"analyzer",
117+
),
118+
],
119+
)
120+
def test_base_name_from_image_with_pipeline_param(inputs):
121+
image, default_base_name, expected = inputs
122+
assert expected == sagemaker.utils.base_name_from_image(
123+
image=image, default_base_name=default_base_name
124+
)
125+
126+
85127
@patch("sagemaker.utils.sagemaker_timestamp")
86128
def test_name_from_base(sagemaker_timestamp):
87129
sagemaker.utils.name_from_base(NAME, short=False)

0 commit comments

Comments
 (0)