diff --git a/doc/conf.py b/doc/conf.py index b127715aa7..b7007175d9 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -77,6 +77,7 @@ html_js_files = [ "https://a0.awsstatic.com/s_code/js/3.0/awshome_s_code.js", "https://cdn.datatables.net/1.10.23/js/jquery.dataTables.min.js", + "https://kit.fontawesome.com/a076d05399.js", "js/datatable.js", ] diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 4d713283ba..d2658dca30 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -102,6 +102,29 @@ def get_model_task(id): return TASK_MAP[task_short] if task_short in TASK_MAP else "Source" +def get_model_source(url): + if "tfhub" in url: + return "Tensorflow Hub" + if "pytorch" in url: + return "Pytorch Hub" + if "huggingface" in url: + return "HuggingFace" + if "catboost" in url: + return "Catboost" + if "gluon" in url: + return "GluonCV" + if "catboost" in url: + return "Catboost" + if "lightgbm" in url: + return "LightGBM" + if "xgboost" in url: + return "XGBoost" + if "scikit" in url: + return "ScikitLearn" + else: + return "Source" + + def create_jumpstart_model_table(): sdk_manifest = get_jumpstart_sdk_manifest() sdk_manifest_top_versions_for_models = {} @@ -117,6 +140,9 @@ def create_jumpstart_model_table(): file_content = [] + file_content.append(".. |external-link| raw:: html\n\n") + file_content.append(' \n\n') + file_content.append("==================================\n") file_content.append("JumpStart Available Model Table\n") file_content.append("==================================\n") @@ -124,23 +150,16 @@ def create_jumpstart_model_table(): """ JumpStart for the SageMaker Python SDK uses model IDs and model versions to access the necessary utilities. This table serves to provide the core material plus some extra information that can be useful - in selecting the correct model ID and corresponding parameters.\n - """ + in selecting the correct model ID and corresponding parameters.\n""" ) file_content.append( """ If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute. - We highly suggest pinning an exact model version however.\n - """ - ) - file_content.append( - """ - Click on the Problem Type to navigate to the source of the model.\n - """ + We highly suggest pinning an exact model version however.\n""" ) file_content.append("\n") file_content.append(".. list-table:: Available Models\n") - file_content.append(" :widths: 50 20 20 20 30\n") + file_content.append(" :widths: 50 20 20 20 30 20\n") file_content.append(" :header-rows: 1\n") file_content.append(" :class: datatable\n") file_content.append("\n") @@ -148,16 +167,21 @@ def create_jumpstart_model_table(): file_content.append(" - Fine Tunable?\n") file_content.append(" - Latest Version\n") file_content.append(" - Min SDK Version\n") - file_content.append(" - Problem Type/Source\n") + file_content.append(" - Problem Type\n") + file_content.append(" - Source\n") for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) model_task = get_model_task(model_spec["model_id"]) + model_source = get_model_source(model_spec["url"]) file_content.append(" * - {}\n".format(model_spec["model_id"])) file_content.append(" - {}\n".format(model_spec["training_supported"])) file_content.append(" - {}\n".format(model["version"])) file_content.append(" - {}\n".format(model["min_version"])) - file_content.append(" - `{} <{}>`__\n".format(model_task, model_spec["url"])) + file_content.append(" - {}\n".format(model_task)) + file_content.append( + " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) + ) f = open("doc_utils/jumpstart.rst", "w") f.writelines(file_content)