Skip to content

Commit 4d4dd1f

Browse files
authored
breaking: rename image to image_uri (#1670)
1 parent 5dee4e4 commit 4d4dd1f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+310
-319
lines changed

src/sagemaker/amazon/amazon_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class constructor
156156
init_params[attribute] = init_params["hyperparameters"][value.name]
157157

158158
del init_params["hyperparameters"]
159-
del init_params["image"]
159+
del init_params["image_uri"]
160160
return init_params
161161

162162
def prepare_workflow_for_training(self, records=None, mini_batch_size=None, job_name=None):

src/sagemaker/amazon/factorization_machines.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
312312
"""
313313
sagemaker_session = sagemaker_session or Session()
314314
repo = "{}:{}".format(FactorizationMachines.repo_name, FactorizationMachines.repo_version)
315-
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
315+
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
316316
super(FactorizationMachinesModel, self).__init__(
317-
image,
317+
image_uri,
318318
model_data,
319319
role,
320320
predictor_cls=FactorizationMachinesPredictor,

src/sagemaker/amazon/ipinsights.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
218218
"""
219219
sagemaker_session = sagemaker_session or Session()
220220
repo = "{}:{}".format(IPInsights.repo_name, IPInsights.repo_version)
221-
image = "{}/{}".format(
221+
image_uri = "{}/{}".format(
222222
registry(sagemaker_session.boto_session.region_name, IPInsights.repo_name), repo
223223
)
224224

225225
super(IPInsightsModel, self).__init__(
226-
image,
226+
image_uri,
227227
model_data,
228228
role,
229229
predictor_cls=IPInsightsPredictor,

src/sagemaker/amazon/kmeans.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
243243
"""
244244
sagemaker_session = sagemaker_session or Session()
245245
repo = "{}:{}".format(KMeans.repo_name, KMeans.repo_version)
246-
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
246+
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
247247
super(KMeansModel, self).__init__(
248-
image,
248+
image_uri,
249249
model_data,
250250
role,
251251
predictor_cls=KMeansPredictor,

src/sagemaker/amazon/lda.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
215215
"""
216216
sagemaker_session = sagemaker_session or Session()
217217
repo = "{}:{}".format(LDA.repo_name, LDA.repo_version)
218-
image = "{}/{}".format(
218+
image_uri = "{}/{}".format(
219219
registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo
220220
)
221221
super(LDAModel, self).__init__(
222-
image,
222+
image_uri,
223223
model_data,
224224
role,
225225
predictor_cls=LDAPredictor,

src/sagemaker/amazon/linear_learner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
476476
"""
477477
sagemaker_session = sagemaker_session or Session()
478478
repo = "{}:{}".format(LinearLearner.repo_name, LinearLearner.repo_version)
479-
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
479+
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
480480
super(LinearLearnerModel, self).__init__(
481-
image,
481+
image_uri,
482482
model_data,
483483
role,
484484
predictor_cls=LinearLearnerPredictor,

src/sagemaker/amazon/ntm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
245245
"""
246246
sagemaker_session = sagemaker_session or Session()
247247
repo = "{}:{}".format(NTM.repo_name, NTM.repo_version)
248-
image = "{}/{}".format(
248+
image_uri = "{}/{}".format(
249249
registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo
250250
)
251251
super(NTMModel, self).__init__(
252-
image,
252+
image_uri,
253253
model_data,
254254
role,
255255
predictor_cls=NTMPredictor,

src/sagemaker/amazon/object2vec.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
351351
"""
352352
sagemaker_session = sagemaker_session or Session()
353353
repo = "{}:{}".format(Object2Vec.repo_name, Object2Vec.repo_version)
354-
image = "{}/{}".format(
354+
image_uri = "{}/{}".format(
355355
registry(sagemaker_session.boto_session.region_name, Object2Vec.repo_name), repo
356356
)
357357
super(Object2VecModel, self).__init__(
358-
image,
358+
image_uri,
359359
model_data,
360360
role,
361361
predictor_cls=Predictor,

src/sagemaker/amazon/pca.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
227227
"""
228228
sagemaker_session = sagemaker_session or Session()
229229
repo = "{}:{}".format(PCA.repo_name, PCA.repo_version)
230-
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
230+
image_uri = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
231231
super(PCAModel, self).__init__(
232-
image,
232+
image_uri,
233233
model_data,
234234
role,
235235
predictor_cls=PCAPredictor,

src/sagemaker/amazon/randomcutforest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,11 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
206206
"""
207207
sagemaker_session = sagemaker_session or Session()
208208
repo = "{}:{}".format(RandomCutForest.repo_name, RandomCutForest.repo_version)
209-
image = "{}/{}".format(
209+
image_uri = "{}/{}".format(
210210
registry(sagemaker_session.boto_session.region_name, RandomCutForest.repo_name), repo
211211
)
212212
super(RandomCutForestModel, self).__init__(
213-
image,
213+
image_uri,
214214
model_data,
215215
role,
216216
predictor_cls=RandomCutForestPredictor,

src/sagemaker/automl/automl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,12 @@ def create_model(
307307
models = []
308308

309309
for container in inference_containers:
310-
image = container["Image"]
310+
image_uri = container["Image"]
311311
model_data = container["ModelDataUrl"]
312312
env = container["Environment"]
313313

314314
model = Model(
315-
image=image,
315+
image_uri=image_uri,
316316
model_data=model_data,
317317
role=self.role,
318318
env=env,

src/sagemaker/automl/candidate_estimator.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -211,24 +211,25 @@ def _get_train_args(
211211
Returns (dcit): a dictionary that can be used as args of
212212
sagemaker_session.train method.
213213
"""
214-
train_args = {}
215-
train_args["input_config"] = inputs
216-
train_args["job_name"] = name
217-
train_args["input_mode"] = desc["AlgorithmSpecification"]["TrainingInputMode"]
218-
train_args["role"] = desc["RoleArn"]
219-
train_args["output_config"] = desc["OutputDataConfig"]
220-
train_args["resource_config"] = desc["ResourceConfig"]
221-
train_args["image"] = desc["AlgorithmSpecification"]["TrainingImage"]
222-
train_args["enable_network_isolation"] = desc["EnableNetworkIsolation"]
223-
train_args["encrypt_inter_container_traffic"] = encrypt_inter_container_traffic
224-
train_args["train_use_spot_instances"] = desc["EnableManagedSpotTraining"]
225-
train_args["hyperparameters"] = {}
226-
train_args["stop_condition"] = {}
227-
train_args["metric_definitions"] = None
228-
train_args["checkpoint_s3_uri"] = None
229-
train_args["checkpoint_local_path"] = None
230-
train_args["tags"] = []
231-
train_args["vpc_config"] = None
214+
train_args = {
215+
"input_config": inputs,
216+
"job_name": name,
217+
"input_mode": desc["AlgorithmSpecification"]["TrainingInputMode"],
218+
"role": desc["RoleArn"],
219+
"output_config": desc["OutputDataConfig"],
220+
"resource_config": desc["ResourceConfig"],
221+
"image_uri": desc["AlgorithmSpecification"]["TrainingImage"],
222+
"enable_network_isolation": desc["EnableNetworkIsolation"],
223+
"encrypt_inter_container_traffic": encrypt_inter_container_traffic,
224+
"train_use_spot_instances": desc["EnableManagedSpotTraining"],
225+
"hyperparameters": {},
226+
"stop_condition": {},
227+
"metric_definitions": None,
228+
"checkpoint_s3_uri": None,
229+
"checkpoint_local_path": None,
230+
"tags": [],
231+
"vpc_config": None,
232+
}
232233

233234
if volume_kms_key is not None:
234235
train_args["resource_config"]["VolumeKmsKeyId"] = volume_kms_key

src/sagemaker/chainer/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ def create_model(
208208
"""
209209
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
210210

211-
if "image" not in kwargs:
212-
kwargs["image"] = self.image_uri
211+
if "image_uri" not in kwargs:
212+
kwargs["image_uri"] = self.image_uri
213213

214214
return ChainerModel(
215215
self.model_data,
@@ -257,7 +257,7 @@ class constructor
257257
if value:
258258
init_params[argument[len("sagemaker_") :]] = value
259259

260-
image_uri = init_params.pop("image")
260+
image_uri = init_params.pop("image_uri")
261261
framework, py_version, tag, _ = framework_name_from_image(image_uri)
262262

263263
if tag is None:

src/sagemaker/chainer/model.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
model_data,
6565
role,
6666
entry_point,
67-
image=None,
67+
image_uri=None,
6868
framework_version=None,
6969
py_version=None,
7070
predictor_cls=ChainerPredictor,
@@ -85,16 +85,16 @@ def __init__(
8585
file which should be executed as the entry point to model
8686
hosting. If ``source_dir`` is specified, then ``entry_point``
8787
must point to a file located at the root of ``source_dir``.
88-
image (str): A Docker image URI (default: None). If not specified, a
88+
image_uri (str): A Docker image URI (default: None). If not specified, a
8989
default image for Chainer will be used. If ``framework_version``
90-
or ``py_version`` are ``None``, then ``image`` is required. If
90+
or ``py_version`` are ``None``, then ``image_uri`` is required. If
9191
also ``None``, then a ``ValueError`` will be raised.
9292
framework_version (str): Chainer version you want to use for
9393
executing your model training code. Defaults to ``None``. Required
94-
unless ``image`` is provided.
94+
unless ``image_uri`` is provided.
9595
py_version (str): Python version you want to use for executing your
9696
model training code. Defaults to ``None``. Required unless
97-
``image`` is provided.
97+
``image_uri`` is provided.
9898
predictor_cls (callable[str, sagemaker.session.Session]): A function
9999
to call to create a predictor with an endpoint name and
100100
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -111,7 +111,7 @@ def __init__(
111111
:class:`~sagemaker.model.FrameworkModel` and
112112
:class:`~sagemaker.model.Model`.
113113
"""
114-
validate_version_or_image_args(framework_version, py_version, image)
114+
validate_version_or_image_args(framework_version, py_version, image_uri)
115115
if py_version == "py2":
116116
logger.warning(
117117
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
@@ -120,7 +120,7 @@ def __init__(
120120
self.py_version = py_version
121121

122122
super(ChainerModel, self).__init__(
123-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
123+
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
124124
)
125125

126126
self.model_server_workers = model_server_workers
@@ -140,7 +140,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
140140
dict[str, str]: A container definition object usable with the
141141
CreateModel API.
142142
"""
143-
deploy_image = self.image
143+
deploy_image = self.image_uri
144144
if not deploy_image:
145145
if instance_type is None:
146146
raise ValueError(

src/sagemaker/estimator.py

+5-26
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ class constructor
795795
if "AlgorithmName" in job_details["AlgorithmSpecification"]:
796796
init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
797797
elif "TrainingImage" in job_details["AlgorithmSpecification"]:
798-
init_params["image"] = job_details["AlgorithmSpecification"]["TrainingImage"]
798+
init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"]
799799
else:
800800
raise RuntimeError(
801801
"Invalid AlgorithmSpecification. Either TrainingImage or "
@@ -1037,7 +1037,7 @@ def start_new(cls, estimator, inputs, experiment_config):
10371037
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
10381038
train_args["algorithm_arn"] = estimator.algorithm_arn
10391039
else:
1040-
train_args["image"] = estimator.train_image()
1040+
train_args["image_uri"] = estimator.train_image()
10411041

10421042
if estimator.debugger_rule_configs:
10431043
train_args["debugger_rule_configs"] = estimator.debugger_rule_configs
@@ -1331,7 +1331,7 @@ def hyperparameters(self):
13311331
def create_model(
13321332
self,
13331333
role=None,
1334-
image=None,
1334+
image_uri=None,
13351335
predictor_cls=None,
13361336
serializer=None,
13371337
deserializer=None,
@@ -1350,7 +1350,7 @@ def create_model(
13501350
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
13511351
which is also used during transform jobs. If not specified, the
13521352
role from the Estimator will be used.
1353-
image (str): An container image to use for deploying the model.
1353+
image_uri (str): A Docker image URI to use for deploying the model.
13541354
Defaults to the image used for training.
13551355
predictor_cls (Predictor): The predictor class to use when
13561356
deploying the model.
@@ -1393,7 +1393,7 @@ def predict_wrapper(endpoint, session):
13931393
kwargs["enable_network_isolation"] = self.enable_network_isolation()
13941394

13951395
return Model(
1396-
image or self.train_image(),
1396+
image_uri or self.train_image(),
13971397
self.model_data,
13981398
role,
13991399
vpc_config=self.get_vpc_config(vpc_config_override),
@@ -1402,27 +1402,6 @@ def predict_wrapper(endpoint, session):
14021402
**kwargs
14031403
)
14041404

1405-
@classmethod
1406-
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
1407-
"""Convert the job description to init params that can be handled by the
1408-
class constructor
1409-
1410-
Args:
1411-
job_details: the returned job details from a describe_training_job
1412-
API call.
1413-
model_channel_name (str): Name of the channel where pre-trained
1414-
model data will be downloaded
1415-
1416-
Returns:
1417-
dictionary: The transformed init_params
1418-
"""
1419-
init_params = super(Estimator, cls)._prepare_init_params_from_job_description(
1420-
job_details, model_channel_name
1421-
)
1422-
1423-
init_params["image_uri"] = init_params.pop("image")
1424-
return init_params
1425-
14261405

14271406
class Framework(EstimatorBase):
14281407
"""Base class that cannot be instantiated directly.

0 commit comments

Comments
 (0)