diff --git a/src/sagemaker/workflow/_repack_model.py b/src/sagemaker/workflow/_repack_model.py index 6ce7e41831..f98f170f39 100644 --- a/src/sagemaker/workflow/_repack_model.py +++ b/src/sagemaker/workflow/_repack_model.py @@ -34,7 +34,7 @@ from distutils.dir_util import copy_tree -def repack(inference_script, model_archive, dependencies=None, source_dir=None): +def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover """Repack custom dependencies and code into an existing model TAR archive Args: @@ -95,7 +95,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): copy_tree(src_dir, "/opt/ml/model") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover parser = argparse.ArgumentParser() parser.add_argument("--inference_script", type=str, default="inference.py") parser.add_argument("--dependencies", type=str, default=None) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index d341af211d..fbbb6acba9 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -80,7 +80,7 @@ def __init__( artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. model_data (str): The S3 location of a SageMaker model data - ``.tar.gz`` file (default: None). + ``.tar.gz`` file. entry_point (str): Path (absolute or relative) to the local Python source file which should be executed as the entry point to inference. If ``source_dir`` is specified, then ``entry_point`` diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index a34330d94d..a2597c07f9 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -95,7 +95,7 @@ def properties(self): @attr.s -class JsonGet(Expression): +class JsonGet(Expression): # pragma: no cover """Get JSON properties from PropertyFiles. Attributes: diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 03ac099d18..e0076322de 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -75,8 +75,8 @@ class JsonGet(Expression): @property def expr(self): """The expression dict for a `JsonGet` function.""" - if not isinstance(self.step_name, str): - raise ValueError("Please give step name as a string") + if not isinstance(self.step_name, str) or not self.step_name: + raise ValueError("Please give a valid step name as a string") if isinstance(self.property_file, PropertyFile): name = self.property_file.name diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index 0446a0b46c..5240ae60b9 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -161,8 +161,8 @@ def _get_function_arn(self): partition = "aws" if self.lambda_func.function_arn is None: + account_id = self.lambda_func.session.account_id() try: - account_id = self.lambda_func.session.account_id() response = self.lambda_func.create() return response["FunctionArn"] except ValueError as error: diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 27060d928e..1280637006 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -323,15 +323,15 @@ def __init__( """ steps = [] if "entry_point" in kwargs: - entry_point = kwargs["entry_point"] - source_dir = kwargs.get("source_dir") - dependencies = kwargs.get("dependencies") + entry_point = kwargs.get("entry_point", None) + source_dir = kwargs.get("source_dir", None) + dependencies = kwargs.get("dependencies", None) repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, - role=estimator.sagemaker_session, + role=estimator.role, model_data=model_data, entry_point=entry_point, source_dir=source_dir, @@ -357,7 +357,11 @@ def predict_wrapper(endpoint, session): vpc_config=None, sagemaker_session=estimator.sagemaker_session, role=estimator.role, - **kwargs, + env=kwargs.get("env", None), + name=kwargs.get("name", None), + enable_network_isolation=kwargs.get("enable_network_isolation", None), + model_kms_key=kwargs.get("model_kms_key", None), + image_config=kwargs.get("image_config", None), ) model_step = CreateModelStep( name=f"{name}CreateModelStep", diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 14c2cf54b3..dd24149ca4 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -67,7 +67,6 @@ ConditionLessThanOrEqualTo, ) from sagemaker.workflow.condition_step import ConditionStep -from sagemaker.workflow.condition_step import JsonGet as ConditionStepJsonGet from sagemaker.workflow.callback_step import ( CallbackStep, CallbackOutput, @@ -2835,8 +2834,8 @@ def test_end_to_end_pipeline_successful_execution( # define condition step cond_lte = ConditionLessThanOrEqualTo( - left=ConditionStepJsonGet( - step=step_eval, + left=JsonGet( + step_name=step_eval.name, property_file=evaluation_report, json_path="regression_metrics.mse.value", ), diff --git a/tests/integ/test_workflow_with_clarify.py b/tests/integ/test_workflow_with_clarify.py index 0c41b2212a..486abab89b 100644 --- a/tests/integ/test_workflow_with_clarify.py +++ b/tests/integ/test_workflow_with_clarify.py @@ -33,7 +33,8 @@ from sagemaker.processing import ProcessingInput, ProcessingOutput from sagemaker.session import get_execution_role from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo -from sagemaker.workflow.condition_step import ConditionStep, JsonGet +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.functions import JsonGet from sagemaker.workflow.parameters import ( ParameterInteger, ParameterString, @@ -237,7 +238,7 @@ def test_workflow_with_clarify( ) cond_left = JsonGet( - step=step_process, + step_name=step_process.name, property_file="BiasOutput", json_path="post_training_bias_metrics.facets.F1[0].metrics[0].value", ) diff --git a/tests/unit/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py similarity index 100% rename from tests/unit/test_airflow.py rename to tests/unit/sagemaker/workflow/test_airflow.py diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 8e5d6b6d31..9b07a41d09 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -13,6 +13,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join, JsonGet from sagemaker.workflow.parameters import ( @@ -97,3 +99,23 @@ def test_json_get_expressions(): "Path": "my-json-path", }, } + + +def test_json_get_expressions_with_invalid_step_name(): + with pytest.raises(ValueError) as err: + JsonGet( + step_name="", + property_file="my-property-file", + json_path="my-json-path", + ).expr + + assert "Please give a valid step name as a string" in str(err.value) + + with pytest.raises(ValueError) as err: + JsonGet( + step_name=ParameterString(name="MyString"), + property_file="my-property-file", + json_path="my-json-path", + ).expr + + assert "Please give a valid step name as a string" in str(err.value) diff --git a/tests/unit/sagemaker/workflow/test_lambda_step.py b/tests/unit/sagemaker/workflow/test_lambda_step.py index 0566e39318..bdaa781b1c 100644 --- a/tests/unit/sagemaker/workflow/test_lambda_step.py +++ b/tests/unit/sagemaker/workflow/test_lambda_step.py @@ -22,6 +22,7 @@ from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from sagemaker.lambda_helper import Lambda +from sagemaker.workflow.steps import CacheConfig @pytest.fixture() @@ -38,10 +39,25 @@ def sagemaker_session(): return session_mock +@pytest.fixture() +def sagemaker_session_cn(): + boto_mock = Mock(name="boto_session", region_name="cn-north-1") + session_mock = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name="cn-north-1", + config=None, + local_mode=False, + ) + session_mock.account_id.return_value = "234567890123" + return session_mock + + def test_lambda_step(sagemaker_session): param = ParameterInteger(name="MyInt") - outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) - outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean) + output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) + output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.Boolean) + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") lambda_step = LambdaStep( name="MyLambdaStep", depends_on=["TestStep"], @@ -52,10 +68,17 @@ def test_lambda_step(sagemaker_session): display_name="MyLambdaStep", description="MyLambdaStepDescription", inputs={"arg1": "foo", "arg2": 5, "arg3": param}, - outputs=[outputParam1, outputParam2], + outputs=[output_param1, output_param2], + cache_config=cache_config, ) lambda_step.add_depends_on(["SecondTestStep"]) - assert lambda_step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[param], + steps=[lambda_step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyLambdaStep", "Type": "Lambda", "DependsOn": ["TestStep", "SecondTestStep"], @@ -66,7 +89,8 @@ def test_lambda_step(sagemaker_session): {"OutputName": "output1", "OutputType": "String"}, {"OutputName": "output2", "OutputType": "Boolean"}, ], - "Arguments": {"arg1": "foo", "arg2": 5, "arg3": param}, + "Arguments": {"arg1": "foo", "arg2": 5, "arg3": {"Get": "Parameters.MyInt"}}, + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } @@ -95,8 +119,8 @@ def test_lambda_step_output_expr(sagemaker_session): def test_pipeline_interpolates_lambda_outputs(sagemaker_session): parameter = ParameterString("MyStr") - outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) - outputParam2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) + output_param1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) + output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) lambda_step1 = LambdaStep( name="MyLambdaStep1", depends_on=["TestStep"], @@ -105,7 +129,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): session=sagemaker_session, ), inputs={"arg1": "foo"}, - outputs=[outputParam1], + outputs=[output_param1], ) lambda_step2 = LambdaStep( name="MyLambdaStep2", @@ -114,8 +138,8 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", session=sagemaker_session, ), - inputs={"arg1": outputParam1}, - outputs=[outputParam2], + inputs={"arg1": output_param1}, + outputs=[output_param2], ) pipeline = Pipeline( @@ -207,3 +231,37 @@ def test_lambda_step_without_function_arn(sagemaker_session): ) lambda_step._get_function_arn() sagemaker_session.account_id.assert_called_once() + + +def test_lambda_step_without_function_arn_and_with_error(sagemaker_session_cn): + lambda_func = MagicMock( + function_arn=None, + function_name="name", + execution_role_arn="arn:aws:lambda:us-west-2:123456789012:execution_role", + zipped_code_dir="", + handler="", + session=sagemaker_session_cn, + ) + # The raised ValueError contains ResourceConflictException + lambda_func.create.side_effect = ValueError("ResourceConflictException") + lambda_step1 = LambdaStep( + name="MyLambdaStep1", + depends_on=["TestStep"], + lambda_func=lambda_func, + inputs={}, + outputs=[], + ) + function_arn = lambda_step1._get_function_arn() + assert function_arn == "arn:aws-cn:lambda:cn-north-1:234567890123:function:name" + + # The raised ValueError does not contain ResourceConflictException + lambda_func.create.side_effect = ValueError() + lambda_step2 = LambdaStep( + name="MyLambdaStep2", + depends_on=["TestStep"], + lambda_func=lambda_func, + inputs={}, + outputs=[], + ) + with pytest.raises(ValueError): + lambda_step2._get_function_arn() diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 6c78412b22..d2f1f07059 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -19,6 +19,7 @@ import pytest from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.workflow.utilities import list_to_request from tests.unit import DATA_DIR import sagemaker @@ -206,6 +207,16 @@ def test_step_collection(): ] +def test_step_collection_with_list_to_request(): + step_collection = StepCollection(steps=[CustomStep("MyStep1"), CustomStep("MyStep2")]) + custom_step = CustomStep("MyStep3") + assert list_to_request([step_collection, custom_step]) == [ + {"Name": "MyStep1", "Type": "Training", "Arguments": dict()}, + {"Name": "MyStep2", "Type": "Training", "Arguments": dict()}, + {"Name": "MyStep3", "Type": "Training", "Arguments": dict()}, + ] + + def test_register_model(estimator, model_metrics, drift_check_baselines): model_data = f"s3://{BUCKET}/model.tar.gz" register_model = RegisterModel( @@ -216,6 +227,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): response_types=["response_type"], inference_instances=["inference_instance"], transform_instances=["transform_instance"], + image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", model_package_group_name="mpg", model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, @@ -236,7 +248,10 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): "Arguments": { "InferenceSpecification": { "Containers": [ - {"Image": "fakeimage", "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz"} + { + "Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri", + "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", + } ], "SupportedContentTypes": ["content_type"], "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"], @@ -865,3 +880,117 @@ def test_estimator_transformer(estimator): } else: raise Exception("A step exists in the collection of an invalid type.") + + +def test_estimator_transformer_with_model_repack_with_estimator(estimator): + model_data = f"s3://{BUCKET}/model.tar.gz" + model_inputs = CreateModelInput( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + service_fault_retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10 + ) + transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") + estimator_transformer = EstimatorTransformer( + name="EstimatorTransformerStep", + estimator=estimator, + model_data=model_data, + model_inputs=model_inputs, + instance_count=1, + instance_type="ml.c4.4xlarge", + transform_inputs=transform_inputs, + depends_on=["TestStep"], + model_step_retry_policies=[service_fault_retry_policy], + transform_step_retry_policies=[service_fault_retry_policy], + repack_model_step_retry_policies=[service_fault_retry_policy], + entry_point=f"{DATA_DIR}/dummy_script.py", + ) + request_dicts = estimator_transformer.request_dicts() + assert len(request_dicts) == 3 + + for request_dict in request_dicts: + if request_dict["Type"] == "Training": + assert request_dict["Name"] == "EstimatorTransformerStepRepackModel" + assert request_dict["DependsOn"] == ["TestStep"] + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + # pop out the dynamic generated fields + arguments["HyperParameters"].pop("sagemaker_submit_directory") + arguments["HyperParameters"].pop("sagemaker_job_name") + assert arguments == { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" + + "sagemaker-scikit-learn:0.23-1-cpu-py3", + }, + "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 30, + }, + "RoleArn": "DummyRole", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://my-bucket", + "S3DataDistributionType": "FullyReplicated", + } + }, + "ChannelName": "training", + } + ], + "HyperParameters": { + "inference_script": '"dummy_script.py"', + "model_archive": '"model.tar.gz"', + "dependencies": "null", + "source_dir": "null", + "sagemaker_program": '"_repack_model.py"', + "sagemaker_container_log_level": "20", + "sagemaker_region": '"us-west-2"', + }, + "VpcConfig": {"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]}, + "DebugHookConfig": { + "S3OutputPath": "s3://my-bucket/", + "CollectionConfigurations": [], + }, + } + elif request_dict["Type"] == "Model": + assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties) + arguments["PrimaryContainer"].pop("ModelDataUrl") + assert "DependsOn" not in request_dict + assert arguments == { + "ExecutionRoleArn": "DummyRole", + "PrimaryContainer": { + "Environment": {}, + "Image": "fakeimage", + }, + } + elif request_dict["Type"] == "Transform": + assert request_dict["Name"] == "EstimatorTransformerStepTransformStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["ModelName"], Properties) + arguments.pop("ModelName") + assert "DependsOn" not in request_dict + assert arguments == { + "TransformInput": { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": f"s3://{BUCKET}/transform_manifest", + } + } + }, + "TransformOutput": {"S3OutputPath": None}, + "TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"}, + } + else: + raise Exception("A step exists in the collection of an invalid type.") diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index e3dc10e23e..674c715617 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -13,6 +13,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json + import pytest import sagemaker import os @@ -43,7 +45,8 @@ ) from sagemaker.network import NetworkConfig from sagemaker.transformer import Transformer -from sagemaker.workflow.properties import Properties +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.properties import Properties, PropertyFile from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.retry import ( StepRetryPolicy, @@ -535,6 +538,9 @@ def test_processing_step(sagemaker_session): ) ] cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + evaluation_report = PropertyFile( + name="EvaluationReport", output_name="evaluation", path="evaluation.json" + ) step = ProcessingStep( name="MyProcessingStep", description="ProcessingStep description", @@ -544,9 +550,20 @@ def test_processing_step(sagemaker_session): inputs=inputs, outputs=[], cache_config=cache_config, + property_files=[evaluation_report], ) step.add_depends_on(["ThirdTestStep"]) - assert step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[ + processing_input_data_uri_parameter, + instance_type_parameter, + instance_count_parameter, + ], + steps=[step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyProcessingStep", "Description": "ProcessingStep description", "DisplayName": "MyProcessingStep", @@ -564,20 +581,27 @@ def test_processing_step(sagemaker_session): "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3InputMode": "File", - "S3Uri": processing_input_data_uri_parameter, + "S3Uri": {"Get": "Parameters.ProcessingInputDataUri"}, }, } ], "ProcessingResources": { "ClusterConfig": { - "InstanceCount": instance_count_parameter, - "InstanceType": instance_type_parameter, + "InstanceCount": {"Get": "Parameters.InstanceCount"}, + "InstanceType": {"Get": "Parameters.InstanceType"}, "VolumeSizeInGB": 30, } }, "RoleArn": "DummyRole", }, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, + "PropertyFiles": [ + { + "FilePath": "evaluation.json", + "OutputName": "evaluation", + "PropertyFileName": "EvaluationReport", + } + ], } assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index 5a2a9497f8..e534aa531e 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -26,6 +26,7 @@ ) from sagemaker.estimator import Estimator +from sagemaker.workflow import Properties from sagemaker.workflow._utils import _RepackModelStep from tests.unit import DATA_DIR @@ -156,7 +157,7 @@ def test_repack_model_step(estimator): def test_repack_model_step_with_source_dir(estimator, source_dir): - model_data = f"s3://{BUCKET}/model.tar.gz" + model_data = Properties(path="Steps.MyStep", shape_name="DescribeModelOutput") entry_point = "inference.py" step = _RepackModelStep( name="MyRepackModelStep", @@ -189,7 +190,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir): "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", - "S3Uri": f"s3://{BUCKET}", + "S3Uri": model_data, } }, }