@@ -94,6 +94,13 @@ def test_conditional_pytorch_training_model_registration(
94
94
good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
95
95
in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
96
96
97
+ task = "IMAGE_CLASSIFICATION"
98
+ sample_payload_url = "s3://test-bucket/model"
99
+ framework = "TENSORFLOW"
100
+ framework_version = "2.9"
101
+ nearest_model_name = "resnet50"
102
+ data_input_configuration = '{"input_1":[1,224,224,3]}'
103
+
97
104
# If image_uri is not provided, the instance_type should not be a pipeline variable
98
105
# since instance_type is used to retrieve image_uri in compile time (PySDK)
99
106
pytorch_estimator = PyTorch (
@@ -120,6 +127,12 @@ def test_conditional_pytorch_training_model_registration(
120
127
inference_instances = ["*" ],
121
128
transform_instances = ["*" ],
122
129
description = "test-description" ,
130
+ sample_payload_url = sample_payload_url ,
131
+ task = task ,
132
+ framework = framework ,
133
+ framework_version = framework_version ,
134
+ nearest_model_name = nearest_model_name ,
135
+ data_input_configuration = data_input_configuration ,
123
136
)
124
137
125
138
model = Model (
@@ -201,6 +214,13 @@ def test_mxnet_model_registration(
201
214
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
202
215
instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
203
216
217
+ task = "IMAGE_CLASSIFICATION"
218
+ sample_payload_url = "s3://test-bucket/model"
219
+ framework = "TENSORFLOW"
220
+ framework_version = "2.9"
221
+ nearest_model_name = "resnet50"
222
+ data_input_configuration = '{"input_1":[1,224,224,3]}'
223
+
204
224
model = MXNetModel (
205
225
entry_point = entry_point ,
206
226
source_dir = source_dir ,
@@ -219,6 +239,12 @@ def test_mxnet_model_registration(
219
239
inference_instances = ["ml.m5.xlarge" ],
220
240
transform_instances = ["*" ],
221
241
description = "test-description" ,
242
+ sample_payload_url = sample_payload_url ,
243
+ task = task ,
244
+ framework = framework ,
245
+ framework_version = framework_version ,
246
+ nearest_model_name = nearest_model_name ,
247
+ data_input_configuration = data_input_configuration ,
222
248
)
223
249
224
250
pipeline = Pipeline (
@@ -262,6 +288,13 @@ def test_sklearn_xgboost_sip_model_registration(
262
288
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
263
289
instance_type = "ml.m5.xlarge"
264
290
291
+ task = "IMAGE_CLASSIFICATION"
292
+ sample_payload_url = "s3://test-bucket/model"
293
+ framework = "TENSORFLOW"
294
+ framework_version = "2.9"
295
+ nearest_model_name = "resnet50"
296
+ data_input_configuration = '{"input_1":[1,224,224,3]}'
297
+
265
298
# The instance_type should not be a pipeline variable
266
299
# since it is used to retrieve image_uri in compile time (PySDK)
267
300
sklearn_processor = SKLearnProcessor (
@@ -412,6 +445,12 @@ def test_sklearn_xgboost_sip_model_registration(
412
445
inference_instances = ["ml.t2.medium" , "ml.m5.xlarge" ],
413
446
transform_instances = ["ml.m5.xlarge" ],
414
447
model_package_group_name = "windturbine" ,
448
+ sample_payload_url = sample_payload_url ,
449
+ task = task ,
450
+ framework = framework ,
451
+ framework_version = framework_version ,
452
+ nearest_model_name = nearest_model_name ,
453
+ data_input_configuration = data_input_configuration ,
415
454
)
416
455
417
456
pipeline = Pipeline (
@@ -575,27 +614,8 @@ def test_model_registration_with_drift_check_baselines(
575
614
role = role ,
576
615
)
577
616
578
- base_dir = os .path .join (DATA_DIR , "mxnet_mnist" )
579
- source_dir = os .path .join (base_dir , "code" )
580
- entry_point = os .path .join (source_dir , "inference.py" )
581
- mx_mnist_model_data = os .path .join (base_dir , "model.tar.gz" )
582
-
583
- instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
584
- instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
585
-
586
- model = MXNetModel (
587
- entry_point = entry_point ,
588
- source_dir = source_dir ,
589
- role = role ,
590
- model_data = mx_mnist_model_data ,
591
- framework_version = "1.4.0" ,
592
- py_version = "py3" ,
593
- sagemaker_session = sagemaker_session ,
594
- )
595
-
596
617
step_register = RegisterModel (
597
618
name = "MyRegisterModelStep" ,
598
- model = model ,
599
619
estimator = estimator ,
600
620
model_data = model_uri_param ,
601
621
content_types = ["application/json" ],
@@ -686,19 +706,6 @@ def test_model_registration_with_drift_check_baselines(
686
706
assert response ["Domain" ] == domain
687
707
assert response ["Task" ] == task
688
708
assert response ["SamplePayloadUrl" ] == sample_payload_url
689
- assert response ["InferenceSpecification" ]["Containers" ][0 ]["Framework" ] == framework
690
- assert (
691
- response ["InferenceSpecification" ]["Containers" ][0 ]["FrameworkVersion" ]
692
- == framework_version
693
- )
694
- assert (
695
- response ["InferenceSpecification" ]["Containers" ][0 ]["NearestModelName" ]
696
- == nearest_model_name
697
- )
698
- assert (
699
- response ["InferenceSpecification" ]["Containers" ][0 ]["ModelInput" ]["DataInputConfig" ]
700
- == data_input_configuration
701
- )
702
709
break
703
710
finally :
704
711
try :
0 commit comments