Skip to content

fix: Fix processing image uri param #3158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options"
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
JOB_CLASS_NAME = "training-job"

def __init__(
self,
Expand Down Expand Up @@ -594,7 +595,9 @@ def _ensure_base_job_name(self):
self.base_job_name = (
self.base_job_name
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
or base_name_from_image(self.training_image_uri())
or base_name_from_image(
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
)
)

def _get_or_create_name(self, name=None):
Expand Down Expand Up @@ -1007,7 +1010,9 @@ def fit(

def _compilation_job_name(self):
"""Placeholder docstring"""
base_name = self.base_job_name or base_name_from_image(self.training_image_uri())
base_name = self.base_job_name or base_name_from_image(
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
)
return name_from_base("compilation-" + base_name)

def compile_model(
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
self._base_name = (
self._base_name
or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri)
or utils.base_name_from_image(image_uri)
or utils.base_name_from_image(image_uri, default_base_name=Model.__name__)
)

def _set_model_name_if_needed(self):
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
class Processor(object):
"""Handles Amazon SageMaker Processing tasks."""

JOB_CLASS_NAME = "processing-job"

def __init__(
self,
role: str,
Expand Down Expand Up @@ -282,7 +284,9 @@ def _generate_current_job_name(self, job_name=None):
if self.base_job_name:
base_name = self.base_job_name
else:
base_name = base_name_from_image(self.image_uri)
base_name = base_name_from_image(
self.image_uri, default_base_name=Processor.JOB_CLASS_NAME
)

return name_from_base(base_name)

Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
class Transformer(object):
"""A class for handling creating and interacting with Amazon SageMaker transform jobs."""

JOB_CLASS_NAME = "transform-job"

def __init__(
self,
model_name: Union[str, PipelineVariable],
Expand Down Expand Up @@ -243,7 +245,7 @@ def _retrieve_base_name(self):
image_uri = self._retrieve_image_uri()

if image_uri:
return base_name_from_image(image_uri)
return base_name_from_image(image_uri, default_base_name=Transformer.JOB_CLASS_NAME)

return self.model_name

Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,9 @@ def _prepare_job_name_for_tuning(self, job_name=None):
estimator = (
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
)
base_name = base_name_from_image(estimator.training_image_uri())
base_name = base_name_from_image(
estimator.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
)

jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
getattr(estimator, "source_dir", None),
Expand Down
18 changes: 14 additions & 4 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from sagemaker import deprecations
from sagemaker.session_settings import SessionSettings
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string


ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
Expand Down Expand Up @@ -90,18 +91,27 @@ def unique_name_from_base(base, max_length=63):
return "{}-{}-{}".format(trimmed, ts, unique)


def base_name_from_image(image):
def base_name_from_image(image, default_base_name=None):
"""Extract the base name of the image to use as the 'algorithm name' for the job.

Args:
image (str): Image name.
default_base_name (str): The default base name

Returns:
str: Algorithm name, as extracted from the image name.
"""
m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image)
algo_name = m.group(2) if m else image
return algo_name
if is_pipeline_variable(image):
if is_pipeline_parameter_string(image) and image.default_value:
image_str = image.default_value
else:
return default_base_name if default_base_name else "base_name"
else:
image_str = image

m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str)
base_name = m.group(2) if m else image_str
return base_name


def base_from_name(name):
Expand Down
12 changes: 12 additions & 0 deletions src/sagemaker/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

from sagemaker.workflow.entities import Expression
from sagemaker.workflow.parameters import ParameterString


def is_pipeline_variable(var: object) -> bool:
Expand All @@ -29,3 +30,14 @@ def is_pipeline_variable(var: object) -> bool:
# as well as PipelineExperimentConfigProperty and PropertyFile
# TODO: We should deprecate the Expression and replace it with PipelineVariable
return isinstance(var, Expression)


def is_pipeline_parameter_string(var: object) -> bool:
"""Check if the variable is a pipeline parameter string

Args:
var (object): The variable to be verified.
Returns:
bool: True if it is, False otherwise.
"""
return isinstance(var, ParameterString)
7 changes: 5 additions & 2 deletions src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from sagemaker import fw_utils, job, utils, s3, session, vpc_utils
from sagemaker.amazon import amazon_estimator
from sagemaker.tensorflow import TensorFlow
from sagemaker.estimator import EstimatorBase
from sagemaker.processing import Processor


def prepare_framework(estimator, s3_operations):
Expand Down Expand Up @@ -151,7 +153,8 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
estimator._current_job_name = job_name
else:
base_name = estimator.base_job_name or utils.base_name_from_image(
estimator.training_image_uri()
estimator.training_image_uri(),
default_base_name=EstimatorBase.JOB_CLASS_NAME,
)
estimator._current_job_name = utils.name_from_base(base_name)

Expand Down Expand Up @@ -1138,7 +1141,7 @@ def processing_config(
processor._current_job_name = (
utils.name_from_base(base_name)
if base_name is not None
else utils.base_name_from_image(processor.image_uri)
else utils.base_name_from_image(processor.image_uri, Processor.JOB_CLASS_NAME)
)

config = {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_create_sagemaker_model_generates_model_name(
)
model._create_sagemaker_model(INSTANCE_TYPE)

base_name_from_image.assert_called_with(MODEL_IMAGE)
base_name_from_image.assert_called_with(MODEL_IMAGE, default_base_name="Model")
name_from_base.assert_called_with(base_name_from_image.return_value)

sagemaker_session.create_model.assert_called_with(
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_create_sagemaker_model_generates_model_name_each_time(
model._create_sagemaker_model(INSTANCE_TYPE)
model._create_sagemaker_model(INSTANCE_TYPE)

base_name_from_image.assert_called_once_with(MODEL_IMAGE)
base_name_from_image.assert_called_once_with(MODEL_IMAGE, default_base_name="Model")
name_from_base.assert_called_with(base_name_from_image.return_value)
assert 2 == name_from_base.call_count

Expand Down
50 changes: 49 additions & 1 deletion tests/unit/sagemaker/workflow/test_pipeline_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@
from mock import Mock, PropertyMock

from sagemaker import Model
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
from sagemaker.workflow.parameters import (
ParameterString,
ParameterInteger,
ParameterBoolean,
ParameterFloat,
)
from sagemaker.workflow.functions import Join, JsonGet
from tests.unit.sagemaker.workflow.helpers import CustomStep

from botocore.config import Config

Expand Down Expand Up @@ -130,6 +138,46 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
assert len(register_step_args.need_runtime_repack) == 0


@pytest.mark.parametrize(
"item",
[
(ParameterString(name="my-str"), True),
(ParameterBoolean(name="my-bool"), True),
(ParameterFloat(name="my-float"), True),
(ParameterInteger(name="my-int"), True),
(Join(on="/", values=["my", "value"]), True),
(JsonGet(step_name="my-step", property_file="pf", json_path="path"), True),
(CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, True),
("my-str", False),
(1, False),
(CustomStep(name="my-ste"), False),
],
)
def test_is_pipeline_variable(item):
var, assertion = item
assert is_pipeline_variable(var) == assertion


@pytest.mark.parametrize(
"item",
[
(ParameterString(name="my-str"), True),
(ParameterBoolean(name="my-bool"), False),
(ParameterFloat(name="my-float"), False),
(ParameterInteger(name="my-int"), False),
(Join(on="/", values=["my", "value"]), False),
(JsonGet(step_name="my-step", property_file="pf", json_path="path"), False),
(CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, False),
("my-str", False),
(1, False),
(CustomStep(name="my-ste"), False),
],
)
def test_is_pipeline_parameter_string(item):
var, assertion = item
assert is_pipeline_parameter_string(var) == assertion


def test_pipeline_session_context_for_model_step_without_instance_types(
pipeline_session_mock,
):
Expand Down
16 changes: 13 additions & 3 deletions tests/unit/sagemaker/workflow/test_processing_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,27 @@ def test_processing_step_with_processor(pipeline_session, processing_input):
)


def test_processing_step_with_processor_and_step_args(pipeline_session, processing_input):
@pytest.mark.parametrize(
"image_uri",
[
IMAGE_URI,
ParameterString(name="MyImage"),
ParameterString(name="MyImage", default_value="my-image-uri"),
Join(on="/", values=["docker", "my-fake-image"]),
],
)
def test_processing_step_with_processor_and_step_args(
pipeline_session, processing_input, image_uri
):
processor = Processor(
image_uri=IMAGE_URI,
image_uri=image_uri,
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
sagemaker_session=pipeline_session,
)

step_args = processor.run(inputs=processing_input)

try:
ProcessingStep(
name="MyProcessingStep",
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

import sagemaker
from sagemaker.session_settings import SessionSettings
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.workflow.parameters import ParameterString

BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"

Expand Down Expand Up @@ -82,6 +84,46 @@ def test_name_from_image(base_name_from_image, name_from_base):
name_from_base.assert_called_with(base_name_from_image.return_value, max_length=max_length)


@pytest.mark.parametrize(
"inputs",
[
(
CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath,
None,
"base_name",
),
(
CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath,
"whatever",
"whatever",
),
(ParameterString(name="image_uri"), None, "base_name"),
(ParameterString(name="image_uri"), "whatever", "whatever"),
(
ParameterString(
name="image_uri",
default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer",
),
None,
"analyzer",
),
(
ParameterString(
name="image_uri",
default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer",
),
"whatever",
"analyzer",
),
],
)
def test_base_name_from_image_with_pipeline_param(inputs):
image, default_base_name, expected = inputs
assert expected == sagemaker.utils.base_name_from_image(
image=image, default_base_name=default_base_name
)


@patch("sagemaker.utils.sagemaker_timestamp")
def test_name_from_base(sagemaker_timestamp):
sagemaker.utils.name_from_base(NAME, short=False)
Expand Down