Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 22fd18c

Browse files
author
Dewen Qi
committedJul 27, 2022
change: Add PipelineVariable annotation in amazon models
1 parent b16630b commit 22fd18c

File tree

13 files changed

+172
-11
lines changed

13 files changed

+172
-11
lines changed
 

‎src/sagemaker/amazon/factorization_machines.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class FactorizationMachines(AmazonAlgorithmEstimatorBase):
@@ -319,7 +323,13 @@ class FactorizationMachinesModel(Model):
319323
returns :class:`FactorizationMachinesPredictor`.
320324
"""
321325

322-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
326+
def __init__(
327+
self,
328+
model_data: Union[str, PipelineVariable],
329+
role: str,
330+
sagemaker_session: Optional[Session] = None,
331+
**kwargs
332+
):
323333
"""Initialization for FactorizationMachinesModel class.
324334
325335
Args:
@@ -343,6 +353,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
343353
sagemaker_session.boto_region_name,
344354
version=FactorizationMachines.repo_version,
345355
)
356+
pop_out_unused_kwarg("predictor_cls", kwargs, FactorizationMachinesPredictor.__name__)
357+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
346358
super(FactorizationMachinesModel, self).__init__(
347359
image_uri,
348360
model_data,

‎src/sagemaker/amazon/ipinsights.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
@@ -22,7 +24,9 @@
2224
from sagemaker.model import Model
2325
from sagemaker.serializers import CSVSerializer
2426
from sagemaker.session import Session
27+
from sagemaker.utils import pop_out_unused_kwarg
2528
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
29+
from sagemaker.workflow.entities import PipelineVariable
2630

2731

2832
class IPInsights(AmazonAlgorithmEstimatorBase):
@@ -222,7 +226,13 @@ class IPInsightsModel(Model):
222226
Predictor that calculates anomaly scores for data points.
223227
"""
224228

225-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
229+
def __init__(
230+
self,
231+
model_data: Union[str, PipelineVariable],
232+
role: str,
233+
sagemaker_session: Optional[Session] = None,
234+
**kwargs
235+
):
226236
"""Creates object to get insights on S3 model data.
227237
228238
Args:
@@ -246,6 +256,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
246256
sagemaker_session.boto_region_name,
247257
version=IPInsights.repo_version,
248258
)
259+
pop_out_unused_kwarg("predictor_cls", kwargs, IPInsightsPredictor.__name__)
260+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
249261
super(IPInsightsModel, self).__init__(
250262
image_uri,
251263
model_data,

‎src/sagemaker/amazon/kmeans.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class KMeans(AmazonAlgorithmEstimatorBase):
@@ -246,7 +250,13 @@ class KMeansModel(Model):
246250
Predictor to performs k-means cluster assignment.
247251
"""
248252

249-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
253+
def __init__(
254+
self,
255+
model_data: Union[str, PipelineVariable],
256+
role: str,
257+
sagemaker_session: Optional[Session] = None,
258+
**kwargs
259+
):
250260
"""Initialization for KMeansModel class.
251261
252262
Args:
@@ -270,6 +280,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
270280
sagemaker_session.boto_region_name,
271281
version=KMeans.repo_version,
272282
)
283+
pop_out_unused_kwarg("predictor_cls", kwargs, KMeansPredictor.__name__)
284+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
273285
super(KMeansModel, self).__init__(
274286
image_uri,
275287
model_data,

‎src/sagemaker/amazon/knn.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class KNN(AmazonAlgorithmEstimatorBase):
@@ -238,7 +242,13 @@ class KNNModel(Model):
238242
and returns :class:`KNNPredictor`.
239243
"""
240244

241-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
245+
def __init__(
246+
self,
247+
model_data: Union[str, PipelineVariable],
248+
role: str,
249+
sagemaker_session: Optional[Session] = None,
250+
**kwargs
251+
):
242252
"""Function to initialize KNNModel.
243253
244254
Args:
@@ -262,6 +272,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
262272
sagemaker_session.boto_region_name,
263273
version=KNN.repo_version,
264274
)
275+
pop_out_unused_kwarg("predictor_cls", kwargs, KNNPredictor.__name__)
276+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
265277
super(KNNModel, self).__init__(
266278
image_uri,
267279
model_data,

‎src/sagemaker/amazon/lda.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class LDA(AmazonAlgorithmEstimatorBase):
@@ -220,7 +224,13 @@ class LDAModel(Model):
220224
Predictor that transforms vectors to a lower-dimensional representation.
221225
"""
222226

223-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
227+
def __init__(
228+
self,
229+
model_data: Union[str, PipelineVariable],
230+
role: str,
231+
sagemaker_session: Optional[Session] = None,
232+
**kwargs
233+
):
224234
"""Initialization for LDAModel class.
225235
226236
Args:
@@ -244,6 +254,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
244254
sagemaker_session.boto_region_name,
245255
version=LDA.repo_version,
246256
)
257+
pop_out_unused_kwarg("predictor_cls", kwargs, LDAPredictor.__name__)
258+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
247259
super(LDAModel, self).__init__(
248260
image_uri,
249261
model_data,

‎src/sagemaker/amazon/linear_learner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class LinearLearner(AmazonAlgorithmEstimatorBase):
@@ -481,7 +485,13 @@ class LinearLearnerModel(Model):
481485
:class:`LinearLearnerPredictor`
482486
"""
483487

484-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
488+
def __init__(
489+
self,
490+
model_data: Union[str, PipelineVariable],
491+
role: str,
492+
sagemaker_session: Optional[Session] = None,
493+
**kwargs
494+
):
485495
"""Initialization for LinearLearnerModel.
486496
487497
Args:
@@ -505,6 +515,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
505515
sagemaker_session.boto_region_name,
506516
version=LinearLearner.repo_version,
507517
)
518+
pop_out_unused_kwarg("predictor_cls", kwargs, LinearLearnerPredictor.__name__)
519+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
508520
super(LinearLearnerModel, self).__init__(
509521
image_uri,
510522
model_data,

‎src/sagemaker/amazon/ntm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class NTM(AmazonAlgorithmEstimatorBase):
@@ -249,7 +253,13 @@ class NTMModel(Model):
249253
Predictor that transforms vectors to a lower-dimensional representation.
250254
"""
251255

252-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
256+
def __init__(
257+
self,
258+
model_data: Union[str, PipelineVariable],
259+
role: str,
260+
sagemaker_session: Optional[Session] = None,
261+
**kwargs
262+
):
253263
"""Initialization for NTMModel class.
254264
255265
Args:
@@ -273,6 +283,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
273283
sagemaker_session.boto_region_name,
274284
version=NTM.repo_version,
275285
)
286+
pop_out_unused_kwarg("predictor_cls", kwargs, NTMPredictor.__name__)
287+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
276288
super(NTMModel, self).__init__(
277289
image_uri,
278290
model_data,

‎src/sagemaker/amazon/object2vec.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1921
from sagemaker.amazon.validation import ge, le, isin
2022
from sagemaker.predictor import Predictor
2123
from sagemaker.model import Model
2224
from sagemaker.session import Session
25+
from sagemaker.utils import pop_out_unused_kwarg
2326
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
27+
from sagemaker.workflow.entities import PipelineVariable
2428

2529

2630
def _list_check_subset(valid_super_list):
@@ -344,7 +348,13 @@ class Object2VecModel(Model):
344348
Predictor that calculates anomaly scores for datapoints.
345349
"""
346350

347-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
351+
def __init__(
352+
self,
353+
model_data: Union[str, PipelineVariable],
354+
role: str,
355+
sagemaker_session: Optional[Session] = None,
356+
**kwargs
357+
):
348358
"""Initialization for Object2VecModel class.
349359
350360
Args:
@@ -368,6 +378,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
368378
sagemaker_session.boto_region_name,
369379
version=Object2Vec.repo_version,
370380
)
381+
pop_out_unused_kwarg("predictor_cls", kwargs, Predictor.__name__)
382+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
371383
super(Object2VecModel, self).__init__(
372384
image_uri,
373385
model_data,

‎src/sagemaker/amazon/pca.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class PCA(AmazonAlgorithmEstimatorBase):
@@ -237,7 +241,13 @@ class PCAModel(Model):
237241
Predictor that transforms vectors to a lower-dimensional representation.
238242
"""
239243

240-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
244+
def __init__(
245+
self,
246+
model_data: Union[str, PipelineVariable],
247+
role: str,
248+
sagemaker_session: Optional[Session] = None,
249+
**kwargs
250+
):
241251
"""Initialization for PCAModel.
242252
243253
Args:
@@ -261,6 +271,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
261271
sagemaker_session.boto_region_name,
262272
version=PCA.repo_version,
263273
)
274+
pop_out_unused_kwarg("predictor_cls", kwargs, PCAPredictor.__name__)
275+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
264276
super(PCAModel, self).__init__(
265277
image_uri,
266278
model_data,

‎src/sagemaker/amazon/randomcutforest.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Optional, Union
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class RandomCutForest(AmazonAlgorithmEstimatorBase):
@@ -209,7 +213,13 @@ class RandomCutForestModel(Model):
209213
Predictor that calculates anomaly scores for datapoints.
210214
"""
211215

212-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
216+
def __init__(
217+
self,
218+
model_data: Union[str, PipelineVariable],
219+
role: str,
220+
sagemaker_session: Optional[Session] = None,
221+
**kwargs
222+
):
213223
"""Initialization for RandomCutForestModel class.
214224
215225
Args:
@@ -233,6 +243,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
233243
sagemaker_session.boto_region_name,
234244
version=RandomCutForest.repo_version,
235245
)
246+
pop_out_unused_kwarg("predictor_cls", kwargs, RandomCutForestPredictor.__name__)
247+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
236248
super(RandomCutForestModel, self).__init__(
237249
image_uri,
238250
model_data,

‎src/sagemaker/sparkml/model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import Model, Predictor, Session, image_uris
1719
from sagemaker.serializers import CSVSerializer
20+
from sagemaker.utils import pop_out_unused_kwarg
21+
from sagemaker.workflow.entities import PipelineVariable
1822

1923
framework_name = "sparkml-serving"
2024

@@ -71,7 +75,12 @@ class SparkMLModel(Model):
7175
"""
7276

7377
def __init__(
74-
self, model_data, role=None, spark_version="2.4", sagemaker_session=None, **kwargs
78+
self,
79+
model_data: Union[str, PipelineVariable],
80+
role: Optional[str] = None,
81+
spark_version: str = "2.4",
82+
sagemaker_session: Optional[Session] = None,
83+
**kwargs,
7584
):
7685
"""Initialize a SparkMLModel.
7786
@@ -104,6 +113,8 @@ def __init__(
104113
# boto_region_name
105114
region_name = (sagemaker_session or Session()).boto_region_name
106115
image_uri = image_uris.retrieve(framework_name, region_name, version=spark_version)
116+
pop_out_unused_kwarg("predictor_cls", kwargs, SparkMLPredictor.__name__)
117+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
107118
super(SparkMLModel, self).__init__(
108119
image_uri,
109120
model_data,

‎src/sagemaker/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import abc
2828
import uuid
2929
from datetime import datetime
30+
from typing import Optional
3031

3132
import botocore
3233
from six.moves.urllib import parse
@@ -827,3 +828,20 @@ def construct_container_object(
827828
)
828829

829830
return obj
831+
832+
833+
def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None):
834+
"""Pop out the unused key-word argument and give a warning.
835+
836+
Args:
837+
arg_name (str): The name of the argument to be checked if it is unused.
838+
kwargs (dict): The key-word argument dict.
839+
override_val (str): The value used to override the unused argument (default: None).
840+
"""
841+
if arg_name not in kwargs:
842+
return
843+
warn_msg = "{} supplied in kwargs will be ignored".format(arg_name)
844+
if override_val:
845+
warn_msg += " and further overridden with {}.".format(override_val)
846+
logging.warning(warn_msg)
847+
kwargs.pop(arg_name)

‎tests/unit/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,3 +761,15 @@ def test_partition_by_region():
761761
assert sagemaker.utils._aws_partition("us-gov-east-1") == "aws-us-gov"
762762
assert sagemaker.utils._aws_partition("us-iso-east-1") == "aws-iso"
763763
assert sagemaker.utils._aws_partition("us-isob-east-1") == "aws-iso-b"
764+
765+
766+
def test_pop_out_unused_kwarg():
767+
# The given arg_name is in kwargs
768+
kwargs = dict(arg1=1, arg2=2)
769+
sagemaker.utils.pop_out_unused_kwarg("arg1", kwargs)
770+
assert "arg1" not in kwargs
771+
772+
# The given arg_name is not in kwargs
773+
kwargs = dict(arg1=1, arg2=2)
774+
sagemaker.utils.pop_out_unused_kwarg("arg3", kwargs)
775+
assert len(kwargs) == 2

0 commit comments

Comments
 (0)
Please sign in to comment.