From d05cc8a0e9df3f5e75a54efcc6a655fb7bc40416 Mon Sep 17 00:00:00 2001 From: Eugene Teoh Date: Sun, 6 Feb 2022 21:06:23 +0000 Subject: [PATCH 1/5] fix: fix EstimatorTransformer entry point error --- src/sagemaker/workflow/step_collections.py | 43 +++++++++++++--------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index f4606488b2..eb24cda505 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -18,6 +18,7 @@ import attr from sagemaker.estimator import EstimatorBase +from sagemaker.inputs import CreateModelInput, TransformInput from sagemaker.model import Model from sagemaker import PipelineModel from sagemaker.predictor import Predictor @@ -244,16 +245,17 @@ class EstimatorTransformer(StepCollection): def __init__( self, name: str, - estimator: EstimatorBase, - model_data, - model_inputs, - instance_count, - instance_type, - transform_inputs, + model_data: str, + model_inputs: CreateModelInput, + instance_count: int, + instance_type: str, + transform_inputs: TransformInput, + image_uri: str, + sagemaker_session: str, + role: str, description: str = None, display_name: str = None, # model arguments - image_uri=None, predictor_cls=None, env=None, # transformer arguments @@ -318,22 +320,27 @@ def __init__( """ steps = [] if "entry_point" in kwargs: - entry_point = kwargs["entry_point"] - source_dir = kwargs.get("source_dir") - dependencies = kwargs.get("dependencies") + entry_point = kwargs.pop("entry_point", None) + source_dir = kwargs.pop("source_dir", None) + dependencies = kwargs.pop("dependencies", None) + code_location = kwargs.pop("code_location", None) + subnets = kwargs.pop("subnets", None) + security_group_ids = kwargs.pop("security_group_ids", None) + repack_model_step = _RepackModelStep( name=f"{name}RepackModel", depends_on=depends_on, retry_policies=repack_model_step_retry_policies, - sagemaker_session=estimator.sagemaker_session, - role=estimator.sagemaker_session, + sagemaker_session=sagemaker_session, + role=role, model_data=model_data, entry_point=entry_point, source_dir=source_dir, dependencies=dependencies, + code_location=code_location, tags=tags, - subnets=estimator.subnets, - security_group_ids=estimator.security_group_ids, + subnets=subnets, + security_group_ids=security_group_ids, description=description, display_name=display_name, ) @@ -346,12 +353,12 @@ def predict_wrapper(endpoint, session): predictor_cls = predictor_cls or predict_wrapper model = Model( - image_uri=image_uri or estimator.training_image_uri(), + image_uri=image_uri, model_data=model_data, predictor_cls=predictor_cls, vpc_config=None, - sagemaker_session=estimator.sagemaker_session, - role=estimator.role, + sagemaker_session=sagemaker_session, + role=role, **kwargs, ) model_step = CreateModelStep( @@ -382,7 +389,7 @@ def predict_wrapper(endpoint, session): tags=tags, base_transform_job_name=name, volume_kms_key=volume_kms_key, - sagemaker_session=estimator.sagemaker_session, + sagemaker_session=sagemaker_session, ) transform_step = TransformStep( name=f"{name}TransformStep", From 9e460f2009278c21a2c21d58bca114d7932a21e4 Mon Sep 17 00:00:00 2001 From: Eugene Teoh Date: Mon, 7 Feb 2022 21:28:30 +0000 Subject: [PATCH 2/5] fix: test_estimator_transformer --- tests/unit/sagemaker/workflow/test_step_collections.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 6c78412b22..49279d1dca 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -802,7 +802,7 @@ def test_register_model_with_model_repack_with_pipeline_model( raise Exception("A step exists in the collection of an invalid type.") -def test_estimator_transformer(estimator): +def test_estimator_transformer(sagemaker_session): model_data = f"s3://{BUCKET}/model.tar.gz" model_inputs = CreateModelInput( instance_type="c4.4xlarge", @@ -814,7 +814,6 @@ def test_estimator_transformer(estimator): transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") estimator_transformer = EstimatorTransformer( name="EstimatorTransformerStep", - estimator=estimator, model_data=model_data, model_inputs=model_inputs, instance_count=1, @@ -824,6 +823,9 @@ def test_estimator_transformer(estimator): model_step_retry_policies=[service_fault_retry_policy], transform_step_retry_policies=[service_fault_retry_policy], repack_model_step_retry_policies=[service_fault_retry_policy], + image_uri=IMAGE_URI, + sagemaker_session=sagemaker_session, + role=ROLE ) request_dicts = estimator_transformer.request_dicts() assert len(request_dicts) == 2 From af84f593aeaf33bd84434b5652b3007b874c7ec9 Mon Sep 17 00:00:00 2001 From: Eugene Teoh Date: Mon, 7 Feb 2022 22:00:36 +0000 Subject: [PATCH 3/5] fix: add estimator transformer with model repack test --- .../workflow/test_step_collections.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 49279d1dca..b17f577e6a 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -867,3 +867,123 @@ def test_estimator_transformer(sagemaker_session): } else: raise Exception("A step exists in the collection of an invalid type.") + +def test_estimator_transformer_with_model_repack(sagemaker_session): + model_data = f"s3://{BUCKET}/model.tar.gz" + model_inputs = CreateModelInput( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + service_fault_retry_policy = StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10 + ) + transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") + dummy_requirements = f"{DATA_DIR}/dummy_requirements.txt" + estimator_transformer = EstimatorTransformer( + name="EstimatorTransformerStep", + model_data=model_data, + model_inputs=model_inputs, + instance_count=1, + instance_type="ml.c4.4xlarge", + transform_inputs=transform_inputs, + depends_on=["TestStep"], + model_step_retry_policies=[service_fault_retry_policy], + transform_step_retry_policies=[service_fault_retry_policy], + repack_model_step_retry_policies=[service_fault_retry_policy], + image_uri=IMAGE_URI, + sagemaker_session=sagemaker_session, + role=ROLE, + entry_point=f"{DATA_DIR}/dummy_script.py", + dependencies=[dummy_requirements] + ) + request_dicts = estimator_transformer.request_dicts() + assert len(request_dicts) == 3 + + for request_dict in request_dicts: + if request_dict["Type"] == "Training": + assert request_dict["Name"] == "EstimatorTransformerStepRepackModel" + assert len(request_dict["DependsOn"]) == 1 + assert request_dict["DependsOn"][0] == "TestStep" + arguments = request_dict["Arguments"] + repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"] + assert ordered(arguments) == ordered( + { + "AlgorithmSpecification": { + "TrainingImage": MODEL_REPACKING_IMAGE_URI, + "TrainingInputMode": "File", + }, + "DebugHookConfig": { + "CollectionConfigurations": [], + "S3OutputPath": f"s3://{BUCKET}/", + }, + "HyperParameters": { + "inference_script": '"dummy_script.py"', + "dependencies": f'"{dummy_requirements}"', + "model_archive": '"model.tar.gz"', + "sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format( + BUCKET, repacker_job_name.replace('"', "") + ), + "sagemaker_program": '"_repack_model.py"', + "sagemaker_container_log_level": "20", + "sagemaker_job_name": repacker_job_name, + "sagemaker_region": f'"{REGION}"', + "source_dir": "null", + }, + "InputDataConfig": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": f"s3://{BUCKET}", + } + }, + } + ], + "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 30, + }, + "RoleArn": ROLE, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + } + ) + + elif request_dict["Type"] == "Model": + assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties) + arguments["PrimaryContainer"].pop("ModelDataUrl") + assert arguments == { + "ExecutionRoleArn": "DummyRole", + "PrimaryContainer": { + "Environment": {}, + "Image": "fakeimage", + } + } + + elif request_dict["Type"] == "Transform": + assert request_dict["Name"] == "EstimatorTransformerStepTransformStep" + assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] + arguments = request_dict["Arguments"] + assert isinstance(arguments["ModelName"], Properties) + arguments.pop("ModelName") + assert "DependsOn" not in request_dict + assert arguments == { + "TransformInput": { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": f"s3://{BUCKET}/transform_manifest", + } + } + }, + "TransformOutput": {"S3OutputPath": None}, + "TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"}, + } + else: + raise Exception("A step exists in the collection of an invalid type.") From 8eb215337d02fc9aea342b93305adc26464cec67 Mon Sep 17 00:00:00 2001 From: Eugene Teoh Date: Mon, 7 Feb 2022 22:06:39 +0000 Subject: [PATCH 4/5] fix: format with black --- .../sagemaker/workflow/test_step_collections.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index b17f577e6a..9c4d196778 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -825,7 +825,7 @@ def test_estimator_transformer(sagemaker_session): repack_model_step_retry_policies=[service_fault_retry_policy], image_uri=IMAGE_URI, sagemaker_session=sagemaker_session, - role=ROLE + role=ROLE, ) request_dicts = estimator_transformer.request_dicts() assert len(request_dicts) == 2 @@ -868,6 +868,7 @@ def test_estimator_transformer(sagemaker_session): else: raise Exception("A step exists in the collection of an invalid type.") + def test_estimator_transformer_with_model_repack(sagemaker_session): model_data = f"s3://{BUCKET}/model.tar.gz" model_inputs = CreateModelInput( @@ -894,7 +895,7 @@ def test_estimator_transformer_with_model_repack(sagemaker_session): sagemaker_session=sagemaker_session, role=ROLE, entry_point=f"{DATA_DIR}/dummy_script.py", - dependencies=[dummy_requirements] + dependencies=[dummy_requirements], ) request_dicts = estimator_transformer.request_dicts() assert len(request_dicts) == 3 @@ -959,11 +960,11 @@ def test_estimator_transformer_with_model_repack(sagemaker_session): assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties) arguments["PrimaryContainer"].pop("ModelDataUrl") assert arguments == { - "ExecutionRoleArn": "DummyRole", - "PrimaryContainer": { - "Environment": {}, - "Image": "fakeimage", - } + "ExecutionRoleArn": "DummyRole", + "PrimaryContainer": { + "Environment": {}, + "Image": "fakeimage", + }, } elif request_dict["Type"] == "Transform": From a8966ec1d25439985d262f14bac53732b7d85197 Mon Sep 17 00:00:00 2001 From: Eugene Teoh Date: Tue, 8 Feb 2022 13:54:32 +0000 Subject: [PATCH 5/5] fix: update docstrings --- src/sagemaker/workflow/step_collections.py | 47 +++++++++++++++++----- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index eb24cda505..9f19a99901 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -17,6 +17,7 @@ import attr +import sagemaker from sagemaker.estimator import EstimatorBase from sagemaker.inputs import CreateModelInput, TransformInput from sagemaker.model import Model @@ -251,8 +252,8 @@ def __init__( instance_type: str, transform_inputs: TransformInput, image_uri: str, - sagemaker_session: str, - role: str, + sagemaker_session: sagemaker.session.Session = None, + role: str = None, description: str = None, display_name: str = None, # model arguments @@ -292,9 +293,29 @@ def __init__( Args: name (str): The name of the Transform Step. - estimator: The estimator instance. + model_data (str): The S3 location of a SageMaker model data + ``.tar.gz`` file. + model_inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance. instance_count (int): The number of EC2 instances to use. instance_type (str): The type of EC2 instance to use. + transform_inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance. + image_uri (str): A Docker image URI. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (default: None). If not + specified, one is created using the default AWS configuration + chain. + role (str): An AWS IAM role (either name or full ARN) (default: + None). + description (str): A description of all steps constructed in the + step collection (default: None). + display_name (str): The display name of all steps constructed in the + step collection (default: None). + predictor_cls (callable[string, sagemaker.session.Session]): A + function to call to create a predictor (default: None). If not + None, ``deploy`` will return the result of invoking this + function on the created endpoint name. + env (dict[str, str]): The Environment variables to be set for use during the + transform job (default: None). strategy (str): The strategy used to decide how to batch records in a single request (default: None). Valid values: 'MultiRecord' and 'SingleRecord'. @@ -307,16 +328,24 @@ def __init__( accept (str): The accept header passed by the client to the inference endpoint. If it is supported by the endpoint, it will be the format of the batch transform output. - env (dict): The Environment variables to be set for use during the - transform job (default: None). + max_concurrent_transforms (int): The maximum number of HTTP requests + to be made to each individual transform container at one time. + max_payload (int): Maximum size of the payload in a single HTTP + request to the container in MB. + tags (list[dict]): List of tags for labeling a transform job + (default: None). For more, see the SageMaker API documentation for + `Tag `_. + volume_kms_key (str): Optional. KMS key ID for encrypting the volume + attached to the ML compute instance (default: None). depends_on (List[str] or List[Step]): The list of step names or step instances - the first step in the collection depends on + the first step in the collection depends on. repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies - for the repack model step + for the repack model step. model_step_retry_policies (List[RetryPolicy]): The list of retry policies for - model step + model step. transform_step_retry_policies (List[RetryPolicy]): The list of retry policies for - transform step + transform step. + **kwargs: Extra keyword arguments for `_RepackModelStep` or `Model`. """ steps = [] if "entry_point" in kwargs: