Skip to content

feat: support JumpStart proprietary models #4467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5aeaa35
feat: add proprietary manifest/specs parsing
Captainia Feb 22, 2024
4d58cfd
remove unused imports and fix docstyle
Captainia Feb 29, 2024
8ba6477
fix: remove unused args
Captainia Feb 29, 2024
29ac23c
fix: remove unused args
Captainia Feb 29, 2024
57a7d37
fix: more unused vars
Captainia Feb 29, 2024
a8ffdc2
fix: slow tests
Captainia Feb 29, 2024
d71b727
fix: unittests
Captainia Feb 29, 2024
359ea1c
added more tests to cover some lines
Captainia Feb 29, 2024
33fc27e
Merge branch 'master' into prop-model-deploy
Captainia Feb 29, 2024
04e1376
remove estimator warn check
Captainia Mar 1, 2024
f74a3e4
chore: address comments re performance
Captainia Mar 1, 2024
a96ea08
fix: address comments
Captainia Mar 4, 2024
cb8aee8
Merge branch 'master' into prop-model-deploy
Captainia Mar 5, 2024
e21b98b
complete list experience and other fixes
Captainia Mar 5, 2024
feef2a2
fix: pylint
Captainia Mar 5, 2024
33a2b59
add doc utils and fix pylint
Captainia Mar 5, 2024
b045b44
Merge branch 'master' into prop-model-deploy
Captainia Mar 5, 2024
8c90641
fix: docstyle
Captainia Mar 5, 2024
5bccb8e
fix: doc
Captainia Mar 5, 2024
f8258b7
fix: default payloads
Captainia Mar 5, 2024
896a2cf
fix: doc and tags and enums
Captainia Mar 6, 2024
d701211
fix: jumpstart doc
Captainia Mar 6, 2024
e3e64ba
rename to open_weights and fix filtering
Captainia Mar 8, 2024
d5b9b76
Merge branch 'master' into prop-model-deploy
Captainia Mar 8, 2024
27e14b9
update filter name
Captainia Mar 8, 2024
ec816c3
doc update
Captainia Mar 8, 2024
07fa93e
fix: black
Captainia Mar 8, 2024
bceb17b
rename to proprietary model and fix unittests
Captainia Mar 8, 2024
6fb885e
address comments
Captainia Mar 11, 2024
22d2078
Merge branch 'master' into prop-model-deploy
Captainia Mar 11, 2024
40542b7
fix: docstyle and flake8
Captainia Mar 11, 2024
abfadca
address more comments and fix doc
Captainia Mar 11, 2024
46ae293
put back doc utils for future refactoring
Captainia Mar 11, 2024
5f053d5
add prop model title in doc
Captainia Mar 11, 2024
9ec6f8e
doc update
Captainia Mar 11, 2024
c1c2a6f
Merge branch 'master' into prop-model-deploy
liujiaorr Mar 12, 2024
2e70c75
Merge branch 'master' into prop-model-deploy
liujiaorr Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions doc/doc_utils/jumpstart_doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ class Frameworks(str, Enum):

JUMPSTART_REGION = "eu-west-2"
SDK_MANIFEST_FILE = "models_manifest.json"
PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json"
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
JUMPSTART_REGION, JUMPSTART_REGION
)
PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For open-weights models we point to a wave-4 region, DUB IIRC.

Could you use the same for consistency?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found an issue with regionalization in some regions, raising a fix today. I'll update this to DUB in a separate PR to update documentation of list_jumpstart_models, and I'd like to refactor the doc utils in that PR as well


TASK_MAP = {
Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION,
Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING,
Expand Down Expand Up @@ -152,18 +155,26 @@ class Frameworks(str, Enum):
}


def get_jumpstart_sdk_manifest():
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE)
def get_public_s3_json_object(url):
with request.urlopen(url) as f:
models_manifest = f.read().decode("utf-8")
return json.loads(models_manifest)


def get_jumpstart_sdk_spec(key):
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key)
with request.urlopen(url) as f:
model_spec = f.read().decode("utf-8")
return json.loads(model_spec)
def get_jumpstart_sdk_manifest():
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}")


def get_proprietary_sdk_manifest():
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}")


def get_jumpstart_sdk_spec(s3_key: str):
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}")


def get_proprietary_sdk_spec(s3_key: str):
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}")


def get_model_task(id):
Expand Down Expand Up @@ -196,6 +207,45 @@ def get_model_source(url):
return "Source"


def create_proprietary_model_table():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is wild, I would just use f strings with triple quotes so that you can have line breaks and have a better visual of what the final string output would look like

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I wasn't refactoring the doc_utils file too much and followed existing patterns, but yes updated this section for sure

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you run this by @judyheflin , as this is a documentation change?

proprietary_content_intro = []
proprietary_content_intro.append("\n")
proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n")
proprietary_content_intro.append(" :widths: 50 20 20 20 20\n")
proprietary_content_intro.append(" :header-rows: 1\n")
proprietary_content_intro.append(" :class: datatable\n")
proprietary_content_intro.append("\n")
proprietary_content_intro.append(" * - Model ID\n")
proprietary_content_intro.append(" - Fine Tunable?\n")
proprietary_content_intro.append(" - Supported Version\n")
proprietary_content_intro.append(" - Min SDK Version\n")
proprietary_content_intro.append(" - Source\n")

sdk_manifest = get_proprietary_sdk_manifest()
sdk_manifest_top_versions_for_models = {}

for model in sdk_manifest:
if model["model_id"] not in sdk_manifest_top_versions_for_models:
sdk_manifest_top_versions_for_models[model["model_id"]] = model
else:
if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we not casting to Version (is it cause of proprietary versions)? This'll cause 2.100.0 > 2.9.0 to fail, unless we pad all the versions with zeros.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, Version("v1.0.3-v1") throws, as they could have wild versioning system. This should get the top one in our metadata which is the latest one, with string comparison

model["version"]
):
sdk_manifest_top_versions_for_models[model["model_id"]] = model

proprietary_content_entries = []
for model in sdk_manifest_top_versions_for_models.values():
model_spec = get_proprietary_sdk_spec(model["spec_key"])
proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training
proprietary_content_entries.append(" - {}\n".format(model["version"]))
proprietary_content_entries.append(" - {}\n".format(model["min_version"]))
proprietary_content_entries.append(
" - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url"))
)
return proprietary_content_intro + proprietary_content_entries + ["\n"]


def create_jumpstart_model_table():
sdk_manifest = get_jumpstart_sdk_manifest()
sdk_manifest_top_versions_for_models = {}
Expand Down Expand Up @@ -249,19 +299,19 @@ def create_jumpstart_model_table():
file_content_intro.append(" - Source\n")

dynamic_table_files = []
file_content_entries = []
open_weight_content_entries = []

for model in sdk_manifest_top_versions_for_models.values():
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
model_task = get_model_task(model_spec["model_id"])
string_model_task = get_string_model_task(model_spec["model_id"])
model_source = get_model_source(model_spec["url"])
file_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
file_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
file_content_entries.append(" - {}\n".format(model["version"]))
file_content_entries.append(" - {}\n".format(model["min_version"]))
file_content_entries.append(" - {}\n".format(model_task))
file_content_entries.append(
open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
open_weight_content_entries.append(" - {}\n".format(model["version"]))
open_weight_content_entries.append(" - {}\n".format(model["min_version"]))
open_weight_content_entries.append(" - {}\n".format(model_task))
open_weight_content_entries.append(
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
)

Expand Down Expand Up @@ -299,7 +349,10 @@ def create_jumpstart_model_table():
f.writelines(file_content_single_entry)
f.close()

proprietary_content_entries = create_proprietary_model_table()

f = open("doc_utils/pretrainedmodels.rst", "a")
f.writelines(file_content_intro)
f.writelines(file_content_entries)
f.writelines(open_weight_content_entries)
f.writelines(proprietary_content_entries)
f.close()
3 changes: 3 additions & 0 deletions src/sagemaker/accept_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.session import Session


Expand Down Expand Up @@ -75,6 +76,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> str:
"""Retrieves the default accept type for the model matching the given arguments.

Expand Down Expand Up @@ -114,4 +116,5 @@ def retrieve_default(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)
4 changes: 3 additions & 1 deletion src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME

from sagemaker.lineage.context import EndpointContext
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.compute_resource_requirements.resource_requirements import (
ResourceRequirements,
)

LOGGER = logging.getLogger("sagemaker")

Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.session import Session


Expand Down Expand Up @@ -75,6 +76,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> str:
"""Retrieves the default content type for the model matching the given arguments.

Expand Down Expand Up @@ -114,6 +116,7 @@ def retrieve_default(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)


Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.session import Session


Expand Down Expand Up @@ -95,6 +96,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> BaseDeserializer:
"""Retrieves the default deserializer for the model matching the given arguments.

Expand Down Expand Up @@ -135,4 +137,5 @@ def retrieve_default(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)
3 changes: 3 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.session import Session

logger = logging.getLogger(__name__)
Expand All @@ -34,6 +35,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> str:
"""Retrieves the default instance type for the model matching the given arguments.

Expand Down Expand Up @@ -85,6 +87,7 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
training_instance_type=training_instance_type,
model_type=model_type,
)


Expand Down
29 changes: 22 additions & 7 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from sagemaker.deprecations import deprecated
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME

Expand Down Expand Up @@ -197,7 +198,9 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:

@staticmethod
def _get_manifest(
region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None
region: str = JUMPSTART_DEFAULT_REGION_NAME,
s3_client: Optional[boto3.client] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest.

Expand All @@ -215,13 +218,19 @@ def _get_manifest(
additional_kwargs.update({"s3_client": s3_client})

cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs},
region,
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
return JumpStartModelsAccessor._cache.get_manifest() # type: ignore
return JumpStartModelsAccessor._cache.get_manifest(model_type) # type: ignore

@staticmethod
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
def get_model_header(
region: str,
model_id: str,
version: str,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> JumpStartModelHeader:
"""Returns model header from JumpStart models cache.

Args:
Expand All @@ -234,12 +243,18 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
return JumpStartModelsAccessor._cache.get_header( # type: ignore
model_id=model_id, semantic_version_str=version
model_id=model_id,
semantic_version_str=version,
model_type=model_type,
)

@staticmethod
def get_model_specs(
region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None
region: str,
model_id: str,
version: str,
s3_client: Optional[boto3.client] = None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.

Expand All @@ -260,7 +275,7 @@ def get_model_specs(
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, semantic_version_str=version
model_id=model_id, version_str=version, model_type=model_type
)

@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.utils import (
verify_model_region_and_return_specs,
Expand All @@ -38,6 +39,7 @@ def _retrieve_default_instance_type(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> str:
"""Retrieves the default instance type for the model.

Expand Down Expand Up @@ -84,6 +86,7 @@ def _retrieve_default_instance_type(
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
)

Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/artifacts/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.jumpstart.utils import (
verify_model_region_and_return_specs,
Expand All @@ -35,6 +36,7 @@ def _retrieve_model_init_kwargs(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> dict:
"""Retrieves kwargs for `Model`.

Expand Down Expand Up @@ -71,6 +73,7 @@ def _retrieve_model_init_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)

kwargs = deepcopy(model_specs.model_kwargs)
Expand All @@ -89,6 +92,7 @@ def _retrieve_model_deploy_kwargs(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> dict:
"""Retrieves kwargs for `Model.deploy`.

Expand Down Expand Up @@ -128,6 +132,7 @@ def _retrieve_model_deploy_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)

if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None:
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
JumpStartModelType,
)
from sagemaker.session import Session

Expand All @@ -35,6 +36,7 @@ def _retrieve_model_package_arn(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Optional[str]:
"""Retrieves associated model pacakge arn for the model.

Expand Down Expand Up @@ -74,6 +76,7 @@ def _retrieve_model_package_arn(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)

if scope == JumpStartScriptScope.INFERENCE:
Expand Down
Loading