Skip to content

Commit cf915c8

Browse files
BasilBeiroutiBasil Beirouti
authored andcommitted
test: Vspecinteg2 (aws#3249)
Co-authored-by: Basil Beirouti <[email protected]>
1 parent 67ff491 commit cf915c8

File tree

8 files changed

+284
-3
lines changed

8 files changed

+284
-3
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
FROM public.ecr.aws/ubuntu/ubuntu:18.04
2+
3+
# Specify encoding
4+
ENV LC_ALL=C.UTF-8
5+
ENV LANG=C.UTF-8
6+
7+
# Install python-pip
8+
RUN apt-get update \
9+
&& apt-get install -y python3.6 python3-pip \
10+
&& ln -s /usr/bin/python3.6 /usr/bin/python \
11+
&& ln -s /usr/bin/pip3 /usr/bin/pip;
12+
13+
# Install flask server
14+
RUN pip install -U flask gunicorn joblib sklearn;
15+
16+
#Copy scoring logic and model artifacts into the docker image
17+
COPY scoring_logic.py /scoring_logic.py
18+
COPY wsgi.py /wsgi.py
19+
COPY model-artifacts.joblib /opt/ml/model/model-artifacts.joblib
20+
COPY serve /opt/program/serve
21+
22+
RUN chmod 755 /opt/program/serve
23+
ENV PATH=/opt/program:${PATH}
Binary file not shown.
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from enum import IntEnum
2+
import json
3+
import logging
4+
import re
5+
from flask import Flask
6+
from flask import request
7+
from joblib import dump, load
8+
import numpy as np
9+
import os
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class IrisLabel(IntEnum):
15+
setosa = 0
16+
versicolor = 1
17+
virginica = 2
18+
19+
20+
class IrisModel:
21+
LABELS = IrisLabel
22+
NUM_FEATURES = 4
23+
24+
def __init__(self, model_path):
25+
self.model_path = model_path
26+
self._model = None
27+
28+
# Cache the model to prevent repeatedly loading it for every request
29+
@property
30+
def model(self):
31+
if self._model is None:
32+
self._model = load(self.model_path)
33+
return self._model
34+
35+
def predict_from_csv(self, lines, **kwargs):
36+
data = np.genfromtxt(lines.split("\n"), delimiter=",")
37+
return self.predict(data, **kwargs)
38+
39+
def predict_from_json(self, obj, **kwargs):
40+
req = json.loads(obj)
41+
instances = req["instances"]
42+
x = np.array([instance["features"] for instance in instances])
43+
return self.predict(x, **kwargs)
44+
45+
def predict_from_jsonlines(self, obj, **kwargs):
46+
x = np.array([json.loads(line)["features"] for line in obj.split("\n")])
47+
return self.predict(x, **kwargs)
48+
49+
def predict(self, x, return_names=True):
50+
label_codes = self.model.predict(x.reshape(-1, IrisModel.NUM_FEATURES))
51+
52+
if return_names:
53+
predictions = [IrisModel.LABELS(code).name for code in label_codes]
54+
else:
55+
predictions = label_codes.tolist()
56+
57+
return predictions
58+
59+
60+
SUPPORTED_REQUEST_MIMETYPES = ["text/csv", "application/json", "application/jsonlines"]
61+
SUPPORTED_RESPONSE_MIMETYPES = ["application/json", "application/jsonlines", "text/csv"]
62+
63+
app = Flask(__name__)
64+
model = IrisModel(model_path="/opt/ml/model/model-artifacts.joblib")
65+
66+
# Create a path for health checks
67+
@app.route("/ping")
68+
def endpoint_ping():
69+
return ""
70+
71+
72+
# Create a path for inference
73+
@app.route("/invocations", methods=["POST"])
74+
def endpoint_invocations():
75+
try:
76+
logger.info(f"Processing request: {request.headers}")
77+
logger.debug(f"Payload: {request.headers}")
78+
79+
if request.content_type not in SUPPORTED_REQUEST_MIMETYPES:
80+
logger.error(f"Unsupported Content-Type specified: {request.content_type}")
81+
return f"Invalid Content-Type. Supported Content-Types: {', '.join(SUPPORTED_REQUEST_MIMETYPES)}"
82+
elif request.content_type == "text/csv":
83+
# Step 1: Decode payload into input format expected by model
84+
data = request.get_data().decode("utf8")
85+
# Step 2: Perform inference with the loaded model
86+
predictions = model.predict_from_csv(data)
87+
elif request.content_type == "application/json":
88+
data = request.get_data().decode("utf8")
89+
predictions = model.predict_from_json(data)
90+
elif request.content_type == "application/jsonlines":
91+
data = request.get_data().decode("utf8")
92+
predictions = model.predict_from_jsonlines(data)
93+
94+
# Step 3: Process predictions into the specified response type (if specified)
95+
response_mimetype = request.accept_mimetypes.best_match(
96+
SUPPORTED_RESPONSE_MIMETYPES, default="application/json"
97+
)
98+
99+
if response_mimetype == "text/csv":
100+
response = "\n".join(predictions)
101+
elif response_mimetype == "application/jsonlines":
102+
response = "\n".join([json.dumps({"class": pred}) for pred in predictions])
103+
elif response_mimetype == "application/json":
104+
response = json.dumps({"predictions": [{"class": pred} for pred in predictions]})
105+
106+
return response
107+
except Exception as e:
108+
return f"Error during model invocation: {str(e)} for input: {request.get_data()}"

tests/data/marketplace/iris/serve

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
3+
ls -lah /opt/ml/model
4+
5+
# Run gunicorn server on port 8080 for SageMaker
6+
gunicorn --worker-tmp-dir /dev/shm --bind 0.0.0.0:8080 wsgi:app

tests/data/marketplace/iris/wsgi.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from scoring_logic import app
2+
3+
if __name__ == "__main__":
4+
app.run()

tests/integ/test_marketplace.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,31 @@
1515
import itertools
1616
import os
1717
import time
18+
import requests
1819

1920
import pandas
2021
import pytest
22+
import docker
2123

2224
import sagemaker
2325
import tests.integ
24-
from sagemaker import AlgorithmEstimator, ModelPackage
26+
from sagemaker import AlgorithmEstimator, ModelPackage, Model
2527
from sagemaker.serializers import CSVSerializer
2628
from sagemaker.tuner import IntegerParameter, HyperparameterTuner
2729
from sagemaker.utils import sagemaker_timestamp, _aws_partition, unique_name_from_base
2830
from tests.integ import DATA_DIR
2931
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
3032
from tests.integ.marketplace_utils import REGION_ACCOUNT_MAP
33+
from tests.integ.test_multidatamodel import (
34+
_ecr_image_uri,
35+
_ecr_login,
36+
_create_repository,
37+
_delete_repository,
38+
)
39+
from tests.integ.retry import retries
40+
import logging
3141

42+
logger = logging.getLogger(__name__)
3243

3344
# All these tests require a manual 1 time subscription to the following Marketplace items:
3445
# Algorithm: Scikit Decision Trees
@@ -186,6 +197,135 @@ def predict_wrapper(endpoint, session):
186197
print(predictor.predict(test_x.values).decode("utf-8"))
187198

188199

200+
@pytest.fixture(scope="module")
201+
def iris_image(sagemaker_session):
202+
algorithm_name = unique_name_from_base("iris-classifier")
203+
ecr_image = _ecr_image_uri(sagemaker_session, algorithm_name)
204+
ecr_client = sagemaker_session.boto_session.client("ecr")
205+
username, password = _ecr_login(ecr_client)
206+
207+
docker_client = docker.from_env()
208+
209+
# Build and tag docker image locally
210+
path = os.path.join(DATA_DIR, "marketplace", "iris")
211+
image, build_logs = docker_client.images.build(
212+
path=path,
213+
tag=algorithm_name,
214+
rm=True,
215+
)
216+
image.tag(ecr_image, tag="latest")
217+
_create_repository(ecr_client, algorithm_name)
218+
219+
# Retry docker image push
220+
for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10):
221+
try:
222+
docker_client.images.push(
223+
ecr_image, auth_config={"username": username, "password": password}
224+
)
225+
break
226+
except requests.exceptions.ConnectionError:
227+
# This can happen when we try to create multiple repositories in parallel, so we retry
228+
pass
229+
230+
yield ecr_image
231+
232+
# Delete repository after the marketplace integration tests complete
233+
_delete_repository(ecr_client, algorithm_name)
234+
235+
236+
def test_create_model_package(sagemaker_session, boto_session, iris_image):
237+
MODEL_NAME = "iris-classifier-mp"
238+
# Prepare
239+
s3_bucket = sagemaker_session.default_bucket()
240+
241+
model_name = unique_name_from_base(MODEL_NAME)
242+
model_description = "This model accepts petal length, petal width, sepal length, sepal width and predicts whether \
243+
flower is of type setosa, versicolor, or virginica"
244+
245+
supported_realtime_inference_instance_types = supported_batch_transform_instance_types = [
246+
"ml.m4.xlarge"
247+
]
248+
supported_content_types = ["text/csv", "application/json", "application/jsonlines"]
249+
supported_response_MIME_types = ["application/json", "text/csv", "application/jsonlines"]
250+
251+
validation_input_path = "s3://" + s3_bucket + "/validation-input-csv/"
252+
validation_output_path = "s3://" + s3_bucket + "/validation-output-csv/"
253+
254+
iam = boto_session.resource("iam")
255+
role = iam.Role("SageMakerRole").arn
256+
sm_client = boto_session.client("sagemaker")
257+
s3_client = boto_session.client("s3")
258+
s3_client.put_object(
259+
Bucket=s3_bucket, Key="validation-input-csv/input.csv", Body="5.1, 3.5, 1.4, 0.2"
260+
)
261+
262+
ValidationSpecification = {
263+
"ValidationRole": role,
264+
"ValidationProfiles": [
265+
{
266+
"ProfileName": "Validation-test",
267+
"TransformJobDefinition": {
268+
"BatchStrategy": "SingleRecord",
269+
"TransformInput": {
270+
"DataSource": {
271+
"S3DataSource": {
272+
"S3DataType": "S3Prefix",
273+
"S3Uri": validation_input_path,
274+
}
275+
},
276+
"ContentType": supported_content_types[0],
277+
},
278+
"TransformOutput": {
279+
"S3OutputPath": validation_output_path,
280+
},
281+
"TransformResources": {
282+
"InstanceType": supported_batch_transform_instance_types[0],
283+
"InstanceCount": 1,
284+
},
285+
},
286+
},
287+
],
288+
}
289+
290+
# get pre-existing model artifact stored in ECR
291+
model = Model(
292+
image_uri=iris_image,
293+
model_data=validation_input_path + "input.csv",
294+
role=role,
295+
sagemaker_session=sagemaker_session,
296+
enable_network_isolation=False,
297+
)
298+
299+
# Call model.register() - the method under test - to create a model package
300+
model.register(
301+
supported_content_types,
302+
supported_response_MIME_types,
303+
supported_realtime_inference_instance_types,
304+
supported_batch_transform_instance_types,
305+
marketplace_cert=True,
306+
description=model_description,
307+
model_package_name=model_name,
308+
validation_specification=ValidationSpecification,
309+
)
310+
311+
# wait for model execution to complete
312+
time.sleep(60 * 3)
313+
314+
# query for all model packages with the name <MODEL_NAME>
315+
response = sm_client.list_model_packages(
316+
MaxResults=10,
317+
NameContains=MODEL_NAME,
318+
SortBy="CreationTime",
319+
SortOrder="Descending",
320+
)
321+
322+
if len(response["ModelPackageSummaryList"]) > 0:
323+
sm_client.delete_model_package(ModelPackageName=model_name)
324+
325+
# assert that response is non-empty
326+
assert len(response["ModelPackageSummaryList"]) > 0
327+
328+
189329
@pytest.mark.skipif(
190330
tests.integ.test_region() in tests.integ.NO_MARKET_PLACE_REGIONS,
191331
reason="Marketplace is not available in {}".format(tests.integ.test_region()),

tests/integ/test_multidatamodel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717
import requests
1818

19-
import botocore
2019
import docker
2120
import numpy
2221
import pytest
@@ -116,7 +115,7 @@ def _delete_repository(ecr_client, repository_name):
116115
try:
117116
ecr_client.describe_repositories(repositoryNames=[repository_name])
118117
ecr_client.delete_repository(repositoryName=repository_name, force=True)
119-
except botocore.errorfactory.ResourceNotFoundException:
118+
except ecr_client.exceptions.RepositoryNotFoundException:
120119
pass
121120

122121

tests/unit/sagemaker/model/test_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,7 @@ def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagema
742742

743743
@patch("sagemaker.get_model_package_args")
744744
def test_register_calls_model_package_args(get_model_package_args, sagemaker_session):
745+
"""model.register() should pass the ValidationSpecification to get_model_package_args()"""
745746

746747
source_dir = "s3://blah/blah/blah"
747748
t = Model(

0 commit comments

Comments
 (0)