|
18 | 18 | import logging
|
19 | 19 | import os
|
20 | 20 | import copy
|
21 |
| -from typing import List, Dict |
| 21 | +from typing import List, Dict, Optional, Union |
22 | 22 |
|
23 | 23 | import sagemaker
|
24 | 24 | from sagemaker import (
|
|
29 | 29 | utils,
|
30 | 30 | git_utils,
|
31 | 31 | )
|
| 32 | +from sagemaker.session import Session |
| 33 | +from sagemaker.model_metrics import ModelMetrics |
32 | 34 | from sagemaker.deprecations import removed_kwargs
|
| 35 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
| 36 | +from sagemaker.metadata_properties import MetadataProperties |
33 | 37 | from sagemaker.predictor import PredictorBase
|
34 | 38 | from sagemaker.serverless import ServerlessInferenceConfig
|
35 | 39 | from sagemaker.transformer import Transformer
|
36 | 40 | from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
|
37 | 41 | from sagemaker.utils import (
|
38 | 42 | unique_name_from_base,
|
39 | 43 | update_container_with_inference_params,
|
| 44 | + to_string, |
40 | 45 | )
|
41 | 46 | from sagemaker.async_inference import AsyncInferenceConfig
|
42 | 47 | from sagemaker.predictor_async import AsyncPredictor
|
43 | 48 | from sagemaker.workflow import is_pipeline_variable
|
| 49 | +from sagemaker.workflow.entities import PipelineVariable |
44 | 50 | from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
|
45 | 51 |
|
46 | 52 | LOGGER = logging.getLogger("sagemaker")
|
@@ -82,23 +88,23 @@ class Model(ModelBase):
|
82 | 88 |
|
83 | 89 | def __init__(
|
84 | 90 | self,
|
85 |
| - image_uri, |
86 |
| - model_data=None, |
87 |
| - role=None, |
88 |
| - predictor_cls=None, |
89 |
| - env=None, |
90 |
| - name=None, |
91 |
| - vpc_config=None, |
92 |
| - sagemaker_session=None, |
93 |
| - enable_network_isolation=False, |
94 |
| - model_kms_key=None, |
95 |
| - image_config=None, |
96 |
| - source_dir=None, |
97 |
| - code_location=None, |
98 |
| - entry_point=None, |
99 |
| - container_log_level=logging.INFO, |
100 |
| - dependencies=None, |
101 |
| - git_config=None, |
| 91 | + image_uri: Union[str, PipelineVariable], |
| 92 | + model_data: Optional[Union[str, PipelineVariable]] = None, |
| 93 | + role: Optional[str] = None, |
| 94 | + predictor_cls: Optional[callable] = None, |
| 95 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 96 | + name: Optional[str] = None, |
| 97 | + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, |
| 98 | + sagemaker_session: Optional[Session] = None, |
| 99 | + enable_network_isolation: Union[bool, PipelineVariable] = False, |
| 100 | + model_kms_key: Optional[str] = None, |
| 101 | + image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 102 | + source_dir: Optional[str] = None, |
| 103 | + code_location: Optional[str] = None, |
| 104 | + entry_point: Optional[str] = None, |
| 105 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 106 | + dependencies: Optional[List[str]] = None, |
| 107 | + git_config: Optional[Dict[str, str]] = None, |
102 | 108 | ):
|
103 | 109 | """Initialize an SageMaker ``Model``.
|
104 | 110 |
|
@@ -298,28 +304,28 @@ def __init__(
|
298 | 304 | @runnable_by_pipeline
|
299 | 305 | def register(
|
300 | 306 | self,
|
301 |
| - content_types, |
302 |
| - response_types, |
303 |
| - inference_instances=None, |
304 |
| - transform_instances=None, |
305 |
| - model_package_name=None, |
306 |
| - model_package_group_name=None, |
307 |
| - image_uri=None, |
308 |
| - model_metrics=None, |
309 |
| - metadata_properties=None, |
310 |
| - marketplace_cert=False, |
311 |
| - approval_status=None, |
312 |
| - description=None, |
313 |
| - drift_check_baselines=None, |
314 |
| - customer_metadata_properties=None, |
315 |
| - validation_specification=None, |
316 |
| - domain=None, |
317 |
| - task=None, |
318 |
| - sample_payload_url=None, |
319 |
| - framework=None, |
320 |
| - framework_version=None, |
321 |
| - nearest_model_name=None, |
322 |
| - data_input_configuration=None, |
| 307 | + content_types: List[Union[str, PipelineVariable]], |
| 308 | + response_types: List[Union[str, PipelineVariable]], |
| 309 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 310 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 311 | + model_package_name: Optional[Union[str, PipelineVariable]] = None, |
| 312 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 313 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 314 | + model_metrics: Optional[ModelMetrics] = None, |
| 315 | + metadata_properties: Optional[MetadataProperties] = None, |
| 316 | + marketplace_cert: bool = False, |
| 317 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 318 | + description: Optional[str] = None, |
| 319 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 320 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 321 | + validation_specification: Optional[Union[str, PipelineVariable]] = None, |
| 322 | + domain: Optional[Union[str, PipelineVariable]] = None, |
| 323 | + task: Optional[Union[str, PipelineVariable]] = None, |
| 324 | + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, |
| 325 | + framework: Optional[Union[str, PipelineVariable]] = None, |
| 326 | + framework_version: Optional[Union[str, PipelineVariable]] = None, |
| 327 | + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, |
| 328 | + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, |
323 | 329 | ):
|
324 | 330 | """Creates a model package for creating SageMaker models or listing on Marketplace.
|
325 | 331 |
|
@@ -349,11 +355,11 @@ def register(
|
349 | 355 | metadata properties (default: None).
|
350 | 356 | domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
|
351 | 357 | "MACHINE_LEARNING" (default: None).
|
352 |
| - sample_payload_url (str): The S3 path where the sample payload is stored |
353 |
| - (default: None). |
354 | 358 | task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
|
355 | 359 | "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
|
356 | 360 | "CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
|
| 361 | + sample_payload_url (str): The S3 path where the sample payload is stored |
| 362 | + (default: None). |
357 | 363 | framework (str): Machine learning framework of the model package container image
|
358 | 364 | (default: None).
|
359 | 365 | framework_version (str): Framework version of the Model Package Container Image
|
@@ -421,10 +427,10 @@ def register(
|
421 | 427 | @runnable_by_pipeline
|
422 | 428 | def create(
|
423 | 429 | self,
|
424 |
| - instance_type: str = None, |
425 |
| - accelerator_type: str = None, |
426 |
| - serverless_inference_config: ServerlessInferenceConfig = None, |
427 |
| - tags: List[Dict[str, str]] = None, |
| 430 | + instance_type: Optional[str] = None, |
| 431 | + accelerator_type: Optional[str] = None, |
| 432 | + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, |
| 433 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
428 | 434 | ):
|
429 | 435 | """Create a SageMaker Model Entity
|
430 | 436 |
|
@@ -608,7 +614,7 @@ def _script_mode_env_vars(self):
|
608 | 614 | return {
|
609 | 615 | SCRIPT_PARAM_NAME.upper(): script_name or str(),
|
610 | 616 | DIR_PARAM_NAME.upper(): dir_name or str(),
|
611 |
| - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), |
| 617 | + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), |
612 | 618 | SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
|
613 | 619 | }
|
614 | 620 |
|
@@ -1286,19 +1292,19 @@ class FrameworkModel(Model):
|
1286 | 1292 |
|
1287 | 1293 | def __init__(
|
1288 | 1294 | self,
|
1289 |
| - model_data, |
1290 |
| - image_uri, |
1291 |
| - role, |
1292 |
| - entry_point, |
1293 |
| - source_dir=None, |
1294 |
| - predictor_cls=None, |
1295 |
| - env=None, |
1296 |
| - name=None, |
1297 |
| - container_log_level=logging.INFO, |
1298 |
| - code_location=None, |
1299 |
| - sagemaker_session=None, |
1300 |
| - dependencies=None, |
1301 |
| - git_config=None, |
| 1295 | + model_data: Union[str, PipelineVariable], |
| 1296 | + image_uri: Union[str, PipelineVariable], |
| 1297 | + role: str, |
| 1298 | + entry_point: str, |
| 1299 | + source_dir: Optional[str] = None, |
| 1300 | + predictor_cls: Optional[callable] = None, |
| 1301 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 1302 | + name: Optional[str] = None, |
| 1303 | + container_log_level: Union[int, PipelineVariable] = logging.INFO, |
| 1304 | + code_location: Optional[str] = None, |
| 1305 | + sagemaker_session: Optional[Session] = None, |
| 1306 | + dependencies: Optional[List[str]] = None, |
| 1307 | + git_config: Optional[Dict[str, str]] = None, |
1302 | 1308 | **kwargs,
|
1303 | 1309 | ):
|
1304 | 1310 | """Initialize a ``FrameworkModel``.
|
|
0 commit comments