Skip to content

Commit b15b7c2

Browse files
committed
Local mode does not propagate errors raised in the customer script.
Fix whitespaces. Make python3 to be default for pytorch estimator.
1 parent 2d4cb33 commit b15b7c2

File tree

5 files changed

+24
-21
lines changed

5 files changed

+24
-21
lines changed

src/sagemaker/pytorch/defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
PYTORCH_VERSION = '0.3'
14+
PYTHON_VERSION = 'py3'

src/sagemaker/pytorch/estimator.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from sagemaker.estimator import Framework
1414
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
15-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
15+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
1616
from sagemaker.pytorch.model import PyTorchModel
1717

1818

@@ -21,7 +21,7 @@ class PyTorch(Framework):
2121

2222
__framework_name__ = "pytorch"
2323

24-
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2',
24+
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION,
2525
framework_version=PYTORCH_VERSION, **kwargs):
2626
"""
2727
This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker
@@ -46,7 +46,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio
4646
The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
4747
For convenience, this accepts other types for keys and values, but ``str()`` will be called
4848
to convert them before training.
49-
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
49+
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
5050
One of 'py2' or 'py3'.
5151
framework_version (str): PyTorch version you want to use for executing your model training code.
5252
List of supported versions https://github.com/aws/sagemaker-python-sdk#pytorch-sagemaker-estimators
@@ -81,10 +81,10 @@ def create_model(self, model_server_workers=None):
8181
See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details.
8282
"""
8383
return PyTorchModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
84-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
85-
container_log_level=self.container_log_level, code_location=self.code_location,
86-
py_version=self.py_version, framework_version=self.framework_version,
87-
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
84+
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
85+
container_log_level=self.container_log_level, code_location=self.code_location,
86+
py_version=self.py_version, framework_version=self.framework_version,
87+
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
8888

8989
@classmethod
9090
def _prepare_init_params_from_job_description(cls, job_details):

src/sagemaker/pytorch/model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sagemaker
1414
from sagemaker.fw_utils import create_image_uri
1515
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
16-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
16+
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
1717
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
1818
from sagemaker.utils import name_from_image
1919

@@ -41,8 +41,9 @@ class PyTorchModel(FrameworkModel):
4141

4242
__framework_name__ = 'pytorch'
4343

44-
def __init__(self, model_data, role, entry_point, image=None, py_version='py2', framework_version=PYTORCH_VERSION,
45-
predictor_cls=PyTorchPredictor, model_server_workers=None, **kwargs):
44+
def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_VERSION,
45+
framework_version=PYTORCH_VERSION, predictor_cls=PyTorchPredictor,
46+
model_server_workers=None, **kwargs):
4647
"""Initialize an PyTorchModel.
4748
4849
Args:
@@ -54,7 +55,7 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
5455
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
5556
as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
5657
image (str): A Docker image URI (default: None). If not specified, a default image for PyTorch will be used.
57-
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
58+
py_version (str): Python version you want to use for executing your model training code (default: 'py3').
5859
framework_version (str): PyTorch version you want to use for executing your model training code.
5960
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a predictor
6061
with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of
@@ -63,8 +64,7 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2',
6364
If None, server will use one worker per vCPU.
6465
**kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer.
6566
"""
66-
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls,
67-
**kwargs)
67+
super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs)
6868
self.py_version = py_version
6969
self.framework_version = framework_version
7070
self.model_server_workers = model_server_workers

tests/integ/test_pytorch_train.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,13 @@ def test_async_fit(sagemaker_session, pytorch_full_version, instance_type):
6868
PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
6969

7070

71-
def test_failed_training_job(sagemaker_session, pytorch_full_version, instance_type):
71+
# TODO(nadiaya): Run against local mode when errors will be propagated
72+
def test_failed_training_job(sagemaker_session, pytorch_full_version):
7273
script_path = os.path.join(MNIST_DIR, 'failure_script.py')
7374

7475
with timeout(minutes=15):
7576
pytorch = PyTorch(entry_point=script_path, role='SageMakerRole', framework_version=pytorch_full_version,
76-
train_instance_count=1, train_instance_type=instance_type,
77+
train_instance_count=1, train_instance_type='ml.c4.xlarge',
7778
sagemaker_session=sagemaker_session)
7879

7980
with pytest.raises(ValueError) as e:

tests/unit/test_pytorch.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,19 @@ def fixture_sagemaker_session():
5252
return ims
5353

5454

55-
def _get_full_cpu_image_uri(version):
56-
return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'cpu', PYTHON_VERSION)
55+
def _get_full_cpu_image_uri(version, py_version=PYTHON_VERSION):
56+
return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'cpu', py_version)
5757

5858

59-
def _get_full_gpu_image_uri(version):
60-
return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'gpu', PYTHON_VERSION)
59+
def _get_full_gpu_image_uri(version, py_version=PYTHON_VERSION):
60+
return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'gpu', py_version)
6161

6262

6363
def _pytorch_estimator(sagemaker_session, framework_version=defaults.PYTORCH_VERSION, train_instance_type=None,
6464
enable_cloudwatch_metrics=False, base_job_name=None, **kwargs):
6565
return PyTorch(entry_point=SCRIPT_PATH,
6666
framework_version=framework_version,
67+
py_version=PYTHON_VERSION,
6768
role=ROLE,
6869
sagemaker_session=sagemaker_session,
6970
train_instance_count=INSTANCE_COUNT,
@@ -138,7 +139,7 @@ def test_create_model(sagemaker_session, pytorch_version):
138139
def test_pytorch(strftime, sagemaker_session, pytorch_version):
139140
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
140141
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
141-
framework_version=pytorch_version)
142+
framework_version=pytorch_version, py_version=PYTHON_VERSION)
142143

143144
inputs = 's3://mybucket/train'
144145

@@ -184,7 +185,7 @@ def test_train_image_default(sagemaker_session):
184185
pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
185186
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
186187

187-
assert _get_full_cpu_image_uri(defaults.PYTORCH_VERSION) in pytorch.train_image()
188+
assert _get_full_cpu_image_uri(defaults.PYTORCH_VERSION, defaults.PYTHON_VERSION) in pytorch.train_image()
188189

189190

190191
def test_train_image_cpu_instances(sagemaker_session, pytorch_version):

0 commit comments

Comments
 (0)