Skip to content

fix: EstimatorTransformer entry point error #2905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 61 additions & 25 deletions src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import attr

import sagemaker
from sagemaker.estimator import EstimatorBase
from sagemaker.inputs import CreateModelInput, TransformInput
from sagemaker.model import Model
from sagemaker import PipelineModel
from sagemaker.predictor import Predictor
Expand Down Expand Up @@ -244,16 +246,17 @@ class EstimatorTransformer(StepCollection):
def __init__(
self,
name: str,
estimator: EstimatorBase,
model_data,
model_inputs,
instance_count,
instance_type,
transform_inputs,
model_data: str,
model_inputs: CreateModelInput,
instance_count: int,
instance_type: str,
transform_inputs: TransformInput,
image_uri: str,
sagemaker_session: sagemaker.session.Session = None,
role: str = None,
description: str = None,
display_name: str = None,
# model arguments
image_uri=None,
predictor_cls=None,
env=None,
# transformer arguments
Expand Down Expand Up @@ -290,9 +293,29 @@ def __init__(

Args:
name (str): The name of the Transform Step.
estimator: The estimator instance.
model_data (str): The S3 location of a SageMaker model data
``.tar.gz`` file.
model_inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
instance_count (int): The number of EC2 instances to use.
instance_type (str): The type of EC2 instance to use.
transform_inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
image_uri (str): A Docker image URI.
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions (default: None). If not
specified, one is created using the default AWS configuration
chain.
role (str): An AWS IAM role (either name or full ARN) (default:
None).
description (str): A description of all steps constructed in the
step collection (default: None).
display_name (str): The display name of all steps constructed in the
step collection (default: None).
predictor_cls (callable[string, sagemaker.session.Session]): A
function to call to create a predictor (default: None). If not
None, ``deploy`` will return the result of invoking this
function on the created endpoint name.
env (dict[str, str]): The Environment variables to be set for use during the
transform job (default: None).
strategy (str): The strategy used to decide how to batch records in
a single request (default: None). Valid values: 'MultiRecord'
and 'SingleRecord'.
Expand All @@ -305,35 +328,48 @@ def __init__(
accept (str): The accept header passed by the client to
the inference endpoint. If it is supported by the endpoint,
it will be the format of the batch transform output.
env (dict): The Environment variables to be set for use during the
transform job (default: None).
max_concurrent_transforms (int): The maximum number of HTTP requests
to be made to each individual transform container at one time.
max_payload (int): Maximum size of the payload in a single HTTP
request to the container in MB.
tags (list[dict]): List of tags for labeling a transform job
(default: None). For more, see the SageMaker API documentation for
`Tag <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume
attached to the ML compute instance (default: None).
depends_on (List[str] or List[Step]): The list of step names or step instances
the first step in the collection depends on
the first step in the collection depends on.
repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies
for the repack model step
for the repack model step.
model_step_retry_policies (List[RetryPolicy]): The list of retry policies for
model step
model step.
transform_step_retry_policies (List[RetryPolicy]): The list of retry policies for
transform step
transform step.
**kwargs: Extra keyword arguments for `_RepackModelStep` or `Model`.
"""
steps = []
if "entry_point" in kwargs:
entry_point = kwargs["entry_point"]
source_dir = kwargs.get("source_dir")
dependencies = kwargs.get("dependencies")
entry_point = kwargs.pop("entry_point", None)
source_dir = kwargs.pop("source_dir", None)
dependencies = kwargs.pop("dependencies", None)
code_location = kwargs.pop("code_location", None)
subnets = kwargs.pop("subnets", None)
security_group_ids = kwargs.pop("security_group_ids", None)

repack_model_step = _RepackModelStep(
name=f"{name}RepackModel",
depends_on=depends_on,
retry_policies=repack_model_step_retry_policies,
sagemaker_session=estimator.sagemaker_session,
role=estimator.sagemaker_session,
sagemaker_session=sagemaker_session,
role=role,
model_data=model_data,
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
code_location=code_location,
tags=tags,
subnets=estimator.subnets,
security_group_ids=estimator.security_group_ids,
subnets=subnets,
security_group_ids=security_group_ids,
description=description,
display_name=display_name,
)
Expand All @@ -346,12 +382,12 @@ def predict_wrapper(endpoint, session):
predictor_cls = predictor_cls or predict_wrapper

model = Model(
image_uri=image_uri or estimator.training_image_uri(),
image_uri=image_uri,
model_data=model_data,
predictor_cls=predictor_cls,
vpc_config=None,
sagemaker_session=estimator.sagemaker_session,
role=estimator.role,
sagemaker_session=sagemaker_session,
role=role,
**kwargs,
)
model_step = CreateModelStep(
Expand Down Expand Up @@ -382,7 +418,7 @@ def predict_wrapper(endpoint, session):
tags=tags,
base_transform_job_name=name,
volume_kms_key=volume_kms_key,
sagemaker_session=estimator.sagemaker_session,
sagemaker_session=sagemaker_session,
)
transform_step = TransformStep(
name=f"{name}TransformStep",
Expand Down
127 changes: 125 additions & 2 deletions tests/unit/sagemaker/workflow/test_step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def test_register_model_with_model_repack_with_pipeline_model(
raise Exception("A step exists in the collection of an invalid type.")


def test_estimator_transformer(estimator):
def test_estimator_transformer(sagemaker_session):
model_data = f"s3://{BUCKET}/model.tar.gz"
model_inputs = CreateModelInput(
instance_type="c4.4xlarge",
Expand All @@ -814,7 +814,6 @@ def test_estimator_transformer(estimator):
transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
estimator_transformer = EstimatorTransformer(
name="EstimatorTransformerStep",
estimator=estimator,
model_data=model_data,
model_inputs=model_inputs,
instance_count=1,
Expand All @@ -824,6 +823,9 @@ def test_estimator_transformer(estimator):
model_step_retry_policies=[service_fault_retry_policy],
transform_step_retry_policies=[service_fault_retry_policy],
repack_model_step_retry_policies=[service_fault_retry_policy],
image_uri=IMAGE_URI,
sagemaker_session=sagemaker_session,
role=ROLE,
)
request_dicts = estimator_transformer.request_dicts()
assert len(request_dicts) == 2
Expand Down Expand Up @@ -865,3 +867,124 @@ def test_estimator_transformer(estimator):
}
else:
raise Exception("A step exists in the collection of an invalid type.")


def test_estimator_transformer_with_model_repack(sagemaker_session):
model_data = f"s3://{BUCKET}/model.tar.gz"
model_inputs = CreateModelInput(
instance_type="c4.4xlarge",
accelerator_type="ml.eia1.medium",
)
service_fault_retry_policy = StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10
)
transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
dummy_requirements = f"{DATA_DIR}/dummy_requirements.txt"
estimator_transformer = EstimatorTransformer(
name="EstimatorTransformerStep",
model_data=model_data,
model_inputs=model_inputs,
instance_count=1,
instance_type="ml.c4.4xlarge",
transform_inputs=transform_inputs,
depends_on=["TestStep"],
model_step_retry_policies=[service_fault_retry_policy],
transform_step_retry_policies=[service_fault_retry_policy],
repack_model_step_retry_policies=[service_fault_retry_policy],
image_uri=IMAGE_URI,
sagemaker_session=sagemaker_session,
role=ROLE,
entry_point=f"{DATA_DIR}/dummy_script.py",
dependencies=[dummy_requirements],
)
request_dicts = estimator_transformer.request_dicts()
assert len(request_dicts) == 3

for request_dict in request_dicts:
if request_dict["Type"] == "Training":
assert request_dict["Name"] == "EstimatorTransformerStepRepackModel"
assert len(request_dict["DependsOn"]) == 1
assert request_dict["DependsOn"][0] == "TestStep"
arguments = request_dict["Arguments"]
repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
assert ordered(arguments) == ordered(
{
"AlgorithmSpecification": {
"TrainingImage": MODEL_REPACKING_IMAGE_URI,
"TrainingInputMode": "File",
},
"DebugHookConfig": {
"CollectionConfigurations": [],
"S3OutputPath": f"s3://{BUCKET}/",
},
"HyperParameters": {
"inference_script": '"dummy_script.py"',
"dependencies": f'"{dummy_requirements}"',
"model_archive": '"model.tar.gz"',
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
BUCKET, repacker_job_name.replace('"', "")
),
"sagemaker_program": '"_repack_model.py"',
"sagemaker_container_log_level": "20",
"sagemaker_job_name": repacker_job_name,
"sagemaker_region": f'"{REGION}"',
"source_dir": "null",
},
"InputDataConfig": [
{
"ChannelName": "training",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{BUCKET}",
}
},
}
],
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
"ResourceConfig": {
"InstanceCount": 1,
"InstanceType": "ml.m5.large",
"VolumeSizeInGB": 30,
},
"RoleArn": ROLE,
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
}
)

elif request_dict["Type"] == "Model":
assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep"
assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()]
arguments = request_dict["Arguments"]
assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties)
arguments["PrimaryContainer"].pop("ModelDataUrl")
assert arguments == {
"ExecutionRoleArn": "DummyRole",
"PrimaryContainer": {
"Environment": {},
"Image": "fakeimage",
},
}

elif request_dict["Type"] == "Transform":
assert request_dict["Name"] == "EstimatorTransformerStepTransformStep"
assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()]
arguments = request_dict["Arguments"]
assert isinstance(arguments["ModelName"], Properties)
arguments.pop("ModelName")
assert "DependsOn" not in request_dict
assert arguments == {
"TransformInput": {
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{BUCKET}/transform_manifest",
}
}
},
"TransformOutput": {"S3OutputPath": None},
"TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"},
}
else:
raise Exception("A step exists in the collection of an invalid type.")