Skip to content

fix: Fix processing image uri param #3158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from sagemaker import deprecations
from sagemaker.session_settings import SessionSettings
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string


ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
Expand Down Expand Up @@ -99,9 +100,17 @@ def base_name_from_image(image):
Returns:
str: Algorithm name, as extracted from the image name.
"""
m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image)
algo_name = m.group(2) if m else image
return algo_name
if is_pipeline_variable(image):
if is_pipeline_parameter_string(image) and image.default_value:
image_str = image.default_value
else:
return "base_name"
else:
image_str = image

m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str)
base_name = m.group(2) if m else image_str
return base_name


def base_from_name(name):
Expand Down
12 changes: 12 additions & 0 deletions src/sagemaker/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

from sagemaker.workflow.entities import Expression
from sagemaker.workflow.parameters import ParameterString


def is_pipeline_variable(var: object) -> bool:
Expand All @@ -29,3 +30,14 @@ def is_pipeline_variable(var: object) -> bool:
# as well as PipelineExperimentConfigProperty and PropertyFile
# TODO: We should deprecate the Expression and replace it with PipelineVariable
return isinstance(var, Expression)


def is_pipeline_parameter_string(var: object) -> bool:
"""Check if the variable is a pipeline parameter string

Args:
var (object): The variable to be verified.
Returns:
bool: True if it is, False otherwise.
"""
return isinstance(var, ParameterString)
50 changes: 49 additions & 1 deletion tests/unit/sagemaker/workflow/test_pipeline_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@
from mock import Mock, PropertyMock

from sagemaker import Model
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
from sagemaker.workflow.parameters import (
ParameterString,
ParameterInteger,
ParameterBoolean,
ParameterFloat,
)
from sagemaker.workflow.functions import Join, JsonGet
from tests.unit.sagemaker.workflow.helpers import CustomStep

from botocore.config import Config

Expand Down Expand Up @@ -122,3 +130,43 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
assert not register_step_args.create_model_request
assert register_step_args.create_model_package_request
assert len(register_step_args.need_runtime_repack) == 0


@pytest.mark.parametrize(
"item",
[
(ParameterString(name="my-str"), True),
(ParameterBoolean(name="my-bool"), True),
(ParameterFloat(name="my-float"), True),
(ParameterInteger(name="my-int"), True),
(Join(on="/", values=["my", "value"]), True),
(JsonGet(step_name="my-step", property_file="pf", json_path="path"), True),
(CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, True),
("my-str", False),
(1, False),
(CustomStep(name="my-ste"), False),
],
)
def test_is_pipeline_variable(item):
var, assertion = item
assert is_pipeline_variable(var) == assertion


@pytest.mark.parametrize(
"item",
[
(ParameterString(name="my-str"), True),
(ParameterBoolean(name="my-bool"), False),
(ParameterFloat(name="my-float"), False),
(ParameterInteger(name="my-int"), False),
(Join(on="/", values=["my", "value"]), False),
(JsonGet(step_name="my-step", property_file="pf", json_path="path"), False),
(CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, False),
("my-str", False),
(1, False),
(CustomStep(name="my-ste"), False),
],
)
def test_is_pipeline_parameter_string(item):
var, assertion = item
assert is_pipeline_parameter_string(var) == assertion
16 changes: 13 additions & 3 deletions tests/unit/sagemaker/workflow/test_processing_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,27 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
}


def test_processing_step_with_processor_and_step_args(pipeline_session, processing_input):
@pytest.mark.parametrize(
"image_uri",
[
IMAGE_URI,
ParameterString(name="MyImage"),
ParameterString(name="MyImage", default_value="my-image-uri"),
Join(on="/", values=["docker", "my-fake-image"]),
],
)
def test_processing_step_with_processor_and_step_args(
pipeline_session, processing_input, image_uri
):
processor = Processor(
image_uri=IMAGE_URI,
image_uri=image_uri,
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
sagemaker_session=pipeline_session,
)

step_args = processor.run(inputs=processing_input)

try:
ProcessingStep(
name="MyProcessingStep",
Expand Down