Skip to content

Commit 4fc7f2c

Browse files
authored
feat: jumpstart model url (#3036)
* feat: jumpstart model url * fix: jumpstart docs url * fix: jumpstart docstring * docs: add static datatable.js file * fix: docstring typo * fix: docstring capitalization
1 parent 9e34237 commit 4fc7f2c

File tree

6 files changed

+97
-1
lines changed

6 files changed

+97
-1
lines changed

doc/_static/js/datatable.js

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
$(document).ready( function () {
2+
$('table.datatable').DataTable();
3+
$('a.external').attr('target', '_blank');
4+
} );

doc/doc_utils/jumpstart_doc_utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def create_jumpstart_model_table():
6767
We highly suggest pinning an exact model version however.\n
6868
"""
6969
)
70+
file_content.append(
71+
"""
72+
Each model id is linked to an external page that describes the model.\n
73+
"""
74+
)
7075
file_content.append("\n")
7176
file_content.append(".. list-table:: Available Models\n")
7277
file_content.append(" :widths: 50 20 20 20\n")
@@ -80,7 +85,7 @@ def create_jumpstart_model_table():
8085

8186
for model in sdk_manifest_top_versions_for_models.values():
8287
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
83-
file_content.append(" * - {}\n".format(model["model_id"]))
88+
file_content.append(" * - `{} <{}>`_\n".format(model_spec["model_id"], model_spec["url"]))
8489
file_content.append(" - {}\n".format(model_spec["training_supported"]))
8590
file_content.append(" - {}\n".format(model["version"]))
8691
file_content.append(" - {}\n".format(model["min_version"]))
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module stores notebook utils related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.jumpstart import accessors
17+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
18+
19+
20+
def get_model_url(
21+
model_id: str, model_version: str, region: str = JUMPSTART_DEFAULT_REGION_NAME
22+
) -> str:
23+
"""Retrieve web url describing pretrained model.
24+
25+
Args:
26+
model_id (str): The model ID for which to retrieve the url.
27+
model_version (str): The model version for which to retrieve the url.
28+
region (str): Optional. The region from which to retrieve metadata.
29+
(Default: JUMPSTART_DEFAULT_REGION_NAME)
30+
"""
31+
32+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
33+
region=region, model_id=model_id, version=model_version
34+
)
35+
return model_specs.url

src/sagemaker/jumpstart/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
272272

273273
__slots__ = [
274274
"model_id",
275+
"url",
275276
"version",
276277
"min_sdk_version",
277278
"incremental_training_supported",
@@ -308,6 +309,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
308309
json_obj (Dict[str, Any]): Dictionary representation of spec.
309310
"""
310311
self.model_id: str = json_obj["model_id"]
312+
self.url: str = json_obj["url"]
311313
self.version: str = json_obj["version"]
312314
self.min_sdk_version: str = json_obj["min_sdk_version"]
313315
self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"])

tests/unit/sagemaker/jumpstart/constants.py

+9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
PROTOTYPICAL_MODEL_SPECS_DICT = {
1616
"pytorch-eqa-bert-base-cased": {
1717
"model_id": "pytorch-eqa-bert-base-cased",
18+
"url": "https://pytorch.org/hub/huggingface_pytorch-transformers/",
1819
"version": "1.0.0",
1920
"min_sdk_version": "2.68.1",
2021
"training_supported": True,
@@ -146,6 +147,7 @@
146147
},
147148
"tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1": {
148149
"model_id": "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1",
150+
"url": "https://tfhub.dev/google/bit/m-r101x1/ilsvrc2012_classification/1",
149151
"version": "1.0.0",
150152
"min_sdk_version": "2.68.1",
151153
"training_supported": True,
@@ -258,6 +260,7 @@
258260
},
259261
"mxnet-semseg-fcn-resnet50-ade": {
260262
"model_id": "mxnet-semseg-fcn-resnet50-ade",
263+
"url": "https://cv.gluon.ai/model_zoo/segmentation.html",
261264
"version": "1.0.0",
262265
"min_sdk_version": "2.68.1",
263266
"training_supported": True,
@@ -369,6 +372,7 @@
369372
},
370373
"huggingface-spc-bert-base-cased": {
371374
"model_id": "huggingface-spc-bert-base-cased",
375+
"url": "https://huggingface.co/bert-base-cased",
372376
"version": "1.0.0",
373377
"min_sdk_version": "2.68.1",
374378
"training_supported": True,
@@ -482,6 +486,7 @@
482486
},
483487
"lightgbm-classification-model": {
484488
"model_id": "lightgbm-classification-model",
489+
"url": "https://lightgbm.readthedocs.io/en/latest/",
485490
"version": "1.0.0",
486491
"min_sdk_version": "2.68.1",
487492
"training_supported": True,
@@ -640,6 +645,7 @@
640645
},
641646
"catboost-classification-model": {
642647
"model_id": "catboost-classification-model",
648+
"url": "https://catboost.ai/",
643649
"version": "1.0.0",
644650
"min_sdk_version": "2.68.1",
645651
"training_supported": True,
@@ -792,6 +798,7 @@
792798
},
793799
"xgboost-classification-model": {
794800
"model_id": "xgboost-classification-model",
801+
"url": "https://xgboost.readthedocs.io/en/latest/",
795802
"version": "1.0.0",
796803
"min_sdk_version": "2.68.1",
797804
"training_supported": True,
@@ -945,6 +952,7 @@
945952
},
946953
"sklearn-classification-linear": {
947954
"model_id": "sklearn-classification-linear",
955+
"url": "https://scikit-learn.org/stable/",
948956
"version": "1.0.0",
949957
"min_sdk_version": "2.68.1",
950958
"training_supported": True,
@@ -1066,6 +1074,7 @@
10661074

10671075
BASE_SPEC = {
10681076
"model_id": "pytorch-ic-mobilenet-v2",
1077+
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",
10691078
"version": "1.0.0",
10701079
"min_sdk_version": "2.49.0",
10711080
"training_supported": True,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import absolute_import
2+
from unittest.mock import Mock, patch
3+
from sagemaker.jumpstart import notebook_utils
4+
from tests.unit.sagemaker.jumpstart.utils import (
5+
get_prototype_model_spec,
6+
)
7+
8+
9+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
10+
def test_get_model_url(
11+
patched_get_model_specs: Mock,
12+
):
13+
14+
patched_get_model_specs.side_effect = get_prototype_model_spec
15+
16+
model_id, version = "xgboost-classification-model", "1.0.0"
17+
assert "https://xgboost.readthedocs.io/en/latest/" == notebook_utils.get_model_url(
18+
model_id, version
19+
)
20+
21+
model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"
22+
assert (
23+
"https://tfhub.dev/google/bit/m-r101x1/ilsvrc2012_classification/1"
24+
== notebook_utils.get_model_url(model_id, version)
25+
)
26+
27+
model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"
28+
region = "fake-region"
29+
30+
patched_get_model_specs.reset_mock()
31+
patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_prototype_model_spec(
32+
*largs,
33+
region="us-west-2",
34+
**{key: value for key, value in kwargs.items() if key != "region"}
35+
)
36+
37+
notebook_utils.get_model_url(model_id, version, region=region)
38+
39+
patched_get_model_specs.assert_called_once_with(
40+
model_id=model_id, version=version, region=region
41+
)

0 commit comments

Comments
 (0)