diff --git a/doc/_static/js/datatable.js b/doc/_static/js/datatable.js new file mode 100644 index 0000000000..897204e8df --- /dev/null +++ b/doc/_static/js/datatable.js @@ -0,0 +1,4 @@ +$(document).ready( function () { + $('table.datatable').DataTable(); + $('a.external').attr('target', '_blank'); +} ); \ No newline at end of file diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 07aea20f3e..47cf6e5f39 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -67,6 +67,11 @@ def create_jumpstart_model_table(): We highly suggest pinning an exact model version however.\n """ ) + file_content.append( + """ + Each model id is linked to an external page that describes the model.\n + """ + ) file_content.append("\n") file_content.append(".. list-table:: Available Models\n") file_content.append(" :widths: 50 20 20 20\n") @@ -80,7 +85,7 @@ def create_jumpstart_model_table(): for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) - file_content.append(" * - {}\n".format(model["model_id"])) + file_content.append(" * - `{} <{}>`_\n".format(model_spec["model_id"], model_spec["url"])) file_content.append(" - {}\n".format(model_spec["training_supported"])) file_content.append(" - {}\n".format(model["version"])) file_content.append(" - {}\n".format(model["min_version"])) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py new file mode 100644 index 0000000000..2ca4823b82 --- /dev/null +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores notebook utils related to SageMaker JumpStart.""" +from __future__ import absolute_import + +from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME + + +def get_model_url( + model_id: str, model_version: str, region: str = JUMPSTART_DEFAULT_REGION_NAME +) -> str: + """Retrieve web url describing pretrained model. + + Args: + model_id (str): The model ID for which to retrieve the url. + model_version (str): The model version for which to retrieve the url. + region (str): Optional. The region from which to retrieve metadata. + (Default: JUMPSTART_DEFAULT_REGION_NAME) + """ + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=model_version + ) + return model_specs.url diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index b9384ca042..844d6d12b7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -272,6 +272,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): __slots__ = [ "model_id", + "url", "version", "min_sdk_version", "incremental_training_supported", @@ -308,6 +309,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ self.model_id: str = json_obj["model_id"] + self.url: str = json_obj["url"] self.version: str = json_obj["version"] self.min_sdk_version: str = json_obj["min_sdk_version"] self.incremental_training_supported: bool = bool(json_obj["incremental_training_supported"]) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index ebb3214e4c..9a616adfe7 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -15,6 +15,7 @@ PROTOTYPICAL_MODEL_SPECS_DICT = { "pytorch-eqa-bert-base-cased": { "model_id": "pytorch-eqa-bert-base-cased", + "url": "https://pytorch.org/hub/huggingface_pytorch-transformers/", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -146,6 +147,7 @@ }, "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1": { "model_id": "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", + "url": "https://tfhub.dev/google/bit/m-r101x1/ilsvrc2012_classification/1", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -258,6 +260,7 @@ }, "mxnet-semseg-fcn-resnet50-ade": { "model_id": "mxnet-semseg-fcn-resnet50-ade", + "url": "https://cv.gluon.ai/model_zoo/segmentation.html", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -369,6 +372,7 @@ }, "huggingface-spc-bert-base-cased": { "model_id": "huggingface-spc-bert-base-cased", + "url": "https://huggingface.co/bert-base-cased", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -482,6 +486,7 @@ }, "lightgbm-classification-model": { "model_id": "lightgbm-classification-model", + "url": "https://lightgbm.readthedocs.io/en/latest/", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -640,6 +645,7 @@ }, "catboost-classification-model": { "model_id": "catboost-classification-model", + "url": "https://catboost.ai/", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -792,6 +798,7 @@ }, "xgboost-classification-model": { "model_id": "xgboost-classification-model", + "url": "https://xgboost.readthedocs.io/en/latest/", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -945,6 +952,7 @@ }, "sklearn-classification-linear": { "model_id": "sklearn-classification-linear", + "url": "https://scikit-learn.org/stable/", "version": "1.0.0", "min_sdk_version": "2.68.1", "training_supported": True, @@ -1066,6 +1074,7 @@ BASE_SPEC = { "model_id": "pytorch-ic-mobilenet-v2", + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", "version": "1.0.0", "min_sdk_version": "2.49.0", "training_supported": True, diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py new file mode 100644 index 0000000000..db5eb2b16d --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -0,0 +1,41 @@ +from __future__ import absolute_import +from unittest.mock import Mock, patch +from sagemaker.jumpstart import notebook_utils +from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_model_spec, +) + + +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_get_model_url( + patched_get_model_specs: Mock, +): + + patched_get_model_specs.side_effect = get_prototype_model_spec + + model_id, version = "xgboost-classification-model", "1.0.0" + assert "https://xgboost.readthedocs.io/en/latest/" == notebook_utils.get_model_url( + model_id, version + ) + + model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" + assert ( + "https://tfhub.dev/google/bit/m-r101x1/ilsvrc2012_classification/1" + == notebook_utils.get_model_url(model_id, version) + ) + + model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" + region = "fake-region" + + patched_get_model_specs.reset_mock() + patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_prototype_model_spec( + *largs, + region="us-west-2", + **{key: value for key, value in kwargs.items() if key != "region"} + ) + + notebook_utils.get_model_url(model_id, version, region=region) + + patched_get_model_specs.assert_called_once_with( + model_id=model_id, version=version, region=region + )