|
58 | 58 | mock_schema_builder = MagicMock()
|
59 | 59 | mock_schema_builder.sample_input = mock_sample_input
|
60 | 60 | mock_schema_builder.sample_output = mock_sample_output
|
| 61 | +MOCK_IMAGE_CONFIG = ( |
| 62 | + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" |
| 63 | + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" |
| 64 | +) |
61 | 65 |
|
62 | 66 |
|
63 | 67 | class TestTransformersBuilder(unittest.TestCase):
|
@@ -100,3 +104,43 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
|
100 | 104 |
|
101 | 105 | with self.assertRaises(ValueError) as _:
|
102 | 106 | model.deploy(mode=Mode.IN_PROCESS)
|
| 107 | + |
| 108 | + @patch( |
| 109 | + "sagemaker.serve.builder.transformers_builder._get_nb_instance", |
| 110 | + return_value="ml.g5.24xlarge", |
| 111 | + ) |
| 112 | + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) |
| 113 | + def test_image_uri( |
| 114 | + self, |
| 115 | + mock_get_nb_instance, |
| 116 | + mock_telemetry, |
| 117 | + ): |
| 118 | + builder = ModelBuilder( |
| 119 | + model=mock_model_id, |
| 120 | + schema_builder=mock_schema_builder, |
| 121 | + mode=Mode.LOCAL_CONTAINER, |
| 122 | + image_uri=MOCK_IMAGE_CONFIG, |
| 123 | + ) |
| 124 | + |
| 125 | + builder._prepare_for_mode = MagicMock() |
| 126 | + builder._prepare_for_mode.side_effect = None |
| 127 | + |
| 128 | + model = builder.build() |
| 129 | + builder.serve_settings.telemetry_opt_out = True |
| 130 | + |
| 131 | + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() |
| 132 | + predictor = model.deploy(model_data_download_timeout=1800) |
| 133 | + |
| 134 | + assert builder.image_uri == MOCK_IMAGE_CONFIG |
| 135 | + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" |
| 136 | + assert isinstance(predictor, TransformersLocalModePredictor) |
| 137 | + |
| 138 | + assert builder.nb_instance_type == "ml.g5.24xlarge" |
| 139 | + |
| 140 | + builder._original_deploy = MagicMock() |
| 141 | + builder._prepare_for_mode.return_value = (None, {}) |
| 142 | + predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") |
| 143 | + assert "HF_MODEL_ID" in model.env |
| 144 | + |
| 145 | + with self.assertRaises(ValueError) as _: |
| 146 | + model.deploy(mode=Mode.IN_PROCESS) |
0 commit comments