Skip to content

Commit 8670d9d

Browse files
committed
Fix flake8 'too complex' error
1 parent 48c4e02 commit 8670d9d

File tree

3 files changed

+21
-14
lines changed

3 files changed

+21
-14
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read(fname):
4949

5050
extras_require={
5151
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',
52-
'mock', 'tensorflow>=1.3.0', 'contextlib2', 'awslogs']},
52+
'mock', 'tensorflow>=1.3.0', 'contextlib2', 'awslogs', 'pandas']},
5353

5454
entry_points={
5555
'console_scripts': ['sagemaker=sagemaker.cli.main:main'],

src/sagemaker/estimator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4444

4545
def __init__(self, role, train_instance_count, train_instance_type,
4646
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
47-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None,
48-
metric_definitions=None):
47+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
4948
"""Initialize an ``EstimatorBase`` instance.
5049
5150
Args:
@@ -74,7 +73,6 @@ def __init__(self, role, train_instance_count, train_instance_type,
7473
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
7574
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
7675
using the default AWS configuration chain.
77-
metric_definitions (list[dict]): Metrics definition with 'name' and 'regex' keys.
7876
"""
7977
self.role = role
8078
self.train_instance_count = train_instance_count
@@ -95,7 +93,6 @@ def __init__(self, role, train_instance_count, train_instance_type,
9593
self.output_path = output_path
9694
self.output_kms_key = output_kms_key
9795
self.latest_training_job = None
98-
self.metric_definitions = metric_definitions
9996

10097
@abstractmethod
10198
def train_image(self):

src/sagemaker/job.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _load_config(inputs, estimator):
6868

6969
@staticmethod
7070
def _format_inputs_to_input_config(inputs):
71-
# Circular dependency
71+
# Deferred import due to circular dependency
7272
from sagemaker.amazon.amazon_estimator import RecordSet
7373
if isinstance(inputs, RecordSet):
7474
inputs = inputs.data_channel()
@@ -84,14 +84,7 @@ def _format_inputs_to_input_config(inputs):
8484
for k, v in inputs.items():
8585
input_dict[k] = _Job._format_string_uri_input(v)
8686
elif isinstance(inputs, list):
87-
for record in inputs:
88-
if not isinstance(record, RecordSet):
89-
raise ValueError('List compatible only with RecordSets.')
90-
91-
if record.channel in input_dict:
92-
raise ValueError('Duplicate channels not allowed.')
93-
94-
input_dict[record.channel] = record.records_s3_input()
87+
input_dict = _Job._format_record_set_list_input(inputs)
9588
else:
9689
raise ValueError(
9790
'Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
@@ -123,6 +116,23 @@ def _format_string_uri_input(input):
123116
'Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(
124117
input))
125118

119+
@staticmethod
120+
def _format_record_set_list_input(inputs):
121+
# Deferred import due to circular dependency
122+
from sagemaker.amazon.amazon_estimator import RecordSet
123+
124+
input_dict = {}
125+
for record in inputs:
126+
if not isinstance(record, RecordSet):
127+
raise ValueError('List compatible only with RecordSets.')
128+
129+
if record.channel in input_dict:
130+
raise ValueError('Duplicate channels not allowed.')
131+
132+
input_dict[record.channel] = record.records_s3_input()
133+
134+
return input_dict
135+
126136
@staticmethod
127137
def _prepare_output_config(s3_path, kms_key_id):
128138
config = {'S3OutputPath': s3_path}

0 commit comments

Comments
 (0)