Skip to content

change: Add PipelineVariable annotation in framework models #3188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 42 additions & 32 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@
from __future__ import absolute_import

import logging
from typing import Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris
from sagemaker import image_uris, ModelMetrics
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.fw_utils import (
model_code_key_prefix,
python_deprecation_warning,
validate_version_or_image_args,
)
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -75,14 +81,14 @@ class ChainerModel(FrameworkModel):

def __init__(
self,
model_data,
role,
entry_point,
image_uri=None,
framework_version=None,
py_version=None,
predictor_cls=ChainerPredictor,
model_server_workers=None,
model_data: Union[str, PipelineVariable],
role: str,
entry_point: str,
image_uri: Optional[Union[str, PipelineVariable]] = None,
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
predictor_cls: callable = ChainerPredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
**kwargs
):
"""Initialize an ChainerModel.
Expand Down Expand Up @@ -142,27 +148,27 @@ def __init__(

def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status=None,
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
sample_payload_url=None,
task=None,
framework=None,
framework_version=None,
nearest_model_name=None,
data_input_configuration=None,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
model_metrics: Optional[ModelMetrics] = None,
metadata_properties: Optional[MetadataProperties] = None,
marketplace_cert: bool = False,
approval_status: Optional[Union[str, PipelineVariable]] = None,
description: Optional[str] = None,
drift_check_baselines: Optional[DriftCheckBaselines] = None,
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
domain: Optional[Union[str, PipelineVariable]] = None,
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
task: Optional[Union[str, PipelineVariable]] = None,
framework: Optional[Union[str, PipelineVariable]] = None,
framework_version: Optional[Union[str, PipelineVariable]] = None,
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -218,6 +224,8 @@ def register(
region_name=self.sagemaker_session.boto_session.region_name,
instance_type=instance_type,
)
if not is_pipeline_variable(framework):
framework = (framework or self._framework_name).upper()
return super(ChainerModel, self).register(
content_types,
response_types,
Expand All @@ -236,7 +244,7 @@ def register(
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=(framework or self._framework_name).upper(),
framework=framework,
framework_version=framework_version or self.framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
Expand Down Expand Up @@ -282,7 +290,9 @@ def prepare_container_def(
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
self.model_server_workers
)
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)

def serving_image_uri(
Expand Down
6 changes: 2 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
build_dict,
get_config_value,
name_from_base,
to_string,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
Expand Down Expand Up @@ -1947,10 +1948,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):

current_hyperparameters = estimator.hyperparameters()
if current_hyperparameters is not None:
hyperparameters = {
str(k): (v.to_string() if is_pipeline_variable(v) else str(v))
for (k, v) in current_hyperparameters.items()
}
hyperparameters = {str(k): to_string(v) for (k, v) in current_hyperparameters.items()}

train_args = config.copy()
train_args["input_mode"] = estimator.input_mode
Expand Down
88 changes: 49 additions & 39 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@
from __future__ import absolute_import

import logging
from typing import Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris
from sagemaker import image_uris, ModelMetrics
from sagemaker.deserializers import JSONDeserializer
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.fw_utils import (
model_code_key_prefix,
validate_version_or_image_args,
)
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.session import Session
from sagemaker.utils import to_string
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -100,16 +106,16 @@ class HuggingFaceModel(FrameworkModel):

def __init__(
self,
role,
model_data=None,
entry_point=None,
transformers_version=None,
tensorflow_version=None,
pytorch_version=None,
py_version=None,
image_uri=None,
predictor_cls=HuggingFacePredictor,
model_server_workers=None,
role: str,
model_data: Optional[Union[str, PipelineVariable]] = None,
entry_point: Optional[str] = None,
transformers_version: Optional[str] = None,
tensorflow_version: Optional[str] = None,
pytorch_version: Optional[str] = None,
py_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
predictor_cls: callable = HuggingFacePredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
**kwargs,
):
"""Initialize a HuggingFaceModel.
Expand Down Expand Up @@ -299,27 +305,27 @@ def deploy(

def register(
self,
content_types,
response_types,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
model_metrics=None,
metadata_properties=None,
marketplace_cert=False,
approval_status=None,
description=None,
drift_check_baselines=None,
customer_metadata_properties=None,
domain=None,
sample_payload_url=None,
task=None,
framework=None,
framework_version=None,
nearest_model_name=None,
data_input_configuration=None,
content_types: List[Union[str, PipelineVariable]],
response_types: List[Union[str, PipelineVariable]],
inference_instances: Optional[List[Union[str, PipelineVariable]]] = None,
transform_instances: Optional[List[Union[str, PipelineVariable]]] = None,
model_package_name: Optional[Union[str, PipelineVariable]] = None,
model_package_group_name: Optional[Union[str, PipelineVariable]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
model_metrics: Optional[ModelMetrics] = None,
metadata_properties: Optional[MetadataProperties] = None,
marketplace_cert: bool = False,
approval_status: Optional[Union[str, PipelineVariable]] = None,
description: Optional[str] = None,
drift_check_baselines: Optional[DriftCheckBaselines] = None,
customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
domain: Optional[Union[str, PipelineVariable]] = None,
sample_payload_url: Optional[Union[str, PipelineVariable]] = None,
task: Optional[Union[str, PipelineVariable]] = None,
framework: Optional[Union[str, PipelineVariable]] = None,
framework_version: Optional[Union[str, PipelineVariable]] = None,
nearest_model_name: Optional[Union[str, PipelineVariable]] = None,
data_input_configuration: Optional[Union[str, PipelineVariable]] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.

Expand Down Expand Up @@ -377,6 +383,13 @@ def register(
region_name=self.sagemaker_session.boto_session.region_name,
instance_type=instance_type,
)
if not is_pipeline_variable(framework):
framework = (
framework
or fetch_framework_and_framework_version(
self.tensorflow_version, self.pytorch_version
)[0]
).upper()
return super(HuggingFaceModel, self).register(
content_types,
response_types,
Expand All @@ -395,12 +408,7 @@ def register(
domain=domain,
sample_payload_url=sample_payload_url,
task=task,
framework=(
framework
or fetch_framework_and_framework_version(
self.tensorflow_version, self.pytorch_version
)[0]
).upper(),
framework=framework,
framework_version=framework_version
or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[
1
Expand Down Expand Up @@ -449,7 +457,9 @@ def prepare_container_def(
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string(
self.model_server_workers
)
return sagemaker.container_def(
deploy_image, self.repacked_model_data or self.model_data, deploy_env
)
Expand Down
20 changes: 13 additions & 7 deletions src/sagemaker/multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import os
from typing import Union, Optional

from six.moves.urllib.parse import urlparse

Expand All @@ -22,6 +23,8 @@
from sagemaker.deprecations import removed_kwargs
from sagemaker.model import Model
from sagemaker.session import Session
from sagemaker.utils import pop_out_unused_kwarg
from sagemaker.workflow.entities import PipelineVariable

MULTI_MODEL_CONTAINER_MODE = "MultiModel"

Expand All @@ -34,12 +37,12 @@ class MultiDataModel(Model):

def __init__(
self,
name,
model_data_prefix,
model=None,
image_uri=None,
role=None,
sagemaker_session=None,
name: str,
model_data_prefix: str,
model: Optional[Model] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
role: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
**kwargs,
):
"""Initialize a ``MultiDataModel``.
Expand Down Expand Up @@ -106,6 +109,7 @@ def __init__(

# Set the ``Model`` parameters if the model parameter is not specified
if not self.model:
pop_out_unused_kwarg("model_data", kwargs, self.model_data_prefix)
super(MultiDataModel, self).__init__(
image_uri,
self.model_data_prefix,
Expand All @@ -115,7 +119,9 @@ def __init__(
**kwargs,
)

def prepare_container_def(self, instance_type=None, accelerator_type=None):
def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
):
"""Return a container definition set.

Definition set includes MultiModel mode, model data and other parameters
Expand Down
Loading