|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 | 16 | import logging
|
| 17 | +from typing import Optional, Union, List, Dict |
17 | 18 |
|
18 | 19 | import sagemaker
|
19 |
| -from sagemaker import image_uris |
| 20 | +from sagemaker import image_uris, ModelMetrics |
20 | 21 | from sagemaker.deserializers import JSONDeserializer
|
| 22 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
21 | 23 | from sagemaker.fw_utils import (
|
22 | 24 | model_code_key_prefix,
|
23 | 25 | validate_version_or_image_args,
|
24 | 26 | )
|
| 27 | +from sagemaker.metadata_properties import MetadataProperties |
25 | 28 | from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
|
26 | 29 | from sagemaker.predictor import Predictor
|
27 | 30 | from sagemaker.serializers import JSONSerializer
|
28 | 31 | from sagemaker.session import Session
|
| 32 | +from sagemaker.utils import to_string |
| 33 | +from sagemaker.workflow.entities import PipelineVariable |
29 | 34 |
|
30 | 35 | logger = logging.getLogger("sagemaker")
|
31 | 36 |
|
@@ -92,16 +97,16 @@ class HuggingFaceModel(FrameworkModel):
|
92 | 97 |
|
93 | 98 | def __init__(
|
94 | 99 | self,
|
95 |
| - role, |
96 |
| - model_data=None, |
97 |
| - entry_point=None, |
98 |
| - transformers_version=None, |
99 |
| - tensorflow_version=None, |
100 |
| - pytorch_version=None, |
101 |
| - py_version=None, |
102 |
| - image_uri=None, |
103 |
| - predictor_cls=HuggingFacePredictor, |
104 |
| - model_server_workers=None, |
| 100 | + role: str, |
| 101 | + model_data: Optional[Union[str, PipelineVariable]] = None, |
| 102 | + entry_point: Optional[str] = None, |
| 103 | + transformers_version: Optional[str] = None, |
| 104 | + tensorflow_version: Optional[str] = None, |
| 105 | + pytorch_version: Optional[str] = None, |
| 106 | + py_version: Optional[str] = None, |
| 107 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 108 | + predictor_cls: callable = HuggingFacePredictor, |
| 109 | + model_server_workers: Optional[Union[int, PipelineVariable]] = None, |
105 | 110 | **kwargs,
|
106 | 111 | ):
|
107 | 112 | """Initialize a HuggingFaceModel.
|
@@ -291,21 +296,21 @@ def deploy(
|
291 | 296 |
|
292 | 297 | def register(
|
293 | 298 | self,
|
294 |
| - content_types, |
295 |
| - response_types, |
296 |
| - inference_instances=None, |
297 |
| - transform_instances=None, |
298 |
| - model_package_name=None, |
299 |
| - model_package_group_name=None, |
300 |
| - image_uri=None, |
301 |
| - model_metrics=None, |
302 |
| - metadata_properties=None, |
303 |
| - marketplace_cert=False, |
304 |
| - approval_status=None, |
305 |
| - description=None, |
306 |
| - drift_check_baselines=None, |
307 |
| - customer_metadata_properties=None, |
308 |
| - domain=None, |
| 299 | + content_types: List[Union[str, PipelineVariable]], |
| 300 | + response_types: List[Union[str, PipelineVariable]], |
| 301 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 302 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 303 | + model_package_name: Optional[Union[str, PipelineVariable]] = None, |
| 304 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 305 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 306 | + model_metrics: Optional[ModelMetrics] = None, |
| 307 | + metadata_properties: Optional[MetadataProperties] = None, |
| 308 | + marketplace_cert: bool = False, |
| 309 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 310 | + description: Optional[str] = None, |
| 311 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 312 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 313 | + domain: Optional[Union[str, PipelineVariable]] = None, |
309 | 314 | ):
|
310 | 315 | """Creates a model package for creating SageMaker models or listing on Marketplace.
|
311 | 316 |
|
@@ -409,7 +414,9 @@ def prepare_container_def(
|
409 | 414 | deploy_env.update(self._script_mode_env_vars())
|
410 | 415 |
|
411 | 416 | if self.model_server_workers:
|
412 |
| - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) |
| 417 | + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( |
| 418 | + self.model_server_workers |
| 419 | + ) |
413 | 420 | return sagemaker.container_def(
|
414 | 421 | deploy_image, self.repacked_model_data or self.model_data, deploy_env
|
415 | 422 | )
|
|
0 commit comments