Skip to content

Commit 27a67e0

Browse files
committed
fix: fixed failing UT's
1 parent 4e6dd93 commit 27a67e0

File tree

5 files changed

+155
-15
lines changed

5 files changed

+155
-15
lines changed

src/sagemaker/estimator.py

+15
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,10 @@ def register(
13031303
domain=None,
13041304
sample_payload_url=None,
13051305
task=None,
1306+
framework=None,
1307+
framework_version=None,
1308+
nearest_model_name=None,
1309+
data_input_configuration=None,
13061310
**kwargs,
13071311
):
13081312
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1341,6 +1345,13 @@ def register(
13411345
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
13421346
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
13431347
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
1348+
framework (str): Machine learning framework of the model package container image
1349+
(default: None).
1350+
framework_version (str): Framework version of the Model Package Container Image
1351+
(default: None).
1352+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
1353+
Amazon SageMaker Inference Recommender (default: None).
1354+
data_input_configuration (str): Input object for the model (default: None).
13441355
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
13451356
``create_model()`` to accept ``**kwargs`` to customize model creation during
13461357
deploy. For more, see the implementation docs.
@@ -1380,6 +1391,10 @@ def register(
13801391
domain=domain,
13811392
sample_payload_url=sample_payload_url,
13821393
task=task,
1394+
framework=framework,
1395+
framework_version=framework_version,
1396+
nearest_model_name=nearest_model_name,
1397+
data_input_configuration=data_input_configuration,
13831398
)
13841399

13851400
@property

src/sagemaker/session.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4314,11 +4314,11 @@ def get_model_package_args(
43144314
container = {
43154315
"Image": image_uri,
43164316
"ModelDataUrl": model_data,
4317-
"Framework": container_def_list[0]["Framework"],
4318-
"FrameworkVersion": container_def_list[0]["FrameworkVersion"],
4319-
"NearestModelName": container_def_list[0]["NearestModelName"],
4317+
"Framework": None,
4318+
"FrameworkVersion": None,
4319+
"NearestModelName": None,
43204320
"ModelInput": {
4321-
"DataInputConfig": container_def_list[0]["ModelInput"]["DataInputConfig"],
4321+
"DataInputConfig": None,
43224322
},
43234323
}
43244324
containers = [container]

src/sagemaker/workflow/step_collections.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,17 @@ def __init__(
246246
inference_instances[0] if inference_instances else None
247247
)
248248
]
249-
self.container_def_list[0].update(
250-
{
251-
"Framework": framework,
252-
"FrameworkVersion": framework_version,
253-
"NearestModelName": nearest_model_name,
254-
"ModelInput": {
255-
"DataInputConfig": data_input_configuration,
256-
},
257-
}
258-
)
259-
249+
for container_obj in self.container_def_list:
250+
container_obj.update(
251+
{
252+
"Framework": framework,
253+
"FrameworkVersion": framework_version,
254+
"NearestModelName": nearest_model_name,
255+
"ModelInput": {
256+
"DataInputConfig": data_input_configuration,
257+
},
258+
}
259+
)
260260
register_model_step = _RegisterModelStep(
261261
name=name,
262262
estimator=estimator,

tests/unit/sagemaker/workflow/test_step_collections.py

+65
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
368368
display_name="RegisterModelStep",
369369
depends_on=["TestStep"],
370370
tags=[{"Key": "myKey", "Value": "myValue"}],
371+
sample_payload_url="s3://test-bucket/model",
372+
task="IMAGE_CLASSIFICATION",
373+
framework="TENSORFLOW",
374+
framework_version="2.9",
375+
nearest_model_name="resnet50",
376+
data_input_configuration='{"input_1":[1,224,224,3]}',
371377
)
372378
assert ordered(register_model.request_dicts()) == ordered(
373379
[
@@ -383,6 +389,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
383389
{
384390
"Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
385391
"ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
392+
"Framework": None,
393+
"FrameworkVersion": None,
394+
"NearestModelName": None,
395+
"ModelInput": {
396+
"DataInputConfig": None,
397+
},
386398
}
387399
],
388400
"SupportedContentTypes": ["content_type"],
@@ -412,6 +424,8 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
412424
"ModelPackageDescription": "description",
413425
"ModelPackageGroupName": "mpg",
414426
"Tags": [{"Key": "myKey", "Value": "myValue"}],
427+
"SamplePayloadUrl": "s3://test-bucket/model",
428+
"Task": "IMAGE_CLASSIFICATION",
415429
},
416430
},
417431
]
@@ -433,6 +447,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
433447
drift_check_baselines=drift_check_baselines,
434448
approval_status="Approved",
435449
description="description",
450+
sample_payload_url="s3://test-bucket/model",
451+
task="IMAGE_CLASSIFICATION",
452+
framework="TENSORFLOW",
453+
framework_version="2.9",
454+
nearest_model_name="resnet50",
455+
data_input_configuration='{"input_1":[1,224,224,3]}',
436456
)
437457
assert ordered(register_model.request_dicts()) == ordered(
438458
[
@@ -446,6 +466,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
446466
{
447467
"Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu",
448468
"ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
469+
"Framework": None,
470+
"FrameworkVersion": None,
471+
"NearestModelName": None,
472+
"ModelInput": {
473+
"DataInputConfig": None,
474+
},
449475
}
450476
],
451477
"SupportedContentTypes": ["content_type"],
@@ -474,6 +500,8 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
474500
},
475501
"ModelPackageDescription": "description",
476502
"ModelPackageGroupName": "mpg",
503+
"SamplePayloadUrl": "s3://test-bucket/model",
504+
"Task": "IMAGE_CLASSIFICATION",
477505
},
478506
},
479507
]
@@ -502,6 +530,12 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
502530
description="description",
503531
model=pipeline_model,
504532
depends_on=["TestStep"],
533+
sample_payload_url="s3://test-bucket/model",
534+
task="IMAGE_CLASSIFICATION",
535+
framework="TENSORFLOW",
536+
framework_version="2.9",
537+
nearest_model_name="resnet50",
538+
data_input_configuration='{"input_1":[1,224,224,3]}',
505539
)
506540
assert ordered(register_model.request_dicts()) == ordered(
507541
[
@@ -517,11 +551,23 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
517551
"Image": "fakeimage1",
518552
"ModelDataUrl": "Url1",
519553
"Environment": [{"k1": "v1"}, {"k2": "v2"}],
554+
"Framework": "TENSORFLOW",
555+
"FrameworkVersion": "2.9",
556+
"NearestModelName": "resnet50",
557+
"ModelInput": {
558+
"DataInputConfig": '{"input_1":[1,224,224,3]}',
559+
},
520560
},
521561
{
522562
"Image": "fakeimage2",
523563
"ModelDataUrl": "Url2",
524564
"Environment": [{"k3": "v3"}, {"k4": "v4"}],
565+
"Framework": "TENSORFLOW",
566+
"FrameworkVersion": "2.9",
567+
"NearestModelName": "resnet50",
568+
"ModelInput": {
569+
"DataInputConfig": '{"input_1":[1,224,224,3]}',
570+
},
525571
},
526572
],
527573
"SupportedContentTypes": ["content_type"],
@@ -550,6 +596,8 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
550596
},
551597
"ModelPackageDescription": "description",
552598
"ModelPackageGroupName": "mpg",
599+
"SamplePayloadUrl": "s3://test-bucket/model",
600+
"Task": "IMAGE_CLASSIFICATION",
553601
},
554602
},
555603
]
@@ -578,6 +626,12 @@ def test_register_model_with_model_repack_with_estimator(
578626
dependencies=[dummy_requirements],
579627
depends_on=["TestStep"],
580628
tags=[{"Key": "myKey", "Value": "myValue"}],
629+
sample_payload_url="s3://test-bucket/model",
630+
task="IMAGE_CLASSIFICATION",
631+
framework="TENSORFLOW",
632+
framework_version="2.9",
633+
nearest_model_name="resnet50",
634+
data_input_configuration='{"input_1":[1,224,224,3]}',
581635
)
582636

583637
request_dicts = register_model.request_dicts()
@@ -649,6 +703,15 @@ def test_register_model_with_model_repack_with_estimator(
649703
assert isinstance(
650704
arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties
651705
)
706+
assert arguments["InferenceSpecification"]["Containers"][0]["Framework"] == None
707+
assert arguments["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] == None
708+
assert arguments["InferenceSpecification"]["Containers"][0]["NearestModelName"] == None
709+
assert (
710+
arguments["InferenceSpecification"]["Containers"][0]["ModelInput"][
711+
"DataInputConfig"
712+
]
713+
== None
714+
)
652715
del arguments["InferenceSpecification"]["Containers"]
653716
assert ordered(arguments) == ordered(
654717
{
@@ -680,6 +743,8 @@ def test_register_model_with_model_repack_with_estimator(
680743
"ModelPackageDescription": "description",
681744
"ModelPackageGroupName": "mpg",
682745
"Tags": [{"Key": "myKey", "Value": "myValue"}],
746+
"SamplePayloadUrl": "s3://test-bucket/model",
747+
"Task": "IMAGE_CLASSIFICATION",
683748
}
684749
)
685750
else:

tests/unit/test_estimator.py

+60
Original file line numberDiff line numberDiff line change
@@ -3109,13 +3109,25 @@ def test_register_default_image(sagemaker_session):
31093109
response_types = ["application/json"]
31103110
inference_instances = ["ml.m4.xlarge"]
31113111
transform_instances = ["ml.m4.xlarget"]
3112+
sample_payload_url = "s3://test-bucket/model"
3113+
task = "IMAGE_CLASSIFICATION"
3114+
framework = "TENSORFLOW"
3115+
framework_version = "2.9"
3116+
nearest_model_name = "resnet50"
3117+
data_input_config = '{"input_1":[1,224,224,3]}'
31123118

31133119
estimator.register(
31143120
content_types=content_types,
31153121
response_types=response_types,
31163122
inference_instances=inference_instances,
31173123
transform_instances=transform_instances,
31183124
model_package_name=model_package_name,
3125+
sample_payload_url=sample_payload_url,
3126+
task=task,
3127+
framework=framework,
3128+
framework_version=framework_version,
3129+
nearest_model_name=nearest_model_name,
3130+
data_input_configuration=data_input_config,
31193131
)
31203132
sagemaker_session.create_model.assert_not_called()
31213133

@@ -3124,6 +3136,12 @@ def test_register_default_image(sagemaker_session):
31243136
{
31253137
"Image": estimator.image_uri,
31263138
"ModelDataUrl": estimator.model_data,
3139+
"Framework": framework,
3140+
"FrameworkVersion": framework_version,
3141+
"NearestModelName": nearest_model_name,
3142+
"ModelInput": {
3143+
"DataInputConfig": data_input_config,
3144+
},
31273145
}
31283146
],
31293147
"content_types": content_types,
@@ -3132,6 +3150,8 @@ def test_register_default_image(sagemaker_session):
31323150
"transform_instances": transform_instances,
31333151
"model_package_name": model_package_name,
31343152
"marketplace_cert": False,
3153+
"sample_payload_url": sample_payload_url,
3154+
"task": task,
31353155
}
31363156
sagemaker_session.create_model_package_from_containers.assert_called_with(
31373157
**expected_create_model_package_request
@@ -3153,11 +3173,23 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
31533173
model_package_name = "test-estimator-register-model"
31543174
content_types = ["application/json"]
31553175
response_types = ["application/json"]
3176+
sample_payload_url = "s3://test-bucket/model"
3177+
task = "IMAGE_CLASSIFICATION"
3178+
framework = "TENSORFLOW"
3179+
framework_version = "2.9"
3180+
nearest_model_name = "resnet50"
3181+
data_input_config = '{"input_1":[1,224,224,3]}'
31563182

31573183
estimator.register(
31583184
content_types=content_types,
31593185
response_types=response_types,
31603186
model_package_name=model_package_name,
3187+
sample_payload_url=sample_payload_url,
3188+
task=task,
3189+
framework=framework,
3190+
framework_version=framework_version,
3191+
nearest_model_name=nearest_model_name,
3192+
data_input_configuration=data_input_config,
31613193
)
31623194
sagemaker_session.create_model.assert_not_called()
31633195

@@ -3166,6 +3198,12 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
31663198
{
31673199
"Image": estimator.image_uri,
31683200
"ModelDataUrl": estimator.model_data,
3201+
"Framework": framework,
3202+
"FrameworkVersion": framework_version,
3203+
"NearestModelName": nearest_model_name,
3204+
"ModelInput": {
3205+
"DataInputConfig": data_input_config,
3206+
},
31693207
}
31703208
],
31713209
"content_types": content_types,
@@ -3174,6 +3212,8 @@ def test_register_default_image_without_instance_type_args(sagemaker_session):
31743212
"transform_instances": None,
31753213
"model_package_name": model_package_name,
31763214
"marketplace_cert": False,
3215+
"sample_payload_url": sample_payload_url,
3216+
"task": task,
31773217
}
31783218
sagemaker_session.create_model_package_from_containers.assert_called_with(
31793219
**expected_create_model_package_request
@@ -3198,14 +3238,26 @@ def test_register_inference_image(sagemaker_session):
31983238
inference_instances = ["ml.m4.xlarge"]
31993239
transform_instances = ["ml.m4.xlarget"]
32003240
inference_image = "fake-inference-image"
3241+
sample_payload_url = "s3://test-bucket/model"
3242+
task = "IMAGE_CLASSIFICATION"
3243+
framework = "TENSORFLOW"
3244+
framework_version = "2.9"
3245+
nearest_model_name = "resnet50"
3246+
data_input_config = '{"input_1":[1,224,224,3]}'
32013247

32023248
estimator.register(
32033249
content_types=content_types,
32043250
response_types=response_types,
32053251
inference_instances=inference_instances,
32063252
transform_instances=transform_instances,
32073253
model_package_name=model_package_name,
3254+
sample_payload_url=sample_payload_url,
3255+
task=task,
32083256
image_uri=inference_image,
3257+
framework=framework,
3258+
framework_version=framework_version,
3259+
nearest_model_name=nearest_model_name,
3260+
data_input_configuration=data_input_config,
32093261
)
32103262
sagemaker_session.create_model.assert_not_called()
32113263

@@ -3214,6 +3266,12 @@ def test_register_inference_image(sagemaker_session):
32143266
{
32153267
"Image": inference_image,
32163268
"ModelDataUrl": estimator.model_data,
3269+
"Framework": framework,
3270+
"FrameworkVersion": framework_version,
3271+
"NearestModelName": nearest_model_name,
3272+
"ModelInput": {
3273+
"DataInputConfig": data_input_config,
3274+
},
32173275
}
32183276
],
32193277
"content_types": content_types,
@@ -3222,6 +3280,8 @@ def test_register_inference_image(sagemaker_session):
32223280
"transform_instances": transform_instances,
32233281
"model_package_name": model_package_name,
32243282
"marketplace_cert": False,
3283+
"sample_payload_url": sample_payload_url,
3284+
"task": task,
32253285
}
32263286
sagemaker_session.create_model_package_from_containers.assert_called_with(
32273287
**expected_create_model_package_request

0 commit comments

Comments
 (0)