Skip to content

Commit 67ff491

Browse files
qidewenwhenDewen Qi
authored andcommitted
change: Add PipelineVariable annotation in amazon models (aws#3187)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 463ba0e commit 67ff491

13 files changed

+172
-11
lines changed

src/sagemaker/amazon/factorization_machines.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class FactorizationMachines(AmazonAlgorithmEstimatorBase):
@@ -319,7 +323,13 @@ class FactorizationMachinesModel(Model):
319323
returns :class:`FactorizationMachinesPredictor`.
320324
"""
321325

322-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
326+
def __init__(
327+
self,
328+
model_data: Union[str, PipelineVariable],
329+
role: str,
330+
sagemaker_session: Optional[Session] = None,
331+
**kwargs
332+
):
323333
"""Initialization for FactorizationMachinesModel class.
324334
325335
Args:
@@ -343,6 +353,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
343353
sagemaker_session.boto_region_name,
344354
version=FactorizationMachines.repo_version,
345355
)
356+
pop_out_unused_kwarg("predictor_cls", kwargs, FactorizationMachinesPredictor.__name__)
357+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
346358
super(FactorizationMachinesModel, self).__init__(
347359
image_uri,
348360
model_data,

src/sagemaker/amazon/ipinsights.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
@@ -22,7 +24,9 @@
2224
from sagemaker.model import Model
2325
from sagemaker.serializers import CSVSerializer
2426
from sagemaker.session import Session
27+
from sagemaker.utils import pop_out_unused_kwarg
2528
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
29+
from sagemaker.workflow.entities import PipelineVariable
2630

2731

2832
class IPInsights(AmazonAlgorithmEstimatorBase):
@@ -222,7 +226,13 @@ class IPInsightsModel(Model):
222226
Predictor that calculates anomaly scores for data points.
223227
"""
224228

225-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
229+
def __init__(
230+
self,
231+
model_data: Union[str, PipelineVariable],
232+
role: str,
233+
sagemaker_session: Optional[Session] = None,
234+
**kwargs
235+
):
226236
"""Creates object to get insights on S3 model data.
227237
228238
Args:
@@ -246,6 +256,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
246256
sagemaker_session.boto_region_name,
247257
version=IPInsights.repo_version,
248258
)
259+
pop_out_unused_kwarg("predictor_cls", kwargs, IPInsightsPredictor.__name__)
260+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
249261
super(IPInsightsModel, self).__init__(
250262
image_uri,
251263
model_data,

src/sagemaker/amazon/kmeans.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class KMeans(AmazonAlgorithmEstimatorBase):
@@ -246,7 +250,13 @@ class KMeansModel(Model):
246250
Predictor to performs k-means cluster assignment.
247251
"""
248252

249-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
253+
def __init__(
254+
self,
255+
model_data: Union[str, PipelineVariable],
256+
role: str,
257+
sagemaker_session: Optional[Session] = None,
258+
**kwargs
259+
):
250260
"""Initialization for KMeansModel class.
251261
252262
Args:
@@ -270,6 +280,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
270280
sagemaker_session.boto_region_name,
271281
version=KMeans.repo_version,
272282
)
283+
pop_out_unused_kwarg("predictor_cls", kwargs, KMeansPredictor.__name__)
284+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
273285
super(KMeansModel, self).__init__(
274286
image_uri,
275287
model_data,

src/sagemaker/amazon/knn.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class KNN(AmazonAlgorithmEstimatorBase):
@@ -238,7 +242,13 @@ class KNNModel(Model):
238242
and returns :class:`KNNPredictor`.
239243
"""
240244

241-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
245+
def __init__(
246+
self,
247+
model_data: Union[str, PipelineVariable],
248+
role: str,
249+
sagemaker_session: Optional[Session] = None,
250+
**kwargs
251+
):
242252
"""Function to initialize KNNModel.
243253
244254
Args:
@@ -262,6 +272,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
262272
sagemaker_session.boto_region_name,
263273
version=KNN.repo_version,
264274
)
275+
pop_out_unused_kwarg("predictor_cls", kwargs, KNNPredictor.__name__)
276+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
265277
super(KNNModel, self).__init__(
266278
image_uri,
267279
model_data,

src/sagemaker/amazon/lda.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class LDA(AmazonAlgorithmEstimatorBase):
@@ -220,7 +224,13 @@ class LDAModel(Model):
220224
Predictor that transforms vectors to a lower-dimensional representation.
221225
"""
222226

223-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
227+
def __init__(
228+
self,
229+
model_data: Union[str, PipelineVariable],
230+
role: str,
231+
sagemaker_session: Optional[Session] = None,
232+
**kwargs
233+
):
224234
"""Initialization for LDAModel class.
225235
226236
Args:
@@ -244,6 +254,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
244254
sagemaker_session.boto_region_name,
245255
version=LDA.repo_version,
246256
)
257+
pop_out_unused_kwarg("predictor_cls", kwargs, LDAPredictor.__name__)
258+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
247259
super(LDAModel, self).__init__(
248260
image_uri,
249261
model_data,

src/sagemaker/amazon/linear_learner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class LinearLearner(AmazonAlgorithmEstimatorBase):
@@ -481,7 +485,13 @@ class LinearLearnerModel(Model):
481485
:class:`LinearLearnerPredictor`
482486
"""
483487

484-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
488+
def __init__(
489+
self,
490+
model_data: Union[str, PipelineVariable],
491+
role: str,
492+
sagemaker_session: Optional[Session] = None,
493+
**kwargs
494+
):
485495
"""Initialization for LinearLearnerModel.
486496
487497
Args:
@@ -505,6 +515,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
505515
sagemaker_session.boto_region_name,
506516
version=LinearLearner.repo_version,
507517
)
518+
pop_out_unused_kwarg("predictor_cls", kwargs, LinearLearnerPredictor.__name__)
519+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
508520
super(LinearLearnerModel, self).__init__(
509521
image_uri,
510522
model_data,

src/sagemaker/amazon/ntm.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class NTM(AmazonAlgorithmEstimatorBase):
@@ -249,7 +253,13 @@ class NTMModel(Model):
249253
Predictor that transforms vectors to a lower-dimensional representation.
250254
"""
251255

252-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
256+
def __init__(
257+
self,
258+
model_data: Union[str, PipelineVariable],
259+
role: str,
260+
sagemaker_session: Optional[Session] = None,
261+
**kwargs
262+
):
253263
"""Initialization for NTMModel class.
254264
255265
Args:
@@ -273,6 +283,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
273283
sagemaker_session.boto_region_name,
274284
version=NTM.repo_version,
275285
)
286+
pop_out_unused_kwarg("predictor_cls", kwargs, NTMPredictor.__name__)
287+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
276288
super(NTMModel, self).__init__(
277289
image_uri,
278290
model_data,

src/sagemaker/amazon/object2vec.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
1921
from sagemaker.amazon.validation import ge, le, isin
2022
from sagemaker.predictor import Predictor
2123
from sagemaker.model import Model
2224
from sagemaker.session import Session
25+
from sagemaker.utils import pop_out_unused_kwarg
2326
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
27+
from sagemaker.workflow.entities import PipelineVariable
2428

2529

2630
def _list_check_subset(valid_super_list):
@@ -344,7 +348,13 @@ class Object2VecModel(Model):
344348
Predictor that calculates anomaly scores for datapoints.
345349
"""
346350

347-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
351+
def __init__(
352+
self,
353+
model_data: Union[str, PipelineVariable],
354+
role: str,
355+
sagemaker_session: Optional[Session] = None,
356+
**kwargs
357+
):
348358
"""Initialization for Object2VecModel class.
349359
350360
Args:
@@ -368,6 +378,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
368378
sagemaker_session.boto_region_name,
369379
version=Object2Vec.repo_version,
370380
)
381+
pop_out_unused_kwarg("predictor_cls", kwargs, Predictor.__name__)
382+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
371383
super(Object2VecModel, self).__init__(
372384
image_uri,
373385
model_data,

src/sagemaker/amazon/pca.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
from typing import Union, Optional
17+
1618
from sagemaker import image_uris
1719
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
1820
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
@@ -21,7 +23,9 @@
2123
from sagemaker.predictor import Predictor
2224
from sagemaker.model import Model
2325
from sagemaker.session import Session
26+
from sagemaker.utils import pop_out_unused_kwarg
2427
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
28+
from sagemaker.workflow.entities import PipelineVariable
2529

2630

2731
class PCA(AmazonAlgorithmEstimatorBase):
@@ -237,7 +241,13 @@ class PCAModel(Model):
237241
Predictor that transforms vectors to a lower-dimensional representation.
238242
"""
239243

240-
def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
244+
def __init__(
245+
self,
246+
model_data: Union[str, PipelineVariable],
247+
role: str,
248+
sagemaker_session: Optional[Session] = None,
249+
**kwargs
250+
):
241251
"""Initialization for PCAModel.
242252
243253
Args:
@@ -261,6 +271,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
261271
sagemaker_session.boto_region_name,
262272
version=PCA.repo_version,
263273
)
274+
pop_out_unused_kwarg("predictor_cls", kwargs, PCAPredictor.__name__)
275+
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
264276
super(PCAModel, self).__init__(
265277
image_uri,
266278
model_data,

0 commit comments

Comments
 (0)