@@ -57,6 +57,7 @@ class ProblemTypes(str, Enum):
57
57
TABULAR_REGRESSION = "Regression"
58
58
TABULAR_CLASSIFICATION = "Classification"
59
59
60
+
60
61
class Frameworks (str , Enum ):
61
62
"""Possible frameworks for JumpStart models"""
62
63
@@ -104,7 +105,7 @@ class Frameworks(str, Enum):
104
105
"LightGBM" : Frameworks .LIGHTGBM ,
105
106
"XGBoost" : Frameworks .XGBOOST ,
106
107
"ScikitLearn" : Frameworks .SCIKIT_LEARN ,
107
- "Source" : Frameworks .SOURCE
108
+ "Source" : Frameworks .SOURCE ,
108
109
}
109
110
110
111
@@ -117,16 +118,37 @@ class Frameworks(str, Enum):
117
118
(Tasks .OD , Frameworks .PYTORCH ): "algorithms/vision/object_detection_pytorch.rst" ,
118
119
(Tasks .OD , Frameworks .TENSORFLOW ): "algorithms/vision/object_detection_tensorflow.rst" ,
119
120
(Tasks .SEMSEG , Frameworks .GLUONCV ): "algorithms/vision/semantic_segmentation_mxnet.rst" ,
120
- (Tasks .TRANSLATION , Frameworks .HUGGINGFACE ): "algorithms/text/machine_translation_hugging_face.rst" ,
121
+ (
122
+ Tasks .TRANSLATION ,
123
+ Frameworks .HUGGINGFACE ,
124
+ ): "algorithms/text/machine_translation_hugging_face.rst" ,
121
125
(Tasks .NER , Frameworks .GLUONCV ): "algorithms/text/named_entity_recognition_hugging_face.rst" ,
122
126
(Tasks .EQA , Frameworks .PYTORCH ): "algorithms/text/question_answering_pytorch.rst" ,
123
- (Tasks .SPC , Frameworks .HUGGINGFACE ): "algorithms/text/sentence_pair_classification_hugging_face.rst" ,
124
- (Tasks .SPC , Frameworks .TENSORFLOW ): "algorithms/text/sentence_pair_classification_tensorflow.rst" ,
127
+ (
128
+ Tasks .SPC ,
129
+ Frameworks .HUGGINGFACE ,
130
+ ): "algorithms/text/sentence_pair_classification_hugging_face.rst" ,
131
+ (
132
+ Tasks .SPC ,
133
+ Frameworks .TENSORFLOW ,
134
+ ): "algorithms/text/sentence_pair_classification_tensorflow.rst" ,
125
135
(Tasks .TC , Frameworks .TENSORFLOW ): "algorithms/text/text_classification_tensorflow.rst" ,
126
- (Tasks .TC_EMBEDDING , Frameworks .GLUONCV ): "algorithms/text/text_embedding_tensorflow_mxnet.rst" ,
127
- (Tasks .TC_EMBEDDING , Frameworks .TENSORFLOW ): "algorithms/text/text_embedding_tensorflow_mxnet.rst" ,
128
- (Tasks .TEXT_GENERATION , Frameworks .HUGGINGFACE ): "algorithms/text/text_generation_hugging_face.rst" ,
129
- (Tasks .SUMMARIZATION , Frameworks .HUGGINGFACE ): "algorithms/text/text_summarization_hugging_face.rst" ,
136
+ (
137
+ Tasks .TC_EMBEDDING ,
138
+ Frameworks .GLUONCV ,
139
+ ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst" ,
140
+ (
141
+ Tasks .TC_EMBEDDING ,
142
+ Frameworks .TENSORFLOW ,
143
+ ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst" ,
144
+ (
145
+ Tasks .TEXT_GENERATION ,
146
+ Frameworks .HUGGINGFACE ,
147
+ ): "algorithms/text/text_generation_hugging_face.rst" ,
148
+ (
149
+ Tasks .SUMMARIZATION ,
150
+ Frameworks .HUGGINGFACE ,
151
+ ): "algorithms/text/text_summarization_hugging_face.rst" ,
130
152
}
131
153
132
154
@@ -246,7 +268,10 @@ def create_jumpstart_model_table():
246
268
if (string_model_task , TO_FRAMEWORK [model_source ]) in MODALITY_MAP :
247
269
file_content_single_entry = []
248
270
249
- if MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])] not in dynamic_table_files :
271
+ if (
272
+ MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])]
273
+ not in dynamic_table_files
274
+ ):
250
275
file_content_single_entry .append ("\n " )
251
276
file_content_single_entry .append (".. list-table:: Available Models\n " )
252
277
file_content_single_entry .append (" :widths: 50 20 20 20 30 20\n " )
@@ -259,16 +284,18 @@ def create_jumpstart_model_table():
259
284
file_content_single_entry .append (" - Min SDK Version\n " )
260
285
file_content_single_entry .append (" - Problem Type\n " )
261
286
file_content_single_entry .append (" - Source\n " )
262
-
263
- dynamic_table_files .append (MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])])
287
+
288
+ dynamic_table_files .append (
289
+ MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])]
290
+ )
264
291
265
292
file_content_single_entry .append (" * - {}\n " .format (model_spec ["model_id" ]))
266
293
file_content_single_entry .append (" - {}\n " .format (model_spec ["training_supported" ]))
267
294
file_content_single_entry .append (" - {}\n " .format (model ["version" ]))
268
295
file_content_single_entry .append (" - {}\n " .format (model ["min_version" ]))
269
296
file_content_single_entry .append (" - {}\n " .format (model_task ))
270
297
file_content_single_entry .append (
271
- " - `{} <{}>`__ \n " .format (model_source , model_spec ["url" ])
298
+ " - `{} <{}>`__\n " .format (model_source , model_spec ["url" ])
272
299
)
273
300
f = open (MODALITY_MAP [(string_model_task , TO_FRAMEWORK [model_source ])], "a" )
274
301
f .writelines (file_content_single_entry )
0 commit comments