Skip to content

Commit 17a479a

Browse files
ragdhallRaghav DhallBasilBeiroutiaaronmarkham
authored
documentation: remove Other tab in Built-in algorithms section and mi… (#3317)
Co-authored-by: Raghav Dhall <[email protected]> Co-authored-by: Basil Beirouti <[email protected]> Co-authored-by: Aaron Markham <[email protected]>
1 parent 736f503 commit 17a479a

11 files changed

+160
-2128
lines changed

doc/algorithms/index.rst

+6-2
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22
Built-in Algorithms
33
######################
44

5-
Amazon SageMaker provides implementations of some common machine learning algorithms optimized for GPU architecture and massive datasets.
5+
Built-in algorithms are offered in 2 modes:
6+
7+
* Container mode algorithms offered through :ref:`Estimators <estimators>` & :ref:`Amazon Estimators <amazon_estimators>`
8+
9+
* Script mode algorithms based on `pre-built SageMaker Docker Images <https://docs.aws.amazon.com/sagemaker/latest/dg/docker-containers-prebuilt.html>`__ offered through Estimators
610

711
.. toctree::
812
:maxdepth: 2
913

14+
sagemaker.amazon.amazon_estimator
1015
tabular/index
1116
text/index
1217
time_series/index
1318
unsupervised/index
1419
vision/index
15-
other/index

doc/algorithms/other/index.rst

-10
This file was deleted.

doc/algorithms/other/sagemaker.amazon.amazon_estimator.rst renamed to doc/algorithms/sagemaker.amazon.amazon_estimator.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _amazon_estimators:
2+
13
Amazon Estimators
24
--------------------
35

doc/algorithms/tabular/index.rst

-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,3 @@ Amazon SageMaker provides built-in algorithms that are tailored to the analysis
1515
linear_learner
1616
tabtransformer
1717
xgboost
18-
object2vec

doc/algorithms/text/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Amazon SageMaker provides algorithms that are tailored to the analysis of textua
1010
blazing_text
1111
lda
1212
ntm
13+
object2vec
1314
sequence_to_sequence
1415
text_classification_tensorflow
1516
sentence_pair_classification_tensorflow
@@ -19,4 +20,3 @@ Amazon SageMaker provides algorithms that are tailored to the analysis of textua
1920
text_summarization_hugging_face
2021
text_generation_hugging_face
2122
machine_translation_hugging_face
22-
text_embedding_tensorflow_mxnet

doc/algorithms/vision/index.rst

+7-6
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ Amazon SageMaker provides image processing algorithms that are used for image cl
77
.. toctree::
88
:maxdepth: 2
99

10-
image_classification_mxnet
11-
image_classification_pytorch
1210
image_classification_tensorflow
11+
image_classification_pytorch
12+
image_classification_mxnet
13+
image_embedding_tensorflow
14+
instance_segmentation_mxnet
15+
object_detection_tensorflow
16+
object_detection_pytorch
1317
object_detection_mxnet_gluoncv
1418
object_detection_mxnet
15-
object_detection_pytorch
16-
object_detection_tensorflow
1719
semantic_segmentation_mxnet_gluoncv
1820
semantic_segmentation_mxnet
19-
instance_segmentation_mxnet
20-
image_embedding_tensorflow
21+
text_embedding_tensorflow_mxnet

doc/api/training/estimators.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _estimators:
2+
13
Estimators
24
----------
35

doc/doc_utils/jumpstart_doc_utils.py

+142-32
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ class ProblemTypes(str, Enum):
5858
TABULAR_CLASSIFICATION = "Classification"
5959

6060

61+
class Frameworks(str, Enum):
62+
"""Possible frameworks for JumpStart models"""
63+
64+
TENSORFLOW = "Tensorflow Hub"
65+
PYTORCH = "Pytorch Hub"
66+
HUGGINGFACE = "HuggingFace"
67+
CATBOOST = "Catboost"
68+
GLUONCV = "GluonCV"
69+
LIGHTGBM = "LightGBM"
70+
XGBOOST = "XGBoost"
71+
SCIKIT_LEARN = "ScikitLearn"
72+
SOURCE = "Source"
73+
74+
6175
JUMPSTART_REGION = "eu-west-2"
6276
SDK_MANIFEST_FILE = "models_manifest.json"
6377
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
@@ -82,6 +96,61 @@ class ProblemTypes(str, Enum):
8296
Tasks.TABULAR_CLASSIFICATION: ProblemTypes.TABULAR_CLASSIFICATION,
8397
}
8498

99+
TO_FRAMEWORK = {
100+
"Tensorflow Hub": Frameworks.TENSORFLOW,
101+
"Pytorch Hub": Frameworks.PYTORCH,
102+
"HuggingFace": Frameworks.HUGGINGFACE,
103+
"Catboost": Frameworks.CATBOOST,
104+
"GluonCV": Frameworks.GLUONCV,
105+
"LightGBM": Frameworks.LIGHTGBM,
106+
"XGBoost": Frameworks.XGBOOST,
107+
"ScikitLearn": Frameworks.SCIKIT_LEARN,
108+
"Source": Frameworks.SOURCE,
109+
}
110+
111+
112+
MODALITY_MAP = {
113+
(Tasks.IC, Frameworks.PYTORCH): "algorithms/vision/image_classification_pytorch.rst",
114+
(Tasks.IC, Frameworks.TENSORFLOW): "algorithms/vision/image_classification_tensorflow.rst",
115+
(Tasks.IC_EMBEDDING, Frameworks.TENSORFLOW): "algorithms/vision/image_embedding_tensorflow.rst",
116+
(Tasks.IS, Frameworks.GLUONCV): "algorithms/vision/instance_segmentation_mxnet.rst",
117+
(Tasks.OD, Frameworks.GLUONCV): "algorithms/vision/object_detection_mxnet.rst",
118+
(Tasks.OD, Frameworks.PYTORCH): "algorithms/vision/object_detection_pytorch.rst",
119+
(Tasks.OD, Frameworks.TENSORFLOW): "algorithms/vision/object_detection_tensorflow.rst",
120+
(Tasks.SEMSEG, Frameworks.GLUONCV): "algorithms/vision/semantic_segmentation_mxnet.rst",
121+
(
122+
Tasks.TRANSLATION,
123+
Frameworks.HUGGINGFACE,
124+
): "algorithms/text/machine_translation_hugging_face.rst",
125+
(Tasks.NER, Frameworks.GLUONCV): "algorithms/text/named_entity_recognition_hugging_face.rst",
126+
(Tasks.EQA, Frameworks.PYTORCH): "algorithms/text/question_answering_pytorch.rst",
127+
(
128+
Tasks.SPC,
129+
Frameworks.HUGGINGFACE,
130+
): "algorithms/text/sentence_pair_classification_hugging_face.rst",
131+
(
132+
Tasks.SPC,
133+
Frameworks.TENSORFLOW,
134+
): "algorithms/text/sentence_pair_classification_tensorflow.rst",
135+
(Tasks.TC, Frameworks.TENSORFLOW): "algorithms/text/text_classification_tensorflow.rst",
136+
(
137+
Tasks.TC_EMBEDDING,
138+
Frameworks.GLUONCV,
139+
): "algorithms/vision/text_embedding_tensorflow_mxnet.rst",
140+
(
141+
Tasks.TC_EMBEDDING,
142+
Frameworks.TENSORFLOW,
143+
): "algorithms/vision/text_embedding_tensorflow_mxnet.rst",
144+
(
145+
Tasks.TEXT_GENERATION,
146+
Frameworks.HUGGINGFACE,
147+
): "algorithms/text/text_generation_hugging_face.rst",
148+
(
149+
Tasks.SUMMARIZATION,
150+
Frameworks.HUGGINGFACE,
151+
): "algorithms/text/text_summarization_hugging_face.rst",
152+
}
153+
85154

86155
def get_jumpstart_sdk_manifest():
87156
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE)
@@ -102,6 +171,10 @@ def get_model_task(id):
102171
return TASK_MAP[task_short] if task_short in TASK_MAP else "Source"
103172

104173

174+
def get_string_model_task(id):
175+
return id.split("-")[1]
176+
177+
105178
def get_model_source(url):
106179
if "tfhub" in url:
107180
return "Tensorflow Hub"
@@ -113,8 +186,6 @@ def get_model_source(url):
113186
return "Catboost"
114187
if "gluon" in url:
115188
return "GluonCV"
116-
if "catboost" in url:
117-
return "Catboost"
118189
if "lightgbm" in url:
119190
return "LightGBM"
120191
if "xgboost" in url:
@@ -138,58 +209,97 @@ def create_jumpstart_model_table():
138209
) < Version(model["version"]):
139210
sdk_manifest_top_versions_for_models[model["model_id"]] = model
140211

141-
file_content = []
212+
file_content_intro = []
142213

143-
file_content.append(".. _all-pretrained-models:\n\n")
144-
file_content.append(".. |external-link| raw:: html\n\n")
145-
file_content.append(' <i class="fa fa-external-link"></i>\n\n')
214+
file_content_intro.append(".. _all-pretrained-models:\n\n")
215+
file_content_intro.append(".. |external-link| raw:: html\n\n")
216+
file_content_intro.append(' <i class="fa fa-external-link"></i>\n\n')
146217

147-
file_content.append("================================================\n")
148-
file_content.append("Built-in Algorithms with pre-trained Model Table\n")
149-
file_content.append("================================================\n")
150-
file_content.append(
218+
file_content_intro.append("================================================\n")
219+
file_content_intro.append("Built-in Algorithms with pre-trained Model Table\n")
220+
file_content_intro.append("================================================\n")
221+
file_content_intro.append(
151222
"""
152223
The SageMaker Python SDK uses model IDs and model versions to access the necessary
153224
utilities for pre-trained models. This table serves to provide the core material plus
154225
some extra information that can be useful in selecting the correct model ID and
155226
corresponding parameters.\n"""
156227
)
157-
file_content.append(
228+
file_content_intro.append(
158229
"""
159230
If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
160231
We highly suggest pinning an exact model version however.\n"""
161232
)
162-
file_content.append(
233+
file_content_intro.append(
163234
"""
164235
These models are also available through the
165236
`JumpStart UI in SageMaker Studio <https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html>`__\n"""
166237
)
167-
file_content.append("\n")
168-
file_content.append(".. list-table:: Available Models\n")
169-
file_content.append(" :widths: 50 20 20 20 30 20\n")
170-
file_content.append(" :header-rows: 1\n")
171-
file_content.append(" :class: datatable\n")
172-
file_content.append("\n")
173-
file_content.append(" * - Model ID\n")
174-
file_content.append(" - Fine Tunable?\n")
175-
file_content.append(" - Latest Version\n")
176-
file_content.append(" - Min SDK Version\n")
177-
file_content.append(" - Problem Type\n")
178-
file_content.append(" - Source\n")
238+
file_content_intro.append("\n")
239+
file_content_intro.append(".. list-table:: Available Models\n")
240+
file_content_intro.append(" :widths: 50 20 20 20 30 20\n")
241+
file_content_intro.append(" :header-rows: 1\n")
242+
file_content_intro.append(" :class: datatable\n")
243+
file_content_intro.append("\n")
244+
file_content_intro.append(" * - Model ID\n")
245+
file_content_intro.append(" - Fine Tunable?\n")
246+
file_content_intro.append(" - Latest Version\n")
247+
file_content_intro.append(" - Min SDK Version\n")
248+
file_content_intro.append(" - Problem Type\n")
249+
file_content_intro.append(" - Source\n")
250+
251+
dynamic_table_files = []
252+
file_content_entries = []
179253

180254
for model in sdk_manifest_top_versions_for_models.values():
181255
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
182256
model_task = get_model_task(model_spec["model_id"])
257+
string_model_task = get_string_model_task(model_spec["model_id"])
183258
model_source = get_model_source(model_spec["url"])
184-
file_content.append(" * - {}\n".format(model_spec["model_id"]))
185-
file_content.append(" - {}\n".format(model_spec["training_supported"]))
186-
file_content.append(" - {}\n".format(model["version"]))
187-
file_content.append(" - {}\n".format(model["min_version"]))
188-
file_content.append(" - {}\n".format(model_task))
189-
file_content.append(
259+
file_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
260+
file_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
261+
file_content_entries.append(" - {}\n".format(model["version"]))
262+
file_content_entries.append(" - {}\n".format(model["min_version"]))
263+
file_content_entries.append(" - {}\n".format(model_task))
264+
file_content_entries.append(
190265
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
191266
)
192267

193-
f = open("doc_utils/pretrainedmodels.rst", "w")
194-
f.writelines(file_content)
268+
if (string_model_task, TO_FRAMEWORK[model_source]) in MODALITY_MAP:
269+
file_content_single_entry = []
270+
271+
if (
272+
MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])]
273+
not in dynamic_table_files
274+
):
275+
file_content_single_entry.append("\n")
276+
file_content_single_entry.append(".. list-table:: Available Models\n")
277+
file_content_single_entry.append(" :widths: 50 20 20 20 20\n")
278+
file_content_single_entry.append(" :header-rows: 1\n")
279+
file_content_single_entry.append(" :class: datatable\n")
280+
file_content_single_entry.append("\n")
281+
file_content_single_entry.append(" * - Model ID\n")
282+
file_content_single_entry.append(" - Fine Tunable?\n")
283+
file_content_single_entry.append(" - Latest Version\n")
284+
file_content_single_entry.append(" - Min SDK Version\n")
285+
file_content_single_entry.append(" - Source\n")
286+
287+
dynamic_table_files.append(
288+
MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])]
289+
)
290+
291+
file_content_single_entry.append(" * - {}\n".format(model_spec["model_id"]))
292+
file_content_single_entry.append(" - {}\n".format(model_spec["training_supported"]))
293+
file_content_single_entry.append(" - {}\n".format(model["version"]))
294+
file_content_single_entry.append(" - {}\n".format(model["min_version"]))
295+
file_content_single_entry.append(
296+
" - `{} <{}>`__\n".format(model_source, model_spec["url"])
297+
)
298+
f = open(MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])], "a")
299+
f.writelines(file_content_single_entry)
300+
f.close()
301+
302+
f = open("doc_utils/pretrainedmodels.rst", "a")
303+
f.writelines(file_content_intro)
304+
f.writelines(file_content_entries)
195305
f.close()

0 commit comments

Comments
 (0)