Skip to content

Commit 8e36074

Browse files
author
Dewen Qi
committed
change: Add Pipeline annotation in model base class and tensorflow estimator
Model annotate update
1 parent 7d30d8c commit 8e36074

File tree

8 files changed

+133
-98
lines changed

8 files changed

+133
-98
lines changed

src/sagemaker/drift_check_baselines.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
"""This file contains code related to drift check baselines"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional
17+
18+
from sagemaker.model_metrics import MetricsSource, FileSource
19+
1620

1721
class DriftCheckBaselines(object):
1822
"""Accepts drift check baselines parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias_config_file=None,
27-
bias_pre_training_constraints=None,
28-
bias_post_training_constraints=None,
29-
explainability_constraints=None,
30-
explainability_config_file=None,
26+
model_statistics: Optional[MetricsSource] = None,
27+
model_constraints: Optional[MetricsSource] = None,
28+
model_data_statistics: Optional[MetricsSource] = None,
29+
model_data_constraints: Optional[MetricsSource] = None,
30+
bias_config_file: Optional[FileSource] = None,
31+
bias_pre_training_constraints: Optional[MetricsSource] = None,
32+
bias_post_training_constraints: Optional[MetricsSource] = None,
33+
explainability_constraints: Optional[MetricsSource] = None,
34+
explainability_config_file: Optional[FileSource] = None,
3135
):
3236
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
3337

src/sagemaker/estimator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
code_location: Optional[str] = None,
150150
entry_point: Optional[Union[str, PipelineVariable]] = None,
151151
dependencies: Optional[List[Union[str]]] = None,
152-
instance_groups: Optional[Dict[str, Union[str, int]]] = None,
152+
instance_groups: Optional[Dict[str, Union[str, int]]] = None, # TODO test
153153
**kwargs,
154154
):
155155
"""Initialize an ``EstimatorBase`` instance.
@@ -2862,15 +2862,26 @@ def _validate_and_set_debugger_configs(self):
28622862
# Disable debugger if checkpointing is enabled by the customer
28632863
if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config:
28642864
if self._framework_name in {"mxnet", "pytorch", "tensorflow"}:
2865-
if self.instance_count > 1 or (
2865+
disable_debugger_hook_config = False
2866+
if not is_pipeline_variable(self.instance_count) and self.instance_count > 1:
2867+
disable_debugger_hook_config = True
2868+
if (
28662869
hasattr(self, "distribution")
28672870
and self.distribution is not None # pylint: disable=no-member
28682871
):
2872+
disable_debugger_hook_config = True
2873+
if disable_debugger_hook_config:
28692874
logger.info(
28702875
"SMDebug Does Not Currently Support \
28712876
Distributed Training Jobs With Checkpointing Enabled"
28722877
)
28732878
self.debugger_hook_config = False
2879+
elif is_pipeline_variable(self.instance_count):
2880+
logger.warning(
2881+
"SMDebug does not currently support distributed training jobs "
2882+
"with checkpointing enabled, which means"
2883+
"instance_count should not be > 1 in pipeline execution."
2884+
)
28742885

28752886
if self.debugger_hook_config is False:
28762887
if self.environment is None:

src/sagemaker/metadata_properties.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
"""This file contains code related to metadata properties."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class MetadataProperties(object):
1822
"""Accepts metadata properties parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
commit_id=None,
23-
repository=None,
24-
generated_by=None,
25-
project_id=None,
26+
commit_id: Optional[Union[str, PipelineVariable]] = None,
27+
repository: Optional[Union[str, PipelineVariable]] = None,
28+
generated_by: Optional[Union[str, PipelineVariable]] = None,
29+
project_id: Optional[Union[str, PipelineVariable]] = None,
2630
):
2731
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.
2832

src/sagemaker/model.py

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import logging
1919
import os
2020
import copy
21-
from typing import List, Dict
21+
from typing import List, Dict, Optional, Union
2222

2323
import sagemaker
2424
from sagemaker import (
@@ -29,18 +29,24 @@
2929
utils,
3030
git_utils,
3131
)
32+
from sagemaker.session import Session
33+
from sagemaker.model_metrics import ModelMetrics
3234
from sagemaker.deprecations import removed_kwargs
35+
from sagemaker.drift_check_baselines import DriftCheckBaselines
36+
from sagemaker.metadata_properties import MetadataProperties
3337
from sagemaker.predictor import PredictorBase
3438
from sagemaker.serverless import ServerlessInferenceConfig
3539
from sagemaker.transformer import Transformer
3640
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
3741
from sagemaker.utils import (
3842
unique_name_from_base,
3943
update_container_with_inference_params,
44+
to_string,
4045
)
4146
from sagemaker.async_inference import AsyncInferenceConfig
4247
from sagemaker.predictor_async import AsyncPredictor
4348
from sagemaker.workflow import is_pipeline_variable
49+
from sagemaker.workflow.entities import PipelineVariable
4450
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4551

4652
LOGGER = logging.getLogger("sagemaker")
@@ -82,23 +88,23 @@ class Model(ModelBase):
8288

8389
def __init__(
8490
self,
85-
image_uri,
86-
model_data=None,
87-
role=None,
88-
predictor_cls=None,
89-
env=None,
90-
name=None,
91-
vpc_config=None,
92-
sagemaker_session=None,
93-
enable_network_isolation=False,
94-
model_kms_key=None,
95-
image_config=None,
96-
source_dir=None,
97-
code_location=None,
98-
entry_point=None,
99-
container_log_level=logging.INFO,
100-
dependencies=None,
101-
git_config=None,
91+
image_uri: Union[str, PipelineVariable],
92+
model_data: Optional[Union[str, PipelineVariable]] = None,
93+
role: Optional[str] = None,
94+
predictor_cls: Optional[callable] = None,
95+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
96+
name: Optional[str] = None,
97+
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
98+
sagemaker_session: Optional[Session] = None,
99+
enable_network_isolation: Union[bool, PipelineVariable] = False,
100+
model_kms_key: Optional[str] = None,
101+
image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
102+
source_dir: Optional[str] = None,
103+
code_location: Optional[str] = None,
104+
entry_point: Optional[str] = None,
105+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
106+
dependencies: Optional[List[str]] = None,
107+
git_config: Optional[Dict[str, str]] = None,
102108
):
103109
"""Initialize an SageMaker ``Model``.
104110
@@ -298,28 +304,28 @@ def __init__(
298304
@runnable_by_pipeline
299305
def register(
300306
self,
301-
content_types,
302-
response_types,
303-
inference_instances=None,
304-
transform_instances=None,
305-
model_package_name=None,
306-
model_package_group_name=None,
307-
image_uri=None,
308-
model_metrics=None,
309-
metadata_properties=None,
310-
marketplace_cert=False,
311-
approval_status=None,
312-
description=None,
313-
drift_check_baselines=None,
314-
customer_metadata_properties=None,
315-
validation_specification=None,
316-
domain=None,
317-
task=None,
318-
sample_payload_url=None,
319-
framework=None,
320-
framework_version=None,
321-
nearest_model_name=None,
322-
data_input_configuration=None,
307+
content_types: List[Union[str, PipelineVariable]],
308+
response_types: List[Union[str, PipelineVariable]],
309+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
310+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
311+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
312+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
313+
image_uri: Optional[Union[str, PipelineVariable]] = None,
314+
model_metrics: Optional[ModelMetrics] = None,
315+
metadata_properties: Optional[MetadataProperties] = None,
316+
marketplace_cert: bool = False,
317+
approval_status: Optional[Union[str, PipelineVariable]] = None,
318+
description: Optional[str] = None,
319+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
320+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
321+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
322+
domain: Optional[Union[str, PipelineVariable]] = None,
323+
task: Optional[Union[str, PipelineVariable]] = None,
324+
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
325+
framework: Optional[Union[str, PipelineVariable]] = None,
326+
framework_version: Optional[Union[str, PipelineVariable]] = None,
327+
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
328+
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
323329
):
324330
"""Creates a model package for creating SageMaker models or listing on Marketplace.
325331
@@ -349,11 +355,11 @@ def register(
349355
metadata properties (default: None).
350356
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
351357
"MACHINE_LEARNING" (default: None).
352-
sample_payload_url (str): The S3 path where the sample payload is stored
353-
(default: None).
354358
task (str): Task values which are supported by Inference Recommender are "FILL_MASK",
355359
"IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION",
356360
"CLASSIFICATION", "REGRESSION", "OTHER" (default: None).
361+
sample_payload_url (str): The S3 path where the sample payload is stored
362+
(default: None).
357363
framework (str): Machine learning framework of the model package container image
358364
(default: None).
359365
framework_version (str): Framework version of the Model Package Container Image
@@ -421,10 +427,10 @@ def register(
421427
@runnable_by_pipeline
422428
def create(
423429
self,
424-
instance_type: str = None,
425-
accelerator_type: str = None,
426-
serverless_inference_config: ServerlessInferenceConfig = None,
427-
tags: List[Dict[str, str]] = None,
430+
instance_type: Optional[str] = None,
431+
accelerator_type: Optional[str] = None,
432+
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
433+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
428434
):
429435
"""Create a SageMaker Model Entity
430436
@@ -608,7 +614,7 @@ def _script_mode_env_vars(self):
608614
return {
609615
SCRIPT_PARAM_NAME.upper(): script_name or str(),
610616
DIR_PARAM_NAME.upper(): dir_name or str(),
611-
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
617+
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
612618
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
613619
}
614620

@@ -1286,19 +1292,19 @@ class FrameworkModel(Model):
12861292

12871293
def __init__(
12881294
self,
1289-
model_data,
1290-
image_uri,
1291-
role,
1292-
entry_point,
1293-
source_dir=None,
1294-
predictor_cls=None,
1295-
env=None,
1296-
name=None,
1297-
container_log_level=logging.INFO,
1298-
code_location=None,
1299-
sagemaker_session=None,
1300-
dependencies=None,
1301-
git_config=None,
1295+
model_data: Union[str, PipelineVariable],
1296+
image_uri: Union[str, PipelineVariable],
1297+
role: str,
1298+
entry_point: str,
1299+
source_dir: Optional[str] = None,
1300+
predictor_cls: Optional[callable] = None,
1301+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
1302+
name: Optional[str] = None,
1303+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
1304+
code_location: Optional[str] = None,
1305+
sagemaker_session: Optional[Session] = None,
1306+
dependencies: Optional[List[str]] = None,
1307+
git_config: Optional[Dict[str, str]] = None,
13021308
**kwargs,
13031309
):
13041310
"""Initialize a ``FrameworkModel``.

src/sagemaker/model_metrics.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@
1313
"""This file contains code related to model metrics, including metric source and file source."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class ModelMetrics(object):
1822
"""Accepts model metrics parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias=None,
27-
explainability=None,
28-
bias_pre_training=None,
29-
bias_post_training=None,
26+
model_statistics: Optional["MetricsSource"] = None,
27+
model_constraints: Optional["MetricsSource"] = None,
28+
model_data_statistics: Optional["MetricsSource"] = None,
29+
model_data_constraints: Optional["MetricsSource"] = None,
30+
bias: Optional["MetricsSource"] = None,
31+
explainability: Optional["MetricsSource"] = None,
32+
bias_pre_training: Optional["MetricsSource"] = None,
33+
bias_post_training: Optional["MetricsSource"] = None,
3034
):
3135
"""Initialize a ``ModelMetrics`` instance and turn parameters into dict.
3236
@@ -99,9 +103,9 @@ class MetricsSource(object):
99103

100104
def __init__(
101105
self,
102-
content_type,
103-
s3_uri,
104-
content_digest=None,
106+
content_type: Union[str, PipelineVariable],
107+
s3_uri: Union[str, PipelineVariable],
108+
content_digest: Optional[Union[str, PipelineVariable]] = None,
105109
):
106110
"""Initialize a ``MetricsSource`` instance and turn parameters into dict.
107111
@@ -127,9 +131,9 @@ class FileSource(object):
127131

128132
def __init__(
129133
self,
130-
s3_uri,
131-
content_digest=None,
132-
content_type=None,
134+
s3_uri: Union[str, PipelineVariable],
135+
content_digest: Optional[Union[str, PipelineVariable]] = None,
136+
content_type: Optional[Union[str, PipelineVariable]] = None,
133137
):
134138
"""Initialize a ``FileSource`` instance and turn parameters into dict.
135139

src/sagemaker/serverless/serverless_inference_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class ServerlessInferenceConfig(object):
2727

2828
def __init__(
2929
self,
30-
memory_size_in_mb=2048,
31-
max_concurrency=5,
30+
memory_size_in_mb: int = 2048,
31+
max_concurrency: int = 5,
3232
):
3333
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.
3434

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2633,7 +2633,9 @@ def _create_model_request(
26332633
request["VpcConfig"] = vpc_config
26342634

26352635
if enable_network_isolation:
2636-
request["EnableNetworkIsolation"] = True
2636+
# enable_network_isolation may be a pipeline variable which is
2637+
# parsed in execution time
2638+
request["EnableNetworkIsolation"] = enable_network_isolation
26372639

26382640
return request
26392641

0 commit comments

Comments
 (0)