Skip to content

Commit 65c4732

Browse files
nargokulpintaoz-aws
authored andcommitted
Integration tests for Model Builder Handshake (#1610)
* Integration tests for Model Builder Handshake * Codestyle checks
1 parent c96ed13 commit 65c4732

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from __future__ import absolute_import
2+
3+
import unittest
4+
import os
5+
import uuid
6+
7+
import numpy as np
8+
import pandas as pd
9+
from sagemaker_core.main.resources import TrainingJob
10+
from xgboost import XGBClassifier
11+
12+
from sagemaker.serve import ModelBuilder, SchemaBuilder
13+
from sagemaker.serve.spec.inference_spec import InferenceSpec
14+
from sagemaker_core.main.shapes import OutputDataConfig, StoppingCondition, Channel, DataSource, \
15+
S3DataSource, AlgorithmSpecification, ResourceConfig
16+
from sklearn.datasets import load_iris
17+
from sklearn.model_selection import train_test_split
18+
19+
from sagemaker import Session, get_execution_role, image_uris
20+
from sagemaker.modules.train import ModelTrainer
21+
22+
prefix = "DEMO-scikit-iris"
23+
TRAIN_DATA = "train.csv"
24+
TEST_DATA = "test.csv"
25+
DATA_DIRECTORY = "data"
26+
27+
28+
class XGBoostSpec(InferenceSpec):
29+
def load(self, model_dir: str):
30+
print(model_dir)
31+
model = XGBClassifier()
32+
model.load_model(model_dir + "/xgboost-model")
33+
return model
34+
35+
def invoke(self, input_object: object, model: object):
36+
prediction_probabilities = model.predict_proba(input_object)
37+
predictions = np.argmax(prediction_probabilities, axis=1)
38+
return predictions
39+
40+
41+
class TestModelBuilderHandshake(unittest.TestCase):
42+
43+
def setUp(self):
44+
self.sagemaker_session = Session()
45+
self.role = get_execution_role()
46+
self.region = self.sagemaker_session.boto_region_name
47+
self.bucket = self.sagemaker_session.default_bucket()
48+
self.setup_data()
49+
50+
def setup_data(self):
51+
self.iris = load_iris()
52+
self.iris_df = pd.DataFrame(self.iris.data, columns=self.iris.feature_names)
53+
self.iris_df['target'] = self.iris.target
54+
55+
os.makedirs('./data', exist_ok=True)
56+
57+
iris_df = self.iris_df[
58+
['target'] + [col for col in self.iris_df.columns if col != 'target']]
59+
60+
self.train_data, self.test_data = train_test_split(iris_df, test_size=0.2, random_state=42)
61+
62+
self.train_data.to_csv('./data/train.csv', index=False, header=False)
63+
self.test_data.to_csv('./data/test.csv', index=False, header=False)
64+
65+
# Remove the target column from the testing data. We will use this to call invoke_endpoint later
66+
self.test_data_no_target = self.test_data.drop('target', axis=1)
67+
68+
self.train_input = self.sagemaker_session.upload_data(
69+
DATA_DIRECTORY, bucket=self.bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY)
70+
)
71+
72+
self.s3_input_path = "s3://{}/{}/data/{}".format(self.bucket, prefix, TRAIN_DATA)
73+
self.s3_output_path = "s3://{}/{}/output".format(self.bucket, prefix)
74+
self.s3_test_path = "s3://{}/{}/data/{}".format(self.bucket, prefix, TEST_DATA)
75+
self.xgboost_image = image_uris.retrieve(framework="xgboost", region="us-west-2",
76+
image_scope="training")
77+
data = {
78+
'Name': ['Alice', 'Bob', 'Charlie']
79+
}
80+
df = pd.DataFrame(data)
81+
self.schema_builder = SchemaBuilder(sample_input=df, sample_output=df)
82+
83+
def test_model_trainer_handshake(self):
84+
model_trainer = ModelTrainer(
85+
base_job_name='test-mb-handshake',
86+
hyperparameters={
87+
'objective': 'multi:softmax',
88+
'num_class': '3',
89+
'num_round': '10',
90+
'eval_metric': 'merror'
91+
},
92+
training_image=self.xgboost_image,
93+
training_input_mode='File',
94+
role=self.role,
95+
output_data_config=OutputDataConfig(
96+
s3_output_path=self.s3_output_path
97+
),
98+
stopping_condition=StoppingCondition(
99+
max_runtime_in_seconds=600
100+
)
101+
)
102+
103+
model_trainer.train(
104+
input_data_config=[
105+
Channel(
106+
channel_name='train',
107+
content_type='csv',
108+
compression_type='None',
109+
record_wrapper_type='None',
110+
data_source=DataSource(
111+
s3_data_source=S3DataSource(
112+
s3_data_type='S3Prefix',
113+
s3_uri=self.s3_input_path,
114+
s3_data_distribution_type='FullyReplicated'
115+
)
116+
))])
117+
118+
model_builder = ModelBuilder(
119+
model=model_trainer, # ModelTrainer object passed onto ModelBuilder directly
120+
role_arn=self.role,
121+
image_uri=self.xgboost_image,
122+
inference_spec=XGBoostSpec(),
123+
schema_builder=self.schema_builder,
124+
instance_type="ml.c6i.xlarge"
125+
)
126+
model = model_builder.build()
127+
assert (model.model_data == model_trainer
128+
._latest_training_job.model_artifacts.s3_model_artifacts)
129+
130+
def test_sagemaker_core_handshake(self):
131+
training_job_name = str(uuid.uuid4())
132+
training_job = TrainingJob.create(
133+
training_job_name=training_job_name,
134+
hyper_parameters={
135+
'objective': 'multi:softmax',
136+
'num_class': '3',
137+
'num_round': '10',
138+
'eval_metric': 'merror'
139+
},
140+
algorithm_specification=AlgorithmSpecification(
141+
training_image=self.xgboost_image,
142+
training_input_mode='File'
143+
),
144+
role_arn=self.role,
145+
input_data_config=[
146+
Channel(
147+
channel_name='train',
148+
content_type='csv',
149+
compression_type='None',
150+
record_wrapper_type='None',
151+
data_source=DataSource(
152+
s3_data_source=S3DataSource(
153+
s3_data_type='S3Prefix',
154+
s3_uri=self.s3_input_path,
155+
s3_data_distribution_type='FullyReplicated'
156+
)
157+
)
158+
)
159+
],
160+
output_data_config=OutputDataConfig(
161+
s3_output_path=self.s3_output_path
162+
),
163+
resource_config=ResourceConfig(
164+
instance_type='ml.m4.xlarge',
165+
instance_count=1,
166+
volume_size_in_gb=30
167+
),
168+
stopping_condition=StoppingCondition(
169+
max_runtime_in_seconds=600
170+
)
171+
)
172+
training_job.wait()
173+
174+
model_builder = ModelBuilder(
175+
model=training_job,
176+
role_arn=self.role,
177+
inference_spec=XGBoostSpec(),
178+
image_uri=self.xgboost_image,
179+
schema_builder=self.schema_builder,
180+
instance_type="ml.c6i.xlarge"
181+
)
182+
model = model_builder.build()
183+
184+
assert model.model_data == training_job.model_artifacts.s3_model_artifacts

0 commit comments

Comments
 (0)