Skip to content

Commit cdd8cb5

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
change: Make repack step output path align with model repack path (#3257)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 6a4fd6a commit cdd8cb5

File tree

10 files changed

+79
-27
lines changed

10 files changed

+79
-27
lines changed

src/sagemaker/image_uris.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
HUGGING_FACE_FRAMEWORK = "huggingface"
3333

3434

35-
# TODO: we should remove this decorator later
3635
@override_pipeline_parameter_var
3736
def retrieve(
3837
framework,
@@ -117,7 +116,11 @@ def retrieve(
117116
args = dict(locals())
118117
for name, val in args.items():
119118
if is_pipeline_variable(val):
120-
raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val)))
119+
raise ValueError(
120+
"When retrieving the image_uri, the argument %s should not be a pipeline variable "
121+
"(%s) since pipeline variables are only interpreted in the pipeline execution time."
122+
% (name, type(val))
123+
)
121124

122125
if is_jumpstart_model_input(model_id, model_version):
123126
return artifacts._retrieve_image_uri(
@@ -487,6 +490,9 @@ def get_training_image_uri(
487490
if image_uri:
488491
return image_uri
489492

493+
logger.info(
494+
"image_uri is not presented, retrieving image_uri based on instance_type, framework etc."
495+
)
490496
base_framework_version: Optional[str] = None
491497

492498
if tensorflow_version is not None or pytorch_version is not None:

src/sagemaker/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,10 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
527527
artifact should be repackaged into a new S3 object. (default: False).
528528
"""
529529
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
530+
bucket = self.bucket or self.sagemaker_session.default_bucket()
530531
if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None:
531532
self.uploaded_code = None
532533
elif not repack:
533-
bucket = self.bucket or self.sagemaker_session.default_bucket()
534534
self.uploaded_code = fw_utils.tar_and_upload_dir(
535535
session=self.sagemaker_session.boto_session,
536536
bucket=bucket,
@@ -557,6 +557,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
557557
)
558558
return
559559
self.sagemaker_session.context.need_runtime_repack.add(id(self))
560+
self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format(
561+
bucket, key_prefix
562+
)
560563
# Add the uploaded_code and repacked_model_data to update the container env
561564
self.repacked_model_data = self.model_data
562565
self.uploaded_code = fw_utils.UploadedCode(
@@ -567,7 +570,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
567570
if local_code and self.model_data.startswith("file://"):
568571
repacked_model_data = self.model_data
569572
else:
570-
bucket = self.bucket or self.sagemaker_session.default_bucket()
571573
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
572574
self.uploaded_code = fw_utils.UploadedCode(
573575
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)

src/sagemaker/rl/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ def training_image_uri(self):
282282
"""
283283
if self.image_uri:
284284
return self.image_uri
285+
286+
logger.info(
287+
"image_uri is not presented, retrieving image_uri based on instance_type, "
288+
"framework etc."
289+
)
285290
return image_uris.retrieve(
286291
self._image_framework(),
287292
self.sagemaker_session.boto_region_name,

src/sagemaker/tensorflow/model.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from sagemaker.workflow import is_pipeline_variable
2525
from sagemaker.workflow.pipeline_context import PipelineSession
2626

27+
logger = logging.getLogger(__name__)
28+
2729

2830
class TensorFlowPredictor(Predictor):
2931
"""A ``Predictor`` implementation for inference against TensorFlow Serving endpoints."""
@@ -363,13 +365,10 @@ def prepare_container_def(
363365
instance_type, accelerator_type, serverless_inference_config=serverless_inference_config
364366
)
365367
env = self._get_container_env()
368+
key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image_uri)
369+
bucket = self.bucket or self.sagemaker_session.default_bucket()
366370

367371
if self.entry_point and not is_pipeline_variable(self.model_data):
368-
key_prefix = sagemaker.fw_utils.model_code_key_prefix(
369-
self.key_prefix, self.name, image_uri
370-
)
371-
372-
bucket = self.bucket or self.sagemaker_session.default_bucket()
373372
model_data = s3.s3_path_join("s3://", bucket, key_prefix, "model.tar.gz")
374373

375374
sagemaker.utils.repack_model(
@@ -385,6 +384,9 @@ def prepare_container_def(
385384
# model is not yet there, defer repacking to later during pipeline execution
386385
if isinstance(self.sagemaker_session, PipelineSession):
387386
self.sagemaker_session.context.need_runtime_repack.add(id(self))
387+
self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format(
388+
bucket, key_prefix
389+
)
388390
else:
389391
logging.warning(
390392
"The model_data is a Pipeline variable of type %s, "
@@ -426,6 +428,10 @@ def _get_image_uri(
426428
if self.image_uri:
427429
return self.image_uri
428430

431+
logger.info(
432+
"image_uri is not presented, retrieving image_uri based on instance_type, "
433+
"framework etc."
434+
)
429435
return image_uris.retrieve(
430436
self._framework_name,
431437
region_name or self.sagemaker_session.boto_region_name,

src/sagemaker/workflow/model_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from sagemaker.workflow.step_collections import StepCollection
2424
from sagemaker.workflow.steps import Step, CreateModelStep
2525

26-
NEED_RUNTIME_REPACK = "need_runtime_repack"
27-
2826
_CREATE_MODEL_RETRY_POLICIES = "create_model_retry_policies"
2927
_REGISTER_MODEL_RETRY_POLICIES = "register_model_retry_policies"
3028
_REPACK_MODEL_RETRY_POLICIES = "repack_model_retry_policies"
@@ -155,6 +153,7 @@ def __init__(
155153
self._create_model_args = self.step_args.create_model_request
156154
self._register_model_args = self.step_args.create_model_package_request
157155
self._need_runtime_repack = self.step_args.need_runtime_repack
156+
self._runtime_repack_output_prefix = self.step_args.runtime_repack_output_prefix
158157
self._assign_and_validate_retry_policies(retry_policies)
159158

160159
if self._need_runtime_repack:
@@ -268,6 +267,7 @@ def _append_repack_model_step(self):
268267
),
269268
depends_on=self.depends_on,
270269
retry_policies=self._repack_model_retry_policies,
270+
output_path=self._runtime_repack_output_prefix,
271271
)
272272
self.steps.append(repack_model_step)
273273

src/sagemaker/workflow/pipeline_context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, model):
6767
self.create_model_package_request = None
6868
self.create_model_request = None
6969
self.need_runtime_repack = set()
70+
self.runtime_repack_output_prefix = None
7071

7172

7273
class PipelineSession(Session):
@@ -139,14 +140,14 @@ def _intercept_create_request(self, request: Dict, create, func_name: str = None
139140
else:
140141
self.context = _JobStepArguments(func_name, request)
141142

142-
def init_step_arguments(self, model):
143+
def init_model_step_arguments(self, model):
143144
"""Create a `_ModelStepArguments` (if not exist) as pipeline context
144145
145146
Args:
146147
model (Model or PipelineModel): A `sagemaker.model.Model`
147148
or `sagemaker.pipeline.PipelineModel` instance
148149
"""
149-
if not self._context or not isinstance(self._context, _ModelStepArguments):
150+
if not isinstance(self._context, _ModelStepArguments):
150151
self._context = _ModelStepArguments(model)
151152

152153

@@ -197,7 +198,7 @@ def wrapper(*args, **kwargs):
197198
UserWarning,
198199
)
199200
if run_func.__name__ in ["register", "create"]:
200-
self_instance.sagemaker_session.init_step_arguments(self_instance)
201+
self_instance.sagemaker_session.init_model_step_arguments(self_instance)
201202
run_func(*args, **kwargs)
202203
context = self_instance.sagemaker_session.context
203204
self_instance.sagemaker_session.context = None

src/sagemaker/workflow/utilities.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
RequestType,
3030
)
3131

32+
logger = logging.getLogger(__name__)
33+
3234
if TYPE_CHECKING:
3335
from sagemaker.workflow.step_collections import StepCollection
3436

@@ -173,26 +175,26 @@ def override_pipeline_parameter_var(func):
173175
We should remove this decorator after the grace period.
174176
"""
175177
warning_msg_template = (
176-
"%s should not be a pipeline variable (%s). "
177-
"The default_value of this Parameter object will be used to override it. "
178-
"Please remove this pipeline variable and use python primitives instead."
178+
"The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. "
179+
"The default_value of this Parameter object will be used to override it."
179180
)
180181

181182
@wraps(func)
182183
def wrapper(*args, **kwargs):
184+
func_name = "{}.{}".format(func.__module__, func.__name__)
183185
params = inspect.signature(func).parameters
184186
args = list(args)
185187
for i, (arg_name, _) in enumerate(params.items()):
186188
if i >= len(args):
187189
break
188190
if isinstance(args[i], Parameter):
189-
logging.warning(warning_msg_template, arg_name, type(args[i]))
191+
logger.warning(warning_msg_template, arg_name, func_name, type(args[i]))
190192
args[i] = args[i].default_value
191193
args = tuple(args)
192194

193195
for arg_name, value in kwargs.items():
194196
if isinstance(value, Parameter):
195-
logging.warning(warning_msg_template, arg_name, type(value))
197+
logger.warning(warning_msg_template, arg_name, func_name, type(value))
196198
kwargs[arg_name] = value.default_value
197199
return func(*args, **kwargs)
198200

tests/integ/sagemaker/workflow/test_model_steps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,7 @@ def test_tensorflow_model_register_and_deploy_with_runtime_repack(
836836
sagemaker_session=pipeline_session,
837837
entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"),
838838
dependencies=[os.path.join(_TENSORFLOW_PATH, "dependency.py")],
839+
code_location=f"s3://{pipeline_session.default_bucket()}/model-code",
839840
)
840841
step_args = tf_model.register(
841842
content_types=["application/json"],

tests/unit/sagemaker/image_uris/test_retrieve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def test_retrieve_with_pipeline_variable():
754754
kwargs["instance_type"] = Join(on="", values=["a", "b"])
755755
with pytest.raises(Exception) as error:
756756
image_uris.retrieve(**kwargs)
757-
assert "instance_type should not be a pipeline variable" in str(error.value)
757+
assert "the argument instance_type should not be a pipeline variable" in str(error.value)
758758

759759
# instance_type (ParameterString) is given as args rather than kwargs
760760
# which should not break anything

tests/unit/sagemaker/workflow/test_model_step.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
_DIR_NAME = "/opt/ml/model/code"
6868
_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")
6969
_TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies")
70+
_REPACK_OUTPUT_KEY_PREFIX = "code-output"
71+
_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}"
7072

7173

7274
@pytest.fixture
@@ -688,6 +690,7 @@ def test_conditional_model_create_and_regis(
688690
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
689691
role=_ROLE,
690692
enable_network_isolation=True,
693+
code_location=_MODEL_CODE_LOCATION,
691694
),
692695
2,
693696
),
@@ -711,6 +714,7 @@ def test_conditional_model_create_and_regis(
711714
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
712715
role=_ROLE,
713716
framework_version="1.5.0",
717+
code_location=_MODEL_CODE_LOCATION,
714718
),
715719
2,
716720
),
@@ -742,6 +746,7 @@ def test_conditional_model_create_and_regis(
742746
image_uri=_IMAGE_URI,
743747
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
744748
role=_ROLE,
749+
code_location=_MODEL_CODE_LOCATION,
745750
),
746751
2,
747752
),
@@ -758,21 +763,45 @@ def test_conditional_model_create_and_regis(
758763
],
759764
)
760765
def test_create_model_among_different_model_types(test_input, pipeline_session, model_data_param):
766+
def assert_test_result(steps: list):
767+
# If expected_step_num is 2, it means a runtime repack step is appended
768+
# If expected_step_num is 1, it means no runtime repack is needed
769+
assert len(steps) == expected_step_num
770+
if expected_step_num == 2:
771+
assert steps[0]["Type"] == "Training"
772+
if model.key_prefix == _REPACK_OUTPUT_KEY_PREFIX:
773+
assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == (
774+
f"{_MODEL_CODE_LOCATION}/{model.name}"
775+
)
776+
else:
777+
assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == (
778+
f"s3://{_BUCKET}/{model.name}"
779+
)
780+
761781
model, expected_step_num = test_input
762782
model.sagemaker_session = pipeline_session
763783
model.model_data = model_data_param
764-
step_args = model.create(
784+
create_model_step_args = model.create(
765785
instance_type="c4.4xlarge",
766786
)
767-
model_steps = ModelStep(
787+
create_model_steps = ModelStep(
768788
name="MyModelStep",
769-
step_args=step_args,
789+
step_args=create_model_step_args,
770790
)
771-
steps = model_steps.request_dicts()
791+
assert_test_result(create_model_steps.request_dicts())
772792

773-
# If expected_step_num is 2, it means a runtime repack step is appended
774-
# If expected_step_num is 1, it means no runtime repack is needed
775-
assert len(steps) == expected_step_num
793+
register_model_step_args = model.register(
794+
content_types=["text/csv"],
795+
response_types=["text/csv"],
796+
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
797+
transform_instances=["ml.m5.xlarge"],
798+
model_package_group_name="MyModelPackageGroup",
799+
)
800+
register_model_steps = ModelStep(
801+
name="MyModelStep",
802+
step_args=register_model_step_args,
803+
)
804+
assert_test_result(register_model_steps.request_dicts())
776805

777806

778807
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)