Skip to content

Commit 3a9dedd

Browse files
committed
Make pytorch code and test conform with latest python sdk guidlines.
1 parent d254128 commit 3a9dedd

File tree

7 files changed

+51
-38
lines changed

7 files changed

+51
-38
lines changed

src/sagemaker/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
1314
from sagemaker.pytorch.estimator import PyTorch
1415
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor
1516

src/sagemaker/pytorch/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
1315
PYTORCH_VERSION = '0.4'
1416
PYTHON_VERSION = 'py3'

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
1314
from sagemaker.estimator import Framework
1415
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
1516
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION

src/sagemaker/pytorch/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
1314
import sagemaker
1415
from sagemaker.fw_utils import create_image_uri
1516
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME

tests/integ/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
1314
import pytest
1415

1516

tests/integ/test_pytorch_train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
1314
import numpy
1415
import os
1516
import sys

tests/unit/test_pytorch.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
1314
import logging
1415

1516
import json
@@ -44,12 +45,14 @@
4445
@pytest.fixture(name='sagemaker_session')
4546
def fixture_sagemaker_session():
4647
boto_mock = Mock(name='boto_session', region_name=REGION)
47-
ims = Mock(name='sagemaker_session', boto_session=boto_mock)
48-
ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts':
49-
{'S3ModelArtifacts': 's3://m/m.tar.gz'}})
50-
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
51-
ims.expand_role = Mock(name="expand_role", return_value=ROLE)
52-
return ims
48+
session = Mock(name='sagemaker_session', boto_session=boto_mock,
49+
boto_region_name=REGION, config=None, local_mode=False)
50+
51+
describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}}
52+
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
53+
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
54+
session.expand_role = Mock(name="expand_role", return_value=ROLE)
55+
return session
5356

5457

5558
def _get_full_cpu_image_uri(version, py_version=PYTHON_VERSION):
@@ -75,39 +78,42 @@ def _pytorch_estimator(sagemaker_session, framework_version=defaults.PYTORCH_VER
7578

7679

7780
def _create_train_job(version):
78-
return {'image': _get_full_cpu_image_uri(version),
79-
'input_mode': 'File',
80-
'input_config': [{
81-
'ChannelName': 'training',
82-
'DataSource': {
83-
'S3DataSource': {
84-
'S3DataDistributionType': 'FullyReplicated',
85-
'S3DataType': 'S3Prefix'
86-
}
81+
return {
82+
'image': _get_full_cpu_image_uri(version),
83+
'input_mode': 'File',
84+
'input_config': [{
85+
'ChannelName': 'training',
86+
'DataSource': {
87+
'S3DataSource': {
88+
'S3DataDistributionType': 'FullyReplicated',
89+
'S3DataType': 'S3Prefix'
8790
}
88-
}],
89-
'role': ROLE,
90-
'job_name': JOB_NAME,
91-
'output_config': {
92-
'S3OutputPath': 's3://{}/'.format(BUCKET_NAME),
93-
},
94-
'resource_config': {
95-
'InstanceType': 'ml.c4.4xlarge',
96-
'InstanceCount': 1,
97-
'VolumeSizeInGB': 30,
98-
},
99-
'hyperparameters': {
100-
'sagemaker_program': json.dumps('dummy_script.py'),
101-
'sagemaker_enable_cloudwatch_metrics': 'false',
102-
'sagemaker_container_log_level': str(logging.INFO),
103-
'sagemaker_job_name': json.dumps(JOB_NAME),
104-
'sagemaker_submit_directory':
105-
json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)),
106-
'sagemaker_region': '"us-west-2"'
107-
},
108-
'stop_condition': {
109-
'MaxRuntimeInSeconds': 24 * 60 * 60
110-
}}
91+
}
92+
}],
93+
'role': ROLE,
94+
'job_name': JOB_NAME,
95+
'output_config': {
96+
'S3OutputPath': 's3://{}/'.format(BUCKET_NAME),
97+
},
98+
'resource_config': {
99+
'InstanceType': 'ml.c4.4xlarge',
100+
'InstanceCount': 1,
101+
'VolumeSizeInGB': 30,
102+
},
103+
'hyperparameters': {
104+
'sagemaker_program': json.dumps('dummy_script.py'),
105+
'sagemaker_enable_cloudwatch_metrics': 'false',
106+
'sagemaker_container_log_level': str(logging.INFO),
107+
'sagemaker_job_name': json.dumps(JOB_NAME),
108+
'sagemaker_submit_directory':
109+
json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)),
110+
'sagemaker_region': '"us-west-2"'
111+
},
112+
'stop_condition': {
113+
'MaxRuntimeInSeconds': 24 * 60 * 60
114+
},
115+
'tags': None
116+
}
111117

112118

113119
def test_create_model(sagemaker_session, pytorch_version):

0 commit comments

Comments
 (0)