@@ -58,6 +58,20 @@ class ProblemTypes(str, Enum):
58
58
TABULAR_CLASSIFICATION = "Classification"
59
59
60
60
61
+ class Frameworks (str , Enum ):
62
+ """Possible frameworks for JumpStart models"""
63
+
64
+ TENSORFLOW = "Tensorflow Hub"
65
+ PYTORCH = "Pytorch Hub"
66
+ HUGGINGFACE = "HuggingFace"
67
+ CATBOOST = "Catboost"
68
+ GLUONCV = "GluonCV"
69
+ LIGHTGBM = "LightGBM"
70
+ XGBOOST = "XGBoost"
71
+ SCIKIT_LEARN = "ScikitLearn"
72
+ SOURCE = "Source"
73
+
74
+
61
75
JUMPSTART_REGION = "eu-west-2"
62
76
SDK_MANIFEST_FILE = "models_manifest.json"
63
77
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com" .format (
@@ -82,6 +96,61 @@ class ProblemTypes(str, Enum):
82
96
Tasks .TABULAR_CLASSIFICATION : ProblemTypes .TABULAR_CLASSIFICATION ,
83
97
}
84
98
99
+ TO_FRAMEWORK = {
100
+ "Tensorflow Hub" : Frameworks .TENSORFLOW ,
101
+ "Pytorch Hub" : Frameworks .PYTORCH ,
102
+ "HuggingFace" : Frameworks .HUGGINGFACE ,
103
+ "Catboost" : Frameworks .CATBOOST ,
104
+ "GluonCV" : Frameworks .GLUONCV ,
105
+ "LightGBM" : Frameworks .LIGHTGBM ,
106
+ "XGBoost" : Frameworks .XGBOOST ,
107
+ "ScikitLearn" : Frameworks .SCIKIT_LEARN ,
108
+ "Source" : Frameworks .SOURCE ,
109
+ }
110
+
111
+
112
+ MODALITY_MAP = {
113
+ (Tasks .IC , Frameworks .PYTORCH ): "algorithms/vision/image_classification_pytorch.rst" ,
114
+ (Tasks .IC , Frameworks .TENSORFLOW ): "algorithms/vision/image_classification_tensorflow.rst" ,
115
+ (Tasks .IC_EMBEDDING , Frameworks .TENSORFLOW ): "algorithms/vision/image_embedding_tensorflow.rst" ,
116
+ (Tasks .IS , Frameworks .GLUONCV ): "algorithms/vision/instance_segmentation_mxnet.rst" ,
117
+ (Tasks .OD , Frameworks .GLUONCV ): "algorithms/vision/object_detection_mxnet.rst" ,
118
+ (Tasks .OD , Frameworks .PYTORCH ): "algorithms/vision/object_detection_pytorch.rst" ,
119
+ (Tasks .OD , Frameworks .TENSORFLOW ): "algorithms/vision/object_detection_tensorflow.rst" ,
120
+ (Tasks .SEMSEG , Frameworks .GLUONCV ): "algorithms/vision/semantic_segmentation_mxnet.rst" ,
121
+ (
122
+ Tasks .TRANSLATION ,
123
+ Frameworks .HUGGINGFACE ,
124
+ ): "algorithms/text/machine_translation_hugging_face.rst" ,
125
+ (Tasks .NER , Frameworks .GLUONCV ): "algorithms/text/named_entity_recognition_hugging_face.rst" ,
126
+ (Tasks .EQA , Frameworks .PYTORCH ): "algorithms/text/question_answering_pytorch.rst" ,
127
+ (
128
+ Tasks .SPC ,
129
+ Frameworks .HUGGINGFACE ,
130
+ ): "algorithms/text/sentence_pair_classification_hugging_face.rst" ,
131
+ (
132
+ Tasks .SPC ,
133
+ Frameworks .TENSORFLOW ,
134
+ ): "algorithms/text/sentence_pair_classification_tensorflow.rst" ,
135
+ (Tasks .TC , Frameworks .TENSORFLOW ): "algorithms/text/text_classification_tensorflow.rst" ,
136
+ (
137
+ Tasks .TC_EMBEDDING ,
138
+ Frameworks .GLUONCV ,
139
+ ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst" ,
140
+ (
141
+ Tasks .TC_EMBEDDING ,
142
+ Frameworks .TENSORFLOW ,
143
+ ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst" ,
144
+ (
145
+ Tasks .TEXT_GENERATION ,
146
+ Frameworks .HUGGINGFACE ,
147
+ ): "algorithms/text/text_generation_hugging_face.rst" ,
148
+ (
149
+ Tasks .SUMMARIZATION ,
150
+ Frameworks .HUGGINGFACE ,
151
+ ): "algorithms/text/text_summarization_hugging_face.rst" ,
152
+ }
153
+
85
154
86
155
def get_jumpstart_sdk_manifest ():
87
156
url = "{}/{}" .format (JUMPSTART_BUCKET_BASE_URL , SDK_MANIFEST_FILE )
@@ -102,6 +171,10 @@ def get_model_task(id):
102
171
return TASK_MAP [task_short ] if task_short in TASK_MAP else "Source"
103
172
104
173
174
+ def get_string_model_task (id ):
175
+ return id .split ("-" )[1 ]
176
+
177
+
105
178
def get_model_source (url ):
106
179
if "tfhub" in url :
107
180
return "Tensorflow Hub"
@@ -113,8 +186,6 @@ def get_model_source(url):
113
186
return "Catboost"
114
187
if "gluon" in url :
115
188
return "GluonCV"
116
- if "catboost" in url :
117
- return "Catboost"
118
189
if "lightgbm" in url :
119
190
return "LightGBM"
120
191
if "xgboost" in url :
@@ -138,58 +209,97 @@ def create_jumpstart_model_table():
138
209
) < Version (model ["version" ]):
139
210
sdk_manifest_top_versions_for_models [model ["model_id" ]] = model
140
211
141
- file_content = []
212
+ file_content_intro = []
142
213
143
- file_content .append (".. _all-pretrained-models:\n \n " )
144
- file_content .append (".. |external-link| raw:: html\n \n " )
145
- file_content .append (' <i class="fa fa-external-link"></i>\n \n ' )
214
+ file_content_intro .append (".. _all-pretrained-models:\n \n " )
215
+ file_content_intro .append (".. |external-link| raw:: html\n \n " )
216
+ file_content_intro .append (' <i class="fa fa-external-link"></i>\n \n ' )
146
217
147
- file_content .append ("================================================\n " )
148
- file_content .append ("Built-in Algorithms with pre-trained Model Table\n " )
149
- file_content .append ("================================================\n " )
150
- file_content .append (
218
+ file_content_intro .append ("================================================\n " )
219
+ file_content_intro .append ("Built-in Algorithms with pre-trained Model Table\n " )
220
+ file_content_intro .append ("================================================\n " )
221
+ file_content_intro .append (
151
222
"""
152
223
The SageMaker Python SDK uses model IDs and model versions to access the necessary
153
224
utilities for pre-trained models. This table serves to provide the core material plus
154
225
some extra information that can be useful in selecting the correct model ID and
155
226
corresponding parameters.\n """
156
227
)
157
- file_content .append (
228
+ file_content_intro .append (
158
229
"""
159
230
If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
160
231
We highly suggest pinning an exact model version however.\n """
161
232
)
162
- file_content .append (
233
+ file_content_intro .append (
163
234
"""
164
235
These models are also available through the
165
236
`JumpStart UI in SageMaker Studio <https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html>`__\n """
166
237
)
167
- file_content .append ("\n " )
168
- file_content .append (".. list-table:: Available Models\n " )
169
- file_content .append (" :widths: 50 20 20 20 30 20\n " )
170
- file_content .append (" :header-rows: 1\n " )
171
- file_content .append (" :class: datatable\n " )
172
- file_content .append ("\n " )
173
- file_content .append (" * - Model ID\n " )
174
- file_content .append (" - Fine Tunable?\n " )
175
- file_content .append (" - Latest Version\n " )
176
- file_content .append (" - Min SDK Version\n " )
177
- file_content .append (" - Problem Type\n " )
178
- file_content .append (" - Source\n " )
238
+ file_content_intro .append ("\n " )
239
+ file_content_intro .append (".. list-table:: Available Models\n " )
240
+ file_content_intro .append (" :widths: 50 20 20 20 30 20\n " )
241
+ file_content_intro .append (" :header-rows: 1\n " )
242
+ file_content_intro .append (" :class: datatable\n " )
243
+ file_content_intro .append ("\n " )
244
+ file_content_intro .append (" * - Model ID\n " )
245
+ file_content_intro .append (" - Fine Tunable?\n " )
246
+ file_content_intro .append (" - Latest Version\n " )
247
+ file_content_intro .append (" - Min SDK Version\n " )
248
+ file_content_intro .append (" - Problem Type\n " )
249
+ file_content_intro .append (" - Source\n " )
250
+
251
+ dynamic_table_files = []
252
+ file_content_entries = []
179
253
180
254
for model in sdk_manifest_top_versions_for_models .values ():
181
255
model_spec = get_jumpstart_sdk_spec (model ["spec_key" ])
182
256
model_task = get_model_task (model_spec ["model_id" ])
257
+ string_model_task = get_string_model_task (model_spec ["model_id" ])
183
258
model_source = get_model_source (model_spec ["url" ])
184
- file_content .append (" * - {}\n " .format (model_spec ["model_id" ]))
185
- file_content .append (" - {}\n " .format (model_spec ["training_supported" ]))
186
- file_content .append (" - {}\n " .format (model ["version" ]))
187
- file_content .append (" - {}\n " .format (model ["min_version" ]))
188
- file_content .append (" - {}\n " .format (model_task ))
189
- file_content .append (
259
+ file_content_entries .append (" * - {}\n " .format (model_spec ["model_id" ]))
260
+ file_content_entries .append (" - {}\n " .format (model_spec ["training_supported" ]))
261
+ file_content_entries .append (" - {}\n " .format (model ["version" ]))
262
+ file_content_entries .append (" - {}\n " .format (model ["min_version" ]))
263
+ file_content_entries .append (" - {}\n " .format (model_task ))
264
+ file_content_entries .append (
190
265
" - `{} <{}>`__ |external-link|\n " .format (model_source , model_spec ["url" ])
191
266
)
192
267
193
- f = open ("doc_utils/pretrainedmodels.rst" , "w" )
194
- f .writelines (file_content )
268
+ if (string_model_task , TO_FRAMEWORK [model_source ]) in MODALITY_MAP :
269
+ file_content_single_entry = []
270
+
271
+ if (
272
+ MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])]
273
+ not in dynamic_table_files
274
+ ):
275
+ file_content_single_entry .append ("\n " )
276
+ file_content_single_entry .append (".. list-table:: Available Models\n " )
277
+ file_content_single_entry .append (" :widths: 50 20 20 20 20\n " )
278
+ file_content_single_entry .append (" :header-rows: 1\n " )
279
+ file_content_single_entry .append (" :class: datatable\n " )
280
+ file_content_single_entry .append ("\n " )
281
+ file_content_single_entry .append (" * - Model ID\n " )
282
+ file_content_single_entry .append (" - Fine Tunable?\n " )
283
+ file_content_single_entry .append (" - Latest Version\n " )
284
+ file_content_single_entry .append (" - Min SDK Version\n " )
285
+ file_content_single_entry .append (" - Source\n " )
286
+
287
+ dynamic_table_files .append (
288
+ MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])]
289
+ )
290
+
291
+ file_content_single_entry .append (" * - {}\n " .format (model_spec ["model_id" ]))
292
+ file_content_single_entry .append (" - {}\n " .format (model_spec ["training_supported" ]))
293
+ file_content_single_entry .append (" - {}\n " .format (model ["version" ]))
294
+ file_content_single_entry .append (" - {}\n " .format (model ["min_version" ]))
295
+ file_content_single_entry .append (
296
+ " - `{} <{}>`__\n " .format (model_source , model_spec ["url" ])
297
+ )
298
+ f = open (MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])], "a" )
299
+ f .writelines (file_content_single_entry )
300
+ f .close ()
301
+
302
+ f = open ("doc_utils/pretrainedmodels.rst" , "a" )
303
+ f .writelines (file_content_intro )
304
+ f .writelines (file_content_entries )
195
305
f .close ()
0 commit comments