Skip to content

feat: jumpstart model id suggestions #2899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 18, 2022
Merged
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ SageMaker Python SDK is tested on:
- Python 3.6
- Python 3.7
- Python 3.8
- Python 3.9

AWS Permissions
~~~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions doc/workflows/pipelines/sagemaker.workflow.pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ Steps
.. autoclass:: sagemaker.workflow.clarify_check_step.ClarifyCheckConfig

.. autoclass:: sagemaker.workflow.clarify_check_step.ClarifyCheckStep

.. autoclass:: sagemaker.workflow.fail_step.FailStep
44 changes: 23 additions & 21 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def read_version():

# Declare minimal set for installation
required_packages = [
"attrs",
"attrs==20.3.0",
"boto3>=1.20.21",
"google-pasta",
"numpy>=1.9.0",
Expand All @@ -49,36 +49,38 @@ def read_version():
# Specific use case dependencies
extras = {
"local": [
"urllib3>=1.21.1,!=1.25,!=1.25.1",
"docker-compose>=1.25.2",
"docker==5.0.0",
"PyYAML>=5.3, <6", # PyYAML version has to match docker-compose requirements
"urllib3==1.26.8",
"docker-compose==1.29.2",
"docker~=5.0.0",
"PyYAML==5.4.1", # PyYAML version has to match docker-compose requirements
],
"scipy": ["scipy>=0.19.0"],
"scipy": ["scipy==1.5.4"],
}
# Meta dependency groups
extras["all"] = [item for group in extras.values() for item in group]
# Tests specific dependencies (do not need to be included in 'all')
extras["test"] = (
[
extras["all"],
"tox",
"flake8",
"pytest<6.1.0",
"pytest-cov",
"pytest-rerunfailures",
"pytest-timeout",
"tox==3.24.5",
"flake8==4.0.1",
"pytest==6.0.2",
"pytest-cov==3.0.0",
"pytest-rerunfailures==10.2",
"pytest-timeout==2.1.0",
"pytest-xdist==2.4.0",
"coverage<6.2",
"mock",
"contextlib2",
"awslogs",
"black",
"coverage>=5.2, <6.2",
"mock==4.0.3",
"contextlib2==21.6.0",
"awslogs==0.14.0",
"black==22.1.0",
"stopit==1.1.2",
"apache-airflow==1.10.11",
"fabric>=2.0",
"requests>=2.20.0, <3",
"sagemaker-experiments",
"apache-airflow==2.2.3",
"apache-airflow-providers-amazon==3.0.0",
"attrs==20.3.0",
"fabric==2.6.0",
"requests==2.27.1",
"sagemaker-experiments==0.1.35",
],
)

Expand Down
31 changes: 26 additions & 5 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This module defines the JumpStartModelsCache class."""
from __future__ import absolute_import
import datetime
from difflib import get_close_matches
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice find!

from typing import List, Optional
import json
import boto3
Expand Down Expand Up @@ -204,14 +205,34 @@ def _get_manifest_key_from_model_id_semantic_version(
sm_version_to_use = sm_version_to_use_list[0]

error_msg = (
f"Unable to find model manifest for {model_id} with version {version} "
f"compatible with your SageMaker version ({sm_version}). "
f"Unable to find model manifest for '{model_id}' with version '{version}' "
f"compatible with your SageMaker version ('{sm_version}'). "
f"Consider upgrading your SageMaker library to at least version "
f"{sm_version_to_use} so you can use version "
f"{model_version_to_use_incompatible_with_sagemaker} of {model_id}."
f"'{sm_version_to_use}' so you can use version "
f"'{model_version_to_use_incompatible_with_sagemaker}' of '{model_id}'."
)
raise KeyError(error_msg)
error_msg = f"Unable to find model manifest for {model_id} with version {version}."

error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
error_msg += (
"Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html"
" for updated list of models. "
)

other_model_id_version = self._select_version(
"*", versions_incompatible_with_sagemaker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add a comment as to why you use this variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reiterate: please explain why we use a variable called versions_incompatitble_with_sagemaker here.

) # all versions here are incompatible with sagemaker
if other_model_id_version is not None:
error_msg += (
f"Consider using model ID '{model_id}' with version "
f"'{other_model_id_version}'."
)

else:
possible_model_ids = [header.model_id for header in manifest.values()]
closest_model_id = get_close_matches(model_id, possible_model_ids, n=1, cutoff=0)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: should you show only one? or say for example: at least 1, up to 3 when score > xx ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just do 1 and make it simple. If we want to make this really good, we're better off just having a separate utility for searching model ids.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another question, can you get an IndexError here if get_close_matches returns an empty list?

If that's possible, could you handle this edge case and add a unit test for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because the cutoff is 0, there will always be a match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I know this is super edge case, but could there be a case where possible_model_ids is an empty list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if the manifest is empty

error_msg += f"Did you mean to use model ID '{closest_model_id}'?"

raise KeyError(error_msg)

def _get_file_from_s3(
Expand Down
12 changes: 6 additions & 6 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def from_json(self, json_obj: Dict[str, str]) -> None:
class JumpStartECRSpecs(JumpStartDataHolderType):
"""Data class for JumpStart ECR specs."""

__slots__ = {
__slots__ = [
"framework",
"framework_version",
"py_version",
"huggingface_transformers_version",
}
]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartECRSpecs object from its json representation.
Expand Down Expand Up @@ -173,7 +173,7 @@ def to_json(self) -> Dict[str, Any]:
class JumpStartHyperparameter(JumpStartDataHolderType):
"""Data class for JumpStart hyperparameter definition in the training container."""

__slots__ = {
__slots__ = [
"name",
"type",
"options",
Expand All @@ -183,7 +183,7 @@ class JumpStartHyperparameter(JumpStartDataHolderType):
"max",
"exclusive_min",
"exclusive_max",
}
]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartHyperparameter object from its json representation.
Expand Down Expand Up @@ -234,12 +234,12 @@ def to_json(self) -> Dict[str, Any]:
class JumpStartEnvironmentVariable(JumpStartDataHolderType):
"""Data class for JumpStart environment variable definitions in the hosting container."""

__slots__ = {
__slots__ = [
"name",
"type",
"default",
"scope",
}
]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartEnvironmentVariable object from its json representation.
Expand Down
22 changes: 11 additions & 11 deletions src/sagemaker/jumpstart/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _validate_hyperparameter(

if len(hyperparameter_spec) > 1:
raise JumpStartHyperparametersError(
f"Unable to perform validation -- found multiple hyperparameter "
"Unable to perform validation -- found multiple hyperparameter "
f"'{hyperparameter_name}' in model specs."
)

Expand All @@ -76,35 +76,35 @@ def _validate_hyperparameter(
if hyperparameter_value not in hyperparameter_spec.options:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have one of the following "
f"values: {', '.join(hyperparameter_spec.options)}"
f"values: {', '.join(hyperparameter_spec.options)}."
)

if hasattr(hyperparameter_spec, "min"):
if len(hyperparameter_value) < hyperparameter_spec.min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length no less than "
f"{hyperparameter_spec.min}"
f"{hyperparameter_spec.min}."
)

if hasattr(hyperparameter_spec, "exclusive_min"):
if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length greater than "
f"{hyperparameter_spec.exclusive_min}"
f"{hyperparameter_spec.exclusive_min}."
)

if hasattr(hyperparameter_spec, "max"):
if len(hyperparameter_value) > hyperparameter_spec.max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length no greater than "
f"{hyperparameter_spec.max}"
f"{hyperparameter_spec.max}."
)

if hasattr(hyperparameter_spec, "exclusive_max"):
if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must have length less than "
f"{hyperparameter_spec.exclusive_max}"
f"{hyperparameter_spec.exclusive_max}."
)

# validate numeric types
Expand All @@ -125,35 +125,35 @@ def _validate_hyperparameter(
if not hyperparameter_value_str[start_index:].isdigit():
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must be integer type "
"('{hyperparameter_value}')."
f"('{hyperparameter_value}')."
)

if hasattr(hyperparameter_spec, "min"):
if numeric_hyperparam_value < hyperparameter_spec.min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' can be no less than "
"{hyperparameter_spec.min}."
f"{hyperparameter_spec.min}."
)

if hasattr(hyperparameter_spec, "max"):
if numeric_hyperparam_value > hyperparameter_spec.max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' can be no greater than "
"{hyperparameter_spec.max}."
f"{hyperparameter_spec.max}."
)

if hasattr(hyperparameter_spec, "exclusive_min"):
if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must be greater than "
"{hyperparameter_spec.exclusive_min}."
f"{hyperparameter_spec.exclusive_min}."
)

if hasattr(hyperparameter_spec, "exclusive_max"):
if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max:
raise JumpStartHyperparametersError(
f"Hyperparameter '{hyperparameter_name}' must be less than "
"{hyperparameter_spec.exclusive_max}."
f"{hyperparameter_spec.exclusive_max}."
)


Expand Down
36 changes: 1 addition & 35 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import sagemaker
from sagemaker import (
fw_utils,
image_uris,
local,
s3,
session,
Expand Down Expand Up @@ -657,34 +656,6 @@ def _compilation_job_config(
"job_name": job_name,
}

def _compilation_image_uri(self, region, target_instance_type, framework, framework_version):
"""Retrieve the Neo or Inferentia image URI.

Args:
region (str): The AWS region.
target_instance_type (str): Identifies the device on which you want to run
your model after compilation, for example: ml_c5. For valid values, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
framework (str): The framework name.
framework_version (str): The framework version.
"""
framework_prefix = ""
framework_suffix = ""

if framework == "xgboost":
framework_suffix = "-neo"
elif target_instance_type.startswith("ml_inf"):
framework_prefix = "inferentia-"
else:
framework_prefix = "neo-"

return image_uris.retrieve(
"{}{}{}".format(framework_prefix, framework, framework_suffix),
region,
instance_type=target_instance_type,
version=framework_version,
)

def package_for_edge(
self,
output_path,
Expand Down Expand Up @@ -849,12 +820,7 @@ def compile(
if target_instance_family == "ml_eia2":
pass
elif target_instance_family.startswith("ml_"):
self.image_uri = self._compilation_image_uri(
self.sagemaker_session.boto_region_name,
target_instance_family,
framework,
framework_version,
)
self.image_uri = job_status.get("InferenceImage", None)
self._is_compiled_model = True
else:
LOGGER.warning(
Expand Down
71 changes: 71 additions & 0 deletions src/sagemaker/workflow/fail_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The `Step` definitions for SageMaker Pipelines Workflows."""
from __future__ import absolute_import

from typing import List, Union

from sagemaker.workflow import PipelineNonPrimitiveInputTypes
from sagemaker.workflow.entities import (
RequestType,
)
from sagemaker.workflow.steps import Step, StepTypeEnum


class FailStep(Step):
"""`FailStep` for SageMaker Pipelines Workflows."""

def __init__(
self,
name: str,
error_message: Union[str, PipelineNonPrimitiveInputTypes] = None,
display_name: str = None,
description: str = None,
depends_on: Union[List[str], List[Step]] = None,
):
"""Constructs a `FailStep`.

Args:
name (str): The name of the `FailStep`. A name is required and must be
unique within a pipeline.
error_message (str or PipelineNonPrimitiveInputTypes):
An error message defined by the user.
Once the `FailStep` is reached, the execution fails and the
error message is set as the failure reason (default: None).
display_name (str): The display name of the `FailStep`.
The display name provides better UI readability. (default: None).
description (str): The description of the `FailStep` (default: None).
depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances
that this `FailStep` depends on.
If a listed `Step` name does not exist, an error is returned (default: None).
"""
super(FailStep, self).__init__(
name, display_name, description, StepTypeEnum.FAIL, depends_on
)
self.error_message = error_message if error_message is not None else ""

@property
def arguments(self) -> RequestType:
"""The arguments dictionary that is used to define the `FailStep`."""
return dict(ErrorMessage=self.error_message)

@property
def properties(self):
"""A `Properties` object is not available for the `FailStep`.

Executing a `FailStep` will terminate the pipeline.
`FailStep` properties should not be referenced.
"""
raise RuntimeError(
"FailStep is a terminal step and the Properties object is not available for it."
)
Loading