|
48 | 48 | IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
|
49 | 49 |
|
50 | 50 |
|
| 51 | +MODEL_DESCRIPTION = "a description" |
| 52 | + |
| 53 | +SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES = ["ml.m4.xlarge"] |
| 54 | +SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES = ["ml.m4.xlarge"] |
| 55 | + |
| 56 | +SUPPORTED_CONTENT_TYPES = ["text/csv", "application/json", "application/jsonlines"] |
| 57 | +SUPPORTED_RESPONSE_MIME_TYPES = ["application/json", "text/csv", "application/jsonlines"] |
| 58 | + |
| 59 | +VALIDATION_FILE_NAME = "input.csv" |
| 60 | +VALIDATION_INPUT_PATH = "s3://" + BUCKET_NAME + "/validation-input-csv/" |
| 61 | +VALIDATION_OUTPUT_PATH = "s3://" + BUCKET_NAME + "/validation-output-csv/" |
| 62 | + |
| 63 | +VALIDATION_SPECIFICATION = { |
| 64 | + "ValidationRole": "some_role", |
| 65 | + "ValidationProfiles": [ |
| 66 | + { |
| 67 | + "ProfileName": "Validation-test", |
| 68 | + "TransformJobDefinition": { |
| 69 | + "BatchStrategy": "SingleRecord", |
| 70 | + "TransformInput": { |
| 71 | + "DataSource": { |
| 72 | + "S3DataSource": { |
| 73 | + "S3DataType": "S3Prefix", |
| 74 | + "S3Uri": VALIDATION_INPUT_PATH, |
| 75 | + } |
| 76 | + }, |
| 77 | + "ContentType": SUPPORTED_CONTENT_TYPES[0], |
| 78 | + }, |
| 79 | + "TransformOutput": { |
| 80 | + "S3OutputPath": VALIDATION_OUTPUT_PATH, |
| 81 | + }, |
| 82 | + "TransformResources": { |
| 83 | + "InstanceType": SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES[0], |
| 84 | + "InstanceCount": 1, |
| 85 | + }, |
| 86 | + }, |
| 87 | + }, |
| 88 | + ], |
| 89 | +} |
| 90 | + |
| 91 | + |
51 | 92 | class DummyFrameworkModel(FrameworkModel):
|
52 | 93 | def __init__(self, **kwargs):
|
53 | 94 | super(DummyFrameworkModel, self).__init__(
|
@@ -687,3 +728,40 @@ def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagema
|
687 | 728 | ]
|
688 | 729 | == "/opt/ml/model/code"
|
689 | 730 | )
|
| 731 | + |
| 732 | + |
| 733 | +@patch("sagemaker.get_model_package_args") |
| 734 | +def test_register_calls_model_package_args(get_model_package_args, sagemaker_session): |
| 735 | + |
| 736 | + source_dir = "s3://blah/blah/blah" |
| 737 | + t = Model( |
| 738 | + entry_point=ENTRY_POINT_INFERENCE, |
| 739 | + role=ROLE, |
| 740 | + sagemaker_session=sagemaker_session, |
| 741 | + source_dir=source_dir, |
| 742 | + image_uri=IMAGE_URI, |
| 743 | + model_data=MODEL_DATA, |
| 744 | + ) |
| 745 | + |
| 746 | + t.register( |
| 747 | + SUPPORTED_CONTENT_TYPES, |
| 748 | + SUPPORTED_RESPONSE_MIME_TYPES, |
| 749 | + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, |
| 750 | + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, |
| 751 | + marketplace_cert=True, |
| 752 | + description=MODEL_DESCRIPTION, |
| 753 | + model_package_name=MODEL_NAME, |
| 754 | + validation_specification=VALIDATION_SPECIFICATION, |
| 755 | + ) |
| 756 | + |
| 757 | + # check that the kwarg validation_specification was passed to the internal method 'get_model_package_args' |
| 758 | + assert ( |
| 759 | + "validation_specification" in get_model_package_args.call_args_list[0][1] |
| 760 | + ), "validation_specification kwarg was not passed to get_model_package_args" |
| 761 | + |
| 762 | + # check that the kwarg validation_specification is identical to the one passed into the method 'register' |
| 763 | + assert ( |
| 764 | + VALIDATION_SPECIFICATION |
| 765 | + == get_model_package_args.call_args_list[0][1]["validation_specification"] |
| 766 | + ), """ValidationSpecification from model.register method is not identical to validation_spec from |
| 767 | + get_model_package_args""" |
0 commit comments