Skip to content

Commit ff26090

Browse files
committed
change: add unit tests for de/ser overrides
1 parent 028b844 commit ff26090

14 files changed

+320
-1
lines changed

tests/unit/test_chainer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,29 @@ def test_model(sagemaker_session, chainer_version, chainer_py_version):
383383
assert isinstance(predictor, ChainerPredictor)
384384

385385

386+
@patch("sagemaker.utils.create_tar_file", MagicMock())
387+
def test_model_custom_serialization(sagemaker_session, chainer_version, chainer_py_version):
388+
model = ChainerModel(
389+
"s3://some/data.tar.gz",
390+
role=ROLE,
391+
entry_point=SCRIPT_PATH,
392+
sagemaker_session=sagemaker_session,
393+
framework_version=chainer_version,
394+
py_version=chainer_py_version,
395+
)
396+
custom_serializer = Mock()
397+
custom_deserializer = Mock()
398+
predictor = model.deploy(
399+
1,
400+
CPU,
401+
serializer=custom_serializer,
402+
deserializer=custom_deserializer,
403+
)
404+
assert isinstance(predictor, ChainerPredictor)
405+
assert predictor.serializer is custom_serializer
406+
assert predictor.deserializer is custom_deserializer
407+
408+
386409
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
387410
def test_model_prepare_container_def_accelerator_error(
388411
sagemaker_session, chainer_version, chainer_py_version

tests/unit/test_fm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,27 @@ def test_predictor_type(sagemaker_session):
330330
predictor = model.deploy(1, INSTANCE_TYPE)
331331

332332
assert isinstance(predictor, FactorizationMachinesPredictor)
333+
334+
335+
def test_predictor_custom_serialization(sagemaker_session):
336+
fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
337+
data = RecordSet(
338+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
339+
num_records=1,
340+
feature_dim=FEATURE_DIM,
341+
channel="train",
342+
)
343+
fm.fit(data, MINI_BATCH_SIZE)
344+
model = fm.create_model()
345+
custom_serializer = Mock()
346+
custom_deserializer = Mock()
347+
predictor = model.deploy(
348+
1,
349+
INSTANCE_TYPE,
350+
serializer=custom_serializer,
351+
deserializer=custom_deserializer,
352+
)
353+
354+
assert isinstance(predictor, FactorizationMachinesPredictor)
355+
assert predictor.serializer is custom_serializer
356+
assert predictor.deserializer is custom_deserializer

tests/unit/test_ipinsights.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,27 @@ def test_predictor_type(sagemaker_session):
305305
predictor = model.deploy(1, INSTANCE_TYPE)
306306

307307
assert isinstance(predictor, IPInsightsPredictor)
308+
309+
310+
def test_predictor_custom_serialization(sagemaker_session):
311+
ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
312+
data = RecordSet(
313+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
314+
num_records=1,
315+
feature_dim=FEATURE_DIM,
316+
channel="train",
317+
)
318+
ipinsights.fit(data, MINI_BATCH_SIZE)
319+
model = ipinsights.create_model()
320+
custom_serializer = Mock()
321+
custom_deserializer = Mock()
322+
predictor = model.deploy(
323+
1,
324+
INSTANCE_TYPE,
325+
serializer=custom_serializer,
326+
deserializer=custom_deserializer,
327+
)
328+
329+
assert isinstance(predictor, IPInsightsPredictor)
330+
assert predictor.serializer is custom_serializer
331+
assert predictor.deserializer is custom_deserializer

tests/unit/test_kmeans.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,27 @@ def test_predictor_type(sagemaker_session):
272272
predictor = model.deploy(1, INSTANCE_TYPE)
273273

274274
assert isinstance(predictor, KMeansPredictor)
275+
276+
277+
def test_predictor_custom_serialization(sagemaker_session):
278+
kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
279+
data = RecordSet(
280+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
281+
num_records=1,
282+
feature_dim=FEATURE_DIM,
283+
channel="train",
284+
)
285+
kmeans.fit(data, MINI_BATCH_SIZE)
286+
model = kmeans.create_model()
287+
custom_serializer = Mock()
288+
custom_deserializer = Mock()
289+
predictor = model.deploy(
290+
1,
291+
INSTANCE_TYPE,
292+
serializer=custom_serializer,
293+
deserializer=custom_deserializer,
294+
)
295+
296+
assert isinstance(predictor, KMeansPredictor)
297+
assert predictor.serializer is custom_serializer
298+
assert predictor.deserializer is custom_deserializer

tests/unit/test_knn.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,27 @@ def test_predictor_type(sagemaker_session):
296296
predictor = model.deploy(1, INSTANCE_TYPE)
297297

298298
assert isinstance(predictor, KNNPredictor)
299+
300+
301+
def test_predictor_custom_serialization(sagemaker_session):
302+
knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
303+
data = RecordSet(
304+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
305+
num_records=1,
306+
feature_dim=FEATURE_DIM,
307+
channel="train",
308+
)
309+
knn.fit(data, MINI_BATCH_SIZE)
310+
model = knn.create_model()
311+
custom_serializer = Mock()
312+
custom_deserializer = Mock()
313+
predictor = model.deploy(
314+
1,
315+
INSTANCE_TYPE,
316+
serializer=custom_serializer,
317+
deserializer=custom_deserializer,
318+
)
319+
320+
assert isinstance(predictor, KNNPredictor)
321+
assert predictor.serializer is custom_serializer
322+
assert predictor.deserializer is custom_deserializer

tests/unit/test_lda.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,27 @@ def test_predictor_type(sagemaker_session):
232232
predictor = model.deploy(1, INSTANCE_TYPE)
233233

234234
assert isinstance(predictor, LDAPredictor)
235+
236+
237+
def test_predictor_custom_serialization(sagemaker_session):
238+
lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
239+
data = RecordSet(
240+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
241+
num_records=1,
242+
feature_dim=FEATURE_DIM,
243+
channel="train",
244+
)
245+
lda.fit(data, MINI_BATCH_SZIE)
246+
model = lda.create_model()
247+
custom_serializer = Mock()
248+
custom_deserializer = Mock()
249+
predictor = model.deploy(
250+
1,
251+
INSTANCE_TYPE,
252+
serializer=custom_serializer,
253+
deserializer=custom_deserializer,
254+
)
255+
256+
assert isinstance(predictor, LDAPredictor)
257+
assert predictor.serializer is custom_serializer
258+
assert predictor.deserializer is custom_deserializer

tests/unit/test_linear_learner.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,27 @@ def test_predictor_type(sagemaker_session):
433433
predictor = model.deploy(1, INSTANCE_TYPE)
434434

435435
assert isinstance(predictor, LinearLearnerPredictor)
436+
437+
438+
def test_predictor_custom_serialization(sagemaker_session):
439+
lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
440+
data = RecordSet(
441+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
442+
num_records=1,
443+
feature_dim=FEATURE_DIM,
444+
channel="train",
445+
)
446+
lr.fit(data)
447+
model = lr.create_model()
448+
custom_serializer = Mock()
449+
custom_deserializer = Mock()
450+
predictor = model.deploy(
451+
1,
452+
INSTANCE_TYPE,
453+
serializer=custom_serializer,
454+
deserializer=custom_deserializer,
455+
)
456+
457+
assert isinstance(predictor, LinearLearnerPredictor)
458+
assert predictor.serializer is custom_serializer
459+
assert predictor.deserializer is custom_deserializer

tests/unit/test_mxnet.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,31 @@ def test_model(
427427
assert isinstance(predictor, MXNetPredictor)
428428

429429

430+
@patch("sagemaker.utils.create_tar_file", MagicMock())
431+
def test_model_custom_serialization(
432+
sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version
433+
):
434+
model = MXNetModel(
435+
MODEL_DATA,
436+
role=ROLE,
437+
entry_point=SCRIPT_PATH,
438+
framework_version=mxnet_inference_version,
439+
py_version=mxnet_inference_py_version,
440+
sagemaker_session=sagemaker_session,
441+
)
442+
custom_serializer = Mock()
443+
custom_deserializer = Mock()
444+
predictor = model.deploy(
445+
1,
446+
CPU,
447+
serializer=custom_serializer,
448+
deserializer=custom_deserializer,
449+
)
450+
assert isinstance(predictor, MXNetPredictor)
451+
assert predictor.serializer is custom_serializer
452+
assert predictor.deserializer is custom_deserializer
453+
454+
430455
@patch("sagemaker.utils.repack_model")
431456
def test_model_mms_version(
432457
repack_model,

tests/unit/test_ntm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,27 @@ def test_predictor_type(sagemaker_session):
301301
predictor = model.deploy(1, INSTANCE_TYPE)
302302

303303
assert isinstance(predictor, NTMPredictor)
304+
305+
306+
def test_predictor_custom_serialization(sagemaker_session):
307+
ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
308+
data = RecordSet(
309+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
310+
num_records=1,
311+
feature_dim=FEATURE_DIM,
312+
channel="train",
313+
)
314+
ntm.fit(data, MINI_BATCH_SIZE)
315+
model = ntm.create_model()
316+
custom_serializer = Mock()
317+
custom_deserializer = Mock()
318+
predictor = model.deploy(
319+
1,
320+
INSTANCE_TYPE,
321+
serializer=custom_serializer,
322+
deserializer=custom_deserializer,
323+
)
324+
325+
assert isinstance(predictor, NTMPredictor)
326+
assert predictor.serializer is custom_serializer
327+
assert predictor.deserializer is custom_deserializer

tests/unit/test_pca.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,27 @@ def test_predictor_type(sagemaker_session):
252252
predictor = model.deploy(1, INSTANCE_TYPE)
253253

254254
assert isinstance(predictor, PCAPredictor)
255+
256+
257+
def test_predictor_custom_serialization(sagemaker_session):
258+
pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
259+
data = RecordSet(
260+
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
261+
num_records=1,
262+
feature_dim=FEATURE_DIM,
263+
channel="train",
264+
)
265+
pca.fit(data, MINI_BATCH_SIZE)
266+
model = pca.create_model()
267+
custom_serializer = Mock()
268+
custom_deserializer = Mock()
269+
predictor = model.deploy(
270+
1,
271+
INSTANCE_TYPE,
272+
serializer=custom_serializer,
273+
deserializer=custom_deserializer,
274+
)
275+
276+
assert isinstance(predictor, PCAPredictor)
277+
assert predictor.serializer is custom_serializer
278+
assert predictor.deserializer is custom_deserializer

tests/unit/test_pytorch.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,34 @@ def test_model_image_accelerator(sagemaker_session):
408408
assert "Unsupported Python version: py2." in str(error)
409409

410410

411+
@patch("sagemaker.utils.create_tar_file", MagicMock())
412+
@patch("sagemaker.utils.repack_model", MagicMock())
413+
def test_model_custom_serialization(
414+
sagemaker_session,
415+
pytorch_inference_version,
416+
pytorch_inference_py_version,
417+
):
418+
model = PyTorchModel(
419+
MODEL_DATA,
420+
role=ROLE,
421+
entry_point=SCRIPT_PATH,
422+
framework_version=pytorch_inference_version,
423+
py_version=pytorch_inference_py_version,
424+
sagemaker_session=sagemaker_session,
425+
)
426+
custom_serializer = Mock()
427+
custom_deserializer = Mock()
428+
predictor = model.deploy(
429+
1,
430+
GPU,
431+
serializer=custom_serializer,
432+
deserializer=custom_deserializer,
433+
)
434+
assert isinstance(predictor, PyTorchPredictor)
435+
assert predictor.serializer is custom_serializer
436+
assert predictor.deserializer is custom_deserializer
437+
438+
411439
def test_model_prepare_container_def_no_instance_type_or_image():
412440
model = PyTorchModel(
413441
MODEL_DATA,

tests/unit/test_sklearn.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from mock import Mock
2121
from mock import patch
2222

23-
from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnPredictor
2423
from sagemaker.fw_utils import UploadedCode
24+
from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnPredictor
2525

2626
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
2727
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
@@ -408,6 +408,27 @@ def test_model(sagemaker_session, sklearn_version):
408408
assert isinstance(predictor, SKLearnPredictor)
409409

410410

411+
def test_model_custom_serialization(sagemaker_session, sklearn_version):
412+
model = SKLearnModel(
413+
"s3://some/data.tar.gz",
414+
role=ROLE,
415+
entry_point=SCRIPT_PATH,
416+
framework_version=sklearn_version,
417+
sagemaker_session=sagemaker_session,
418+
)
419+
custom_serializer = Mock()
420+
custom_deserializer = Mock()
421+
predictor = model.deploy(
422+
1,
423+
CPU,
424+
serializer=custom_serializer,
425+
deserializer=custom_deserializer,
426+
)
427+
assert isinstance(predictor, SKLearnPredictor)
428+
assert predictor.serializer is custom_serializer
429+
assert predictor.deserializer is custom_deserializer
430+
431+
411432
def test_attach(sagemaker_session, sklearn_version):
412433
training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-{}".format(
413434
sklearn_version, PYTHON_VERSION

tests/unit/test_sparkml_serving.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,12 @@ def test_predictor_type(sagemaker_session):
5757
predictor = sparkml.deploy(1, TRAIN_INSTANCE_TYPE)
5858

5959
assert isinstance(predictor, SparkMLPredictor)
60+
61+
62+
def test_predictor_custom_serialization(sagemaker_session):
63+
sparkml = SparkMLModel(sagemaker_session=sagemaker_session, model_data=MODEL_DATA, role=ROLE)
64+
custom_serializer = Mock()
65+
predictor = sparkml.deploy(1, TRAIN_INSTANCE_TYPE, serializer=custom_serializer)
66+
67+
assert isinstance(predictor, SparkMLPredictor)
68+
assert predictor.serializer is custom_serializer

tests/unit/test_xgboost.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,27 @@ def test_model(sagemaker_session, xgboost_framework_version):
429429
assert isinstance(predictor, XGBoostPredictor)
430430

431431

432+
def test_model_custom_serialization(sagemaker_session, xgboost_framework_version):
433+
model = XGBoostModel(
434+
"s3://some/data.tar.gz",
435+
role=ROLE,
436+
framework_version=xgboost_framework_version,
437+
entry_point=SCRIPT_PATH,
438+
sagemaker_session=sagemaker_session,
439+
)
440+
custom_serializer = Mock()
441+
custom_deserializer = Mock()
442+
predictor = model.deploy(
443+
1,
444+
CPU,
445+
serializer=custom_serializer,
446+
deserializer=custom_deserializer,
447+
)
448+
assert isinstance(predictor, XGBoostPredictor)
449+
assert predictor.serializer is custom_serializer
450+
assert predictor.deserializer is custom_deserializer
451+
452+
432453
def test_training_image_uri(sagemaker_session, xgboost_framework_version):
433454
xgboost = XGBoost(
434455
entry_point=SCRIPT_PATH,

0 commit comments

Comments
 (0)