-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: support JumpStart proprietary models #4467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 34 commits
5aeaa35
4d58cfd
8ba6477
29ac23c
57a7d37
a8ffdc2
d71b727
359ea1c
33fc27e
04e1376
f74a3e4
a96ea08
cb8aee8
e21b98b
feef2a2
33a2b59
b045b44
8c90641
5bccb8e
f8258b7
896a2cf
d701211
e3e64ba
d5b9b76
27e14b9
ec816c3
07fa93e
bceb17b
6fb885e
22d2078
40542b7
abfadca
46ae293
5f053d5
9ec6f8e
c1c2a6f
2e70c75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,9 +74,12 @@ class Frameworks(str, Enum): | |
|
||
JUMPSTART_REGION = "eu-west-2" | ||
SDK_MANIFEST_FILE = "models_manifest.json" | ||
PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json" | ||
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format( | ||
JUMPSTART_REGION, JUMPSTART_REGION | ||
) | ||
PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com" | ||
|
||
TASK_MAP = { | ||
Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION, | ||
Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING, | ||
|
@@ -152,18 +155,26 @@ class Frameworks(str, Enum): | |
} | ||
|
||
|
||
def get_jumpstart_sdk_manifest(): | ||
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE) | ||
def get_public_s3_json_object(url): | ||
with request.urlopen(url) as f: | ||
models_manifest = f.read().decode("utf-8") | ||
return json.loads(models_manifest) | ||
|
||
|
||
def get_jumpstart_sdk_spec(key): | ||
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key) | ||
with request.urlopen(url) as f: | ||
model_spec = f.read().decode("utf-8") | ||
return json.loads(model_spec) | ||
def get_jumpstart_sdk_manifest(): | ||
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}") | ||
|
||
|
||
def get_proprietary_sdk_manifest(): | ||
Captainia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}") | ||
|
||
|
||
def get_jumpstart_sdk_spec(s3_key: str): | ||
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}") | ||
|
||
|
||
def get_proprietary_sdk_spec(s3_key: str): | ||
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}") | ||
|
||
|
||
def get_model_task(id): | ||
|
@@ -196,6 +207,45 @@ def get_model_source(url): | |
return "Source" | ||
|
||
|
||
def create_proprietary_model_table(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is wild, I would just use f strings with triple quotes so that you can have line breaks and have a better visual of what the final string output would look like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, I wasn't refactoring the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you run this by @judyheflin , as this is a documentation change? |
||
proprietary_content_intro = [] | ||
proprietary_content_intro.append("\n") | ||
proprietary_content_intro.append(".. list-table:: Available Proprietary Models\n") | ||
proprietary_content_intro.append(" :widths: 50 20 20 20 20\n") | ||
proprietary_content_intro.append(" :header-rows: 1\n") | ||
proprietary_content_intro.append(" :class: datatable\n") | ||
proprietary_content_intro.append("\n") | ||
proprietary_content_intro.append(" * - Model ID\n") | ||
proprietary_content_intro.append(" - Fine Tunable?\n") | ||
proprietary_content_intro.append(" - Supported Version\n") | ||
proprietary_content_intro.append(" - Min SDK Version\n") | ||
proprietary_content_intro.append(" - Source\n") | ||
|
||
sdk_manifest = get_proprietary_sdk_manifest() | ||
sdk_manifest_top_versions_for_models = {} | ||
|
||
for model in sdk_manifest: | ||
if model["model_id"] not in sdk_manifest_top_versions_for_models: | ||
sdk_manifest_top_versions_for_models[model["model_id"]] = model | ||
else: | ||
if str(sdk_manifest_top_versions_for_models[model["model_id"]]["version"]) < str( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we not casting to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it is, |
||
model["version"] | ||
): | ||
sdk_manifest_top_versions_for_models[model["model_id"]] = model | ||
|
||
proprietary_content_entries = [] | ||
for model in sdk_manifest_top_versions_for_models.values(): | ||
model_spec = get_proprietary_sdk_spec(model["spec_key"]) | ||
proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"])) | ||
proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training | ||
proprietary_content_entries.append(" - {}\n".format(model["version"])) | ||
proprietary_content_entries.append(" - {}\n".format(model["min_version"])) | ||
proprietary_content_entries.append( | ||
" - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url")) | ||
) | ||
return proprietary_content_intro + proprietary_content_entries + ["\n"] | ||
|
||
|
||
def create_jumpstart_model_table(): | ||
sdk_manifest = get_jumpstart_sdk_manifest() | ||
sdk_manifest_top_versions_for_models = {} | ||
|
@@ -249,19 +299,19 @@ def create_jumpstart_model_table(): | |
file_content_intro.append(" - Source\n") | ||
|
||
dynamic_table_files = [] | ||
file_content_entries = [] | ||
open_weight_content_entries = [] | ||
|
||
for model in sdk_manifest_top_versions_for_models.values(): | ||
model_spec = get_jumpstart_sdk_spec(model["spec_key"]) | ||
model_task = get_model_task(model_spec["model_id"]) | ||
string_model_task = get_string_model_task(model_spec["model_id"]) | ||
model_source = get_model_source(model_spec["url"]) | ||
file_content_entries.append(" * - {}\n".format(model_spec["model_id"])) | ||
file_content_entries.append(" - {}\n".format(model_spec["training_supported"])) | ||
file_content_entries.append(" - {}\n".format(model["version"])) | ||
file_content_entries.append(" - {}\n".format(model["min_version"])) | ||
file_content_entries.append(" - {}\n".format(model_task)) | ||
file_content_entries.append( | ||
open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"])) | ||
open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"])) | ||
open_weight_content_entries.append(" - {}\n".format(model["version"])) | ||
open_weight_content_entries.append(" - {}\n".format(model["min_version"])) | ||
open_weight_content_entries.append(" - {}\n".format(model_task)) | ||
open_weight_content_entries.append( | ||
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) | ||
) | ||
|
||
|
@@ -299,7 +349,10 @@ def create_jumpstart_model_table(): | |
f.writelines(file_content_single_entry) | ||
f.close() | ||
|
||
proprietary_content_entries = create_proprietary_model_table() | ||
|
||
f = open("doc_utils/pretrainedmodels.rst", "a") | ||
f.writelines(file_content_intro) | ||
f.writelines(file_content_entries) | ||
f.writelines(open_weight_content_entries) | ||
f.writelines(proprietary_content_entries) | ||
f.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For open-weights models we point to a wave-4 region,
DUB
IIRC.Could you use the same for consistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just found an issue with regionalization in some regions, raising a fix today. I'll update this to
DUB
in a separate PR to update documentation oflist_jumpstart_models
, and I'd like to refactor the doc utils in that PR as well