Skip to content

Commit 900befd

Browse files
author
Dewen Qi
committed
change: Add Pipeline annotation in model base class and tensorflow estimator
1 parent 72c12bb commit 900befd

File tree

8 files changed

+125
-91
lines changed

8 files changed

+125
-91
lines changed

src/sagemaker/drift_check_baselines.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,25 @@
1313
"""This file contains code related to drift check baselines"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional
17+
18+
from sagemaker.model_metrics import MetricsSource, FileSource
19+
1620

1721
class DriftCheckBaselines(object):
1822
"""Accepts drift check baselines parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias_config_file=None,
27-
bias_pre_training_constraints=None,
28-
bias_post_training_constraints=None,
29-
explainability_constraints=None,
30-
explainability_config_file=None,
26+
model_statistics: Optional[MetricsSource] = None,
27+
model_constraints: Optional[MetricsSource] = None,
28+
model_data_statistics: Optional[MetricsSource] = None,
29+
model_data_constraints: Optional[MetricsSource] = None,
30+
bias_config_file: Optional[FileSource] = None,
31+
bias_pre_training_constraints: Optional[MetricsSource] = None,
32+
bias_post_training_constraints: Optional[MetricsSource] = None,
33+
explainability_constraints: Optional[MetricsSource] = None,
34+
explainability_config_file: Optional[FileSource] = None,
3135
):
3236
"""Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict.
3337

src/sagemaker/metadata_properties.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
"""This file contains code related to metadata properties."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class MetadataProperties(object):
1822
"""Accepts metadata properties parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
commit_id=None,
23-
repository=None,
24-
generated_by=None,
25-
project_id=None,
26+
commit_id: Optional[Union[str, PipelineVariable]] = None,
27+
repository: Optional[Union[str, PipelineVariable]] = None,
28+
generated_by: Optional[Union[str, PipelineVariable]] = None,
29+
project_id: Optional[Union[str, PipelineVariable]] = None,
2630
):
2731
"""Initialize a ``MetadataProperties`` instance and turn parameters into dict.
2832

src/sagemaker/model.py

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import re
2121
import copy
22-
from typing import List, Dict
22+
from typing import List, Dict, Optional, Union
2323

2424
import sagemaker
2525
from sagemaker import (
@@ -30,15 +30,20 @@
3030
utils,
3131
git_utils,
3232
)
33+
from sagemaker.session import Session
34+
from sagemaker.model_metrics import ModelMetrics
3335
from sagemaker.deprecations import removed_kwargs
36+
from sagemaker.drift_check_baselines import DriftCheckBaselines
37+
from sagemaker.metadata_properties import MetadataProperties
3438
from sagemaker.predictor import PredictorBase
3539
from sagemaker.serverless import ServerlessInferenceConfig
3640
from sagemaker.transformer import Transformer
3741
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
38-
from sagemaker.utils import unique_name_from_base
42+
from sagemaker.utils import unique_name_from_base, to_string
3943
from sagemaker.async_inference import AsyncInferenceConfig
4044
from sagemaker.predictor_async import AsyncPredictor
4145
from sagemaker.workflow import is_pipeline_variable
46+
from sagemaker.workflow.entities import PipelineVariable
4247
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
4348

4449
LOGGER = logging.getLogger("sagemaker")
@@ -78,23 +83,23 @@ class Model(ModelBase):
7883

7984
def __init__(
8085
self,
81-
image_uri,
82-
model_data=None,
83-
role=None,
84-
predictor_cls=None,
85-
env=None,
86-
name=None,
87-
vpc_config=None,
88-
sagemaker_session=None,
89-
enable_network_isolation=False,
90-
model_kms_key=None,
91-
image_config=None,
92-
source_dir=None,
93-
code_location=None,
94-
entry_point=None,
95-
container_log_level=logging.INFO,
96-
dependencies=None,
97-
git_config=None,
86+
image_uri: Union[str, PipelineVariable],
87+
model_data: Optional[Union[str, PipelineVariable]] = None,
88+
role: Optional[str] = None,
89+
predictor_cls: Optional[callable] = None,
90+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
91+
name: Optional[str] = None,
92+
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
93+
sagemaker_session: Optional[Session] = None,
94+
enable_network_isolation: Union[bool, PipelineVariable] = False,
95+
model_kms_key: Optional[str] = None,
96+
image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
97+
source_dir: Optional[str] = None,
98+
code_location: Optional[str] = None,
99+
entry_point: Optional[str] = None,
100+
container_log_level: Union[int, PipelineVariable] = logging.INFO,
101+
dependencies: Optional[List[str]] = None,
102+
git_config: Optional[Dict[str, str]] = None,
98103
):
99104
"""Initialize an SageMaker ``Model``.
100105
@@ -294,22 +299,22 @@ def __init__(
294299
@runnable_by_pipeline
295300
def register(
296301
self,
297-
content_types,
298-
response_types,
299-
inference_instances=None,
300-
transform_instances=None,
301-
model_package_name=None,
302-
model_package_group_name=None,
303-
image_uri=None,
304-
model_metrics=None,
305-
metadata_properties=None,
306-
marketplace_cert=False,
307-
approval_status=None,
308-
description=None,
309-
drift_check_baselines=None,
310-
customer_metadata_properties=None,
311-
validation_specification=None,
312-
domain=None,
302+
content_types: List[Union[str, PipelineVariable]],
303+
response_types: List[Union[str, PipelineVariable]],
304+
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
305+
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
306+
model_package_name: Optional[Union[str, PipelineVariable]] = None,
307+
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
308+
image_uri: Optional[Union[str, PipelineVariable]] = None,
309+
model_metrics: Optional[ModelMetrics] = None,
310+
metadata_properties: Optional[MetadataProperties] = None,
311+
marketplace_cert: bool = False,
312+
approval_status: Optional[Union[str, PipelineVariable]] = None,
313+
description: Optional[str] = None,
314+
drift_check_baselines: Optional[DriftCheckBaselines] = None,
315+
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
316+
validation_specification: Optional[Union[str, PipelineVariable]] = None,
317+
domain: Optional[Union[str, PipelineVariable]] = None,
313318
):
314319
"""Creates a model package for creating SageMaker models or listing on Marketplace.
315320
@@ -385,10 +390,10 @@ def register(
385390
@runnable_by_pipeline
386391
def create(
387392
self,
388-
instance_type: str = None,
389-
accelerator_type: str = None,
390-
serverless_inference_config: ServerlessInferenceConfig = None,
391-
tags: List[Dict[str, str]] = None,
393+
instance_type: Optional[str] = None,
394+
accelerator_type: Optional[str] = None,
395+
serverless_inference_config: Optional[ServerlessInferenceConfig] = None,
396+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
392397
):
393398
"""Create a SageMaker Model Entity
394399
@@ -570,7 +575,7 @@ def _script_mode_env_vars(self):
570575
return {
571576
SCRIPT_PARAM_NAME.upper(): script_name or str(),
572577
DIR_PARAM_NAME.upper(): dir_name or str(),
573-
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
578+
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
574579
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
575580
}
576581

@@ -1239,19 +1244,19 @@ class FrameworkModel(Model):
12391244

12401245
def __init__(
12411246
self,
1242-
model_data,
1243-
image_uri,
1244-
role,
1245-
entry_point,
1246-
source_dir=None,
1247-
predictor_cls=None,
1248-
env=None,
1249-
name=None,
1250-
container_log_level=logging.INFO,
1251-
code_location=None,
1252-
sagemaker_session=None,
1253-
dependencies=None,
1254-
git_config=None,
1247+
model_data: Union[str, PipelineVariable],
1248+
image_uri: Union[str, PipelineVariable],
1249+
role: str,
1250+
entry_point: str,
1251+
source_dir: Optional[str] = None,
1252+
predictor_cls: Optional[callable] = None,
1253+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
1254+
name: Optional[str] = None,
1255+
container_log_level: int = logging.INFO,
1256+
code_location: Optional[str] = None,
1257+
sagemaker_session: Optional[Session] = None,
1258+
dependencies: Optional[List[str]] = None,
1259+
git_config: Optional[Dict[str, str]] = None,
12551260
**kwargs,
12561261
):
12571262
"""Initialize a ``FrameworkModel``.

src/sagemaker/model_metrics.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@
1313
"""This file contains code related to model metrics, including metric source and file source."""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
18+
from sagemaker.workflow.entities import PipelineVariable
19+
1620

1721
class ModelMetrics(object):
1822
"""Accepts model metrics parameters for conversion to request dict."""
1923

2024
def __init__(
2125
self,
22-
model_statistics=None,
23-
model_constraints=None,
24-
model_data_statistics=None,
25-
model_data_constraints=None,
26-
bias=None,
27-
explainability=None,
28-
bias_pre_training=None,
29-
bias_post_training=None,
26+
model_statistics: Optional["MetricsSource"] = None,
27+
model_constraints: Optional["MetricsSource"] = None,
28+
model_data_statistics: Optional["MetricsSource"] = None,
29+
model_data_constraints: Optional["MetricsSource"] = None,
30+
bias: Optional["MetricsSource"] = None,
31+
explainability: Optional["MetricsSource"] = None,
32+
bias_pre_training: Optional["MetricsSource"] = None,
33+
bias_post_training: Optional["MetricsSource"] = None,
3034
):
3135
"""Initialize a ``ModelMetrics`` instance and turn parameters into dict.
3236
@@ -99,9 +103,9 @@ class MetricsSource(object):
99103

100104
def __init__(
101105
self,
102-
content_type,
103-
s3_uri,
104-
content_digest=None,
106+
content_type: Union[str, PipelineVariable],
107+
s3_uri: Union[str, PipelineVariable],
108+
content_digest: Optional[Union[str, PipelineVariable]] = None,
105109
):
106110
"""Initialize a ``MetricsSource`` instance and turn parameters into dict.
107111
@@ -127,9 +131,9 @@ class FileSource(object):
127131

128132
def __init__(
129133
self,
130-
s3_uri,
131-
content_digest=None,
132-
content_type=None,
134+
s3_uri: Union[str, PipelineVariable],
135+
content_digest: Optional[Union[str, PipelineVariable]] = None,
136+
content_type: Optional[Union[str, PipelineVariable]] = None,
133137
):
134138
"""Initialize a ``FileSource`` instance and turn parameters into dict.
135139

src/sagemaker/serverless/serverless_inference_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class ServerlessInferenceConfig(object):
2727

2828
def __init__(
2929
self,
30-
memory_size_in_mb=2048,
31-
max_concurrency=5,
30+
memory_size_in_mb: int = 2048,
31+
max_concurrency: int = 5,
3232
):
3333
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.
3434

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2633,7 +2633,9 @@ def _create_model_request(
26332633
request["VpcConfig"] = vpc_config
26342634

26352635
if enable_network_isolation:
2636-
request["EnableNetworkIsolation"] = True
2636+
# enable_network_isolation may be a pipeline variable which is
2637+
# parsed in execution time
2638+
request["EnableNetworkIsolation"] = enable_network_isolation
26372639

26382640
return request
26392641

src/sagemaker/tensorflow/estimator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17+
from typing import Optional, Union, Dict
1718

1819
from packaging import version
1920

@@ -27,6 +28,7 @@
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2829
from sagemaker.workflow import is_pipeline_variable
2930
from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig
31+
from sagemaker.workflow.entities import PipelineVariable
3032

3133
logger = logging.getLogger("sagemaker")
3234

@@ -41,12 +43,12 @@ class TensorFlow(Framework):
4143

4244
def __init__(
4345
self,
44-
py_version=None,
45-
framework_version=None,
46-
model_dir=None,
47-
image_uri=None,
48-
distribution=None,
49-
compiler_config=None,
46+
py_version: Optional[str] = None,
47+
framework_version: Optional[str] = None,
48+
model_dir: Optional[Union[str, PipelineVariable]] = None,
49+
image_uri: Optional[Union[str, PipelineVariable]] = None,
50+
distribution: Optional[Dict[str, str]] = None,
51+
compiler_config: Optional[TrainingCompilerConfig] = None,
5052
**kwargs,
5153
):
5254
"""Initialize a ``TensorFlow`` estimator.
@@ -183,7 +185,7 @@ def __init__(
183185
self.py_version = py_version
184186
self.instance_type = instance_type
185187

186-
if distribution is not None:
188+
if distribution is not None and instance_type is not None:
187189
fw.warn_if_parameter_server_with_multi_gpu(
188190
training_instance_type=instance_type, distribution=distribution
189191
)
@@ -254,6 +256,8 @@ def _only_legacy_mode_supported(self):
254256

255257
def _only_python_3_supported(self):
256258
"""Placeholder docstring"""
259+
if not self.framework_version:
260+
return False
257261
return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION
258262

259263
@classmethod

src/sagemaker/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from sagemaker import deprecations
3535
from sagemaker.session_settings import SessionSettings
36-
36+
from sagemaker.workflow import is_pipeline_variable
3737

3838
ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
3939
MAX_BUCKET_PATHS_COUNT = 5
@@ -722,3 +722,14 @@ def get_data_bucket(self, region_requested=None):
722722

723723

724724
get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix")
725+
726+
727+
def to_string(obj: object):
728+
"""Convert an object to string
729+
730+
This helper function handles converting PipelineVariable object to string as well
731+
732+
Args:
733+
obj (object): The object to be converted
734+
"""
735+
return obj.to_string() if is_pipeline_variable(obj) else str(obj)

0 commit comments

Comments
 (0)