Skip to content

Commit 57d4763

Browse files
Feat/jumpstart model table update (#3087)
Co-authored-by: Shreya Pandit <[email protected]>
1 parent 0aea163 commit 57d4763

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

doc/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
html_js_files = [
7878
"https://a0.awsstatic.com/s_code/js/3.0/awshome_s_code.js",
7979
"https://cdn.datatables.net/1.10.23/js/jquery.dataTables.min.js",
80+
"https://kit.fontawesome.com/a076d05399.js",
8081
"js/datatable.js",
8182
]
8283

doc/doc_utils/jumpstart_doc_utils.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,29 @@ def get_model_task(id):
102102
return TASK_MAP[task_short] if task_short in TASK_MAP else "Source"
103103

104104

105+
def get_model_source(url):
106+
if "tfhub" in url:
107+
return "Tensorflow Hub"
108+
if "pytorch" in url:
109+
return "Pytorch Hub"
110+
if "huggingface" in url:
111+
return "HuggingFace"
112+
if "catboost" in url:
113+
return "Catboost"
114+
if "gluon" in url:
115+
return "GluonCV"
116+
if "catboost" in url:
117+
return "Catboost"
118+
if "lightgbm" in url:
119+
return "LightGBM"
120+
if "xgboost" in url:
121+
return "XGBoost"
122+
if "scikit" in url:
123+
return "ScikitLearn"
124+
else:
125+
return "Source"
126+
127+
105128
def create_jumpstart_model_table():
106129
sdk_manifest = get_jumpstart_sdk_manifest()
107130
sdk_manifest_top_versions_for_models = {}
@@ -117,47 +140,48 @@ def create_jumpstart_model_table():
117140

118141
file_content = []
119142

143+
file_content.append(".. |external-link| raw:: html\n\n")
144+
file_content.append(' <i class="fa fa-external-link"></i>\n\n')
145+
120146
file_content.append("==================================\n")
121147
file_content.append("JumpStart Available Model Table\n")
122148
file_content.append("==================================\n")
123149
file_content.append(
124150
"""
125151
JumpStart for the SageMaker Python SDK uses model IDs and model versions to access the necessary
126152
utilities. This table serves to provide the core material plus some extra information that can be useful
127-
in selecting the correct model ID and corresponding parameters.\n
128-
"""
153+
in selecting the correct model ID and corresponding parameters.\n"""
129154
)
130155
file_content.append(
131156
"""
132157
If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
133-
We highly suggest pinning an exact model version however.\n
134-
"""
135-
)
136-
file_content.append(
137-
"""
138-
Click on the Problem Type to navigate to the source of the model.\n
139-
"""
158+
We highly suggest pinning an exact model version however.\n"""
140159
)
141160
file_content.append("\n")
142161
file_content.append(".. list-table:: Available Models\n")
143-
file_content.append(" :widths: 50 20 20 20 30\n")
162+
file_content.append(" :widths: 50 20 20 20 30 20\n")
144163
file_content.append(" :header-rows: 1\n")
145164
file_content.append(" :class: datatable\n")
146165
file_content.append("\n")
147166
file_content.append(" * - Model ID\n")
148167
file_content.append(" - Fine Tunable?\n")
149168
file_content.append(" - Latest Version\n")
150169
file_content.append(" - Min SDK Version\n")
151-
file_content.append(" - Problem Type/Source\n")
170+
file_content.append(" - Problem Type\n")
171+
file_content.append(" - Source\n")
152172

153173
for model in sdk_manifest_top_versions_for_models.values():
154174
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
155175
model_task = get_model_task(model_spec["model_id"])
176+
model_source = get_model_source(model_spec["url"])
156177
file_content.append(" * - {}\n".format(model_spec["model_id"]))
157178
file_content.append(" - {}\n".format(model_spec["training_supported"]))
158179
file_content.append(" - {}\n".format(model["version"]))
159180
file_content.append(" - {}\n".format(model["min_version"]))
160-
file_content.append(" - `{} <{}>`__\n".format(model_task, model_spec["url"]))
181+
file_content.append(" - {}\n".format(model_task))
182+
file_content.append(
183+
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
184+
)
161185

162186
f = open("doc_utils/jumpstart.rst", "w")
163187
f.writelines(file_content)

0 commit comments

Comments
 (0)