13
13
from __future__ import absolute_import
14
14
from urllib import request
15
15
import json
16
- from packaging .version import Version
16
+ from packaging .version import Version
17
17
from enum import Enum
18
18
19
+
19
20
class Tasks (str , Enum ):
20
21
"""The ML task name as referenced in the infix of the model ID."""
21
22
@@ -36,6 +37,7 @@ class Tasks(str, Enum):
36
37
TABULAR_REGRESSION = "regression"
37
38
TABULAR_CLASSIFICATION = "classification"
38
39
40
+
39
41
class ProblemTypes (str , Enum ):
40
42
"""Possible problem types for JumpStart models."""
41
43
@@ -55,6 +57,7 @@ class ProblemTypes(str, Enum):
55
57
TABULAR_REGRESSION = "Regression"
56
58
TABULAR_CLASSIFICATION = "Classification"
57
59
60
+
58
61
JUMPSTART_REGION = "eu-west-2"
59
62
SDK_MANIFEST_FILE = "models_manifest.json"
60
63
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com" .format (
@@ -93,9 +96,10 @@ def get_jumpstart_sdk_spec(key):
93
96
model_spec = f .read ().decode ("utf-8" )
94
97
return json .loads (model_spec )
95
98
99
+
96
100
def get_model_task (id ):
97
- task_short = id .split ('-' )[1 ]
98
- return TASK_MAP [task_short ] if task_short in TASK_MAP else ' Source'
101
+ task_short = id .split ("-" )[1 ]
102
+ return TASK_MAP [task_short ] if task_short in TASK_MAP else " Source"
99
103
100
104
101
105
def create_jumpstart_model_table ():
0 commit comments