13
13
import sagemaker
14
14
from sagemaker .fw_utils import create_image_uri
15
15
from sagemaker .model import FrameworkModel , MODEL_SERVER_WORKERS_PARAM_NAME
16
- from sagemaker .pytorch .defaults import PYTORCH_VERSION
16
+ from sagemaker .pytorch .defaults import PYTORCH_VERSION , PYTHON_VERSION
17
17
from sagemaker .predictor import RealTimePredictor , json_serializer , json_deserializer
18
18
from sagemaker .utils import name_from_image
19
19
@@ -41,8 +41,9 @@ class PyTorchModel(FrameworkModel):
41
41
42
42
__framework_name__ = 'pytorch'
43
43
44
- def __init__ (self , model_data , role , entry_point , image = None , py_version = 'py2' , framework_version = PYTORCH_VERSION ,
45
- predictor_cls = PyTorchPredictor , model_server_workers = None , ** kwargs ):
44
+ def __init__ (self , model_data , role , entry_point , image = None , py_version = PYTHON_VERSION ,
45
+ framework_version = PYTORCH_VERSION , predictor_cls = PyTorchPredictor ,
46
+ model_server_workers = None , ** kwargs ):
46
47
"""Initialize an PyTorchModel.
47
48
48
49
Args:
@@ -54,7 +55,7 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
54
55
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
55
56
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
56
57
image (str): A Docker image URI (default: None). If not specified, a default image for PyTorch will be used.
57
- py_version (str): Python version you want to use for executing your model training code (default: 'py2 ').
58
+ py_version (str): Python version you want to use for executing your model training code (default: 'py3 ').
58
59
framework_version (str): PyTorch version you want to use for executing your model training code.
59
60
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
60
61
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
@@ -63,8 +64,7 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
63
64
If None, server will use one worker per vCPU.
64
65
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
65
66
"""
66
- super (PyTorchModel , self ).__init__ (model_data , image , role , entry_point , predictor_cls = predictor_cls ,
67
- ** kwargs )
67
+ super (PyTorchModel , self ).__init__ (model_data , image , role , entry_point , predictor_cls = predictor_cls , ** kwargs )
68
68
self .py_version = py_version
69
69
self .framework_version = framework_version
70
70
self .model_server_workers = model_server_workers
0 commit comments