|
17 | 17 | import subprocess
|
18 | 18 |
|
19 | 19 | import sagemaker
|
20 |
| -from sagemaker.model import FrameworkModel, ModelPackage |
| 20 | +from sagemaker.model import FrameworkModel, Model, ModelPackage |
21 | 21 | from sagemaker.predictor import RealTimePredictor
|
22 | 22 |
|
23 | 23 | import pytest
|
@@ -417,37 +417,100 @@ def test_model_enable_network_isolation(sagemaker_session):
|
417 | 417 | assert model.enable_network_isolation() is False
|
418 | 418 |
|
419 | 419 |
|
420 |
| -@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) |
421 |
| -def test_model_create_transformer(sagemaker_session): |
422 |
| - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
423 |
| - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE |
424 |
| - ) |
| 420 | +@patch("sagemaker.model.Model._create_sagemaker_model") |
| 421 | +def test_model_create_transformer(create_sagemaker_model, sagemaker_session): |
| 422 | + model_name = "auto-generated-model" |
| 423 | + model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session) |
| 424 | + |
| 425 | + instance_type = "ml.m4.xlarge" |
| 426 | + transformer = model.transformer(instance_count=1, instance_type=instance_type) |
| 427 | + |
| 428 | + create_sagemaker_model.assert_called_with(instance_type, tags=None) |
| 429 | + |
| 430 | + assert isinstance(transformer, sagemaker.transformer.Transformer) |
| 431 | + assert transformer.model_name == model_name |
| 432 | + assert transformer.instance_type == instance_type |
| 433 | + assert transformer.instance_count == 1 |
| 434 | + assert transformer.sagemaker_session == sagemaker_session |
| 435 | + assert transformer.base_transform_job_name == model_name |
| 436 | + |
| 437 | + assert transformer.strategy is None |
| 438 | + assert transformer.env is None |
| 439 | + assert transformer.output_path is None |
| 440 | + assert transformer.output_kms_key is None |
| 441 | + assert transformer.accept is None |
| 442 | + assert transformer.assemble_with is None |
| 443 | + assert transformer.volume_kms_key is None |
| 444 | + assert transformer.max_concurrent_transforms is None |
| 445 | + assert transformer.max_payload is None |
| 446 | + assert transformer.tags is None |
| 447 | + |
| 448 | + |
| 449 | +@patch("sagemaker.model.Model._create_sagemaker_model") |
| 450 | +def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session): |
| 451 | + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) |
425 | 452 |
|
426 |
| - tags = [{"Key": "k", "Value": "v"}] |
427 |
| - model = DummyFrameworkModel(sagemaker_session=sagemaker_session) |
428 | 453 | instance_type = "ml.m4.xlarge"
|
429 |
| - model.name = "auto-generated-model" |
| 454 | + strategy = "MultiRecord" |
| 455 | + assemble_with = "Line" |
| 456 | + output_path = "s3://bucket/path" |
| 457 | + kms_key = "key" |
| 458 | + accept = "text/csv" |
| 459 | + env = {"test": True} |
| 460 | + max_concurrent_transforms = 1 |
| 461 | + max_payload = 6 |
| 462 | + tags = [{"Key": "k", "Value": "v"}] |
| 463 | + |
430 | 464 | transformer = model.transformer(
|
431 |
| - instance_count=1, instance_type=instance_type, env={"test": True}, tags=tags |
| 465 | + instance_count=1, |
| 466 | + instance_type=instance_type, |
| 467 | + strategy=strategy, |
| 468 | + assemble_with=assemble_with, |
| 469 | + output_path=output_path, |
| 470 | + output_kms_key=kms_key, |
| 471 | + accept=accept, |
| 472 | + env=env, |
| 473 | + max_concurrent_transforms=max_concurrent_transforms, |
| 474 | + max_payload=max_payload, |
| 475 | + tags=tags, |
| 476 | + volume_kms_key=kms_key, |
432 | 477 | )
|
| 478 | + |
| 479 | + create_sagemaker_model.assert_called_with(instance_type, tags=tags) |
| 480 | + |
433 | 481 | assert isinstance(transformer, sagemaker.transformer.Transformer)
|
434 |
| - assert transformer.model_name == "auto-generated-model" |
435 |
| - assert transformer.instance_type == "ml.m4.xlarge" |
436 |
| - assert transformer.env == {"test": True} |
| 482 | + assert transformer.strategy == strategy |
| 483 | + assert transformer.assemble_with == assemble_with |
| 484 | + assert transformer.output_path == output_path |
| 485 | + assert transformer.output_kms_key == kms_key |
| 486 | + assert transformer.accept == accept |
| 487 | + assert transformer.max_concurrent_transforms == max_concurrent_transforms |
| 488 | + assert transformer.max_payload == max_payload |
| 489 | + assert transformer.env == env |
| 490 | + assert transformer.tags == tags |
| 491 | + assert transformer.volume_kms_key == kms_key |
| 492 | + |
| 493 | + |
| 494 | +@patch("sagemaker.model.Model._create_sagemaker_model") |
| 495 | +def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session): |
| 496 | + model = Model( |
| 497 | + MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True |
| 498 | + ) |
437 | 499 |
|
438 |
| - sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags) |
| 500 | + transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"}) |
| 501 | + assert transformer.env is None |
439 | 502 |
|
440 | 503 |
|
441 | 504 | @patch("sagemaker.session.Session")
|
442 | 505 | @patch("sagemaker.local.LocalSession")
|
443 | 506 | @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
|
444 | 507 | def test_transformer_creates_correct_session(local_session, session):
|
445 |
| - model = DummyFrameworkModel(sagemaker_session=None) |
| 508 | + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) |
446 | 509 | transformer = model.transformer(instance_count=1, instance_type="local")
|
447 | 510 | assert model.sagemaker_session == local_session.return_value
|
448 | 511 | assert transformer.sagemaker_session == local_session.return_value
|
449 | 512 |
|
450 |
| - model = DummyFrameworkModel(sagemaker_session=None) |
| 513 | + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) |
451 | 514 | transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge")
|
452 | 515 | assert model.sagemaker_session == session.return_value
|
453 | 516 | assert transformer.sagemaker_session == session.return_value
|
|
0 commit comments