Skip to content

Commit 4d82bdc

Browse files
committed
chore: update source column
1 parent 95b1056 commit 4d82bdc

File tree

1 file changed

+69
-4
lines changed

1 file changed

+69
-4
lines changed

doc/doc_utils/jumpstart_doc_utils.py

+69-4
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,71 @@
1313
from __future__ import absolute_import
1414
from urllib import request
1515
import json
16-
from packaging.version import Version
16+
from packaging.version import Version
17+
from enum import Enum
18+
19+
class Tasks(str, Enum):
20+
"""The ML task name as referenced in the infix of the model ID."""
21+
22+
IC = "ic"
23+
OD = "od"
24+
OD1 = "od1"
25+
SEMSEG = "semseg"
26+
IS = "is"
27+
TC = "tc"
28+
SPC = "spc"
29+
EQA = "eqa"
30+
TEXT_GENERATION = "textgeneration"
31+
IC_EMBEDDING = "icembedding"
32+
TC_EMBEDDING = "tcembedding"
33+
NER = "ner"
34+
SUMMARIZATION = "summarization"
35+
TRANSLATION = "translation"
36+
TABULAR_REGRESSION = "regression"
37+
TABULAR_CLASSIFICATION = "classification"
38+
39+
class ProblemTypes(str, Enum):
40+
"""Possible problem types for JumpStart models."""
41+
42+
IMAGE_CLASSIFICATION = "Image Classification"
43+
IMAGE_EMBEDDING = "Image Embedding"
44+
OBJECT_DETECTION = "Object Detection"
45+
SEMANTIC_SEGMENTATION = "Semantic Segmentation"
46+
INSTANCE_SEGMENTATION = "Instance Segmentation"
47+
TEXT_CLASSIFICATION = "Text Classification"
48+
TEXT_EMBEDDING = "Text Embedding"
49+
QUESTION_ANSWERING = "Question Answering"
50+
SENTENCE_PAIR_CLASSIFICATION = "Sentence Pair Classification"
51+
TEXT_GENERATION = "Text Generation"
52+
TEXT_SUMMARIZATION = "Text Summarization"
53+
MACHINE_TRANSLATION = "Machine Translation"
54+
NAMED_ENTITY_RECOGNITION = "Named Entity Recognition"
55+
TABULAR_REGRESSION = "Regression"
56+
TABULAR_CLASSIFICATION = "Classification"
1757

1858
JUMPSTART_REGION = "eu-west-2"
1959
SDK_MANIFEST_FILE = "models_manifest.json"
2060
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
2161
JUMPSTART_REGION, JUMPSTART_REGION
2262
)
63+
TASK_MAP = {
64+
Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION,
65+
Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING,
66+
Tasks.OD: ProblemTypes.OBJECT_DETECTION,
67+
Tasks.OD1: ProblemTypes.OBJECT_DETECTION,
68+
Tasks.SEMSEG: ProblemTypes.SEMANTIC_SEGMENTATION,
69+
Tasks.IS: ProblemTypes.INSTANCE_SEGMENTATION,
70+
Tasks.TC: ProblemTypes.TEXT_CLASSIFICATION,
71+
Tasks.TC_EMBEDDING: ProblemTypes.TEXT_EMBEDDING,
72+
Tasks.EQA: ProblemTypes.QUESTION_ANSWERING,
73+
Tasks.SPC: ProblemTypes.SENTENCE_PAIR_CLASSIFICATION,
74+
Tasks.TEXT_GENERATION: ProblemTypes.TEXT_GENERATION,
75+
Tasks.SUMMARIZATION: ProblemTypes.TEXT_SUMMARIZATION,
76+
Tasks.TRANSLATION: ProblemTypes.MACHINE_TRANSLATION,
77+
Tasks.NER: ProblemTypes.NAMED_ENTITY_RECOGNITION,
78+
Tasks.TABULAR_REGRESSION: ProblemTypes.TABULAR_REGRESSION,
79+
Tasks.TABULAR_CLASSIFICATION: ProblemTypes.TABULAR_CLASSIFICATION,
80+
}
2381

2482

2583
def get_jumpstart_sdk_manifest():
@@ -35,6 +93,10 @@ def get_jumpstart_sdk_spec(key):
3593
model_spec = f.read().decode("utf-8")
3694
return json.loads(model_spec)
3795

96+
def get_model_task(id):
97+
task_short = id.split('-')[1]
98+
return TASK_MAP[task_short] if task_short in TASK_MAP else 'Source'
99+
38100

39101
def create_jumpstart_model_table():
40102
sdk_manifest = get_jumpstart_sdk_manifest()
@@ -69,26 +131,29 @@ def create_jumpstart_model_table():
69131
)
70132
file_content.append(
71133
"""
72-
Each model id is linked to an external page that describes the model.\n
134+
Click on the Problem Type to navigate to the source of the model.\n
73135
"""
74136
)
75137
file_content.append("\n")
76138
file_content.append(".. list-table:: Available Models\n")
77-
file_content.append(" :widths: 50 20 20 20\n")
139+
file_content.append(" :widths: 50 20 20 20 30\n")
78140
file_content.append(" :header-rows: 1\n")
79141
file_content.append(" :class: datatable\n")
80142
file_content.append("\n")
81143
file_content.append(" * - Model ID\n")
82144
file_content.append(" - Fine Tunable?\n")
83145
file_content.append(" - Latest Version\n")
84146
file_content.append(" - Min SDK Version\n")
147+
file_content.append(" - Problem Type/Source\n")
85148

86149
for model in sdk_manifest_top_versions_for_models.values():
87150
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
88-
file_content.append(" * - `{} <{}>`_\n".format(model_spec["model_id"], model_spec["url"]))
151+
model_task = get_model_task(model_spec["model_id"])
152+
file_content.append(" * - {}\n".format(model_spec["model_id"]))
89153
file_content.append(" - {}\n".format(model_spec["training_supported"]))
90154
file_content.append(" - {}\n".format(model["version"]))
91155
file_content.append(" - {}\n".format(model["min_version"]))
156+
file_content.append(" - `{} <{}>`__\n".format(model_task, model_spec["url"]))
92157

93158
f = open("doc_utils/jumpstart.rst", "w")
94159
f.writelines(file_content)

0 commit comments

Comments
 (0)