Skip to content

Commit 8b96b2f

Browse files
committed
change: enable no-else-return and no-else-raise pylint checks
1 parent 1705b13 commit 8b96b2f

16 files changed

+103
-131
lines changed

.pylintrc

-2
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,8 @@ disable=
8989
abstract-method, # TODO: Fix abstract methods
9090
unidiomatic-typecheck, # TODO: Fix typechecks
9191
wrong-import-order, # TODO: Fix import order
92-
no-else-return, # TODO: Remove unnecessary elses
9392
useless-object-inheritance, # TODO: Remove unnecessary imports
9493
cyclic-import, # TODO: Resolve cyclic imports
95-
no-else-raise, # TODO: Remove unnecessary elses
9694
no-self-use, # TODO: Convert methods to functions where appropriate
9795
inconsistent-return-statements, # TODO: Make returns consistent
9896
consider-merging-isinstance, # TODO: Merge isinstance where appropriate

src/sagemaker/amazon/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def read_recordio(f):
204204
def _resolve_type(dtype):
205205
if dtype == np.dtype(int):
206206
return "Int32"
207-
elif dtype == np.dtype(float):
207+
if dtype == np.dtype(float):
208208
return "Float64"
209-
elif dtype == np.dtype("float32"):
209+
if dtype == np.dtype("float32"):
210210
return "Float32"
211211
raise ValueError("Unsupported dtype {} on array".format(dtype))

src/sagemaker/estimator.py

+14-17
Original file line numberDiff line numberDiff line change
@@ -643,8 +643,7 @@ def get_vpc_config(self, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT):
643643
"""
644644
if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT:
645645
return vpc_utils.to_dict(self.subnets, self.security_group_ids)
646-
else:
647-
return vpc_utils.sanitize(vpc_config_override)
646+
return vpc_utils.sanitize(vpc_config_override)
648647

649648
def _ensure_latest_training_job(
650649
self, error_message="Estimator is not associated with a training job"
@@ -1235,14 +1234,13 @@ def train_image(self):
12351234
"""
12361235
if self.image_name:
12371236
return self.image_name
1238-
else:
1239-
return create_image_uri(
1240-
self.sagemaker_session.boto_region_name,
1241-
self.__framework_name__,
1242-
self.train_instance_type,
1243-
self.framework_version, # pylint: disable=no-member
1244-
py_version=self.py_version, # pylint: disable=no-member
1245-
)
1237+
return create_image_uri(
1238+
self.sagemaker_session.boto_region_name,
1239+
self.__framework_name__,
1240+
self.train_instance_type,
1241+
self.framework_version, # pylint: disable=no-member
1242+
py_version=self.py_version, # pylint: disable=no-member
1243+
)
12461244

12471245
@classmethod
12481246
def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="model"):
@@ -1404,13 +1402,12 @@ def _s3_uri_without_prefix_from_input(input_data):
14041402
for channel_name, channel_s3_uri in input_data.items():
14051403
response.update(_s3_uri_prefix(channel_name, channel_s3_uri))
14061404
return response
1407-
elif isinstance(input_data, str):
1405+
if isinstance(input_data, str):
14081406
return _s3_uri_prefix("training", input_data)
1409-
elif isinstance(input_data, s3_input):
1407+
if isinstance(input_data, s3_input):
14101408
return _s3_uri_prefix("training", input_data)
1411-
else:
1412-
raise ValueError(
1413-
"Unrecognized type for S3 input data config - not str or s3_input: {}".format(
1414-
input_data
1415-
)
1409+
raise ValueError(
1410+
"Unrecognized type for S3 input data config - not str or s3_input: {}".format(
1411+
input_data
14161412
)
1413+
)

src/sagemaker/fw_utils.py

+26-31
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def _is_merged_versions(framework, framework_version):
8787
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
8888
if lowest_version_list:
8989
return is_version_equal_or_higher(lowest_version_list, framework_version)
90-
else:
91-
return False
90+
return False
9291

9392

9493
def _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
@@ -101,8 +100,7 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew
101100
def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):
102101
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
103102
return "763104351884"
104-
else:
105-
return VALID_ACCOUNTS_BY_REGION.get(region, account)
103+
return VALID_ACCOUNTS_BY_REGION.get(region, account)
106104

107105

108106
def create_image_uri(
@@ -182,10 +180,9 @@ def create_image_uri(
182180
return "{}/{}:{}".format(
183181
get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag
184182
)
185-
else:
186-
return "{}/sagemaker-{}:{}".format(
187-
get_ecr_image_uri_prefix(account, region), framework, tag
188-
)
183+
return "{}/sagemaker-{}:{}".format(
184+
get_ecr_image_uri_prefix(account, region), framework, tag
185+
)
189186

190187

191188
def _accelerator_type_valid_for_framework(
@@ -324,30 +321,28 @@ def framework_name_from_image(image_name):
324321
sagemaker_match = sagemaker_pattern.match(image_name)
325322
if sagemaker_match is None:
326323
return None, None, None, None
327-
else:
328-
# extract framework, python version and image tag
329-
# We must support both the legacy and current image name format.
330-
name_pattern = re.compile(
331-
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
324+
# extract framework, python version and image tag
325+
# We must support both the legacy and current image name format.
326+
name_pattern = re.compile(
327+
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
328+
)
329+
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
330+
331+
name_match = name_pattern.match(sagemaker_match.group(9))
332+
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
333+
334+
if name_match is not None:
335+
fw, scriptmode, ver, device, py = (
336+
name_match.group(1),
337+
name_match.group(2),
338+
name_match.group(3),
339+
name_match.group(4),
340+
name_match.group(5),
332341
)
333-
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
334-
335-
name_match = name_pattern.match(sagemaker_match.group(9))
336-
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
337-
338-
if name_match is not None:
339-
fw, scriptmode, ver, device, py = (
340-
name_match.group(1),
341-
name_match.group(2),
342-
name_match.group(3),
343-
name_match.group(4),
344-
name_match.group(5),
345-
)
346-
return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode
347-
elif legacy_match is not None:
348-
return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
349-
else:
350-
return None, None, None, None
342+
return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode
343+
if legacy_match is not None:
344+
return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
345+
return None, None, None, None
351346

352347

353348
def framework_version_from_tag(image_tag):

src/sagemaker/job.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -140,25 +140,24 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
140140
def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None):
141141
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
142142
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
143-
elif isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
143+
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
144144
return file_input(uri_input)
145-
elif isinstance(uri_input, str) and validate_uri:
145+
if isinstance(uri_input, str) and validate_uri:
146146
raise ValueError(
147147
'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
148148
'"file://"'.format(uri_input)
149149
)
150-
elif isinstance(uri_input, str):
150+
if isinstance(uri_input, str):
151151
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
152-
elif isinstance(uri_input, s3_input):
152+
if isinstance(uri_input, s3_input):
153153
return uri_input
154-
elif isinstance(uri_input, file_input):
154+
if isinstance(uri_input, file_input):
155155
return uri_input
156-
else:
157-
raise ValueError(
158-
"Cannot format input {}. Expecting one of str, s3_input, or file_input".format(
159-
uri_input
160-
)
156+
raise ValueError(
157+
"Cannot format input {}. Expecting one of str, s3_input, or file_input".format(
158+
uri_input
161159
)
160+
)
162161

163162
@staticmethod
164163
def _prepare_channel(
@@ -171,7 +170,7 @@ def _prepare_channel(
171170
):
172171
if not channel_uri:
173172
return
174-
elif not channel_name:
173+
if not channel_name:
175174
raise ValueError(
176175
"Expected a channel name if a channel URI {} is specified".format(channel_uri)
177176
)
@@ -197,23 +196,22 @@ def _format_model_uri_input(model_uri, validate_uri=True):
197196
distribution="FullyReplicated",
198197
content_type="application/x-sagemaker-model",
199198
)
200-
elif (
199+
if (
201200
isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://")
202201
):
203202
return file_input(model_uri)
204-
elif isinstance(model_uri, string_types) and validate_uri:
203+
if isinstance(model_uri, string_types) and validate_uri:
205204
raise ValueError(
206205
'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://'
207206
)
208-
elif isinstance(model_uri, string_types):
207+
if isinstance(model_uri, string_types):
209208
return s3_input(
210209
model_uri,
211210
input_mode="File",
212211
distribution="FullyReplicated",
213212
content_type="application/x-sagemaker-model",
214213
)
215-
else:
216-
raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri))
214+
raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri))
217215

218216
@staticmethod
219217
def _format_record_set_list_input(inputs):

src/sagemaker/local/data.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_data_source_instance(data_source, sagemaker_session):
4545
parsed_uri = urlparse(data_source)
4646
if parsed_uri.scheme == "file":
4747
return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path)
48-
elif parsed_uri.scheme == "s3":
48+
if parsed_uri.scheme == "s3":
4949
return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)
5050

5151

@@ -62,12 +62,11 @@ def get_splitter_instance(split_type):
6262
"""
6363
if split_type is None:
6464
return NoneSplitter()
65-
elif split_type == "Line":
65+
if split_type == "Line":
6666
return LineSplitter()
67-
elif split_type == "RecordIO":
67+
if split_type == "RecordIO":
6868
return RecordIOSplitter()
69-
else:
70-
raise ValueError("Invalid Split Type: %s" % split_type)
69+
raise ValueError("Invalid Split Type: %s" % split_type)
7170

7271

7372
def get_batch_strategy_instance(strategy, splitter):
@@ -82,12 +81,11 @@ def get_batch_strategy_instance(strategy, splitter):
8281
"""
8382
if strategy == "SingleRecord":
8483
return SingleRecordStrategy(splitter)
85-
elif strategy == "MultiRecord":
84+
if strategy == "MultiRecord":
8685
return MultiRecordStrategy(splitter)
87-
else:
88-
raise ValueError(
89-
'Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"'
90-
)
86+
raise ValueError(
87+
'Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"'
88+
)
9189

9290

9391
class DataSource(with_metaclass(ABCMeta, object)):
@@ -129,8 +127,7 @@ def get_file_list(self):
129127
for f in os.listdir(self.root_path)
130128
if os.path.isfile(os.path.join(self.root_path, f))
131129
]
132-
else:
133-
return [self.root_path]
130+
return [self.root_path]
134131

135132
def get_root_dir(self):
136133
"""Retrieve the absolute path to the root directory of this data source.
@@ -140,8 +137,7 @@ def get_root_dir(self):
140137
"""
141138
if os.path.isdir(self.root_path):
142139
return self.root_path
143-
else:
144-
return os.path.dirname(self.root_path)
140+
return os.path.dirname(self.root_path)
145141

146142

147143
class S3DataSource(DataSource):

src/sagemaker/local/image.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def _aws_credentials(session):
665665
"AWS_ACCESS_KEY_ID=%s" % (str(access_key)),
666666
"AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)),
667667
]
668-
elif not _aws_credentials_available_in_metadata_service():
668+
if not _aws_credentials_available_in_metadata_service():
669669
logger.warning(
670670
"Using the short-lived AWS credentials found in session. They might expire while running."
671671
)
@@ -674,11 +674,10 @@ def _aws_credentials(session):
674674
"AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)),
675675
"AWS_SESSION_TOKEN=%s" % (str(token)),
676676
]
677-
else:
678-
logger.info(
679-
"No AWS credentials found in session but credentials from EC2 Metadata Service are available."
680-
)
681-
return None
677+
logger.info(
678+
"No AWS credentials found in session but credentials from EC2 Metadata Service are available."
679+
)
680+
return None
682681
except Exception as e: # pylint: disable=broad-except
683682
logger.info("Could not get AWS credentials: %s", e)
684683

src/sagemaker/local/local_session.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def describe_training_job(self, TrainingJobName):
107107
}
108108
}
109109
raise ClientError(error_response, "describe_training_job")
110-
else:
111-
return LocalSagemakerClient._training_jobs[TrainingJobName].describe()
110+
return LocalSagemakerClient._training_jobs[TrainingJobName].describe()
112111

113112
def create_transform_job(
114113
self,
@@ -132,8 +131,7 @@ def describe_transform_job(self, TransformJobName):
132131
}
133132
}
134133
raise ClientError(error_response, "describe_transform_job")
135-
else:
136-
return LocalSagemakerClient._transform_jobs[TransformJobName].describe()
134+
return LocalSagemakerClient._transform_jobs[TransformJobName].describe()
137135

138136
def create_model(
139137
self, ModelName, PrimaryContainer, *args, **kwargs
@@ -152,13 +150,10 @@ def describe_model(self, ModelName):
152150
"Error": {"Code": "ValidationException", "Message": "Could not find local model"}
153151
}
154152
raise ClientError(error_response, "describe_model")
155-
else:
156-
return LocalSagemakerClient._models[ModelName].describe()
153+
return LocalSagemakerClient._models[ModelName].describe()
157154

158155
def describe_endpoint_config(self, EndpointConfigName):
159-
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
160-
return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()
161-
else:
156+
if EndpointConfigName not in LocalSagemakerClient._endpoint_configs:
162157
error_response = {
163158
"Error": {
164159
"Code": "ValidationException",
@@ -167,6 +162,8 @@ def describe_endpoint_config(self, EndpointConfigName):
167162
}
168163
raise ClientError(error_response, "describe_endpoint_config")
169164

165+
return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()
166+
170167
def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
171168
LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
172169
EndpointConfigName, ProductionVariants, Tags
@@ -178,8 +175,7 @@ def describe_endpoint(self, EndpointName):
178175
"Error": {"Code": "ValidationException", "Message": "Could not find local endpoint"}
179176
}
180177
raise ClientError(error_response, "describe_endpoint")
181-
else:
182-
return LocalSagemakerClient._endpoints[EndpointName].describe()
178+
return LocalSagemakerClient._endpoints[EndpointName].describe()
183179

184180
def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
185181
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session)

src/sagemaker/model.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def check_neo_region(self, region):
205205
"""
206206
if region in NEO_IMAGE_ACCOUNT:
207207
return True
208-
else:
209-
return False
208+
return False
210209

211210
def _neo_image_account(self, region):
212211
if region not in NEO_IMAGE_ACCOUNT:

0 commit comments

Comments
 (0)