Skip to content

fix: jumpstart cache using sagemaker session s3 client #4051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/sagemaker/accept_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import List, Optional

from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.session import Session


def retrieve_options(
Expand All @@ -23,6 +25,7 @@ def retrieve_options(
model_version: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[str]:
"""Retrieves the supported accept types for the model matching the given arguments.

Expand All @@ -40,6 +43,10 @@ def retrieve_options(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please review this docstring with @judyheflin ?

Returns:
list: The supported accept types to use for the model.

Expand All @@ -57,6 +64,7 @@ def retrieve_options(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)


Expand All @@ -66,6 +74,7 @@ def retrieve_default(
model_version: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Retrieves the default accept type for the model matching the given arguments.

Expand All @@ -83,6 +92,10 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
str: The default accept type to use for the model.

Expand All @@ -100,4 +113,5 @@ def retrieve_default(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
14 changes: 14 additions & 0 deletions src/sagemaker/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import List, Optional

from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.session import Session


def retrieve_options(
Expand All @@ -23,6 +25,7 @@ def retrieve_options(
model_version: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[str]:
"""Retrieves the supported content types for the model matching the given arguments.

Expand All @@ -40,6 +43,10 @@ def retrieve_options(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
list: The supported content types to use for the model.

Expand All @@ -57,6 +64,7 @@ def retrieve_options(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)


Expand All @@ -66,6 +74,7 @@ def retrieve_default(
model_version: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Retrieves the default content type for the model matching the given arguments.

Expand All @@ -83,6 +92,10 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
str: The default content type to use for the model.

Expand All @@ -100,6 +113,7 @@ def retrieve_default(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)


Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
)

from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.session import Session


def retrieve_options(
Expand All @@ -41,6 +43,7 @@ def retrieve_options(
model_version: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[BaseDeserializer]:
"""Retrieves the supported deserializers for the model matching the given arguments.

Expand All @@ -58,6 +61,10 @@ def retrieve_options(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
List[BaseDeserializer]: The supported deserializers to use for the model.

Expand All @@ -76,6 +83,7 @@ def retrieve_options(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)


Expand All @@ -85,6 +93,7 @@ def retrieve_default(
model_version: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> BaseDeserializer:
"""Retrieves the default deserializer for the model matching the given arguments.

Expand All @@ -102,6 +111,10 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
BaseDeserializer: The default deserializer to use for the model.

Expand All @@ -120,4 +133,5 @@ def retrieve_default(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
8 changes: 8 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.session import Session

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +32,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
include_aws_sdk_env_vars: bool = True,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

Expand All @@ -51,6 +54,10 @@ def retrieve_default(
should be included. The `Model` class of the SageMaker Python SDK inserts environment
variables that would be required when making the low-level AWS API call.
(Default: True).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
dict: The variables to use for the model.

Expand All @@ -70,4 +77,5 @@ def retrieve_default(
tolerate_vulnerable_model,
tolerate_deprecated_model,
include_aws_sdk_env_vars,
sagemaker_session=sagemaker_session,
)
25 changes: 25 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import HyperparameterValidationMode
from sagemaker.jumpstart.validators import validate_hyperparameters
from sagemaker.session import Session

logger = logging.getLogger(__name__)

Expand All @@ -32,6 +34,7 @@ def retrieve_default(
include_container_hyperparameters: bool = False,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.

Expand All @@ -56,6 +59,10 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
dict: The hyperparameters to use for the model.

Expand All @@ -74,6 +81,7 @@ def retrieve_default(
include_container_hyperparameters,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)


Expand All @@ -83,6 +91,9 @@ def validate(
model_version: Optional[str] = None,
hyperparameters: Optional[dict] = None,
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> None:
"""Validates hyperparameters for models.

Expand All @@ -100,6 +111,17 @@ def validate(
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).

Raises:
JumpStartHyperparametersError: If the hyperparameter is not formatted correctly,
Expand All @@ -125,4 +147,7 @@ def validate(
hyperparameters=hyperparameters,
validation_mode=validation_mode,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
7 changes: 7 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from packaging.version import Version

from sagemaker import utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
from sagemaker.jumpstart import artifacts
Expand Down Expand Up @@ -60,6 +61,7 @@ def retrieve(
sdk_version=None,
inference_tool=None,
serverless_inference_config=None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -109,6 +111,10 @@ def retrieve(
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Instance type is
not provided in serverless inference. So this is used to determine processor type.
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).

Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
Expand Down Expand Up @@ -147,6 +153,7 @@ def retrieve(
training_compiler_config,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.session import Session

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +32,7 @@ def retrieve_default(
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
"""Retrieves the default instance type for the model matching the given arguments.

Expand All @@ -49,6 +52,10 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
str: The default instance type to use for the model.

Expand All @@ -70,6 +77,7 @@ def retrieve_default(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)


Expand All @@ -80,6 +88,7 @@ def retrieve(
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> List[str]:
"""Retrieves the supported training instance types for the model matching the given arguments.

Expand All @@ -97,6 +106,10 @@ def retrieve(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
list: The supported instance types to use for the model.

Expand All @@ -118,4 +131,5 @@ def retrieve(
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
Loading