Skip to content

Commit 26c7231

Browse files
committed
reformat
1 parent ea36e5e commit 26c7231

File tree

442 files changed

+14354
-4872
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

442 files changed

+14354
-4872
lines changed

ci-scripts/queue_build.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,18 @@ def _wait_for_other_builds(ticket_number):
5151
file_ticket_number, build_id, source_version = _build_info_from_file(file)
5252
print(
5353
"%s -> %s %s, ticket number: %s status: %s"
54-
% (order, build_id, source_version, file_ticket_number, file.key.split("/")[1])
54+
% (
55+
order,
56+
build_id,
57+
source_version,
58+
file_ticket_number,
59+
file.key.split("/")[1],
60+
)
5561
)
5662
print()
57-
build_id = re.sub("[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID"))
63+
build_id = re.sub(
64+
"[_/]", "-", os.environ.get("CODEBUILD_BUILD_ID", "CODEBUILD-BUILD-ID")
65+
)
5866
source_version = re.sub(
5967
"[_/]",
6068
"-",
@@ -68,7 +76,9 @@ def _wait_for_other_builds(ticket_number):
6876
_cleanup_tickets_with_terminal_states()
6977
waiting_tickets = _list_tickets("waiting")
7078
if waiting_tickets:
71-
first_waiting_ticket_number, _, _ = _build_info_from_file(_list_tickets("waiting")[0])
79+
first_waiting_ticket_number, _, _ = _build_info_from_file(
80+
_list_tickets("waiting")[0]
81+
)
7282
else:
7383
first_waiting_ticket_number = ticket_number
7484

@@ -91,7 +101,9 @@ def last_in_progress_elapsed_time_check():
91101
in_progress_tickets = _list_tickets("in-progress")
92102
if not in_progress_tickets:
93103
return True
94-
last_in_progress_ticket, _, _ = _build_info_from_file(_list_tickets("in-progress")[-1])
104+
last_in_progress_ticket, _, _ = _build_info_from_file(
105+
_list_tickets("in-progress")[-1]
106+
)
95107
_elapsed_time = int(1000 * time.time()) - last_in_progress_ticket
96108
last_in_progress_elapsed_time = int(_elapsed_time / (1000 * 60)) # in minutes
97109
return last_in_progress_elapsed_time > INTERVAL_BETWEEN_CONCURRENT_RUNS

doc/conf.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,13 @@
8585
"https://cdn.datatables.net/1.10.23/css/jquery.dataTables.min.css",
8686
]
8787

88-
html_context = {"css_files": ["_static/theme_overrides.css", "_static/pagination.css", "_static/search_accessories.css"]}
88+
html_context = {
89+
"css_files": [
90+
"_static/theme_overrides.css",
91+
"_static/pagination.css",
92+
"_static/search_accessories.css",
93+
]
94+
}
8995

9096
# Example configuration for intersphinx: refer to the Python standard library.
9197
intersphinx_mapping = {"http://docs.python.org/": None}

doc/doc_utils/jumpstart_doc_utils.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ class ProblemTypes(str, Enum):
6060

6161
JUMPSTART_REGION = "eu-west-2"
6262
SDK_MANIFEST_FILE = "models_manifest.json"
63-
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
64-
JUMPSTART_REGION, JUMPSTART_REGION
63+
JUMPSTART_BUCKET_BASE_URL = (
64+
"https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format(
65+
JUMPSTART_REGION, JUMPSTART_REGION
66+
)
6567
)
6668
TASK_MAP = {
6769
Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION,
@@ -187,7 +189,9 @@ def create_jumpstart_model_table():
187189
file_content.append(" - {}\n".format(model["min_version"]))
188190
file_content.append(" - {}\n".format(model_task))
189191
file_content.append(
190-
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
192+
" - `{} <{}>`__ |external-link|\n".format(
193+
model_source, model_spec["url"]
194+
)
191195
)
192196

193197
f = open("doc_utils/pretrainedmodels.rst", "w")

setup.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,19 @@ def read_requirements(filename):
6969
# Meta dependency groups
7070
extras["all"] = [item for group in extras.values() for item in group]
7171
# Tests specific dependencies (do not need to be included in 'all')
72-
extras["test"] = (extras["all"] + read_requirements("requirements/extras/test_requirements.txt"),)
72+
extras["test"] = (
73+
extras["all"] + read_requirements("requirements/extras/test_requirements.txt"),
74+
)
7375

7476
setup(
7577
name="sagemaker",
7678
version=read_version(),
7779
description="Open source library for training and deploying models on Amazon SageMaker.",
7880
packages=find_packages("src"),
7981
package_dir={"": "src"},
80-
py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")],
82+
py_modules=[
83+
os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")
84+
],
8185
include_package_data=True,
8286
long_description=read("README.rst"),
8387
author="Amazon Web Services",

src/sagemaker/__init__.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
FactorizationMachines,
2929
FactorizationMachinesModel,
3030
)
31-
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor # noqa: F401
31+
from sagemaker.amazon.factorization_machines import (
32+
FactorizationMachinesPredictor,
33+
) # noqa: F401
3234
from sagemaker.inputs import TrainingInput # noqa: F401
3335
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor # noqa: F401
3436
from sagemaker.amazon.randomcutforest import ( # noqa: F401
@@ -45,11 +47,18 @@
4547
)
4648

4749
from sagemaker.algorithm import AlgorithmEstimator # noqa: F401
48-
from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics # noqa: F401
50+
from sagemaker.analytics import (
51+
TrainingJobAnalytics,
52+
HyperparameterTuningJobAnalytics,
53+
) # noqa: F401
4954
from sagemaker.local.local_session import LocalSession # noqa: F401
5055

5156
from sagemaker.model import Model, ModelPackage # noqa: F401
52-
from sagemaker.model_metrics import ModelMetrics, MetricsSource, FileSource # noqa: F401
57+
from sagemaker.model_metrics import (
58+
ModelMetrics,
59+
MetricsSource,
60+
FileSource,
61+
) # noqa: F401
5362
from sagemaker.pipeline import PipelineModel # noqa: F401
5463
from sagemaker.predictor import Predictor # noqa: F401
5564
from sagemaker.processing import Processor, ScriptProcessor # noqa: F401
@@ -60,6 +69,9 @@
6069
from sagemaker.session import get_execution_role # noqa: F401
6170

6271
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401
63-
from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401
72+
from sagemaker.automl.candidate_estimator import (
73+
CandidateEstimator,
74+
CandidateStep,
75+
) # noqa: F401
6476

6577
__version__ = importlib_metadata.version("sagemaker")

src/sagemaker/algorithm.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,10 @@ def __init__(
168168
max_wait=max_wait,
169169
)
170170

171-
self.algorithm_spec = self.sagemaker_session.sagemaker_client.describe_algorithm(
172-
AlgorithmName=algorithm_arn
171+
self.algorithm_spec = (
172+
self.sagemaker_session.sagemaker_client.describe_algorithm(
173+
AlgorithmName=algorithm_arn
174+
)
173175
)
174176
self.validate_train_spec()
175177
self.hyperparameter_definitions = self._parse_hyperparameters()
@@ -185,7 +187,9 @@ def validate_train_spec(self):
185187

186188
# Check that the input mode provided is compatible with the training input modes for the
187189
# algorithm.
188-
input_modes = self._algorithm_training_input_modes(train_spec["TrainingChannels"])
190+
input_modes = self._algorithm_training_input_modes(
191+
train_spec["TrainingChannels"]
192+
)
189193
if self.input_mode not in input_modes:
190194
raise ValueError(
191195
"Invalid input mode: %s. %s only supports: %s"
@@ -233,7 +237,9 @@ def training_image_uri(self):
233237
The fit() method, that does the model training, calls this method to
234238
find the image to use for model training.
235239
"""
236-
raise RuntimeError("training_image_uri is never meant to be called on Algorithm Estimators")
240+
raise RuntimeError(
241+
"training_image_uri is never meant to be called on Algorithm Estimators"
242+
)
237243

238244
def enable_network_isolation(self):
239245
"""Return True if this Estimator will need network isolation to run.
@@ -377,7 +383,9 @@ def transformer(
377383

378384
tags = tags or self.tags
379385
else:
380-
raise RuntimeError("No finished training job found associated with this estimator")
386+
raise RuntimeError(
387+
"No finished training job found associated with this estimator"
388+
)
381389

382390
return Transformer(
383391
model_name,
@@ -431,21 +439,29 @@ def _validate_input_channels(self, channels):
431439
for c in channels:
432440
if c not in training_channels:
433441
raise ValueError(
434-
"Unknown input channel: %s is not supported by: %s" % (c, algorithm_name)
442+
"Unknown input channel: %s is not supported by: %s"
443+
% (c, algorithm_name)
435444
)
436445

437446
# check for required channels that were not provided
438447
for name, channel in training_channels.items():
439-
if name not in channels and "IsRequired" in channel and channel["IsRequired"]:
440-
raise ValueError("Required input channel: %s Was not provided." % (name))
448+
if (
449+
name not in channels
450+
and "IsRequired" in channel
451+
and channel["IsRequired"]
452+
):
453+
raise ValueError(
454+
"Required input channel: %s Was not provided." % (name)
455+
)
441456

442457
def _validate_and_cast_hyperparameter(self, name, v):
443458
"""Placeholder docstring"""
444459
algorithm_name = self.algorithm_spec["AlgorithmName"]
445460

446461
if name not in self.hyperparameter_definitions:
447462
raise ValueError(
448-
"Invalid hyperparameter: %s is not supported by %s" % (name, algorithm_name)
463+
"Invalid hyperparameter: %s is not supported by %s"
464+
% (name, algorithm_name)
449465
)
450466

451467
definition = self.hyperparameter_definitions[name]
@@ -456,7 +472,9 @@ def _validate_and_cast_hyperparameter(self, name, v):
456472

457473
if "range" in definition and not definition["range"].is_valid(value):
458474
valid_range = definition["range"].as_tuning_range(name)
459-
raise ValueError("Invalid value: %s Supported range: %s" % (value, valid_range))
475+
raise ValueError(
476+
"Invalid value: %s Supported range: %s" % (value, valid_range)
477+
)
460478
return value
461479

462480
def _validate_and_set_default_hyperparameters(self):
@@ -544,7 +562,9 @@ def _algorithm_training_input_modes(self, training_channels):
544562
return current_input_modes
545563

546564
@classmethod
547-
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
565+
def _prepare_init_params_from_job_description(
566+
cls, job_details, model_channel_name=None
567+
):
548568
"""Convert the job description to init params that can be handled by the class constructor.
549569
550570
Args:
@@ -556,9 +576,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
556576
Returns:
557577
dict: The transformed init_params
558578
"""
559-
init_params = super(AlgorithmEstimator, cls)._prepare_init_params_from_job_description(
560-
job_details, model_channel_name
561-
)
579+
init_params = super(
580+
AlgorithmEstimator, cls
581+
)._prepare_init_params_from_job_description(job_details, model_channel_name)
562582

563583
# This hyperparameter is added by Amazon SageMaker Automatic Model Tuning.
564584
# It cannot be set through instantiating an estimator.

src/sagemaker/amazon/amazon_estimator.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,18 @@ def data_location(self, data_location):
117117
"""Placeholder docstring"""
118118
if not data_location.startswith("s3://"):
119119
raise ValueError(
120-
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)
120+
'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(
121+
data_location
122+
)
121123
)
122124
if data_location[-1] != "/":
123125
data_location = data_location + "/"
124126
self._data_location = data_location
125127

126128
@classmethod
127-
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
129+
def _prepare_init_params_from_job_description(
130+
cls, job_details, model_channel_name=None
131+
):
128132
"""Convert the job description to init params that can be handled by the class constructor.
129133
130134
Args:
@@ -152,7 +156,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
152156
del init_params["image_uri"]
153157
return init_params
154158

155-
def prepare_workflow_for_training(self, records=None, mini_batch_size=None, job_name=None):
159+
def prepare_workflow_for_training(
160+
self, records=None, mini_batch_size=None, job_name=None
161+
):
156162
"""Calls _prepare_for_training. Used when setting up a workflow.
157163
158164
Args:
@@ -178,7 +184,9 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
178184
specified, one is generated, using the base name given to the
179185
constructor if applicable.
180186
"""
181-
super(AmazonAlgorithmEstimatorBase, self)._prepare_for_training(job_name=job_name)
187+
super(AmazonAlgorithmEstimatorBase, self)._prepare_for_training(
188+
job_name=job_name
189+
)
182190

183191
feature_dim = None
184192

@@ -244,7 +252,9 @@ def fit(
244252
will be unassociated.
245253
* `TrialComponentDisplayName` is used for display in Studio.
246254
"""
247-
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
255+
self._prepare_for_training(
256+
records, job_name=job_name, mini_batch_size=mini_batch_size
257+
)
248258

249259
self.latest_training_job = _TrainingJob.start_new(
250260
self, records, experiment_config=experiment_config
@@ -287,7 +297,9 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
287297
)
288298
parsed_s3_url = urlparse(self.data_location)
289299
bucket, key_prefix = parsed_s3_url.netloc, parsed_s3_url.path
290-
key_prefix = key_prefix + "{}-{}/".format(type(self).__name__, sagemaker_timestamp())
300+
key_prefix = key_prefix + "{}-{}/".format(
301+
type(self).__name__, sagemaker_timestamp()
302+
)
291303
key_prefix = key_prefix.lstrip("/")
292304
logger.debug("Uploading to bucket %s and key_prefix %s", bucket, key_prefix)
293305
manifest_s3_file = upload_numpy_to_s3_shards(
@@ -404,7 +416,10 @@ def _build_shards(num_shards, array):
404416
shard_size = int(array.shape[0] / num_shards)
405417
if shard_size == 0:
406418
raise ValueError("Array length is less than num shards")
407-
shards = [array[i * shard_size : i * shard_size + shard_size] for i in range(num_shards - 1)]
419+
shards = [
420+
array[i * shard_size : i * shard_size + shard_size]
421+
for i in range(num_shards - 1)
422+
]
408423
shards.append(array[(num_shards - 1) * shard_size :])
409424
return shards
410425

@@ -451,7 +466,9 @@ def upload_numpy_to_s3_shards(
451466
manifest_str = json.dumps(
452467
[{"prefix": "s3://{}/{}".format(bucket, key_prefix)}] + uploaded_files
453468
)
454-
s3.Object(bucket, manifest_key).put(Body=manifest_str.encode("utf-8"), **extra_put_kwargs)
469+
s3.Object(bucket, manifest_key).put(
470+
Body=manifest_str.encode("utf-8"), **extra_put_kwargs
471+
)
455472
return "s3://{}/{}".format(bucket, manifest_key)
456473
except Exception as ex: # pylint: disable=broad-except
457474
try:

src/sagemaker/amazon/common.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def serialize(self, data):
5353

5454
if len(data.shape) != 2:
5555
raise ValueError(
56-
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
56+
"Expected a 1D or 2D array, but got a %dD array instead."
57+
% len(data.shape)
5758
)
5859

5960
buffer = io.BytesIO()
@@ -290,5 +291,7 @@ def _resolve_type(dtype):
290291
raise ValueError("Unsupported dtype {} on array".format(dtype))
291292

292293

293-
numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer")
294+
numpy_to_record_serializer = deprecated_class(
295+
RecordSerializer, "numpy_to_record_serializer"
296+
)
294297
record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer")

src/sagemaker/amazon/factorization_machines.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def __init__(
222222
:class:`~sagemaker.estimator.amazon_estimator.AmazonAlgorithmEstimatorBase` and
223223
:class:`~sagemaker.estimator.EstimatorBase`.
224224
"""
225-
super(FactorizationMachines, self).__init__(role, instance_count, instance_type, **kwargs)
225+
super(FactorizationMachines, self).__init__(
226+
role, instance_count, instance_type, **kwargs
227+
)
226228

227229
self.num_factors = num_factors
228230
self.predictor_type = predictor_type
@@ -353,7 +355,9 @@ def __init__(
353355
sagemaker_session.boto_region_name,
354356
version=FactorizationMachines.repo_version,
355357
)
356-
pop_out_unused_kwarg("predictor_cls", kwargs, FactorizationMachinesPredictor.__name__)
358+
pop_out_unused_kwarg(
359+
"predictor_cls", kwargs, FactorizationMachinesPredictor.__name__
360+
)
357361
pop_out_unused_kwarg("image_uri", kwargs, image_uri)
358362
super(FactorizationMachinesModel, self).__init__(
359363
image_uri,

0 commit comments

Comments
 (0)