Skip to content

Commit e4120da

Browse files
author
Raghav Dhall
committed
documentation: small fixes for unit tests
1 parent 468dadd commit e4120da

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

doc/doc_utils/jumpstart_doc_utils.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ProblemTypes(str, Enum):
5757
TABULAR_REGRESSION = "Regression"
5858
TABULAR_CLASSIFICATION = "Classification"
5959

60+
6061
class Frameworks(str, Enum):
6162
"""Possible frameworks for JumpStart models"""
6263

@@ -104,7 +105,7 @@ class Frameworks(str, Enum):
104105
"LightGBM": Frameworks.LIGHTGBM,
105106
"XGBoost": Frameworks.XGBOOST,
106107
"ScikitLearn": Frameworks.SCIKIT_LEARN,
107-
"Source": Frameworks.SOURCE
108+
"Source": Frameworks.SOURCE,
108109
}
109110

110111

@@ -117,16 +118,37 @@ class Frameworks(str, Enum):
117118
(Tasks.OD, Frameworks.PYTORCH): "algorithms/vision/object_detection_pytorch.rst",
118119
(Tasks.OD, Frameworks.TENSORFLOW): "algorithms/vision/object_detection_tensorflow.rst",
119120
(Tasks.SEMSEG, Frameworks.GLUONCV): "algorithms/vision/semantic_segmentation_mxnet.rst",
120-
(Tasks.TRANSLATION, Frameworks.HUGGINGFACE): "algorithms/text/machine_translation_hugging_face.rst",
121+
(
122+
Tasks.TRANSLATION,
123+
Frameworks.HUGGINGFACE,
124+
): "algorithms/text/machine_translation_hugging_face.rst",
121125
(Tasks.NER, Frameworks.GLUONCV): "algorithms/text/named_entity_recognition_hugging_face.rst",
122126
(Tasks.EQA, Frameworks.PYTORCH): "algorithms/text/question_answering_pytorch.rst",
123-
(Tasks.SPC, Frameworks.HUGGINGFACE): "algorithms/text/sentence_pair_classification_hugging_face.rst",
124-
(Tasks.SPC, Frameworks.TENSORFLOW): "algorithms/text/sentence_pair_classification_tensorflow.rst",
127+
(
128+
Tasks.SPC,
129+
Frameworks.HUGGINGFACE,
130+
): "algorithms/text/sentence_pair_classification_hugging_face.rst",
131+
(
132+
Tasks.SPC,
133+
Frameworks.TENSORFLOW,
134+
): "algorithms/text/sentence_pair_classification_tensorflow.rst",
125135
(Tasks.TC, Frameworks.TENSORFLOW): "algorithms/text/text_classification_tensorflow.rst",
126-
(Tasks.TC_EMBEDDING, Frameworks.GLUONCV): "algorithms/text/text_embedding_tensorflow_mxnet.rst",
127-
(Tasks.TC_EMBEDDING, Frameworks.TENSORFLOW): "algorithms/text/text_embedding_tensorflow_mxnet.rst",
128-
(Tasks.TEXT_GENERATION, Frameworks.HUGGINGFACE): "algorithms/text/text_generation_hugging_face.rst",
129-
(Tasks.SUMMARIZATION, Frameworks.HUGGINGFACE): "algorithms/text/text_summarization_hugging_face.rst",
136+
(
137+
Tasks.TC_EMBEDDING,
138+
Frameworks.GLUONCV,
139+
): "algorithms/vision/text_embedding_tensorflow_mxnet.rst",
140+
(
141+
Tasks.TC_EMBEDDING,
142+
Frameworks.TENSORFLOW,
143+
): "algorithms/vision/text_embedding_tensorflow_mxnet.rst",
144+
(
145+
Tasks.TEXT_GENERATION,
146+
Frameworks.HUGGINGFACE,
147+
): "algorithms/text/text_generation_hugging_face.rst",
148+
(
149+
Tasks.SUMMARIZATION,
150+
Frameworks.HUGGINGFACE,
151+
): "algorithms/text/text_summarization_hugging_face.rst",
130152
}
131153

132154

@@ -246,7 +268,10 @@ def create_jumpstart_model_table():
246268
if (string_model_task, TO_FRAMEWORK[model_source]) in MODALITY_MAP:
247269
file_content_single_entry = []
248270

249-
if MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])] not in dynamic_table_files:
271+
if (
272+
MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])]
273+
not in dynamic_table_files
274+
):
250275
file_content_single_entry.append("\n")
251276
file_content_single_entry.append(".. list-table:: Available Models\n")
252277
file_content_single_entry.append(" :widths: 50 20 20 20 30 20\n")
@@ -259,16 +284,18 @@ def create_jumpstart_model_table():
259284
file_content_single_entry.append(" - Min SDK Version\n")
260285
file_content_single_entry.append(" - Problem Type\n")
261286
file_content_single_entry.append(" - Source\n")
262-
263-
dynamic_table_files.append(MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])])
287+
288+
dynamic_table_files.append(
289+
MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])]
290+
)
264291

265292
file_content_single_entry.append(" * - {}\n".format(model_spec["model_id"]))
266293
file_content_single_entry.append(" - {}\n".format(model_spec["training_supported"]))
267294
file_content_single_entry.append(" - {}\n".format(model["version"]))
268295
file_content_single_entry.append(" - {}\n".format(model["min_version"]))
269296
file_content_single_entry.append(" - {}\n".format(model_task))
270297
file_content_single_entry.append(
271-
" - `{} <{}>`__ \n".format(model_source, model_spec["url"])
298+
" - `{} <{}>`__\n".format(model_source, model_spec["url"])
272299
)
273300
f = open(MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])], "a")
274301
f.writelines(file_content_single_entry)

0 commit comments

Comments
 (0)