|
19 | 19 | import os
|
20 | 20 | import re
|
21 | 21 | import copy
|
22 |
| -from typing import List, Dict |
| 22 | +from typing import List, Dict, Optional, Union |
23 | 23 |
|
24 | 24 | import sagemaker
|
25 | 25 | from sagemaker import (
|
|
30 | 30 | utils,
|
31 | 31 | git_utils,
|
32 | 32 | )
|
| 33 | +from sagemaker.session import Session |
| 34 | +from sagemaker.model_metrics import ModelMetrics |
33 | 35 | from sagemaker.deprecations import removed_kwargs
|
| 36 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
| 37 | +from sagemaker.metadata_properties import MetadataProperties |
34 | 38 | from sagemaker.predictor import PredictorBase
|
35 | 39 | from sagemaker.serverless import ServerlessInferenceConfig
|
36 | 40 | from sagemaker.transformer import Transformer
|
37 | 41 | from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
|
38 |
| -from sagemaker.utils import unique_name_from_base |
| 42 | +from sagemaker.utils import unique_name_from_base, to_string |
39 | 43 | from sagemaker.async_inference import AsyncInferenceConfig
|
40 | 44 | from sagemaker.predictor_async import AsyncPredictor
|
41 | 45 | from sagemaker.workflow import is_pipeline_variable
|
| 46 | +from sagemaker.workflow.entities import PipelineVariable |
42 | 47 | from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
|
43 | 48 |
|
44 | 49 | LOGGER = logging.getLogger("sagemaker")
|
@@ -78,23 +83,23 @@ class Model(ModelBase):
|
78 | 83 |
|
79 | 84 | def __init__(
|
80 | 85 | self,
|
81 |
| - image_uri, |
82 |
| - model_data=None, |
83 |
| - role=None, |
84 |
| - predictor_cls=None, |
85 |
| - env=None, |
86 |
| - name=None, |
87 |
| - vpc_config=None, |
88 |
| - sagemaker_session=None, |
89 |
| - enable_network_isolation=False, |
90 |
| - model_kms_key=None, |
91 |
| - image_config=None, |
92 |
| - source_dir=None, |
93 |
| - code_location=None, |
94 |
| - entry_point=None, |
95 |
| - container_log_level=logging.INFO, |
96 |
| - dependencies=None, |
97 |
| - git_config=None, |
| 86 | + image_uri: Union[str, PipelineVariable], |
| 87 | + model_data: Optional[Union[str, PipelineVariable]] = None, |
| 88 | + role: Optional[str] = None, |
| 89 | + predictor_cls: Optional[callable] = None, |
| 90 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 91 | + name: Optional[str] = None, |
| 92 | + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, |
| 93 | + sagemaker_session: Optional[Session] = None, |
| 94 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 95 | + model_kms_key: Optional[str] = None, |
| 96 | + image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 97 | + source_dir: Optional[str] = None, |
| 98 | + code_location: Optional[str] = None, |
| 99 | + entry_point: Optional[str] = None, |
| 100 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 101 | + dependencies: Optional[List[str]] = None, |
| 102 | + git_config: Optional[Dict[str, str]] = None, |
98 | 103 | ):
|
99 | 104 | """Initialize an SageMaker ``Model``.
|
100 | 105 |
|
@@ -294,22 +299,22 @@ def __init__(
|
294 | 299 | @runnable_by_pipeline
|
295 | 300 | def register(
|
296 | 301 | self,
|
297 |
| - content_types, |
298 |
| - response_types, |
299 |
| - inference_instances=None, |
300 |
| - transform_instances=None, |
301 |
| - model_package_name=None, |
302 |
| - model_package_group_name=None, |
303 |
| - image_uri=None, |
304 |
| - model_metrics=None, |
305 |
| - metadata_properties=None, |
306 |
| - marketplace_cert=False, |
307 |
| - approval_status=None, |
308 |
| - description=None, |
309 |
| - drift_check_baselines=None, |
310 |
| - customer_metadata_properties=None, |
311 |
| - validation_specification=None, |
312 |
| - domain=None, |
| 302 | + content_types: List[Union[str, PipelineVariable]], |
| 303 | + response_types: List[Union[str, PipelineVariable]], |
| 304 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 305 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 306 | + model_package_name: Optional[Union[str, PipelineVariable]] = None, |
| 307 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 308 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 309 | + model_metrics: Optional[ModelMetrics] = None, |
| 310 | + metadata_properties: Optional[MetadataProperties] = None, |
| 311 | + marketplace_cert: bool = False, |
| 312 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 313 | + description: Optional[str] = None, |
| 314 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 315 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 316 | + validation_specification: Optional[Union[str, PipelineVariable]] = None, |
| 317 | + domain: Optional[Union[str, PipelineVariable]] = None, |
313 | 318 | ):
|
314 | 319 | """Creates a model package for creating SageMaker models or listing on Marketplace.
|
315 | 320 |
|
@@ -385,10 +390,10 @@ def register(
|
385 | 390 | @runnable_by_pipeline
|
386 | 391 | def create(
|
387 | 392 | self,
|
388 |
| - instance_type: str = None, |
389 |
| - accelerator_type: str = None, |
390 |
| - serverless_inference_config: ServerlessInferenceConfig = None, |
391 |
| - tags: List[Dict[str, str]] = None, |
| 393 | + instance_type: Optional[str] = None, |
| 394 | + accelerator_type: Optional[str] = None, |
| 395 | + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, |
| 396 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
392 | 397 | ):
|
393 | 398 | """Create a SageMaker Model Entity
|
394 | 399 |
|
@@ -570,7 +575,7 @@ def _script_mode_env_vars(self):
|
570 | 575 | return {
|
571 | 576 | SCRIPT_PARAM_NAME.upper(): script_name or str(),
|
572 | 577 | DIR_PARAM_NAME.upper(): dir_name or str(),
|
573 |
| - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), |
| 578 | + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), |
574 | 579 | SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
|
575 | 580 | }
|
576 | 581 |
|
@@ -1239,19 +1244,19 @@ class FrameworkModel(Model):
|
1239 | 1244 |
|
1240 | 1245 | def __init__(
|
1241 | 1246 | self,
|
1242 |
| - model_data, |
1243 |
| - image_uri, |
1244 |
| - role, |
1245 |
| - entry_point, |
1246 |
| - source_dir=None, |
1247 |
| - predictor_cls=None, |
1248 |
| - env=None, |
1249 |
| - name=None, |
1250 |
| - container_log_level=logging.INFO, |
1251 |
| - code_location=None, |
1252 |
| - sagemaker_session=None, |
1253 |
| - dependencies=None, |
1254 |
| - git_config=None, |
| 1247 | + model_data: Union[str, PipelineVariable], |
| 1248 | + image_uri: Union[str, PipelineVariable], |
| 1249 | + role: str, |
| 1250 | + entry_point: str, |
| 1251 | + source_dir: Optional[str] = None, |
| 1252 | + predictor_cls: Optional[callable] = None, |
| 1253 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 1254 | + name: Optional[str] = None, |
| 1255 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 1256 | + code_location: Optional[str] = None, |
| 1257 | + sagemaker_session: Optional[Session] = None, |
| 1258 | + dependencies: Optional[List[str]] = None, |
| 1259 | + git_config: Optional[Dict[str, str]] = None, |
1255 | 1260 | **kwargs,
|
1256 | 1261 | ):
|
1257 | 1262 | """Initialize a ``FrameworkModel``.
|
|
0 commit comments