Skip to content

change: Make repack step output path align with model repack path #3257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
HUGGING_FACE_FRAMEWORK = "huggingface"


# TODO: we should remove this decorator later
@override_pipeline_parameter_var
def retrieve(
framework,
Expand Down Expand Up @@ -117,7 +116,11 @@ def retrieve(
args = dict(locals())
for name, val in args.items():
if is_pipeline_variable(val):
raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val)))
raise ValueError(
"When retrieving the image_uri, the argument %s should not be a pipeline variable "
"(%s) since pipeline variables are only interpreted in the pipeline execution time."
% (name, type(val))
)

if is_jumpstart_model_input(model_id, model_version):
return artifacts._retrieve_image_uri(
Expand Down Expand Up @@ -487,6 +490,9 @@ def get_training_image_uri(
if image_uri:
return image_uri

logger.info(
"image_uri is not presented, retrieving image_uri based on instance_type, framework etc."
)
base_framework_version: Optional[str] = None

if tensorflow_version is not None or pytorch_version is not None:
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,10 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
artifact should be repackaged into a new S3 object. (default: False).
"""
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
bucket = self.bucket or self.sagemaker_session.default_bucket()
if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None:
self.uploaded_code = None
elif not repack:
bucket = self.bucket or self.sagemaker_session.default_bucket()
self.uploaded_code = fw_utils.tar_and_upload_dir(
session=self.sagemaker_session.boto_session,
bucket=bucket,
Expand All @@ -557,6 +557,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
)
return
self.sagemaker_session.context.need_runtime_repack.add(id(self))
self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format(
bucket, key_prefix
)
# Add the uploaded_code and repacked_model_data to update the container env
self.repacked_model_data = self.model_data
self.uploaded_code = fw_utils.UploadedCode(
Expand All @@ -567,7 +570,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
if local_code and self.model_data.startswith("file://"):
repacked_model_data = self.model_data
else:
bucket = self.bucket or self.sagemaker_session.default_bucket()
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
self.uploaded_code = fw_utils.UploadedCode(
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ def training_image_uri(self):
"""
if self.image_uri:
return self.image_uri

logger.info(
"image_uri is not presented, retrieving image_uri based on instance_type, "
"framework etc."
)
return image_uris.retrieve(
self._image_framework(),
self.sagemaker_session.boto_region_name,
Expand Down
16 changes: 11 additions & 5 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.pipeline_context import PipelineSession

logger = logging.getLogger(__name__)


class TensorFlowPredictor(Predictor):
"""A ``Predictor`` implementation for inference against TensorFlow Serving endpoints."""
Expand Down Expand Up @@ -363,13 +365,10 @@ def prepare_container_def(
instance_type, accelerator_type, serverless_inference_config=serverless_inference_config
)
env = self._get_container_env()
key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image_uri)
bucket = self.bucket or self.sagemaker_session.default_bucket()

if self.entry_point and not is_pipeline_variable(self.model_data):
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
self.key_prefix, self.name, image_uri
)

bucket = self.bucket or self.sagemaker_session.default_bucket()
model_data = s3.s3_path_join("s3://", bucket, key_prefix, "model.tar.gz")

sagemaker.utils.repack_model(
Expand All @@ -385,6 +384,9 @@ def prepare_container_def(
# model is not yet there, defer repacking to later during pipeline execution
if isinstance(self.sagemaker_session, PipelineSession):
self.sagemaker_session.context.need_runtime_repack.add(id(self))
self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format(
bucket, key_prefix
)
else:
logging.warning(
"The model_data is a Pipeline variable of type %s, "
Expand Down Expand Up @@ -426,6 +428,10 @@ def _get_image_uri(
if self.image_uri:
return self.image_uri

logger.info(
"image_uri is not presented, retrieving image_uri based on instance_type, "
"framework etc."
)
return image_uris.retrieve(
self._framework_name,
region_name or self.sagemaker_session.boto_region_name,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, CreateModelStep

NEED_RUNTIME_REPACK = "need_runtime_repack"

_CREATE_MODEL_RETRY_POLICIES = "create_model_retry_policies"
_REGISTER_MODEL_RETRY_POLICIES = "register_model_retry_policies"
_REPACK_MODEL_RETRY_POLICIES = "repack_model_retry_policies"
Expand Down Expand Up @@ -155,6 +153,7 @@ def __init__(
self._create_model_args = self.step_args.create_model_request
self._register_model_args = self.step_args.create_model_package_request
self._need_runtime_repack = self.step_args.need_runtime_repack
self._runtime_repack_output_prefix = self.step_args.runtime_repack_output_prefix
self._assign_and_validate_retry_policies(retry_policies)

if self._need_runtime_repack:
Expand Down Expand Up @@ -268,6 +267,7 @@ def _append_repack_model_step(self):
),
depends_on=self.depends_on,
retry_policies=self._repack_model_retry_policies,
output_path=self._runtime_repack_output_prefix,
)
self.steps.append(repack_model_step)

Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/workflow/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, model):
self.create_model_package_request = None
self.create_model_request = None
self.need_runtime_repack = set()
self.runtime_repack_output_prefix = None


class PipelineSession(Session):
Expand Down Expand Up @@ -139,14 +140,14 @@ def _intercept_create_request(self, request: Dict, create, func_name: str = None
else:
self.context = _JobStepArguments(func_name, request)

def init_step_arguments(self, model):
def init_model_step_arguments(self, model):
"""Create a `_ModelStepArguments` (if not exist) as pipeline context

Args:
model (Model or PipelineModel): A `sagemaker.model.Model`
or `sagemaker.pipeline.PipelineModel` instance
"""
if not self._context or not isinstance(self._context, _ModelStepArguments):
if not isinstance(self._context, _ModelStepArguments):
self._context = _ModelStepArguments(model)


Expand Down Expand Up @@ -197,7 +198,7 @@ def wrapper(*args, **kwargs):
UserWarning,
)
if run_func.__name__ in ["register", "create"]:
self_instance.sagemaker_session.init_step_arguments(self_instance)
self_instance.sagemaker_session.init_model_step_arguments(self_instance)
run_func(*args, **kwargs)
context = self_instance.sagemaker_session.context
self_instance.sagemaker_session.context = None
Expand Down
12 changes: 7 additions & 5 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
RequestType,
)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from sagemaker.workflow.step_collections import StepCollection

Expand Down Expand Up @@ -173,26 +175,26 @@ def override_pipeline_parameter_var(func):
We should remove this decorator after the grace period.
"""
warning_msg_template = (
"%s should not be a pipeline variable (%s). "
"The default_value of this Parameter object will be used to override it. "
"Please remove this pipeline variable and use python primitives instead."
"The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. "
"The default_value of this Parameter object will be used to override it."
)

@wraps(func)
def wrapper(*args, **kwargs):
func_name = "{}.{}".format(func.__module__, func.__name__)
params = inspect.signature(func).parameters
args = list(args)
for i, (arg_name, _) in enumerate(params.items()):
if i >= len(args):
break
if isinstance(args[i], Parameter):
logging.warning(warning_msg_template, arg_name, type(args[i]))
logger.warning(warning_msg_template, arg_name, func_name, type(args[i]))
args[i] = args[i].default_value
args = tuple(args)

for arg_name, value in kwargs.items():
if isinstance(value, Parameter):
logging.warning(warning_msg_template, arg_name, type(value))
logger.warning(warning_msg_template, arg_name, func_name, type(value))
kwargs[arg_name] = value.default_value
return func(*args, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions tests/integ/sagemaker/workflow/test_model_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ def test_tensorflow_model_register_and_deploy_with_runtime_repack(
sagemaker_session=pipeline_session,
entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"),
dependencies=[os.path.join(_TENSORFLOW_PATH, "dependency.py")],
code_location=f"s3://{pipeline_session.default_bucket()}/model-code",
)
step_args = tf_model.register(
content_types=["application/json"],
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/image_uris/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def test_retrieve_with_pipeline_variable():
kwargs["instance_type"] = Join(on="", values=["a", "b"])
with pytest.raises(Exception) as error:
image_uris.retrieve(**kwargs)
assert "instance_type should not be a pipeline variable" in str(error.value)
assert "the argument instance_type should not be a pipeline variable" in str(error.value)

# instance_type (ParameterString) is given as args rather than kwargs
# which should not break anything
Expand Down
43 changes: 36 additions & 7 deletions tests/unit/sagemaker/workflow/test_model_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
_DIR_NAME = "/opt/ml/model/code"
_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")
_TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies")
_REPACK_OUTPUT_KEY_PREFIX = "code-output"
_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}"


@pytest.fixture
Expand Down Expand Up @@ -688,6 +690,7 @@ def test_conditional_model_create_and_regis(
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
role=_ROLE,
enable_network_isolation=True,
code_location=_MODEL_CODE_LOCATION,
),
2,
),
Expand All @@ -711,6 +714,7 @@ def test_conditional_model_create_and_regis(
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
role=_ROLE,
framework_version="1.5.0",
code_location=_MODEL_CODE_LOCATION,
),
2,
),
Expand Down Expand Up @@ -742,6 +746,7 @@ def test_conditional_model_create_and_regis(
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
role=_ROLE,
code_location=_MODEL_CODE_LOCATION,
),
2,
),
Expand All @@ -758,21 +763,45 @@ def test_conditional_model_create_and_regis(
],
)
def test_create_model_among_different_model_types(test_input, pipeline_session, model_data_param):
def assert_test_result(steps: list):
# If expected_step_num is 2, it means a runtime repack step is appended
# If expected_step_num is 1, it means no runtime repack is needed
assert len(steps) == expected_step_num
if expected_step_num == 2:
assert steps[0]["Type"] == "Training"
if model.key_prefix == _REPACK_OUTPUT_KEY_PREFIX:
assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == (
f"{_MODEL_CODE_LOCATION}/{model.name}"
)
else:
assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == (
f"s3://{_BUCKET}/{model.name}"
)

model, expected_step_num = test_input
model.sagemaker_session = pipeline_session
model.model_data = model_data_param
step_args = model.create(
create_model_step_args = model.create(
instance_type="c4.4xlarge",
)
model_steps = ModelStep(
create_model_steps = ModelStep(
name="MyModelStep",
step_args=step_args,
step_args=create_model_step_args,
)
steps = model_steps.request_dicts()
assert_test_result(create_model_steps.request_dicts())

# If expected_step_num is 2, it means a runtime repack step is appended
# If expected_step_num is 1, it means no runtime repack is needed
assert len(steps) == expected_step_num
register_model_step_args = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
transform_instances=["ml.m5.xlarge"],
model_package_group_name="MyModelPackageGroup",
)
register_model_steps = ModelStep(
name="MyModelStep",
step_args=register_model_step_args,
)
assert_test_result(register_model_steps.request_dicts())


@pytest.mark.parametrize(
Expand Down