@@ -6128,7 +6128,14 @@ def update_args(args: Dict[str, Any], **kwargs):
6128
6128
args .update ({key : value })
6129
6129
6130
6130
6131
- def container_def (image_uri , model_data_url = None , env = None , container_mode = None , image_config = None ):
6131
+ def container_def (
6132
+ image_uri ,
6133
+ model_data_url = None ,
6134
+ env = None ,
6135
+ container_mode = None ,
6136
+ image_config = None ,
6137
+ accept_eula = None ,
6138
+ ):
6132
6139
"""Create a definition for executing a container as part of a SageMaker model.
6133
6140
6134
6141
Args:
@@ -6145,6 +6152,11 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
6145
6152
image_config (dict[str, str]): Specifies whether the image of model container is pulled
6146
6153
from ECR, or private registry in your VPC. By default it is set to pull model
6147
6154
container image from ECR. (default: None).
6155
+ accept_eula (bool): For models that require a Model Access Config, specify True or
6156
+ False to indicate whether model terms of use have been accepted.
6157
+ The `accept_eula` value must be explicitly defined as `True` in order to
6158
+ accept the end-user license agreement (EULA) that some
6159
+ models require. (Default: None).
6148
6160
6149
6161
Returns:
6150
6162
dict[str, str]: A complete container definition object usable with the CreateModel API if
@@ -6154,9 +6166,28 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None,
6154
6166
env = {}
6155
6167
c_def = {"Image" : image_uri , "Environment" : env }
6156
6168
6157
- if isinstance (model_data_url , dict ):
6158
- c_def ["ModelDataSource" ] = model_data_url
6159
- elif model_data_url :
6169
+ if isinstance (model_data_url , str ) and (
6170
+ not (model_data_url .startswith ("s3://" ) and model_data_url .endswith ("tar.gz" ))
6171
+ or accept_eula is None
6172
+ ):
6173
+ c_def ["ModelDataUrl" ] = model_data_url
6174
+
6175
+ elif isinstance (model_data_url , (dict , str )):
6176
+ if isinstance (model_data_url , dict ):
6177
+ c_def ["ModelDataSource" ] = model_data_url
6178
+ else :
6179
+ c_def ["ModelDataSource" ] = {
6180
+ "S3DataSource" : {
6181
+ "S3Uri" : model_data_url ,
6182
+ "S3DataType" : "S3Object" ,
6183
+ "CompressionType" : "Gzip" ,
6184
+ }
6185
+ }
6186
+ if accept_eula is not None :
6187
+ c_def ["ModelDataSource" ]["S3DataSource" ]["ModelAccessConfig" ] = {
6188
+ "AcceptEula" : accept_eula
6189
+ }
6190
+ elif model_data_url is not None :
6160
6191
c_def ["ModelDataUrl" ] = model_data_url
6161
6192
6162
6193
if container_mode :
0 commit comments