Skip to content

Commit d830446

Browse files
committed
chore: update jumpstart model table
1 parent 564e454 commit d830446

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

doc/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
html_js_files = [
7878
"https://a0.awsstatic.com/s_code/js/3.0/awshome_s_code.js",
7979
"https://cdn.datatables.net/1.10.23/js/jquery.dataTables.min.js",
80+
"https://kit.fontawesome.com/a076d05399.js",
8081
"js/datatable.js",
8182
]
8283

doc/doc_utils/jumpstart_doc_utils.py

+35-9
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,29 @@ def get_model_task(id):
102102
return TASK_MAP[task_short] if task_short in TASK_MAP else "Source"
103103

104104

105+
def get_model_source(url):
106+
if 'tfhub' in url:
107+
return 'Tensorflow Hub'
108+
if 'pytorch' in url:
109+
return 'Pytorch Hub'
110+
if 'huggingface' in url:
111+
return "HuggingFace"
112+
if 'catboost' in url:
113+
return "Catboost"
114+
if 'gluon' in url:
115+
return "GluonCV"
116+
if 'catboost' in url:
117+
return 'Catboost'
118+
if 'lightgbm' in url:
119+
return 'LightGBM'
120+
if 'xgboost' in url:
121+
return 'XGBoost'
122+
if 'scikit' in url:
123+
return 'ScikitLearn'
124+
else:
125+
return 'Source'
126+
127+
105128
def create_jumpstart_model_table():
106129
sdk_manifest = get_jumpstart_sdk_manifest()
107130
sdk_manifest_top_versions_for_models = {}
@@ -117,47 +140,50 @@ def create_jumpstart_model_table():
117140

118141
file_content = []
119142

143+
file_content.append('.. |external-link| raw:: html\n\n')
144+
file_content.append(' <i class="fa fa-external-link"></i>\n\n')
145+
120146
file_content.append("==================================\n")
121147
file_content.append("JumpStart Available Model Table\n")
122148
file_content.append("==================================\n")
123149
file_content.append(
124150
"""
125151
JumpStart for the SageMaker Python SDK uses model IDs and model versions to access the necessary
126152
utilities. This table serves to provide the core material plus some extra information that can be useful
127-
in selecting the correct model ID and corresponding parameters.\n
128-
"""
153+
in selecting the correct model ID and corresponding parameters.\n"""
129154
)
130155
file_content.append(
131156
"""
132157
If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
133-
We highly suggest pinning an exact model version however.\n
134-
"""
158+
We highly suggest pinning an exact model version however.\n"""
135159
)
136160
file_content.append(
137161
"""
138-
Click on the Problem Type to navigate to the source of the model.\n
139-
"""
162+
Click on the Problem Type to navigate to the source of the model.\n"""
140163
)
141164
file_content.append("\n")
142165
file_content.append(".. list-table:: Available Models\n")
143-
file_content.append(" :widths: 50 20 20 20 30\n")
166+
file_content.append(" :widths: 50 20 20 20 30 20\n")
144167
file_content.append(" :header-rows: 1\n")
145168
file_content.append(" :class: datatable\n")
146169
file_content.append("\n")
147170
file_content.append(" * - Model ID\n")
148171
file_content.append(" - Fine Tunable?\n")
149172
file_content.append(" - Latest Version\n")
150173
file_content.append(" - Min SDK Version\n")
151-
file_content.append(" - Problem Type/Source\n")
174+
file_content.append(" - Problem Type\n")
175+
file_content.append(" - Source\n")
152176

153177
for model in sdk_manifest_top_versions_for_models.values():
154178
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
155179
model_task = get_model_task(model_spec["model_id"])
180+
model_source = get_model_source(model_spec["url"])
156181
file_content.append(" * - {}\n".format(model_spec["model_id"]))
157182
file_content.append(" - {}\n".format(model_spec["training_supported"]))
158183
file_content.append(" - {}\n".format(model["version"]))
159184
file_content.append(" - {}\n".format(model["min_version"]))
160-
file_content.append(" - `{} <{}>`__\n".format(model_task, model_spec["url"]))
185+
file_content.append(" - {}\n".format(model_task))
186+
file_content.append(" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]))
161187

162188
f = open("doc_utils/jumpstart.rst", "w")
163189
f.writelines(file_content)

0 commit comments

Comments
 (0)