Skip to content

Commit ee6afcf

Browse files
fix: Remove entry_point before calling Model on EstimatorTransformer
While writing the unit test for EstimatorTransformer with repack model and a custom output output_path, I discovered that sending an entry_point to EstimatorTransformer was raising an "unexpected keyword argument 'entry_point'" on Model.__init__. Using code from RegisterModel as a base, I removed the entry_point and other repack variables from kwargs. Also implemented unit tests for this case.
1 parent 7a1f4f8 commit ee6afcf

File tree

2 files changed

+128
-6
lines changed

2 files changed

+128
-6
lines changed

src/sagemaker/workflow/step_collections.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ def __init__(
284284
285285
An estimator-centric step collection. It models what happens in workflows
286286
when invoking the `transform()` method on an estimator instance:
287-
First, if custom
288-
model artifacts are required, a `_RepackModelStep` is included.
287+
First, if a custom
288+
entry point script is required, a `_RepackModelStep` is included.
289289
Second, a
290290
`CreateModelStep` with the model data passed in from a training step or other
291291
training job output.
@@ -327,10 +327,13 @@ def __init__(
327327
transform step
328328
"""
329329
steps = []
330+
repack_model = False
331+
330332
if "entry_point" in kwargs:
331-
entry_point = kwargs["entry_point"]
332-
source_dir = kwargs.get("source_dir")
333-
dependencies = kwargs.get("dependencies")
333+
repack_model = True
334+
entry_point = kwargs.pop("entry_point", None)
335+
source_dir = kwargs.pop("source_dir", None)
336+
dependencies = kwargs.pop("dependencies", None)
334337
repack_model_step = _RepackModelStep(
335338
name=f"{name}RepackModel",
336339
depends_on=depends_on,
@@ -347,6 +350,7 @@ def __init__(
347350
description=description,
348351
display_name=display_name,
349352
repack_output_path=repack_output_path,
353+
**kwargs,
350354
)
351355
steps.append(repack_model_step)
352356
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
@@ -373,7 +377,7 @@ def predict_wrapper(endpoint, session):
373377
display_name=display_name,
374378
retry_policies=model_step_retry_policies,
375379
)
376-
if "entry_point" not in kwargs and depends_on:
380+
if not repack_model and depends_on:
377381
# if the CreateModelStep is the first step in the collection
378382
model_step.add_depends_on(depends_on)
379383
steps.append(model_step)

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,3 +865,121 @@ def test_estimator_transformer(estimator):
865865
}
866866
else:
867867
raise Exception("A step exists in the collection of an invalid type.")
868+
869+
870+
def test_estimator_transformer_with_model_repack(estimator):
871+
model_data = f"s3://{BUCKET}/model.tar.gz"
872+
dummy_requirements = f"{DATA_DIR}/dummy_requirements.txt"
873+
model_inputs = CreateModelInput(
874+
instance_type="c4.4xlarge",
875+
accelerator_type="ml.eia1.medium",
876+
)
877+
transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
878+
estimator_transformer = EstimatorTransformer(
879+
name="EstimatorTransformerStep",
880+
estimator=estimator,
881+
model_data=model_data,
882+
model_inputs=model_inputs,
883+
instance_count=1,
884+
instance_type="ml.c4.4xlarge",
885+
transform_inputs=transform_inputs,
886+
depends_on=["TestStep"],
887+
entry_point=f"{DATA_DIR}/dummy_script.py",
888+
dependencies=[dummy_requirements],
889+
)
890+
request_dicts = estimator_transformer.request_dicts()
891+
assert len(request_dicts) == 3
892+
893+
for request_dict in request_dicts:
894+
if request_dict["Type"] == "Training":
895+
assert request_dict["Name"] == "EstimatorTransformerStepRepackModel"
896+
assert len(request_dict["DependsOn"]) == 1
897+
assert request_dict["DependsOn"][0] == "TestStep"
898+
arguments = request_dict["Arguments"]
899+
repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
900+
assert ordered(arguments) == ordered(
901+
{
902+
"AlgorithmSpecification": {
903+
"TrainingImage": MODEL_REPACKING_IMAGE_URI,
904+
"TrainingInputMode": "File",
905+
},
906+
"DebugHookConfig": {
907+
"CollectionConfigurations": [],
908+
"S3OutputPath": f"s3://{BUCKET}/",
909+
},
910+
"HyperParameters": {
911+
"inference_script": '"dummy_script.py"',
912+
"dependencies": f'"{dummy_requirements}"',
913+
"model_archive": '"model.tar.gz"',
914+
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
915+
BUCKET, repacker_job_name.replace('"', "")
916+
),
917+
"sagemaker_program": '"_repack_model.py"',
918+
"sagemaker_container_log_level": "20",
919+
"sagemaker_job_name": repacker_job_name,
920+
"sagemaker_region": f'"{REGION}"',
921+
"source_dir": "null",
922+
},
923+
"InputDataConfig": [
924+
{
925+
"ChannelName": "training",
926+
"DataSource": {
927+
"S3DataSource": {
928+
"S3DataDistributionType": "FullyReplicated",
929+
"S3DataType": "S3Prefix",
930+
"S3Uri": f"s3://{BUCKET}",
931+
}
932+
},
933+
}
934+
],
935+
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
936+
"ResourceConfig": {
937+
"InstanceCount": 1,
938+
"InstanceType": "ml.m5.large",
939+
"VolumeSizeInGB": 30,
940+
},
941+
"RoleArn": ROLE,
942+
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
943+
"VpcConfig": [
944+
("SecurityGroupIds", ["123", "456"]),
945+
("Subnets", ["abc", "def"]),
946+
],
947+
}
948+
)
949+
elif request_dict["Type"] == "Model":
950+
assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep"
951+
assert "DependsOn" not in request_dict
952+
arguments = request_dict["Arguments"]
953+
assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties)
954+
del arguments["PrimaryContainer"]["ModelDataUrl"]
955+
assert ordered(arguments) == ordered(
956+
{
957+
"ExecutionRoleArn": "DummyRole",
958+
"PrimaryContainer": {
959+
"Environment": {},
960+
"Image": "fakeimage",
961+
},
962+
}
963+
)
964+
elif request_dict["Type"] == "Transform":
965+
assert request_dict["Name"] == "EstimatorTransformerStepTransformStep"
966+
assert "DependsOn" not in request_dict
967+
arguments = request_dict["Arguments"]
968+
assert isinstance(arguments["ModelName"], Properties)
969+
arguments.pop("ModelName")
970+
assert ordered(arguments) == ordered(
971+
{
972+
"TransformInput": {
973+
"DataSource": {
974+
"S3DataSource": {
975+
"S3DataType": "S3Prefix",
976+
"S3Uri": f"s3://{BUCKET}/transform_manifest",
977+
}
978+
}
979+
},
980+
"TransformOutput": {"S3OutputPath": None},
981+
"TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"},
982+
}
983+
)
984+
else:
985+
raise Exception("A step exists in the collection of an invalid type.")

0 commit comments

Comments
 (0)