Skip to content

Commit 7965e69

Browse files
bencrabtreejerrypeng7773
authored andcommitted
feature: Add Jumpstart example notebooks (aws#3068)
1 parent 02acb53 commit 7965e69

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

doc/doc_utils/jumpstart_doc_utils.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,73 @@
1414
from urllib import request
1515
import json
1616
from packaging.version import Version
17+
from enum import Enum
18+
19+
20+
class Tasks(str, Enum):
21+
"""The ML task name as referenced in the infix of the model ID."""
22+
23+
IC = "ic"
24+
OD = "od"
25+
OD1 = "od1"
26+
SEMSEG = "semseg"
27+
IS = "is"
28+
TC = "tc"
29+
SPC = "spc"
30+
EQA = "eqa"
31+
TEXT_GENERATION = "textgeneration"
32+
IC_EMBEDDING = "icembedding"
33+
TC_EMBEDDING = "tcembedding"
34+
NER = "ner"
35+
SUMMARIZATION = "summarization"
36+
TRANSLATION = "translation"
37+
TABULAR_REGRESSION = "regression"
38+
TABULAR_CLASSIFICATION = "classification"
39+
40+
41+
class ProblemTypes(str, Enum):
42+
"""Possible problem types for JumpStart models."""
43+
44+
IMAGE_CLASSIFICATION = "Image Classification"
45+
IMAGE_EMBEDDING = "Image Embedding"
46+
OBJECT_DETECTION = "Object Detection"
47+
SEMANTIC_SEGMENTATION = "Semantic Segmentation"
48+
INSTANCE_SEGMENTATION = "Instance Segmentation"
49+
TEXT_CLASSIFICATION = "Text Classification"
50+
TEXT_EMBEDDING = "Text Embedding"
51+
QUESTION_ANSWERING = "Question Answering"
52+
SENTENCE_PAIR_CLASSIFICATION = "Sentence Pair Classification"
53+
TEXT_GENERATION = "Text Generation"
54+
TEXT_SUMMARIZATION = "Text Summarization"
55+
MACHINE_TRANSLATION = "Machine Translation"
56+
NAMED_ENTITY_RECOGNITION = "Named Entity Recognition"
57+
TABULAR_REGRESSION = "Regression"
58+
TABULAR_CLASSIFICATION = "Classification"
59+
1760

1861
JUMPSTART_REGION = "eu-west-2"
1962
SDK_MANIFEST_FILE = "models_manifest.json"
2063
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
2164
JUMPSTART_REGION, JUMPSTART_REGION
2265
)
66+
TASK_MAP = {
67+
Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION,
68+
Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING,
69+
Tasks.OD: ProblemTypes.OBJECT_DETECTION,
70+
Tasks.OD1: ProblemTypes.OBJECT_DETECTION,
71+
Tasks.SEMSEG: ProblemTypes.SEMANTIC_SEGMENTATION,
72+
Tasks.IS: ProblemTypes.INSTANCE_SEGMENTATION,
73+
Tasks.TC: ProblemTypes.TEXT_CLASSIFICATION,
74+
Tasks.TC_EMBEDDING: ProblemTypes.TEXT_EMBEDDING,
75+
Tasks.EQA: ProblemTypes.QUESTION_ANSWERING,
76+
Tasks.SPC: ProblemTypes.SENTENCE_PAIR_CLASSIFICATION,
77+
Tasks.TEXT_GENERATION: ProblemTypes.TEXT_GENERATION,
78+
Tasks.SUMMARIZATION: ProblemTypes.TEXT_SUMMARIZATION,
79+
Tasks.TRANSLATION: ProblemTypes.MACHINE_TRANSLATION,
80+
Tasks.NER: ProblemTypes.NAMED_ENTITY_RECOGNITION,
81+
Tasks.TABULAR_REGRESSION: ProblemTypes.TABULAR_REGRESSION,
82+
Tasks.TABULAR_CLASSIFICATION: ProblemTypes.TABULAR_CLASSIFICATION,
83+
}
2384

2485

2586
def get_jumpstart_sdk_manifest():
@@ -36,6 +97,11 @@ def get_jumpstart_sdk_spec(key):
3697
return json.loads(model_spec)
3798

3899

100+
def get_model_task(id):
101+
task_short = id.split("-")[1]
102+
return TASK_MAP[task_short] if task_short in TASK_MAP else "Source"
103+
104+
39105
def create_jumpstart_model_table():
40106
sdk_manifest = get_jumpstart_sdk_manifest()
41107
sdk_manifest_top_versions_for_models = {}
@@ -69,26 +135,29 @@ def create_jumpstart_model_table():
69135
)
70136
file_content.append(
71137
"""
72-
Each model id is linked to an external page that describes the model.\n
138+
Click on the Problem Type to navigate to the source of the model.\n
73139
"""
74140
)
75141
file_content.append("\n")
76142
file_content.append(".. list-table:: Available Models\n")
77-
file_content.append(" :widths: 50 20 20 20\n")
143+
file_content.append(" :widths: 50 20 20 20 30\n")
78144
file_content.append(" :header-rows: 1\n")
79145
file_content.append(" :class: datatable\n")
80146
file_content.append("\n")
81147
file_content.append(" * - Model ID\n")
82148
file_content.append(" - Fine Tunable?\n")
83149
file_content.append(" - Latest Version\n")
84150
file_content.append(" - Min SDK Version\n")
151+
file_content.append(" - Problem Type/Source\n")
85152

86153
for model in sdk_manifest_top_versions_for_models.values():
87154
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
88-
file_content.append(" * - `{} <{}>`_\n".format(model_spec["model_id"], model_spec["url"]))
155+
model_task = get_model_task(model_spec["model_id"])
156+
file_content.append(" * - {}\n".format(model_spec["model_id"]))
89157
file_content.append(" - {}\n".format(model_spec["training_supported"]))
90158
file_content.append(" - {}\n".format(model["version"]))
91159
file_content.append(" - {}\n".format(model["min_version"]))
160+
file_content.append(" - `{} <{}>`__\n".format(model_task, model_spec["url"]))
92161

93162
f = open("doc_utils/jumpstart.rst", "w")
94163
f.writelines(file_content)

doc/overview.rst

+37-3
Original file line numberDiff line numberDiff line change
@@ -573,15 +573,49 @@ Here is an example:
573573
# When you are done using your endpoint
574574
model.sagemaker_session.delete_endpoint('my-endpoint')
575575
576-
********************************************
577-
Use Prebuilt Models with SageMaker JumpStart
578-
********************************************
576+
*********************************************************
577+
Use SageMaker JumpStart Algorithms with Pretrained Models
578+
*********************************************************
579+
580+
JumpStart for the SageMaker Python SDK uses model ids and model versions to access the necessary
581+
utilities. This table serves to provide the core material plus some extra information that can be useful
582+
in selecting the correct model id and corresponding parameters.
579583

580584
.. toctree::
581585
:maxdepth: 2
582586

583587
doc_utils/jumpstart
584588

589+
Example notebooks
590+
=================
591+
592+
JumpStart supports 15 different machine learning problem types. Below is a list of all the supported
593+
problem types with a link to a Jupyter notebook that provides example usage.
594+
595+
Vision
596+
- `Image Classification <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_image_classification/Amazon_JumpStart_Image_Classification.ipynb>`__
597+
- `Object Detection <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_object_detection/Amazon_JumpStart_Object_Detection.ipynb>`__
598+
- `Semantic Segmentation <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_semantic_segmentation/Amazon_JumpStart_Semantic_Segmentation.ipynb>`__
599+
- `Instance Segmentation <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_instance_segmentation/Amazon_JumpStart_Instance_Segmentation.ipynb>`__
600+
- `Image Embedding <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_image_embedding/Amazon_JumpStart_Image_Embedding.ipynb>`__
601+
602+
Text
603+
- `Text Classification <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_text_classification/Amazon_JumpStart_Text_Classification.ipynb>`__
604+
- `Sentence Pair Classification <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_sentence_pair_classification/Amazon_JumpStart_Sentence_Pair_Classification.ipynb>`__
605+
- `Question Answering <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_question_answering/Amazon_JumpStart_Question_Answering.ipynb>`__
606+
- `Named Entity Recognition <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_named_entity_recognition/Amazon_JumpStart_Named_Entity_Recognition.ipynb>`__
607+
- `Text Summarization <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_text_summarization/Amazon_JumpStart_Text_Summarization.ipynb>`__
608+
- `Text Generation <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_text_generation/Amazon_JumpStart_Text_Generation.ipynb>`__
609+
- `Machine Translation <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_machine_translation/Amazon_JumpStart_Machine_Translation.ipynb>`__
610+
- `Text Embedding <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_text_embedding/Amazon_JumpStart_Text_Embedding.ipynb>`__
611+
612+
Tabular
613+
- `Tabular Classification (LightGBM & Catboost) <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_tabular_classification/Amazon_JumpStart_Tabular_Classification_LightGBM_CatBoost.ipynb>`__
614+
- `Tabular Classification (XGBoost & Linear Learner) <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_tabular_classification/Amazon_JumpStart_Tabular_Classification_XGBoost_LinearLearner.ipynb>`__
615+
- `Tabular Regression (LightGBM & Catboost) <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_tabular_regression/Amazon_JumpStart_Tabular_Regression_LightGBM_CatBoost.ipynb>`__
616+
- `Tabular Regression (XGBoost & Linear Learner) <https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart_tabular_regression/Amazon_JumpStart_Tabular_Regression_XGBoost_LinearLearner.ipynb>`__
617+
618+
585619
`Amazon SageMaker JumpStart <https://aws.amazon.com/sagemaker/getting-started/>`__ is a
586620
SageMaker feature that helps users bring machine learning (ML)
587621
applications to market using prebuilt solutions for common use cases,

0 commit comments

Comments
 (0)