@@ -31,6 +31,7 @@ def retrieve_default(
31
31
region : Optional [str ] = None ,
32
32
model_id : Optional [str ] = None ,
33
33
model_version : Optional [str ] = None ,
34
+ hub_arn : Optional [str ] = None ,
34
35
instance_type : Optional [str ] = None ,
35
36
include_container_hyperparameters : bool = False ,
36
37
tolerate_vulnerable_model : bool = False ,
@@ -46,6 +47,8 @@ def retrieve_default(
46
47
retrieve the default hyperparameters. (Default: None).
47
48
model_version (str): The version of the model for which to retrieve the
48
49
default hyperparameters. (Default: None).
50
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51
+ model details from. (default: None).
49
52
instance_type (str): An instance type to optionally supply in order to get hyperparameters
50
53
specific for the instance type.
51
54
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -80,6 +83,7 @@ def retrieve_default(
80
83
return artifacts ._retrieve_default_hyperparameters (
81
84
model_id = model_id ,
82
85
model_version = model_version ,
86
+ hub_arn = hub_arn ,
83
87
instance_type = instance_type ,
84
88
region = region ,
85
89
include_container_hyperparameters = include_container_hyperparameters ,
@@ -93,6 +97,7 @@ def validate(
93
97
region : Optional [str ] = None ,
94
98
model_id : Optional [str ] = None ,
95
99
model_version : Optional [str ] = None ,
100
+ hub_arn : Optional [str ] = None ,
96
101
hyperparameters : Optional [dict ] = None ,
97
102
validation_mode : HyperparameterValidationMode = HyperparameterValidationMode .VALIDATE_PROVIDED ,
98
103
tolerate_vulnerable_model : bool = False ,
@@ -107,6 +112,8 @@ def validate(
107
112
(Default: None).
108
113
model_version (str): The version of the model for which to validate hyperparameters.
109
114
(Default: None).
115
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116
+ model details from. (default: None).
110
117
hyperparameters (dict): Hyperparameters to validate.
111
118
(Default: None).
112
119
validation_mode (HyperparameterValidationMode): Method of validation to use with
@@ -148,6 +155,7 @@ def validate(
148
155
return validate_hyperparameters (
149
156
model_id = model_id ,
150
157
model_version = model_version ,
158
+ hub_arn = hub_arn ,
151
159
hyperparameters = hyperparameters ,
152
160
validation_mode = validation_mode ,
153
161
region = region ,
0 commit comments