@@ -74,9 +74,12 @@ class Frameworks(str, Enum):
74
74
75
75
JUMPSTART_REGION = "eu-west-2"
76
76
SDK_MANIFEST_FILE = "models_manifest.json"
77
+ PROPRIETARY_SDK_MANIFEST_FILE = "proprietary-sdk-manifest.json"
77
78
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com" .format (
78
79
JUMPSTART_REGION , JUMPSTART_REGION
79
80
)
81
+ PROPRIETARY_DOC_BUCKET = "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com"
82
+
80
83
TASK_MAP = {
81
84
Tasks .IC : ProblemTypes .IMAGE_CLASSIFICATION ,
82
85
Tasks .IC_EMBEDDING : ProblemTypes .IMAGE_EMBEDDING ,
@@ -152,18 +155,26 @@ class Frameworks(str, Enum):
152
155
}
153
156
154
157
155
- def get_jumpstart_sdk_manifest ():
156
- url = "{}/{}" .format (JUMPSTART_BUCKET_BASE_URL , SDK_MANIFEST_FILE )
158
+ def get_public_s3_json_object (url ):
157
159
with request .urlopen (url ) as f :
158
160
models_manifest = f .read ().decode ("utf-8" )
159
161
return json .loads (models_manifest )
160
162
161
163
162
- def get_jumpstart_sdk_spec (key ):
163
- url = "{}/{}" .format (JUMPSTART_BUCKET_BASE_URL , key )
164
- with request .urlopen (url ) as f :
165
- model_spec = f .read ().decode ("utf-8" )
166
- return json .loads (model_spec )
164
+ def get_jumpstart_sdk_manifest ():
165
+ return get_public_s3_json_object (f"{ JUMPSTART_BUCKET_BASE_URL } /{ SDK_MANIFEST_FILE } " )
166
+
167
+
168
+ def get_proprietary_sdk_manifest ():
169
+ return get_public_s3_json_object (f"{ PROPRIETARY_DOC_BUCKET } /{ PROPRIETARY_SDK_MANIFEST_FILE } " )
170
+
171
+
172
+ def get_jumpstart_sdk_spec (s3_key : str ):
173
+ return get_public_s3_json_object (f"{ JUMPSTART_BUCKET_BASE_URL } /{ s3_key } " )
174
+
175
+
176
+ def get_proprietary_sdk_spec (s3_key : str ):
177
+ return get_public_s3_json_object (f"{ PROPRIETARY_DOC_BUCKET } /{ s3_key } " )
167
178
168
179
169
180
def get_model_task (id ):
@@ -196,6 +207,45 @@ def get_model_source(url):
196
207
return "Source"
197
208
198
209
210
+ def create_proprietary_model_table ():
211
+ proprietary_content_intro = []
212
+ proprietary_content_intro .append ("\n " )
213
+ proprietary_content_intro .append (".. list-table:: Available Proprietary Models\n " )
214
+ proprietary_content_intro .append (" :widths: 50 20 20 20 20\n " )
215
+ proprietary_content_intro .append (" :header-rows: 1\n " )
216
+ proprietary_content_intro .append (" :class: datatable\n " )
217
+ proprietary_content_intro .append ("\n " )
218
+ proprietary_content_intro .append (" * - Model ID\n " )
219
+ proprietary_content_intro .append (" - Fine Tunable?\n " )
220
+ proprietary_content_intro .append (" - Supported Version\n " )
221
+ proprietary_content_intro .append (" - Min SDK Version\n " )
222
+ proprietary_content_intro .append (" - Source\n " )
223
+
224
+ sdk_manifest = get_proprietary_sdk_manifest ()
225
+ sdk_manifest_top_versions_for_models = {}
226
+
227
+ for model in sdk_manifest :
228
+ if model ["model_id" ] not in sdk_manifest_top_versions_for_models :
229
+ sdk_manifest_top_versions_for_models [model ["model_id" ]] = model
230
+ else :
231
+ if str (sdk_manifest_top_versions_for_models [model ["model_id" ]]["version" ]) < str (
232
+ model ["version" ]
233
+ ):
234
+ sdk_manifest_top_versions_for_models [model ["model_id" ]] = model
235
+
236
+ proprietary_content_entries = []
237
+ for model in sdk_manifest_top_versions_for_models .values ():
238
+ model_spec = get_proprietary_sdk_spec (model ["spec_key" ])
239
+ proprietary_content_entries .append (" * - {}\n " .format (model_spec ["model_id" ]))
240
+ proprietary_content_entries .append (" - {}\n " .format (False )) # TODO: support training
241
+ proprietary_content_entries .append (" - {}\n " .format (model ["version" ]))
242
+ proprietary_content_entries .append (" - {}\n " .format (model ["min_version" ]))
243
+ proprietary_content_entries .append (
244
+ " - `{} <{}>`__ |external-link|\n " .format ("Source" , model_spec .get ("url" ))
245
+ )
246
+ return proprietary_content_intro + proprietary_content_entries + ["\n " ]
247
+
248
+
199
249
def create_jumpstart_model_table ():
200
250
sdk_manifest = get_jumpstart_sdk_manifest ()
201
251
sdk_manifest_top_versions_for_models = {}
@@ -249,19 +299,19 @@ def create_jumpstart_model_table():
249
299
file_content_intro .append (" - Source\n " )
250
300
251
301
dynamic_table_files = []
252
- file_content_entries = []
302
+ open_weight_content_entries = []
253
303
254
304
for model in sdk_manifest_top_versions_for_models .values ():
255
305
model_spec = get_jumpstart_sdk_spec (model ["spec_key" ])
256
306
model_task = get_model_task (model_spec ["model_id" ])
257
307
string_model_task = get_string_model_task (model_spec ["model_id" ])
258
308
model_source = get_model_source (model_spec ["url" ])
259
- file_content_entries .append (" * - {}\n " .format (model_spec ["model_id" ]))
260
- file_content_entries .append (" - {}\n " .format (model_spec ["training_supported" ]))
261
- file_content_entries .append (" - {}\n " .format (model ["version" ]))
262
- file_content_entries .append (" - {}\n " .format (model ["min_version" ]))
263
- file_content_entries .append (" - {}\n " .format (model_task ))
264
- file_content_entries .append (
309
+ open_weight_content_entries .append (" * - {}\n " .format (model_spec ["model_id" ]))
310
+ open_weight_content_entries .append (" - {}\n " .format (model_spec ["training_supported" ]))
311
+ open_weight_content_entries .append (" - {}\n " .format (model ["version" ]))
312
+ open_weight_content_entries .append (" - {}\n " .format (model ["min_version" ]))
313
+ open_weight_content_entries .append (" - {}\n " .format (model_task ))
314
+ open_weight_content_entries .append (
265
315
" - `{} <{}>`__ |external-link|\n " .format (model_source , model_spec ["url" ])
266
316
)
267
317
@@ -299,7 +349,10 @@ def create_jumpstart_model_table():
299
349
f .writelines (file_content_single_entry )
300
350
f .close ()
301
351
352
+ proprietary_content_entries = create_proprietary_model_table ()
353
+
302
354
f = open ("doc_utils/pretrainedmodels.rst" , "a" )
303
355
f .writelines (file_content_intro )
304
- f .writelines (file_content_entries )
356
+ f .writelines (open_weight_content_entries )
357
+ f .writelines (proprietary_content_entries )
305
358
f .close ()
0 commit comments