Skip to content

Commit ff0b9f4

Browse files
authored
infra: add Model unit tests for all transformer() params (#1415)
1 parent 78eddfd commit ff0b9f4

File tree

1 file changed

+79
-16
lines changed

1 file changed

+79
-16
lines changed

tests/unit/test_model.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import subprocess
1818

1919
import sagemaker
20-
from sagemaker.model import FrameworkModel, ModelPackage
20+
from sagemaker.model import FrameworkModel, Model, ModelPackage
2121
from sagemaker.predictor import RealTimePredictor
2222

2323
import pytest
@@ -417,37 +417,100 @@ def test_model_enable_network_isolation(sagemaker_session):
417417
assert model.enable_network_isolation() is False
418418

419419

420-
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
421-
def test_model_create_transformer(sagemaker_session):
422-
sagemaker_session.sagemaker_client.describe_model_package = Mock(
423-
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE
424-
)
420+
@patch("sagemaker.model.Model._create_sagemaker_model")
421+
def test_model_create_transformer(create_sagemaker_model, sagemaker_session):
422+
model_name = "auto-generated-model"
423+
model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session)
424+
425+
instance_type = "ml.m4.xlarge"
426+
transformer = model.transformer(instance_count=1, instance_type=instance_type)
427+
428+
create_sagemaker_model.assert_called_with(instance_type, tags=None)
429+
430+
assert isinstance(transformer, sagemaker.transformer.Transformer)
431+
assert transformer.model_name == model_name
432+
assert transformer.instance_type == instance_type
433+
assert transformer.instance_count == 1
434+
assert transformer.sagemaker_session == sagemaker_session
435+
assert transformer.base_transform_job_name == model_name
436+
437+
assert transformer.strategy is None
438+
assert transformer.env is None
439+
assert transformer.output_path is None
440+
assert transformer.output_kms_key is None
441+
assert transformer.accept is None
442+
assert transformer.assemble_with is None
443+
assert transformer.volume_kms_key is None
444+
assert transformer.max_concurrent_transforms is None
445+
assert transformer.max_payload is None
446+
assert transformer.tags is None
447+
448+
449+
@patch("sagemaker.model.Model._create_sagemaker_model")
450+
def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session):
451+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
425452

426-
tags = [{"Key": "k", "Value": "v"}]
427-
model = DummyFrameworkModel(sagemaker_session=sagemaker_session)
428453
instance_type = "ml.m4.xlarge"
429-
model.name = "auto-generated-model"
454+
strategy = "MultiRecord"
455+
assemble_with = "Line"
456+
output_path = "s3://bucket/path"
457+
kms_key = "key"
458+
accept = "text/csv"
459+
env = {"test": True}
460+
max_concurrent_transforms = 1
461+
max_payload = 6
462+
tags = [{"Key": "k", "Value": "v"}]
463+
430464
transformer = model.transformer(
431-
instance_count=1, instance_type=instance_type, env={"test": True}, tags=tags
465+
instance_count=1,
466+
instance_type=instance_type,
467+
strategy=strategy,
468+
assemble_with=assemble_with,
469+
output_path=output_path,
470+
output_kms_key=kms_key,
471+
accept=accept,
472+
env=env,
473+
max_concurrent_transforms=max_concurrent_transforms,
474+
max_payload=max_payload,
475+
tags=tags,
476+
volume_kms_key=kms_key,
432477
)
478+
479+
create_sagemaker_model.assert_called_with(instance_type, tags=tags)
480+
433481
assert isinstance(transformer, sagemaker.transformer.Transformer)
434-
assert transformer.model_name == "auto-generated-model"
435-
assert transformer.instance_type == "ml.m4.xlarge"
436-
assert transformer.env == {"test": True}
482+
assert transformer.strategy == strategy
483+
assert transformer.assemble_with == assemble_with
484+
assert transformer.output_path == output_path
485+
assert transformer.output_kms_key == kms_key
486+
assert transformer.accept == accept
487+
assert transformer.max_concurrent_transforms == max_concurrent_transforms
488+
assert transformer.max_payload == max_payload
489+
assert transformer.env == env
490+
assert transformer.tags == tags
491+
assert transformer.volume_kms_key == kms_key
492+
493+
494+
@patch("sagemaker.model.Model._create_sagemaker_model")
495+
def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session):
496+
model = Model(
497+
MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True
498+
)
437499

438-
sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags)
500+
transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"})
501+
assert transformer.env is None
439502

440503

441504
@patch("sagemaker.session.Session")
442505
@patch("sagemaker.local.LocalSession")
443506
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
444507
def test_transformer_creates_correct_session(local_session, session):
445-
model = DummyFrameworkModel(sagemaker_session=None)
508+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
446509
transformer = model.transformer(instance_count=1, instance_type="local")
447510
assert model.sagemaker_session == local_session.return_value
448511
assert transformer.sagemaker_session == local_session.return_value
449512

450-
model = DummyFrameworkModel(sagemaker_session=None)
513+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
451514
transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge")
452515
assert model.sagemaker_session == session.return_value
453516
assert transformer.sagemaker_session == session.return_value

0 commit comments

Comments
 (0)