Skip to content

change: enable no-else-return and no-else-raise pylint checks #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ disable=
protected-access, # TODO: Fix access
abstract-method, # TODO: Fix abstract methods
wrong-import-order, # TODO: Fix import order
no-else-return, # TODO: Remove unnecessary elses
useless-object-inheritance, # TODO: Remove unnecessary imports
cyclic-import, # TODO: Resolve cyclic imports
no-else-raise, # TODO: Remove unnecessary elses
no-self-use, # TODO: Convert methods to functions where appropriate
inconsistent-return-statements, # TODO: Make returns consistent
consider-merging-isinstance, # TODO: Merge isinstance where appropriate
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/amazon/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def read_recordio(f):
def _resolve_type(dtype):
if dtype == np.dtype(int):
return "Int32"
elif dtype == np.dtype(float):
if dtype == np.dtype(float):
return "Float64"
elif dtype == np.dtype("float32"):
if dtype == np.dtype("float32"):
return "Float32"
raise ValueError("Unsupported dtype {} on array".format(dtype))
31 changes: 13 additions & 18 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,7 @@ def get_vpc_config(self, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT):
"""
if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT:
return vpc_utils.to_dict(self.subnets, self.security_group_ids)
else:
return vpc_utils.sanitize(vpc_config_override)
return vpc_utils.sanitize(vpc_config_override)

def _ensure_latest_training_job(
self, error_message="Estimator is not associated with a training job"
Expand Down Expand Up @@ -1235,14 +1234,13 @@ def train_image(self):
"""
if self.image_name:
return self.image_name
else:
return create_image_uri(
self.sagemaker_session.boto_region_name,
self.__framework_name__,
self.train_instance_type,
self.framework_version, # pylint: disable=no-member
py_version=self.py_version, # pylint: disable=no-member
)
return create_image_uri(
self.sagemaker_session.boto_region_name,
self.__framework_name__,
self.train_instance_type,
self.framework_version, # pylint: disable=no-member
py_version=self.py_version, # pylint: disable=no-member
)

@classmethod
def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="model"):
Expand Down Expand Up @@ -1404,13 +1402,10 @@ def _s3_uri_without_prefix_from_input(input_data):
for channel_name, channel_s3_uri in input_data.items():
response.update(_s3_uri_prefix(channel_name, channel_s3_uri))
return response
elif isinstance(input_data, str):
if isinstance(input_data, str):
return _s3_uri_prefix("training", input_data)
elif isinstance(input_data, s3_input):
if isinstance(input_data, s3_input):
return _s3_uri_prefix("training", input_data)
else:
raise ValueError(
"Unrecognized type for S3 input data config - not str or s3_input: {}".format(
input_data
)
)
raise ValueError(
"Unrecognized type for S3 input data config - not str or s3_input: {}".format(input_data)
)
55 changes: 24 additions & 31 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def _is_merged_versions(framework, framework_version):
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework)
if lowest_version_list:
return is_version_equal_or_higher(lowest_version_list, framework_version)
else:
return False
return False


def _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
Expand All @@ -101,8 +100,7 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew
def _registry_id(region, framework, py_version, account, accelerator_type, framework_version):
if _using_merged_images(region, framework, py_version, accelerator_type, framework_version):
return "763104351884"
else:
return VALID_ACCOUNTS_BY_REGION.get(region, account)
return VALID_ACCOUNTS_BY_REGION.get(region, account)


def create_image_uri(
Expand Down Expand Up @@ -182,10 +180,7 @@ def create_image_uri(
return "{}/{}:{}".format(
get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag
)
else:
return "{}/sagemaker-{}:{}".format(
get_ecr_image_uri_prefix(account, region), framework, tag
)
return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag)


def _accelerator_type_valid_for_framework(
Expand Down Expand Up @@ -324,30 +319,28 @@ def framework_name_from_image(image_name):
sagemaker_match = sagemaker_pattern.match(image_name)
if sagemaker_match is None:
return None, None, None, None
else:
# extract framework, python version and image tag
# We must support both the legacy and current image name format.
name_pattern = re.compile(
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
# extract framework, python version and image tag
# We must support both the legacy and current image name format.
name_pattern = re.compile(
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501
)
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")

name_match = name_pattern.match(sagemaker_match.group(9))
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))

if name_match is not None:
fw, scriptmode, ver, device, py = (
name_match.group(1),
name_match.group(2),
name_match.group(3),
name_match.group(4),
name_match.group(5),
)
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")

name_match = name_pattern.match(sagemaker_match.group(9))
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))

if name_match is not None:
fw, scriptmode, ver, device, py = (
name_match.group(1),
name_match.group(2),
name_match.group(3),
name_match.group(4),
name_match.group(5),
)
return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode
elif legacy_match is not None:
return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
else:
return None, None, None, None
return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode
if legacy_match is not None:
return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
return None, None, None, None


def framework_version_from_tag(image_tag):
Expand Down
32 changes: 14 additions & 18 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,25 +140,24 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None):
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
elif isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
return file_input(uri_input)
elif isinstance(uri_input, str) and validate_uri:
if isinstance(uri_input, str) and validate_uri:
raise ValueError(
'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
'"file://"'.format(uri_input)
)
elif isinstance(uri_input, str):
if isinstance(uri_input, str):
return s3_input(uri_input, content_type=content_type, input_mode=input_mode)
elif isinstance(uri_input, s3_input):
if isinstance(uri_input, s3_input):
return uri_input
elif isinstance(uri_input, file_input):
if isinstance(uri_input, file_input):
return uri_input
else:
raise ValueError(
"Cannot format input {}. Expecting one of str, s3_input, or file_input".format(
uri_input
)
raise ValueError(
"Cannot format input {}. Expecting one of str, s3_input, or file_input".format(
uri_input
)
)

@staticmethod
def _prepare_channel(
Expand All @@ -171,7 +170,7 @@ def _prepare_channel(
):
if not channel_uri:
return
elif not channel_name:
if not channel_name:
raise ValueError(
"Expected a channel name if a channel URI {} is specified".format(channel_uri)
)
Expand All @@ -197,23 +196,20 @@ def _format_model_uri_input(model_uri, validate_uri=True):
distribution="FullyReplicated",
content_type="application/x-sagemaker-model",
)
elif (
isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://")
):
if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://"):
return file_input(model_uri)
elif isinstance(model_uri, string_types) and validate_uri:
if isinstance(model_uri, string_types) and validate_uri:
raise ValueError(
'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://'
)
elif isinstance(model_uri, string_types):
if isinstance(model_uri, string_types):
return s3_input(
model_uri,
input_mode="File",
distribution="FullyReplicated",
content_type="application/x-sagemaker-model",
)
else:
raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri))
raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri))

@staticmethod
def _format_record_set_list_input(inputs):
Expand Down
22 changes: 8 additions & 14 deletions src/sagemaker/local/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_data_source_instance(data_source, sagemaker_session):
parsed_uri = urlparse(data_source)
if parsed_uri.scheme == "file":
return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path)
elif parsed_uri.scheme == "s3":
if parsed_uri.scheme == "s3":
return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)


Expand All @@ -62,12 +62,11 @@ def get_splitter_instance(split_type):
"""
if split_type is None:
return NoneSplitter()
elif split_type == "Line":
if split_type == "Line":
return LineSplitter()
elif split_type == "RecordIO":
if split_type == "RecordIO":
return RecordIOSplitter()
else:
raise ValueError("Invalid Split Type: %s" % split_type)
raise ValueError("Invalid Split Type: %s" % split_type)


def get_batch_strategy_instance(strategy, splitter):
Expand All @@ -82,12 +81,9 @@ def get_batch_strategy_instance(strategy, splitter):
"""
if strategy == "SingleRecord":
return SingleRecordStrategy(splitter)
elif strategy == "MultiRecord":
if strategy == "MultiRecord":
return MultiRecordStrategy(splitter)
else:
raise ValueError(
'Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"'
)
raise ValueError('Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"')


class DataSource(with_metaclass(ABCMeta, object)):
Expand Down Expand Up @@ -129,8 +125,7 @@ def get_file_list(self):
for f in os.listdir(self.root_path)
if os.path.isfile(os.path.join(self.root_path, f))
]
else:
return [self.root_path]
return [self.root_path]

def get_root_dir(self):
"""Retrieve the absolute path to the root directory of this data source.
Expand All @@ -140,8 +135,7 @@ def get_root_dir(self):
"""
if os.path.isdir(self.root_path):
return self.root_path
else:
return os.path.dirname(self.root_path)
return os.path.dirname(self.root_path)


class S3DataSource(DataSource):
Expand Down
11 changes: 5 additions & 6 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def _aws_credentials(session):
"AWS_ACCESS_KEY_ID=%s" % (str(access_key)),
"AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)),
]
elif not _aws_credentials_available_in_metadata_service():
if not _aws_credentials_available_in_metadata_service():
logger.warning(
"Using the short-lived AWS credentials found in session. They might expire while running."
)
Expand All @@ -674,11 +674,10 @@ def _aws_credentials(session):
"AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)),
"AWS_SESSION_TOKEN=%s" % (str(token)),
]
else:
logger.info(
"No AWS credentials found in session but credentials from EC2 Metadata Service are available."
)
return None
logger.info(
"No AWS credentials found in session but credentials from EC2 Metadata Service are available."
)
return None
except Exception as e: # pylint: disable=broad-except
logger.info("Could not get AWS credentials: %s", e)

Expand Down
17 changes: 6 additions & 11 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def describe_training_job(self, TrainingJobName):
}
}
raise ClientError(error_response, "describe_training_job")
else:
return LocalSagemakerClient._training_jobs[TrainingJobName].describe()
return LocalSagemakerClient._training_jobs[TrainingJobName].describe()

def create_transform_job(
self,
Expand All @@ -132,8 +131,7 @@ def describe_transform_job(self, TransformJobName):
}
}
raise ClientError(error_response, "describe_transform_job")
else:
return LocalSagemakerClient._transform_jobs[TransformJobName].describe()
return LocalSagemakerClient._transform_jobs[TransformJobName].describe()

def create_model(
self, ModelName, PrimaryContainer, *args, **kwargs
Expand All @@ -152,20 +150,18 @@ def describe_model(self, ModelName):
"Error": {"Code": "ValidationException", "Message": "Could not find local model"}
}
raise ClientError(error_response, "describe_model")
else:
return LocalSagemakerClient._models[ModelName].describe()
return LocalSagemakerClient._models[ModelName].describe()

def describe_endpoint_config(self, EndpointConfigName):
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()
else:
if EndpointConfigName not in LocalSagemakerClient._endpoint_configs:
error_response = {
"Error": {
"Code": "ValidationException",
"Message": "Could not find local endpoint config",
}
}
raise ClientError(error_response, "describe_endpoint_config")
return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()

def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
Expand All @@ -178,8 +174,7 @@ def describe_endpoint(self, EndpointName):
"Error": {"Code": "ValidationException", "Message": "Could not find local endpoint"}
}
raise ClientError(error_response, "describe_endpoint")
else:
return LocalSagemakerClient._endpoints[EndpointName].describe()
return LocalSagemakerClient._endpoints[EndpointName].describe()

def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session)
Expand Down
3 changes: 1 addition & 2 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def check_neo_region(self, region):
"""
if region in NEO_IMAGE_ACCOUNT:
return True
else:
return False
return False

def _neo_image_account(self, region):
if region not in NEO_IMAGE_ACCOUNT:
Expand Down
Loading