Skip to content

Commit 8aa8819

Browse files
fix: Fix processing image uri param (#3158)
1 parent 9369a87 commit 8aa8819

File tree

12 files changed

+156
-18
lines changed

12 files changed

+156
-18
lines changed

src/sagemaker/estimator.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
105105
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
106106
SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
107107
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
108+
JOB_CLASS_NAME = "training-job"
108109

109110
def __init__(
110111
self,
@@ -594,7 +595,9 @@ def _ensure_base_job_name(self):
594595
self.base_job_name = (
595596
self.base_job_name
596597
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
597-
or base_name_from_image(self.training_image_uri())
598+
or base_name_from_image(
599+
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
600+
)
598601
)
599602

600603
def _get_or_create_name(self, name=None):
@@ -1007,7 +1010,9 @@ def fit(
10071010

10081011
def _compilation_job_name(self):
10091012
"""Placeholder docstring"""
1010-
base_name = self.base_job_name or base_name_from_image(self.training_image_uri())
1013+
base_name = self.base_job_name or base_name_from_image(
1014+
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
1015+
)
10111016
return name_from_base("compilation-" + base_name)
10121017

10131018
def compile_model(

src/sagemaker/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
679679
self._base_name = (
680680
self._base_name
681681
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri)
682-
or utils.base_name_from_image(image_uri)
682+
or utils.base_name_from_image(image_uri, default_base_name=Model.__name__)
683683
)
684684

685685
def _set_model_name_if_needed(self):

src/sagemaker/processing.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
class Processor(object):
4848
"""Handles Amazon SageMaker Processing tasks."""
4949

50+
JOB_CLASS_NAME = "processing-job"
51+
5052
def __init__(
5153
self,
5254
role: str,
@@ -282,7 +284,9 @@ def _generate_current_job_name(self, job_name=None):
282284
if self.base_job_name:
283285
base_name = self.base_job_name
284286
else:
285-
base_name = base_name_from_image(self.image_uri)
287+
base_name = base_name_from_image(
288+
self.image_uri, default_base_name=Processor.JOB_CLASS_NAME
289+
)
286290

287291
return name_from_base(base_name)
288292

src/sagemaker/transformer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
class Transformer(object):
2929
"""A class for handling creating and interacting with Amazon SageMaker transform jobs."""
3030

31+
JOB_CLASS_NAME = "transform-job"
32+
3133
def __init__(
3234
self,
3335
model_name: Union[str, PipelineVariable],
@@ -243,7 +245,7 @@ def _retrieve_base_name(self):
243245
image_uri = self._retrieve_image_uri()
244246

245247
if image_uri:
246-
return base_name_from_image(image_uri)
248+
return base_name_from_image(image_uri, default_base_name=Transformer.JOB_CLASS_NAME)
247249

248250
return self.model_name
249251

src/sagemaker/tuner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ def _prepare_job_name_for_tuning(self, job_name=None):
373373
estimator = (
374374
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
375375
)
376-
base_name = base_name_from_image(estimator.training_image_uri())
376+
base_name = base_name_from_image(
377+
estimator.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
378+
)
377379

378380
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
379381
getattr(estimator, "source_dir", None),

src/sagemaker/utils.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from sagemaker import deprecations
3535
from sagemaker.session_settings import SessionSettings
36+
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
3637

3738

3839
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
@@ -90,18 +91,27 @@ def unique_name_from_base(base, max_length=63):
9091
return "{}-{}-{}".format(trimmed, ts, unique)
9192

9293

93-
def base_name_from_image(image):
94+
def base_name_from_image(image, default_base_name=None):
9495
"""Extract the base name of the image to use as the 'algorithm name' for the job.
9596
9697
Args:
9798
image (str): Image name.
99+
default_base_name (str): The default base name
98100
99101
Returns:
100102
str: Algorithm name, as extracted from the image name.
101103
"""
102-
m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image)
103-
algo_name = m.group(2) if m else image
104-
return algo_name
104+
if is_pipeline_variable(image):
105+
if is_pipeline_parameter_string(image) and image.default_value:
106+
image_str = image.default_value
107+
else:
108+
return default_base_name if default_base_name else "base_name"
109+
else:
110+
image_str = image
111+
112+
m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str)
113+
base_name = m.group(2) if m else image_str
114+
return base_name
105115

106116

107117
def base_from_name(name):

src/sagemaker/workflow/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.workflow.entities import Expression
17+
from sagemaker.workflow.parameters import ParameterString
1718

1819

1920
def is_pipeline_variable(var: object) -> bool:
@@ -29,3 +30,14 @@ def is_pipeline_variable(var: object) -> bool:
2930
# as well as PipelineExperimentConfigProperty and PropertyFile
3031
# TODO: We should deprecate the Expression and replace it with PipelineVariable
3132
return isinstance(var, Expression)
33+
34+
35+
def is_pipeline_parameter_string(var: object) -> bool:
36+
"""Check if the variable is a pipeline parameter string
37+
38+
Args:
39+
var (object): The variable to be verified.
40+
Returns:
41+
bool: True if it is, False otherwise.
42+
"""
43+
return isinstance(var, ParameterString)

src/sagemaker/workflow/airflow.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sagemaker import fw_utils, job, utils, s3, session, vpc_utils
2121
from sagemaker.amazon import amazon_estimator
2222
from sagemaker.tensorflow import TensorFlow
23+
from sagemaker.estimator import EstimatorBase
24+
from sagemaker.processing import Processor
2325

2426

2527
def prepare_framework(estimator, s3_operations):
@@ -151,7 +153,8 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
151153
estimator._current_job_name = job_name
152154
else:
153155
base_name = estimator.base_job_name or utils.base_name_from_image(
154-
estimator.training_image_uri()
156+
estimator.training_image_uri(),
157+
default_base_name=EstimatorBase.JOB_CLASS_NAME,
155158
)
156159
estimator._current_job_name = utils.name_from_base(base_name)
157160

@@ -1138,7 +1141,7 @@ def processing_config(
11381141
processor._current_job_name = (
11391142
utils.name_from_base(base_name)
11401143
if base_name is not None
1141-
else utils.base_name_from_image(processor.image_uri)
1144+
else utils.base_name_from_image(processor.image_uri, Processor.JOB_CLASS_NAME)
11421145
)
11431146

11441147
config = {

tests/unit/sagemaker/model/test_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def test_create_sagemaker_model_generates_model_name(
287287
)
288288
model._create_sagemaker_model(INSTANCE_TYPE)
289289

290-
base_name_from_image.assert_called_with(MODEL_IMAGE)
290+
base_name_from_image.assert_called_with(MODEL_IMAGE, default_base_name="Model")
291291
name_from_base.assert_called_with(base_name_from_image.return_value)
292292

293293
sagemaker_session.create_model.assert_called_with(
@@ -317,7 +317,7 @@ def test_create_sagemaker_model_generates_model_name_each_time(
317317
model._create_sagemaker_model(INSTANCE_TYPE)
318318
model._create_sagemaker_model(INSTANCE_TYPE)
319319

320-
base_name_from_image.assert_called_once_with(MODEL_IMAGE)
320+
base_name_from_image.assert_called_once_with(MODEL_IMAGE, default_base_name="Model")
321321
name_from_base.assert_called_with(base_name_from_image.return_value)
322322
assert 2 == name_from_base.call_count
323323

tests/unit/sagemaker/workflow/test_pipeline_session.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717
from mock import Mock, PropertyMock
1818

1919
from sagemaker import Model
20-
from sagemaker.workflow.parameters import ParameterString
2120
from sagemaker.workflow.pipeline_context import PipelineSession
21+
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
22+
from sagemaker.workflow.parameters import (
23+
ParameterString,
24+
ParameterInteger,
25+
ParameterBoolean,
26+
ParameterFloat,
27+
)
28+
from sagemaker.workflow.functions import Join, JsonGet
29+
from tests.unit.sagemaker.workflow.helpers import CustomStep
2230

2331
from botocore.config import Config
2432

@@ -130,6 +138,46 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
130138
assert len(register_step_args.need_runtime_repack) == 0
131139

132140

141+
@pytest.mark.parametrize(
142+
"item",
143+
[
144+
(ParameterString(name="my-str"), True),
145+
(ParameterBoolean(name="my-bool"), True),
146+
(ParameterFloat(name="my-float"), True),
147+
(ParameterInteger(name="my-int"), True),
148+
(Join(on="/", values=["my", "value"]), True),
149+
(JsonGet(step_name="my-step", property_file="pf", json_path="path"), True),
150+
(CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, True),
151+
("my-str", False),
152+
(1, False),
153+
(CustomStep(name="my-ste"), False),
154+
],
155+
)
156+
def test_is_pipeline_variable(item):
157+
var, assertion = item
158+
assert is_pipeline_variable(var) == assertion
159+
160+
161+
@pytest.mark.parametrize(
162+
"item",
163+
[
164+
(ParameterString(name="my-str"), True),
165+
(ParameterBoolean(name="my-bool"), False),
166+
(ParameterFloat(name="my-float"), False),
167+
(ParameterInteger(name="my-int"), False),
168+
(Join(on="/", values=["my", "value"]), False),
169+
(JsonGet(step_name="my-step", property_file="pf", json_path="path"), False),
170+
(CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, False),
171+
("my-str", False),
172+
(1, False),
173+
(CustomStep(name="my-ste"), False),
174+
],
175+
)
176+
def test_is_pipeline_parameter_string(item):
177+
var, assertion = item
178+
assert is_pipeline_parameter_string(var) == assertion
179+
180+
133181
def test_pipeline_session_context_for_model_step_without_instance_types(
134182
pipeline_session_mock,
135183
):

tests/unit/sagemaker/workflow/test_processing_step.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -336,17 +336,27 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
336336
)
337337

338338

339-
def test_processing_step_with_processor_and_step_args(pipeline_session, processing_input):
339+
@pytest.mark.parametrize(
340+
"image_uri",
341+
[
342+
IMAGE_URI,
343+
ParameterString(name="MyImage"),
344+
ParameterString(name="MyImage", default_value="my-image-uri"),
345+
Join(on="/", values=["docker", "my-fake-image"]),
346+
],
347+
)
348+
def test_processing_step_with_processor_and_step_args(
349+
pipeline_session, processing_input, image_uri
350+
):
340351
processor = Processor(
341-
image_uri=IMAGE_URI,
352+
image_uri=image_uri,
342353
role=ROLE,
343354
instance_count=1,
344355
instance_type=INSTANCE_TYPE,
345356
sagemaker_session=pipeline_session,
346357
)
347358

348359
step_args = processor.run(inputs=processing_input)
349-
350360
try:
351361
ProcessingStep(
352362
name="MyProcessingStep",

tests/unit/test_utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
import sagemaker
3131
from sagemaker.session_settings import SessionSettings
32+
from tests.unit.sagemaker.workflow.helpers import CustomStep
33+
from sagemaker.workflow.parameters import ParameterString
3234

3335
BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"
3436

@@ -82,6 +84,46 @@ def test_name_from_image(base_name_from_image, name_from_base):
8284
name_from_base.assert_called_with(base_name_from_image.return_value, max_length=max_length)
8385

8486

87+
@pytest.mark.parametrize(
88+
"inputs",
89+
[
90+
(
91+
CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath,
92+
None,
93+
"base_name",
94+
),
95+
(
96+
CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath,
97+
"whatever",
98+
"whatever",
99+
),
100+
(ParameterString(name="image_uri"), None, "base_name"),
101+
(ParameterString(name="image_uri"), "whatever", "whatever"),
102+
(
103+
ParameterString(
104+
name="image_uri",
105+
default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer",
106+
),
107+
None,
108+
"analyzer",
109+
),
110+
(
111+
ParameterString(
112+
name="image_uri",
113+
default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer",
114+
),
115+
"whatever",
116+
"analyzer",
117+
),
118+
],
119+
)
120+
def test_base_name_from_image_with_pipeline_param(inputs):
121+
image, default_base_name, expected = inputs
122+
assert expected == sagemaker.utils.base_name_from_image(
123+
image=image, default_base_name=default_base_name
124+
)
125+
126+
85127
@patch("sagemaker.utils.sagemaker_timestamp")
86128
def test_name_from_base(sagemaker_timestamp):
87129
sagemaker.utils.name_from_base(NAME, short=False)

0 commit comments

Comments
 (0)