Skip to content

Commit f194f7a

Browse files
committed
fix: fixed failing Integration Test
1 parent ef06942 commit f194f7a

File tree

1 file changed

+39
-32
lines changed

1 file changed

+39
-32
lines changed

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

+39-32
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ def test_conditional_pytorch_training_model_registration(
9494
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
9595
in_condition_input = ParameterString(name="Foo", default_value="Foo")
9696

97+
task = "IMAGE_CLASSIFICATION"
98+
sample_payload_url = "s3://test-bucket/model"
99+
framework = "TENSORFLOW"
100+
framework_version = "2.9"
101+
nearest_model_name = "resnet50"
102+
data_input_configuration = '{"input_1":[1,224,224,3]}'
103+
97104
# If image_uri is not provided, the instance_type should not be a pipeline variable
98105
# since instance_type is used to retrieve image_uri in compile time (PySDK)
99106
pytorch_estimator = PyTorch(
@@ -120,6 +127,12 @@ def test_conditional_pytorch_training_model_registration(
120127
inference_instances=["*"],
121128
transform_instances=["*"],
122129
description="test-description",
130+
sample_payload_url=sample_payload_url,
131+
task=task,
132+
framework=framework,
133+
framework_version=framework_version,
134+
nearest_model_name=nearest_model_name,
135+
data_input_configuration=data_input_configuration,
123136
)
124137

125138
model = Model(
@@ -201,6 +214,13 @@ def test_mxnet_model_registration(
201214
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
202215
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
203216

217+
task = "IMAGE_CLASSIFICATION"
218+
sample_payload_url = "s3://test-bucket/model"
219+
framework = "TENSORFLOW"
220+
framework_version = "2.9"
221+
nearest_model_name = "resnet50"
222+
data_input_configuration = '{"input_1":[1,224,224,3]}'
223+
204224
model = MXNetModel(
205225
entry_point=entry_point,
206226
source_dir=source_dir,
@@ -219,6 +239,12 @@ def test_mxnet_model_registration(
219239
inference_instances=["ml.m5.xlarge"],
220240
transform_instances=["*"],
221241
description="test-description",
242+
sample_payload_url=sample_payload_url,
243+
task=task,
244+
framework=framework,
245+
framework_version=framework_version,
246+
nearest_model_name=nearest_model_name,
247+
data_input_configuration=data_input_configuration,
222248
)
223249

224250
pipeline = Pipeline(
@@ -262,6 +288,13 @@ def test_sklearn_xgboost_sip_model_registration(
262288
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
263289
instance_type = "ml.m5.xlarge"
264290

291+
task = "IMAGE_CLASSIFICATION"
292+
sample_payload_url = "s3://test-bucket/model"
293+
framework = "TENSORFLOW"
294+
framework_version = "2.9"
295+
nearest_model_name = "resnet50"
296+
data_input_configuration = '{"input_1":[1,224,224,3]}'
297+
265298
# The instance_type should not be a pipeline variable
266299
# since it is used to retrieve image_uri in compile time (PySDK)
267300
sklearn_processor = SKLearnProcessor(
@@ -412,6 +445,12 @@ def test_sklearn_xgboost_sip_model_registration(
412445
inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
413446
transform_instances=["ml.m5.xlarge"],
414447
model_package_group_name="windturbine",
448+
sample_payload_url=sample_payload_url,
449+
task=task,
450+
framework=framework,
451+
framework_version=framework_version,
452+
nearest_model_name=nearest_model_name,
453+
data_input_configuration=data_input_configuration,
415454
)
416455

417456
pipeline = Pipeline(
@@ -575,27 +614,8 @@ def test_model_registration_with_drift_check_baselines(
575614
role=role,
576615
)
577616

578-
base_dir = os.path.join(DATA_DIR, "mxnet_mnist")
579-
source_dir = os.path.join(base_dir, "code")
580-
entry_point = os.path.join(source_dir, "inference.py")
581-
mx_mnist_model_data = os.path.join(base_dir, "model.tar.gz")
582-
583-
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
584-
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
585-
586-
model = MXNetModel(
587-
entry_point=entry_point,
588-
source_dir=source_dir,
589-
role=role,
590-
model_data=mx_mnist_model_data,
591-
framework_version="1.4.0",
592-
py_version="py3",
593-
sagemaker_session=sagemaker_session,
594-
)
595-
596617
step_register = RegisterModel(
597618
name="MyRegisterModelStep",
598-
model=model,
599619
estimator=estimator,
600620
model_data=model_uri_param,
601621
content_types=["application/json"],
@@ -686,19 +706,6 @@ def test_model_registration_with_drift_check_baselines(
686706
assert response["Domain"] == domain
687707
assert response["Task"] == task
688708
assert response["SamplePayloadUrl"] == sample_payload_url
689-
assert response["InferenceSpecification"]["Containers"][0]["Framework"] == framework
690-
assert (
691-
response["InferenceSpecification"]["Containers"][0]["FrameworkVersion"]
692-
== framework_version
693-
)
694-
assert (
695-
response["InferenceSpecification"]["Containers"][0]["NearestModelName"]
696-
== nearest_model_name
697-
)
698-
assert (
699-
response["InferenceSpecification"]["Containers"][0]["ModelInput"]["DataInputConfig"]
700-
== data_input_configuration
701-
)
702709
break
703710
finally:
704711
try:

0 commit comments

Comments
 (0)