diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index ec1fec2d20..01ed5f1d99 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -32,7 +32,6 @@ HUGGING_FACE_FRAMEWORK = "huggingface" -# TODO: we should remove this decorator later @override_pipeline_parameter_var def retrieve( framework, @@ -117,7 +116,11 @@ def retrieve( args = dict(locals()) for name, val in args.items(): if is_pipeline_variable(val): - raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val))) + raise ValueError( + "When retrieving the image_uri, the argument %s should not be a pipeline variable " + "(%s) since pipeline variables are only interpreted in the pipeline execution time." + % (name, type(val)) + ) if is_jumpstart_model_input(model_id, model_version): return artifacts._retrieve_image_uri( @@ -487,6 +490,9 @@ def get_training_image_uri( if image_uri: return image_uri + logger.info( + "image_uri is not presented, retrieving image_uri based on instance_type, framework etc." + ) base_framework_version: Optional[str] = None if tensorflow_version is not None or pytorch_version is not None: diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 8f128fe3f4..aaed24ac05 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -527,10 +527,10 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: artifact should be repackaged into a new S3 object. (default: False). """ local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) + bucket = self.bucket or self.sagemaker_session.default_bucket() if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: self.uploaded_code = None elif not repack: - bucket = self.bucket or self.sagemaker_session.default_bucket() self.uploaded_code = fw_utils.tar_and_upload_dir( session=self.sagemaker_session.boto_session, bucket=bucket, @@ -557,6 +557,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: ) return self.sagemaker_session.context.need_runtime_repack.add(id(self)) + self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format( + bucket, key_prefix + ) # Add the uploaded_code and repacked_model_data to update the container env self.repacked_model_data = self.model_data self.uploaded_code = fw_utils.UploadedCode( @@ -567,7 +570,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: if local_code and self.model_data.startswith("file://"): repacked_model_data = self.model_data else: - bucket = self.bucket or self.sagemaker_session.default_bucket() repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) self.uploaded_code = fw_utils.UploadedCode( s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point) diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index 60307a7868..1957903587 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -280,6 +280,11 @@ def training_image_uri(self): """ if self.image_uri: return self.image_uri + + logger.info( + "image_uri is not presented, retrieving image_uri based on instance_type, " + "framework etc." + ) return image_uris.retrieve( self._image_framework(), self.sagemaker_session.boto_region_name, diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index e5e6798a63..4841563cf7 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -24,6 +24,8 @@ from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import PipelineSession +logger = logging.getLogger(__name__) + class TensorFlowPredictor(Predictor): """A ``Predictor`` implementation for inference against TensorFlow Serving endpoints.""" @@ -363,13 +365,10 @@ def prepare_container_def( instance_type, accelerator_type, serverless_inference_config=serverless_inference_config ) env = self._get_container_env() + key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image_uri) + bucket = self.bucket or self.sagemaker_session.default_bucket() if self.entry_point and not is_pipeline_variable(self.model_data): - key_prefix = sagemaker.fw_utils.model_code_key_prefix( - self.key_prefix, self.name, image_uri - ) - - bucket = self.bucket or self.sagemaker_session.default_bucket() model_data = s3.s3_path_join("s3://", bucket, key_prefix, "model.tar.gz") sagemaker.utils.repack_model( @@ -385,6 +384,9 @@ def prepare_container_def( # model is not yet there, defer repacking to later during pipeline execution if isinstance(self.sagemaker_session, PipelineSession): self.sagemaker_session.context.need_runtime_repack.add(id(self)) + self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format( + bucket, key_prefix + ) else: logging.warning( "The model_data is a Pipeline variable of type %s, " @@ -426,6 +428,10 @@ def _get_image_uri( if self.image_uri: return self.image_uri + logger.info( + "image_uri is not presented, retrieving image_uri based on instance_type, " + "framework etc." + ) return image_uris.retrieve( self._framework_name, region_name or self.sagemaker_session.boto_region_name, diff --git a/src/sagemaker/workflow/model_step.py b/src/sagemaker/workflow/model_step.py index e46fd71a84..6c261d1bdc 100644 --- a/src/sagemaker/workflow/model_step.py +++ b/src/sagemaker/workflow/model_step.py @@ -23,8 +23,6 @@ from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, CreateModelStep -NEED_RUNTIME_REPACK = "need_runtime_repack" - _CREATE_MODEL_RETRY_POLICIES = "create_model_retry_policies" _REGISTER_MODEL_RETRY_POLICIES = "register_model_retry_policies" _REPACK_MODEL_RETRY_POLICIES = "repack_model_retry_policies" @@ -155,6 +153,7 @@ def __init__( self._create_model_args = self.step_args.create_model_request self._register_model_args = self.step_args.create_model_package_request self._need_runtime_repack = self.step_args.need_runtime_repack + self._runtime_repack_output_prefix = self.step_args.runtime_repack_output_prefix self._assign_and_validate_retry_policies(retry_policies) if self._need_runtime_repack: @@ -268,6 +267,7 @@ def _append_repack_model_step(self): ), depends_on=self.depends_on, retry_policies=self._repack_model_retry_policies, + output_path=self._runtime_repack_output_prefix, ) self.steps.append(repack_model_step) diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index 95fbd9371c..341e123be0 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -67,6 +67,7 @@ def __init__(self, model): self.create_model_package_request = None self.create_model_request = None self.need_runtime_repack = set() + self.runtime_repack_output_prefix = None class PipelineSession(Session): @@ -139,14 +140,14 @@ def _intercept_create_request(self, request: Dict, create, func_name: str = None else: self.context = _JobStepArguments(func_name, request) - def init_step_arguments(self, model): + def init_model_step_arguments(self, model): """Create a `_ModelStepArguments` (if not exist) as pipeline context Args: model (Model or PipelineModel): A `sagemaker.model.Model` or `sagemaker.pipeline.PipelineModel` instance """ - if not self._context or not isinstance(self._context, _ModelStepArguments): + if not isinstance(self._context, _ModelStepArguments): self._context = _ModelStepArguments(model) @@ -197,7 +198,7 @@ def wrapper(*args, **kwargs): UserWarning, ) if run_func.__name__ in ["register", "create"]: - self_instance.sagemaker_session.init_step_arguments(self_instance) + self_instance.sagemaker_session.init_model_step_arguments(self_instance) run_func(*args, **kwargs) context = self_instance.sagemaker_session.context self_instance.sagemaker_session.context = None diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index afe1e4eae1..a30ddd4dee 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -29,6 +29,8 @@ RequestType, ) +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from sagemaker.workflow.step_collections import StepCollection @@ -173,26 +175,26 @@ def override_pipeline_parameter_var(func): We should remove this decorator after the grace period. """ warning_msg_template = ( - "%s should not be a pipeline variable (%s). " - "The default_value of this Parameter object will be used to override it. " - "Please remove this pipeline variable and use python primitives instead." + "The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. " + "The default_value of this Parameter object will be used to override it." ) @wraps(func) def wrapper(*args, **kwargs): + func_name = "{}.{}".format(func.__module__, func.__name__) params = inspect.signature(func).parameters args = list(args) for i, (arg_name, _) in enumerate(params.items()): if i >= len(args): break if isinstance(args[i], Parameter): - logging.warning(warning_msg_template, arg_name, type(args[i])) + logger.warning(warning_msg_template, arg_name, func_name, type(args[i])) args[i] = args[i].default_value args = tuple(args) for arg_name, value in kwargs.items(): if isinstance(value, Parameter): - logging.warning(warning_msg_template, arg_name, type(value)) + logger.warning(warning_msg_template, arg_name, func_name, type(value)) kwargs[arg_name] = value.default_value return func(*args, **kwargs) diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index d5e21be1bf..31c518b100 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -836,6 +836,7 @@ def test_tensorflow_model_register_and_deploy_with_runtime_repack( sagemaker_session=pipeline_session, entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"), dependencies=[os.path.join(_TENSORFLOW_PATH, "dependency.py")], + code_location=f"s3://{pipeline_session.default_bucket()}/model-code", ) step_args = tf_model.register( content_types=["application/json"], diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index c167da6f47..ae37395b92 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -754,7 +754,7 @@ def test_retrieve_with_pipeline_variable(): kwargs["instance_type"] = Join(on="", values=["a", "b"]) with pytest.raises(Exception) as error: image_uris.retrieve(**kwargs) - assert "instance_type should not be a pipeline variable" in str(error.value) + assert "the argument instance_type should not be a pipeline variable" in str(error.value) # instance_type (ParameterString) is given as args rather than kwargs # which should not break anything diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 68961b355c..cfeb8d5a03 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -67,6 +67,8 @@ _DIR_NAME = "/opt/ml/model/code" _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") _TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") +_REPACK_OUTPUT_KEY_PREFIX = "code-output" +_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" @pytest.fixture @@ -688,6 +690,7 @@ def test_conditional_model_create_and_regis( entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", role=_ROLE, enable_network_isolation=True, + code_location=_MODEL_CODE_LOCATION, ), 2, ), @@ -711,6 +714,7 @@ def test_conditional_model_create_and_regis( entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", role=_ROLE, framework_version="1.5.0", + code_location=_MODEL_CODE_LOCATION, ), 2, ), @@ -742,6 +746,7 @@ def test_conditional_model_create_and_regis( image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", role=_ROLE, + code_location=_MODEL_CODE_LOCATION, ), 2, ), @@ -758,21 +763,45 @@ def test_conditional_model_create_and_regis( ], ) def test_create_model_among_different_model_types(test_input, pipeline_session, model_data_param): + def assert_test_result(steps: list): + # If expected_step_num is 2, it means a runtime repack step is appended + # If expected_step_num is 1, it means no runtime repack is needed + assert len(steps) == expected_step_num + if expected_step_num == 2: + assert steps[0]["Type"] == "Training" + if model.key_prefix == _REPACK_OUTPUT_KEY_PREFIX: + assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( + f"{_MODEL_CODE_LOCATION}/{model.name}" + ) + else: + assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( + f"s3://{_BUCKET}/{model.name}" + ) + model, expected_step_num = test_input model.sagemaker_session = pipeline_session model.model_data = model_data_param - step_args = model.create( + create_model_step_args = model.create( instance_type="c4.4xlarge", ) - model_steps = ModelStep( + create_model_steps = ModelStep( name="MyModelStep", - step_args=step_args, + step_args=create_model_step_args, ) - steps = model_steps.request_dicts() + assert_test_result(create_model_steps.request_dicts()) - # If expected_step_num is 2, it means a runtime repack step is appended - # If expected_step_num is 1, it means no runtime repack is needed - assert len(steps) == expected_step_num + register_model_step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="MyModelPackageGroup", + ) + register_model_steps = ModelStep( + name="MyModelStep", + step_args=register_model_step_args, + ) + assert_test_result(register_model_steps.request_dicts()) @pytest.mark.parametrize(