diff --git a/buildspec-release.yml b/buildspec-release.yml index 1c32b51cdf..e8acf68bec 100644 --- a/buildspec-release.yml +++ b/buildspec-release.yml @@ -9,6 +9,9 @@ phases: # run linters - tox -e flake8,pylint + # run format verification + - tox -e black-check + # run package and docbuild checks - tox -e twine - tox -e sphinx diff --git a/buildspec.yml b/buildspec.yml index af7e071a4d..253e0d79bc 100644 --- a/buildspec.yml +++ b/buildspec.yml @@ -14,6 +14,9 @@ phases: - tox -e twine - tox -e sphinx + # run format verification + - tox -e black-check + # run unit tests - AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN= AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION= diff --git a/doc/conf.py b/doc/conf.py index f6538642e4..f9705e0089 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -28,51 +28,65 @@ def __getattr__(cls, name): return MagicMock() -MOCK_MODULES = ['tensorflow', 'tensorflow.core', 'tensorflow.core.framework', 'tensorflow.python', - 'tensorflow.python.framework', 'tensorflow_serving', 'tensorflow_serving.apis', - 'numpy', 'scipy', 'scipy.sparse'] +MOCK_MODULES = [ + "tensorflow", + "tensorflow.core", + "tensorflow.core.framework", + "tensorflow.python", + "tensorflow.python.framework", + "tensorflow_serving", + "tensorflow_serving.apis", + "numpy", + "scipy", + "scipy.sparse", +] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) -project = u'sagemaker' +project = u"sagemaker" version = pkg_resources.require(project)[0].version # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', 'sphinx.ext.todo', - 'sphinx.ext.coverage', 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon'] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] -source_suffix = '.rst' # The suffix of source filenames. -master_doc = 'index' # The master toctree document. +source_suffix = ".rst" # The suffix of source filenames. +master_doc = "index" # The master toctree document. -copyright = u'%s, Amazon' % datetime.now().year +copyright = u"%s, Amazon" % datetime.now().year # The full version, including alpha/beta/rc tags. release = version # List of directories, relative to source directory, that shouldn't be searched # for source files. -exclude_trees = ['_build'] +exclude_trees = ["_build"] -pygments_style = 'default' +pygments_style = "default" autoclass_content = "both" -autodoc_default_flags = ['show-inheritance', 'members', 'undoc-members'] -autodoc_member_order = 'bysource' +autodoc_default_flags = ["show-inheritance", "members", "undoc-members"] +autodoc_member_order = "bysource" -if 'READTHEDOCS' in os.environ: - html_theme = 'default' +if "READTHEDOCS" in os.environ: + html_theme = "default" else: - html_theme = 'haiku' + html_theme = "haiku" html_static_path = [] -htmlhelp_basename = '%sdoc' % project +htmlhelp_basename = "%sdoc" % project # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'http://docs.python.org/': None} +intersphinx_mapping = {"http://docs.python.org/": None} # autosummary autosummary_generate = True diff --git a/examples/cli/host/script.py b/examples/cli/host/script.py index f3775b4310..506cc86c4e 100644 --- a/examples/cli/host/script.py +++ b/examples/cli/host/script.py @@ -12,12 +12,12 @@ def model_fn(model_dir): :param: model_dir The directory where model files are stored. :return: a model (in this case a Gluon network) """ - symbol = mx.sym.load('%s/model.json' % model_dir) - outputs = mx.symbol.softmax(data=symbol, name='softmax_label') - inputs = mx.sym.var('data') - param_dict = gluon.ParameterDict('model_') + symbol = mx.sym.load("%s/model.json" % model_dir) + outputs = mx.symbol.softmax(data=symbol, name="softmax_label") + inputs = mx.sym.var("data") + param_dict = gluon.ParameterDict("model_") net = gluon.SymbolBlock(outputs, inputs, param_dict) - net.load_params('%s/model.params' % model_dir, ctx=mx.cpu()) + net.load_params("%s/model.params" % model_dir, ctx=mx.cpu()) return net diff --git a/examples/cli/train/download_training_data.py b/examples/cli/train/download_training_data.py index eb33996904..2bc97d9588 100644 --- a/examples/cli/train/download_training_data.py +++ b/examples/cli/train/download_training_data.py @@ -2,8 +2,8 @@ def download_training_data(): - gluon.data.vision.MNIST('./data/training', train=True) - gluon.data.vision.MNIST('./data/training', train=False) + gluon.data.vision.MNIST("./data/training", train=True) + gluon.data.vision.MNIST("./data/training", train=False) if __name__ == "__main__": diff --git a/examples/cli/train/script.py b/examples/cli/train/script.py index a219548fcc..01d7ff3dbd 100644 --- a/examples/cli/train/script.py +++ b/examples/cli/train/script.py @@ -15,13 +15,13 @@ def train(channel_input_dirs, hyperparameters, **kwargs): ctx = mx.cpu() # retrieve the hyperparameters we set in notebook (with some defaults) - batch_size = hyperparameters.get('batch_size', 100) - epochs = hyperparameters.get('epochs', 10) - learning_rate = hyperparameters.get('learning_rate', 0.1) - momentum = hyperparameters.get('momentum', 0.9) - log_interval = hyperparameters.get('log_interval', 100) + batch_size = hyperparameters.get("batch_size", 100) + epochs = hyperparameters.get("epochs", 10) + learning_rate = hyperparameters.get("learning_rate", 0.1) + momentum = hyperparameters.get("momentum", 0.9) + log_interval = hyperparameters.get("log_interval", 100) - training_data = channel_input_dirs['training'] + training_data = channel_input_dirs["training"] # load training and validation data # we use the gluon.data.vision.MNIST class because of its built in mnist pre-processing logic, @@ -35,8 +35,9 @@ def train(channel_input_dirs, hyperparameters, **kwargs): # Collect all parameters from net and its children, then initialize them. net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) # Trainer is for updating parameters with gradient. - trainer = gluon.Trainer(net.collect_params(), 'sgd', - {'learning_rate': learning_rate, 'momentum': momentum}) + trainer = gluon.Trainer( + net.collect_params(), "sgd", {"learning_rate": learning_rate, "momentum": momentum} + ) metric = mx.metric.Accuracy() loss = gluon.loss.SoftmaxCrossEntropyLoss() @@ -61,32 +62,34 @@ def train(channel_input_dirs, hyperparameters, **kwargs): if i % log_interval == 0 and i > 0: name, acc = metric.get() - logger.info('[Epoch %d Batch %d] Training: %s=%f, %f samples/s' % - (epoch, i, name, acc, batch_size / (time.time() - btic))) + logger.info( + "[Epoch %d Batch %d] Training: %s=%f, %f samples/s" + % (epoch, i, name, acc, batch_size / (time.time() - btic)) + ) btic = time.time() name, acc = metric.get() - logger.info('[Epoch %d] Training: %s=%f' % (epoch, name, acc)) + logger.info("[Epoch %d] Training: %s=%f" % (epoch, name, acc)) name, val_acc = test(ctx, net, val_data) - logger.info('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc)) + logger.info("[Epoch %d] Validation: %s=%f" % (epoch, name, val_acc)) return net def save(net, model_dir): # save the model - y = net(mx.sym.var('data')) - y.save('%s/model.json' % model_dir) - net.collect_params().save('%s/model.params' % model_dir) + y = net(mx.sym.var("data")) + y.save("%s/model.json" % model_dir) + net.collect_params().save("%s/model.params" % model_dir) def define_network(): net = nn.Sequential() with net.name_scope(): - net.add(nn.Dense(128, activation='relu')) - net.add(nn.Dense(64, activation='relu')) + net.add(nn.Dense(128, activation="relu")) + net.add(nn.Dense(64, activation="relu")) net.add(nn.Dense(10)) return net @@ -99,13 +102,18 @@ def input_transformer(data, label): def get_train_data(data_dir, batch_size): return gluon.data.DataLoader( gluon.data.vision.MNIST(data_dir, train=True, transform=input_transformer), - batch_size=batch_size, shuffle=True, last_batch='discard') + batch_size=batch_size, + shuffle=True, + last_batch="discard", + ) def get_val_data(data_dir, batch_size): return gluon.data.DataLoader( gluon.data.vision.MNIST(data_dir, train=False, transform=input_transformer), - batch_size=batch_size, shuffle=False) + batch_size=batch_size, + shuffle=False, + ) def test(ctx, net, val_data): diff --git a/setup.py b/setup.py index 688fc0487f..13b5960a57 100644 --- a/setup.py +++ b/setup.py @@ -24,46 +24,62 @@ def read(fname): def read_version(): - return read('VERSION').strip() + return read("VERSION").strip() # Declare minimal set for installation -required_packages = ['boto3>=1.9.169', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', - 'urllib3>=1.21, <1.25', 'protobuf3-to-dict>=0.1.5', 'docker-compose>=1.23.0', - 'requests>=2.20.0, <2.21'] +required_packages = [ + "boto3>=1.9.169", + "numpy>=1.9.0", + "protobuf>=3.1", + "scipy>=0.19.0", + "urllib3>=1.21, <1.25", + "protobuf3-to-dict>=0.1.5", + "docker-compose>=1.23.0", + "requests>=2.20.0, <2.21", +] # enum is introduced in Python 3.4. Installing enum back port if sys.version_info < (3, 4): - required_packages.append('enum34>=1.1.6') + required_packages.append("enum34>=1.1.6") -setup(name="sagemaker", - version=read_version(), - description="Open source library for training and deploying models on Amazon SageMaker.", - packages=find_packages('src'), - package_dir={'': 'src'}, - py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob('src/*.py')], - long_description=read('README.rst'), - author="Amazon Web Services", - url='https://github.com/aws/sagemaker-python-sdk/', - license="Apache License 2.0", - keywords="ML Amazon AWS AI Tensorflow MXNet", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Natural Language :: English", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python", - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3.6", - ], - - install_requires=required_packages, - - extras_require={ - 'test': ['tox', 'flake8', 'pytest==4.4.1', 'pytest-cov', 'pytest-rerunfailures', - 'pytest-xdist', 'mock', 'tensorflow>=1.3.0', 'contextlib2', - 'awslogs', 'pandas']}, - - entry_points={ - 'console_scripts': ['sagemaker=sagemaker.cli.main:main'], - }) +setup( + name="sagemaker", + version=read_version(), + description="Open source library for training and deploying models on Amazon SageMaker.", + packages=find_packages("src"), + package_dir={"": "src"}, + py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")], + long_description=read("README.rst"), + author="Amazon Web Services", + url="https://github.com/aws/sagemaker-python-sdk/", + license="Apache License 2.0", + keywords="ML Amazon AWS AI Tensorflow MXNet", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3.6", + ], + install_requires=required_packages, + extras_require={ + "test": [ + "tox", + "flake8", + "pytest==4.4.1", + "pytest-cov", + "pytest-rerunfailures", + "pytest-xdist", + "mock", + "tensorflow>=1.3.0", + "contextlib2", + "awslogs", + "pandas", + "black==19.3b0 ; python_version >= '3.6'", + ] + }, + entry_points={"console_scripts": ["sagemaker=sagemaker.cli.main:main"]}, +) diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 5232e336f5..7993c2fcda 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -18,15 +18,29 @@ from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor # noqa: F401 from sagemaker.amazon.pca import PCA, PCAModel, PCAPredictor # noqa: F401 from sagemaker.amazon.lda import LDA, LDAModel, LDAPredictor # noqa: F401 -from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor # noqa: F401 -from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel # noqa: F401 +from sagemaker.amazon.linear_learner import ( # noqa: F401 + LinearLearner, + LinearLearnerModel, + LinearLearnerPredictor, +) +from sagemaker.amazon.factorization_machines import ( # noqa: F401 + FactorizationMachines, + FactorizationMachinesModel, +) from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor # noqa: F401 from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor # noqa: F401 -from sagemaker.amazon.randomcutforest import (RandomCutForest, RandomCutForestModel, # noqa: F401 - RandomCutForestPredictor) +from sagemaker.amazon.randomcutforest import ( # noqa: F401 + RandomCutForest, + RandomCutForestModel, + RandomCutForestPredictor, +) from sagemaker.amazon.knn import KNN, KNNModel, KNNPredictor # noqa: F401 from sagemaker.amazon.object2vec import Object2Vec, Object2VecModel # noqa: F401 -from sagemaker.amazon.ipinsights import IPInsights, IPInsightsModel, IPInsightsPredictor # noqa: F401 +from sagemaker.amazon.ipinsights import ( # noqa: F401 + IPInsights, + IPInsightsModel, + IPInsightsPredictor, +) from sagemaker.algorithm import AlgorithmEstimator # noqa: F401 from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics # noqa: F401 @@ -41,4 +55,4 @@ from sagemaker.session import s3_input # noqa: F401 from sagemaker.session import get_execution_role # noqa: F401 -__version__ = pkg_resources.require('sagemaker')[0].version +__version__ = pkg_resources.require("sagemaker")[0].version diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index 2c78e98f6f..68a29053af 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -27,7 +27,7 @@ class AlgorithmEstimator(EstimatorBase): """ # These Hyperparameter Types have a range definition. - _hyperpameters_with_range = ('Integer', 'Continuous', 'Categorical') + _hyperpameters_with_range = ("Integer", "Continuous", "Categorical") def __init__( self, @@ -38,7 +38,7 @@ def __init__( train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, - input_mode='File', + input_mode="File", output_path=None, output_kms_key=None, base_job_name=None, @@ -48,9 +48,9 @@ def __init__( subnets=None, security_group_ids=None, model_uri=None, - model_channel_name='model', + model_channel_name="model", metric_definitions=None, - encrypt_inter_container_traffic=False + encrypt_inter_container_traffic=False, ): """Initialize an ``AlgorithmEstimator`` instance. @@ -123,7 +123,7 @@ def __init__( model_uri=model_uri, model_channel_name=model_channel_name, metric_definitions=metric_definitions, - encrypt_inter_container_traffic=encrypt_inter_container_traffic + encrypt_inter_container_traffic=encrypt_inter_container_traffic, ) self.algorithm_spec = self.sagemaker_session.sagemaker_client.describe_algorithm( @@ -137,35 +137,35 @@ def __init__( self.set_hyperparameters(**hyperparameters) def validate_train_spec(self): - train_spec = self.algorithm_spec['TrainingSpecification'] - algorithm_name = self.algorithm_spec['AlgorithmName'] + train_spec = self.algorithm_spec["TrainingSpecification"] + algorithm_name = self.algorithm_spec["AlgorithmName"] # Check that the input mode provided is compatible with the training input modes for the # algorithm. - train_input_modes = self._algorithm_training_input_modes(train_spec['TrainingChannels']) + train_input_modes = self._algorithm_training_input_modes(train_spec["TrainingChannels"]) if self.input_mode not in train_input_modes: raise ValueError( - 'Invalid input mode: %s. %s only supports: %s' + "Invalid input mode: %s. %s only supports: %s" % (self.input_mode, algorithm_name, train_input_modes) ) # Check that the training instance type is compatible with the algorithm. - supported_instances = train_spec['SupportedTrainingInstanceTypes'] + supported_instances = train_spec["SupportedTrainingInstanceTypes"] if self.train_instance_type not in supported_instances: raise ValueError( - 'Invalid train_instance_type: %s. %s supports the following instance types: %s' + "Invalid train_instance_type: %s. %s supports the following instance types: %s" % (self.train_instance_type, algorithm_name, supported_instances) ) # Verify if distributed training is supported by the algorithm if ( self.train_instance_count > 1 - and 'SupportsDistributedTraining' in train_spec - and not train_spec['SupportsDistributedTraining'] + and "SupportsDistributedTraining" in train_spec + and not train_spec["SupportsDistributedTraining"] ): raise ValueError( - 'Distributed training is not supported by %s. ' - 'Please set train_instance_count=1' % algorithm_name + "Distributed training is not supported by %s. " + "Please set train_instance_count=1" % algorithm_name ) def set_hyperparameters(self, **kwargs): @@ -187,7 +187,7 @@ def train_image(self): The fit() method, that does the model training, calls this method to find the image to use for model training. """ - raise RuntimeError('train_image is never meant to be called on Algorithm Estimators') + raise RuntimeError("train_image is never meant to be called on Algorithm Estimators") def enable_network_isolation(self): """Return True if this Estimator will need network isolation to run. @@ -258,9 +258,22 @@ def predict_wrapper(endpoint, session): **kwargs ) - def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, - output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, role=None, volume_kms_key=None): + def transformer( + self, + instance_count, + instance_type, + strategy=None, + assemble_with=None, + output_path=None, + output_kms_key=None, + accept=None, + env=None, + max_concurrent_transforms=None, + max_payload=None, + tags=None, + role=None, + volume_kms_key=None, + ): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -300,18 +313,28 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit tags = tags or self.tags else: - raise RuntimeError('No finished training job found associated with this estimator') - - return Transformer(model_name, instance_count, instance_type, strategy=strategy, - assemble_with=assemble_with, output_path=output_path, - output_kms_key=output_kms_key, accept=accept, - max_concurrent_transforms=max_concurrent_transforms, - max_payload=max_payload, env=transform_env, tags=tags, - base_transform_job_name=self.base_job_name, - volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session) + raise RuntimeError("No finished training job found associated with this estimator") + + return Transformer( + model_name, + instance_count, + instance_type, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=output_kms_key, + accept=accept, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + env=transform_env, + tags=tags, + base_transform_job_name=self.base_job_name, + volume_kms_key=volume_kms_key, + sagemaker_session=self.sagemaker_session, + ) def _is_marketplace(self): - return 'ProductId' in self.algorithm_spec + return "ProductId" in self.algorithm_spec def _prepare_for_training(self, job_name=None): # Validate hyperparameters @@ -328,39 +351,39 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None): super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name) def _validate_input_channels(self, channels): - train_spec = self.algorithm_spec['TrainingSpecification'] - algorithm_name = self.algorithm_spec['AlgorithmName'] - training_channels = {c['Name']: c for c in train_spec['TrainingChannels']} + train_spec = self.algorithm_spec["TrainingSpecification"] + algorithm_name = self.algorithm_spec["AlgorithmName"] + training_channels = {c["Name"]: c for c in train_spec["TrainingChannels"]} # check for unknown channels that the algorithm does not support for c in channels: if c not in training_channels: raise ValueError( - 'Unknown input channel: %s is not supported by: %s' % (c, algorithm_name) + "Unknown input channel: %s is not supported by: %s" % (c, algorithm_name) ) # check for required channels that were not provided for name, channel in training_channels.items(): - if name not in channels and 'IsRequired' in channel and channel['IsRequired']: - raise ValueError('Required input channel: %s Was not provided.' % (name)) + if name not in channels and "IsRequired" in channel and channel["IsRequired"]: + raise ValueError("Required input channel: %s Was not provided." % (name)) def _validate_and_cast_hyperparameter(self, name, v): - algorithm_name = self.algorithm_spec['AlgorithmName'] + algorithm_name = self.algorithm_spec["AlgorithmName"] if name not in self.hyperparameter_definitions: raise ValueError( - 'Invalid hyperparameter: %s is not supported by %s' % (name, algorithm_name) + "Invalid hyperparameter: %s is not supported by %s" % (name, algorithm_name) ) definition = self.hyperparameter_definitions[name] - if 'class' in definition: - value = definition['class'].cast_to_type(v) + if "class" in definition: + value = definition["class"].cast_to_type(v) else: value = v - if 'range' in definition and not definition['range'].is_valid(value): - valid_range = definition['range'].as_tuning_range(name) - raise ValueError('Invalid value: %s Supported range: %s' % (value, valid_range)) + if "range" in definition and not definition["range"].is_valid(value): + valid_range = definition["range"].as_tuning_range(name) + raise ValueError("Invalid value: %s Supported range: %s" % (value, valid_range)) return value def _validate_and_set_default_hyperparameters(self): @@ -368,77 +391,77 @@ def _validate_and_set_default_hyperparameters(self): # for one, set it. for name, definition in self.hyperparameter_definitions.items(): if name not in self.hyperparam_dict: - spec = definition['spec'] - if 'DefaultValue' in spec: - self.hyperparam_dict[name] = spec['DefaultValue'] - elif 'IsRequired' in spec and spec['IsRequired']: - raise ValueError('Required hyperparameter: %s is not set' % name) + spec = definition["spec"] + if "DefaultValue" in spec: + self.hyperparam_dict[name] = spec["DefaultValue"] + elif "IsRequired" in spec and spec["IsRequired"]: + raise ValueError("Required hyperparameter: %s is not set" % name) def _parse_hyperparameters(self): definitions = {} - training_spec = self.algorithm_spec['TrainingSpecification'] - if 'SupportedHyperParameters' in training_spec: - hyperparameters = training_spec['SupportedHyperParameters'] + training_spec = self.algorithm_spec["TrainingSpecification"] + if "SupportedHyperParameters" in training_spec: + hyperparameters = training_spec["SupportedHyperParameters"] for h in hyperparameters: - parameter_type = h['Type'] - name = h['Name'] + parameter_type = h["Type"] + name = h["Name"] parameter_class, parameter_range = self._hyperparameter_range_and_class( parameter_type, h ) - definitions[name] = {'spec': h} + definitions[name] = {"spec": h} if parameter_range: - definitions[name]['range'] = parameter_range + definitions[name]["range"] = parameter_range if parameter_class: - definitions[name]['class'] = parameter_class + definitions[name]["class"] = parameter_class return definitions def _hyperparameter_range_and_class(self, parameter_type, hyperparameter): if parameter_type in self._hyperpameters_with_range: - range_name = parameter_type + 'ParameterRangeSpecification' + range_name = parameter_type + "ParameterRangeSpecification" parameter_class = None parameter_range = None - if parameter_type in ('Integer', 'Continuous'): + if parameter_type in ("Integer", "Continuous"): # Integer and Continuous are handled the same way. We get the min and max values # and just create an Instance of Parameter. Note that the range is optional for all # the Parameter Types. - if parameter_type == 'Integer': + if parameter_type == "Integer": parameter_class = sagemaker.parameter.IntegerParameter else: parameter_class = sagemaker.parameter.ContinuousParameter - if 'Range' in hyperparameter: + if "Range" in hyperparameter: min_value = parameter_class.cast_to_type( - hyperparameter['Range'][range_name]['MinValue'] + hyperparameter["Range"][range_name]["MinValue"] ) max_value = parameter_class.cast_to_type( - hyperparameter['Range'][range_name]['MaxValue'] + hyperparameter["Range"][range_name]["MaxValue"] ) parameter_range = parameter_class(min_value, max_value) - elif parameter_type == 'Categorical': + elif parameter_type == "Categorical": parameter_class = sagemaker.parameter.CategoricalParameter - if 'Range' in hyperparameter: - values = hyperparameter['Range'][range_name]['Values'] + if "Range" in hyperparameter: + values = hyperparameter["Range"][range_name]["Values"] parameter_range = sagemaker.parameter.CategoricalParameter(values) - elif parameter_type == 'FreeText': + elif parameter_type == "FreeText": pass else: raise ValueError( - 'Invalid Hyperparameter type: %s. Valid ones are:' - '(Integer, Continuous, Categorical, FreeText)' % parameter_type + "Invalid Hyperparameter type: %s. Valid ones are:" + "(Integer, Continuous, Categorical, FreeText)" % parameter_type ) return parameter_class, parameter_range def _algorithm_training_input_modes(self, training_channels): - current_input_modes = {'File', 'Pipe'} + current_input_modes = {"File", "Pipe"} for channel in training_channels: - supported_input_modes = set(channel['SupportedInputModes']) + supported_input_modes = set(channel["SupportedInputModes"]) current_input_modes = current_input_modes & supported_input_modes return current_input_modes diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index d43902b54d..94e1e3b340 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -32,12 +32,14 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase): """Base class for Amazon first-party Estimator implementations. This class isn't intended to be instantiated directly.""" - feature_dim = hp('feature_dim', validation.gt(0), data_type=int) - mini_batch_size = hp('mini_batch_size', validation.gt(0), data_type=int) + feature_dim = hp("feature_dim", validation.gt(0), data_type=int) + mini_batch_size = hp("mini_batch_size", validation.gt(0), data_type=int) repo_name = None repo_version = None - def __init__(self, role, train_instance_count, train_instance_type, data_location=None, **kwargs): + def __init__( + self, role, train_instance_count, train_instance_type, data_location=None, **kwargs + ): """Initialize an AmazonAlgorithmEstimatorBase. Args: @@ -45,18 +47,19 @@ def __init__(self, role, train_instance_count, train_instance_type, data_locatio S3 url. For example "s3://example-bucket/some-key-prefix/". Objects will be saved in a unique sub-directory of the specified location. If None, a default data location will be used.""" - super(AmazonAlgorithmEstimatorBase, self).__init__(role, train_instance_count, train_instance_type, - **kwargs) + super(AmazonAlgorithmEstimatorBase, self).__init__( + role, train_instance_count, train_instance_type, **kwargs + ) data_location = data_location or "s3://{}/sagemaker-record-sets/".format( - self.sagemaker_session.default_bucket()) + self.sagemaker_session.default_bucket() + ) self.data_location = data_location def train_image(self): return get_image_uri( - self.sagemaker_session.boto_region_name, - type(self).repo_name, - type(self).repo_version) + self.sagemaker_session.boto_region_name, type(self).repo_name, type(self).repo_version + ) def hyperparameters(self): return hp.serialize_all(self) @@ -67,10 +70,12 @@ def data_location(self): @data_location.setter def data_location(self, data_location): - if not data_location.startswith('s3://'): - raise ValueError('Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location)) - if data_location[-1] != '/': - data_location = data_location + '/' + if not data_location.startswith("s3://"): + raise ValueError( + 'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location) + ) + if data_location[-1] != "/": + data_location = data_location + "/" self._data_location = data_location @classmethod @@ -85,19 +90,20 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na dictionary: The transformed init_params """ - init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description( - job_details, model_channel_name) + init_params = super( + AmazonAlgorithmEstimatorBase, cls + )._prepare_init_params_from_job_description(job_details, model_channel_name) # The hyperparam names may not be the same as the class attribute that holds them, # for instance: local_lloyd_init_method is called local_init_method. We need to map these # and pass the correct name to the constructor. for attribute, value in cls.__dict__.items(): if isinstance(value, hp): - if value.name in init_params['hyperparameters']: - init_params[attribute] = init_params['hyperparameters'][value.name] + if value.name in init_params["hyperparameters"]: + init_params[attribute] = init_params["hyperparameters"][value.name] - del init_params['hyperparameters'] - del init_params['image'] + del init_params["hyperparameters"] + del init_params["image"] return init_params def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @@ -116,11 +122,11 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): if isinstance(records, list): for record in records: - if record.channel == 'train': + if record.channel == "train": feature_dim = record.feature_dim break if feature_dim is None: - raise ValueError('Must provide train channel.') + raise ValueError("Must provide train channel.") else: feature_dim = records.feature_dim @@ -184,21 +190,28 @@ def record_set(self, train, labels=None, channel="train", encrypt=False): Returns: RecordSet: A RecordSet referencing the encoded, uploading training and label data. """ - s3 = self.sagemaker_session.boto_session.resource('s3') + s3 = self.sagemaker_session.boto_session.resource("s3") parsed_s3_url = urlparse(self.data_location) bucket, key_prefix = parsed_s3_url.netloc, parsed_s3_url.path - key_prefix = key_prefix + '{}-{}/'.format(type(self).__name__, sagemaker_timestamp()) - key_prefix = key_prefix.lstrip('/') - logger.debug('Uploading to bucket {} and key_prefix {}'.format(bucket, key_prefix)) - manifest_s3_file = upload_numpy_to_s3_shards(self.train_instance_count, s3, bucket, - key_prefix, train, labels, encrypt) + key_prefix = key_prefix + "{}-{}/".format(type(self).__name__, sagemaker_timestamp()) + key_prefix = key_prefix.lstrip("/") + logger.debug("Uploading to bucket {} and key_prefix {}".format(bucket, key_prefix)) + manifest_s3_file = upload_numpy_to_s3_shards( + self.train_instance_count, s3, bucket, key_prefix, train, labels, encrypt + ) logger.debug("Created manifest file {}".format(manifest_s3_file)) - return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel) + return RecordSet( + manifest_s3_file, + num_records=train.shape[0], + feature_dim=train.shape[1], + channel=channel, + ) class RecordSet(object): - - def __init__(self, s3_data, num_records, feature_dim, s3_data_type='ManifestFile', channel='train'): + def __init__( + self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train" + ): """A collection of Amazon :class:~`Record` objects serialized and stored in S3. Args: @@ -228,7 +241,7 @@ def data_channel(self): def records_s3_input(self): """Return a s3_input to represent the training data""" - return s3_input(self.s3_data, distribution='ShardedByS3Key', s3_data_type=self.s3_data_type) + return s3_input(self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type) def _build_shards(num_shards, array): @@ -237,12 +250,14 @@ def _build_shards(num_shards, array): shard_size = int(array.shape[0] / num_shards) if shard_size == 0: raise ValueError("Array length is less than num shards") - shards = [array[i * shard_size:i * shard_size + shard_size] for i in range(num_shards - 1)] - shards.append(array[(num_shards - 1) * shard_size:]) + shards = [array[i * shard_size : i * shard_size + shard_size] for i in range(num_shards - 1)] + shards.append(array[(num_shards - 1) * shard_size :]) return shards -def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False): +def upload_numpy_to_s3_shards( + num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False +): """Upload the training ``array`` and ``labels`` arrays to ``num_shards`` S3 objects, stored in "s3://``bucket``/``key_prefix``/". Optionally ``encrypt`` the S3 objects using AES-256.""" @@ -250,9 +265,9 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= if labels is not None: label_shards = _build_shards(num_shards, labels) uploaded_files = [] - if key_prefix[-1] != '/': - key_prefix = key_prefix + '/' - extra_put_kwargs = {'ServerSideEncryption': 'AES256'} if encrypt else {} + if key_prefix[-1] != "/": + key_prefix = key_prefix + "/" + extra_put_kwargs = {"ServerSideEncryption": "AES256"} if encrypt else {} try: for shard_index, shard in enumerate(shards): with tempfile.TemporaryFile() as file: @@ -269,8 +284,9 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= uploaded_files.append(file_name) manifest_key = key_prefix + ".amazon.manifest" manifest_str = json.dumps( - [{'prefix': 's3://{}/{}'.format(bucket, key_prefix)}] + uploaded_files) - s3.Object(bucket, manifest_key).put(Body=manifest_str.encode('utf-8'), **extra_put_kwargs) + [{"prefix": "s3://{}/{}".format(bucket, key_prefix)}] + uploaded_files + ) + s3.Object(bucket, manifest_key).put(Body=manifest_str.encode("utf-8"), **extra_put_kwargs) return "s3://{}/{}".format(bucket, manifest_key) except Exception as ex: # pylint: disable=broad-except try: @@ -288,94 +304,112 @@ def registry(region_name, algorithm=None): https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon """ - if algorithm in [None, 'pca', 'kmeans', 'linear-learner', 'factorization-machines', 'ntm', - 'randomcutforest', 'knn', 'object2vec', 'ipinsights']: + if algorithm in [ + None, + "pca", + "kmeans", + "linear-learner", + "factorization-machines", + "ntm", + "randomcutforest", + "knn", + "object2vec", + "ipinsights", + ]: account_id = { - 'us-east-1': '382416733822', - 'us-east-2': '404615174143', - 'us-west-2': '174872318107', - 'eu-west-1': '438346466558', - 'eu-central-1': '664544806723', - 'ap-northeast-1': '351501993468', - 'ap-northeast-2': '835164637446', - 'ap-southeast-2': '712309505854', - 'us-gov-west-1': '226302683700', - 'ap-southeast-1': '475088953585', - 'ap-south-1': '991648021394', - 'ca-central-1': '469771592824', - 'eu-west-2': '644912444149', - 'us-west-1': '632365934929', - 'us-iso-east-1': '490574956308', + "us-east-1": "382416733822", + "us-east-2": "404615174143", + "us-west-2": "174872318107", + "eu-west-1": "438346466558", + "eu-central-1": "664544806723", + "ap-northeast-1": "351501993468", + "ap-northeast-2": "835164637446", + "ap-southeast-2": "712309505854", + "us-gov-west-1": "226302683700", + "ap-southeast-1": "475088953585", + "ap-south-1": "991648021394", + "ca-central-1": "469771592824", + "eu-west-2": "644912444149", + "us-west-1": "632365934929", + "us-iso-east-1": "490574956308", }[region_name] - elif algorithm in ['lda']: + elif algorithm in ["lda"]: account_id = { - 'us-east-1': '766337827248', - 'us-east-2': '999911452149', - 'us-west-2': '266724342769', - 'eu-west-1': '999678624901', - 'eu-central-1': '353608530281', - 'ap-northeast-1': '258307448986', - 'ap-northeast-2': '293181348795', - 'ap-southeast-2': '297031611018', - 'us-gov-west-1': '226302683700', - 'ap-southeast-1': '475088953585', - 'ap-south-1': '991648021394', - 'ca-central-1': '469771592824', - 'eu-west-2': '644912444149', - 'us-west-1': '632365934929', - 'us-iso-east-1': '490574956308', + "us-east-1": "766337827248", + "us-east-2": "999911452149", + "us-west-2": "266724342769", + "eu-west-1": "999678624901", + "eu-central-1": "353608530281", + "ap-northeast-1": "258307448986", + "ap-northeast-2": "293181348795", + "ap-southeast-2": "297031611018", + "us-gov-west-1": "226302683700", + "ap-southeast-1": "475088953585", + "ap-south-1": "991648021394", + "ca-central-1": "469771592824", + "eu-west-2": "644912444149", + "us-west-1": "632365934929", + "us-iso-east-1": "490574956308", }[region_name] - elif algorithm in ['forecasting-deepar']: + elif algorithm in ["forecasting-deepar"]: account_id = { - 'us-east-1': '522234722520', - 'us-east-2': '566113047672', - 'us-west-2': '156387875391', - 'eu-west-1': '224300973850', - 'eu-central-1': '495149712605', - 'ap-northeast-1': '633353088612', - 'ap-northeast-2': '204372634319', - 'ap-southeast-2': '514117268639', - 'us-gov-west-1': '226302683700', - 'ap-southeast-1': '475088953585', - 'ap-south-1': '991648021394', - 'ca-central-1': '469771592824', - 'eu-west-2': '644912444149', - 'us-west-1': '632365934929', - 'us-iso-east-1': '490574956308', + "us-east-1": "522234722520", + "us-east-2": "566113047672", + "us-west-2": "156387875391", + "eu-west-1": "224300973850", + "eu-central-1": "495149712605", + "ap-northeast-1": "633353088612", + "ap-northeast-2": "204372634319", + "ap-southeast-2": "514117268639", + "us-gov-west-1": "226302683700", + "ap-southeast-1": "475088953585", + "ap-south-1": "991648021394", + "ca-central-1": "469771592824", + "eu-west-2": "644912444149", + "us-west-1": "632365934929", + "us-iso-east-1": "490574956308", }[region_name] - elif algorithm in ['xgboost', 'seq2seq', 'image-classification', 'blazingtext', - 'object-detection', 'semantic-segmentation']: + elif algorithm in [ + "xgboost", + "seq2seq", + "image-classification", + "blazingtext", + "object-detection", + "semantic-segmentation", + ]: account_id = { - 'us-east-1': '811284229777', - 'us-east-2': '825641698319', - 'us-west-2': '433757028032', - 'eu-west-1': '685385470294', - 'eu-central-1': '813361260812', - 'ap-northeast-1': '501404015308', - 'ap-northeast-2': '306986355934', - 'ap-southeast-2': '544295431143', - 'us-gov-west-1': '226302683700', - 'ap-southeast-1': '475088953585', - 'ap-south-1': '991648021394', - 'ca-central-1': '469771592824', - 'eu-west-2': '644912444149', - 'us-west-1': '632365934929', - 'us-iso-east-1': '490574956308', + "us-east-1": "811284229777", + "us-east-2": "825641698319", + "us-west-2": "433757028032", + "eu-west-1": "685385470294", + "eu-central-1": "813361260812", + "ap-northeast-1": "501404015308", + "ap-northeast-2": "306986355934", + "ap-southeast-2": "544295431143", + "us-gov-west-1": "226302683700", + "ap-southeast-1": "475088953585", + "ap-south-1": "991648021394", + "ca-central-1": "469771592824", + "eu-west-2": "644912444149", + "us-west-1": "632365934929", + "us-iso-east-1": "490574956308", }[region_name] - elif algorithm in ['image-classification-neo', 'xgboost-neo']: + elif algorithm in ["image-classification-neo", "xgboost-neo"]: account_id = { - 'us-west-2': '301217895009', - 'us-east-1': '785573368785', - 'eu-west-1': '802834080501', - 'us-east-2': '007439368137', + "us-west-2": "301217895009", + "us-east-1": "785573368785", + "eu-west-1": "802834080501", + "us-east-2": "007439368137", }[region_name] else: - raise ValueError('Algorithm class:{} does not have mapping to account_id with images'.format(algorithm)) + raise ValueError( + "Algorithm class:{} does not have mapping to account_id with images".format(algorithm) + ) return get_ecr_image_uri_prefix(account_id, region_name) def get_image_uri(region_name, repo_name, repo_version=1): """Return algorithm image URI for the given AWS region, repository name, and repository version""" - repo = '{}:{}'.format(repo_name, repo_version) - return '{}/{}'.format(registry(region_name, repo_name), repo) + repo = "{}:{}".format(repo_name, repo_version) + return "{}/{}".format(registry(region_name, repo_name), repo) diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index e15cc09a4d..5a402a9c25 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -23,8 +23,7 @@ class numpy_to_record_serializer(object): - - def __init__(self, content_type='application/x-recordio-protobuf'): + def __init__(self, content_type="application/x-recordio-protobuf"): self.content_type = content_type def __call__(self, array): @@ -38,8 +37,7 @@ def __call__(self, array): class record_deserializer(object): - - def __init__(self, accept='application/x-recordio-protobuf'): + def __init__(self, accept="application/x-recordio-protobuf"): self.accept = accept def __call__(self, stream, content_type): @@ -95,8 +93,11 @@ def write_numpy_to_dense_tensor(file, array, labels=None): if not len(labels.shape) == 1: raise ValueError("Labels must be a Vector") if labels.shape[0] not in array.shape: - raise ValueError("Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape)) + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) resolved_label_type = _resolve_type(labels.dtype) resolved_type = _resolve_type(array.dtype) @@ -123,8 +124,11 @@ def write_spmatrix_to_sparse_tensor(file, array, labels=None): if not len(labels.shape) == 1: raise ValueError("Labels must be a Vector") if labels.shape[0] not in array.shape: - raise ValueError("Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape)) + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) resolved_label_type = _resolve_type(labels.dtype) resolved_type = _resolve_type(array.dtype) @@ -170,27 +174,27 @@ def read_records(file): else: padding[amount] = bytearray([0x00 for _ in range(amount)]) -_kmagic = 0xced7230a +_kmagic = 0xCED7230A def _write_recordio(f, data): """Writes a single data point as a RecordIO record to the given file.""" length = len(data) - f.write(struct.pack('I', _kmagic)) - f.write(struct.pack('I', length)) + f.write(struct.pack("I", _kmagic)) + f.write(struct.pack("I", length)) pad = (((length + 3) >> 2) << 2) - length f.write(data) f.write(padding[pad]) def read_recordio(f): - while(True): + while True: try: - read_kmagic, = struct.unpack('I', f.read(4)) + read_kmagic, = struct.unpack("I", f.read(4)) except struct.error: return assert read_kmagic == _kmagic - len_record, = struct.unpack('I', f.read(4)) + len_record, = struct.unpack("I", f.read(4)) pad = (((len_record + 3) >> 2) << 2) - len_record yield f.read(len_record) if pad: @@ -199,9 +203,9 @@ def read_recordio(f): def _resolve_type(dtype): if dtype == np.dtype(int): - return 'Int32' + return "Int32" elif dtype == np.dtype(float): - return 'Float64' - elif dtype == np.dtype('float32'): - return 'Float32' - raise ValueError('Unsupported dtype {} on array'.format(dtype)) + return "Float64" + elif dtype == np.dtype("float32"): + return "Float32" + raise ValueError("Unsupported dtype {} on array".format(dtype)) diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index d01b45da35..ad377b5f2b 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -24,47 +24,85 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase): - repo_name = 'factorization-machines' + repo_name = "factorization-machines" repo_version = 1 - num_factors = hp('num_factors', gt(0), 'An integer greater than zero', int) - predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'), - 'Value "binary_classifier" or "regressor"', str) - epochs = hp('epochs', gt(0), "An integer greater than 0", int) - clip_gradient = hp('clip_gradient', (), "A float value", float) - eps = hp('eps', (), "A float value", float) - rescale_grad = hp('rescale_grad', (), "A float value", float) - bias_lr = hp('bias_lr', ge(0), "A non-negative float", float) - linear_lr = hp('linear_lr', ge(0), "A non-negative float", float) - factors_lr = hp('factors_lr', ge(0), "A non-negative float", float) - bias_wd = hp('bias_wd', ge(0), "A non-negative float", float) - linear_wd = hp('linear_wd', ge(0), "A non-negative float", float) - factors_wd = hp('factors_wd', ge(0), "A non-negative float", float) - bias_init_method = hp('bias_init_method', isin('normal', 'uniform', 'constant'), - 'Value "normal", "uniform" or "constant"', str) - bias_init_scale = hp('bias_init_scale', ge(0), "A non-negative float", float) - bias_init_sigma = hp('bias_init_sigma', ge(0), "A non-negative float", float) - bias_init_value = hp('bias_init_value', (), "A float value", float) - linear_init_method = hp('linear_init_method', isin('normal', 'uniform', 'constant'), - 'Value "normal", "uniform" or "constant"', str) - linear_init_scale = hp('linear_init_scale', ge(0), "A non-negative float", float) - linear_init_sigma = hp('linear_init_sigma', ge(0), "A non-negative float", float) - linear_init_value = hp('linear_init_value', (), "A float value", float) - factors_init_method = hp('factors_init_method', isin('normal', 'uniform', 'constant'), - 'Value "normal", "uniform" or "constant"', str) - factors_init_scale = hp('factors_init_scale', ge(0), "A non-negative float", float) - factors_init_sigma = hp('factors_init_sigma', ge(0), "A non-negative float", float) - factors_init_value = hp('factors_init_value', (), "A float value", float) - - def __init__(self, role, train_instance_count, train_instance_type, - num_factors, predictor_type, - epochs=None, clip_gradient=None, eps=None, rescale_grad=None, - bias_lr=None, linear_lr=None, factors_lr=None, - bias_wd=None, linear_wd=None, factors_wd=None, - bias_init_method=None, bias_init_scale=None, bias_init_sigma=None, bias_init_value=None, - linear_init_method=None, linear_init_scale=None, linear_init_sigma=None, linear_init_value=None, - factors_init_method=None, factors_init_scale=None, factors_init_sigma=None, factors_init_value=None, - **kwargs): + num_factors = hp("num_factors", gt(0), "An integer greater than zero", int) + predictor_type = hp( + "predictor_type", + isin("binary_classifier", "regressor"), + 'Value "binary_classifier" or "regressor"', + str, + ) + epochs = hp("epochs", gt(0), "An integer greater than 0", int) + clip_gradient = hp("clip_gradient", (), "A float value", float) + eps = hp("eps", (), "A float value", float) + rescale_grad = hp("rescale_grad", (), "A float value", float) + bias_lr = hp("bias_lr", ge(0), "A non-negative float", float) + linear_lr = hp("linear_lr", ge(0), "A non-negative float", float) + factors_lr = hp("factors_lr", ge(0), "A non-negative float", float) + bias_wd = hp("bias_wd", ge(0), "A non-negative float", float) + linear_wd = hp("linear_wd", ge(0), "A non-negative float", float) + factors_wd = hp("factors_wd", ge(0), "A non-negative float", float) + bias_init_method = hp( + "bias_init_method", + isin("normal", "uniform", "constant"), + 'Value "normal", "uniform" or "constant"', + str, + ) + bias_init_scale = hp("bias_init_scale", ge(0), "A non-negative float", float) + bias_init_sigma = hp("bias_init_sigma", ge(0), "A non-negative float", float) + bias_init_value = hp("bias_init_value", (), "A float value", float) + linear_init_method = hp( + "linear_init_method", + isin("normal", "uniform", "constant"), + 'Value "normal", "uniform" or "constant"', + str, + ) + linear_init_scale = hp("linear_init_scale", ge(0), "A non-negative float", float) + linear_init_sigma = hp("linear_init_sigma", ge(0), "A non-negative float", float) + linear_init_value = hp("linear_init_value", (), "A float value", float) + factors_init_method = hp( + "factors_init_method", + isin("normal", "uniform", "constant"), + 'Value "normal", "uniform" or "constant"', + str, + ) + factors_init_scale = hp("factors_init_scale", ge(0), "A non-negative float", float) + factors_init_sigma = hp("factors_init_sigma", ge(0), "A non-negative float", float) + factors_init_value = hp("factors_init_value", (), "A float value", float) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + num_factors, + predictor_type, + epochs=None, + clip_gradient=None, + eps=None, + rescale_grad=None, + bias_lr=None, + linear_lr=None, + factors_lr=None, + bias_wd=None, + linear_wd=None, + factors_wd=None, + bias_init_method=None, + bias_init_scale=None, + bias_init_sigma=None, + bias_init_value=None, + linear_init_method=None, + linear_init_scale=None, + linear_init_sigma=None, + linear_init_value=None, + factors_init_method=None, + factors_init_scale=None, + factors_init_sigma=None, + factors_init_value=None, + **kwargs + ): """Factorization Machines is :class:`Estimator` for general-purpose supervised learning. Amazon SageMaker Factorization Machines is a general-purpose supervised learning algorithm that you can use @@ -137,7 +175,9 @@ def __init__(self, role, train_instance_count, train_instance_type, effect when factors_init_method parameter is 'constant'. **kwargs: base class keyword argument values. """ - super(FactorizationMachines, self).__init__(role, train_instance_count, train_instance_type, **kwargs) + super(FactorizationMachines, self).__init__( + role, train_instance_count, train_instance_type, **kwargs + ) self.num_factors = num_factors self.predictor_type = predictor_type @@ -175,8 +215,12 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return FactorizationMachinesModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return FactorizationMachinesModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) class FactorizationMachinesPredictor(RealTimePredictor): @@ -194,10 +238,12 @@ class FactorizationMachinesPredictor(RealTimePredictor): """ def __init__(self, endpoint, sagemaker_session=None): - super(FactorizationMachinesPredictor, self).__init__(endpoint, - sagemaker_session, - serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(FactorizationMachinesPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class FactorizationMachinesModel(Model): @@ -206,11 +252,13 @@ class FactorizationMachinesModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(FactorizationMachines.repo_name, FactorizationMachines.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) - super(FactorizationMachinesModel, self).__init__(model_data, - image, - role, - predictor_cls=FactorizationMachinesPredictor, - sagemaker_session=sagemaker_session, - **kwargs) + repo = "{}:{}".format(FactorizationMachines.repo_name, FactorizationMachines.repo_version) + image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo) + super(FactorizationMachinesModel, self).__init__( + model_data, + image, + role, + predictor_cls=FactorizationMachinesPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index ad25a84d14..06dc89e45c 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -46,7 +46,7 @@ def validate(self, value): raise ValueError(error_message) def __get__(self, obj, objtype): - if '_hyperparameters' not in dir(obj) or self.name not in obj._hyperparameters: + if "_hyperparameters" not in dir(obj) or self.name not in obj._hyperparameters: raise AttributeError() return obj._hyperparameters[self.name] @@ -54,7 +54,7 @@ def __set__(self, obj, value): """Validate the supplied value and set this hyperparameter to value""" value = None if value is None else self.data_type(value) self.validate(value) - if '_hyperparameters' not in dir(obj): + if "_hyperparameters" not in dir(obj): obj._hyperparameters = dict() obj._hyperparameters[self.name] = value @@ -65,6 +65,6 @@ def __delete__(self, obj): @staticmethod def serialize_all(obj): """Return all non-None ``hyperparameter`` values on ``obj`` as a ``dict[str,str].``""" - if '_hyperparameters' not in dir(obj): + if "_hyperparameters" not in dir(obj): return {} return {k: str(v) for k, v in obj._hyperparameters.items() if v is not None} diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index e1cb6c1f34..5844092196 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -22,26 +22,47 @@ class IPInsights(AmazonAlgorithmEstimatorBase): - repo_name = 'ipinsights' + repo_name = "ipinsights" repo_version = 1 MINI_BATCH_SIZE = 10000 - num_entity_vectors = hp('num_entity_vectors', (ge(1), le(250000000)), 'An integer in [1, 250000000]', int) - vector_dim = hp('vector_dim', (ge(4), le(4096)), 'An integer in [4, 4096]', int) - - batch_metrics_publish_interval = hp('batch_metrics_publish_interval', (ge(1)), 'An integer greater than 0', int) - epochs = hp('epochs', (ge(1)), 'An integer greater than 0', int) - learning_rate = hp('learning_rate', (ge(1e-6), le(10.0)), 'A float in [1e-6, 10.0]', float) - num_ip_encoder_layers = hp('num_ip_encoder_layers', (ge(0), le(100)), 'An integer in [0, 100]', int) - random_negative_sampling_rate = hp('random_negative_sampling_rate', (ge(0), le(500)), 'An integer in [0, 500]', int) - shuffled_negative_sampling_rate = hp('shuffled_negative_sampling_rate', (ge(0), le(500)), 'An integer in [0, 500]', - int) - weight_decay = hp('weight_decay', (ge(0.0), le(10.0)), 'A float in [0.0, 10.0]', float) - - def __init__(self, role, train_instance_count, train_instance_type, num_entity_vectors, vector_dim, - batch_metrics_publish_interval=None, epochs=None, learning_rate=None, - num_ip_encoder_layers=None, random_negative_sampling_rate=None, - shuffled_negative_sampling_rate=None, weight_decay=None, **kwargs): + num_entity_vectors = hp( + "num_entity_vectors", (ge(1), le(250000000)), "An integer in [1, 250000000]", int + ) + vector_dim = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int) + + batch_metrics_publish_interval = hp( + "batch_metrics_publish_interval", (ge(1)), "An integer greater than 0", int + ) + epochs = hp("epochs", (ge(1)), "An integer greater than 0", int) + learning_rate = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float) + num_ip_encoder_layers = hp( + "num_ip_encoder_layers", (ge(0), le(100)), "An integer in [0, 100]", int + ) + random_negative_sampling_rate = hp( + "random_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int + ) + shuffled_negative_sampling_rate = hp( + "shuffled_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int + ) + weight_decay = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + num_entity_vectors, + vector_dim, + batch_metrics_publish_interval=None, + epochs=None, + learning_rate=None, + num_ip_encoder_layers=None, + random_negative_sampling_rate=None, + shuffled_negative_sampling_rate=None, + weight_decay=None, + **kwargs + ): """This estimator is for IP Insights, an unsupervised algorithm that learns usage patterns of IP addresses. This Estimator may be fit via calls to @@ -102,13 +123,19 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): Returns: :class:`~sagemaker.amazon.IPInsightsModel`: references the latest s3 model data produced by this estimator. """ - return IPInsightsModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return IPInsightsModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): if mini_batch_size is not None and (mini_batch_size < 1 or mini_batch_size > 500000): raise ValueError("mini_batch_size must be in [1, 500000]") - super(IPInsights, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(IPInsights, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) class IPInsightsPredictor(RealTimePredictor): @@ -121,9 +148,9 @@ class IPInsightsPredictor(RealTimePredictor): """ def __init__(self, endpoint, sagemaker_session=None): - super(IPInsightsPredictor, self).__init__(endpoint, sagemaker_session, - serializer=csv_serializer, - deserializer=json_deserializer) + super(IPInsightsPredictor, self).__init__( + endpoint, sagemaker_session, serializer=csv_serializer, deserializer=json_deserializer + ) class IPInsightsModel(Model): @@ -132,12 +159,16 @@ class IPInsightsModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(IPInsights.repo_name, IPInsights.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, - IPInsights.repo_name), repo) + repo = "{}:{}".format(IPInsights.repo_name, IPInsights.repo_version) + image = "{}/{}".format( + registry(sagemaker_session.boto_session.region_name, IPInsights.repo_name), repo + ) super(IPInsightsModel, self).__init__( - model_data, image, role, + model_data, + image, + role, predictor_cls=IPInsightsPredictor, sagemaker_session=sagemaker_session, - **kwargs) + **kwargs + ) diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index d7fb99a389..633d8dc9a6 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -24,24 +24,45 @@ class KMeans(AmazonAlgorithmEstimatorBase): - repo_name = 'kmeans' + repo_name = "kmeans" repo_version = 1 - k = hp('k', gt(1), 'An integer greater-than 1', int) - init_method = hp('init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str) - max_iterations = hp('local_lloyd_max_iter', gt(0), 'An integer greater-than 0', int) - tol = hp('local_lloyd_tol', (ge(0), le(1)), 'An float in [0, 1]', float) - num_trials = hp('local_lloyd_num_trials', gt(0), 'An integer greater-than 0', int) - local_init_method = hp('local_lloyd_init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str) - half_life_time_size = hp('half_life_time_size', ge(0), 'An integer greater-than-or-equal-to 0', int) - epochs = hp('epochs', gt(0), 'An integer greater-than 0', int) - center_factor = hp('extra_center_factor', gt(0), 'An integer greater-than 0', int) - eval_metrics = hp(name='eval_metrics', validation_message='A comma separated list of "msd" or "ssd"', - data_type=list) - - def __init__(self, role, train_instance_count, train_instance_type, k, init_method=None, - max_iterations=None, tol=None, num_trials=None, local_init_method=None, - half_life_time_size=None, epochs=None, center_factor=None, eval_metrics=None, **kwargs): + k = hp("k", gt(1), "An integer greater-than 1", int) + init_method = hp("init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str) + max_iterations = hp("local_lloyd_max_iter", gt(0), "An integer greater-than 0", int) + tol = hp("local_lloyd_tol", (ge(0), le(1)), "An float in [0, 1]", float) + num_trials = hp("local_lloyd_num_trials", gt(0), "An integer greater-than 0", int) + local_init_method = hp( + "local_lloyd_init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str + ) + half_life_time_size = hp( + "half_life_time_size", ge(0), "An integer greater-than-or-equal-to 0", int + ) + epochs = hp("epochs", gt(0), "An integer greater-than 0", int) + center_factor = hp("extra_center_factor", gt(0), "An integer greater-than 0", int) + eval_metrics = hp( + name="eval_metrics", + validation_message='A comma separated list of "msd" or "ssd"', + data_type=list, + ) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + k, + init_method=None, + max_iterations=None, + tol=None, + num_trials=None, + local_init_method=None, + half_life_time_size=None, + epochs=None, + center_factor=None, + eval_metrics=None, + **kwargs + ): """ A k-means clustering :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. Finds k clusters of data in an unlabeled dataset. @@ -113,15 +134,21 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return KMeansModel(self.model_data, self.role, self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return KMeansModel( + self.model_data, + self.role, + self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size=5000, job_name=None): - super(KMeans, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(KMeans, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) def hyperparameters(self): """Return the SageMaker hyperparameters for training this KMeans Estimator""" - hp_dict = dict(force_dense='True') # KMeans requires this hp to fit on Record objects + hp_dict = dict(force_dense="True") # KMeans requires this hp to fit on Record objects hp_dict.update(super(KMeans, self).hyperparameters()) return hp_dict @@ -139,8 +166,12 @@ class KMeansPredictor(RealTimePredictor): key of the ``Record.label`` field.""" def __init__(self, endpoint, sagemaker_session=None): - super(KMeansPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(KMeansPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class KMeansModel(Model): @@ -149,8 +180,13 @@ class KMeansModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(KMeans.repo_name, KMeans.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) - super(KMeansModel, self).__init__(model_data, image, role, predictor_cls=KMeansPredictor, - sagemaker_session=sagemaker_session, - **kwargs) + repo = "{}:{}".format(KMeans.repo_name, KMeans.repo_version) + image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo) + super(KMeansModel, self).__init__( + model_data, + image, + role, + predictor_cls=KMeansPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 0951b732d6..9df4325e34 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -23,26 +23,56 @@ class KNN(AmazonAlgorithmEstimatorBase): - repo_name = 'knn' + repo_name = "knn" repo_version = 1 - k = hp('k', (ge(1)), 'An integer greater than 0', int) - sample_size = hp('sample_size', (ge(1)), 'An integer greater than 0', int) - predictor_type = hp('predictor_type', isin('classifier', 'regressor'), - 'One of "classifier" or "regressor"', str) - dimension_reduction_target = hp('dimension_reduction_target', (ge(1)), - 'An integer greater than 0 and less than feature_dim', int) - dimension_reduction_type = hp('dimension_reduction_type', isin('sign', 'fjlt'), 'One of "sign" or "fjlt"', str) - index_metric = hp('index_metric', isin('COSINE', 'INNER_PRODUCT', 'L2'), - 'One of "COSINE", "INNER_PRODUCT", "L2"', str) - index_type = hp('index_type', isin('faiss.Flat', 'faiss.IVFFlat', 'faiss.IVFPQ'), - 'One of "faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ"', str) - faiss_index_ivf_nlists = hp('faiss_index_ivf_nlists', (), '"auto" or an integer greater than 0', str) - faiss_index_pq_m = hp('faiss_index_pq_m', (ge(1)), 'An integer greater than 0', int) - - def __init__(self, role, train_instance_count, train_instance_type, k, sample_size, predictor_type, - dimension_reduction_type=None, dimension_reduction_target=None, index_type=None, - index_metric=None, faiss_index_ivf_nlists=None, faiss_index_pq_m=None, **kwargs): + k = hp("k", (ge(1)), "An integer greater than 0", int) + sample_size = hp("sample_size", (ge(1)), "An integer greater than 0", int) + predictor_type = hp( + "predictor_type", isin("classifier", "regressor"), 'One of "classifier" or "regressor"', str + ) + dimension_reduction_target = hp( + "dimension_reduction_target", + (ge(1)), + "An integer greater than 0 and less than feature_dim", + int, + ) + dimension_reduction_type = hp( + "dimension_reduction_type", isin("sign", "fjlt"), 'One of "sign" or "fjlt"', str + ) + index_metric = hp( + "index_metric", + isin("COSINE", "INNER_PRODUCT", "L2"), + 'One of "COSINE", "INNER_PRODUCT", "L2"', + str, + ) + index_type = hp( + "index_type", + isin("faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ"), + 'One of "faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ"', + str, + ) + faiss_index_ivf_nlists = hp( + "faiss_index_ivf_nlists", (), '"auto" or an integer greater than 0', str + ) + faiss_index_pq_m = hp("faiss_index_pq_m", (ge(1)), "An integer greater than 0", int) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + k, + sample_size, + predictor_type, + dimension_reduction_type=None, + dimension_reduction_target=None, + index_type=None, + index_metric=None, + faiss_index_ivf_nlists=None, + faiss_index_pq_m=None, + **kwargs + ): """k-nearest neighbors (KNN) is :class:`Estimator` used for classification and regression. This Estimator may be fit via calls to :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon @@ -97,7 +127,9 @@ def __init__(self, role, train_instance_count, train_instance_type, k, sample_si self.faiss_index_ivf_nlists = faiss_index_ivf_nlists self.faiss_index_pq_m = faiss_index_pq_m if dimension_reduction_type and not dimension_reduction_target: - raise ValueError('"dimension_reduction_target" is required when "dimension_reduction_type" is set.') + raise ValueError( + '"dimension_reduction_target" is required when "dimension_reduction_type" is set.' + ) def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): """Return a :class:`~sagemaker.amazon.KNNModel` referencing the latest @@ -109,11 +141,17 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return KNNModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return KNNModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): - super(KNN, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(KNN, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) class KNNPredictor(RealTimePredictor): @@ -129,8 +167,12 @@ class KNNPredictor(RealTimePredictor): key of the ``Record.label`` field.""" def __init__(self, endpoint, sagemaker_session=None): - super(KNNPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(KNNPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class KNNModel(Model): @@ -139,7 +181,15 @@ class KNNModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(KNN.repo_name, KNN.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, KNN.repo_name), repo) - super(KNNModel, self).__init__(model_data, image, role, predictor_cls=KNNPredictor, - sagemaker_session=sagemaker_session, **kwargs) + repo = "{}:{}".format(KNN.repo_name, KNN.repo_version) + image = "{}/{}".format( + registry(sagemaker_session.boto_session.region_name, KNN.repo_name), repo + ) + super(KNNModel, self).__init__( + model_data, + image, + role, + predictor_cls=KNNPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index 8c7235a8a8..bdcb7b71cc 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -24,17 +24,26 @@ class LDA(AmazonAlgorithmEstimatorBase): - repo_name = 'lda' + repo_name = "lda" repo_version = 1 - num_topics = hp('num_topics', gt(0), 'An integer greater than zero', int) - alpha0 = hp('alpha0', gt(0), 'A positive float', float) - max_restarts = hp('max_restarts', gt(0), 'An integer greater than zero', int) - max_iterations = hp('max_iterations', gt(0), 'An integer greater than zero', int) - tol = hp('tol', gt(0), 'A positive float', float) - - def __init__(self, role, train_instance_type, num_topics, - alpha0=None, max_restarts=None, max_iterations=None, tol=None, **kwargs): + num_topics = hp("num_topics", gt(0), "An integer greater than zero", int) + alpha0 = hp("alpha0", gt(0), "A positive float", float) + max_restarts = hp("max_restarts", gt(0), "An integer greater than zero", int) + max_iterations = hp("max_iterations", gt(0), "An integer greater than zero", int) + tol = hp("tol", gt(0), "A positive float", float) + + def __init__( + self, + role, + train_instance_type, + num_topics, + alpha0=None, + max_restarts=None, + max_iterations=None, + tol=None, + **kwargs + ): """Latent Dirichlet Allocation (LDA) is :class:`Estimator` used for unsupervised learning. Amazon SageMaker Latent Dirichlet Allocation is an unsupervised learning algorithm that attempts to describe @@ -80,8 +89,12 @@ def __init__(self, role, train_instance_type, num_topics, **kwargs: base class keyword argument values. """ # this algorithm only supports single instance training - if kwargs.pop('train_instance_count', 1) != 1: - print('LDA only supports single instance training. Defaulting to 1 {}.'.format(train_instance_type)) + if kwargs.pop("train_instance_count", 1) != 1: + print( + "LDA only supports single instance training. Defaulting to 1 {}.".format( + train_instance_type + ) + ) super(LDA, self).__init__(role, 1, train_instance_type, **kwargs) self.num_topics = num_topics @@ -100,15 +113,21 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return LDAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return LDAModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size, job_name=None): # mini_batch_size is required, prevent explicit calls with None if mini_batch_size is None: raise ValueError("mini_batch_size must be set") - super(LDA, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(LDA, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) class LDAPredictor(RealTimePredictor): @@ -124,8 +143,12 @@ class LDAPredictor(RealTimePredictor): key of the ``Record.label`` field.""" def __init__(self, endpoint, sagemaker_session=None): - super(LDAPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(LDAPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class LDAModel(Model): @@ -134,7 +157,15 @@ class LDAModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(LDA.repo_name, LDA.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo) - super(LDAModel, self).__init__(model_data, image, role, predictor_cls=LDAPredictor, - sagemaker_session=sagemaker_session, **kwargs) + repo = "{}:{}".format(LDA.repo_name, LDA.repo_version) + image = "{}/{}".format( + registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo + ) + super(LDAModel, self).__init__( + model_data, + image, + role, + predictor_cls=LDAPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index 1c113b22c6..9efb32f6a7 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -23,73 +23,146 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): - repo_name = 'linear-learner' + repo_name = "linear-learner" repo_version = 1 DEFAULT_MINI_BATCH_SIZE = 1000 - binary_classifier_model_selection_criteria = hp('binary_classifier_model_selection_criteria', - isin('accuracy', 'f1', 'f_beta', 'precision_at_target_recall', - 'recall_at_target_precision', 'cross_entropy_loss', - 'loss_function'), data_type=str) - target_recall = hp('target_recall', (gt(0), lt(1)), "A float in (0,1)", float) - target_precision = hp('target_precision', (gt(0), lt(1)), "A float in (0,1)", float) - positive_example_weight_mult = hp('positive_example_weight_mult', (), - "A float greater than 0 or 'auto' or 'balanced'", str) - epochs = hp('epochs', gt(0), "An integer greater-than 0", int) - predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor', 'multiclass_classifier'), - 'One of "binary_classifier" or "multiclass_classifier" or "regressor"', str) - use_bias = hp('use_bias', (), "Either True or False", bool) - num_models = hp('num_models', gt(0), "An integer greater-than 0", int) - num_calibration_samples = hp('num_calibration_samples', gt(0), "An integer greater-than 0", int) - init_method = hp('init_method', isin('uniform', 'normal'), 'One of "uniform" or "normal"', str) - init_scale = hp('init_scale', gt(0), 'A float greater-than 0', float) - init_sigma = hp('init_sigma', gt(0), 'A float greater-than 0', float) - init_bias = hp('init_bias', (), 'A number', float) - optimizer = hp('optimizer', isin('sgd', 'adam', 'rmsprop', 'auto'), 'One of "sgd", "adam", "rmsprop" or "auto', str) - loss = hp('loss', isin('logistic', 'squared_loss', 'absolute_loss', 'hinge_loss', 'eps_insensitive_squared_loss', - 'eps_insensitive_absolute_loss', 'quantile_loss', 'huber_loss', 'softmax_loss', 'auto'), - '"logistic", "squared_loss", "absolute_loss", "hinge_loss", "eps_insensitive_squared_loss", ' - '"eps_insensitive_absolute_loss", "quantile_loss", "huber_loss", "softmax_loss" or "auto"', str) - wd = hp('wd', ge(0), 'A float greater-than or equal to 0', float) - l1 = hp('l1', ge(0), 'A float greater-than or equal to 0', float) - momentum = hp('momentum', (ge(0), lt(1)), 'A float in [0,1)', float) - learning_rate = hp('learning_rate', gt(0), 'A float greater-than 0', float) - beta_1 = hp('beta_1', (ge(0), lt(1)), 'A float in [0,1)', float) - beta_2 = hp('beta_2', (ge(0), lt(1)), 'A float in [0,1)', float) - bias_lr_mult = hp('bias_lr_mult', gt(0), 'A float greater-than 0', float) - bias_wd_mult = hp('bias_wd_mult', ge(0), 'A float greater-than or equal to 0', float) - use_lr_scheduler = hp('use_lr_scheduler', (), 'A boolean', bool) - lr_scheduler_step = hp('lr_scheduler_step', gt(0), 'An integer greater-than 0', int) - lr_scheduler_factor = hp('lr_scheduler_factor', (gt(0), lt(1)), 'A float in (0,1)', float) - lr_scheduler_minimum_lr = hp('lr_scheduler_minimum_lr', gt(0), 'A float greater-than 0', float) - normalize_data = hp('normalize_data', (), 'A boolean', bool) - normalize_label = hp('normalize_label', (), 'A boolean', bool) - unbias_data = hp('unbias_data', (), 'A boolean', bool) - unbias_label = hp('unbias_label', (), 'A boolean', bool) - num_point_for_scaler = hp('num_point_for_scaler', gt(0), 'An integer greater-than 0', int) - margin = hp('margin', ge(0), 'A float greater-than or equal to 0', float) - quantile = hp('quantile', (gt(0), lt(1)), 'A float in (0,1)', float) - loss_insensitivity = hp('loss_insensitivity', gt(0), 'A float greater-than 0', float) - huber_delta = hp('huber_delta', ge(0), 'A float greater-than or equal to 0', float) - early_stopping_patience = hp('early_stopping_patience', gt(0), 'An integer greater-than 0', int) - early_stopping_tolerance = hp('early_stopping_tolerance', gt(0), 'A float greater-than 0', float) - num_classes = hp('num_classes', (gt(0), le(1000000)), 'An integer in [1,1000000]', int) - accuracy_top_k = hp('accuracy_top_k', (gt(0), le(1000000)), 'An integer in [1,1000000]', int) - f_beta = hp('f_beta', gt(0), 'A float greater-than 0', float) - balance_multiclass_weights = hp('balance_multiclass_weights', (), 'A boolean', bool) + binary_classifier_model_selection_criteria = hp( + "binary_classifier_model_selection_criteria", + isin( + "accuracy", + "f1", + "f_beta", + "precision_at_target_recall", + "recall_at_target_precision", + "cross_entropy_loss", + "loss_function", + ), + data_type=str, + ) + target_recall = hp("target_recall", (gt(0), lt(1)), "A float in (0,1)", float) + target_precision = hp("target_precision", (gt(0), lt(1)), "A float in (0,1)", float) + positive_example_weight_mult = hp( + "positive_example_weight_mult", (), "A float greater than 0 or 'auto' or 'balanced'", str + ) + epochs = hp("epochs", gt(0), "An integer greater-than 0", int) + predictor_type = hp( + "predictor_type", + isin("binary_classifier", "regressor", "multiclass_classifier"), + 'One of "binary_classifier" or "multiclass_classifier" or "regressor"', + str, + ) + use_bias = hp("use_bias", (), "Either True or False", bool) + num_models = hp("num_models", gt(0), "An integer greater-than 0", int) + num_calibration_samples = hp("num_calibration_samples", gt(0), "An integer greater-than 0", int) + init_method = hp("init_method", isin("uniform", "normal"), 'One of "uniform" or "normal"', str) + init_scale = hp("init_scale", gt(0), "A float greater-than 0", float) + init_sigma = hp("init_sigma", gt(0), "A float greater-than 0", float) + init_bias = hp("init_bias", (), "A number", float) + optimizer = hp( + "optimizer", + isin("sgd", "adam", "rmsprop", "auto"), + 'One of "sgd", "adam", "rmsprop" or "auto', + str, + ) + loss = hp( + "loss", + isin( + "logistic", + "squared_loss", + "absolute_loss", + "hinge_loss", + "eps_insensitive_squared_loss", + "eps_insensitive_absolute_loss", + "quantile_loss", + "huber_loss", + "softmax_loss", + "auto", + ), + '"logistic", "squared_loss", "absolute_loss", "hinge_loss", "eps_insensitive_squared_loss", ' + '"eps_insensitive_absolute_loss", "quantile_loss", "huber_loss", "softmax_loss" or "auto"', + str, + ) + wd = hp("wd", ge(0), "A float greater-than or equal to 0", float) + l1 = hp("l1", ge(0), "A float greater-than or equal to 0", float) + momentum = hp("momentum", (ge(0), lt(1)), "A float in [0,1)", float) + learning_rate = hp("learning_rate", gt(0), "A float greater-than 0", float) + beta_1 = hp("beta_1", (ge(0), lt(1)), "A float in [0,1)", float) + beta_2 = hp("beta_2", (ge(0), lt(1)), "A float in [0,1)", float) + bias_lr_mult = hp("bias_lr_mult", gt(0), "A float greater-than 0", float) + bias_wd_mult = hp("bias_wd_mult", ge(0), "A float greater-than or equal to 0", float) + use_lr_scheduler = hp("use_lr_scheduler", (), "A boolean", bool) + lr_scheduler_step = hp("lr_scheduler_step", gt(0), "An integer greater-than 0", int) + lr_scheduler_factor = hp("lr_scheduler_factor", (gt(0), lt(1)), "A float in (0,1)", float) + lr_scheduler_minimum_lr = hp("lr_scheduler_minimum_lr", gt(0), "A float greater-than 0", float) + normalize_data = hp("normalize_data", (), "A boolean", bool) + normalize_label = hp("normalize_label", (), "A boolean", bool) + unbias_data = hp("unbias_data", (), "A boolean", bool) + unbias_label = hp("unbias_label", (), "A boolean", bool) + num_point_for_scaler = hp("num_point_for_scaler", gt(0), "An integer greater-than 0", int) + margin = hp("margin", ge(0), "A float greater-than or equal to 0", float) + quantile = hp("quantile", (gt(0), lt(1)), "A float in (0,1)", float) + loss_insensitivity = hp("loss_insensitivity", gt(0), "A float greater-than 0", float) + huber_delta = hp("huber_delta", ge(0), "A float greater-than or equal to 0", float) + early_stopping_patience = hp("early_stopping_patience", gt(0), "An integer greater-than 0", int) + early_stopping_tolerance = hp( + "early_stopping_tolerance", gt(0), "A float greater-than 0", float + ) + num_classes = hp("num_classes", (gt(0), le(1000000)), "An integer in [1,1000000]", int) + accuracy_top_k = hp("accuracy_top_k", (gt(0), le(1000000)), "An integer in [1,1000000]", int) + f_beta = hp("f_beta", gt(0), "A float greater-than 0", float) + balance_multiclass_weights = hp("balance_multiclass_weights", (), "A boolean", bool) - def __init__(self, role, train_instance_count, train_instance_type, predictor_type, - binary_classifier_model_selection_criteria=None, target_recall=None, target_precision=None, - positive_example_weight_mult=None, epochs=None, use_bias=None, num_models=None, - num_calibration_samples=None, init_method=None, init_scale=None, init_sigma=None, init_bias=None, - optimizer=None, loss=None, wd=None, l1=None, momentum=None, learning_rate=None, beta_1=None, - beta_2=None, bias_lr_mult=None, bias_wd_mult=None, use_lr_scheduler=None, lr_scheduler_step=None, - lr_scheduler_factor=None, lr_scheduler_minimum_lr=None, normalize_data=None, - normalize_label=None, unbias_data=None, unbias_label=None, num_point_for_scaler=None, margin=None, - quantile=None, loss_insensitivity=None, huber_delta=None, early_stopping_patience=None, - early_stopping_tolerance=None, num_classes=None, accuracy_top_k=None, f_beta=None, - balance_multiclass_weights=None, **kwargs): + def __init__( + self, + role, + train_instance_count, + train_instance_type, + predictor_type, + binary_classifier_model_selection_criteria=None, + target_recall=None, + target_precision=None, + positive_example_weight_mult=None, + epochs=None, + use_bias=None, + num_models=None, + num_calibration_samples=None, + init_method=None, + init_scale=None, + init_sigma=None, + init_bias=None, + optimizer=None, + loss=None, + wd=None, + l1=None, + momentum=None, + learning_rate=None, + beta_1=None, + beta_2=None, + bias_lr_mult=None, + bias_wd_mult=None, + use_lr_scheduler=None, + lr_scheduler_step=None, + lr_scheduler_factor=None, + lr_scheduler_minimum_lr=None, + normalize_data=None, + normalize_label=None, + unbias_data=None, + unbias_label=None, + num_point_for_scaler=None, + margin=None, + quantile=None, + loss_insensitivity=None, + huber_delta=None, + early_stopping_patience=None, + early_stopping_tolerance=None, + num_classes=None, + accuracy_top_k=None, + f_beta=None, + balance_multiclass_weights=None, + **kwargs + ): """An :class:`Estimator` for binary classification and regression. Amazon SageMaker Linear Learner provides a solution for both classification and regression problems, allowing @@ -199,7 +272,9 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty the loss function. Only used when predictor_type is multiclass_classifier. **kwargs: base class keyword argument values. """ - super(LinearLearner, self).__init__(role, train_instance_count, train_instance_type, **kwargs) + super(LinearLearner, self).__init__( + role, train_instance_count, train_instance_type, **kwargs + ) self.predictor_type = predictor_type self.binary_classifier_model_selection_criteria = binary_classifier_model_selection_criteria self.target_recall = target_recall @@ -243,9 +318,12 @@ def __init__(self, role, train_instance_count, train_instance_type, predictor_ty self.f_beta = f_beta self.balance_multiclass_weights = balance_multiclass_weights - if self.predictor_type == 'multiclass_classifier' and (num_classes is None or num_classes < 3): + if self.predictor_type == "multiclass_classifier" and ( + num_classes is None or num_classes < 3 + ): raise ValueError( - "For predictor_type 'multiclass_classifier', 'num_classes' should be set to a value greater than 2.") + "For predictor_type 'multiclass_classifier', 'num_classes' should be set to a value greater than 2." + ) def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): """Return a :class:`~sagemaker.amazon.LinearLearnerModel` referencing the latest @@ -257,26 +335,33 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return LinearLearnerModel(self.model_data, self.role, self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return LinearLearnerModel( + self.model_data, + self.role, + self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): num_records = None if isinstance(records, list): for record in records: - if record.channel == 'train': + if record.channel == "train": num_records = record.num_records break if num_records is None: - raise ValueError('Must provide train channel.') + raise ValueError("Must provide train channel.") else: num_records = records.num_records # mini_batch_size can't be greater than number of records or training job fails - default_mini_batch_size = min(self.DEFAULT_MINI_BATCH_SIZE, - max(1, int(num_records / self.train_instance_count))) + default_mini_batch_size = min( + self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.train_instance_count)) + ) mini_batch_size = mini_batch_size or default_mini_batch_size - super(LinearLearner, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(LinearLearner, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) class LinearLearnerPredictor(RealTimePredictor): @@ -292,9 +377,12 @@ class LinearLearnerPredictor(RealTimePredictor): key of the ``Record.label`` field.""" def __init__(self, endpoint, sagemaker_session=None): - super(LinearLearnerPredictor, self).__init__(endpoint, sagemaker_session, - serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(LinearLearnerPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class LinearLearnerModel(Model): @@ -303,9 +391,13 @@ class LinearLearnerModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(LinearLearner.repo_name, LinearLearner.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) - super(LinearLearnerModel, self).__init__(model_data, image, role, - predictor_cls=LinearLearnerPredictor, - sagemaker_session=sagemaker_session, - **kwargs) + repo = "{}:{}".format(LinearLearner.repo_name, LinearLearner.repo_version) + image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo) + super(LinearLearnerModel, self).__init__( + model_data, + image, + role, + predictor_cls=LinearLearnerPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 0fa11caeae..5e6faaae6a 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -24,29 +24,55 @@ class NTM(AmazonAlgorithmEstimatorBase): - repo_name = 'ntm' + repo_name = "ntm" repo_version = 1 - num_topics = hp('num_topics', (ge(2), le(1000)), 'An integer in [2, 1000]', int) - encoder_layers = hp(name='encoder_layers', validation_message='A comma separated list of ' - 'positive integers', data_type=list) - epochs = hp('epochs', (ge(1), le(100)), 'An integer in [1, 100]', int) - encoder_layers_activation = hp('encoder_layers_activation', isin('sigmoid', 'tanh', 'relu'), - 'One of "sigmoid", "tanh" or "relu"', str) - optimizer = hp('optimizer', isin('adagrad', 'adam', 'rmsprop', 'sgd', 'adadelta'), - 'One of "adagrad", "adam", "rmsprop", "sgd" and "adadelta"', str) - tolerance = hp('tolerance', (ge(1e-6), le(0.1)), 'A float in [1e-6, 0.1]', float) - num_patience_epochs = hp('num_patience_epochs', (ge(1), le(10)), 'An integer in [1, 10]', int) - batch_norm = hp(name='batch_norm', validation_message='Value must be a boolean', data_type=bool) - rescale_gradient = hp('rescale_gradient', (ge(1e-3), le(1.0)), 'A float in [1e-3, 1.0]', float) - clip_gradient = hp('clip_gradient', ge(1e-3), 'A float greater equal to 1e-3', float) - weight_decay = hp('weight_decay', (ge(0.0), le(1.0)), 'A float in [0.0, 1.0]', float) - learning_rate = hp('learning_rate', (ge(1e-6), le(1.0)), 'A float in [1e-6, 1.0]', float) - - def __init__(self, role, train_instance_count, train_instance_type, num_topics, - encoder_layers=None, epochs=None, encoder_layers_activation=None, optimizer=None, tolerance=None, - num_patience_epochs=None, batch_norm=None, rescale_gradient=None, clip_gradient=None, - weight_decay=None, learning_rate=None, **kwargs): + num_topics = hp("num_topics", (ge(2), le(1000)), "An integer in [2, 1000]", int) + encoder_layers = hp( + name="encoder_layers", + validation_message="A comma separated list of " "positive integers", + data_type=list, + ) + epochs = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) + encoder_layers_activation = hp( + "encoder_layers_activation", + isin("sigmoid", "tanh", "relu"), + 'One of "sigmoid", "tanh" or "relu"', + str, + ) + optimizer = hp( + "optimizer", + isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"), + 'One of "adagrad", "adam", "rmsprop", "sgd" and "adadelta"', + str, + ) + tolerance = hp("tolerance", (ge(1e-6), le(0.1)), "A float in [1e-6, 0.1]", float) + num_patience_epochs = hp("num_patience_epochs", (ge(1), le(10)), "An integer in [1, 10]", int) + batch_norm = hp(name="batch_norm", validation_message="Value must be a boolean", data_type=bool) + rescale_gradient = hp("rescale_gradient", (ge(1e-3), le(1.0)), "A float in [1e-3, 1.0]", float) + clip_gradient = hp("clip_gradient", ge(1e-3), "A float greater equal to 1e-3", float) + weight_decay = hp("weight_decay", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) + learning_rate = hp("learning_rate", (ge(1e-6), le(1.0)), "A float in [1e-6, 1.0]", float) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + num_topics, + encoder_layers=None, + epochs=None, + encoder_layers_activation=None, + optimizer=None, + tolerance=None, + num_patience_epochs=None, + batch_norm=None, + rescale_gradient=None, + clip_gradient=None, + weight_decay=None, + learning_rate=None, + **kwargs + ): """Neural Topic Model (NTM) is :class:`Estimator` used for unsupervised learning. This Estimator may be fit via calls to @@ -118,13 +144,19 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return NTMModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return NTMModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size, job_name=None): if mini_batch_size is not None and (mini_batch_size < 1 or mini_batch_size > 10000): raise ValueError("mini_batch_size must be in [1, 10000]") - super(NTM, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(NTM, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) class NTMPredictor(RealTimePredictor): @@ -140,8 +172,12 @@ class NTMPredictor(RealTimePredictor): key of the ``Record.label`` field.""" def __init__(self, endpoint, sagemaker_session=None): - super(NTMPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(NTMPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class NTMModel(Model): @@ -150,8 +186,15 @@ class NTMModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(NTM.repo_name, NTM.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo) - super(NTMModel, self).__init__(model_data, image, role, predictor_cls=NTMPredictor, - sagemaker_session=sagemaker_session, - **kwargs) + repo = "{}:{}".format(NTM.repo_name, NTM.repo_version) + image = "{}/{}".format( + registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo + ) + super(NTMModel, self).__init__( + model_data, + image, + role, + predictor_cls=NTMPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/object2vec.py b/src/sagemaker/amazon/object2vec.py index 06dc965000..8aaad2bd47 100644 --- a/src/sagemaker/amazon/object2vec.py +++ b/src/sagemaker/amazon/object2vec.py @@ -28,7 +28,7 @@ def validate(value): if not isinstance(value, str): return False - val_list = [s.strip() for s in value.split(',')] + val_list = [s.strip() for s in value.split(",")] return set(val_list).issubset(valid_superset) return validate @@ -36,112 +36,134 @@ def validate(value): class Object2Vec(AmazonAlgorithmEstimatorBase): - repo_name = 'object2vec' + repo_name = "object2vec" repo_version = 1 MINI_BATCH_SIZE = 32 - enc_dim = hp('enc_dim', (ge(4), le(10000)), - 'An integer in [4, 10000]', int) - mini_batch_size = hp('mini_batch_size', (ge(1), le(10000)), - 'An integer in [1, 10000]', int) - epochs = hp('epochs', (ge(1), le(100)), - 'An integer in [1, 100]', int) - early_stopping_patience = hp('early_stopping_patience', (ge(1), le(5)), - 'An integer in [1, 5]', int) - early_stopping_tolerance = hp('early_stopping_tolerance', (ge(1e-06), le(0.1)), - 'A float in [1e-06, 0.1]', float) - dropout = hp('dropout', (ge(0.0), le(1.0)), - 'A float in [0.0, 1.0]', float) - weight_decay = hp('weight_decay', (ge(0.0), le(10000.0)), - 'A float in [0.0, 10000.0]', float) - bucket_width = hp('bucket_width', (ge(0), le(100)), - 'An integer in [0, 100]', int) - num_classes = hp('num_classes', (ge(2), le(30)), - 'An integer in [2, 30]', int) - mlp_layers = hp('mlp_layers', (ge(1), le(10)), - 'An integer in [1, 10]', int) - mlp_dim = hp('mlp_dim', (ge(2), le(10000)), - 'An integer in [2, 10000]', int) - mlp_activation = hp('mlp_activation', isin("tanh", "relu", "linear"), - 'One of "tanh", "relu", "linear"', str) - output_layer = hp('output_layer', isin("softmax", "mean_squared_error"), - 'One of "softmax", "mean_squared_error"', str) - optimizer = hp('optimizer', isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"), - 'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', str) - learning_rate = hp('learning_rate', (ge(1e-06), le(1.0)), - 'A float in [1e-06, 1.0]', float) - - negative_sampling_rate = hp('negative_sampling_rate', (ge(0), le(100)), 'An integer in [0, 100]', int) - comparator_list = hp('comparator_list', _list_check_subset(["hadamard", "concat", "abs_diff"]), - 'Comma-separated of hadamard, concat, abs_diff. E.g. "hadamard,abs_diff"', str) - tied_token_embedding_weight = hp('tied_token_embedding_weight', (), 'Either True or False', bool) - token_embedding_storage_type = hp('token_embedding_storage_type', isin("dense", "row_sparse"), - 'One of "dense", "row_sparse"', str) - - enc0_network = hp('enc0_network', isin("hcnn", "bilstm", "pooled_embedding"), - 'One of "hcnn", "bilstm", "pooled_embedding"', str) - enc1_network = hp('enc1_network', isin("hcnn", "bilstm", "pooled_embedding", "enc0"), - 'One of "hcnn", "bilstm", "pooled_embedding", "enc0"', str) - enc0_cnn_filter_width = hp('enc0_cnn_filter_width', (ge(1), le(9)), - 'An integer in [1, 9]', int) - enc1_cnn_filter_width = hp('enc1_cnn_filter_width', (ge(1), le(9)), - 'An integer in [1, 9]', int) - enc0_max_seq_len = hp('enc0_max_seq_len', (ge(1), le(5000)), - 'An integer in [1, 5000]', int) - enc1_max_seq_len = hp('enc1_max_seq_len', (ge(1), le(5000)), - 'An integer in [1, 5000]', int) - enc0_token_embedding_dim = hp('enc0_token_embedding_dim', (ge(2), le(1000)), - 'An integer in [2, 1000]', int) - enc1_token_embedding_dim = hp('enc1_token_embedding_dim', (ge(2), le(1000)), - 'An integer in [2, 1000]', int) - enc0_vocab_size = hp('enc0_vocab_size', (ge(2), le(3000000)), - 'An integer in [2, 3000000]', int) - enc1_vocab_size = hp('enc1_vocab_size', (ge(2), le(3000000)), - 'An integer in [2, 3000000]', int) - enc0_layers = hp('enc0_layers', (ge(1), le(4)), - 'An integer in [1, 4]', int) - enc1_layers = hp('enc1_layers', (ge(1), le(4)), - 'An integer in [1, 4]', int) - enc0_freeze_pretrained_embedding = hp('enc0_freeze_pretrained_embedding', (), - 'Either True or False', bool) - enc1_freeze_pretrained_embedding = hp('enc1_freeze_pretrained_embedding', (), - 'Either True or False', bool) - - def __init__(self, role, train_instance_count, train_instance_type, - epochs, - enc0_max_seq_len, - enc0_vocab_size, - enc_dim=None, - mini_batch_size=None, - early_stopping_patience=None, - early_stopping_tolerance=None, - dropout=None, - weight_decay=None, - bucket_width=None, - num_classes=None, - mlp_layers=None, - mlp_dim=None, - mlp_activation=None, - output_layer=None, - optimizer=None, - learning_rate=None, - negative_sampling_rate=None, - comparator_list=None, - tied_token_embedding_weight=None, - token_embedding_storage_type=None, - enc0_network=None, - enc1_network=None, - enc0_cnn_filter_width=None, - enc1_cnn_filter_width=None, - enc1_max_seq_len=None, - enc0_token_embedding_dim=None, - enc1_token_embedding_dim=None, - enc1_vocab_size=None, - enc0_layers=None, - enc1_layers=None, - enc0_freeze_pretrained_embedding=None, - enc1_freeze_pretrained_embedding=None, - **kwargs): + enc_dim = hp("enc_dim", (ge(4), le(10000)), "An integer in [4, 10000]", int) + mini_batch_size = hp("mini_batch_size", (ge(1), le(10000)), "An integer in [1, 10000]", int) + epochs = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) + early_stopping_patience = hp( + "early_stopping_patience", (ge(1), le(5)), "An integer in [1, 5]", int + ) + early_stopping_tolerance = hp( + "early_stopping_tolerance", (ge(1e-06), le(0.1)), "A float in [1e-06, 0.1]", float + ) + dropout = hp("dropout", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) + weight_decay = hp("weight_decay", (ge(0.0), le(10000.0)), "A float in [0.0, 10000.0]", float) + bucket_width = hp("bucket_width", (ge(0), le(100)), "An integer in [0, 100]", int) + num_classes = hp("num_classes", (ge(2), le(30)), "An integer in [2, 30]", int) + mlp_layers = hp("mlp_layers", (ge(1), le(10)), "An integer in [1, 10]", int) + mlp_dim = hp("mlp_dim", (ge(2), le(10000)), "An integer in [2, 10000]", int) + mlp_activation = hp( + "mlp_activation", isin("tanh", "relu", "linear"), 'One of "tanh", "relu", "linear"', str + ) + output_layer = hp( + "output_layer", + isin("softmax", "mean_squared_error"), + 'One of "softmax", "mean_squared_error"', + str, + ) + optimizer = hp( + "optimizer", + isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"), + 'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', + str, + ) + learning_rate = hp("learning_rate", (ge(1e-06), le(1.0)), "A float in [1e-06, 1.0]", float) + + negative_sampling_rate = hp( + "negative_sampling_rate", (ge(0), le(100)), "An integer in [0, 100]", int + ) + comparator_list = hp( + "comparator_list", + _list_check_subset(["hadamard", "concat", "abs_diff"]), + 'Comma-separated of hadamard, concat, abs_diff. E.g. "hadamard,abs_diff"', + str, + ) + tied_token_embedding_weight = hp( + "tied_token_embedding_weight", (), "Either True or False", bool + ) + token_embedding_storage_type = hp( + "token_embedding_storage_type", + isin("dense", "row_sparse"), + 'One of "dense", "row_sparse"', + str, + ) + + enc0_network = hp( + "enc0_network", + isin("hcnn", "bilstm", "pooled_embedding"), + 'One of "hcnn", "bilstm", "pooled_embedding"', + str, + ) + enc1_network = hp( + "enc1_network", + isin("hcnn", "bilstm", "pooled_embedding", "enc0"), + 'One of "hcnn", "bilstm", "pooled_embedding", "enc0"', + str, + ) + enc0_cnn_filter_width = hp("enc0_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int) + enc1_cnn_filter_width = hp("enc1_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int) + enc0_max_seq_len = hp("enc0_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) + enc1_max_seq_len = hp("enc1_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) + enc0_token_embedding_dim = hp( + "enc0_token_embedding_dim", (ge(2), le(1000)), "An integer in [2, 1000]", int + ) + enc1_token_embedding_dim = hp( + "enc1_token_embedding_dim", (ge(2), le(1000)), "An integer in [2, 1000]", int + ) + enc0_vocab_size = hp("enc0_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int) + enc1_vocab_size = hp("enc1_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int) + enc0_layers = hp("enc0_layers", (ge(1), le(4)), "An integer in [1, 4]", int) + enc1_layers = hp("enc1_layers", (ge(1), le(4)), "An integer in [1, 4]", int) + enc0_freeze_pretrained_embedding = hp( + "enc0_freeze_pretrained_embedding", (), "Either True or False", bool + ) + enc1_freeze_pretrained_embedding = hp( + "enc1_freeze_pretrained_embedding", (), "Either True or False", bool + ) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + epochs, + enc0_max_seq_len, + enc0_vocab_size, + enc_dim=None, + mini_batch_size=None, + early_stopping_patience=None, + early_stopping_tolerance=None, + dropout=None, + weight_decay=None, + bucket_width=None, + num_classes=None, + mlp_layers=None, + mlp_dim=None, + mlp_activation=None, + output_layer=None, + optimizer=None, + learning_rate=None, + negative_sampling_rate=None, + comparator_list=None, + tied_token_embedding_weight=None, + token_embedding_storage_type=None, + enc0_network=None, + enc1_network=None, + enc0_cnn_filter_width=None, + enc1_cnn_filter_width=None, + enc1_max_seq_len=None, + enc0_token_embedding_dim=None, + enc1_token_embedding_dim=None, + enc1_vocab_size=None, + enc0_layers=None, + enc1_layers=None, + enc0_freeze_pretrained_embedding=None, + enc1_freeze_pretrained_embedding=None, + **kwargs + ): """Object2Vec is :class:`Estimator` used for anomaly detection. This Estimator may be fit via calls to @@ -257,14 +279,20 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return Object2VecModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return Object2VecModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): if mini_batch_size is None: mini_batch_size = self.MINI_BATCH_SIZE - super(Object2Vec, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(Object2Vec, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) class Object2VecModel(Model): @@ -273,10 +301,15 @@ class Object2VecModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(Object2Vec.repo_name, Object2Vec.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, - Object2Vec.repo_name), repo) - super(Object2VecModel, self).__init__(model_data, image, role, - predictor_cls=RealTimePredictor, - sagemaker_session=sagemaker_session, - **kwargs) + repo = "{}:{}".format(Object2Vec.repo_name, Object2Vec.repo_version) + image = "{}/{}".format( + registry(sagemaker_session.boto_session.region_name, Object2Vec.repo_name), repo + ) + super(Object2VecModel, self).__init__( + model_data, + image, + role, + predictor_cls=RealTimePredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index b0e58f5f52..75db62e5a9 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -24,21 +24,38 @@ class PCA(AmazonAlgorithmEstimatorBase): - repo_name = 'pca' + repo_name = "pca" repo_version = 1 DEFAULT_MINI_BATCH_SIZE = 500 - num_components = hp('num_components', gt(0), 'Value must be an integer greater than zero', int) - algorithm_mode = hp('algorithm_mode', isin('regular', 'randomized'), - 'Value must be one of "regular" and "randomized"', str) - subtract_mean = hp(name='subtract_mean', validation_message='Value must be a boolean', data_type=bool) - extra_components = hp(name='extra_components', - validation_message="Value must be an integer greater than or equal to 0, or -1.", - data_type=int) - - def __init__(self, role, train_instance_count, train_instance_type, num_components, - algorithm_mode=None, subtract_mean=None, extra_components=None, **kwargs): + num_components = hp("num_components", gt(0), "Value must be an integer greater than zero", int) + algorithm_mode = hp( + "algorithm_mode", + isin("regular", "randomized"), + 'Value must be one of "regular" and "randomized"', + str, + ) + subtract_mean = hp( + name="subtract_mean", validation_message="Value must be a boolean", data_type=bool + ) + extra_components = hp( + name="extra_components", + validation_message="Value must be an integer greater than or equal to 0, or -1.", + data_type=int, + ) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + num_components, + algorithm_mode=None, + subtract_mean=None, + extra_components=None, + **kwargs + ): """A Principal Components Analysis (PCA) :class:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase`. This Estimator may be fit via calls to @@ -97,8 +114,12 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return PCAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return PCAModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): """Set hyperparameters needed for training. @@ -113,20 +134,23 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): num_records = None if isinstance(records, list): for record in records: - if record.channel == 'train': + if record.channel == "train": num_records = record.num_records break if num_records is None: - raise ValueError('Must provide train channel.') + raise ValueError("Must provide train channel.") else: num_records = records.num_records # mini_batch_size is a required parameter - default_mini_batch_size = min(self.DEFAULT_MINI_BATCH_SIZE, - max(1, int(num_records / self.train_instance_count))) + default_mini_batch_size = min( + self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.train_instance_count)) + ) use_mini_batch_size = mini_batch_size or default_mini_batch_size - super(PCA, self)._prepare_for_training(records=records, mini_batch_size=use_mini_batch_size, job_name=job_name) + super(PCA, self)._prepare_for_training( + records=records, mini_batch_size=use_mini_batch_size, job_name=job_name + ) class PCAPredictor(RealTimePredictor): @@ -142,8 +166,12 @@ class PCAPredictor(RealTimePredictor): key of the ``Record.label`` field.""" def __init__(self, endpoint, sagemaker_session=None): - super(PCAPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(PCAPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class PCAModel(Model): @@ -152,8 +180,13 @@ class PCAModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(PCA.repo_name, PCA.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) - super(PCAModel, self).__init__(model_data, image, role, predictor_cls=PCAPredictor, - sagemaker_session=sagemaker_session, - **kwargs) + repo = "{}:{}".format(PCA.repo_name, PCA.repo_version) + image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo) + super(PCAModel, self).__init__( + model_data, + image, + role, + predictor_cls=PCAPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index e4d11951e6..2d9e514553 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -24,20 +24,32 @@ class RandomCutForest(AmazonAlgorithmEstimatorBase): - repo_name = 'randomcutforest' + repo_name = "randomcutforest" repo_version = 1 MINI_BATCH_SIZE = 1000 - eval_metrics = hp(name='eval_metrics', - validation_message='A comma separated list of "accuracy" or "precision_recall_fscore"', - data_type=list) - - num_trees = hp('num_trees', (ge(50), le(1000)), 'An integer in [50, 1000]', int) - num_samples_per_tree = hp('num_samples_per_tree', (ge(1), le(2048)), 'An integer in [1, 2048]', int) - feature_dim = hp("feature_dim", (ge(1), le(10000)), 'An integer in [1, 10000]', int) - - def __init__(self, role, train_instance_count, train_instance_type, - num_samples_per_tree=None, num_trees=None, eval_metrics=None, **kwargs): + eval_metrics = hp( + name="eval_metrics", + validation_message='A comma separated list of "accuracy" or "precision_recall_fscore"', + data_type=list, + ) + + num_trees = hp("num_trees", (ge(50), le(1000)), "An integer in [50, 1000]", int) + num_samples_per_tree = hp( + "num_samples_per_tree", (ge(1), le(2048)), "An integer in [1, 2048]", int + ) + feature_dim = hp("feature_dim", (ge(1), le(10000)), "An integer in [1, 10000]", int) + + def __init__( + self, + role, + train_instance_count, + train_instance_type, + num_samples_per_tree=None, + num_trees=None, + eval_metrics=None, + **kwargs + ): """RandomCutForest is :class:`Estimator` used for anomaly detection. This Estimator may be fit via calls to @@ -77,7 +89,9 @@ def __init__(self, role, train_instance_count, train_instance_type, **kwargs: base class keyword argument values. """ - super(RandomCutForest, self).__init__(role, train_instance_count, train_instance_type, **kwargs) + super(RandomCutForest, self).__init__( + role, train_instance_count, train_instance_type, **kwargs + ) self.num_samples_per_tree = num_samples_per_tree self.num_trees = num_trees self.eval_metrics = eval_metrics @@ -92,16 +106,24 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT): * 'Subnets' (list[str]): List of subnet ids. * 'SecurityGroupIds' (list[str]): List of security group ids. """ - return RandomCutForestModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return RandomCutForestModel( + self.model_data, + self.role, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): if mini_batch_size is None: mini_batch_size = self.MINI_BATCH_SIZE elif mini_batch_size != self.MINI_BATCH_SIZE: - raise ValueError("Random Cut Forest uses a fixed mini_batch_size of {}".format(self.MINI_BATCH_SIZE)) + raise ValueError( + "Random Cut Forest uses a fixed mini_batch_size of {}".format(self.MINI_BATCH_SIZE) + ) - super(RandomCutForest, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name) + super(RandomCutForest, self)._prepare_for_training( + records, mini_batch_size=mini_batch_size, job_name=job_name + ) class RandomCutForestPredictor(RealTimePredictor): @@ -117,9 +139,12 @@ class RandomCutForestPredictor(RealTimePredictor): ``Record.label`` field.""" def __init__(self, endpoint, sagemaker_session=None): - super(RandomCutForestPredictor, self).__init__(endpoint, sagemaker_session, - serializer=numpy_to_record_serializer(), - deserializer=record_deserializer()) + super(RandomCutForestPredictor, self).__init__( + endpoint, + sagemaker_session, + serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), + ) class RandomCutForestModel(Model): @@ -128,10 +153,15 @@ class RandomCutForestModel(Model): def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session = sagemaker_session or Session() - repo = '{}:{}'.format(RandomCutForest.repo_name, RandomCutForest.repo_version) - image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, - RandomCutForest.repo_name), repo) - super(RandomCutForestModel, self).__init__(model_data, image, role, - predictor_cls=RandomCutForestPredictor, - sagemaker_session=sagemaker_session, - **kwargs) + repo = "{}:{}".format(RandomCutForest.repo_name, RandomCutForest.repo_version) + image = "{}/{}".format( + registry(sagemaker_session.boto_session.region_name, RandomCutForest.repo_name), repo + ) + super(RandomCutForestModel, self).__init__( + model_data, + image, + role, + predictor_cls=RandomCutForestPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/amazon/record_pb2.py b/src/sagemaker/amazon/record_pb2.py index 583dcd420b..cf4578c571 100644 --- a/src/sagemaker/amazon/record_pb2.py +++ b/src/sagemaker/amazon/record_pb2.py @@ -2,500 +2,793 @@ # source: record.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - DESCRIPTOR = _descriptor.FileDescriptor( - name='record.proto', - package='aialgs.data', - syntax='proto2', - serialized_pb=_b('\n\x0crecord.proto\x12\x0b\x61ialgs.data\"H\n\rFloat32Tensor\x12\x12\n\x06values\x18\x01 \x03(\x02\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01\"H\n\rFloat64Tensor\x12\x12\n\x06values\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01\"F\n\x0bInt32Tensor\x12\x12\n\x06values\x18\x01 \x03(\x05\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01\",\n\x05\x42ytes\x12\r\n\x05value\x18\x01 \x03(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x02 \x01(\t\"\xd3\x01\n\x05Value\x12\x34\n\x0e\x66loat32_tensor\x18\x02 \x01(\x0b\x32\x1a.aialgs.data.Float32TensorH\x00\x12\x34\n\x0e\x66loat64_tensor\x18\x03 \x01(\x0b\x32\x1a.aialgs.data.Float64TensorH\x00\x12\x30\n\x0cint32_tensor\x18\x07 \x01(\x0b\x32\x18.aialgs.data.Int32TensorH\x00\x12#\n\x05\x62ytes\x18\t \x01(\x0b\x32\x12.aialgs.data.BytesH\x00\x42\x07\n\x05value\"\xa9\x02\n\x06Record\x12\x33\n\x08\x66\x65\x61tures\x18\x01 \x03(\x0b\x32!.aialgs.data.Record.FeaturesEntry\x12-\n\x05label\x18\x02 \x03(\x0b\x32\x1e.aialgs.data.Record.LabelEntry\x12\x0b\n\x03uid\x18\x03 \x01(\t\x12\x10\n\x08metadata\x18\x04 \x01(\t\x12\x15\n\rconfiguration\x18\x05 \x01(\t\x1a\x43\n\rFeaturesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.aialgs.data.Value:\x02\x38\x01\x1a@\n\nLabelEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.aialgs.data.Value:\x02\x38\x01\x42\x30\n com.amazonaws.aialgorithms.protoB\x0cRecordProtos') + name="record.proto", + package="aialgs.data", + syntax="proto2", + serialized_pb=_b( + '\n\x0crecord.proto\x12\x0b\x61ialgs.data"H\n\rFloat32Tensor\x12\x12\n\x06values\x18\x01 \x03(\x02\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01"H\n\rFloat64Tensor\x12\x12\n\x06values\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01"F\n\x0bInt32Tensor\x12\x12\n\x06values\x18\x01 \x03(\x05\x42\x02\x10\x01\x12\x10\n\x04keys\x18\x02 \x03(\x04\x42\x02\x10\x01\x12\x11\n\x05shape\x18\x03 \x03(\x04\x42\x02\x10\x01",\n\x05\x42ytes\x12\r\n\x05value\x18\x01 \x03(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x02 \x01(\t"\xd3\x01\n\x05Value\x12\x34\n\x0e\x66loat32_tensor\x18\x02 \x01(\x0b\x32\x1a.aialgs.data.Float32TensorH\x00\x12\x34\n\x0e\x66loat64_tensor\x18\x03 \x01(\x0b\x32\x1a.aialgs.data.Float64TensorH\x00\x12\x30\n\x0cint32_tensor\x18\x07 \x01(\x0b\x32\x18.aialgs.data.Int32TensorH\x00\x12#\n\x05\x62ytes\x18\t \x01(\x0b\x32\x12.aialgs.data.BytesH\x00\x42\x07\n\x05value"\xa9\x02\n\x06Record\x12\x33\n\x08\x66\x65\x61tures\x18\x01 \x03(\x0b\x32!.aialgs.data.Record.FeaturesEntry\x12-\n\x05label\x18\x02 \x03(\x0b\x32\x1e.aialgs.data.Record.LabelEntry\x12\x0b\n\x03uid\x18\x03 \x01(\t\x12\x10\n\x08metadata\x18\x04 \x01(\t\x12\x15\n\rconfiguration\x18\x05 \x01(\t\x1a\x43\n\rFeaturesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.aialgs.data.Value:\x02\x38\x01\x1a@\n\nLabelEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.aialgs.data.Value:\x02\x38\x01\x42\x30\n com.amazonaws.aialgorithms.protoB\x0cRecordProtos' + ), ) - - _FLOAT32TENSOR = _descriptor.Descriptor( - name='Float32Tensor', - full_name='aialgs.data.Float32Tensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='aialgs.data.Float32Tensor.values', index=0, - number=1, type=2, cpp_type=6, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='keys', full_name='aialgs.data.Float32Tensor.keys', index=1, - number=2, type=4, cpp_type=4, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='aialgs.data.Float32Tensor.shape', index=2, - number=3, type=4, cpp_type=4, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=29, - serialized_end=101, + name="Float32Tensor", + full_name="aialgs.data.Float32Tensor", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="aialgs.data.Float32Tensor.values", + index=0, + number=1, + type=2, + cpp_type=6, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="keys", + full_name="aialgs.data.Float32Tensor.keys", + index=1, + number=2, + type=4, + cpp_type=4, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="shape", + full_name="aialgs.data.Float32Tensor.shape", + index=2, + number=3, + type=4, + cpp_type=4, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=29, + serialized_end=101, ) _FLOAT64TENSOR = _descriptor.Descriptor( - name='Float64Tensor', - full_name='aialgs.data.Float64Tensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='aialgs.data.Float64Tensor.values', index=0, - number=1, type=1, cpp_type=5, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='keys', full_name='aialgs.data.Float64Tensor.keys', index=1, - number=2, type=4, cpp_type=4, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='aialgs.data.Float64Tensor.shape', index=2, - number=3, type=4, cpp_type=4, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=103, - serialized_end=175, + name="Float64Tensor", + full_name="aialgs.data.Float64Tensor", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="aialgs.data.Float64Tensor.values", + index=0, + number=1, + type=1, + cpp_type=5, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="keys", + full_name="aialgs.data.Float64Tensor.keys", + index=1, + number=2, + type=4, + cpp_type=4, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="shape", + full_name="aialgs.data.Float64Tensor.shape", + index=2, + number=3, + type=4, + cpp_type=4, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=103, + serialized_end=175, ) _INT32TENSOR = _descriptor.Descriptor( - name='Int32Tensor', - full_name='aialgs.data.Int32Tensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='values', full_name='aialgs.data.Int32Tensor.values', index=0, - number=1, type=5, cpp_type=1, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='keys', full_name='aialgs.data.Int32Tensor.keys', index=1, - number=2, type=4, cpp_type=4, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='aialgs.data.Int32Tensor.shape', index=2, - number=3, type=4, cpp_type=4, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')), file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=177, - serialized_end=247, + name="Int32Tensor", + full_name="aialgs.data.Int32Tensor", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="values", + full_name="aialgs.data.Int32Tensor.values", + index=0, + number=1, + type=5, + cpp_type=1, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="keys", + full_name="aialgs.data.Int32Tensor.keys", + index=1, + number=2, + type=4, + cpp_type=4, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="shape", + full_name="aialgs.data.Int32Tensor.shape", + index=2, + number=3, + type=4, + cpp_type=4, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=177, + serialized_end=247, ) _BYTES = _descriptor.Descriptor( - name='Bytes', - full_name='aialgs.data.Bytes', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='aialgs.data.Bytes.value', index=0, - number=1, type=12, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='content_type', full_name='aialgs.data.Bytes.content_type', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=249, - serialized_end=293, + name="Bytes", + full_name="aialgs.data.Bytes", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="aialgs.data.Bytes.value", + index=0, + number=1, + type=12, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="content_type", + full_name="aialgs.data.Bytes.content_type", + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=249, + serialized_end=293, ) _VALUE = _descriptor.Descriptor( - name='Value', - full_name='aialgs.data.Value', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='float32_tensor', full_name='aialgs.data.Value.float32_tensor', index=0, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='float64_tensor', full_name='aialgs.data.Value.float64_tensor', index=1, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='int32_tensor', full_name='aialgs.data.Value.int32_tensor', index=2, - number=7, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='bytes', full_name='aialgs.data.Value.bytes', index=3, - number=9, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='value', full_name='aialgs.data.Value.value', - index=0, containing_type=None, fields=[]), - ], - serialized_start=296, - serialized_end=507, + name="Value", + full_name="aialgs.data.Value", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="float32_tensor", + full_name="aialgs.data.Value.float32_tensor", + index=0, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="float64_tensor", + full_name="aialgs.data.Value.float64_tensor", + index=1, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="int32_tensor", + full_name="aialgs.data.Value.int32_tensor", + index=2, + number=7, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="bytes", + full_name="aialgs.data.Value.bytes", + index=3, + number=9, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="value", + full_name="aialgs.data.Value.value", + index=0, + containing_type=None, + fields=[], + ) + ], + serialized_start=296, + serialized_end=507, ) _RECORD_FEATURESENTRY = _descriptor.Descriptor( - name='FeaturesEntry', - full_name='aialgs.data.Record.FeaturesEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='aialgs.data.Record.FeaturesEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='aialgs.data.Record.FeaturesEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=674, - serialized_end=741, + name="FeaturesEntry", + full_name="aialgs.data.Record.FeaturesEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="aialgs.data.Record.FeaturesEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="aialgs.data.Record.FeaturesEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=674, + serialized_end=741, ) _RECORD_LABELENTRY = _descriptor.Descriptor( - name='LabelEntry', - full_name='aialgs.data.Record.LabelEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='aialgs.data.Record.LabelEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='aialgs.data.Record.LabelEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=743, - serialized_end=807, + name="LabelEntry", + full_name="aialgs.data.Record.LabelEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="aialgs.data.Record.LabelEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="aialgs.data.Record.LabelEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=743, + serialized_end=807, ) _RECORD = _descriptor.Descriptor( - name='Record', - full_name='aialgs.data.Record', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='features', full_name='aialgs.data.Record.features', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='label', full_name='aialgs.data.Record.label', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='uid', full_name='aialgs.data.Record.uid', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata', full_name='aialgs.data.Record.metadata', index=3, - number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='configuration', full_name='aialgs.data.Record.configuration', index=4, - number=5, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_RECORD_FEATURESENTRY, _RECORD_LABELENTRY, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=510, - serialized_end=807, + name="Record", + full_name="aialgs.data.Record", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="features", + full_name="aialgs.data.Record.features", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="label", + full_name="aialgs.data.Record.label", + index=1, + number=2, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="uid", + full_name="aialgs.data.Record.uid", + index=2, + number=3, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="metadata", + full_name="aialgs.data.Record.metadata", + index=3, + number=4, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="configuration", + full_name="aialgs.data.Record.configuration", + index=4, + number=5, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[_RECORD_FEATURESENTRY, _RECORD_LABELENTRY], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto2", + extension_ranges=[], + oneofs=[], + serialized_start=510, + serialized_end=807, ) -_VALUE.fields_by_name['float32_tensor'].message_type = _FLOAT32TENSOR -_VALUE.fields_by_name['float64_tensor'].message_type = _FLOAT64TENSOR -_VALUE.fields_by_name['int32_tensor'].message_type = _INT32TENSOR -_VALUE.fields_by_name['bytes'].message_type = _BYTES -_VALUE.oneofs_by_name['value'].fields.append( - _VALUE.fields_by_name['float32_tensor']) -_VALUE.fields_by_name['float32_tensor'].containing_oneof = _VALUE.oneofs_by_name['value'] -_VALUE.oneofs_by_name['value'].fields.append( - _VALUE.fields_by_name['float64_tensor']) -_VALUE.fields_by_name['float64_tensor'].containing_oneof = _VALUE.oneofs_by_name['value'] -_VALUE.oneofs_by_name['value'].fields.append( - _VALUE.fields_by_name['int32_tensor']) -_VALUE.fields_by_name['int32_tensor'].containing_oneof = _VALUE.oneofs_by_name['value'] -_VALUE.oneofs_by_name['value'].fields.append( - _VALUE.fields_by_name['bytes']) -_VALUE.fields_by_name['bytes'].containing_oneof = _VALUE.oneofs_by_name['value'] -_RECORD_FEATURESENTRY.fields_by_name['value'].message_type = _VALUE +_VALUE.fields_by_name["float32_tensor"].message_type = _FLOAT32TENSOR +_VALUE.fields_by_name["float64_tensor"].message_type = _FLOAT64TENSOR +_VALUE.fields_by_name["int32_tensor"].message_type = _INT32TENSOR +_VALUE.fields_by_name["bytes"].message_type = _BYTES +_VALUE.oneofs_by_name["value"].fields.append(_VALUE.fields_by_name["float32_tensor"]) +_VALUE.fields_by_name["float32_tensor"].containing_oneof = _VALUE.oneofs_by_name["value"] +_VALUE.oneofs_by_name["value"].fields.append(_VALUE.fields_by_name["float64_tensor"]) +_VALUE.fields_by_name["float64_tensor"].containing_oneof = _VALUE.oneofs_by_name["value"] +_VALUE.oneofs_by_name["value"].fields.append(_VALUE.fields_by_name["int32_tensor"]) +_VALUE.fields_by_name["int32_tensor"].containing_oneof = _VALUE.oneofs_by_name["value"] +_VALUE.oneofs_by_name["value"].fields.append(_VALUE.fields_by_name["bytes"]) +_VALUE.fields_by_name["bytes"].containing_oneof = _VALUE.oneofs_by_name["value"] +_RECORD_FEATURESENTRY.fields_by_name["value"].message_type = _VALUE _RECORD_FEATURESENTRY.containing_type = _RECORD -_RECORD_LABELENTRY.fields_by_name['value'].message_type = _VALUE +_RECORD_LABELENTRY.fields_by_name["value"].message_type = _VALUE _RECORD_LABELENTRY.containing_type = _RECORD -_RECORD.fields_by_name['features'].message_type = _RECORD_FEATURESENTRY -_RECORD.fields_by_name['label'].message_type = _RECORD_LABELENTRY -DESCRIPTOR.message_types_by_name['Float32Tensor'] = _FLOAT32TENSOR -DESCRIPTOR.message_types_by_name['Float64Tensor'] = _FLOAT64TENSOR -DESCRIPTOR.message_types_by_name['Int32Tensor'] = _INT32TENSOR -DESCRIPTOR.message_types_by_name['Bytes'] = _BYTES -DESCRIPTOR.message_types_by_name['Value'] = _VALUE -DESCRIPTOR.message_types_by_name['Record'] = _RECORD +_RECORD.fields_by_name["features"].message_type = _RECORD_FEATURESENTRY +_RECORD.fields_by_name["label"].message_type = _RECORD_LABELENTRY +DESCRIPTOR.message_types_by_name["Float32Tensor"] = _FLOAT32TENSOR +DESCRIPTOR.message_types_by_name["Float64Tensor"] = _FLOAT64TENSOR +DESCRIPTOR.message_types_by_name["Int32Tensor"] = _INT32TENSOR +DESCRIPTOR.message_types_by_name["Bytes"] = _BYTES +DESCRIPTOR.message_types_by_name["Value"] = _VALUE +DESCRIPTOR.message_types_by_name["Record"] = _RECORD _sym_db.RegisterFileDescriptor(DESCRIPTOR) -Float32Tensor = _reflection.GeneratedProtocolMessageType('Float32Tensor', (_message.Message,), dict( - DESCRIPTOR = _FLOAT32TENSOR, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Float32Tensor) - )) +Float32Tensor = _reflection.GeneratedProtocolMessageType( + "Float32Tensor", + (_message.Message,), + dict( + DESCRIPTOR=_FLOAT32TENSOR, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Float32Tensor) + ), +) _sym_db.RegisterMessage(Float32Tensor) -Float64Tensor = _reflection.GeneratedProtocolMessageType('Float64Tensor', (_message.Message,), dict( - DESCRIPTOR = _FLOAT64TENSOR, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Float64Tensor) - )) +Float64Tensor = _reflection.GeneratedProtocolMessageType( + "Float64Tensor", + (_message.Message,), + dict( + DESCRIPTOR=_FLOAT64TENSOR, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Float64Tensor) + ), +) _sym_db.RegisterMessage(Float64Tensor) -Int32Tensor = _reflection.GeneratedProtocolMessageType('Int32Tensor', (_message.Message,), dict( - DESCRIPTOR = _INT32TENSOR, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Int32Tensor) - )) +Int32Tensor = _reflection.GeneratedProtocolMessageType( + "Int32Tensor", + (_message.Message,), + dict( + DESCRIPTOR=_INT32TENSOR, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Int32Tensor) + ), +) _sym_db.RegisterMessage(Int32Tensor) -Bytes = _reflection.GeneratedProtocolMessageType('Bytes', (_message.Message,), dict( - DESCRIPTOR = _BYTES, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Bytes) - )) +Bytes = _reflection.GeneratedProtocolMessageType( + "Bytes", + (_message.Message,), + dict( + DESCRIPTOR=_BYTES, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Bytes) + ), +) _sym_db.RegisterMessage(Bytes) -Value = _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), dict( - DESCRIPTOR = _VALUE, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Value) - )) +Value = _reflection.GeneratedProtocolMessageType( + "Value", + (_message.Message,), + dict( + DESCRIPTOR=_VALUE, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Value) + ), +) _sym_db.RegisterMessage(Value) -Record = _reflection.GeneratedProtocolMessageType('Record', (_message.Message,), dict( - - FeaturesEntry = _reflection.GeneratedProtocolMessageType('FeaturesEntry', (_message.Message,), dict( - DESCRIPTOR = _RECORD_FEATURESENTRY, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Record.FeaturesEntry) - )) - , - - LabelEntry = _reflection.GeneratedProtocolMessageType('LabelEntry', (_message.Message,), dict( - DESCRIPTOR = _RECORD_LABELENTRY, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Record.LabelEntry) - )) - , - DESCRIPTOR = _RECORD, - __module__ = 'record_pb2' - # @@protoc_insertion_point(class_scope:aialgs.data.Record) - )) +Record = _reflection.GeneratedProtocolMessageType( + "Record", + (_message.Message,), + dict( + FeaturesEntry=_reflection.GeneratedProtocolMessageType( + "FeaturesEntry", + (_message.Message,), + dict( + DESCRIPTOR=_RECORD_FEATURESENTRY, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Record.FeaturesEntry) + ), + ), + LabelEntry=_reflection.GeneratedProtocolMessageType( + "LabelEntry", + (_message.Message,), + dict( + DESCRIPTOR=_RECORD_LABELENTRY, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Record.LabelEntry) + ), + ), + DESCRIPTOR=_RECORD, + __module__="record_pb2" + # @@protoc_insertion_point(class_scope:aialgs.data.Record) + ), +) _sym_db.RegisterMessage(Record) _sym_db.RegisterMessage(Record.FeaturesEntry) _sym_db.RegisterMessage(Record.LabelEntry) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n com.amazonaws.aialgorithms.protoB\014RecordProtos')) -_FLOAT32TENSOR.fields_by_name['values'].has_options = True -_FLOAT32TENSOR.fields_by_name['values']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_FLOAT32TENSOR.fields_by_name['keys'].has_options = True -_FLOAT32TENSOR.fields_by_name['keys']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_FLOAT32TENSOR.fields_by_name['shape'].has_options = True -_FLOAT32TENSOR.fields_by_name['shape']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_FLOAT64TENSOR.fields_by_name['values'].has_options = True -_FLOAT64TENSOR.fields_by_name['values']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_FLOAT64TENSOR.fields_by_name['keys'].has_options = True -_FLOAT64TENSOR.fields_by_name['keys']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_FLOAT64TENSOR.fields_by_name['shape'].has_options = True -_FLOAT64TENSOR.fields_by_name['shape']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_INT32TENSOR.fields_by_name['values'].has_options = True -_INT32TENSOR.fields_by_name['values']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_INT32TENSOR.fields_by_name['keys'].has_options = True -_INT32TENSOR.fields_by_name['keys']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) -_INT32TENSOR.fields_by_name['shape'].has_options = True -_INT32TENSOR.fields_by_name['shape']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) +DESCRIPTOR._options = _descriptor._ParseOptions( + descriptor_pb2.FileOptions(), _b("\n com.amazonaws.aialgorithms.protoB\014RecordProtos") +) +_FLOAT32TENSOR.fields_by_name["values"].has_options = True +_FLOAT32TENSOR.fields_by_name["values"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_FLOAT32TENSOR.fields_by_name["keys"].has_options = True +_FLOAT32TENSOR.fields_by_name["keys"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_FLOAT32TENSOR.fields_by_name["shape"].has_options = True +_FLOAT32TENSOR.fields_by_name["shape"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_FLOAT64TENSOR.fields_by_name["values"].has_options = True +_FLOAT64TENSOR.fields_by_name["values"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_FLOAT64TENSOR.fields_by_name["keys"].has_options = True +_FLOAT64TENSOR.fields_by_name["keys"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_FLOAT64TENSOR.fields_by_name["shape"].has_options = True +_FLOAT64TENSOR.fields_by_name["shape"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_INT32TENSOR.fields_by_name["values"].has_options = True +_INT32TENSOR.fields_by_name["values"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_INT32TENSOR.fields_by_name["keys"].has_options = True +_INT32TENSOR.fields_by_name["keys"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) +_INT32TENSOR.fields_by_name["shape"].has_options = True +_INT32TENSOR.fields_by_name["shape"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("\020\001") +) _RECORD_FEATURESENTRY.has_options = True -_RECORD_FEATURESENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +_RECORD_FEATURESENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b("8\001") +) _RECORD_LABELENTRY.has_options = True -_RECORD_LABELENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +_RECORD_LABELENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b("8\001") +) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/amazon/validation.py b/src/sagemaker/amazon/validation.py index baaae4ad36..c6a3291e7e 100644 --- a/src/sagemaker/amazon/validation.py +++ b/src/sagemaker/amazon/validation.py @@ -16,34 +16,40 @@ def gt(minimum): def validate(value): return value > minimum + return validate def ge(minimum): def validate(value): return value >= minimum + return validate def lt(maximum): def validate(value): return value < maximum + return validate def le(maximum): def validate(value): return value <= maximum + return validate def isin(*expected): def validate(value): return value in expected + return validate def istype(expected): def validate(value): return isinstance(value, expected) + return validate diff --git a/src/sagemaker/analytics.py b/src/sagemaker/analytics.py index 91e5ec013a..11a594873b 100644 --- a/src/sagemaker/analytics.py +++ b/src/sagemaker/analytics.py @@ -111,28 +111,31 @@ def _fetch_dataframe(self): hyperparameters, results, and metadata. This also includes a column to indicate if a training job was the best seen so far. """ + def reshape(training_summary): # Helper method to reshape a single training job summary into a dataframe record out = {} - for k, v in training_summary['TunedHyperParameters'].items(): + for k, v in training_summary["TunedHyperParameters"].items(): # Something (bokeh?) gets confused with ints so convert to float try: v = float(v) except (TypeError, ValueError): pass out[k] = v - out['TrainingJobName'] = training_summary['TrainingJobName'] - out['TrainingJobStatus'] = training_summary['TrainingJobStatus'] - out['FinalObjectiveValue'] = training_summary.get('FinalHyperParameterTuningJobObjectiveMetric', - {}).get('Value') - - start_time = training_summary.get('TrainingStartTime', None) - end_time = training_summary.get('TrainingEndTime', None) - out['TrainingStartTime'] = start_time - out['TrainingEndTime'] = end_time + out["TrainingJobName"] = training_summary["TrainingJobName"] + out["TrainingJobStatus"] = training_summary["TrainingJobStatus"] + out["FinalObjectiveValue"] = training_summary.get( + "FinalHyperParameterTuningJobObjectiveMetric", {} + ).get("Value") + + start_time = training_summary.get("TrainingStartTime", None) + end_time = training_summary.get("TrainingEndTime", None) + out["TrainingStartTime"] = start_time + out["TrainingEndTime"] = end_time if start_time and end_time: - out['TrainingElapsedTimeSeconds'] = (end_time - start_time).total_seconds() + out["TrainingElapsedTimeSeconds"] = (end_time - start_time).total_seconds() return out + # Run that helper over all the summaries. df = pd.DataFrame([reshape(tjs) for tjs in self.training_job_summaries()]) return df @@ -143,9 +146,11 @@ def tuning_ranges(self): The keys are the names of the hyperparameter, and the values are the ranges. """ out = {} - for _, ranges in self.description()['HyperParameterTuningJobConfig']['ParameterRanges'].items(): + for _, ranges in self.description()["HyperParameterTuningJobConfig"][ + "ParameterRanges" + ].items(): for param in ranges: - out[param['Name']] = param + out[param["Name"]] = param return out def description(self, force_refresh=False): @@ -185,11 +190,13 @@ def training_job_summaries(self, force_refresh=False): raw_result = self._sage_client.list_training_jobs_for_hyper_parameter_tuning_job( HyperParameterTuningJobName=self.name, MaxResults=100, **next_args ) - new_output = raw_result['TrainingJobSummaries'] + new_output = raw_result["TrainingJobSummaries"] output.extend(new_output) - logging.debug("Got %d more TrainingJobs. Total so far: %d" % (len(new_output), len(output))) - if ('NextToken' in raw_result) and (len(new_output) > 0): - next_args['NextToken'] = raw_result['NextToken'] + logging.debug( + "Got %d more TrainingJobs. Total so far: %d" % (len(new_output), len(output)) + ) + if ("NextToken" in raw_result) and (len(new_output) > 0): + next_args["NextToken"] = raw_result["NextToken"] else: break self._training_job_summaries = output @@ -200,10 +207,17 @@ class TrainingJobAnalytics(AnalyticsMetricsBase): """Fetch training curve data from CloudWatch Metrics for a specific training job. """ - CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs' - - def __init__(self, training_job_name, metric_names=None, sagemaker_session=None, - start_time=None, end_time=None, period=None): + CLOUDWATCH_NAMESPACE = "/aws/sagemaker/TrainingJobs" + + def __init__( + self, + training_job_name, + metric_names=None, + sagemaker_session=None, + start_time=None, + end_time=None, + period=None, + ): """Initialize a ``TrainingJobAnalytics`` instance. Args: @@ -216,7 +230,7 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None, """ sagemaker_session = sagemaker_session or Session() self._sage_client = sagemaker_session.sagemaker_client - self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch') + self._cloudwatch = sagemaker_session.boto_session.client("cloudwatch") self._training_job_name = training_job_name self._start_time = start_time self._end_time = end_time @@ -251,19 +265,17 @@ def _determine_timeinterval(self): covering the interval of the training job """ description = self._sage_client.describe_training_job(TrainingJobName=self.name) - start_time = self._start_time or description[u'TrainingStartTime'] # datetime object + start_time = self._start_time or description[u"TrainingStartTime"] # datetime object # Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs. # This results in logs being searched in the time range in which the correct log line was not present. # Example - Log time - 2018-10-22 08:25:55 # Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition) # CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log. end_time = self._end_time or description.get( - u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1) + u"TrainingEndTime", datetime.datetime.utcnow() + ) + datetime.timedelta(minutes=1) - return { - 'start_time': start_time, - 'end_time': end_time, - } + return {"start_time": start_time, "end_time": end_time} def _fetch_dataframe(self): for metric_name in self._metric_names: @@ -274,30 +286,25 @@ def _fetch_metric(self, metric_name): """Fetch all the values of a named metric, and add them to _data """ request = { - 'Namespace': self.CLOUDWATCH_NAMESPACE, - 'MetricName': metric_name, - 'Dimensions': [ - { - 'Name': 'TrainingJobName', - 'Value': self.name - } - ], - 'StartTime': self._time_interval['start_time'], - 'EndTime': self._time_interval['end_time'], - 'Period': self._period, - 'Statistics': ['Average'], + "Namespace": self.CLOUDWATCH_NAMESPACE, + "MetricName": metric_name, + "Dimensions": [{"Name": "TrainingJobName", "Value": self.name}], + "StartTime": self._time_interval["start_time"], + "EndTime": self._time_interval["end_time"], + "Period": self._period, + "Statistics": ["Average"], } - raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints'] + raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)["Datapoints"] if len(raw_cwm_data) == 0: logging.warning("Warning: No metrics called %s found" % metric_name) return # Process data: normalize to starting time, and sort. - base_time = min(raw_cwm_data, key=lambda pt: pt['Timestamp'])['Timestamp'] + base_time = min(raw_cwm_data, key=lambda pt: pt["Timestamp"])["Timestamp"] all_xy = [] for pt in raw_cwm_data: - y = pt['Average'] - x = (pt['Timestamp'] - base_time).total_seconds() + y = pt["Average"] + x = (pt["Timestamp"] - base_time).total_seconds() all_xy.append([x, y]) all_xy = sorted(all_xy, key=lambda x: x[0]) @@ -311,16 +318,18 @@ def _add_single_metric(self, timestamp, metric_name, value): """ # note that this method is built this way to make it possible to # support live-refreshing charts in Bokeh at some point in the future. - self._data['timestamp'].append(timestamp) - self._data['metric_name'].append(metric_name) - self._data['value'].append(value) + self._data["timestamp"].append(timestamp) + self._data["metric_name"].append(metric_name) + self._data["value"].append(value) def _metric_names_for_training_job(self): """Helper method to discover the metrics defined for a training job. """ - training_description = self._sage_client.describe_training_job(TrainingJobName=self._training_job_name) + training_description = self._sage_client.describe_training_job( + TrainingJobName=self._training_job_name + ) - metric_definitions = training_description['AlgorithmSpecification']['MetricDefinitions'] - metric_names = [md['Name'] for md in metric_definitions] + metric_definitions = training_description["AlgorithmSpecification"]["MetricDefinitions"] + metric_names = [md["Name"] for md in metric_definitions] return metric_names diff --git a/src/sagemaker/chainer/defaults.py b/src/sagemaker/chainer/defaults.py index 69163cc6eb..878f76747a 100644 --- a/src/sagemaker/chainer/defaults.py +++ b/src/sagemaker/chainer/defaults.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -CHAINER_VERSION = '4.1.0' +CHAINER_VERSION = "4.1.0" """Default Chainer version for when the framework version is not specified. This is no longer updated so as to not break existing workflows. """ diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 7db0725578..6c553aafa4 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -15,13 +15,17 @@ import logging from sagemaker.estimator import Framework -from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \ - python_deprecation_warning +from sagemaker.fw_utils import ( + framework_name_from_image, + framework_version_from_tag, + empty_framework_version_warning, + python_deprecation_warning, +) from sagemaker.chainer.defaults import CHAINER_VERSION from sagemaker.chainer.model import ChainerModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class Chainer(Framework): @@ -35,12 +39,23 @@ class Chainer(Framework): _process_slots_per_host = "sagemaker_process_slots_per_host" _additional_mpi_options = "sagemaker_additional_mpi_options" - LATEST_VERSION = '5.0.0' + LATEST_VERSION = "5.0.0" """The latest version of Chainer included in the SageMaker pre-built Docker images.""" - def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_per_host=None, - additional_mpi_options=None, source_dir=None, hyperparameters=None, py_version='py3', - framework_version=None, image_name=None, **kwargs): + def __init__( + self, + entry_point, + use_mpi=None, + num_processes=None, + process_slots_per_host=None, + additional_mpi_options=None, + source_dir=None, + hyperparameters=None, + py_version="py3", + framework_version=None, + image_name=None, + **kwargs + ): """ This ``Estimator`` executes an Chainer script in a managed Chainer execution environment, within a SageMaker Training Job. The managed Chainer environment is an Amazon-built Docker container that executes functions @@ -89,10 +104,11 @@ def __init__(self, entry_point, use_mpi=None, num_processes=None, process_slots_ logger.warning(empty_framework_version_warning(CHAINER_VERSION, self.LATEST_VERSION)) self.framework_version = framework_version or CHAINER_VERSION - super(Chainer, self).__init__(entry_point, source_dir, hyperparameters, - image_name=image_name, **kwargs) + super(Chainer, self).__init__( + entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs + ) - if py_version == 'py2': + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -105,17 +121,21 @@ def hyperparameters(self): """Return hyperparameters used by your custom Chainer code during training.""" hyperparameters = super(Chainer, self).hyperparameters() - additional_hyperparameters = {Chainer._use_mpi: self.use_mpi, - Chainer._num_processes: self.num_processes, - Chainer._process_slots_per_host: self.process_slots_per_host, - Chainer._additional_mpi_options: self.additional_mpi_options} + additional_hyperparameters = { + Chainer._use_mpi: self.use_mpi, + Chainer._num_processes: self.num_processes, + Chainer._process_slots_per_host: self.process_slots_per_host, + Chainer._additional_mpi_options: self.additional_mpi_options, + } # remove unset keys. additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v} hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) return hyperparameters - def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): + def create_model( + self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT + ): """Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``. Args: @@ -133,13 +153,23 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override See :func:`~sagemaker.chainer.model.ChainerModel` for full details. """ role = role or self.role - return ChainerModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), - enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, - container_log_level=self.container_log_level, code_location=self.code_location, - py_version=self.py_version, framework_version=self.framework_version, - model_server_workers=model_server_workers, image=self.image_name, - sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override), dependencies=self.dependencies) + return ChainerModel( + self.model_data, + role, + self.entry_point, + source_dir=self._model_source_dir(), + enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, + name=self._current_job_name, + container_log_level=self.container_log_level, + code_location=self.code_location, + py_version=self.py_version, + framework_version=self.framework_version, + model_server_workers=model_server_workers, + image=self.image_name, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + dependencies=self.dependencies, + ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -153,29 +183,39 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na dictionary: The transformed init_params """ - init_params = super(Chainer, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) - - for argument in [Chainer._use_mpi, Chainer._num_processes, Chainer._process_slots_per_host, - Chainer._additional_mpi_options]: - - value = init_params['hyperparameters'].pop(argument, None) + init_params = super(Chainer, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) + + for argument in [ + Chainer._use_mpi, + Chainer._num_processes, + Chainer._process_slots_per_host, + Chainer._additional_mpi_options, + ]: + + value = init_params["hyperparameters"].pop(argument, None) if value: - init_params[argument[len('sagemaker_'):]] = value + init_params[argument[len("sagemaker_") :]] = value - image_name = init_params.pop('image') + image_name = init_params.pop("image") framework, py_version, tag, _ = framework_name_from_image(image_name) if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. - init_params['image_name'] = image_name + init_params["image_name"] = image_name return init_params - init_params['py_version'] = py_version - init_params['framework_version'] = framework_version_from_tag(tag) + init_params["py_version"] = py_version + init_params["framework_version"] = framework_version_from_tag(tag) - training_job_name = init_params['base_job_name'] + training_job_name = init_params["base_job_name"] if framework != cls.__framework_name__: - raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name)) + raise ValueError( + "Training job: {} didn't use image for requested framework".format( + training_job_name + ) + ) return init_params diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index ddab2d33ac..2fd2c8a70e 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -20,7 +20,7 @@ from sagemaker.chainer.defaults import CHAINER_VERSION from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class ChainerPredictor(RealTimePredictor): @@ -38,16 +38,28 @@ def __init__(self, endpoint_name, sagemaker_session=None): Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. """ - super(ChainerPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer) + super(ChainerPredictor, self).__init__( + endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer + ) class ChainerModel(FrameworkModel): """An Chainer SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" - __framework_name__ = 'chainer' - - def __init__(self, model_data, role, entry_point, image=None, py_version='py3', framework_version=CHAINER_VERSION, - predictor_cls=ChainerPredictor, model_server_workers=None, **kwargs): + __framework_name__ = "chainer" + + def __init__( + self, + model_data, + role, + entry_point, + image=None, + py_version="py3", + framework_version=CHAINER_VERSION, + predictor_cls=ChainerPredictor, + model_server_workers=None, + **kwargs + ): """Initialize an ChainerModel. Args: @@ -68,9 +80,10 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py3', If None, server will use one worker per vCPU. **kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer. """ - super(ChainerModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, - **kwargs) - if py_version == 'py2': + super(ChainerModel, self).__init__( + model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs + ) + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -91,8 +104,14 @@ def prepare_container_def(self, instance_type, accelerator_type=None): deploy_image = self.image if not deploy_image: region_name = self.sagemaker_session.boto_session.region_name - deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type, - self.framework_version, self.py_version, accelerator_type=accelerator_type) + deploy_image = create_image_uri( + region_name, + self.__framework_name__, + instance_type, + self.framework_version, + self.py_version, + accelerator_type=accelerator_type, + ) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix) diff --git a/src/sagemaker/cli/common.py b/src/sagemaker/cli/common.py index 346ed5b570..ed5599c93a 100644 --- a/src/sagemaker/cli/common.py +++ b/src/sagemaker/cli/common.py @@ -34,12 +34,12 @@ def __init__(self, args): self.script = args.script self.instance_type = args.instance_type self.instance_count = args.instance_count - self.environment = {k: v for k, v in (kv.split('=') for kv in args.env)} + self.environment = {k: v for k, v in (kv.split("=") for kv in args.env)} self.session = sagemaker.Session() def upload_model(self): - prefix = '{}/model'.format(self.endpoint_name) + prefix = "{}/model".format(self.endpoint_name) archive = self.create_model_archive(self.data) model_uri = self.session.upload_data(path=archive, bucket=self.bucket, key_prefix=prefix) @@ -50,14 +50,14 @@ def upload_model(self): @staticmethod def create_model_archive(src): if os.path.isdir(src): - arcname = '.' + arcname = "." else: arcname = os.path.basename(src) tmp = tempfile.mkdtemp() - archive = os.path.join(tmp, 'model.tar.gz') + archive = os.path.join(tmp, "model.tar.gz") - with tarfile.open(archive, mode='w:gz') as t: + with tarfile.open(archive, mode="w:gz") as t: t.add(src, arcname=arcname) return archive @@ -67,8 +67,9 @@ def create_model(self, model_url): def start(self): model_url = self.upload_model() model = self.create_model(model_url) - predictor = model.deploy(initial_instance_count=self.instance_count, - instance_type=self.instance_type) + predictor = model.deploy( + initial_instance_count=self.instance_count, instance_type=self.instance_type + ) return predictor @@ -91,12 +92,12 @@ def __init__(self, args): def load_hyperparameters(src): hp = {} if src and os.path.exists(src): - with open(src, 'r') as f: + with open(src, "r") as f: hp = json.load(f) return hp def upload_training_data(self): - prefix = '{}/data'.format(self.job_name) + prefix = "{}/data".format(self.job_name) data_url = self.session.upload_data(path=self.data, bucket=self.bucket, key_prefix=prefix) return data_url @@ -107,6 +108,9 @@ def start(self): data_url = self.upload_training_data() estimator = self.create_estimator() estimator.fit(data_url) - logger.debug('code location: {}'.format(estimator.uploaded_code.s3_prefix)) - logger.debug('model location: {}{}/output/model.tar.gz'.format(estimator.output_path, - estimator._current_job_name)) + logger.debug("code location: {}".format(estimator.uploaded_code.s3_prefix)) + logger.debug( + "model location: {}{}/output/model.tar.gz".format( + estimator.output_path, estimator._current_job_name + ) + ) diff --git a/src/sagemaker/cli/main.py b/src/sagemaker/cli/main.py index a466482bf8..3a61a324c7 100644 --- a/src/sagemaker/cli/main.py +++ b/src/sagemaker/cli/main.py @@ -22,78 +22,110 @@ logger = logging.getLogger(__name__) -DEFAULT_LOG_LEVEL = 'info' -DEFAULT_BOTOCORE_LOG_LEVEL = 'warning' +DEFAULT_LOG_LEVEL = "info" +DEFAULT_BOTOCORE_LOG_LEVEL = "warning" def parse_arguments(args): - parser = argparse.ArgumentParser(description='Launch SageMaker training jobs or hosting endpoints') + parser = argparse.ArgumentParser( + description="Launch SageMaker training jobs or hosting endpoints" + ) parser.set_defaults(func=lambda x: parser.print_usage()) # common args for training/hosting/all frameworks common_parser = argparse.ArgumentParser(add_help=False) - common_parser.add_argument('--role-name', help='SageMaker execution role name', type=str, required=True) - common_parser.add_argument('--data', help='path to training data or model files', type=str, default='./data') - common_parser.add_argument('--script', help='path to script', type=str, default='./script.py') - common_parser.add_argument('--job-name', help='job or endpoint name', type=str, default=None) - common_parser.add_argument('--bucket-name', help='S3 bucket for training/model data and script files', - type=str, default=None) - common_parser.add_argument('--python', help='python version', type=str, default='py2') - - instance_group = common_parser.add_argument_group('instance settings') - instance_group.add_argument('--instance-type', type=str, help='instance type', default='ml.m4.xlarge') - instance_group.add_argument('--instance-count', type=int, help='instance count', default=1) + common_parser.add_argument( + "--role-name", help="SageMaker execution role name", type=str, required=True + ) + common_parser.add_argument( + "--data", help="path to training data or model files", type=str, default="./data" + ) + common_parser.add_argument("--script", help="path to script", type=str, default="./script.py") + common_parser.add_argument("--job-name", help="job or endpoint name", type=str, default=None) + common_parser.add_argument( + "--bucket-name", + help="S3 bucket for training/model data and script files", + type=str, + default=None, + ) + common_parser.add_argument("--python", help="python version", type=str, default="py2") + + instance_group = common_parser.add_argument_group("instance settings") + instance_group.add_argument( + "--instance-type", type=str, help="instance type", default="ml.m4.xlarge" + ) + instance_group.add_argument("--instance-count", type=int, help="instance count", default=1) # common training args common_train_parser = argparse.ArgumentParser(add_help=False) - common_train_parser.add_argument('--hyperparameters', help='path to training hyperparameters file', - type=str, default='./hyperparameters.json') + common_train_parser.add_argument( + "--hyperparameters", + help="path to training hyperparameters file", + type=str, + default="./hyperparameters.json", + ) # common hosting args common_host_parser = argparse.ArgumentParser(add_help=False) - common_host_parser.add_argument('--env', help='hosting environment variable(s)', type=str, nargs='*', default=[]) + common_host_parser.add_argument( + "--env", help="hosting environment variable(s)", type=str, nargs="*", default=[] + ) subparsers = parser.add_subparsers() # framework/algo subcommands - mxnet_parser = subparsers.add_parser('mxnet', help='use MXNet', parents=[]) + mxnet_parser = subparsers.add_parser("mxnet", help="use MXNet", parents=[]) mxnet_subparsers = mxnet_parser.add_subparsers() - mxnet_train_parser = mxnet_subparsers.add_parser('train', - help='start a training job', - parents=[common_parser, common_train_parser]) + mxnet_train_parser = mxnet_subparsers.add_parser( + "train", help="start a training job", parents=[common_parser, common_train_parser] + ) mxnet_train_parser.set_defaults(func=sagemaker.cli.mxnet.train) - mxnet_host_parser = mxnet_subparsers.add_parser('host', - help='start a hosting endpoint', - parents=[common_parser, common_host_parser]) + mxnet_host_parser = mxnet_subparsers.add_parser( + "host", help="start a hosting endpoint", parents=[common_parser, common_host_parser] + ) mxnet_host_parser.set_defaults(func=sagemaker.cli.mxnet.host) - tensorflow_parser = subparsers.add_parser('tensorflow', help='use TensorFlow', parents=[]) + tensorflow_parser = subparsers.add_parser("tensorflow", help="use TensorFlow", parents=[]) tensorflow_subparsers = tensorflow_parser.add_subparsers() - tensorflow_train_parser = tensorflow_subparsers.add_parser('train', - help='start a training job', - parents=[common_parser, common_train_parser]) - tensorflow_train_parser.add_argument('--training-steps', - help='number of training steps (tensorflow only)', type=int, default=None) - tensorflow_train_parser.add_argument('--evaluation-steps', - help='number of evaluation steps (tensorflow only)', type=int, default=None) + tensorflow_train_parser = tensorflow_subparsers.add_parser( + "train", help="start a training job", parents=[common_parser, common_train_parser] + ) + tensorflow_train_parser.add_argument( + "--training-steps", + help="number of training steps (tensorflow only)", + type=int, + default=None, + ) + tensorflow_train_parser.add_argument( + "--evaluation-steps", + help="number of evaluation steps (tensorflow only)", + type=int, + default=None, + ) tensorflow_train_parser.set_defaults(func=sagemaker.cli.tensorflow.train) - tensorflow_host_parser = tensorflow_subparsers.add_parser('host', - help='start a hosting endpoint', - parents=[common_parser, common_host_parser]) + tensorflow_host_parser = tensorflow_subparsers.add_parser( + "host", help="start a hosting endpoint", parents=[common_parser, common_host_parser] + ) tensorflow_host_parser.set_defaults(func=sagemaker.cli.tensorflow.host) - log_group = parser.add_argument_group('optional log settings') - log_group.add_argument('--log-level', help='log level for this command', type=str, default=DEFAULT_LOG_LEVEL) - log_group.add_argument('--botocore-log-level', help='log level for botocore', type=str, - default=DEFAULT_BOTOCORE_LOG_LEVEL) + log_group = parser.add_argument_group("optional log settings") + log_group.add_argument( + "--log-level", help="log level for this command", type=str, default=DEFAULT_LOG_LEVEL + ) + log_group.add_argument( + "--botocore-log-level", + help="log level for botocore", + type=str, + default=DEFAULT_BOTOCORE_LOG_LEVEL, + ) return parser.parse_args(args) def configure_logging(args): - log_format = '%(asctime)s %(levelname)s %(name)s: %(message)s' + log_format = "%(asctime)s %(levelname)s %(name)s: %(message)s" log_level = logging.getLevelName(args.log_level.upper()) logging.basicConfig(format=log_format, level=log_level) logging.getLogger("botocore").setLevel(args.botocore_log_level.upper()) @@ -102,9 +134,9 @@ def configure_logging(args): def main(): args = parse_arguments(sys.argv[1:]) configure_logging(args) - logger.debug('args: {}'.format(args)) + logger.debug("args: {}".format(args)) args.func(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/sagemaker/cli/mxnet.py b/src/sagemaker/cli/mxnet.py index 46905e272d..d989d2d01f 100644 --- a/src/sagemaker/cli/mxnet.py +++ b/src/sagemaker/cli/mxnet.py @@ -26,17 +26,27 @@ def host(args): class MXNetTrainCommand(TrainCommand): def create_estimator(self): from sagemaker.mxnet.estimator import MXNet - return MXNet(self.script, - role=self.role_name, - base_job_name=self.job_name, - train_instance_count=self.instance_count, - train_instance_type=self.instance_type, - hyperparameters=self.hyperparameters, - py_version=self.python) + + return MXNet( + self.script, + role=self.role_name, + base_job_name=self.job_name, + train_instance_count=self.instance_count, + train_instance_type=self.instance_type, + hyperparameters=self.hyperparameters, + py_version=self.python, + ) class MXNetHostCommand(HostCommand): def create_model(self, model_url): from sagemaker.mxnet.model import MXNetModel - return MXNetModel(model_data=model_url, role=self.role_name, entry_point=self.script, - py_version=self.python, name=self.endpoint_name, env=self.environment) + + return MXNetModel( + model_data=model_url, + role=self.role_name, + entry_point=self.script, + py_version=self.python, + name=self.endpoint_name, + env=self.environment, + ) diff --git a/src/sagemaker/cli/tensorflow.py b/src/sagemaker/cli/tensorflow.py index 5c68e949c6..9fbd7dfa1f 100644 --- a/src/sagemaker/cli/tensorflow.py +++ b/src/sagemaker/cli/tensorflow.py @@ -31,19 +31,29 @@ def __init__(self, args): def create_estimator(self): from sagemaker.tensorflow import TensorFlow - return TensorFlow(training_steps=self.training_steps, - evaluation_steps=self.evaluation_steps, - py_version=self.python, - entry_point=self.script, - role=self.role_name, - base_job_name=self.job_name, - train_instance_count=self.instance_count, - train_instance_type=self.instance_type, - hyperparameters=self.hyperparameters) + + return TensorFlow( + training_steps=self.training_steps, + evaluation_steps=self.evaluation_steps, + py_version=self.python, + entry_point=self.script, + role=self.role_name, + base_job_name=self.job_name, + train_instance_count=self.instance_count, + train_instance_type=self.instance_type, + hyperparameters=self.hyperparameters, + ) class TensorFlowHostCommand(HostCommand): def create_model(self, model_url): from sagemaker.tensorflow.model import TensorFlowModel - return TensorFlowModel(model_data=model_url, role=self.role_name, entry_point=self.script, - py_version=self.python, name=self.endpoint_name, env=self.environment) + + return TensorFlowModel( + model_data=model_url, + role=self.role_name, + entry_point=self.script, + py_version=self.python, + name=self.endpoint_name, + env=self.environment, + ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 899039d79a..a126f3786c 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -CONTENT_TYPE_JSON = 'application/json' -CONTENT_TYPE_CSV = 'text/csv' -CONTENT_TYPE_OCTET_STREAM = 'application/octet-stream' -CONTENT_TYPE_NPY = 'application/x-npy' +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_CSV = "text/csv" +CONTENT_TYPE_OCTET_STREAM = "application/octet-stream" +CONTENT_TYPE_NPY = "application/x-npy" diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 2bc770f4c7..413335984f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -23,13 +23,24 @@ import sagemaker from sagemaker.analytics import TrainingJobAnalytics -from sagemaker.fw_utils import (create_image_uri, tar_and_upload_dir, parse_s3_url, UploadedCode, - validate_source_dir) +from sagemaker.fw_utils import ( + create_image_uri, + tar_and_upload_dir, + parse_s3_url, + UploadedCode, + validate_source_dir, +) from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.model import Model, NEO_ALLOWED_TARGET_INSTANCE_FAMILY, NEO_ALLOWED_FRAMEWORKS -from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME, - CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME) +from sagemaker.model import ( + SCRIPT_PARAM_NAME, + DIR_PARAM_NAME, + CLOUDWATCH_METRICS_PARAM_NAME, + CONTAINER_LOG_LEVEL_PARAM_NAME, + JOB_NAME_PARAM_NAME, + SAGEMAKER_REGION_PARAM_NAME, +) from sagemaker.predictor import RealTimePredictor from sagemaker.session import Session from sagemaker.session import s3_input @@ -48,11 +59,27 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): what hyperparameters to use, and how to create an appropriate predictor instance. """ - def __init__(self, role, train_instance_count, train_instance_type, - train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, input_mode='File', - output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None, - subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model', - metric_definitions=None, encrypt_inter_container_traffic=False): + def __init__( + self, + role, + train_instance_count, + train_instance_type, + train_volume_size=30, + train_volume_kms_key=None, + train_max_run=24 * 60 * 60, + input_mode="File", + output_path=None, + output_kms_key=None, + base_job_name=None, + sagemaker_session=None, + tags=None, + subnets=None, + security_group_ids=None, + model_uri=None, + model_channel_name="model", + metric_definitions=None, + encrypt_inter_container_traffic=False, + ): """Initialize an ``EstimatorBase`` instance. Args: @@ -118,10 +145,10 @@ def __init__(self, role, train_instance_count, train_instance_type, self.model_uri = model_uri self.model_channel_name = model_channel_name self.code_uri = None - self.code_channel_name = 'code' + self.code_channel_name = "code" - if self.train_instance_type in ('local', 'local_gpu'): - if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: + if self.train_instance_type in ("local", "local_gpu"): + if self.train_instance_type == "local_gpu" and self.train_instance_count > 1: raise RuntimeError("Distributed Training in Local GPU is not supported") self.sagemaker_session = sagemaker_session or LocalSession() else: @@ -129,9 +156,12 @@ def __init__(self, role, train_instance_count, train_instance_type, self.base_job_name = base_job_name self._current_job_name = None - if (not self.sagemaker_session.local_mode - and output_path and output_path.startswith('file://')): - raise RuntimeError('file:// output paths are only supported in Local Mode') + if ( + not self.sagemaker_session.local_mode + and output_path + and output_path.startswith("file://") + ): + raise RuntimeError("file:// output paths are only supported in Local Mode") self.output_path = output_path self.output_kms_key = output_kms_key self.latest_training_job = None @@ -188,7 +218,7 @@ def _prepare_for_training(self, job_name=None): if self.base_job_name: base_name = self.base_job_name elif isinstance(self, sagemaker.algorithm.AlgorithmEstimator): - base_name = self.algorithm_arn.split('/')[-1] # pylint: disable=no-member + base_name = self.algorithm_arn.split("/")[-1] # pylint: disable=no-member else: base_name = base_name_from_image(self.train_image()) @@ -197,11 +227,11 @@ def _prepare_for_training(self, job_name=None): # if output_path was specified we use it otherwise initialize here. # For Local Mode with local_code=True we don't need an explicit output_path if self.output_path is None: - local_code = get_config_value('local.local_code', self.sagemaker_session.config) + local_code = get_config_value("local.local_code", self.sagemaker_session.config) if self.sagemaker_session.local_mode and local_code: - self.output_path = '' + self.output_path = "" else: - self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket()) + self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket()) def fit(self, inputs=None, wait=True, logs=True, job_name=None): """Train a model using the input training dataset. @@ -239,10 +269,19 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None): def _compilation_job_name(self): base_name = self.base_job_name or base_name_from_image(self.train_image()) - return name_from_base('compilation-' + base_name) - - def compile_model(self, target_instance_family, input_shape, output_path, framework=None, framework_version=None, - compile_max_run=5 * 60, tags=None, **kwargs): + return name_from_base("compilation-" + base_name) + + def compile_model( + self, + target_instance_family, + input_shape, + output_path, + framework=None, + framework_version=None, + compile_max_run=5 * 60, + tags=None, + **kwargs + ): """Compile a Neo model using the input model. Args: @@ -267,29 +306,35 @@ def compile_model(self, target_instance_family, input_shape, output_path, framew sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details. """ if target_instance_family not in NEO_ALLOWED_TARGET_INSTANCE_FAMILY: - raise ValueError("Please use valid target_instance_family," - "allowed values: {}".format(NEO_ALLOWED_TARGET_INSTANCE_FAMILY)) + raise ValueError( + "Please use valid target_instance_family," + "allowed values: {}".format(NEO_ALLOWED_TARGET_INSTANCE_FAMILY) + ) if framework and framework not in NEO_ALLOWED_FRAMEWORKS: - raise ValueError("Please use valid framework, allowed values: {}".format(NEO_ALLOWED_FRAMEWORKS)) + raise ValueError( + "Please use valid framework, allowed values: {}".format(NEO_ALLOWED_FRAMEWORKS) + ) if (framework is None) != (framework_version is None): raise ValueError("You should provide framework and framework_version at the same time.") model = self.create_model(**kwargs) - self._compiled_models[target_instance_family] = model.compile(target_instance_family, - input_shape, - output_path, - self.role, - tags, - self._compilation_job_name(), - compile_max_run, - framework=framework, - framework_version=framework_version) + self._compiled_models[target_instance_family] = model.compile( + target_instance_family, + input_shape, + output_path, + self.role, + tags, + self._compilation_job_name(), + compile_max_run, + framework=framework, + framework_version=framework_version, + ) return self._compiled_models[target_instance_family] @classmethod - def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='model'): + def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="model"): """Attach to an existing training job. Create an Estimator bound to an existing training job, each subclass is responsible to implement @@ -320,20 +365,34 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m """ sagemaker_session = sagemaker_session or Session() - job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) + job_details = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job_name + ) init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name) - tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=job_details['TrainingJobArn'])['Tags'] + tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=job_details["TrainingJobArn"] + )["Tags"] init_params.update(tags=tags) estimator = cls(sagemaker_session=sagemaker_session, **init_params) - estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session, - job_name=init_params['base_job_name']) + estimator.latest_training_job = _TrainingJob( + sagemaker_session=sagemaker_session, job_name=init_params["base_job_name"] + ) estimator._current_job_name = estimator.latest_training_job.name estimator.latest_training_job.wait() return estimator - def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, - use_compiled_model=False, update_endpoint=False, wait=True, **kwargs): + def deploy( + self, + initial_instance_count, + instance_type, + accelerator_type=None, + endpoint_name=None, + use_compiled_model=False, + update_endpoint=False, + wait=True, + **kwargs + ): """Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object. More information: @@ -371,10 +430,12 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e endpoint_name = endpoint_name or self.latest_training_job.name self.deploy_instance_type = instance_type if use_compiled_model: - family = '_'.join(instance_type.split('.')[:-1]) + family = "_".join(instance_type.split(".")[:-1]) if family not in self._compiled_models: - raise ValueError("No compiled model for {}. " - "Please compile one with compile_model before deploying.".format(family)) + raise ValueError( + "No compiled model for {}. " + "Please compile one with compile_model before deploying.".format(family) + ) model = self._compiled_models[family] else: model = self.create_model(**kwargs) @@ -385,18 +446,24 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e endpoint_name=endpoint_name, update_endpoint=update_endpoint, tags=self.tags, - wait=wait) + wait=wait, + ) @property def model_data(self): """str: The model location in S3. Only set if Estimator has been ``fit()``.""" if self.latest_training_job is not None: model_uri = self.sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=self.latest_training_job.name)['ModelArtifacts']['S3ModelArtifacts'] + TrainingJobName=self.latest_training_job.name + )["ModelArtifacts"]["S3ModelArtifacts"] else: - logging.warning('No finished training job found associated with this estimator. Please make sure' - 'this estimator is only used for building workflow config') - model_uri = os.path.join(self.output_path, self._current_job_name, 'output', 'model.tar.gz') + logging.warning( + "No finished training job found associated with this estimator. Please make sure" + "this estimator is only used for building workflow config" + ) + model_uri = os.path.join( + self.output_path, self._current_job_name, "output", "model.tar.gz" + ) return model_uri @@ -425,45 +492,50 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na """ init_params = dict() - init_params['role'] = job_details['RoleArn'] - init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount'] - init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType'] - init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB'] - init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds'] - init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode'] - init_params['base_job_name'] = job_details['TrainingJobName'] - init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath'] - init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId'] - - has_hps = 'HyperParameters' in job_details - init_params['hyperparameters'] = job_details['HyperParameters'] if has_hps else {} - - if 'TrainingImage' in job_details['AlgorithmSpecification']: - init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage'] - elif 'AlgorithmName' in job_details['AlgorithmSpecification']: - init_params['algorithm_arn'] = job_details['AlgorithmSpecification']['AlgorithmName'] + init_params["role"] = job_details["RoleArn"] + init_params["train_instance_count"] = job_details["ResourceConfig"]["InstanceCount"] + init_params["train_instance_type"] = job_details["ResourceConfig"]["InstanceType"] + init_params["train_volume_size"] = job_details["ResourceConfig"]["VolumeSizeInGB"] + init_params["train_max_run"] = job_details["StoppingCondition"]["MaxRuntimeInSeconds"] + init_params["input_mode"] = job_details["AlgorithmSpecification"]["TrainingInputMode"] + init_params["base_job_name"] = job_details["TrainingJobName"] + init_params["output_path"] = job_details["OutputDataConfig"]["S3OutputPath"] + init_params["output_kms_key"] = job_details["OutputDataConfig"]["KmsKeyId"] + + has_hps = "HyperParameters" in job_details + init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {} + + if "TrainingImage" in job_details["AlgorithmSpecification"]: + init_params["image"] = job_details["AlgorithmSpecification"]["TrainingImage"] + elif "AlgorithmName" in job_details["AlgorithmSpecification"]: + init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"] else: - raise RuntimeError('Invalid AlgorithmSpecification. Either TrainingImage or ' - 'AlgorithmName is expected. None was found.') + raise RuntimeError( + "Invalid AlgorithmSpecification. Either TrainingImage or " + "AlgorithmName is expected. None was found." + ) - if 'MetricDefinitons' in job_details['AlgorithmSpecification']: - init_params['metric_definitions'] = job_details['AlgorithmSpecification']['MetricsDefinition'] + if "MetricDefinitons" in job_details["AlgorithmSpecification"]: + init_params["metric_definitions"] = job_details["AlgorithmSpecification"][ + "MetricsDefinition" + ] - if 'EnableInterContainerTrafficEncryption' in job_details: - init_params['encrypt_inter_container_traffic'] = \ - job_details['EnableInterContainerTrafficEncryption'] + if "EnableInterContainerTrafficEncryption" in job_details: + init_params["encrypt_inter_container_traffic"] = job_details[ + "EnableInterContainerTrafficEncryption" + ] subnets, security_group_ids = vpc_utils.from_dict(job_details.get(vpc_utils.VPC_CONFIG_KEY)) if subnets: - init_params['subnets'] = subnets + init_params["subnets"] = subnets if security_group_ids: - init_params['security_group_ids'] = security_group_ids + init_params["security_group_ids"] = security_group_ids - if 'InputDataConfig' in job_details and model_channel_name: - for channel in job_details['InputDataConfig']: - if channel['ChannelName'] == model_channel_name: - init_params['model_channel_name'] = model_channel_name - init_params['model_uri'] = channel['DataSource']['S3DataSource']['S3Uri'] + if "InputDataConfig" in job_details and model_channel_name: + for channel in job_details["InputDataConfig"]: + if channel["ChannelName"] == model_channel_name: + init_params["model_channel_name"] = model_channel_name + init_params["model_uri"] = channel["DataSource"]["S3DataSource"]["S3Uri"] break return init_params @@ -474,12 +546,25 @@ def delete_endpoint(self): Raises: ValueError: If the endpoint does not exist. """ - self._ensure_latest_training_job(error_message='Endpoint was not created yet') + self._ensure_latest_training_job(error_message="Endpoint was not created yet") self.sagemaker_session.delete_endpoint(self.latest_training_job.name) - def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, - output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, role=None, volume_kms_key=None): + def transformer( + self, + instance_count, + instance_type, + strategy=None, + assemble_with=None, + output_path=None, + output_kms_key=None, + accept=None, + env=None, + max_concurrent_transforms=None, + max_payload=None, + tags=None, + role=None, + volume_kms_key=None, + ): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -507,26 +592,43 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit tags = tags or self.tags if self.latest_training_job is not None: - model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role, - tags=tags) + model_name = self.sagemaker_session.create_model_from_job( + self.latest_training_job.name, role=role, tags=tags + ) else: - logging.warning('No finished training job found associated with this estimator. Please make sure' - 'this estimator is only used for building workflow config') + logging.warning( + "No finished training job found associated with this estimator. Please make sure" + "this estimator is only used for building workflow config" + ) model_name = self._current_job_name - return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, - output_path=output_path, output_kms_key=output_kms_key, accept=accept, - max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=env, tags=tags, base_transform_job_name=self.base_job_name, - volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session) + return Transformer( + model_name, + instance_count, + instance_type, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=output_kms_key, + accept=accept, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + env=env, + tags=tags, + base_transform_job_name=self.base_job_name, + volume_kms_key=volume_kms_key, + sagemaker_session=self.sagemaker_session, + ) @property def training_job_analytics(self): """Return a ``TrainingJobAnalytics`` object for the current training job. """ if self._current_job_name is None: - raise ValueError('Estimator is not associated with a TrainingJob') - return TrainingJobAnalytics(self._current_job_name, sagemaker_session=self.sagemaker_session) + raise ValueError("Estimator is not associated with a TrainingJob") + return TrainingJobAnalytics( + self._current_job_name, sagemaker_session=self.sagemaker_session + ) def get_vpc_config(self, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): """ @@ -538,7 +640,9 @@ def get_vpc_config(self, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): else: return vpc_utils.sanitize(vpc_config_override) - def _ensure_latest_training_job(self, error_message='Estimator is not associated with a training job'): + def _ensure_latest_training_job( + self, error_message="Estimator is not associated with a training job" + ): if self.latest_training_job is None: raise ValueError(error_message) @@ -563,7 +667,9 @@ def start_new(cls, estimator, inputs): # Allow file:// input only in local mode if cls._is_local_channel(inputs) or cls._is_local_channel(model_uri): if not local_mode: - raise ValueError('File URIs are supported in local mode only. Please use a S3 URI instead.') + raise ValueError( + "File URIs are supported in local mode only. Please use a S3 URI instead." + ) config = _Job._load_config(inputs, estimator) @@ -571,28 +677,31 @@ def start_new(cls, estimator, inputs): hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()} train_args = config.copy() - train_args['input_mode'] = estimator.input_mode - train_args['job_name'] = estimator._current_job_name - train_args['hyperparameters'] = hyperparameters - train_args['tags'] = estimator.tags - train_args['metric_definitions'] = estimator.metric_definitions + train_args["input_mode"] = estimator.input_mode + train_args["job_name"] = estimator._current_job_name + train_args["hyperparameters"] = hyperparameters + train_args["tags"] = estimator.tags + train_args["metric_definitions"] = estimator.metric_definitions if isinstance(inputs, s3_input): - if 'InputMode' in inputs.config: - logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.' - .format(inputs.config['InputMode'])) - train_args['input_mode'] = inputs.config['InputMode'] + if "InputMode" in inputs.config: + logging.debug( + "Selecting s3_input's input_mode ({}) for TrainingInputMode.".format( + inputs.config["InputMode"] + ) + ) + train_args["input_mode"] = inputs.config["InputMode"] if estimator.enable_network_isolation(): - train_args['enable_network_isolation'] = True + train_args["enable_network_isolation"] = True if estimator.encrypt_inter_container_traffic: - train_args['encrypt_inter_container_traffic'] = True + train_args["encrypt_inter_container_traffic"] = True if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator): - train_args['algorithm_arn'] = estimator.algorithm_arn + train_args["algorithm_arn"] = estimator.algorithm_arn else: - train_args['image'] = estimator.train_image() + train_args["image"] = estimator.train_image() estimator.sagemaker_session.train(**train_args) @@ -600,7 +709,7 @@ def start_new(cls, estimator, inputs): @classmethod def _is_local_channel(cls, input_uri): - return isinstance(input_uri, string_types) and input_uri.startswith('file://') + return isinstance(input_uri, string_types) and input_uri.startswith("file://") def wait(self, logs=True): if logs: @@ -615,12 +724,29 @@ class Estimator(EstimatorBase): algorithms that don't have their own, custom class. """ - def __init__(self, image_name, role, train_instance_count, train_instance_type, - train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, - input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, - sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None, - model_uri=None, model_channel_name='model', metric_definitions=None, - encrypt_inter_container_traffic=False): + def __init__( + self, + image_name, + role, + train_instance_count, + train_instance_type, + train_volume_size=30, + train_volume_kms_key=None, + train_max_run=24 * 60 * 60, + input_mode="File", + output_path=None, + output_kms_key=None, + base_job_name=None, + sagemaker_session=None, + hyperparameters=None, + tags=None, + subnets=None, + security_group_ids=None, + model_uri=None, + model_channel_name="model", + metric_definitions=None, + encrypt_inter_container_traffic=False, + ): """Initialize an ``Estimator`` instance. Args: @@ -680,12 +806,26 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type, """ self.image_name = image_name self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {} - super(Estimator, self).__init__(role, train_instance_count, train_instance_type, - train_volume_size, train_volume_kms_key, train_max_run, input_mode, - output_path, output_kms_key, base_job_name, sagemaker_session, - tags, subnets, security_group_ids, model_uri=model_uri, - model_channel_name=model_channel_name, metric_definitions=metric_definitions, - encrypt_inter_container_traffic=encrypt_inter_container_traffic) + super(Estimator, self).__init__( + role, + train_instance_count, + train_instance_type, + train_volume_size, + train_volume_kms_key, + train_max_run, + input_mode, + output_path, + output_kms_key, + base_job_name, + sagemaker_session, + tags, + subnets, + security_group_ids, + model_uri=model_uri, + model_channel_name=model_channel_name, + metric_definitions=metric_definitions, + encrypt_inter_container_traffic=encrypt_inter_container_traffic, + ) def train_image(self): """ @@ -706,8 +846,18 @@ def hyperparameters(self): """ return self.hyperparam_dict - def create_model(self, role=None, image=None, predictor_cls=None, serializer=None, deserializer=None, - content_type=None, accept=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, **kwargs): + def create_model( + self, + role=None, + image=None, + predictor_cls=None, + serializer=None, + deserializer=None, + content_type=None, + accept=None, + vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + **kwargs + ): """ Create a model to deploy. @@ -735,15 +885,25 @@ def create_model(self, role=None, image=None, predictor_cls=None, serializer=Non Returns: a Model ready for deployment. """ if predictor_cls is None: + def predict_wrapper(endpoint, session): - return RealTimePredictor(endpoint, session, serializer, deserializer, content_type, accept) + return RealTimePredictor( + endpoint, session, serializer, deserializer, content_type, accept + ) + predictor_cls = predict_wrapper role = role or self.role - return Model(self.model_data, image or self.train_image(), role, - vpc_config=self.get_vpc_config(vpc_config_override), - sagemaker_session=self.sagemaker_session, predictor_cls=predictor_cls, **kwargs) + return Model( + self.model_data, + image or self.train_image(), + role, + vpc_config=self.get_vpc_config(vpc_config_override), + sagemaker_session=self.sagemaker_session, + predictor_cls=predictor_cls, + **kwargs + ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -757,9 +917,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na dictionary: The transformed init_params """ - init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) + init_params = super(Estimator, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) - init_params['image_name'] = init_params.pop('image') + init_params["image_name"] = init_params.pop("image") return init_params @@ -771,15 +933,25 @@ class Framework(EstimatorBase): """ __framework_name__ = None - LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled' - LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled' - MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host' - MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options' - CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz' - - def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False, - container_log_level=logging.INFO, code_location=None, image_name=None, dependencies=None, - enable_network_isolation=False, **kwargs): + LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled" + LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled" + MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host" + MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options" + CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz" + + def __init__( + self, + entry_point, + source_dir=None, + hyperparameters=None, + enable_cloudwatch_metrics=False, + container_log_level=logging.INFO, + code_location=None, + image_name=None, + dependencies=None, + enable_network_isolation=False, + **kwargs + ): """Base class initializer. Subclasses which override ``__init__`` should invoke ``super()`` Args: @@ -827,14 +999,20 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl **kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor. """ super(Framework, self).__init__(**kwargs) - if entry_point.startswith('s3://'): - raise ValueError('Invalid entry point script: {}. Must be a path to a local file.'.format(entry_point)) + if entry_point.startswith("s3://"): + raise ValueError( + "Invalid entry point script: {}. Must be a path to a local file.".format( + entry_point + ) + ) self.entry_point = entry_point self.source_dir = source_dir self.dependencies = dependencies or [] if enable_cloudwatch_metrics: - warnings.warn('enable_cloudwatch_metrics is now deprecated and will be removed in the future.', - DeprecationWarning) + warnings.warn( + "enable_cloudwatch_metrics is now deprecated and will be removed in the future.", + DeprecationWarning, + ) self.enable_cloudwatch_metrics = False self.container_log_level = container_log_level self.code_location = code_location @@ -862,19 +1040,19 @@ def _prepare_for_training(self, job_name=None): # validate source dir will raise a ValueError if there is something wrong with the # source directory. We are intentionally not handling it because this is a critical error. - if self.source_dir and not self.source_dir.lower().startswith('s3://'): + if self.source_dir and not self.source_dir.lower().startswith("s3://"): validate_source_dir(self.entry_point, self.source_dir) # if we are in local mode with local_code=True. We want the container to just # mount the source dir instead of uploading to S3. - local_code = get_config_value('local.local_code', self.sagemaker_session.config) + local_code = get_config_value("local.local_code", self.sagemaker_session.config) if self.sagemaker_session.local_mode and local_code: # if there is no source dir, use the directory containing the entry point. if self.source_dir is None: self.source_dir = os.path.dirname(self.entry_point) self.entry_point = os.path.basename(self.entry_point) - code_dir = 'file://' + self.source_dir + code_dir = "file://" + self.source_dir script = self.entry_point elif self.enable_network_isolation() and self.entry_point: self.uploaded_code = self._stage_user_code_in_s3() @@ -900,31 +1078,33 @@ def _stage_user_code_in_s3(self): Returns: s3 uri """ - local_mode = self.output_path.startswith('file://') + local_mode = self.output_path.startswith("file://") if self.code_location is None and local_mode: code_bucket = self.sagemaker_session.default_bucket() - code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source') + code_s3_prefix = "{}/{}".format(self._current_job_name, "source") kms_key = None elif self.code_location is None: code_bucket, _ = parse_s3_url(self.output_path) - code_s3_prefix = '{}/{}'.format(self._current_job_name, 'source') + code_s3_prefix = "{}/{}".format(self._current_job_name, "source") kms_key = self.output_kms_key else: code_bucket, key_prefix = parse_s3_url(self.code_location) - code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, 'source'])) + code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) output_bucket, _ = parse_s3_url(self.output_path) kms_key = self.output_kms_key if code_bucket == output_bucket else None - return tar_and_upload_dir(session=self.sagemaker_session.boto_session, - bucket=code_bucket, - s3_key_prefix=code_s3_prefix, - script=self.entry_point, - directory=self.source_dir, - dependencies=self.dependencies, - kms_key=kms_key) + return tar_and_upload_dir( + session=self.sagemaker_session.boto_session, + bucket=code_bucket, + s3_key_prefix=code_s3_prefix, + script=self.entry_point, + directory=self.source_dir, + dependencies=self.dependencies, + kms_key=kms_key, + ) def _model_source_dir(self): """Get the appropriate value to pass as source_dir to model constructor on deploying @@ -932,7 +1112,9 @@ def _model_source_dir(self): Returns: str: Either a local or an S3 path pointing to the source_dir to be used for code by the model to be deployed """ - return self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix + return ( + self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix + ) def hyperparameters(self): """Return the hyperparameters as a dictionary to use for training. @@ -957,26 +1139,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na dictionary: The transformed init_params """ - init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) - - init_params['entry_point'] = json.loads(init_params['hyperparameters'].get(SCRIPT_PARAM_NAME)) - init_params['source_dir'] = json.loads(init_params['hyperparameters'].get(DIR_PARAM_NAME)) - init_params['enable_cloudwatch_metrics'] = json.loads( - init_params['hyperparameters'].get(CLOUDWATCH_METRICS_PARAM_NAME)) - init_params['container_log_level'] = json.loads( - init_params['hyperparameters'].get(CONTAINER_LOG_LEVEL_PARAM_NAME)) + init_params = super(Framework, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) + + init_params["entry_point"] = json.loads( + init_params["hyperparameters"].get(SCRIPT_PARAM_NAME) + ) + init_params["source_dir"] = json.loads(init_params["hyperparameters"].get(DIR_PARAM_NAME)) + init_params["enable_cloudwatch_metrics"] = json.loads( + init_params["hyperparameters"].get(CLOUDWATCH_METRICS_PARAM_NAME) + ) + init_params["container_log_level"] = json.loads( + init_params["hyperparameters"].get(CONTAINER_LOG_LEVEL_PARAM_NAME) + ) hyperparameters = {} - for k, v in init_params['hyperparameters'].items(): + for k, v in init_params["hyperparameters"].items(): # Tuning jobs add this special hyperparameter which is not JSON serialized - if k == '_tuning_objective_metric': + if k == "_tuning_objective_metric": if v.startswith('"') and v.endswith('"'): v = v.strip('"') hyperparameters[k] = v else: hyperparameters[k] = json.loads(v) - init_params['hyperparameters'] = hyperparameters + init_params["hyperparameters"] = hyperparameters return init_params @@ -992,14 +1180,16 @@ def train_image(self): if self.image_name: return self.image_name else: - return create_image_uri(self.sagemaker_session.boto_region_name, - self.__framework_name__, - self.train_instance_type, - self.framework_version, # pylint: disable=no-member - py_version=self.py_version) # pylint: disable=no-member + return create_image_uri( + self.sagemaker_session.boto_region_name, + self.__framework_name__, + self.train_instance_type, + self.framework_version, # pylint: disable=no-member + py_version=self.py_version, # pylint: disable=no-member + ) @classmethod - def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='model'): + def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="model"): """Attach to an existing training job. Create an Estimator bound to an existing training job, each subclass is responsible to implement @@ -1028,12 +1218,15 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m Returns: Instance of the calling ``Estimator`` Class with the attached training job. """ - estimator = super(Framework, cls).attach(training_job_name, sagemaker_session, model_channel_name) + estimator = super(Framework, cls).attach( + training_job_name, sagemaker_session, model_channel_name + ) # pylint gets confused thinking that estimator is an EstimatorBase instance, but it actually # is a Framework or any of its derived classes. We can safely ignore the no-member errors. estimator.uploaded_code = UploadedCode( - estimator.source_dir, estimator.entry_point) # pylint: disable=no-member + estimator.source_dir, estimator.entry_point # pylint: disable=no-member + ) return estimator @staticmethod @@ -1050,9 +1243,23 @@ def _update_init_params(cls, hp, tf_arguments): updated_params[argument] = value return updated_params - def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, - output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None): + def transformer( + self, + instance_count, + instance_type, + strategy=None, + assemble_with=None, + output_path=None, + output_kms_key=None, + accept=None, + env=None, + max_concurrent_transforms=None, + max_payload=None, + tags=None, + role=None, + model_server_workers=None, + volume_kms_key=None, + ): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -1085,34 +1292,50 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit model = self.create_model(role=role, model_server_workers=model_server_workers) container_def = model.prepare_container_def(instance_type) - model_name = model.name or name_from_image(container_def['Image']) + model_name = model.name or name_from_image(container_def["Image"]) vpc_config = model.vpc_config tags = tags or self.tags - self.sagemaker_session.create_model(model_name, role, container_def, vpc_config, tags=tags) + self.sagemaker_session.create_model( + model_name, role, container_def, vpc_config, tags=tags + ) transform_env = model.env.copy() if env is not None: transform_env.update(env) else: - logging.warning('No finished training job found associated with this estimator. Please make sure' - 'this estimator is only used for building workflow config') + logging.warning( + "No finished training job found associated with this estimator. Please make sure" + "this estimator is only used for building workflow config" + ) model_name = self._current_job_name transform_env = env or {} tags = tags or self.tags - return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, - output_path=output_path, output_kms_key=output_kms_key, accept=accept, - max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=transform_env, tags=tags, base_transform_job_name=self.base_job_name, - volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session) + return Transformer( + model_name, + instance_count, + instance_type, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=output_kms_key, + accept=accept, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + env=transform_env, + tags=tags, + base_transform_job_name=self.base_job_name, + volume_kms_key=volume_kms_key, + sagemaker_session=self.sagemaker_session, + ) def _s3_uri_prefix(channel_name, s3_data): if isinstance(s3_data, s3_input): - s3_uri = s3_data.config['DataSource']['S3DataSource']['S3Uri'] + s3_uri = s3_data.config["DataSource"]["S3DataSource"]["S3Uri"] else: s3_uri = s3_data - if not s3_uri.startswith('s3://'): - raise ValueError('Expecting an s3 uri. Got {}'.format(s3_uri)) + if not s3_uri.startswith("s3://"): + raise ValueError("Expecting an s3 uri. Got {}".format(s3_uri)) return {channel_name: s3_uri[5:]} @@ -1126,8 +1349,12 @@ def _s3_uri_without_prefix_from_input(input_data): response.update(_s3_uri_prefix(channel_name, channel_s3_uri)) return response elif isinstance(input_data, str): - return _s3_uri_prefix('training', input_data) + return _s3_uri_prefix("training", input_data) elif isinstance(input_data, s3_input): - return _s3_uri_prefix('training', input_data) + return _s3_uri_prefix("training", input_data) else: - raise ValueError('Unrecognized type for S3 input data config - not str or s3_input: {}'.format(input_data)) + raise ValueError( + "Unrecognized type for S3 input data config - not str or s3_input: {}".format( + input_data + ) + ) diff --git a/src/sagemaker/fw_registry.py b/src/sagemaker/fw_registry.py index 9cab601998..19772a6fa6 100644 --- a/src/sagemaker/fw_registry.py +++ b/src/sagemaker/fw_registry.py @@ -16,66 +16,21 @@ from sagemaker.utils import get_ecr_image_uri_prefix image_registry_map = { - "us-west-1": { - "sparkml-serving": "746614075791", - "scikit-learn": "746614075791" - }, - "us-west-2": { - "sparkml-serving": "246618743249", - "scikit-learn": "246618743249" - }, - "us-east-1": { - "sparkml-serving": "683313688378", - "scikit-learn": "683313688378" - }, - "us-east-2": { - "sparkml-serving": "257758044811", - "scikit-learn": "257758044811" - }, - "ap-northeast-1": { - "sparkml-serving": "354813040037", - "scikit-learn": "354813040037" - }, - "ap-northeast-2": { - "sparkml-serving": "366743142698", - "scikit-learn": "366743142698" - }, - "ap-southeast-1": { - "sparkml-serving": "121021644041", - "scikit-learn": "121021644041" - }, - "ap-southeast-2": { - "sparkml-serving": "783357654285", - "scikit-learn": "783357654285" - }, - "ap-south-1": { - "sparkml-serving": "720646828776", - "scikit-learn": "720646828776" - }, - "eu-west-1": { - "sparkml-serving": "141502667606", - "scikit-learn": "141502667606" - }, - "eu-west-2": { - "sparkml-serving": "764974769150", - "scikit-learn": "764974769150" - }, - "eu-central-1": { - "sparkml-serving": "492215442770", - "scikit-learn": "492215442770" - }, - "ca-central-1": { - "sparkml-serving": "341280168497", - "scikit-learn": "341280168497" - }, - "us-gov-west-1": { - "sparkml-serving": "414596584902", - "scikit-learn": "414596584902" - }, - "us-iso-east-1": { - "sparkml-serving": "833128469047", - "scikit-learn": "833128469047" - } + "us-west-1": {"sparkml-serving": "746614075791", "scikit-learn": "746614075791"}, + "us-west-2": {"sparkml-serving": "246618743249", "scikit-learn": "246618743249"}, + "us-east-1": {"sparkml-serving": "683313688378", "scikit-learn": "683313688378"}, + "us-east-2": {"sparkml-serving": "257758044811", "scikit-learn": "257758044811"}, + "ap-northeast-1": {"sparkml-serving": "354813040037", "scikit-learn": "354813040037"}, + "ap-northeast-2": {"sparkml-serving": "366743142698", "scikit-learn": "366743142698"}, + "ap-southeast-1": {"sparkml-serving": "121021644041", "scikit-learn": "121021644041"}, + "ap-southeast-2": {"sparkml-serving": "783357654285", "scikit-learn": "783357654285"}, + "ap-south-1": {"sparkml-serving": "720646828776", "scikit-learn": "720646828776"}, + "eu-west-1": {"sparkml-serving": "141502667606", "scikit-learn": "141502667606"}, + "eu-west-2": {"sparkml-serving": "764974769150", "scikit-learn": "764974769150"}, + "eu-central-1": {"sparkml-serving": "492215442770", "scikit-learn": "492215442770"}, + "ca-central-1": {"sparkml-serving": "341280168497", "scikit-learn": "341280168497"}, + "us-gov-west-1": {"sparkml-serving": "414596584902", "scikit-learn": "414596584902"}, + "us-iso-east-1": {"sparkml-serving": "833128469047", "scikit-learn": "833128469047"}, } diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index f186fb8780..c05511ef70 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -23,35 +23,48 @@ from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN -_TAR_SOURCE_FILENAME = 'source.tar.gz' +_TAR_SOURCE_FILENAME = "source.tar.gz" -UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name']) +UploadedCode = namedtuple("UserCode", ["s3_prefix", "script_name"]) """sagemaker.fw_utils.UserCode: An object containing the S3 prefix and script name. This is for the source code used for the entry point with an ``Estimator``. It can be instantiated with positional or keyword arguments. """ -EMPTY_FRAMEWORK_VERSION_WARNING = 'No framework_version specified, defaulting to version {}.' -LATER_FRAMEWORK_VERSION_WARNING = 'This is not the latest supported version. ' \ - 'If you would like to use version {latest}, ' \ - 'please add framework_version={latest} to your constructor.' -PYTHON_2_DEPRECATION_WARNING = 'The Python 2 {framework} images will be soon deprecated and may not be ' \ - 'supported for newer upcoming versions of the {framework} images.\n' \ - 'Please set the argument \"py_version=\'py3\'\" to use the Python 3 {framework} image.' - - -EMPTY_FRAMEWORK_VERSION_ERROR = 'framework_version is required for script mode estimator. ' \ - 'Please add framework_version={} to your constructor to avoid this error.' - -VALID_PY_VERSIONS = ['py2', 'py3'] -VALID_EIA_FRAMEWORKS = ['tensorflow', 'tensorflow-serving', 'mxnet', 'mxnet-serving'] -VALID_ACCOUNTS_BY_REGION = {'us-gov-west-1': '246785580436', - 'us-iso-east-1': '744548109606'} - - -def create_image_uri(region, framework, instance_type, framework_version, py_version=None, - account='520713654638', accelerator_type=None, optimized_families=None): +EMPTY_FRAMEWORK_VERSION_WARNING = "No framework_version specified, defaulting to version {}." +LATER_FRAMEWORK_VERSION_WARNING = ( + "This is not the latest supported version. " + "If you would like to use version {latest}, " + "please add framework_version={latest} to your constructor." +) +PYTHON_2_DEPRECATION_WARNING = ( + "The Python 2 {framework} images will be soon deprecated and may not be " + "supported for newer upcoming versions of the {framework} images.\n" + "Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image." +) + + +EMPTY_FRAMEWORK_VERSION_ERROR = ( + "framework_version is required for script mode estimator. " + "Please add framework_version={} to your constructor to avoid this error." +) + +VALID_PY_VERSIONS = ["py2", "py3"] +VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"] +VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"} + + +def create_image_uri( + region, + framework, + instance_type, + framework_version, + py_version=None, + account="520713654638", + accelerator_type=None, + optimized_families=None, +): """Return the ECR URI of an image. Args: @@ -71,57 +84,69 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver optimized_families = optimized_families or [] if py_version and py_version not in VALID_PY_VERSIONS: - raise ValueError('invalid py_version argument: {}'.format(py_version)) + raise ValueError("invalid py_version argument: {}".format(py_version)) # Handle Account Number for Gov Cloud account = VALID_ACCOUNTS_BY_REGION.get(region, account) # Handle Local Mode - if instance_type.startswith('local'): - device_type = 'cpu' if instance_type == 'local' else 'gpu' - elif not instance_type.startswith('ml.'): - raise ValueError('{} is not a valid SageMaker instance type. See: ' - 'https://aws.amazon.com/sagemaker/pricing/instance-types/'.format(instance_type)) + if instance_type.startswith("local"): + device_type = "cpu" if instance_type == "local" else "gpu" + elif not instance_type.startswith("ml."): + raise ValueError( + "{} is not a valid SageMaker instance type. See: " + "https://aws.amazon.com/sagemaker/pricing/instance-types/".format(instance_type) + ) else: - family = instance_type.split('.')[1] + family = instance_type.split(".")[1] # For some frameworks, we have optimized images for specific families, e.g c5 or p3. In those cases, # we use the family name in the image tag. In other cases, we use 'cpu' or 'gpu'. if family in optimized_families: device_type = family - elif family[0] in ['g', 'p']: - device_type = 'gpu' + elif family[0] in ["g", "p"]: + device_type = "gpu" else: - device_type = 'cpu' + device_type = "cpu" if py_version: tag = "{}-{}-{}".format(framework_version, device_type, py_version) else: tag = "{}-{}".format(framework_version, device_type) - if _accelerator_type_valid_for_framework(framework=framework, accelerator_type=accelerator_type, - optimized_families=optimized_families): - framework += '-eia' + if _accelerator_type_valid_for_framework( + framework=framework, + accelerator_type=accelerator_type, + optimized_families=optimized_families, + ): + framework += "-eia" - return "{}/sagemaker-{}:{}" \ - .format(get_ecr_image_uri_prefix(account, region), framework, tag) + return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag) -def _accelerator_type_valid_for_framework(framework, accelerator_type=None, optimized_families=None): +def _accelerator_type_valid_for_framework( + framework, accelerator_type=None, optimized_families=None +): if accelerator_type is None: return False if framework not in VALID_EIA_FRAMEWORKS: - raise ValueError('{} is not supported with Amazon Elastic Inference. Currently only ' - 'Python-based TensorFlow and MXNet are supported.'.format(framework)) + raise ValueError( + "{} is not supported with Amazon Elastic Inference. Currently only " + "Python-based TensorFlow and MXNet are supported.".format(framework) + ) if optimized_families: - raise ValueError('Neo does not support Amazon Elastic Inference.') + raise ValueError("Neo does not support Amazon Elastic Inference.") - if not accelerator_type.startswith('ml.eia') and not accelerator_type == 'local_sagemaker_notebook': - raise ValueError('{} is not a valid SageMaker Elastic Inference accelerator type. ' - 'See: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html' - .format(accelerator_type)) + if ( + not accelerator_type.startswith("ml.eia") + and not accelerator_type == "local_sagemaker_notebook" + ): + raise ValueError( + "{} is not a valid SageMaker Elastic Inference accelerator type. " + "See: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html".format(accelerator_type) + ) return True @@ -138,13 +163,16 @@ def validate_source_dir(script, directory): """ if directory: if not os.path.isfile(os.path.join(directory, script)): - raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory)) + raise ValueError( + 'No file named "{}" was found in directory "{}".'.format(script, directory) + ) return True -def tar_and_upload_dir(session, bucket, s3_key_prefix, script, - directory=None, dependencies=None, kms_key=None): +def tar_and_upload_dir( + session, bucket, s3_key_prefix, script, directory=None, dependencies=None, kms_key=None +): """Package source files and upload a compress tar file to S3. The S3 location will be ``s3:///s3_key_prefix/sourcedir.tar.gz``. @@ -173,29 +201,30 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name. """ - if directory and directory.lower().startswith('s3://'): + if directory and directory.lower().startswith("s3://"): return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script)) script_name = script if directory else os.path.basename(script) dependencies = dependencies or [] - key = '%s/sourcedir.tar.gz' % s3_key_prefix + key = "%s/sourcedir.tar.gz" % s3_key_prefix tmp = tempfile.mkdtemp() try: source_files = _list_files_to_compress(script, directory) + dependencies - tar_file = sagemaker.utils.create_tar_file(source_files, - os.path.join(tmp, _TAR_SOURCE_FILENAME)) + tar_file = sagemaker.utils.create_tar_file( + source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME) + ) if kms_key: - extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key} + extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key} else: extra_args = None - session.resource('s3').Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args) + session.resource("s3").Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args) finally: shutil.rmtree(tmp) - return UploadedCode(s3_prefix='s3://%s/%s' % (bucket, key), script_name=script_name) + return UploadedCode(s3_prefix="s3://%s/%s" % (bucket, key), script_name=script_name) def _list_files_to_compress(script, directory): @@ -235,19 +264,24 @@ def framework_name_from_image(image_name): # extract framework, python version and image tag # We must support both the legacy and current image name format. name_pattern = re.compile( - r'^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$') # noqa - legacy_name_pattern = re.compile( - r'^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$') + r"^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 + ) + legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$") name_match = name_pattern.match(sagemaker_match.group(9)) legacy_match = legacy_name_pattern.match(sagemaker_match.group(9)) if name_match is not None: - fw, scriptmode, ver, device, py = name_match.group(1), name_match.group(2), name_match.group(3),\ - name_match.group(4), name_match.group(5) - return fw, py, '{}-{}-{}'.format(ver, device, py), scriptmode + fw, scriptmode, ver, device, py = ( + name_match.group(1), + name_match.group(2), + name_match.group(3), + name_match.group(4), + name_match.group(5), + ) + return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode elif legacy_match is not None: - return legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None + return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None) else: return None, None, None, None @@ -261,7 +295,7 @@ def framework_version_from_tag(image_tag): Returns: str: The framework version. """ - tag_pattern = re.compile('^(.*)-(cpu|gpu)-(py2|py3)$') + tag_pattern = re.compile("^(.*)-(cpu|gpu)-(py2|py3)$") tag_match = tag_pattern.match(image_tag) return None if tag_match is None else tag_match.group(1) @@ -280,7 +314,7 @@ def parse_s3_url(url): parsed_url = urlparse(url) if parsed_url.scheme != "s3": raise ValueError("Expecting 's3' scheme, got: {} in {}".format(parsed_url.scheme, url)) - return parsed_url.netloc, parsed_url.path.lstrip('/') + return parsed_url.netloc, parsed_url.path.lstrip("/") def model_code_key_prefix(code_location_key_prefix, model_name, image): @@ -299,14 +333,14 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image): str: the key prefix to be used in uploading code """ training_job_name = sagemaker.utils.name_from_image(image) - return '/'.join(filter(None, [code_location_key_prefix, model_name or training_job_name])) + return "/".join(filter(None, [code_location_key_prefix, model_name or training_job_name])) def empty_framework_version_warning(default_version, latest_version): msgs = [EMPTY_FRAMEWORK_VERSION_WARNING.format(default_version)] if default_version != latest_version: msgs.append(LATER_FRAMEWORK_VERSION_WARNING.format(latest=latest_version)) - return ' '.join(msgs) + return " ".join(msgs) def python_deprecation_warning(framework): diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index bac1b78b3a..ee2658fe02 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -51,36 +51,50 @@ def wait(self): @staticmethod def _load_config(inputs, estimator, expand_role=True, validate_uri=True): input_config = _Job._format_inputs_to_input_config(inputs, validate_uri) - role = estimator.sagemaker_session.expand_role(estimator.role) if expand_role else estimator.role + role = ( + estimator.sagemaker_session.expand_role(estimator.role) + if expand_role + else estimator.role + ) output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key) - resource_config = _Job._prepare_resource_config(estimator.train_instance_count, - estimator.train_instance_type, - estimator.train_volume_size, - estimator.train_volume_kms_key) + resource_config = _Job._prepare_resource_config( + estimator.train_instance_count, + estimator.train_instance_type, + estimator.train_volume_size, + estimator.train_volume_kms_key, + ) stop_condition = _Job._prepare_stop_condition(estimator.train_max_run) vpc_config = estimator.get_vpc_config() - model_channel = _Job._prepare_channel(input_config, estimator.model_uri, estimator.model_channel_name, - validate_uri, content_type='application/x-sagemaker-model', - input_mode='File') + model_channel = _Job._prepare_channel( + input_config, + estimator.model_uri, + estimator.model_channel_name, + validate_uri, + content_type="application/x-sagemaker-model", + input_mode="File", + ) if model_channel: input_config = [] if input_config is None else input_config input_config.append(model_channel) if estimator.enable_network_isolation(): - code_channel = _Job._prepare_channel(input_config, estimator.code_uri, estimator.code_channel_name, - validate_uri) + code_channel = _Job._prepare_channel( + input_config, estimator.code_uri, estimator.code_channel_name, validate_uri + ) if code_channel: input_config = [] if input_config is None else input_config input_config.append(code_channel) - return {'input_config': input_config, - 'role': role, - 'output_config': output_config, - 'resource_config': resource_config, - 'stop_condition': stop_condition, - 'vpc_config': vpc_config} + return { + "input_config": input_config, + "role": role, + "output_config": output_config, + "resource_config": resource_config, + "stop_condition": stop_condition, + "vpc_config": vpc_config, + } @staticmethod def _format_inputs_to_input_config(inputs, validate_uri=True): @@ -89,16 +103,17 @@ def _format_inputs_to_input_config(inputs, validate_uri=True): # Deferred import due to circular dependency from sagemaker.amazon.amazon_estimator import RecordSet + if isinstance(inputs, RecordSet): inputs = inputs.data_channel() input_dict = {} if isinstance(inputs, string_types): - input_dict['training'] = _Job._format_string_uri_input(inputs, validate_uri) + input_dict["training"] = _Job._format_string_uri_input(inputs, validate_uri) elif isinstance(inputs, s3_input): - input_dict['training'] = inputs + input_dict["training"] = inputs elif isinstance(inputs, file_input): - input_dict['training'] = inputs + input_dict["training"] = inputs elif isinstance(inputs, dict): for k, v in inputs.items(): input_dict[k] = _Job._format_string_uri_input(v, validate_uri) @@ -106,27 +121,32 @@ def _format_inputs_to_input_config(inputs, validate_uri=True): input_dict = _Job._format_record_set_list_input(inputs) else: raise ValueError( - 'Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs)) + "Cannot format input {}. Expecting one of str, dict or s3_input".format(inputs) + ) - channels = [_Job._convert_input_to_channel(name, input) for name, input in input_dict.items()] + channels = [ + _Job._convert_input_to_channel(name, input) for name, input in input_dict.items() + ] return channels @staticmethod def _convert_input_to_channel(channel_name, channel_s3_input): channel_config = channel_s3_input.config.copy() - channel_config['ChannelName'] = channel_name + channel_config["ChannelName"] = channel_name return channel_config @staticmethod def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None): - if isinstance(uri_input, str) and validate_uri and uri_input.startswith('s3://'): + if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"): return s3_input(uri_input, content_type=content_type, input_mode=input_mode) - elif isinstance(uri_input, str) and validate_uri and uri_input.startswith('file://'): + elif isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"): return file_input(uri_input) elif isinstance(uri_input, str) and validate_uri: - raise ValueError('URI input {} must be a valid S3 or FILE URI: must start with "s3://" or ' - '"file://"'.format(uri_input)) + raise ValueError( + 'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or ' + '"file://"'.format(uri_input) + ) elif isinstance(uri_input, str): return s3_input(uri_input, content_type=content_type, input_mode=input_mode) elif isinstance(uri_input, s3_input): @@ -134,41 +154,66 @@ def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, in elif isinstance(uri_input, file_input): return uri_input else: - raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(uri_input)) + raise ValueError( + "Cannot format input {}. Expecting one of str, s3_input, or file_input".format( + uri_input + ) + ) @staticmethod - def _prepare_channel(input_config, channel_uri=None, channel_name=None, validate_uri=True, content_type=None, - input_mode=None): + def _prepare_channel( + input_config, + channel_uri=None, + channel_name=None, + validate_uri=True, + content_type=None, + input_mode=None, + ): if not channel_uri: return elif not channel_name: - raise ValueError('Expected a channel name if a channel URI {} is specified'.format(channel_uri)) + raise ValueError( + "Expected a channel name if a channel URI {} is specified".format(channel_uri) + ) if input_config: for existing_channel in input_config: - if existing_channel['ChannelName'] == channel_name: - raise ValueError('Duplicate channel {} not allowed.'.format(channel_name)) + if existing_channel["ChannelName"] == channel_name: + raise ValueError("Duplicate channel {} not allowed.".format(channel_name)) - channel_input = _Job._format_string_uri_input(channel_uri, validate_uri, content_type, input_mode) + channel_input = _Job._format_string_uri_input( + channel_uri, validate_uri, content_type, input_mode + ) channel = _Job._convert_input_to_channel(channel_name, channel_input) return channel @staticmethod def _format_model_uri_input(model_uri, validate_uri=True): - if isinstance(model_uri, string_types)and validate_uri and model_uri.startswith('s3://'): - return s3_input(model_uri, input_mode='File', distribution='FullyReplicated', - content_type='application/x-sagemaker-model') - elif isinstance(model_uri, string_types) and validate_uri and model_uri.startswith('file://'): + if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("s3://"): + return s3_input( + model_uri, + input_mode="File", + distribution="FullyReplicated", + content_type="application/x-sagemaker-model", + ) + elif ( + isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://") + ): return file_input(model_uri) elif isinstance(model_uri, string_types) and validate_uri: - raise ValueError('Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' - '"file://') + raise ValueError( + 'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://' + ) elif isinstance(model_uri, string_types): - return s3_input(model_uri, input_mode='File', distribution='FullyReplicated', - content_type='application/x-sagemaker-model') + return s3_input( + model_uri, + input_mode="File", + distribution="FullyReplicated", + content_type="application/x-sagemaker-model", + ) else: - raise ValueError('Cannot format model URI {}. Expecting str'.format(model_uri)) + raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri)) @staticmethod def _format_record_set_list_input(inputs): @@ -178,10 +223,10 @@ def _format_record_set_list_input(inputs): input_dict = {} for record in inputs: if not isinstance(record, RecordSet): - raise ValueError('List compatible only with RecordSets.') + raise ValueError("List compatible only with RecordSets.") if record.channel in input_dict: - raise ValueError('Duplicate channels not allowed.') + raise ValueError("Duplicate channels not allowed.") input_dict[record.channel] = record.records_s3_input() @@ -189,24 +234,26 @@ def _format_record_set_list_input(inputs): @staticmethod def _prepare_output_config(s3_path, kms_key_id): - config = {'S3OutputPath': s3_path} + config = {"S3OutputPath": s3_path} if kms_key_id is not None: - config['KmsKeyId'] = kms_key_id + config["KmsKeyId"] = kms_key_id return config @staticmethod def _prepare_resource_config(instance_count, instance_type, volume_size, train_volume_kms_key): - resource_config = {'InstanceCount': instance_count, - 'InstanceType': instance_type, - 'VolumeSizeInGB': volume_size} + resource_config = { + "InstanceCount": instance_count, + "InstanceType": instance_type, + "VolumeSizeInGB": volume_size, + } if train_volume_kms_key is not None: - resource_config['VolumeKmsKeyId'] = train_volume_kms_key + resource_config["VolumeKmsKeyId"] = train_volume_kms_key return resource_config @staticmethod def _prepare_stop_condition(max_run): - return {'MaxRuntimeInSeconds': max_run} + return {"MaxRuntimeInSeconds": max_run} @property def name(self): diff --git a/src/sagemaker/local/__init__.py b/src/sagemaker/local/__init__.py index 167b07ff2a..ad5dea8c3d 100644 --- a/src/sagemaker/local/__init__.py +++ b/src/sagemaker/local/__init__.py @@ -12,5 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from .local_session import (file_input, LocalSagemakerClient, # noqa: F401 - LocalSagemakerRuntimeClient, LocalSession) +from .local_session import ( # noqa: F401 + file_input, + LocalSagemakerClient, + LocalSagemakerRuntimeClient, + LocalSession, +) diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index 334024e968..88d5dde41b 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -43,9 +43,9 @@ def get_data_source_instance(data_source, sagemaker_session): """ parsed_uri = urlparse(data_source) - if parsed_uri.scheme == 'file': + if parsed_uri.scheme == "file": return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path) - elif parsed_uri.scheme == 's3': + elif parsed_uri.scheme == "s3": return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session) @@ -62,12 +62,12 @@ def get_splitter_instance(split_type): """ if split_type is None: return NoneSplitter() - elif split_type == 'Line': + elif split_type == "Line": return LineSplitter() - elif split_type == 'RecordIO': + elif split_type == "RecordIO": return RecordIOSplitter() else: - raise ValueError('Invalid Split Type: %s' % split_type) + raise ValueError("Invalid Split Type: %s" % split_type) def get_batch_strategy_instance(strategy, splitter): @@ -80,16 +80,17 @@ def get_batch_strategy_instance(strategy, splitter): Returns :class:`sagemaker.local.data.BatchStrategy`: an Instance of a BatchStrategy """ - if strategy == 'SingleRecord': + if strategy == "SingleRecord": return SingleRecordStrategy(splitter) - elif strategy == 'MultiRecord': + elif strategy == "MultiRecord": return MultiRecordStrategy(splitter) else: - raise ValueError('Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"') + raise ValueError( + 'Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"' + ) class DataSource(with_metaclass(ABCMeta, object)): - @abstractmethod def get_file_list(self): """Retrieve the list of absolute paths to all the files in this data source. @@ -114,7 +115,7 @@ class LocalFileDataSource(DataSource): def __init__(self, root_path): self.root_path = os.path.abspath(root_path) if not os.path.exists(self.root_path): - raise RuntimeError('Invalid data source: %s does not exist.' % self.root_path) + raise RuntimeError("Invalid data source: %s does not exist." % self.root_path) def get_file_list(self): """Retrieve the list of absolute paths to all the files in this data source. @@ -123,8 +124,11 @@ def get_file_list(self): List[str] List of absolute paths. """ if os.path.isdir(self.root_path): - return [os.path.join(self.root_path, f) for f in os.listdir(self.root_path) - if os.path.isfile(os.path.join(self.root_path, f))] + return [ + os.path.join(self.root_path, f) + for f in os.listdir(self.root_path) + if os.path.isfile(os.path.join(self.root_path, f)) + ] else: return [self.root_path] @@ -156,7 +160,9 @@ def __init__(self, bucket, prefix, sagemaker_session): """ # Create a temporary dir to store the S3 contents - root_dir = sagemaker.utils.get_config_value('local.container_root', sagemaker_session.config) + root_dir = sagemaker.utils.get_config_value( + "local.container_root", sagemaker_session.config + ) if root_dir: root_dir = os.path.abspath(root_dir) @@ -164,8 +170,8 @@ def __init__(self, bucket, prefix, sagemaker_session): # Docker cannot mount Mac OS /var folder properly see # https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600 # Only apply this workaround if the user didn't provide an alternate storage root dir. - if root_dir is None and platform.system() == 'Darwin': - working_dir = '/private{}'.format(working_dir) + if root_dir is None and platform.system() == "Darwin": + working_dir = "/private{}".format(working_dir) sagemaker.utils.download_folder(bucket, prefix, working_dir, sagemaker_session) self.files = LocalFileDataSource(working_dir) @@ -188,7 +194,6 @@ def get_root_dir(self): class Splitter(with_metaclass(ABCMeta, object)): - @abstractmethod def split(self, file): """Split a file into records using a specific strategy @@ -216,7 +221,7 @@ def split(self, file): Returns: generator for the individual records that were split from the file """ - with open(file, 'r') as f: + with open(file, "r") as f: yield f.read() @@ -234,7 +239,7 @@ def split(self, file): Returns: generator for the individual records that were split from the file """ - with open(file, 'r') as f: + with open(file, "r") as f: for line in f: yield line @@ -245,6 +250,7 @@ class RecordIOSplitter(Splitter): Not useful for string content. """ + def split(self, file): """Split a file into records using a specific strategy @@ -255,13 +261,12 @@ def split(self, file): Returns: generator for the individual records that were split from the file """ - with open(file, 'rb') as f: + with open(file, "rb") as f: for record in sagemaker.amazon.common.read_recordio(f): yield record class BatchStrategy(with_metaclass(ABCMeta, object)): - def __init__(self, splitter): """Create a Batch Strategy Instance @@ -290,6 +295,7 @@ class MultiRecordStrategy(BatchStrategy): Will group up as many records as possible within the payload specified. """ + def pad(self, file, size=6): """Group together as many records as possible to fit in the specified size @@ -301,7 +307,7 @@ def pad(self, file, size=6): Returns: generator of records """ - buffer = '' + buffer = "" for element in self.splitter.split(file): if _payload_size_within_limit(buffer + element, size): buffer += element @@ -318,6 +324,7 @@ class SingleRecordStrategy(BatchStrategy): If a single record does not fit within the payload specified it will throw a RuntimeError. """ + def pad(self, file, size=6): """Group together as many records as possible to fit in the specified size @@ -360,4 +367,4 @@ def _validate_payload_size(payload, size): if _payload_size_within_limit(payload, size): return True - raise RuntimeError('Record is larger than %sMB. Please increase your max_payload' % size) + raise RuntimeError("Record is larger than %sMB. Please increase your max_payload" % size) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 7ac8a2383c..a544b4cba0 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -27,71 +27,76 @@ logger = logging.getLogger(__name__) -_UNUSED_ARN = 'local:arn-does-not-matter' +_UNUSED_ARN = "local:arn-does-not-matter" HEALTH_CHECK_TIMEOUT_LIMIT = 120 class _LocalTrainingJob(object): - _STARTING = 'Starting' - _TRAINING = 'Training' - _COMPLETED = 'Completed' - _states = ['Starting', 'Training', 'Completed'] + _STARTING = "Starting" + _TRAINING = "Training" + _COMPLETED = "Completed" + _states = ["Starting", "Training", "Completed"] def __init__(self, container): self.container = container self.model_artifacts = None - self.state = 'created' + self.state = "created" self.start_time = None self.end_time = None def start(self, input_data_config, output_data_config, hyperparameters, job_name): for channel in input_data_config: - if channel['DataSource'] and 'S3DataSource' in channel['DataSource']: - data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType'] - data_uri = channel['DataSource']['S3DataSource']['S3Uri'] - elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']: - data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType'] - data_uri = channel['DataSource']['FileDataSource']['FileUri'] + if channel["DataSource"] and "S3DataSource" in channel["DataSource"]: + data_distribution = channel["DataSource"]["S3DataSource"]["S3DataDistributionType"] + data_uri = channel["DataSource"]["S3DataSource"]["S3Uri"] + elif channel["DataSource"] and "FileDataSource" in channel["DataSource"]: + data_distribution = channel["DataSource"]["FileDataSource"][ + "FileDataDistributionType" + ] + data_uri = channel["DataSource"]["FileDataSource"]["FileUri"] else: - raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']') + raise ValueError( + "Need channel['DataSource'] to have ['S3DataSource'] or ['FileDataSource']" + ) # use a single Data URI - this makes handling S3 and File Data easier down the stack - channel['DataUri'] = data_uri + channel["DataUri"] = data_uri - if data_distribution != 'FullyReplicated': - raise RuntimeError('DataDistribution: %s is not currently supported in Local Mode' % - data_distribution) + if data_distribution != "FullyReplicated": + raise RuntimeError( + "DataDistribution: %s is not currently supported in Local Mode" + % data_distribution + ) self.start_time = datetime.datetime.now() self.state = self._TRAINING - self.model_artifacts = self.container.train(input_data_config, output_data_config, hyperparameters, job_name) + self.model_artifacts = self.container.train( + input_data_config, output_data_config, hyperparameters, job_name + ) self.end = datetime.datetime.now() self.state = self._COMPLETED def describe(self): response = { - 'ResourceConfig': { - 'InstanceCount': self.container.instance_count - }, - 'TrainingJobStatus': self.state, - 'TrainingStartTime': self.start_time, - 'TrainingEndTime': self.end_time, - 'ModelArtifacts': { - 'S3ModelArtifacts': self.model_artifacts - } + "ResourceConfig": {"InstanceCount": self.container.instance_count}, + "TrainingJobStatus": self.state, + "TrainingStartTime": self.start_time, + "TrainingEndTime": self.end_time, + "ModelArtifacts": {"S3ModelArtifacts": self.model_artifacts}, } return response class _LocalTransformJob(object): - _CREATING = 'Creating' - _COMPLETED = 'Completed' + _CREATING = "Creating" + _COMPLETED = "Completed" def __init__(self, transform_job_name, model_name, local_session=None): from sagemaker.local import LocalSession + self.local_session = local_session or LocalSession() local_client = self.local_session.sagemaker_client @@ -100,7 +105,7 @@ def __init__(self, transform_job_name, model_name, local_session=None): # TODO - support SageMaker Models not just local models. This is not # ideal but it may be a good thing to do. - self.primary_container = local_client.describe_model(model_name)['PrimaryContainer'] + self.primary_container = local_client.describe_model(model_name)["PrimaryContainer"] self.container = None self.start_time = None self.end_time = None @@ -122,26 +127,28 @@ def start(self, input_data, output_data, transform_resources, **kwargs): self.input_data = input_data self.output_data = output_data - image = self.primary_container['Image'] - instance_type = transform_resources['InstanceType'] + image = self.primary_container["Image"] + instance_type = transform_resources["InstanceType"] instance_count = 1 environment = self._get_container_environment(**kwargs) # Start the container, pass the environment and wait for it to start up - self.container = _SageMakerContainer(instance_type, instance_count, image, self.local_session) - self.container.serve(self.primary_container['ModelDataUrl'], environment) + self.container = _SageMakerContainer( + instance_type, instance_count, image, self.local_session + ) + self.container.serve(self.primary_container["ModelDataUrl"], environment) - serving_port = get_config_value('local.serving_port', self.local_session.config) or 8080 + serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080 _wait_for_serving_container(serving_port) # Get capabilities from Container if needed - endpoint_url = 'http://localhost:%s/execution-parameters' % serving_port + endpoint_url = "http://localhost:%s/execution-parameters" % serving_port response, code = _perform_request(endpoint_url) if code == 200: execution_parameters = json.loads(response.read()) # MaxConcurrentTransforms is ignored because we currently only support 1 - for setting in ('BatchStrategy', 'MaxPayloadInMB'): + for setting in ("BatchStrategy", "MaxPayloadInMB"): if setting not in kwargs and setting in execution_parameters: kwargs[setting] = execution_parameters[setting] @@ -149,9 +156,9 @@ def start(self, input_data, output_data, transform_resources, **kwargs): kwargs.update(self._get_required_defaults(**kwargs)) self.start_time = datetime.datetime.now() - self.batch_strategy = kwargs['BatchStrategy'] - if 'Environment' in kwargs: - self.environment = kwargs['Environment'] + self.batch_strategy = kwargs["BatchStrategy"] + if "Environment" in kwargs: + self.environment = kwargs["Environment"] # run the batch inference requests self._perform_batch_inference(input_data, output_data, **kwargs) @@ -168,25 +175,25 @@ def describe(self): dict: description of this _LocalTransformJob """ response = { - 'TransformJobStatus': self.state, - 'ModelName': self.model_name, - 'TransformJobName': self.name, - 'TransformJobArn': _UNUSED_ARN, - 'TransformEndTime': self.end_time, - 'CreationTime': self.start_time, - 'TransformStartTime': self.start_time, - 'Environment': {}, - 'BatchStrategy': self.batch_strategy, + "TransformJobStatus": self.state, + "ModelName": self.model_name, + "TransformJobName": self.name, + "TransformJobArn": _UNUSED_ARN, + "TransformEndTime": self.end_time, + "CreationTime": self.start_time, + "TransformStartTime": self.start_time, + "Environment": {}, + "BatchStrategy": self.batch_strategy, } if self.transform_resources: - response['TransformResources'] = self.transform_resources + response["TransformResources"] = self.transform_resources if self.output_data: - response['TransformOutput'] = self.output_data + response["TransformOutput"] = self.output_data if self.input_data: - response['TransformInput'] = self.input_data + response["TransformInput"] = self.input_data return response @@ -204,29 +211,31 @@ def _get_container_environment(self, **kwargs): """ environment = {} - environment.update(self.primary_container['Environment']) - environment['SAGEMAKER_BATCH'] = 'True' - if 'MaxPayloadInMB' in kwargs: - environment['SAGEMAKER_MAX_PAYLOAD_IN_MB'] = str(kwargs['MaxPayloadInMB']) - - if 'BatchStrategy' in kwargs: - if kwargs['BatchStrategy'] == 'SingleRecord': - strategy_env_value = 'SINGLE_RECORD' - elif kwargs['BatchStrategy'] == 'MultiRecord': - strategy_env_value = 'MULTI_RECORD' + environment.update(self.primary_container["Environment"]) + environment["SAGEMAKER_BATCH"] = "True" + if "MaxPayloadInMB" in kwargs: + environment["SAGEMAKER_MAX_PAYLOAD_IN_MB"] = str(kwargs["MaxPayloadInMB"]) + + if "BatchStrategy" in kwargs: + if kwargs["BatchStrategy"] == "SingleRecord": + strategy_env_value = "SINGLE_RECORD" + elif kwargs["BatchStrategy"] == "MultiRecord": + strategy_env_value = "MULTI_RECORD" else: - raise ValueError('Invalid BatchStrategy, must be \'SingleRecord\' or \'MultiRecord\'') - environment['SAGEMAKER_BATCH_STRATEGY'] = strategy_env_value + raise ValueError("Invalid BatchStrategy, must be 'SingleRecord' or 'MultiRecord'") + environment["SAGEMAKER_BATCH_STRATEGY"] = strategy_env_value # we only do 1 max concurrent transform in Local Mode - if 'MaxConcurrentTransforms' in kwargs and int(kwargs['MaxConcurrentTransforms']) > 1: - logger.warning('Local Mode only supports 1 ConcurrentTransform. Setting MaxConcurrentTransforms to 1') - environment['SAGEMAKER_MAX_CONCURRENT_TRANSFORMS'] = '1' + if "MaxConcurrentTransforms" in kwargs and int(kwargs["MaxConcurrentTransforms"]) > 1: + logger.warning( + "Local Mode only supports 1 ConcurrentTransform. Setting MaxConcurrentTransforms to 1" + ) + environment["SAGEMAKER_MAX_CONCURRENT_TRANSFORMS"] = "1" # if there were environment variables passed to the Transformer we will pass them to the # container as well. - if 'Environment' in kwargs: - environment.update(kwargs['Environment']) + if "Environment" in kwargs: + environment.update(kwargs["Environment"]) return environment def _get_required_defaults(self, **kwargs): @@ -239,18 +248,18 @@ def _get_required_defaults(self, **kwargs): dict: key/values for the default parameters that are missing. """ defaults = {} - if 'BatchStrategy' not in kwargs: - defaults['BatchStrategy'] = 'MultiRecord' + if "BatchStrategy" not in kwargs: + defaults["BatchStrategy"] = "MultiRecord" - if 'MaxPayloadInMB' not in kwargs: - defaults['MaxPayloadInMB'] = 6 + if "MaxPayloadInMB" not in kwargs: + defaults["MaxPayloadInMB"] = 6 return defaults def _get_working_directory(self): # Root dir to use for intermediate data location. To make things simple we will write here regardless # of the final destination. At the end the files will either be moved or uploaded to S3 and deleted. - root_dir = get_config_value('local.container_root', self.local_session.config) + root_dir = get_config_value("local.container_root", self.local_session.config) if root_dir: root_dir = os.path.abspath(root_dir) @@ -258,10 +267,10 @@ def _get_working_directory(self): return working_dir def _prepare_data_transformation(self, input_data, batch_strategy): - input_path = input_data['DataSource']['S3DataSource']['S3Uri'] + input_path = input_data["DataSource"]["S3DataSource"]["S3Uri"] data_source = sagemaker.local.data.get_data_source_instance(input_path, self.local_session) - split_type = input_data['SplitType'] if 'SplitType' in input_data else None + split_type = input_data["SplitType"] if "SplitType" in input_data else None splitter = sagemaker.local.data.get_splitter_instance(split_type) batch_provider = sagemaker.local.data.get_batch_strategy_instance(batch_strategy, splitter) @@ -272,12 +281,12 @@ def _perform_batch_inference(self, input_data, output_data, **kwargs): # from S3 or Local FileSystem. Split them as required (Line, RecordIO, None) and finally batch them # according to the batch strategy and limit the request size. - batch_strategy = kwargs['BatchStrategy'] - max_payload = int(kwargs['MaxPayloadInMB']) + batch_strategy = kwargs["BatchStrategy"] + max_payload = int(kwargs["MaxPayloadInMB"]) data_source, batch_provider = self._prepare_data_transformation(input_data, batch_strategy) # Output settings - accept = output_data['Accept'] if 'Accept' in output_data else None + accept = output_data["Accept"] if "Accept" in output_data else None working_dir = self._get_working_directory() dataset_dir = data_source.get_root_dir() @@ -287,27 +296,27 @@ def _perform_batch_inference(self, input_data, output_data, **kwargs): relative_path = os.path.dirname(os.path.relpath(file, dataset_dir)) filename = os.path.basename(file) copy_directory_structure(working_dir, relative_path) - destination_path = os.path.join(working_dir, relative_path, filename + '.out') + destination_path = os.path.join(working_dir, relative_path, filename + ".out") - with open(destination_path, 'wb') as f: + with open(destination_path, "wb") as f: for item in batch_provider.pad(file, max_payload): # call the container and add the result to inference. response = self.local_session.sagemaker_runtime_client.invoke_endpoint( - item, '', input_data['ContentType'], accept) + item, "", input_data["ContentType"], accept + ) - response_body = response['Body'] + response_body = response["Body"] data = response_body.read().strip() response_body.close() f.write(data) - if 'AssembleWith' in output_data and output_data['AssembleWith'] == 'Line': - f.write(b'\n') + if "AssembleWith" in output_data and output_data["AssembleWith"] == "Line": + f.write(b"\n") - move_to_destination(working_dir, output_data['S3OutputPath'], self.name, self.local_session) + move_to_destination(working_dir, output_data["S3OutputPath"], self.name, self.local_session) self.container.stop_serving() class _LocalModel(object): - def __init__(self, model_name, primary_container): self.model_name = model_name self.primary_container = primary_container @@ -315,17 +324,16 @@ def __init__(self, model_name, primary_container): def describe(self): response = { - 'ModelName': self.model_name, - 'CreationTime': self.creation_time, - 'ExecutionRoleArn': _UNUSED_ARN, - 'ModelArn': _UNUSED_ARN, - 'PrimaryContainer': self.primary_container + "ModelName": self.model_name, + "CreationTime": self.creation_time, + "ExecutionRoleArn": _UNUSED_ARN, + "ModelArn": _UNUSED_ARN, + "PrimaryContainer": self.primary_container, } return response class _LocalEndpointConfig(object): - def __init__(self, config_name, production_variants, tags=None): self.name = config_name self.production_variants = production_variants @@ -334,53 +342,60 @@ def __init__(self, config_name, production_variants, tags=None): def describe(self): response = { - 'EndpointConfigName': self.name, - 'EndpointConfigArn': _UNUSED_ARN, - 'Tags': self.tags, - 'CreationTime': self.creation_time, - 'ProductionVariants': self.production_variants + "EndpointConfigName": self.name, + "EndpointConfigArn": _UNUSED_ARN, + "Tags": self.tags, + "CreationTime": self.creation_time, + "ProductionVariants": self.production_variants, } return response class _LocalEndpoint(object): - _CREATING = 'Creating' - _IN_SERVICE = 'InService' - _FAILED = 'Failed' + _CREATING = "Creating" + _IN_SERVICE = "InService" + _FAILED = "Failed" def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session=None): # runtime import since there is a cyclic dependency between entities and local_session from sagemaker.local import LocalSession + self.local_session = local_session or LocalSession() local_client = self.local_session.sagemaker_client self.name = endpoint_name self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name) - self.production_variant = self.endpoint_config['ProductionVariants'][0] + self.production_variant = self.endpoint_config["ProductionVariants"][0] self.tags = tags - model_name = self.production_variant['ModelName'] - self.primary_container = local_client.describe_model(model_name)['PrimaryContainer'] + model_name = self.production_variant["ModelName"] + self.primary_container = local_client.describe_model(model_name)["PrimaryContainer"] self.container = None self.create_time = None self.state = _LocalEndpoint._CREATING def serve(self): - image = self.primary_container['Image'] - instance_type = self.production_variant['InstanceType'] - instance_count = self.production_variant['InitialInstanceCount'] + image = self.primary_container["Image"] + instance_type = self.production_variant["InstanceType"] + instance_count = self.production_variant["InitialInstanceCount"] - accelerator_type = self.production_variant.get('AcceleratorType') - if accelerator_type == 'local_sagemaker_notebook': - self.primary_container['Environment']['SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT'] = 'true' + accelerator_type = self.production_variant.get("AcceleratorType") + if accelerator_type == "local_sagemaker_notebook": + self.primary_container["Environment"][ + "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT" + ] = "true" self.create_time = datetime.datetime.now() - self.container = _SageMakerContainer(instance_type, instance_count, image, self.local_session) - self.container.serve(self.primary_container['ModelDataUrl'], self.primary_container['Environment']) - - serving_port = get_config_value('local.serving_port', self.local_session.config) or 8080 + self.container = _SageMakerContainer( + instance_type, instance_count, image, self.local_session + ) + self.container.serve( + self.primary_container["ModelDataUrl"], self.primary_container["Environment"] + ) + + serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080 _wait_for_serving_container(serving_port) # the container is running and it passed the healthcheck status is now InService self.state = _LocalEndpoint._IN_SERVICE @@ -391,13 +406,13 @@ def stop(self): def describe(self): response = { - 'EndpointConfigName': self.endpoint_config['EndpointConfigName'], - 'CreationTime': self.create_time, - 'ProductionVariants': self.endpoint_config['ProductionVariants'], - 'Tags': self.tags, - 'EndpointName': self.name, - 'EndpointArn': _UNUSED_ARN, - 'EndpointStatus': self.state + "EndpointConfigName": self.endpoint_config["EndpointConfigName"], + "CreationTime": self.create_time, + "ProductionVariants": self.endpoint_config["ProductionVariants"], + "Tags": self.tags, + "EndpointName": self.name, + "EndpointArn": _UNUSED_ARN, + "EndpointStatus": self.state, } return response @@ -406,16 +421,16 @@ def _wait_for_serving_container(serving_port): i = 0 http = urllib3.PoolManager() - endpoint_url = 'http://localhost:%s/ping' % serving_port + endpoint_url = "http://localhost:%s/ping" % serving_port while True: i += 5 if i >= HEALTH_CHECK_TIMEOUT_LIMIT: - raise RuntimeError('Giving up, endpoint didn\'t launch correctly') + raise RuntimeError("Giving up, endpoint didn't launch correctly") - logger.info('Checking if serving container is up, attempt: %s' % i) + logger.info("Checking if serving container is up, attempt: %s" % i) _, code = _perform_request(endpoint_url, http) if code != 200: - logger.info('Container still not up, got: %s' % code) + logger.info("Container still not up, got: %s" % code) else: return @@ -425,7 +440,7 @@ def _wait_for_serving_container(serving_port): def _perform_request(endpoint_url, pool_manager=None): http = pool_manager or urllib3.PoolManager() try: - r = http.request('GET', endpoint_url) + r = http.request("GET", endpoint_url) code = r.status except urllib3.exceptions.RequestError: return None, -1 diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 979725b2f3..783b84f5ff 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -38,15 +38,15 @@ import sagemaker.local.utils import sagemaker.utils -CONTAINER_PREFIX = 'algo' -DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml' -DOCKER_COMPOSE_HTTP_TIMEOUT_ENV = 'COMPOSE_HTTP_TIMEOUT' -DOCKER_COMPOSE_HTTP_TIMEOUT = '120' +CONTAINER_PREFIX = "algo" +DOCKER_COMPOSE_FILENAME = "docker-compose.yaml" +DOCKER_COMPOSE_HTTP_TIMEOUT_ENV = "COMPOSE_HTTP_TIMEOUT" +DOCKER_COMPOSE_HTTP_TIMEOUT = "120" # Environment variables to be set during training -REGION_ENV_NAME = 'AWS_REGION' -TRAINING_JOB_NAME_ENV_NAME = 'TRAINING_JOB_NAME' +REGION_ENV_NAME = "AWS_REGION" +TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME" logger = logging.getLogger(__name__) @@ -73,14 +73,18 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None) with SageMaker. """ from sagemaker.local.local_session import LocalSession + self.sagemaker_session = sagemaker_session or LocalSession() self.instance_type = instance_type self.instance_count = instance_count self.image = image # Since we are using a single docker network, Generate a random suffix to attach to the container names. # This way multiple jobs can run in parallel. - suffix = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(5)) - self.hosts = ['{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1)] + suffix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(5)) + self.hosts = [ + "{}-{}-{}".format(CONTAINER_PREFIX, i, suffix) + for i in range(1, self.instance_count + 1) + ] self.container_root = None self.container = None @@ -95,39 +99,45 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name Returns (str): Location of the trained model. """ self.container_root = self._create_tmp_folder() - os.mkdir(os.path.join(self.container_root, 'output')) + os.mkdir(os.path.join(self.container_root, "output")) # create output/data folder since sagemaker-containers 2.0 expects it - os.mkdir(os.path.join(self.container_root, 'output', 'data')) + os.mkdir(os.path.join(self.container_root, "output", "data")) # A shared directory for all the containers. It is only mounted if the training script is # Local. - shared_dir = os.path.join(self.container_root, 'shared') + shared_dir = os.path.join(self.container_root, "shared") os.mkdir(shared_dir) data_dir = self._create_tmp_folder() - volumes = self._prepare_training_volumes(data_dir, input_data_config, output_data_config, - hyperparameters) + volumes = self._prepare_training_volumes( + data_dir, input_data_config, output_data_config, hyperparameters + ) # If local, source directory needs to be updated to mounted /opt/ml/code path - hyperparameters = self._update_local_src_path(hyperparameters, key=sagemaker.estimator.DIR_PARAM_NAME) + hyperparameters = self._update_local_src_path( + hyperparameters, key=sagemaker.estimator.DIR_PARAM_NAME + ) # Create the configuration files for each container that we will create # Each container will map the additional local volumes (if any). for host in self.hosts: _create_config_file_directories(self.container_root, host) self.write_config_files(host, hyperparameters, input_data_config) - shutil.copytree(data_dir, os.path.join(self.container_root, host, 'input', 'data')) + shutil.copytree(data_dir, os.path.join(self.container_root, host, "input", "data")) training_env_vars = { REGION_ENV_NAME: self.sagemaker_session.boto_region_name, TRAINING_JOB_NAME_ENV_NAME: job_name, } - compose_data = self._generate_compose_file('train', additional_volumes=volumes, - additional_env_vars=training_env_vars) + compose_data = self._generate_compose_file( + "train", additional_volumes=volumes, additional_env_vars=training_env_vars + ) compose_command = self._compose() if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image): _pull_image(self.image) - process = subprocess.Popen(compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + process = subprocess.Popen( + compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) try: _stream_output(process) @@ -147,7 +157,7 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name # Print our Job Complete line to have a similar experience to training on SageMaker where you # see this line at the end. - print('===== Job Complete =====') + print("===== Job Complete =====") return artifacts def serve(self, model_dir, environment): @@ -162,7 +172,7 @@ def serve(self, model_dir, environment): logger.info("serving") self.container_root = self._create_tmp_folder() - logger.info('creating hosting dir in {}'.format(self.container_root)) + logger.info("creating hosting dir in {}".format(self.container_root)) volumes = self._prepare_serving_volumes(model_dir) @@ -170,18 +180,18 @@ def serve(self, model_dir, environment): if sagemaker.estimator.DIR_PARAM_NAME.upper() in environment: script_dir = environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] parsed_uri = urlparse(script_dir) - if parsed_uri.scheme == 'file': - volumes.append(_Volume(parsed_uri.path, '/opt/ml/code')) + if parsed_uri.scheme == "file": + volumes.append(_Volume(parsed_uri.path, "/opt/ml/code")) # Update path to mount location environment = environment.copy() - environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = '/opt/ml/code' + environment[sagemaker.estimator.DIR_PARAM_NAME.upper()] = "/opt/ml/code" if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image): _pull_image(self.image) - self._generate_compose_file('serve', - additional_env_vars=environment, - additional_volumes=volumes) + self._generate_compose_file( + "serve", additional_env_vars=environment, additional_volumes=volumes + ) compose_command = self._compose() self.container = _HostingContainer(compose_command) self.container.start() @@ -213,12 +223,12 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): """ # We need a directory to store the artfiacts from all the nodes # and another one to contained the compressed final artifacts - artifacts = os.path.join(self.container_root, 'artifacts') - compressed_artifacts = os.path.join(self.container_root, 'compressed_artifacts') + artifacts = os.path.join(self.container_root, "artifacts") + compressed_artifacts = os.path.join(self.container_root, "compressed_artifacts") os.mkdir(artifacts) - model_artifacts = os.path.join(artifacts, 'model') - output_artifacts = os.path.join(artifacts, 'output') + model_artifacts = os.path.join(artifacts, "model") + output_artifacts = os.path.join(artifacts, "output") artifact_dirs = [model_artifacts, output_artifacts, compressed_artifacts] for d in artifact_dirs: @@ -226,38 +236,41 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): # Gather the artifacts from all nodes into artifacts/model and artifacts/output for host in self.hosts: - volumes = compose_data['services'][str(host)]['volumes'] + volumes = compose_data["services"][str(host)]["volumes"] for volume in volumes: - host_dir, container_dir = volume.split(':') - if container_dir == '/opt/ml/model': + host_dir, container_dir = volume.split(":") + if container_dir == "/opt/ml/model": sagemaker.local.utils.recursive_copy(host_dir, model_artifacts) - elif container_dir == '/opt/ml/output': + elif container_dir == "/opt/ml/output": sagemaker.local.utils.recursive_copy(host_dir, output_artifacts) # Tar Artifacts -> model.tar.gz and output.tar.gz - model_files = [os.path.join(model_artifacts, name) for name in - os.listdir(model_artifacts)] - output_files = [os.path.join(output_artifacts, name) for name in - os.listdir(output_artifacts)] - sagemaker.utils.create_tar_file(model_files, - os.path.join(compressed_artifacts, 'model.tar.gz')) - sagemaker.utils.create_tar_file(output_files, - os.path.join(compressed_artifacts, 'output.tar.gz')) - - if output_data_config['S3OutputPath'] == '': - output_data = 'file://%s' % compressed_artifacts + model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)] + output_files = [ + os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts) + ] + sagemaker.utils.create_tar_file( + model_files, os.path.join(compressed_artifacts, "model.tar.gz") + ) + sagemaker.utils.create_tar_file( + output_files, os.path.join(compressed_artifacts, "output.tar.gz") + ) + + if output_data_config["S3OutputPath"] == "": + output_data = "file://%s" % compressed_artifacts else: # Now we just need to move the compressed artifacts to wherever they are required output_data = sagemaker.local.utils.move_to_destination( compressed_artifacts, - output_data_config['S3OutputPath'], + output_data_config["S3OutputPath"], job_name, - self.sagemaker_session) + self.sagemaker_session, + ) _delete_tree(model_artifacts) _delete_tree(output_artifacts) - return os.path.join(output_data, 'model.tar.gz') + return os.path.join(output_data, "model.tar.gz") def write_config_files(self, host, hyperparameters, input_data_config): """Write the config files for the training containers. @@ -272,39 +285,35 @@ def write_config_files(self, host, hyperparameters, input_data_config): Returns: None """ - config_path = os.path.join(self.container_root, host, 'input', 'config') + config_path = os.path.join(self.container_root, host, "input", "config") - resource_config = { - 'current_host': host, - 'hosts': self.hosts - } + resource_config = {"current_host": host, "hosts": self.hosts} json_input_data_config = {} for c in input_data_config: - channel_name = c['ChannelName'] - json_input_data_config[channel_name] = { - 'TrainingInputMode': 'File' - } - if 'ContentType' in c: - json_input_data_config[channel_name]['ContentType'] = c['ContentType'] - - _write_json_file(os.path.join(config_path, 'hyperparameters.json'), hyperparameters) - _write_json_file(os.path.join(config_path, 'resourceconfig.json'), resource_config) - _write_json_file(os.path.join(config_path, 'inputdataconfig.json'), json_input_data_config) - - def _prepare_training_volumes(self, data_dir, input_data_config, output_data_config, - hyperparameters): - shared_dir = os.path.join(self.container_root, 'shared') - model_dir = os.path.join(self.container_root, 'model') + channel_name = c["ChannelName"] + json_input_data_config[channel_name] = {"TrainingInputMode": "File"} + if "ContentType" in c: + json_input_data_config[channel_name]["ContentType"] = c["ContentType"] + + _write_json_file(os.path.join(config_path, "hyperparameters.json"), hyperparameters) + _write_json_file(os.path.join(config_path, "resourceconfig.json"), resource_config) + _write_json_file(os.path.join(config_path, "inputdataconfig.json"), json_input_data_config) + + def _prepare_training_volumes( + self, data_dir, input_data_config, output_data_config, hyperparameters + ): + shared_dir = os.path.join(self.container_root, "shared") + model_dir = os.path.join(self.container_root, "model") volumes = [] - volumes.append(_Volume(model_dir, '/opt/ml/model')) + volumes.append(_Volume(model_dir, "/opt/ml/model")) # Set up the channels for the containers. For local data we will # mount the local directory to the container. For S3 Data we will download the S3 data # first. for channel in input_data_config: - uri = channel['DataUri'] - channel_name = channel['ChannelName'] + uri = channel["DataUri"] + channel_name = channel["ChannelName"] channel_dir = os.path.join(data_dir, channel_name) os.mkdir(channel_dir) @@ -316,17 +325,20 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con if sagemaker.estimator.DIR_PARAM_NAME in hyperparameters: training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME]) parsed_uri = urlparse(training_dir) - if parsed_uri.scheme == 'file': - volumes.append(_Volume(parsed_uri.path, '/opt/ml/code')) + if parsed_uri.scheme == "file": + volumes.append(_Volume(parsed_uri.path, "/opt/ml/code")) # Also mount a directory that all the containers can access. - volumes.append(_Volume(shared_dir, '/opt/ml/shared')) - - parsed_uri = urlparse(output_data_config['S3OutputPath']) - if parsed_uri.scheme == 'file' and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters: - intermediate_dir = os.path.join(parsed_uri.path, 'output', 'intermediate') + volumes.append(_Volume(shared_dir, "/opt/ml/shared")) + + parsed_uri = urlparse(output_data_config["S3OutputPath"]) + if ( + parsed_uri.scheme == "file" + and sagemaker.model.SAGEMAKER_OUTPUT_LOCATION in hyperparameters + ): + intermediate_dir = os.path.join(parsed_uri.path, "output", "intermediate") if not os.path.exists(intermediate_dir): os.makedirs(intermediate_dir) - volumes.append(_Volume(intermediate_dir, '/opt/ml/output/intermediate')) + volumes.append(_Volume(intermediate_dir, "/opt/ml/output/intermediate")) return volumes @@ -334,9 +346,9 @@ def _update_local_src_path(self, params, key): if key in params: src_dir = json.loads(params[key]) parsed_uri = urlparse(src_dir) - if parsed_uri.scheme == 'file': + if parsed_uri.scheme == "file": new_params = params.copy() - new_params[key] = json.dumps('/opt/ml/code') + new_params[key] = json.dumps("/opt/ml/code") return new_params return params @@ -350,14 +362,15 @@ def _prepare_serving_volumes(self, model_location): os.makedirs(host_dir) model_data_source = sagemaker.local.data.get_data_source_instance( - model_location, self.sagemaker_session) + model_location, self.sagemaker_session + ) for filename in model_data_source.get_file_list(): if tarfile.is_tarfile(filename): with tarfile.open(filename) as tar: tar.extractall(path=model_data_source.get_root_dir()) - volumes.append(_Volume(model_data_source.get_root_dir(), '/opt/ml/model')) + volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model")) return volumes @@ -388,53 +401,51 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en if aws_creds is not None: environment.extend(aws_creds) - additional_env_var_list = ['{}={}'.format(k, v) for k, v in additional_env_vars.items()] + additional_env_var_list = ["{}={}".format(k, v) for k, v in additional_env_vars.items()] environment.extend(additional_env_var_list) if os.environ.get(DOCKER_COMPOSE_HTTP_TIMEOUT_ENV) is None: os.environ[DOCKER_COMPOSE_HTTP_TIMEOUT_ENV] = DOCKER_COMPOSE_HTTP_TIMEOUT - if command == 'train': - optml_dirs = {'output', 'output/data', 'input'} + if command == "train": + optml_dirs = {"output", "output/data", "input"} services = { - h: self._create_docker_host(h, environment, optml_dirs, - command, additional_volumes) for h in self.hosts + h: self._create_docker_host(h, environment, optml_dirs, command, additional_volumes) + for h in self.hosts } content = { # Use version 2.3 as a minimum so that we can specify the runtime - 'version': '2.3', - 'services': services, - 'networks': { - 'sagemaker-local': {'name': 'sagemaker-local'} - } + "version": "2.3", + "services": services, + "networks": {"sagemaker-local": {"name": "sagemaker-local"}}, } docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME) yaml_content = yaml.dump(content, default_flow_style=False) - logger.info('docker compose file: \n{}'.format(yaml_content)) - with open(docker_compose_path, 'w') as f: + logger.info("docker compose file: \n{}".format(yaml_content)) + with open(docker_compose_path, "w") as f: f.write(yaml_content) return content def _compose(self, detached=False): - compose_cmd = 'docker-compose' + compose_cmd = "docker-compose" command = [ compose_cmd, - '-f', + "-f", os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME), - 'up', - '--build', - '--abort-on-container-exit' + "up", + "--build", + "--abort-on-container-exit", ] if detached: - command.append('-d') + command.append("-d") - logger.info('docker command: {}'.format(' '.join(command))) + logger.info("docker command: {}".format(" ".join(command))) return command def _create_docker_host(self, host, environment, optml_subdirs, command, volumes): @@ -442,38 +453,35 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes optml_volumes.extend(volumes) host_config = { - 'image': self.image, - 'stdin_open': True, - 'tty': True, - 'volumes': [v.map for v in optml_volumes], - 'environment': environment, - 'command': command, - 'networks': { - 'sagemaker-local': { - 'aliases': [host] - } - } + "image": self.image, + "stdin_open": True, + "tty": True, + "volumes": [v.map for v in optml_volumes], + "environment": environment, + "command": command, + "networks": {"sagemaker-local": {"aliases": [host]}}, } # for GPU support pass in nvidia as the runtime, this is equivalent # to setting --runtime=nvidia in the docker commandline. - if self.instance_type == 'local_gpu': - host_config['runtime'] = 'nvidia' - - if command == 'serve': - serving_port = sagemaker.utils.get_config_value('local.serving_port', - self.sagemaker_session.config) or 8080 - host_config.update({ - 'ports': [ - '%s:8080' % serving_port - ] - }) + if self.instance_type == "local_gpu": + host_config["runtime"] = "nvidia" + + if command == "serve": + serving_port = ( + sagemaker.utils.get_config_value( + "local.serving_port", self.sagemaker_session.config + ) + or 8080 + ) + host_config.update({"ports": ["%s:8080" % serving_port]}) return host_config def _create_tmp_folder(self): - root_dir = sagemaker.utils.get_config_value('local.container_root', - self.sagemaker_session.config) + root_dir = sagemaker.utils.get_config_value( + "local.container_root", self.sagemaker_session.config + ) if root_dir: root_dir = os.path.abspath(root_dir) @@ -482,8 +490,8 @@ def _create_tmp_folder(self): # Docker cannot mount Mac OS /var folder properly see # https://forums.docker.com/t/var-folders-isnt-mounted-properly/9600 # Only apply this workaround if the user didn't provide an alternate storage root dir. - if root_dir is None and platform.system() == 'Darwin': - working_dir = '/private{}'.format(working_dir) + if root_dir is None and platform.system() == "Darwin": + working_dir = "/private{}".format(working_dir) return os.path.abspath(working_dir) @@ -503,7 +511,7 @@ def _build_optml_volumes(self, host, subdirs): for subdir in subdirs: host_dir = os.path.join(self.container_root, host, subdir) - container_dir = '/opt/ml/{}'.format(subdir) + container_dir = "/opt/ml/{}".format(subdir) volume = _Volume(host_dir, container_dir) volumes.append(volume) @@ -527,9 +535,9 @@ def __init__(self, command): self.process = None def run(self): - self.process = subprocess.Popen(self.command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + self.process = subprocess.Popen( + self.command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) try: _stream_output(self.process) except RuntimeError as e: @@ -558,17 +566,19 @@ def __init__(self, host_dir, container_dir=None, channel=None): /opt/ml/input/data/ in the container. """ if not container_dir and not channel: - raise ValueError('Either container_dir or channel must be declared.') + raise ValueError("Either container_dir or channel must be declared.") if container_dir and channel: - raise ValueError('container_dir and channel cannot be declared together.') + raise ValueError("container_dir and channel cannot be declared together.") - self.container_dir = container_dir if container_dir else os.path.join('/opt/ml/input/data', channel) + self.container_dir = ( + container_dir if container_dir else os.path.join("/opt/ml/input/data", channel) + ) self.host_dir = host_dir - if platform.system() == 'Darwin' and host_dir.startswith('/var'): - self.host_dir = os.path.join('/private', host_dir) + if platform.system() == "Darwin" and host_dir.startswith("/var"): + self.host_dir = os.path.join("/private", host_dir) - self.map = '{}:{}'.format(self.host_dir, self.container_dir) + self.map = "{}:{}".format(self.host_dir, self.container_dir) def _stream_output(process): @@ -616,7 +626,7 @@ def _check_output(cmd, *popenargs, **kwargs): def _create_config_file_directories(root, host): - for d in ['input', 'input/config', 'output', 'model']: + for d in ["input", "input/config", "output", "model"]: os.makedirs(os.path.join(root, host, d)) @@ -652,21 +662,25 @@ def _aws_credentials(session): if token is None: logger.info("Using the long-lived AWS credentials found in session") return [ - 'AWS_ACCESS_KEY_ID=%s' % (str(access_key)), - 'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)) + "AWS_ACCESS_KEY_ID=%s" % (str(access_key)), + "AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)), ] elif not _aws_credentials_available_in_metadata_service(): - logger.warning("Using the short-lived AWS credentials found in session. They might expire while running.") + logger.warning( + "Using the short-lived AWS credentials found in session. They might expire while running." + ) return [ - 'AWS_ACCESS_KEY_ID=%s' % (str(access_key)), - 'AWS_SECRET_ACCESS_KEY=%s' % (str(secret_key)), - 'AWS_SESSION_TOKEN=%s' % (str(token)) + "AWS_ACCESS_KEY_ID=%s" % (str(access_key)), + "AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)), + "AWS_SESSION_TOKEN=%s" % (str(token)), ] else: - logger.info("No AWS credentials found in session but credentials from EC2 Metadata Service are available.") + logger.info( + "No AWS credentials found in session but credentials from EC2 Metadata Service are available." + ) return None except Exception as e: # pylint: disable=broad-except - logger.info('Could not get AWS credentials: %s' % e) + logger.info("Could not get AWS credentials: %s" % e) return None @@ -679,15 +693,16 @@ def _aws_credentials_available_in_metadata_service(): session = botocore.session.Session() instance_metadata_provider = InstanceMetadataProvider( iam_role_fetcher=InstanceMetadataFetcher( - timeout=session.get_config_variable('metadata_service_timeout'), - num_attempts=session.get_config_variable('metadata_service_num_attempts'), - user_agent=session.user_agent()) + timeout=session.get_config_variable("metadata_service_timeout"), + num_attempts=session.get_config_variable("metadata_service_num_attempts"), + user_agent=session.user_agent(), + ) ) return not (instance_metadata_provider.load() is None) def _write_json_file(filename, content): - with open(filename, 'w') as f: + with open(filename, "w") as f: json.dump(content, f) @@ -699,20 +714,22 @@ def _ecr_login_if_needed(boto_session, image): return False # do we have the image? - if _check_output('docker images -q %s' % image).strip(): + if _check_output("docker images -q %s" % image).strip(): return False if not boto_session: - raise RuntimeError('A boto session is required to login to ECR.' - 'Please pull the image: %s manually.' % image) + raise RuntimeError( + "A boto session is required to login to ECR." + "Please pull the image: %s manually." % image + ) - ecr = boto_session.client('ecr') - auth = ecr.get_authorization_token(registryIds=[image.split('.')[0]]) - authorization_data = auth['authorizationData'][0] + ecr = boto_session.client("ecr") + auth = ecr.get_authorization_token(registryIds=[image.split(".")[0]]) + authorization_data = auth["authorizationData"][0] - raw_token = base64.b64decode(authorization_data['authorizationToken']) - token = raw_token.decode('utf-8').strip('AWS:') - ecr_url = auth['authorizationData'][0]['proxyEndpoint'] + raw_token = base64.b64decode(authorization_data["authorizationToken"]) + token = raw_token.decode("utf-8").strip("AWS:") + ecr_url = auth["authorizationData"][0]["proxyEndpoint"] cmd = "docker login -u AWS -p %s %s" % (token, ecr_url) subprocess.check_output(cmd, shell=True) @@ -721,8 +738,8 @@ def _ecr_login_if_needed(boto_session, image): def _pull_image(image): - pull_image_command = ('docker pull %s' % image).strip() - logger.info('docker command: {}'.format(pull_image_command)) + pull_image_command = ("docker pull %s" % image).strip() + logger.info("docker command: {}".format(pull_image_command)) subprocess.check_output(pull_image_command, shell=True) - logger.info('image pulled: {}'.format(image)) + logger.info("image pulled: {}".format(image)) diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 75595fada5..247dc17790 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -20,8 +20,13 @@ from botocore.exceptions import ClientError from sagemaker.local.image import _SageMakerContainer -from sagemaker.local.entities import (_LocalEndpointConfig, _LocalEndpoint, _LocalModel, - _LocalTrainingJob, _LocalTransformJob) +from sagemaker.local.entities import ( + _LocalEndpointConfig, + _LocalEndpoint, + _LocalModel, + _LocalTrainingJob, + _LocalTransformJob, +) from sagemaker.session import Session from sagemaker.utils import get_config_value @@ -52,8 +57,15 @@ def __init__(self, sagemaker_session=None): """ self.sagemaker_session = sagemaker_session or LocalSession() - def create_training_job(self, TrainingJobName, AlgorithmSpecification, OutputDataConfig, - ResourceConfig, InputDataConfig=None, **kwargs): + def create_training_job( + self, + TrainingJobName, + AlgorithmSpecification, + OutputDataConfig, + ResourceConfig, + InputDataConfig=None, + **kwargs + ): """ Create a training job in Local Mode Args: @@ -66,10 +78,14 @@ def create_training_job(self, TrainingJobName, AlgorithmSpecification, OutputDat the final model. """ InputDataConfig = InputDataConfig or {} - container = _SageMakerContainer(ResourceConfig['InstanceType'], ResourceConfig['InstanceCount'], - AlgorithmSpecification['TrainingImage'], self.sagemaker_session) + container = _SageMakerContainer( + ResourceConfig["InstanceType"], + ResourceConfig["InstanceCount"], + AlgorithmSpecification["TrainingImage"], + self.sagemaker_session, + ) training_job = _LocalTrainingJob(container) - hyperparameters = kwargs['HyperParameters'] if 'HyperParameters' in kwargs else {} + hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {} training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName) LocalSagemakerClient._training_jobs[TrainingJobName] = training_job @@ -84,25 +100,44 @@ def describe_training_job(self, TrainingJobName): """ if TrainingJobName not in LocalSagemakerClient._training_jobs: - error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find local training job'}} - raise ClientError(error_response, 'describe_training_job') + error_response = { + "Error": { + "Code": "ValidationException", + "Message": "Could not find local training job", + } + } + raise ClientError(error_response, "describe_training_job") else: return LocalSagemakerClient._training_jobs[TrainingJobName].describe() - def create_transform_job(self, TransformJobName, ModelName, TransformInput, TransformOutput, - TransformResources, **kwargs): + def create_transform_job( + self, + TransformJobName, + ModelName, + TransformInput, + TransformOutput, + TransformResources, + **kwargs + ): transform_job = _LocalTransformJob(TransformJobName, ModelName, self.sagemaker_session) LocalSagemakerClient._transform_jobs[TransformJobName] = transform_job transform_job.start(TransformInput, TransformOutput, TransformResources, **kwargs) def describe_transform_job(self, TransformJobName): if TransformJobName not in LocalSagemakerClient._transform_jobs: - error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find local transform job'}} - raise ClientError(error_response, 'describe_transform_job') + error_response = { + "Error": { + "Code": "ValidationException", + "Message": "Could not find local transform job", + } + } + raise ClientError(error_response, "describe_transform_job") else: return LocalSagemakerClient._transform_jobs[TransformJobName].describe() - def create_model(self, ModelName, PrimaryContainer, *args, **kwargs): # pylint: disable=unused-argument + def create_model( + self, ModelName, PrimaryContainer, *args, **kwargs + ): # pylint: disable=unused-argument """Create a Local Model Object Args: @@ -113,8 +148,10 @@ def create_model(self, ModelName, PrimaryContainer, *args, **kwargs): # pylint: def describe_model(self, ModelName): if ModelName not in LocalSagemakerClient._models: - error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find local model'}} - raise ClientError(error_response, 'describe_model') + error_response = { + "Error": {"Code": "ValidationException", "Message": "Could not find local model"} + } + raise ClientError(error_response, "describe_model") else: return LocalSagemakerClient._models[ModelName].describe() @@ -122,18 +159,25 @@ def describe_endpoint_config(self, EndpointConfigName): if EndpointConfigName in LocalSagemakerClient._endpoint_configs: return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe() else: - error_response = {'Error': { - 'Code': 'ValidationException', 'Message': 'Could not find local endpoint config'}} - raise ClientError(error_response, 'describe_endpoint_config') + error_response = { + "Error": { + "Code": "ValidationException", + "Message": "Could not find local endpoint config", + } + } + raise ClientError(error_response, "describe_endpoint_config") def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None): LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig( - EndpointConfigName, ProductionVariants, Tags) + EndpointConfigName, ProductionVariants, Tags + ) def describe_endpoint(self, EndpointName): if EndpointName not in LocalSagemakerClient._endpoints: - error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find local endpoint'}} - raise ClientError(error_response, 'describe_endpoint') + error_response = { + "Error": {"Code": "ValidationException", "Message": "Could not find local endpoint"} + } + raise ClientError(error_response, "describe_endpoint") else: return LocalSagemakerClient._endpoints[EndpointName].describe() @@ -143,7 +187,7 @@ def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None): endpoint.serve() def update_endpoint(self, EndpointName, EndpointConfigName): # pylint: disable=unused-argument - raise NotImplementedError('Update endpoint name is not supported in local session.') + raise NotImplementedError("Update endpoint name is not supported in local session.") def delete_endpoint(self, EndpointName): if EndpointName in LocalSagemakerClient._endpoints: @@ -162,6 +206,7 @@ class LocalSagemakerRuntimeClient(object): """A SageMaker Runtime client that calls a local endpoint only. """ + def __init__(self, config=None): """Initializes a LocalSageMakerRuntimeClient @@ -172,34 +217,38 @@ def __init__(self, config=None): self.http = urllib3.PoolManager() self.serving_port = 8080 self.config = config - self.serving_port = get_config_value('local.serving_port', config) or 8080 - - def invoke_endpoint(self, Body, EndpointName, # pylint: disable=unused-argument - ContentType=None, Accept=None, CustomAttributes=None): + self.serving_port = get_config_value("local.serving_port", config) or 8080 + + def invoke_endpoint( + self, + Body, + EndpointName, # pylint: disable=unused-argument + ContentType=None, + Accept=None, + CustomAttributes=None, + ): url = "http://localhost:%s/invocations" % self.serving_port headers = {} if ContentType is not None: - headers['Content-type'] = ContentType + headers["Content-type"] = ContentType if Accept is not None: - headers['Accept'] = Accept + headers["Accept"] = Accept if CustomAttributes is not None: - headers['X-Amzn-SageMaker-Custom-Attributes'] = CustomAttributes + headers["X-Amzn-SageMaker-Custom-Attributes"] = CustomAttributes - r = self.http.request('POST', url, body=Body, preload_content=False, - headers=headers) + r = self.http.request("POST", url, body=Body, preload_content=False, headers=headers) - return {'Body': r, 'ContentType': Accept} + return {"Body": r, "ContentType": Accept} class LocalSession(Session): - def __init__(self, boto_session=None): super(LocalSession, self).__init__(boto_session) - if platform.system() == 'Windows': + if platform.system() == "Windows": logger.warning("Windows Support for Local Mode is Experimental") def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): @@ -209,7 +258,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self._region_name = self.boto_session.region_name if self._region_name is None: - raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.') + raise ValueError( + "Must setup local AWS configuration with a region supported by SageMaker." + ) self.sagemaker_client = LocalSagemakerClient(self) self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) @@ -232,13 +283,13 @@ def __init__(self, fileUri, content_type=None): """Create a definition for input data used by an SageMaker training job in local mode. """ self.config = { - 'DataSource': { - 'FileDataSource': { - 'FileDataDistributionType': 'FullyReplicated', - 'FileUri': fileUri + "DataSource": { + "FileDataSource": { + "FileDataDistributionType": "FullyReplicated", + "FileUri": fileUri, } } } if content_type is not None: - self.config['ContentType'] = content_type + self.config["ContentType"] = content_type diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 4c64458c55..ab316e21d3 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -53,16 +53,16 @@ def move_to_destination(source, destination, job_name, sagemaker_session): (str): destination URI """ parsed_uri = urlparse(destination) - if parsed_uri.scheme == 'file': + if parsed_uri.scheme == "file": recursive_copy(source, parsed_uri.path) final_uri = destination - elif parsed_uri.scheme == 's3': + elif parsed_uri.scheme == "s3": bucket = parsed_uri.netloc - path = "%s%s" % (parsed_uri.path.lstrip('/'), job_name) - final_uri = 's3://%s/%s' % (bucket, path) + path = "%s%s" % (parsed_uri.path.lstrip("/"), job_name) + final_uri = "s3://%s/%s" % (bucket, path) sagemaker_session.upload_data(source, bucket, path) else: - raise ValueError('Invalid destination URI, must be s3:// or file://, got: %s' % destination) + raise ValueError("Invalid destination URI, must be s3:// or file://, got: %s" % destination) shutil.rmtree(source) return final_uri diff --git a/src/sagemaker/logs.py b/src/sagemaker/logs.py index b167f1c84a..3d3582ab33 100644 --- a/src/sagemaker/logs.py +++ b/src/sagemaker/logs.py @@ -37,7 +37,7 @@ def __init__(self, force=False): Args: force (bool): If True, render colorizes output no matter where the output is (default: False). """ - self.colorize = force or sys.stdout.isatty() or os.environ.get('JPY_PARENT_PID', None) + self.colorize = force or sys.stdout.isatty() or os.environ.get("JPY_PARENT_PID", None) def __call__(self, index, s): """Print the output, colorized or not, depending on the environment. @@ -52,7 +52,7 @@ def __call__(self, index, s): print(s) def _color_wrap(self, index, s): - print('\x1b[{}m{}\x1b[0m'.format(self._stream_colors[index % len(self._stream_colors)], s)) + print("\x1b[{}m{}\x1b[0m".format(self._stream_colors[index % len(self._stream_colors)], s)) def argmin(arr, f): @@ -74,7 +74,7 @@ def some(arr): # Position is a tuple that includes the last read timestamp and the number of items that were read # at that time. This is used to figure out which event to start with on the next read. -Position = collections.namedtuple('Position', ['timestamp', 'skip']) +Position = collections.namedtuple("Position", ["timestamp", "skip"]) def multi_stream_iter(client, log_group, streams, positions=None): @@ -93,7 +93,9 @@ def multi_stream_iter(client, log_group, streams, positions=None): A tuple of (stream number, cloudwatch log event). """ positions = positions or {s: Position(timestamp=0, skip=0) for s in streams} - event_iters = [log_stream(client, log_group, s, positions[s].timestamp, positions[s].skip) for s in streams] + event_iters = [ + log_stream(client, log_group, s, positions[s].timestamp, positions[s].skip) for s in streams + ] events = [] for s in event_iters: if not s: @@ -105,7 +107,7 @@ def multi_stream_iter(client, log_group, streams, positions=None): events.append(None) while some(events): - i = argmin(events, lambda x: x['timestamp'] if x else 9999999999) + i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) yield (i, events[i]) try: events[i] = next(event_iters[i]) @@ -137,14 +139,19 @@ def log_stream(client, log_group, stream_name, start_time=0, skip=0): event_count = 1 while event_count > 0: if next_token is not None: - token_arg = {'nextToken': next_token} + token_arg = {"nextToken": next_token} else: token_arg = {} - response = client.get_log_events(logGroupName=log_group, logStreamName=stream_name, startTime=start_time, - startFromHead=True, **token_arg) - next_token = response['nextForwardToken'] - events = response['events'] + response = client.get_log_events( + logGroupName=log_group, + logStreamName=stream_name, + startTime=start_time, + startFromHead=True, + **token_arg + ) + next_token = response["nextForwardToken"] + events = response["events"] event_count = len(events) if event_count > skip: events = events[skip:] diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a5c7a96ed3..2ee5116505 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -21,28 +21,52 @@ from sagemaker.fw_utils import UploadedCode from sagemaker.transformer import Transformer -LOGGER = logging.getLogger('sagemaker') - -NEO_ALLOWED_TARGET_INSTANCE_FAMILY = set(['ml_c5', 'ml_m5', 'ml_c4', 'ml_m4', - 'jetson_tx1', 'jetson_tx2', 'jetson_nano', 'ml_p2', - 'ml_p3', 'deeplens', 'rasp3b', - 'rk3288', 'rk3399', 'sbe_c']) -NEO_ALLOWED_FRAMEWORKS = set(['mxnet', 'tensorflow', 'pytorch', 'onnx', 'xgboost']) +LOGGER = logging.getLogger("sagemaker") + +NEO_ALLOWED_TARGET_INSTANCE_FAMILY = set( + [ + "ml_c5", + "ml_m5", + "ml_c4", + "ml_m4", + "jetson_tx1", + "jetson_tx2", + "jetson_nano", + "ml_p2", + "ml_p3", + "deeplens", + "rasp3b", + "rk3288", + "rk3399", + "sbe_c", + ] +) +NEO_ALLOWED_FRAMEWORKS = set(["mxnet", "tensorflow", "pytorch", "onnx", "xgboost"]) NEO_IMAGE_ACCOUNT = { - 'us-west-2': '301217895009', - 'us-east-1': '785573368785', - 'eu-west-1': '802834080501', - 'us-east-2': '007439368137', - 'ap-northeast-1': '941853720454' + "us-west-2": "301217895009", + "us-east-1": "785573368785", + "eu-west-1": "802834080501", + "us-east-2": "007439368137", + "ap-northeast-1": "941853720454", } class Model(object): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" - def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, name=None, vpc_config=None, - sagemaker_session=None, enable_network_isolation=False): + def __init__( + self, + model_data, + image, + role=None, + predictor_cls=None, + env=None, + name=None, + vpc_config=None, + sagemaker_session=None, + enable_network_isolation=False, + ): """Initialize an SageMaker ``Model``. Args: @@ -79,7 +103,9 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n self._is_compiled_model = False self._enable_network_isolation = enable_network_isolation - def prepare_container_def(self, instance_type, accelerator_type=None): # pylint: disable=unused-argument + def prepare_container_def( + self, instance_type, accelerator_type=None + ): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()`` for deploying this model to a specified instance type. Subclasses can override this to provide custom container definitions for @@ -119,40 +145,55 @@ def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=Non """ container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type) - self.name = self.name or utils.name_from_image(container_def['Image']) + self.name = self.name or utils.name_from_image(container_def["Image"]) enable_network_isolation = self.enable_network_isolation() - self.sagemaker_session.create_model(self.name, self.role, - container_def, vpc_config=self.vpc_config, - enable_network_isolation=enable_network_isolation, - tags=tags) + self.sagemaker_session.create_model( + self.name, + self.role, + container_def, + vpc_config=self.vpc_config, + enable_network_isolation=enable_network_isolation, + tags=tags, + ) def _framework(self): - return getattr(self, '__framework_name__', None) + return getattr(self, "__framework_name__", None) def _get_framework_version(self): - return getattr(self, 'framework_version', None) - - def _compilation_job_config(self, target_instance_type, input_shape, output_path, role, compile_max_run, - job_name, framework, tags): + return getattr(self, "framework_version", None) + + def _compilation_job_config( + self, + target_instance_type, + input_shape, + output_path, + role, + compile_max_run, + job_name, + framework, + tags, + ): input_model_config = { - 'S3Uri': self.model_data, - 'DataInputConfig': input_shape if type(input_shape) != dict else json.dumps(input_shape), - 'Framework': framework + "S3Uri": self.model_data, + "DataInputConfig": input_shape + if type(input_shape) != dict + else json.dumps(input_shape), + "Framework": framework, } role = self.sagemaker_session.expand_role(role) output_model_config = { - 'TargetDevice': target_instance_type, - 'S3OutputLocation': output_path + "TargetDevice": target_instance_type, + "S3OutputLocation": output_path, } - return {'input_model_config': input_model_config, - 'output_model_config': output_model_config, - 'role': role, - 'stop_condition': { - 'MaxRuntimeInSeconds': compile_max_run - }, - 'tags': tags, - 'job_name': job_name} + return { + "input_model_config": input_model_config, + "output_model_config": output_model_config, + "role": role, + "stop_condition": {"MaxRuntimeInSeconds": compile_max_run}, + "tags": tags, + "job_name": job_name, + } def check_neo_region(self, region): """Check if this ``Model`` in the available region where neo support. @@ -169,20 +210,34 @@ def check_neo_region(self, region): def _neo_image_account(self, region): if region not in NEO_IMAGE_ACCOUNT: - raise ValueError("Neo is not currently supported in {}, " - "valid regions: {}".format(region, NEO_IMAGE_ACCOUNT.keys())) + raise ValueError( + "Neo is not currently supported in {}, " + "valid regions: {}".format(region, NEO_IMAGE_ACCOUNT.keys()) + ) return NEO_IMAGE_ACCOUNT[region] def _neo_image(self, region, target_instance_type, framework, framework_version): - return fw_utils.create_image_uri(region, - 'neo-' + framework.lower(), - target_instance_type.replace('_', '.'), - framework_version, - py_version='py3', - account=self._neo_image_account(region)) - - def compile(self, target_instance_family, input_shape, output_path, role, - tags=None, job_name=None, compile_max_run=5 * 60, framework=None, framework_version=None): + return fw_utils.create_image_uri( + region, + "neo-" + framework.lower(), + target_instance_type.replace("_", "."), + framework_version, + py_version="py3", + account=self._neo_image_account(region), + ) + + def compile( + self, + target_instance_family, + input_shape, + output_path, + role, + tags=None, + job_name=None, + compile_max_run=5 * 60, + framework=None, + framework_version=None, + ): """Compile this ``Model`` with SageMaker Neo. Args: @@ -207,31 +262,58 @@ def compile(self, target_instance_family, input_shape, output_path, role, """ framework = self._framework() or framework if framework is None: - raise ValueError("You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)) + raise ValueError( + "You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS) + ) if framework not in NEO_ALLOWED_FRAMEWORKS: - raise ValueError("You must provide valid framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)) + raise ValueError( + "You must provide valid framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS) + ) if job_name is None: raise ValueError("You must provide a compilation job name") framework = framework.upper() framework_version = self._get_framework_version() or framework_version - config = self._compilation_job_config(target_instance_family, input_shape, output_path, role, - compile_max_run, job_name, framework, tags) + config = self._compilation_job_config( + target_instance_family, + input_shape, + output_path, + role, + compile_max_run, + job_name, + framework, + tags, + ) self.sagemaker_session.compile_model(**config) job_status = self.sagemaker_session.wait_for_compilation_job(job_name) - self.model_data = job_status['ModelArtifacts']['S3ModelArtifacts'] - if target_instance_family.startswith('ml_'): - self.image = self._neo_image(self.sagemaker_session.boto_region_name, target_instance_family, framework, - framework_version) + self.model_data = job_status["ModelArtifacts"]["S3ModelArtifacts"] + if target_instance_family.startswith("ml_"): + self.image = self._neo_image( + self.sagemaker_session.boto_region_name, + target_instance_family, + framework, + framework_version, + ) self._is_compiled_model = True else: - LOGGER.warning("The intance type {} is not supported to deploy via SageMaker," - "please deploy the model on the device by yourself.".format(target_instance_family)) + LOGGER.warning( + "The intance type {} is not supported to deploy via SageMaker," + "please deploy the model on the device by yourself.".format(target_instance_family) + ) return self - def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, - update_endpoint=False, tags=None, kms_key=None, wait=True): + def deploy( + self, + initial_instance_count, + instance_type, + accelerator_type=None, + endpoint_name=None, + update_endpoint=False, + tags=None, + kms_key=None, + wait=True, + ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``. @@ -266,7 +348,7 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ if not self.sagemaker_session: - if instance_type in ('local', 'local_gpu'): + if instance_type in ("local", "local_gpu"): self.sagemaker_session = local.LocalSession() else: self.sagemaker_session = session.Session() @@ -274,13 +356,14 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e if self.role is None: raise ValueError("Role can not be null for deploying a model") - compiled_model_suffix = '-'.join(instance_type.split('.')[:-1]) + compiled_model_suffix = "-".join(instance_type.split(".")[:-1]) if self._is_compiled_model: self.name += compiled_model_suffix self._create_sagemaker_model(instance_type, accelerator_type, tags) - production_variant = sagemaker.production_variant(self.name, instance_type, initial_instance_count, - accelerator_type=accelerator_type) + production_variant = sagemaker.production_variant( + self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type + ) if endpoint_name: self.endpoint_name = endpoint_name else: @@ -296,18 +379,32 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e instance_type=instance_type, accelerator_type=accelerator_type, tags=tags, - kms_key=kms_key) + kms_key=kms_key, + ) self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name) else: - self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], - tags, kms_key, wait) + self.sagemaker_session.endpoint_from_production_variants( + self.endpoint_name, [production_variant], tags, kms_key, wait + ) if self.predictor_cls: return self.predictor_cls(self.endpoint_name, self.sagemaker_session) - def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, - output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, volume_kms_key=None): + def transformer( + self, + instance_count, + instance_type, + strategy=None, + assemble_with=None, + output_path=None, + output_kms_key=None, + accept=None, + env=None, + max_concurrent_transforms=None, + max_payload=None, + tags=None, + volume_kms_key=None, + ): """Return a ``Transformer`` that uses this Model. Args: @@ -333,11 +430,23 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit if self.enable_network_isolation(): env = None - return Transformer(self.name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, - output_path=output_path, output_kms_key=output_kms_key, accept=accept, - max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=env, tags=tags, base_transform_job_name=self.name, - volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session) + return Transformer( + self.name, + instance_count, + instance_type, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=output_kms_key, + accept=accept, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + env=env, + tags=tags, + base_transform_job_name=self.name, + volume_kms_key=volume_kms_key, + sagemaker_session=self.sagemaker_session, + ) def delete_model(self): """Delete an Amazon SageMaker Model. @@ -347,18 +456,20 @@ def delete_model(self): """ if self.name is None: - raise ValueError('The SageMaker model must be created first before attempting to delete.') + raise ValueError( + "The SageMaker model must be created first before attempting to delete." + ) self.sagemaker_session.delete_model(self.name) -SCRIPT_PARAM_NAME = 'sagemaker_program' -DIR_PARAM_NAME = 'sagemaker_submit_directory' -CLOUDWATCH_METRICS_PARAM_NAME = 'sagemaker_enable_cloudwatch_metrics' -CONTAINER_LOG_LEVEL_PARAM_NAME = 'sagemaker_container_log_level' -JOB_NAME_PARAM_NAME = 'sagemaker_job_name' -MODEL_SERVER_WORKERS_PARAM_NAME = 'sagemaker_model_server_workers' -SAGEMAKER_REGION_PARAM_NAME = 'sagemaker_region' -SAGEMAKER_OUTPUT_LOCATION = 'sagemaker_s3_output' +SCRIPT_PARAM_NAME = "sagemaker_program" +DIR_PARAM_NAME = "sagemaker_submit_directory" +CLOUDWATCH_METRICS_PARAM_NAME = "sagemaker_enable_cloudwatch_metrics" +CONTAINER_LOG_LEVEL_PARAM_NAME = "sagemaker_container_log_level" +JOB_NAME_PARAM_NAME = "sagemaker_job_name" +MODEL_SERVER_WORKERS_PARAM_NAME = "sagemaker_model_server_workers" +SAGEMAKER_REGION_PARAM_NAME = "sagemaker_region" +SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output" class FrameworkModel(Model): @@ -367,9 +478,23 @@ class FrameworkModel(Model): This class hosts user-defined code in S3 and sets code location and configuration in model environment variables. """ - def __init__(self, model_data, image, role, entry_point, source_dir=None, predictor_cls=None, env=None, name=None, - enable_cloudwatch_metrics=False, container_log_level=logging.INFO, code_location=None, - sagemaker_session=None, dependencies=None, **kwargs): + def __init__( + self, + model_data, + image, + role, + entry_point, + source_dir=None, + predictor_cls=None, + env=None, + name=None, + enable_cloudwatch_metrics=False, + container_log_level=logging.INFO, + code_location=None, + sagemaker_session=None, + dependencies=None, + **kwargs + ): """Initialize a ``FrameworkModel``. Args: @@ -415,8 +540,16 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic interactions (default: None). If not specified, one is created using the default AWS configuration chain. **kwargs: Keyword arguments passed to the ``Model`` initializer. """ - super(FrameworkModel, self).__init__(model_data, image, role, predictor_cls=predictor_cls, env=env, name=name, - sagemaker_session=sagemaker_session, **kwargs) + super(FrameworkModel, self).__init__( + model_data, + image, + role, + predictor_cls=predictor_cls, + env=env, + name=name, + sagemaker_session=sagemaker_session, + **kwargs + ) self.entry_point = entry_point self.source_dir = source_dir self.dependencies = dependencies or [] @@ -429,7 +562,9 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic self.uploaded_code = None self.repacked_model_data = None - def prepare_container_def(self, instance_type, accelerator_type=None): # pylint disable=unused-argument + def prepare_container_def( + self, instance_type, accelerator_type=None + ): # pylint disable=unused-argument """Return a container definition with framework configuration set in model environment variables. This also uploads user-supplied code to S3. @@ -449,32 +584,37 @@ def prepare_container_def(self, instance_type, accelerator_type=None): # pylint return sagemaker.container_def(self.image, self.model_data, deploy_env) def _upload_code(self, key_prefix, repack=False): - local_code = utils.get_config_value('local.local_code', self.sagemaker_session.config) + local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) if self.sagemaker_session.local_mode and local_code: self.uploaded_code = None elif not repack: bucket = self.bucket or self.sagemaker_session.default_bucket() - self.uploaded_code = fw_utils.tar_and_upload_dir(session=self.sagemaker_session.boto_session, - bucket=bucket, - s3_key_prefix=key_prefix, - script=self.entry_point, - directory=self.source_dir, - dependencies=self.dependencies) + self.uploaded_code = fw_utils.tar_and_upload_dir( + session=self.sagemaker_session.boto_session, + bucket=bucket, + s3_key_prefix=key_prefix, + script=self.entry_point, + directory=self.source_dir, + dependencies=self.dependencies, + ) if repack: bucket = self.bucket or self.sagemaker_session.default_bucket() - repacked_model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') + repacked_model_data = "s3://" + os.path.join(bucket, key_prefix, "model.tar.gz") - utils.repack_model(inference_script=self.entry_point, - source_directory=self.source_dir, - dependencies=self.dependencies, - model_uri=self.model_data, - repacked_model_uri=repacked_model_data, - sagemaker_session=self.sagemaker_session) + utils.repack_model( + inference_script=self.entry_point, + source_directory=self.source_dir, + dependencies=self.dependencies, + model_uri=self.model_data, + repacked_model_uri=repacked_model_data, + sagemaker_session=self.sagemaker_session, + ) self.repacked_model_data = repacked_model_data - self.uploaded_code = UploadedCode(s3_prefix=self.repacked_model_data, - script_name=os.path.basename(self.entry_point)) + self.uploaded_code = UploadedCode( + s3_prefix=self.repacked_model_data, script_name=os.path.basename(self.entry_point) + ) def _framework_env_vars(self): if self.uploaded_code: @@ -482,22 +622,21 @@ def _framework_env_vars(self): dir_name = self.uploaded_code.s3_prefix else: script_name = self.entry_point - dir_name = 'file://' + self.source_dir + dir_name = "file://" + self.source_dir return { SCRIPT_PARAM_NAME.upper(): script_name, DIR_PARAM_NAME.upper(): dir_name, CLOUDWATCH_METRICS_PARAM_NAME.upper(): str(self.enable_cloudwatch_metrics).lower(), CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), - SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name + SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } class ModelPackage(Model): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" - def __init__(self, role, model_data=None, algorithm_arn=None, - model_package_arn=None, **kwargs): + def __init__(self, role, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs): """Initialize a SageMaker ModelPackage. Args: @@ -513,26 +652,24 @@ def __init__(self, role, model_data=None, algorithm_arn=None, your account owns the Model Package. ``model_data`` is not required. **kwargs: Additional kwargs passed to the Model constructor. """ - super(ModelPackage, self).__init__( - role=role, - model_data=model_data, - image=None, - **kwargs - ) + super(ModelPackage, self).__init__(role=role, model_data=model_data, image=None, **kwargs) if model_package_arn and algorithm_arn: - raise ValueError('model_package_arn and algorithm_arn are mutually exclusive.' - 'Both were provided: model_package_arn: %s algorithm_arn: %s' % - (model_package_arn, algorithm_arn)) + raise ValueError( + "model_package_arn and algorithm_arn are mutually exclusive." + "Both were provided: model_package_arn: %s algorithm_arn: %s" + % (model_package_arn, algorithm_arn) + ) if model_package_arn is None and algorithm_arn is None: - raise ValueError('either model_package_arn or algorithm_arn is required.' - ' None was provided.') + raise ValueError( + "either model_package_arn or algorithm_arn is required." " None was provided." + ) self.algorithm_arn = algorithm_arn if self.algorithm_arn is not None: if model_data is None: - raise ValueError('model_data must be provided with algorithm_arn') + raise ValueError("model_data must be provided with algorithm_arn") self.model_data = model_data self.model_package_arn = model_package_arn @@ -540,13 +677,13 @@ def __init__(self, role, model_data=None, algorithm_arn=None, def _create_sagemaker_model_package(self): if self.algorithm_arn is None: - raise ValueError('No algorithm_arn was provided to create a SageMaker Model Pacakge') + raise ValueError("No algorithm_arn was provided to create a SageMaker Model Pacakge") - name = self.name or utils.name_from_base(self.algorithm_arn.split('/')[-1]) - description = 'Model Package created from training with %s' % self.algorithm_arn - self.sagemaker_session.create_model_package_from_algorithm(name, description, - self.algorithm_arn, - self.model_data) + name = self.name or utils.name_from_base(self.algorithm_arn.split("/")[-1]) + description = "Model Package created from training with %s" % self.algorithm_arn + self.sagemaker_session.create_model_package_from_algorithm( + name, description, self.algorithm_arn, self.model_data + ) return name def enable_network_isolation(self): @@ -569,8 +706,8 @@ def _is_marketplace(self): model_package_desc = sagemaker_session.sagemaker_client.describe_model_package( ModelPackageName=model_package_name ) - for container in model_package_desc['InferenceSpecification']['Containers']: - if 'ProductId' in container: + for container in model_package_desc["InferenceSpecification"]["Containers"]: + if "ProductId" in container: return True return False @@ -593,16 +730,18 @@ def _create_sagemaker_model(self, *args): # pylint: disable=unused-argument # When a ModelPackageArn is provided we just create the Model model_package_name = self.model_package_arn - container_def = { - 'ModelPackageName': model_package_name, - } + container_def = {"ModelPackageName": model_package_name} if self.env != {}: - container_def['Environment'] = self.env + container_def["Environment"] = self.env - model_package_short_name = model_package_name.split('/')[-1] + model_package_short_name = model_package_name.split("/")[-1] enable_network_isolation = self.enable_network_isolation() self.name = self.name or utils.name_from_base(model_package_short_name) - self.sagemaker_session.create_model(self.name, self.role, container_def, - vpc_config=self.vpc_config, - enable_network_isolation=enable_network_isolation) + self.sagemaker_session.create_model( + self.name, + self.role, + container_def, + vpc_config=self.vpc_config, + enable_network_isolation=enable_network_isolation, + ) diff --git a/src/sagemaker/mxnet/defaults.py b/src/sagemaker/mxnet/defaults.py index 43fc487efc..9e693c87bc 100644 --- a/src/sagemaker/mxnet/defaults.py +++ b/src/sagemaker/mxnet/defaults.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -MXNET_VERSION = '1.2' +MXNET_VERSION = "1.2" """Default MXNet version for when the framework version is not specified. This is no longer updated so as to not break existing workflows. """ diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 43634b1bba..81a81a768b 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -15,26 +15,39 @@ import logging from sagemaker.estimator import Framework -from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \ - python_deprecation_warning +from sagemaker.fw_utils import ( + framework_name_from_image, + framework_version_from_tag, + empty_framework_version_warning, + python_deprecation_warning, +) from sagemaker.mxnet.defaults import MXNET_VERSION from sagemaker.mxnet.model import MXNetModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class MXNet(Framework): """Handle end-to-end training and deployment of custom MXNet code.""" - __framework_name__ = 'mxnet' - _LOWEST_SCRIPT_MODE_VERSION = ['1', '3'] + __framework_name__ = "mxnet" + _LOWEST_SCRIPT_MODE_VERSION = ["1", "3"] - LATEST_VERSION = '1.4' + LATEST_VERSION = "1.4" """The latest version of MXNet included in the SageMaker pre-built Docker images.""" - def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2', - framework_version=None, image_name=None, distributions=None, **kwargs): + def __init__( + self, + entry_point, + source_dir=None, + hyperparameters=None, + py_version="py2", + framework_version=None, + image_name=None, + distributions=None, + **kwargs + ): """ This ``Estimator`` executes an MXNet script in a managed MXNet execution environment, within a SageMaker Training Job. The managed MXNet environment is an Amazon-built Docker container that executes functions @@ -78,10 +91,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio logger.warning(empty_framework_version_warning(MXNET_VERSION, self.LATEST_VERSION)) self.framework_version = framework_version or MXNET_VERSION - super(MXNet, self).__init__(entry_point, source_dir, hyperparameters, - image_name=image_name, **kwargs) + super(MXNet, self).__init__( + entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs + ) - if py_version == 'py2': + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -91,15 +105,20 @@ def _configure_distribution(self, distributions): if distributions is None: return - if self.framework_version.split('.') < self._LOWEST_SCRIPT_MODE_VERSION: - raise ValueError('The distributions option is valid for only versions {} and higher' - .format('.'.join(self._LOWEST_SCRIPT_MODE_VERSION))) + if self.framework_version.split(".") < self._LOWEST_SCRIPT_MODE_VERSION: + raise ValueError( + "The distributions option is valid for only versions {} and higher".format( + ".".join(self._LOWEST_SCRIPT_MODE_VERSION) + ) + ) - if 'parameter_server' in distributions: - enabled = distributions['parameter_server'].get('enabled', False) + if "parameter_server" in distributions: + enabled = distributions["parameter_server"].get("enabled", False) self._hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled - def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): + def create_model( + self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT + ): """Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``. Args: @@ -117,12 +136,23 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override See :func:`~sagemaker.mxnet.model.MXNetModel` for full details. """ role = role or self.role - return MXNetModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), - enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, - container_log_level=self.container_log_level, code_location=self.code_location, - py_version=self.py_version, framework_version=self.framework_version, image=self.image_name, - model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override), dependencies=self.dependencies) + return MXNetModel( + self.model_data, + role, + self.entry_point, + source_dir=self._model_source_dir(), + enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, + name=self._current_job_name, + container_log_level=self.container_log_level, + code_location=self.code_location, + py_version=self.py_version, + framework_version=self.framework_version, + image=self.image_name, + model_server_workers=model_server_workers, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + dependencies=self.dependencies, + ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -136,27 +166,35 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na dictionary: The transformed init_params """ - init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) - image_name = init_params.pop('image') + init_params = super(MXNet, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) + image_name = init_params.pop("image") framework, py_version, tag, _ = framework_name_from_image(image_name) if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. - init_params['image_name'] = image_name + init_params["image_name"] = image_name return init_params - init_params['py_version'] = py_version + init_params["py_version"] = py_version # We switched image tagging scheme from regular image version (e.g. '1.0') to more expressive # containing framework version, device type and python version (e.g. '0.12-gpu-py2'). # For backward compatibility map deprecated image tag '1.0' to a '0.12' framework version # otherwise extract framework version from the tag itself. - init_params['framework_version'] = '0.12' if tag == '1.0' else framework_version_from_tag(tag) + init_params["framework_version"] = ( + "0.12" if tag == "1.0" else framework_version_from_tag(tag) + ) - training_job_name = init_params['base_job_name'] + training_job_name = init_params["base_job_name"] if framework != cls.__framework_name__: - raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name)) + raise ValueError( + "Training job: {} didn't use image for requested framework".format( + training_job_name + ) + ) return init_params diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 29f5040c5d..2107f0103e 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -22,7 +22,7 @@ from sagemaker.mxnet.defaults import MXNET_VERSION from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class MXNetPredictor(RealTimePredictor): @@ -40,17 +40,29 @@ def __init__(self, endpoint_name, sagemaker_session=None): Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. """ - super(MXNetPredictor, self).__init__(endpoint_name, sagemaker_session, json_serializer, json_deserializer) + super(MXNetPredictor, self).__init__( + endpoint_name, sagemaker_session, json_serializer, json_deserializer + ) class MXNetModel(FrameworkModel): """An MXNet SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" - __framework_name__ = 'mxnet' - _LOWEST_MMS_VERSION = '1.4' - - def __init__(self, model_data, role, entry_point, image=None, py_version='py2', framework_version=MXNET_VERSION, - predictor_cls=MXNetPredictor, model_server_workers=None, **kwargs): + __framework_name__ = "mxnet" + _LOWEST_MMS_VERSION = "1.4" + + def __init__( + self, + model_data, + role, + entry_point, + image=None, + py_version="py2", + framework_version=MXNET_VERSION, + predictor_cls=MXNetPredictor, + model_server_workers=None, + **kwargs + ): """Initialize an MXNetModel. Args: @@ -71,10 +83,11 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2', If None, server will use one worker per vCPU. **kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer. """ - super(MXNetModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, - **kwargs) + super(MXNetModel, self).__init__( + model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs + ) - if py_version == 'py2': + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -92,7 +105,9 @@ def prepare_container_def(self, instance_type, accelerator_type=None): Returns: dict[str, str]: A container definition object usable with the CreateModel API. """ - is_mms_version = parse_version(self.framework_version) >= parse_version(self._LOWEST_MMS_VERSION) + is_mms_version = parse_version(self.framework_version) >= parse_version( + self._LOWEST_MMS_VERSION + ) deploy_image = self.image if not deploy_image: @@ -100,10 +115,16 @@ def prepare_container_def(self, instance_type, accelerator_type=None): framework_name = self.__framework_name__ if is_mms_version: - framework_name += '-serving' + framework_name += "-serving" - deploy_image = create_image_uri(region_name, framework_name, instance_type, - self.framework_version, self.py_version, accelerator_type=accelerator_type) + deploy_image = create_image_uri( + region_name, + framework_name, + instance_type, + self.framework_version, + self.py_version, + accelerator_type=accelerator_type, + ) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix, is_mms_version) @@ -112,4 +133,6 @@ def prepare_container_def(self, instance_type, accelerator_type=None): if self.model_server_workers: deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) - return sagemaker.container_def(deploy_image, self.repacked_model_data or self.model_data, deploy_env) + return sagemaker.container_def( + deploy_image, self.repacked_model_data or self.model_data, deploy_env + ) diff --git a/src/sagemaker/parameter.py b/src/sagemaker/parameter.py index 69e74941f3..16372d80d4 100644 --- a/src/sagemaker/parameter.py +++ b/src/sagemaker/parameter.py @@ -23,9 +23,9 @@ class ParameterRange(object): """ - __all_types__ = ('Continuous', 'Categorical', 'Integer') + __all_types__ = ("Continuous", "Categorical", "Integer") - def __init__(self, min_value, max_value, scaling_type='Auto'): + def __init__(self, min_value, max_value, scaling_type="Auto"): """Initialize a parameter range. Args: @@ -63,10 +63,12 @@ def as_tuning_range(self, name): Returns: dict[str, str]: A dictionary that contains the name and values of the hyperparameter. """ - return {'Name': name, - 'MinValue': to_str(self.min_value), - 'MaxValue': to_str(self.max_value), - 'ScalingType': self.scaling_type} + return { + "Name": name, + "MinValue": to_str(self.min_value), + "MaxValue": to_str(self.max_value), + "ScalingType": self.scaling_type, + } class ContinuousParameter(ParameterRange): @@ -75,7 +77,8 @@ class ContinuousParameter(ParameterRange): min_value (float): The minimum value for the range. max_value (float): The maximum value for the range. """ - __name__ = 'Continuous' + + __name__ = "Continuous" @classmethod def cast_to_type(cls, value): @@ -85,7 +88,8 @@ def cast_to_type(cls, value): class CategoricalParameter(ParameterRange): """A class for representing hyperparameters that have a discrete list of possible values. """ - __name__ = 'Categorical' + + __name__ = "Categorical" def __init__(self, values): # pylint: disable=super-init-not-called """Initialize a ``CategoricalParameter``. @@ -109,7 +113,7 @@ def as_tuning_range(self, name): Returns: dict[str, list[str]]: A dictionary that contains the name and values of the hyperparameter. """ - return {'Name': name, 'Values': self.values} + return {"Name": name, "Values": self.values} def as_json_range(self, name): """Represent the parameter range as a dictionary suitable for a request to @@ -124,7 +128,7 @@ def as_json_range(self, name): dict[str, list[str]]: A dictionary that contains the name and values of the hyperparameter, where the values are serialized as JSON. """ - return {'Name': name, 'Values': [json.dumps(v) for v in self.values]} + return {"Name": name, "Values": [json.dumps(v) for v in self.values]} def is_valid(self, value): return value in self.values @@ -140,7 +144,8 @@ class IntegerParameter(ParameterRange): min_value (int): The minimum value for the range. max_value (int): The maximum value for the range. """ - __name__ = 'Integer' + + __name__ = "Integer" @classmethod def cast_to_type(cls, value): diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index c0103d6d06..027f7a9c16 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -21,7 +21,9 @@ class PipelineModel(object): """A pipeline of SageMaker ``Model``s that can be deployed to an ``Endpoint``.""" - def __init__(self, models, role, predictor_cls=None, name=None, vpc_config=None, sagemaker_session=None): + def __init__( + self, models, role, predictor_cls=None, name=None, vpc_config=None, sagemaker_session=None + ): """Initialize an SageMaker ``Model`` which can be used to build an Inference Pipeline comprising of multiple model containers. @@ -67,7 +69,9 @@ def pipeline_container_def(self, instance_type): return sagemaker.pipeline_container_def(self.models, instance_type) - def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None, wait=True): + def deploy( + self, initial_instance_count, instance_type, endpoint_name=None, tags=None, wait=True + ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``. @@ -97,13 +101,18 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags containers = self.pipeline_container_def(instance_type) - self.name = self.name or name_from_image(containers[0]['Image']) - self.sagemaker_session.create_model(self.name, self.role, containers, vpc_config=self.vpc_config) + self.name = self.name or name_from_image(containers[0]["Image"]) + self.sagemaker_session.create_model( + self.name, self.role, containers, vpc_config=self.vpc_config + ) - production_variant = sagemaker.production_variant(self.name, instance_type, initial_instance_count) + production_variant = sagemaker.production_variant( + self.name, instance_type, initial_instance_count + ) self.endpoint_name = endpoint_name or self.name - self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags, - wait=wait) + self.sagemaker_session.endpoint_from_production_variants( + self.endpoint_name, [production_variant], tags, wait=wait + ) if self.predictor_cls: return self.predictor_cls(self.endpoint_name, self.sagemaker_session) @@ -122,12 +131,26 @@ def _create_sagemaker_pipeline_model(self, instance_type): containers = self.pipeline_container_def(instance_type) - self.name = self.name or name_from_image(containers[0]['Image']) - self.sagemaker_session.create_model(self.name, self.role, containers, vpc_config=self.vpc_config) - - def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, - output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, volume_kms_key=None): + self.name = self.name or name_from_image(containers[0]["Image"]) + self.sagemaker_session.create_model( + self.name, self.role, containers, vpc_config=self.vpc_config + ) + + def transformer( + self, + instance_count, + instance_type, + strategy=None, + assemble_with=None, + output_path=None, + output_kms_key=None, + accept=None, + env=None, + max_concurrent_transforms=None, + max_payload=None, + tags=None, + volume_kms_key=None, + ): """Return a ``Transformer`` that uses this Model. Args: @@ -151,11 +174,23 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit """ self._create_sagemaker_pipeline_model(instance_type) - return Transformer(self.name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with, - output_path=output_path, output_kms_key=output_kms_key, accept=accept, - max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=env, tags=tags, base_transform_job_name=self.name, - volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session) + return Transformer( + self.name, + instance_count, + instance_type, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=output_kms_key, + accept=accept, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + env=env, + tags=tags, + base_transform_job_name=self.name, + volume_kms_key=volume_kms_key, + sagemaker_session=self.sagemaker_session, + ) def delete_model(self): """Delete the SageMaker model backing this pipeline model. This does not delete the list of SageMaker models used @@ -164,6 +199,6 @@ def delete_model(self): """ if self.name is None: - raise ValueError('The SageMaker model must be created before attempting to delete.') + raise ValueError("The SageMaker model must be created before attempting to delete.") self.sagemaker_session.delete_model(self.name) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 10edbfe727..9b4f777178 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -27,8 +27,15 @@ class RealTimePredictor(object): """Make prediction requests to an Amazon SageMaker endpoint. """ - def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializer=None, - content_type=None, accept=None): + def __init__( + self, + endpoint, + sagemaker_session=None, + serializer=None, + deserializer=None, + content_type=None, + accept=None, + ): """Initialize a ``RealTimePredictor``. Behavior for serialization of input data and deserialization of result data @@ -54,8 +61,8 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ self.sagemaker_session = sagemaker_session or Session() self.serializer = serializer self.deserializer = deserializer - self.content_type = content_type or getattr(serializer, 'content_type', None) - self.accept = accept or getattr(deserializer, 'accept', None) + self.content_type = content_type or getattr(serializer, "content_type", None) + self.accept = accept or getattr(deserializer, "accept", None) self._endpoint_config_name = self._get_endpoint_config_name() self._model_names = self._get_model_names() @@ -81,10 +88,10 @@ def predict(self, data, initial_args=None): return self._handle_response(response) def _handle_response(self, response): - response_body = response['Body'] + response_body = response["Body"] if self.deserializer is not None: # It's the deserializer's responsibility to close the stream - return self.deserializer(response_body, response['ContentType']) + return self.deserializer(response_body, response["ContentType"]) data = response_body.read() response_body.close() return data @@ -92,19 +99,19 @@ def _handle_response(self, response): def _create_request_args(self, data, initial_args=None): args = dict(initial_args) if initial_args else {} - if 'EndpointName' not in args: - args['EndpointName'] = self.endpoint + if "EndpointName" not in args: + args["EndpointName"] = self.endpoint - if self.content_type and 'ContentType' not in args: - args['ContentType'] = self.content_type + if self.content_type and "ContentType" not in args: + args["ContentType"] = self.content_type - if self.accept and 'Accept' not in args: - args['Accept'] = self.accept + if self.accept and "Accept" not in args: + args["Accept"] = self.accept if self.serializer is not None: data = self.serializer(data) - args['Body'] = data + args["Body"] = data return args def _delete_endpoint_config(self): @@ -142,19 +149,24 @@ def delete_model(self): failed_models.append(model_name) if request_failed: - raise Exception('One or more models cannot be deleted, please retry. \n' - 'Failed models: {}'.format(', '.join(failed_models))) + raise Exception( + "One or more models cannot be deleted, please retry. \n" + "Failed models: {}".format(", ".join(failed_models)) + ) def _get_endpoint_config_name(self): - endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint) - endpoint_config_name = endpoint_desc['EndpointConfigName'] + endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint( + EndpointName=self.endpoint + ) + endpoint_config_name = endpoint_desc["EndpointConfigName"] return endpoint_config_name def _get_model_names(self): endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config( - EndpointConfigName=self._endpoint_config_name) - production_variants = endpoint_config['ProductionVariants'] - return map(lambda d: d['ModelName'], production_variants) + EndpointConfigName=self._endpoint_config_name + ) + production_variants = endpoint_config["ProductionVariants"] + return map(lambda d: d["ModelName"], production_variants) class _CsvSerializer(object): @@ -172,7 +184,7 @@ def __call__(self, data): """ # For inputs which represent multiple "rows", the result should be newline-separated CSV rows if _is_mutable_sequence_like(data) and len(data) > 0 and _is_sequence_like(data[0]): - return '\n'.join([_CsvSerializer._serialize_row(row) for row in data]) + return "\n".join([_CsvSerializer._serialize_row(row) for row in data]) return _CsvSerializer._serialize_row(data) @staticmethod @@ -182,14 +194,14 @@ def _serialize_row(data): return data if isinstance(data, np.ndarray): data = np.ndarray.flatten(data) - if hasattr(data, '__len__'): + if hasattr(data, "__len__"): if len(data): return _csv_serialize_python_array(data) else: raise ValueError("Cannot serialize empty array") # files and buffers - if hasattr(data, 'read'): + if hasattr(data, "read"): return _csv_serialize_from_buffer(data) raise ValueError("Unable to handle input format: ", type(data)) @@ -206,31 +218,31 @@ def _csv_serialize_from_buffer(buff): def _csv_serialize_object(data): csv_buffer = StringIO() - csv_writer = csv.writer(csv_buffer, delimiter=',') + csv_writer = csv.writer(csv_buffer, delimiter=",") csv_writer.writerow(data) - return csv_buffer.getvalue().rstrip('\r\n') + return csv_buffer.getvalue().rstrip("\r\n") csv_serializer = _CsvSerializer() def _is_mutable_sequence_like(obj): - return _is_sequence_like(obj) and hasattr(obj, '__setitem__') + return _is_sequence_like(obj) and hasattr(obj, "__setitem__") def _is_sequence_like(obj): # Need to explicitly check on str since str lacks the iterable magic methods in Python 2 - return (hasattr(obj, '__iter__') and hasattr(obj, '__getitem__')) or isinstance(obj, str) + return (hasattr(obj, "__iter__") and hasattr(obj, "__getitem__")) or isinstance(obj, str) def _row_to_csv(obj): if isinstance(obj, str): return obj - return ','.join(obj) + return ",".join(obj) class _CsvDeserializer(object): - def __init__(self, encoding='utf-8'): + def __init__(self, encoding="utf-8"): self.accept = CONTENT_TYPE_CSV self.encoding = encoding @@ -269,7 +281,7 @@ class StringDeserializer(object): accept (str): The Accept header to send to the server (optional). """ - def __init__(self, encoding='utf-8', accept=None): + def __init__(self, encoding="utf-8", accept=None): self.encoding = encoding self.accept = accept @@ -315,7 +327,7 @@ def __call__(self, data): return json.dumps({k: _ndarray_to_list(v) for k, v in six.iteritems(data)}) # files and buffers - if hasattr(data, 'read'): + if hasattr(data, "read"): return _json_serialize_from_buffer(data) return json.dumps(_ndarray_to_list(data)) @@ -347,7 +359,7 @@ def __call__(self, stream, content_type): object: Body of the response deserialized into a JSON object. """ try: - return json.load(codecs.getreader('utf-8')(stream)) + return json.load(codecs.getreader("utf-8")(stream)) finally: stream.close() @@ -372,9 +384,11 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY): """ try: if content_type == CONTENT_TYPE_CSV: - return np.genfromtxt(codecs.getreader('utf-8')(stream), delimiter=',', dtype=self.dtype) + return np.genfromtxt( + codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype + ) elif content_type == CONTENT_TYPE_JSON: - return np.array(json.load(codecs.getreader('utf-8')(stream)), dtype=self.dtype) + return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype) elif content_type == CONTENT_TYPE_NPY: return np.load(BytesIO(stream.read())) finally: @@ -408,7 +422,7 @@ def __call__(self, data, dtype=None): return _npy_serialize(np.array(data, dtype)) # files and buffers. Assumed to hold npy-formatted data. - if hasattr(data, 'read'): + if hasattr(data, "read"): return data.read() return _npy_serialize(np.array(data)) diff --git a/src/sagemaker/pytorch/defaults.py b/src/sagemaker/pytorch/defaults.py index 154cd35b65..21629fa66f 100644 --- a/src/sagemaker/pytorch/defaults.py +++ b/src/sagemaker/pytorch/defaults.py @@ -12,9 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -PYTORCH_VERSION = '0.4' +PYTORCH_VERSION = "0.4" """Default PyTorch version for when the framework version is not specified. This is no longer updated so as to not break existing workflows. """ -PYTHON_VERSION = 'py3' +PYTHON_VERSION = "py3" diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 446b33bfc1..f6a441d8e6 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -15,13 +15,17 @@ import logging from sagemaker.estimator import Framework -from sagemaker.fw_utils import framework_name_from_image, framework_version_from_tag, empty_framework_version_warning, \ - python_deprecation_warning +from sagemaker.fw_utils import ( + framework_name_from_image, + framework_version_from_tag, + empty_framework_version_warning, + python_deprecation_warning, +) from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION from sagemaker.pytorch.model import PyTorchModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class PyTorch(Framework): @@ -29,11 +33,19 @@ class PyTorch(Framework): __framework_name__ = "pytorch" - LATEST_VERSION = '1.0' + LATEST_VERSION = "1.0" """The latest version of PyTorch included in the SageMaker pre-built Docker images.""" - def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION, - framework_version=None, image_name=None, **kwargs): + def __init__( + self, + entry_point, + source_dir=None, + hyperparameters=None, + py_version=PYTHON_VERSION, + framework_version=None, + image_name=None, + **kwargs + ): """ This ``Estimator`` executes an PyTorch script in a managed PyTorch execution environment, within a SageMaker Training Job. The managed PyTorch environment is an Amazon-built Docker container that executes functions @@ -74,14 +86,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio logger.warning(empty_framework_version_warning(PYTORCH_VERSION, PYTORCH_VERSION)) self.framework_version = framework_version or PYTORCH_VERSION - super(PyTorch, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs) + super(PyTorch, self).__init__( + entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs + ) - if py_version == 'py2': + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version - def create_model(self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): + def create_model( + self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT + ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. Args: @@ -99,12 +115,23 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override See :func:`~sagemaker.pytorch.model.PyTorchModel` for full details. """ role = role or self.role - return PyTorchModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), - enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, - container_log_level=self.container_log_level, code_location=self.code_location, - py_version=self.py_version, framework_version=self.framework_version, image=self.image_name, - model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override), dependencies=self.dependencies) + return PyTorchModel( + self.model_data, + role, + self.entry_point, + source_dir=self._model_source_dir(), + enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, + name=self._current_job_name, + container_log_level=self.container_log_level, + code_location=self.code_location, + py_version=self.py_version, + framework_version=self.framework_version, + image=self.image_name, + model_server_workers=model_server_workers, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + dependencies=self.dependencies, + ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -118,22 +145,28 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na dictionary: The transformed init_params """ - init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details, model_channel_name) - image_name = init_params.pop('image') + init_params = super(PyTorch, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) + image_name = init_params.pop("image") framework, py_version, tag, _ = framework_name_from_image(image_name) if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. - init_params['image_name'] = image_name + init_params["image_name"] = image_name return init_params - init_params['py_version'] = py_version - init_params['framework_version'] = framework_version_from_tag(tag) + init_params["py_version"] = py_version + init_params["framework_version"] = framework_version_from_tag(tag) - training_job_name = init_params['base_job_name'] + training_job_name = init_params["base_job_name"] if framework != cls.__framework_name__: - raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name)) + raise ValueError( + "Training job: {} didn't use image for requested framework".format( + training_job_name + ) + ) return init_params diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index db05fff1ed..11d0caf04d 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -20,7 +20,7 @@ from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class PyTorchPredictor(RealTimePredictor): @@ -38,17 +38,28 @@ def __init__(self, endpoint_name, sagemaker_session=None): Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. """ - super(PyTorchPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer) + super(PyTorchPredictor, self).__init__( + endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer + ) class PyTorchModel(FrameworkModel): """An PyTorch SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" - __framework_name__ = 'pytorch' - - def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_VERSION, - framework_version=PYTORCH_VERSION, predictor_cls=PyTorchPredictor, - model_server_workers=None, **kwargs): + __framework_name__ = "pytorch" + + def __init__( + self, + model_data, + role, + entry_point, + image=None, + py_version=PYTHON_VERSION, + framework_version=PYTORCH_VERSION, + predictor_cls=PyTorchPredictor, + model_server_workers=None, + **kwargs + ): """Initialize an PyTorchModel. Args: @@ -69,9 +80,11 @@ def __init__(self, model_data, role, entry_point, image=None, py_version=PYTHON_ If None, server will use one worker per vCPU. **kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer. """ - super(PyTorchModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs) + super(PyTorchModel, self).__init__( + model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs + ) - if py_version == 'py2': + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -92,8 +105,14 @@ def prepare_container_def(self, instance_type, accelerator_type=None): deploy_image = self.image if not deploy_image: region_name = self.sagemaker_session.boto_session.region_name - deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type, - self.framework_version, self.py_version, accelerator_type=accelerator_type) + deploy_image = create_image_uri( + region_name, + self.__framework_name__, + instance_type, + self.framework_version, + self.py_version, + accelerator_type=accelerator_type, + ) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix) deploy_env = dict(self.env) diff --git a/src/sagemaker/rl/__init__.py b/src/sagemaker/rl/__init__.py index cf7b222c93..5acabd42f4 100644 --- a/src/sagemaker/rl/__init__.py +++ b/src/sagemaker/rl/__init__.py @@ -12,5 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from sagemaker.rl.estimator import (RLEstimator, RLFramework, RLToolkit, # noqa: F401 - TOOLKIT_FRAMEWORK_VERSION_MAP) +from sagemaker.rl.estimator import ( # noqa: F401 + RLEstimator, + RLFramework, + RLToolkit, + TOOLKIT_FRAMEWORK_VERSION_MAP, +) diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index b929ea17b0..bcc54cd61e 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -22,68 +22,57 @@ from sagemaker.mxnet.model import MXNetModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") -SAGEMAKER_ESTIMATOR = 'sagemaker_estimator' -SAGEMAKER_ESTIMATOR_VALUE = 'RLEstimator' -PYTHON_VERSION = 'py3' +SAGEMAKER_ESTIMATOR = "sagemaker_estimator" +SAGEMAKER_ESTIMATOR_VALUE = "RLEstimator" +PYTHON_VERSION = "py3" TOOLKIT_FRAMEWORK_VERSION_MAP = { - 'coach': { - '0.10.1': { - 'tensorflow': '1.11' - }, - '0.10': { - 'tensorflow': '1.11' - }, - '0.11.0': { - 'tensorflow': '1.11', - 'mxnet': '1.3' - }, - '0.11.1': { - 'tensorflow': '1.12', - }, - '0.11': { - 'tensorflow': '1.12', - 'mxnet': '1.3' - } + "coach": { + "0.10.1": {"tensorflow": "1.11"}, + "0.10": {"tensorflow": "1.11"}, + "0.11.0": {"tensorflow": "1.11", "mxnet": "1.3"}, + "0.11.1": {"tensorflow": "1.12"}, + "0.11": {"tensorflow": "1.12", "mxnet": "1.3"}, + }, + "ray": { + "0.5.3": {"tensorflow": "1.11"}, + "0.5": {"tensorflow": "1.11"}, + "0.6.5": {"tensorflow": "1.12"}, + "0.6": {"tensorflow": "1.12"}, }, - 'ray': { - '0.5.3': { - 'tensorflow': '1.11' - }, - '0.5': { - 'tensorflow': '1.11' - }, - '0.6.5': { - 'tensorflow': '1.12' - }, - '0.6': { - 'tensorflow': '1.12' - }, - } } class RLToolkit(enum.Enum): - COACH = 'coach' - RAY = 'ray' + COACH = "coach" + RAY = "ray" class RLFramework(enum.Enum): - TENSORFLOW = 'tensorflow' - MXNET = 'mxnet' + TENSORFLOW = "tensorflow" + MXNET = "mxnet" class RLEstimator(Framework): """Handle end-to-end training and deployment of custom RLEstimator code.""" - COACH_LATEST_VERSION_TF = '0.11.1' - COACH_LATEST_VERSION_MXNET = '0.11.0' - RAY_LATEST_VERSION = '0.6.5' - - def __init__(self, entry_point, toolkit=None, toolkit_version=None, framework=None, - source_dir=None, hyperparameters=None, image_name=None, - metric_definitions=None, **kwargs): + COACH_LATEST_VERSION_TF = "0.11.1" + COACH_LATEST_VERSION_MXNET = "0.11.0" + RAY_LATEST_VERSION = "0.6.5" + + def __init__( + self, + entry_point, + toolkit=None, + toolkit_version=None, + framework=None, + source_dir=None, + hyperparameters=None, + image_name=None, + metric_definitions=None, + **kwargs + ): """This Estimator executes an RLEstimator script in a managed Reinforcement Learning (RL) execution environment within a SageMaker Training Job. The managed RL environment is an Amazon-built Docker container that executes @@ -137,19 +126,31 @@ def __init__(self, entry_point, toolkit=None, toolkit_version=None, framework=No self.toolkit = toolkit.value self.toolkit_version = toolkit_version self.framework = framework.value - self.framework_version = \ - TOOLKIT_FRAMEWORK_VERSION_MAP[self.toolkit][self.toolkit_version][self.framework] + self.framework_version = TOOLKIT_FRAMEWORK_VERSION_MAP[self.toolkit][ + self.toolkit_version + ][self.framework] # set default metric_definitions based on the toolkit if not metric_definitions: metric_definitions = self.default_metric_definitions(toolkit) - super(RLEstimator, self).__init__(entry_point, source_dir, hyperparameters, - image_name=image_name, - metric_definitions=metric_definitions, **kwargs) - - def create_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, - entry_point=None, source_dir=None, dependencies=None): + super(RLEstimator, self).__init__( + entry_point, + source_dir, + hyperparameters, + image_name=image_name, + metric_definitions=metric_definitions, + **kwargs + ) + + def create_model( + self, + role=None, + vpc_config_override=VPC_CONFIG_DEFAULT, + entry_point=None, + source_dir=None, + dependencies=None, + ): """Create a SageMaker ``RLEstimatorModel`` object that can be deployed to an Endpoint. Args: @@ -185,26 +186,30 @@ def create_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, TensorFlow was used as RL backend. """ - base_args = dict(model_data=self.model_data, - role=role or self.role, - image=self.image_name, - name=self._current_job_name, - container_log_level=self.container_log_level, - sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + base_args = dict( + model_data=self.model_data, + role=role or self.role, + image=self.image_name, + name=self._current_job_name, + container_log_level=self.container_log_level, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) if not entry_point and (source_dir or dependencies): - raise AttributeError('Please provide an `entry_point`.') + raise AttributeError("Please provide an `entry_point`.") entry_point = entry_point or self.entry_point source_dir = source_dir or self._model_source_dir() dependencies = dependencies or self.dependencies - extended_args = dict(entry_point=entry_point, - source_dir=source_dir, - code_location=self.code_location, - dependencies=dependencies, - enable_cloudwatch_metrics=self.enable_cloudwatch_metrics) + extended_args = dict( + entry_point=entry_point, + source_dir=source_dir, + code_location=self.code_location, + dependencies=dependencies, + enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, + ) extended_args.update(base_args) if self.image_name: @@ -212,17 +217,19 @@ def create_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, if self.toolkit == RLToolkit.RAY.value: raise NotImplementedError( - 'Automatic deployment of Ray models is not currently available.' - ' Train policy parameters are available in model checkpoints' - ' in the TrainingJob output.' + "Automatic deployment of Ray models is not currently available." + " Train policy parameters are available in model checkpoints" + " in the TrainingJob output." ) if self.framework == RLFramework.TENSORFLOW.value: from sagemaker.tensorflow.serving import Model as tfsModel + return tfsModel(framework_version=self.framework_version, **base_args) elif self.framework == RLFramework.MXNET.value: - return MXNetModel(framework_version=self.framework_version, py_version=PYTHON_VERSION, - **extended_args) + return MXNetModel( + framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args + ) def train_image(self): """Return the Docker image to use for training. @@ -236,11 +243,13 @@ def train_image(self): if self.image_name: return self.image_name else: - return fw_utils.create_image_uri(self.sagemaker_session.boto_region_name, - self._image_framework(), - self.train_instance_type, - self._image_version(), - py_version=PYTHON_VERSION) + return fw_utils.create_image_uri( + self.sagemaker_session.boto_region_name, + self._image_framework(), + self.train_instance_type, + self._image_version(), + py_version=PYTHON_VERSION, + ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -254,30 +263,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na Returns: dictionary: The transformed init_params """ - init_params = super(RLEstimator, cls)\ - ._prepare_init_params_from_job_description(job_details, model_channel_name) + init_params = super(RLEstimator, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) - image_name = init_params.pop('image') + image_name = init_params.pop("image") framework, _, tag, _ = fw_utils.framework_name_from_image(image_name) if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. - init_params['image_name'] = image_name + init_params["image_name"] = image_name return init_params toolkit, toolkit_version = cls._toolkit_and_version_from_tag(tag) if not cls._is_combination_supported(toolkit, toolkit_version, framework): - training_job_name = init_params['base_job_name'] + training_job_name = init_params["base_job_name"] raise ValueError( "Training job: {} didn't use image for requested framework".format( - training_job_name) + training_job_name + ) ) - init_params['toolkit'] = RLToolkit(toolkit) - init_params['toolkit_version'] = toolkit_version - init_params['framework'] = RLFramework(framework) + init_params["toolkit"] = RLToolkit(toolkit) + init_params["toolkit_version"] = toolkit_version + init_params["framework"] = RLFramework(framework) return init_params @@ -285,16 +296,20 @@ def hyperparameters(self): """Return hyperparameters used by your custom TensorFlow code during model training.""" hyperparameters = super(RLEstimator, self).hyperparameters() - additional_hyperparameters = {SAGEMAKER_OUTPUT_LOCATION: self.output_path, - # TODO: can be applied to all other estimators - SAGEMAKER_ESTIMATOR: SAGEMAKER_ESTIMATOR_VALUE} + additional_hyperparameters = { + SAGEMAKER_OUTPUT_LOCATION: self.output_path, + # TODO: can be applied to all other estimators + SAGEMAKER_ESTIMATOR: SAGEMAKER_ESTIMATOR_VALUE, + } hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) return hyperparameters @classmethod def _toolkit_and_version_from_tag(cls, image_tag): - tag_pattern = re.compile('^([A-Z]*|[a-z]*)(\d.*)-(cpu|gpu)-(py2|py3)$') # noqa: W605,E501 pylint: disable=anomalous-backslash-in-string + tag_pattern = re.compile( + "^([A-Z]*|[a-z]*)(\d.*)-(cpu|gpu)-(py2|py3)$" # noqa: W605,E501 pylint: disable=anomalous-backslash-in-string + ) tag_match = tag_pattern.match(image_tag) if tag_match is not None: return tag_match.group(1), tag_match.group(2) @@ -304,16 +319,18 @@ def _toolkit_and_version_from_tag(cls, image_tag): def _validate_framework_format(cls, framework): if framework and framework not in RLFramework: raise ValueError( - 'Invalid type: {}, valid RL frameworks types are: [{}]'.format( - framework, [t for t in RLFramework]) + "Invalid type: {}, valid RL frameworks types are: [{}]".format( + framework, [t for t in RLFramework] + ) ) @classmethod def _validate_toolkit_format(cls, toolkit): if toolkit and toolkit not in RLToolkit: raise ValueError( - 'Invalid type: {}, valid RL toolkits types are: [{}]'.format( - toolkit, [t for t in RLToolkit]) + "Invalid type: {}, valid RL toolkits types are: [{}]".format( + toolkit, [t for t in RLToolkit] + ) ) @classmethod @@ -324,29 +341,31 @@ def _validate_images_args(cls, toolkit, toolkit_version, framework, image_name): if not image_name: not_found_args = [] if not toolkit: - not_found_args.append('toolkit') + not_found_args.append("toolkit") if not toolkit_version: - not_found_args.append('toolkit_version') + not_found_args.append("toolkit_version") if not framework: - not_found_args.append('framework') + not_found_args.append("framework") if not_found_args: raise AttributeError( - 'Please provide `{}` or `image_name` parameter.' - .format('`, `'.join(not_found_args)) + "Please provide `{}` or `image_name` parameter.".format( + "`, `".join(not_found_args) + ) ) else: found_args = [] if toolkit: - found_args.append('toolkit') + found_args.append("toolkit") if toolkit_version: - found_args.append('toolkit_version') + found_args.append("toolkit_version") if framework: - found_args.append('framework') + found_args.append("framework") if found_args: logger.warning( - 'Parameter `image_name` is specified, ' - '`{}` are going to be ignored when choosing the image.' - .format('`, `'.join(found_args)) + "Parameter `image_name` is specified, " + "`{}` are going to be ignored when choosing the image.".format( + "`, `".join(found_args) + ) ) @classmethod @@ -362,15 +381,16 @@ def _is_combination_supported(cls, toolkit, toolkit_version, framework): def _validate_toolkit_support(cls, toolkit, toolkit_version, framework): if not cls._is_combination_supported(toolkit, toolkit_version, framework): raise AttributeError( - 'Provided `{}-{}` and `{}` combination is not supported.' - .format(toolkit, toolkit_version, framework) + "Provided `{}-{}` and `{}` combination is not supported.".format( + toolkit, toolkit_version, framework + ) ) def _image_version(self): - return '{}{}'.format(self.toolkit, self.toolkit_version) + return "{}{}".format(self.toolkit, self.toolkit_version) def _image_framework(self): - return 'rl-{}'.format(self.framework) + return "rl-{}".format(self.framework) @classmethod def default_metric_definitions(cls, toolkit): @@ -384,16 +404,13 @@ def default_metric_definitions(cls, toolkit): """ if toolkit is RLToolkit.COACH: return [ - {'Name': 'reward-training', - 'Regex': '^Training>.*Total reward=(.*?),'}, - {'Name': 'reward-testing', - 'Regex': '^Testing>.*Total reward=(.*?),'} + {"Name": "reward-training", "Regex": "^Training>.*Total reward=(.*?),"}, + {"Name": "reward-testing", "Regex": "^Testing>.*Total reward=(.*?),"}, ] elif toolkit is RLToolkit.RAY: - float_regex = "[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501 pylint: disable=anomalous-backslash-in-string + float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501 + return [ - {'Name': 'episode_reward_mean', - 'Regex': 'episode_reward_mean: (%s)' % float_regex}, - {'Name': 'episode_reward_max', - 'Regex': 'episode_reward_max: (%s)' % float_regex} + {"Name": "episode_reward_mean", "Regex": "episode_reward_mean: (%s)" % float_regex}, + {"Name": "episode_reward_max", "Regex": "episode_reward_max: (%s)" % float_regex}, ] diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 75b7697668..90cffa3eca 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -29,17 +29,21 @@ import sagemaker.logs from sagemaker import vpc_utils from sagemaker.user_agent import prepend_user_agent -from sagemaker.utils import name_from_image, secondary_training_status_changed, secondary_training_status_message +from sagemaker.utils import ( + name_from_image, + secondary_training_status_changed, + secondary_training_status_message, +) -LOGGER = logging.getLogger('sagemaker') +LOGGER = logging.getLogger("sagemaker") _STATUS_CODE_TABLE = { - 'COMPLETED': 'Completed', - 'INPROGRESS': 'InProgress', - 'FAILED': 'Failed', - 'STOPPED': 'Stopped', - 'STOPPING': 'Stopping', - 'STARTING': 'Starting' + "COMPLETED": "Completed", + "INPROGRESS": "InProgress", + "FAILED": "Failed", + "STOPPED": "Stopped", + "STOPPING": "Stopping", + "STARTING": "Starting", } @@ -78,9 +82,9 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c """ self._default_bucket = None - sagemaker_config_file = os.path.join(os.path.expanduser('~'), '.sagemaker', 'config.yaml') + sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") if os.path.exists(sagemaker_config_file): - self.config = yaml.load(open(sagemaker_config_file, 'r')) + self.config = yaml.load(open(sagemaker_config_file, "r")) else: self.config = None @@ -96,16 +100,20 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self._region_name = self.boto_session.region_name if self._region_name is None: - raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.') + raise ValueError( + "Must setup local AWS configuration with a region supported by SageMaker." + ) - self.sagemaker_client = sagemaker_client or self.boto_session.client('sagemaker') + self.sagemaker_client = sagemaker_client or self.boto_session.client("sagemaker") prepend_user_agent(self.sagemaker_client) if sagemaker_runtime_client is not None: self.sagemaker_runtime_client = sagemaker_runtime_client else: config = botocore.config.Config(read_timeout=80) - self.sagemaker_runtime_client = self.boto_session.client('runtime.sagemaker', config=config) + self.sagemaker_runtime_client = self.boto_session.client( + "runtime.sagemaker", config=config + ) prepend_user_agent(self.sagemaker_runtime_client) @@ -115,7 +123,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): def boto_region_name(self): return self._region_name - def upload_data(self, path, bucket=None, key_prefix='data', extra_args=None): + def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None): """Upload local file or directory to S3. If a single file is specified for upload, the resulting S3 object key is ``{key_prefix}/{filename}`` @@ -149,27 +157,29 @@ def upload_data(self, path, bucket=None, key_prefix='data', extra_args=None): for dirpath, _, filenames in os.walk(path): for name in filenames: local_path = os.path.join(dirpath, name) - s3_relative_prefix = '' if path == dirpath else os.path.relpath(dirpath, start=path) + '/' - s3_key = '{}/{}{}'.format(key_prefix, s3_relative_prefix, name) + s3_relative_prefix = ( + "" if path == dirpath else os.path.relpath(dirpath, start=path) + "/" + ) + s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name) files.append((local_path, s3_key)) else: _, name = os.path.split(path) - s3_key = '{}/{}'.format(key_prefix, name) + s3_key = "{}/{}".format(key_prefix, name) files.append((path, s3_key)) key_suffix = name bucket = bucket or self.default_bucket() - s3 = self.boto_session.resource('s3') + s3 = self.boto_session.resource("s3") for local_path, s3_key in files: s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args) - s3_uri = 's3://{}/{}'.format(bucket, key_prefix) + s3_uri = "s3://{}/{}".format(bucket, key_prefix) # If a specific file was used as input (instead of a directory), we return the full S3 key # of the uploaded object. This prevents unintentionally using other files under the same prefix # during training. if key_suffix: - s3_uri = '{}/{}'.format(s3_uri, key_suffix) + s3_uri = "{}/{}".format(s3_uri, key_suffix) return s3_uri def default_bucket(self): @@ -181,30 +191,34 @@ def default_bucket(self): if self._default_bucket: return self._default_bucket - account = self.boto_session.client('sts').get_caller_identity()['Account'] + account = self.boto_session.client("sts").get_caller_identity()["Account"] region = self.boto_session.region_name - default_bucket = 'sagemaker-{}-{}'.format(region, account) + default_bucket = "sagemaker-{}-{}".format(region, account) - s3 = self.boto_session.resource('s3') + s3 = self.boto_session.resource("s3") try: # 'us-east-1' cannot be specified because it is the default region: # https://github.com/boto/boto3/issues/125 - if region == 'us-east-1': + if region == "us-east-1": s3.create_bucket(Bucket=default_bucket) else: - s3.create_bucket(Bucket=default_bucket, CreateBucketConfiguration={'LocationConstraint': region}) + s3.create_bucket( + Bucket=default_bucket, CreateBucketConfiguration={"LocationConstraint": region} + ) - LOGGER.info('Created S3 bucket: {}'.format(default_bucket)) + LOGGER.info("Created S3 bucket: {}".format(default_bucket)) except ClientError as e: - error_code = e.response['Error']['Code'] - message = e.response['Error']['Message'] + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] - if error_code == 'BucketAlreadyOwnedByYou': + if error_code == "BucketAlreadyOwnedByYou": pass - elif error_code == 'OperationAborted' and 'conflicting conditional operation' in message: + elif ( + error_code == "OperationAborted" and "conflicting conditional operation" in message + ): # If this bucket is already being concurrently created, we don't need to create it again. pass - elif error_code == 'TooManyBuckets': + elif error_code == "TooManyBuckets": # Succeed if the default bucket exists s3.meta.client.head_bucket(Bucket=default_bucket) else: @@ -214,10 +228,24 @@ def default_bucket(self): return self._default_bucket - def train(self, input_mode, input_config, role, job_name, output_config, # noqa: C901 - resource_config, vpc_config, hyperparameters, stop_condition, tags, metric_definitions, - enable_network_isolation=False, image=None, algorithm_arn=None, - encrypt_inter_container_traffic=False): + def train( # noqa: C901 + self, + input_mode, + input_config, + role, + job_name, + output_config, + resource_config, + vpc_config, + hyperparameters, + stop_condition, + tags, + metric_definitions, + enable_network_isolation=False, + image=None, + algorithm_arn=None, + encrypt_inter_container_traffic=False, + ): """Create an Amazon SageMaker training job. Args: @@ -271,57 +299,57 @@ def train(self, input_mode, input_config, role, job_name, output_config, # noqa """ train_request = { - 'AlgorithmSpecification': { - 'TrainingInputMode': input_mode - }, - 'OutputDataConfig': output_config, - 'TrainingJobName': job_name, - 'StoppingCondition': stop_condition, - 'ResourceConfig': resource_config, - 'RoleArn': role, + "AlgorithmSpecification": {"TrainingInputMode": input_mode}, + "OutputDataConfig": output_config, + "TrainingJobName": job_name, + "StoppingCondition": stop_condition, + "ResourceConfig": resource_config, + "RoleArn": role, } if image and algorithm_arn: - raise ValueError('image and algorithm_arn are mutually exclusive.' - 'Both were provided: image: %s algorithm_arn: %s' % (image, algorithm_arn)) + raise ValueError( + "image and algorithm_arn are mutually exclusive." + "Both were provided: image: %s algorithm_arn: %s" % (image, algorithm_arn) + ) if image is None and algorithm_arn is None: - raise ValueError('either image or algorithm_arn is required. None was provided.') + raise ValueError("either image or algorithm_arn is required. None was provided.") if image is not None: - train_request['AlgorithmSpecification']['TrainingImage'] = image + train_request["AlgorithmSpecification"]["TrainingImage"] = image if algorithm_arn is not None: - train_request['AlgorithmSpecification']['AlgorithmName'] = algorithm_arn + train_request["AlgorithmSpecification"]["AlgorithmName"] = algorithm_arn if input_config is not None: - train_request['InputDataConfig'] = input_config + train_request["InputDataConfig"] = input_config if metric_definitions is not None: - train_request['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions + train_request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions if hyperparameters and len(hyperparameters) > 0: - train_request['HyperParameters'] = hyperparameters + train_request["HyperParameters"] = hyperparameters if tags is not None: - train_request['Tags'] = tags + train_request["Tags"] = tags if vpc_config is not None: - train_request['VpcConfig'] = vpc_config + train_request["VpcConfig"] = vpc_config if enable_network_isolation: - train_request['EnableNetworkIsolation'] = enable_network_isolation + train_request["EnableNetworkIsolation"] = enable_network_isolation if encrypt_inter_container_traffic: - train_request['EnableInterContainerTrafficEncryption'] = \ - encrypt_inter_container_traffic + train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic - LOGGER.info('Creating training-job with name: {}'.format(job_name)) - LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4))) + LOGGER.info("Creating training-job with name: {}".format(job_name)) + LOGGER.debug("train request: {}".format(json.dumps(train_request, indent=4))) self.sagemaker_client.create_training_job(**train_request) - def compile_model(self, input_model_config, output_model_config, role, - job_name, stop_condition, tags): + def compile_model( + self, input_model_config, output_model_config, role, job_name, stop_condition, tags + ): """Create an Amazon SageMaker Neo compilation job. Args: @@ -341,25 +369,45 @@ def compile_model(self, input_model_config, output_model_config, role, """ compilation_job_request = { - 'InputConfig': input_model_config, - 'OutputConfig': output_model_config, - 'RoleArn': role, - 'StoppingCondition': stop_condition, - 'CompilationJobName': job_name + "InputConfig": input_model_config, + "OutputConfig": output_model_config, + "RoleArn": role, + "StoppingCondition": stop_condition, + "CompilationJobName": job_name, } if tags is not None: - compilation_job_request['Tags'] = tags + compilation_job_request["Tags"] = tags - LOGGER.info('Creating compilation-job with name: {}'.format(job_name)) + LOGGER.info("Creating compilation-job with name: {}".format(job_name)) self.sagemaker_client.create_compilation_job(**compilation_job_request) - def tune(self, job_name, strategy, objective_type, objective_metric_name, - max_jobs, max_parallel_jobs, parameter_ranges, - static_hyperparameters, input_mode, metric_definitions, - role, input_config, output_config, resource_config, stop_condition, tags, - warm_start_config, enable_network_isolation=False, image=None, algorithm_arn=None, - early_stopping_type='Off', encrypt_inter_container_traffic=False, vpc_config=None): + def tune( + self, + job_name, + strategy, + objective_type, + objective_metric_name, + max_jobs, + max_parallel_jobs, + parameter_ranges, + static_hyperparameters, + input_mode, + metric_definitions, + role, + input_config, + output_config, + resource_config, + stop_condition, + tags, + warm_start_config, + enable_network_isolation=False, + image=None, + algorithm_arn=None, + early_stopping_type="Off", + encrypt_inter_container_traffic=False, + vpc_config=None, + ): """Create an Amazon SageMaker hyperparameter tuning job Args: @@ -419,62 +467,62 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name, """ tune_request = { - 'HyperParameterTuningJobName': job_name, - 'HyperParameterTuningJobConfig': { - 'Strategy': strategy, - 'HyperParameterTuningJobObjective': { - 'Type': objective_type, - 'MetricName': objective_metric_name, + "HyperParameterTuningJobName": job_name, + "HyperParameterTuningJobConfig": { + "Strategy": strategy, + "HyperParameterTuningJobObjective": { + "Type": objective_type, + "MetricName": objective_metric_name, }, - 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': max_jobs, - 'MaxParallelTrainingJobs': max_parallel_jobs, + "ResourceLimits": { + "MaxNumberOfTrainingJobs": max_jobs, + "MaxParallelTrainingJobs": max_parallel_jobs, }, - 'ParameterRanges': parameter_ranges, - 'TrainingJobEarlyStoppingType': early_stopping_type, + "ParameterRanges": parameter_ranges, + "TrainingJobEarlyStoppingType": early_stopping_type, + }, + "TrainingJobDefinition": { + "StaticHyperParameters": static_hyperparameters, + "RoleArn": role, + "OutputDataConfig": output_config, + "ResourceConfig": resource_config, + "StoppingCondition": stop_condition, }, - 'TrainingJobDefinition': { - 'StaticHyperParameters': static_hyperparameters, - 'RoleArn': role, - 'OutputDataConfig': output_config, - 'ResourceConfig': resource_config, - 'StoppingCondition': stop_condition, - } } - algorithm_spec = { - 'TrainingInputMode': input_mode - } + algorithm_spec = {"TrainingInputMode": input_mode} if algorithm_arn: - algorithm_spec['AlgorithmName'] = algorithm_arn + algorithm_spec["AlgorithmName"] = algorithm_arn else: - algorithm_spec['TrainingImage'] = image + algorithm_spec["TrainingImage"] = image - tune_request['TrainingJobDefinition']['AlgorithmSpecification'] = algorithm_spec + tune_request["TrainingJobDefinition"]["AlgorithmSpecification"] = algorithm_spec if input_config is not None: - tune_request['TrainingJobDefinition']['InputDataConfig'] = input_config + tune_request["TrainingJobDefinition"]["InputDataConfig"] = input_config if warm_start_config: - tune_request['WarmStartConfig'] = warm_start_config + tune_request["WarmStartConfig"] = warm_start_config if metric_definitions is not None: - tune_request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions + tune_request["TrainingJobDefinition"]["AlgorithmSpecification"][ + "MetricDefinitions" + ] = metric_definitions if tags is not None: - tune_request['Tags'] = tags + tune_request["Tags"] = tags if vpc_config is not None: - tune_request['TrainingJobDefinition']['VpcConfig'] = vpc_config + tune_request["TrainingJobDefinition"]["VpcConfig"] = vpc_config if enable_network_isolation: - tune_request['TrainingJobDefinition']['EnableNetworkIsolation'] = True + tune_request["TrainingJobDefinition"]["EnableNetworkIsolation"] = True if encrypt_inter_container_traffic: - tune_request['TrainingJobDefinition']['EnableInterContainerTrafficEncryption'] = True + tune_request["TrainingJobDefinition"]["EnableInterContainerTrafficEncryption"] = True - LOGGER.info('Creating hyperparameter tuning job with name: {}'.format(job_name)) - LOGGER.debug('tune request: {}'.format(json.dumps(tune_request, indent=4))) + LOGGER.info("Creating hyperparameter tuning job with name: {}".format(job_name)) + LOGGER.debug("tune request: {}".format(json.dumps(tune_request, indent=4))) self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request) def stop_tuning_job(self, name): @@ -487,19 +535,35 @@ def stop_tuning_job(self, name): ClientError: If an error occurs while trying to stop the hyperparameter tuning job. """ try: - LOGGER.info('Stopping tuning job: {}'.format(name)) + LOGGER.info("Stopping tuning job: {}".format(name)) self.sagemaker_client.stop_hyper_parameter_tuning_job(HyperParameterTuningJobName=name) except ClientError as e: - error_code = e.response['Error']['Code'] + error_code = e.response["Error"]["Code"] # allow to pass if the job already stopped - if error_code == 'ValidationException': - LOGGER.info('Tuning job: {} is already stopped or not running.'.format(name)) + if error_code == "ValidationException": + LOGGER.info("Tuning job: {} is already stopped or not running.".format(name)) else: - LOGGER.error('Error occurred while attempting to stop tuning job: {}. Please try again.'.format(name)) + LOGGER.error( + "Error occurred while attempting to stop tuning job: {}. Please try again.".format( + name + ) + ) raise - def transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env, - input_config, output_config, resource_config, tags, data_processing): + def transform( + self, + job_name, + model_name, + strategy, + max_concurrent_transforms, + max_payload, + env, + input_config, + output_config, + resource_config, + tags, + data_processing, + ): """Create an Amazon SageMaker transform job. Args: @@ -519,38 +583,45 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. """ transform_request = { - 'TransformJobName': job_name, - 'ModelName': model_name, - 'TransformInput': input_config, - 'TransformOutput': output_config, - 'TransformResources': resource_config, + "TransformJobName": job_name, + "ModelName": model_name, + "TransformInput": input_config, + "TransformOutput": output_config, + "TransformResources": resource_config, } if strategy is not None: - transform_request['BatchStrategy'] = strategy + transform_request["BatchStrategy"] = strategy if max_concurrent_transforms is not None: - transform_request['MaxConcurrentTransforms'] = max_concurrent_transforms + transform_request["MaxConcurrentTransforms"] = max_concurrent_transforms if max_payload is not None: - transform_request['MaxPayloadInMB'] = max_payload + transform_request["MaxPayloadInMB"] = max_payload if env is not None: - transform_request['Environment'] = env + transform_request["Environment"] = env if tags is not None: - transform_request['Tags'] = tags + transform_request["Tags"] = tags if data_processing is not None: - transform_request['DataProcessing'] = data_processing + transform_request["DataProcessing"] = data_processing - LOGGER.info('Creating transform job with name: {}'.format(job_name)) - LOGGER.debug('Transform request: {}'.format(json.dumps(transform_request, indent=4))) + LOGGER.info("Creating transform job with name: {}".format(job_name)) + LOGGER.debug("Transform request: {}".format(json.dumps(transform_request, indent=4))) self.sagemaker_client.create_transform_job(**transform_request) - def create_model(self, name, role, container_defs, vpc_config=None, - enable_network_isolation=False, primary_container=None, - tags=None): + def create_model( + self, + name, + role, + container_defs, + vpc_config=None, + enable_network_isolation=False, + primary_container=None, + tags=None, + ): """Create an Amazon SageMaker ``Model``. Specify the S3 location of the model artifacts and Docker image containing the inference code. Amazon SageMaker uses this information to deploy the @@ -586,10 +657,10 @@ def create_model(self, name, role, container_defs, vpc_config=None, str: Name of the Amazon SageMaker ``Model`` created. """ if container_defs and primary_container: - raise ValueError('Both container_defs and primary_container can not be passed as input') + raise ValueError("Both container_defs and primary_container can not be passed as input") if primary_container: - msg = 'primary_container is going to be deprecated in a future release. Please use container_defs instead.' + msg = "primary_container is going to be deprecated in a future release. Please use container_defs instead." warnings.warn(msg, DeprecationWarning) container_defs = primary_container @@ -600,36 +671,46 @@ def create_model(self, name, role, container_defs, vpc_config=None, else: container_definition = _expand_container_def(container_defs) - create_model_request = _create_model_request(name=name, - role=role, - container_def=container_definition, - tags=tags) + create_model_request = _create_model_request( + name=name, role=role, container_def=container_definition, tags=tags + ) if vpc_config: - create_model_request['VpcConfig'] = vpc_config + create_model_request["VpcConfig"] = vpc_config if enable_network_isolation: - create_model_request['EnableNetworkIsolation'] = True + create_model_request["EnableNetworkIsolation"] = True - LOGGER.info('Creating model with name: {}'.format(name)) - LOGGER.debug('CreateModel request: {}'.format(json.dumps(create_model_request, indent=4))) + LOGGER.info("Creating model with name: {}".format(name)) + LOGGER.debug("CreateModel request: {}".format(json.dumps(create_model_request, indent=4))) try: self.sagemaker_client.create_model(**create_model_request) except ClientError as e: - error_code = e.response['Error']['Code'] - message = e.response['Error']['Message'] + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] - if error_code == 'ValidationException' and 'Cannot create already existing model' in message: - LOGGER.warning('Using already existing model: {}'.format(name)) + if ( + error_code == "ValidationException" + and "Cannot create already existing model" in message + ): + LOGGER.warning("Using already existing model: {}".format(name)) else: raise return name - def create_model_from_job(self, training_job_name, name=None, role=None, primary_container_image=None, - model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, - tags=None): + def create_model_from_job( + self, + training_job_name, + name=None, + role=None, + primary_container_image=None, + model_data_url=None, + env=None, + vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + tags=None, + ): """Create an Amazon SageMaker ``Model`` from a SageMaker Training Job. Args: @@ -653,14 +734,17 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary Returns: str: The name of the created ``Model``. """ - training_job = self.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) + training_job = self.sagemaker_client.describe_training_job( + TrainingJobName=training_job_name + ) name = name or training_job_name - role = role or training_job['RoleArn'] + role = role or training_job["RoleArn"] env = env or {} primary_container = container_def( - primary_container_image or training_job['AlgorithmSpecification']['TrainingImage'], - model_data_url=model_data_url or training_job['ModelArtifacts']['S3ModelArtifacts'], - env=env) + primary_container_image or training_job["AlgorithmSpecification"]["TrainingImage"], + model_data_url=model_data_url or training_job["ModelArtifacts"]["S3ModelArtifacts"], + env=env, + ) vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override) return self.create_model(name, role, primary_container, vpc_config=vpc_config, tags=tags) @@ -674,29 +758,21 @@ def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data (str): s3 URI to the model artifacts produced by training """ request = { - 'ModelPackageName': name, - 'ModelPackageDescription': description, - 'SourceAlgorithmSpecification': { - 'SourceAlgorithms': [ - { - 'AlgorithmName': algorithm_arn, - 'ModelDataUrl': model_data - } - ] - } + "ModelPackageName": name, + "ModelPackageDescription": description, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [{"AlgorithmName": algorithm_arn, "ModelDataUrl": model_data}] + }, } try: - LOGGER.info('Creating model package with name: {}'.format(name)) + LOGGER.info("Creating model package with name: {}".format(name)) self.sagemaker_client.create_model_package(**request) except ClientError as e: - error_code = e.response['Error']['Code'] - message = e.response['Error']['Message'] + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] - if ( - error_code == 'ValidationException' - and 'ModelPackage already exists' in message - ): - LOGGER.warning('Using already existing model package: {}'.format(name)) + if error_code == "ValidationException" and "ModelPackage already exists" in message: + LOGGER.warning("Using already existing model package: {}".format(name)) else: raise @@ -710,18 +786,30 @@ def wait_for_model_package(self, model_package_name, poll=5): Returns: dict: Return value from the ``DescribeEndpoint`` API. """ - desc = _wait_until(lambda: _create_model_package_status(self.sagemaker_client, model_package_name), - poll) - status = desc['ModelPackageStatus'] - - if status != 'Completed': - reason = desc.get('FailureReason', None) - raise ValueError('Error creating model package {}: {} Reason: {}'.format( - model_package_name, status, reason)) + desc = _wait_until( + lambda: _create_model_package_status(self.sagemaker_client, model_package_name), poll + ) + status = desc["ModelPackageStatus"] + + if status != "Completed": + reason = desc.get("FailureReason", None) + raise ValueError( + "Error creating model package {}: {} Reason: {}".format( + model_package_name, status, reason + ) + ) return desc - def create_endpoint_config(self, name, model_name, initial_instance_count, instance_type, - accelerator_type=None, tags=None, kms_key=None): + def create_endpoint_config( + self, + name, + model_name, + initial_instance_count, + instance_type, + accelerator_type=None, + tags=None, + kms_key=None, + ): """Create an Amazon SageMaker endpoint configuration. The endpoint configuration identifies the Amazon SageMaker model (created using the @@ -745,23 +833,27 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta Returns: str: Name of the endpoint point configuration created. """ - LOGGER.info('Creating endpoint-config with name {}'.format(name)) + LOGGER.info("Creating endpoint-config with name {}".format(name)) tags = tags or [] request = { - 'EndpointConfigName': name, - 'ProductionVariants': [ - production_variant(model_name, instance_type, initial_instance_count, - accelerator_type=accelerator_type) + "EndpointConfigName": name, + "ProductionVariants": [ + production_variant( + model_name, + instance_type, + initial_instance_count, + accelerator_type=accelerator_type, + ) ], } if tags is not None: - request['Tags'] = tags + request["Tags"] = tags if kms_key is not None: - request['KmsKeyId'] = kms_key + request["KmsKeyId"] = kms_key self.sagemaker_client.create_endpoint_config(**request) return name @@ -780,11 +872,13 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): Returns: str: Name of the Amazon SageMaker ``Endpoint`` created. """ - LOGGER.info('Creating endpoint with name {}'.format(endpoint_name)) + LOGGER.info("Creating endpoint with name {}".format(endpoint_name)) tags = tags or [] - self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags) + self.sagemaker_client.create_endpoint( + EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags + ) if wait: self.wait_for_endpoint(endpoint_name) return endpoint_name @@ -801,12 +895,18 @@ def update_endpoint(self, endpoint_name, endpoint_config_name): Returns: str: Name of the Amazon SageMaker ``Endpoint`` being updated. """ - if not _deployment_entity_exists(lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)): - raise ValueError('Endpoint with name "{}" does not exist; please use an existing endpoint name' - .format(endpoint_name)) - - self.sagemaker_client.update_endpoint(EndpointName=endpoint_name, - EndpointConfigName=endpoint_config_name) + if not _deployment_entity_exists( + lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) + ): + raise ValueError( + 'Endpoint with name "{}" does not exist; please use an existing endpoint name'.format( + endpoint_name + ) + ) + + self.sagemaker_client.update_endpoint( + EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name + ) return endpoint_name def delete_endpoint(self, endpoint_name): @@ -815,7 +915,7 @@ def delete_endpoint(self, endpoint_name): Args: endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to delete. """ - LOGGER.info('Deleting endpoint with name: {}'.format(endpoint_name)) + LOGGER.info("Deleting endpoint with name: {}".format(endpoint_name)) self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name) def delete_endpoint_config(self, endpoint_config_name): @@ -824,7 +924,7 @@ def delete_endpoint_config(self, endpoint_config_name): Args: endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete. """ - LOGGER.info('Deleting endpoint configuration with name: {}'.format(endpoint_config_name)) + LOGGER.info("Deleting endpoint configuration with name: {}".format(endpoint_config_name)) self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) def delete_model(self, model_name): @@ -834,7 +934,7 @@ def delete_model(self, model_name): model_name (str): Name of the Amazon SageMaker model to delete. """ - LOGGER.info('Deleting model with name: {}'.format(model_name)) + LOGGER.info("Deleting model with name: {}".format(model_name)) self.sagemaker_client.delete_model(ModelName=model_name) def wait_for_job(self, job, poll=5): @@ -850,9 +950,10 @@ def wait_for_job(self, job, poll=5): Raises: ValueError: If the training job fails. """ - desc = _wait_until_training_done(lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), - None, poll) - self._check_job_status(job, desc, 'TrainingJobStatus') + desc = _wait_until_training_done( + lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll + ) + self._check_job_status(job, desc, "TrainingJobStatus") return desc def wait_for_compilation_job(self, job, poll=5): @@ -869,7 +970,7 @@ def wait_for_compilation_job(self, job, poll=5): ValueError: If the compilation job fails. """ desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, 'CompilationJobStatus') + self._check_job_status(job, desc, "CompilationJobStatus") return desc def wait_for_tuning_job(self, job, poll=5): @@ -886,7 +987,7 @@ def wait_for_tuning_job(self, job, poll=5): ValueError: If the hyperparameter tuning job fails. """ desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, 'HyperParameterTuningJobStatus') + self._check_job_status(job, desc, "HyperParameterTuningJobStatus") return desc def wait_for_transform_job(self, job, poll=5): @@ -903,7 +1004,7 @@ def wait_for_transform_job(self, job, poll=5): ValueError: If the transform job fails. """ desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll) - self._check_job_status(job, desc, 'TransformJobStatus') + self._check_job_status(job, desc, "TransformJobStatus") return desc def _check_job_status(self, job, desc, status_key_name): @@ -922,10 +1023,10 @@ def _check_job_status(self, job, desc, status_key_name): # If the status is capital case, then convert it to Camel case status = _STATUS_CODE_TABLE.get(status, status) - if status != 'Completed' and status != 'Stopped': - reason = desc.get('FailureReason', '(No reason provided)') - job_type = status_key_name.replace('JobStatus', ' job') - raise ValueError('Error for {} {}: {} Reason: {}'.format(job_type, job, status, reason)) + if status != "Completed" and status != "Stopped": + reason = desc.get("FailureReason", "(No reason provided)") + job_type = status_key_name.replace("JobStatus", " job") + raise ValueError("Error for {} {}: {} Reason: {}".format(job_type, job, status, reason)) def wait_for_endpoint(self, endpoint, poll=5): """Wait for an Amazon SageMaker endpoint deployment to complete. @@ -938,17 +1039,28 @@ def wait_for_endpoint(self, endpoint, poll=5): dict: Return value from the ``DescribeEndpoint`` API. """ desc = _wait_until(lambda: _deploy_done(self.sagemaker_client, endpoint), poll) - status = desc['EndpointStatus'] + status = desc["EndpointStatus"] - if status != 'InService': - reason = desc.get('FailureReason', None) - raise ValueError('Error hosting endpoint {}: {} Reason: {}'.format(endpoint, status, reason)) + if status != "InService": + reason = desc.get("FailureReason", None) + raise ValueError( + "Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason) + ) return desc - def endpoint_from_job(self, job_name, initial_instance_count, instance_type, - deployment_image=None, name=None, role=None, wait=True, - model_environment_vars=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, - accelerator_type=None): + def endpoint_from_job( + self, + job_name, + initial_instance_count, + instance_type, + deployment_image=None, + name=None, + role=None, + wait=True, + model_environment_vars=None, + vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, + accelerator_type=None, + ): """Create an ``Endpoint`` using the results of a successful training job. Specify the job name, Docker image containing the inference code, and hardware configuration to deploy @@ -983,21 +1095,38 @@ def endpoint_from_job(self, job_name, initial_instance_count, instance_type, str: Name of the ``Endpoint`` that is created. """ job_desc = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) - output_url = job_desc['ModelArtifacts']['S3ModelArtifacts'] - deployment_image = deployment_image or job_desc['AlgorithmSpecification']['TrainingImage'] - role = role or job_desc['RoleArn'] + output_url = job_desc["ModelArtifacts"]["S3ModelArtifacts"] + deployment_image = deployment_image or job_desc["AlgorithmSpecification"]["TrainingImage"] + role = role or job_desc["RoleArn"] name = name or job_name vpc_config_override = _vpc_config_from_training_job(job_desc, vpc_config_override) - return self.endpoint_from_model_data(model_s3_location=output_url, deployment_image=deployment_image, - initial_instance_count=initial_instance_count, instance_type=instance_type, - name=name, role=role, wait=wait, - model_environment_vars=model_environment_vars, - model_vpc_config=vpc_config_override, accelerator_type=accelerator_type) - - def endpoint_from_model_data(self, model_s3_location, deployment_image, initial_instance_count, instance_type, - name=None, role=None, wait=True, model_environment_vars=None, model_vpc_config=None, - accelerator_type=None): + return self.endpoint_from_model_data( + model_s3_location=output_url, + deployment_image=deployment_image, + initial_instance_count=initial_instance_count, + instance_type=instance_type, + name=name, + role=role, + wait=wait, + model_environment_vars=model_environment_vars, + model_vpc_config=vpc_config_override, + accelerator_type=accelerator_type, + ) + + def endpoint_from_model_data( + self, + model_s3_location, + deployment_image, + initial_instance_count, + instance_type, + name=None, + role=None, + wait=True, + model_environment_vars=None, + model_vpc_config=None, + accelerator_type=None, + ): """Create and deploy to an ``Endpoint`` using existing model data stored in S3. Args: @@ -1029,30 +1158,40 @@ def endpoint_from_model_data(self, model_s3_location, deployment_image, initial_ name = name or name_from_image(deployment_image) model_vpc_config = vpc_utils.sanitize(model_vpc_config) - if _deployment_entity_exists(lambda: self.sagemaker_client.describe_endpoint(EndpointName=name)): - raise ValueError('Endpoint with name "{}" already exists; please pick a different name.'.format(name)) + if _deployment_entity_exists( + lambda: self.sagemaker_client.describe_endpoint(EndpointName=name) + ): + raise ValueError( + 'Endpoint with name "{}" already exists; please pick a different name.'.format(name) + ) - if not _deployment_entity_exists(lambda: self.sagemaker_client.describe_model(ModelName=name)): - primary_container = container_def(image=deployment_image, - model_data_url=model_s3_location, - env=model_environment_vars) - self.create_model(name=name, - role=role, - container_defs=primary_container, - vpc_config=model_vpc_config) + if not _deployment_entity_exists( + lambda: self.sagemaker_client.describe_model(ModelName=name) + ): + primary_container = container_def( + image=deployment_image, model_data_url=model_s3_location, env=model_environment_vars + ) + self.create_model( + name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config + ) if not _deployment_entity_exists( - lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)): - self.create_endpoint_config(name=name, - model_name=name, - initial_instance_count=initial_instance_count, - instance_type=instance_type, - accelerator_type=accelerator_type) + lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name) + ): + self.create_endpoint_config( + name=name, + model_name=name, + initial_instance_count=initial_instance_count, + instance_type=instance_type, + accelerator_type=accelerator_type, + ) self.create_endpoint(endpoint_name=name, config_name=name, wait=wait) return name - def endpoint_from_production_variants(self, name, production_variants, tags=None, kms_key=None, wait=True): + def endpoint_from_production_variants( + self, name, production_variants, tags=None, kms_key=None, wait=True + ): """Create an SageMaker ``Endpoint`` from a list of production variants. Args: @@ -1068,12 +1207,13 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None """ if not _deployment_entity_exists( - lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)): - config_options = {'EndpointConfigName': name, 'ProductionVariants': production_variants} + lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name) + ): + config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants} if tags: - config_options['Tags'] = tags + config_options["Tags"] = tags if kms_key: - config_options['KmsKeyId'] = kms_key + config_options["KmsKeyId"] = kms_key self.sagemaker_client.create_endpoint_config(**config_options) return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait) @@ -1090,35 +1230,44 @@ def expand_role(self, role): Returns: str: The corresponding AWS IAM role ARN. """ - if '/' in role: + if "/" in role: return role else: - return self.boto_session.resource('iam').Role(role).arn + return self.boto_session.resource("iam").Role(role).arn def get_caller_identity_arn(self): """Returns the ARN user or role whose credentials are used to call the API. Returns: (str): The ARN user or role """ - assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn'] - - if 'AmazonSageMaker-ExecutionRole' in assumed_role: - role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/service-role/\3', assumed_role) + assumed_role = self.boto_session.client("sts").get_caller_identity()["Arn"] + + if "AmazonSageMaker-ExecutionRole" in assumed_role: + role = re.sub( + r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", + r"\1iam::\2:role/service-role/\3", + assumed_role, + ) return role - role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/\3', assumed_role) + role = re.sub(r"^(.+)sts::(\d+):assumed-role/(.+?)/.*$", r"\1iam::\2:role/\3", assumed_role) # Call IAM to get the role's path - role_name = role[role.rfind('/') + 1:] + role_name = role[role.rfind("/") + 1 :] try: - role = self.boto_session.client('iam').get_role(RoleName=role_name)['Role']['Arn'] + role = self.boto_session.client("iam").get_role(RoleName=role_name)["Role"]["Arn"] except ClientError: - LOGGER.warning("Couldn't call 'get_role' to get Role ARN from role name {} to get Role path." - .format(role_name)) + LOGGER.warning( + "Couldn't call 'get_role' to get Role ARN from role name {} to get Role path.".format( + role_name + ) + ) return role - def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning for this method + def logs_for_job( # noqa: C901 - suppress complexity warning for this method + self, job_name, wait=False, poll=10 + ): """Display the logs for a given training job, optionally tailing them until the job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. @@ -1133,20 +1282,22 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress """ description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) - print(secondary_training_status_message(description, None), end='') - instance_count = description['ResourceConfig']['InstanceCount'] - status = description['TrainingJobStatus'] + print(secondary_training_status_message(description, None), end="") + instance_count = description["ResourceConfig"]["InstanceCount"] + status = description["TrainingJobStatus"] stream_names = [] # The list of log streams - positions = {} # The current position in each stream, map of stream name -> position + positions = {} # The current position in each stream, map of stream name -> position # Increase retries allowed (from default of 4), as we don't want waiting for a training job # to be interrupted by a transient exception. - config = botocore.config.Config(retries={'max_attempts': 15}) - client = self.boto_session.client('logs', config=config) - log_group = '/aws/sagemaker/TrainingJobs' + config = botocore.config.Config(retries={"max_attempts": 15}) + client = self.boto_session.client("logs", config=config) + log_group = "/aws/sagemaker/TrainingJobs" - job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False + job_already_completed = ( + True if status == "Completed" or status == "Failed" or status == "Stopped" else False + ) state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE dot = False @@ -1179,32 +1330,47 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress # Log streams are created whenever a container starts writing to stdout/err, so this list # may be dynamic until we have a stream for every instance. try: - streams = client.describe_log_streams(logGroupName=log_group, logStreamNamePrefix=job_name + '/', - orderBy='LogStreamName', limit=instance_count) - stream_names = [s['logStreamName'] for s in streams['logStreams']] - positions.update([(s, sagemaker.logs.Position(timestamp=0, skip=0)) - for s in stream_names if s not in positions]) + streams = client.describe_log_streams( + logGroupName=log_group, + logStreamNamePrefix=job_name + "/", + orderBy="LogStreamName", + limit=instance_count, + ) + stream_names = [s["logStreamName"] for s in streams["logStreams"]] + positions.update( + [ + (s, sagemaker.logs.Position(timestamp=0, skip=0)) + for s in stream_names + if s not in positions + ] + ) except ClientError as e: # On the very first training job run on an account, there's no log group until # the container starts logging, so ignore any errors thrown about that - err = e.response.get('Error', {}) - if err.get('Code', None) != 'ResourceNotFoundException': + err = e.response.get("Error", {}) + if err.get("Code", None) != "ResourceNotFoundException": raise if len(stream_names) > 0: if dot: - print('') + print("") dot = False - for idx, event in sagemaker.logs.multi_stream_iter(client, log_group, stream_names, positions): - color_wrap(idx, event['message']) + for idx, event in sagemaker.logs.multi_stream_iter( + client, log_group, stream_names, positions + ): + color_wrap(idx, event["message"]) ts, count = positions[stream_names[idx]] - if event['timestamp'] == ts: - positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=ts, skip=count + 1) + if event["timestamp"] == ts: + positions[stream_names[idx]] = sagemaker.logs.Position( + timestamp=ts, skip=count + 1 + ) else: - positions[stream_names[idx]] = sagemaker.logs.Position(timestamp=event['timestamp'], skip=1) + positions[stream_names[idx]] = sagemaker.logs.Position( + timestamp=event["timestamp"], skip=1 + ) else: dot = True - print('.', end='') + print(".", end="") sys.stdout.flush() if state == LogState.COMPLETE: break @@ -1219,22 +1385,24 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress if secondary_training_status_changed(description, last_description): print() - print(secondary_training_status_message(description, last_description), end='') + print(secondary_training_status_message(description, last_description), end="") last_description = description - status = description['TrainingJobStatus'] + status = description["TrainingJobStatus"] - if status == 'Completed' or status == 'Failed' or status == 'Stopped': + if status == "Completed" or status == "Failed" or status == "Stopped": print() state = LogState.JOB_COMPLETE if wait: - self._check_job_status(job_name, description, 'TrainingJobStatus') + self._check_job_status(job_name, description, "TrainingJobStatus") if dot: print() # Customers are not billed for hardware provisioning, so billable time is less than total time - billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count - print('Billable seconds:', int(billable_time.total_seconds()) + 1) + billable_time = ( + description["TrainingEndTime"] - description["TrainingStartTime"] + ) * instance_count + print("Billable seconds:", int(billable_time.total_seconds()) + 1) def container_def(image, model_data_url=None, env=None): @@ -1251,9 +1419,9 @@ def container_def(image, model_data_url=None, env=None): """ if env is None: env = {} - c_def = {'Image': image, 'Environment': env} + c_def = {"Image": image, "Environment": env} if model_data_url: - c_def['ModelDataUrl'] = model_data_url + c_def["ModelDataUrl"] = model_data_url return c_def @@ -1274,8 +1442,14 @@ def pipeline_container_def(models, instance_type=None): return c_defs -def production_variant(model_name, instance_type, initial_instance_count=1, variant_name='AllTraffic', - initial_weight=1, accelerator_type=None): +def production_variant( + model_name, + instance_type, + initial_instance_count=1, + variant_name="AllTraffic", + initial_weight=1, + accelerator_type=None, +): """Create a production variant description suitable for use in a ``ProductionVariant`` list as part of a ``CreateEndpointConfig`` request. @@ -1292,15 +1466,15 @@ def production_variant(model_name, instance_type, initial_instance_count=1, vari dict[str, str]: An SageMaker ``ProductionVariant`` description """ production_variant_configuration = { - 'ModelName': model_name, - 'InstanceType': instance_type, - 'InitialInstanceCount': initial_instance_count, - 'VariantName': variant_name, - 'InitialVariantWeight': initial_weight + "ModelName": model_name, + "InstanceType": instance_type, + "InitialInstanceCount": initial_instance_count, + "VariantName": variant_name, + "InitialVariantWeight": initial_weight, } if accelerator_type: - production_variant_configuration['AcceleratorType'] = accelerator_type + production_variant_configuration["AcceleratorType"] = accelerator_type return production_variant_configuration @@ -1317,9 +1491,9 @@ def get_execution_role(sagemaker_session=None): sagemaker_session = Session() arn = sagemaker_session.get_caller_identity_arn() - if ':role/' in arn: + if ":role/" in arn: return arn - message = 'The current AWS identity is not a role: {}, therefore it cannot be used as a SageMaker execution role' + message = "The current AWS identity is not a role: {}, therefore it cannot be used as a SageMaker execution role" raise ValueError(message.format(arn)) @@ -1330,9 +1504,18 @@ class s3_input(object): config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker ``S3DataSource``. """ - def __init__(self, s3_data, distribution='FullyReplicated', compression=None, - content_type=None, record_wrapping=None, s3_data_type='S3Prefix', - input_mode=None, attribute_names=None, shuffle_config=None): + def __init__( + self, + s3_data, + distribution="FullyReplicated", + compression=None, + content_type=None, + record_wrapping=None, + s3_data_type="S3Prefix", + input_mode=None, + attribute_names=None, + shuffle_config=None, + ): """Create a definition for input data used by an SageMaker training job. See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters. @@ -1366,27 +1549,27 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, """ self.config = { - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': distribution, - 'S3DataType': s3_data_type, - 'S3Uri': s3_data + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": distribution, + "S3DataType": s3_data_type, + "S3Uri": s3_data, } } } if compression is not None: - self.config['CompressionType'] = compression + self.config["CompressionType"] = compression if content_type is not None: - self.config['ContentType'] = content_type + self.config["ContentType"] = content_type if record_wrapping is not None: - self.config['RecordWrapperType'] = record_wrapping + self.config["RecordWrapperType"] = record_wrapping if input_mode is not None: - self.config['InputMode'] = input_mode + self.config["InputMode"] = input_mode if attribute_names is not None: - self.config['DataSource']['S3DataSource']['AttributeNames'] = attribute_names + self.config["DataSource"]["S3DataSource"]["AttributeNames"] = attribute_names if shuffle_config is not None: - self.config['ShuffleConfig'] = {'Seed': shuffle_config.seed} + self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} class ShuffleConfig(object): @@ -1394,6 +1577,7 @@ class ShuffleConfig(object): Used to configure channel shuffling using a seed. See SageMaker documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html """ + def __init__(self, seed): """ Create a ShuffleConfig. @@ -1425,16 +1609,18 @@ def __init__(self, model_data, image, env=None): self.env = env -def _create_model_request(name, role, container_def=None, tags=None): # pylint: disable=redefined-outer-name - request = {'ModelName': name, 'ExecutionRoleArn': role} +def _create_model_request( + name, role, container_def=None, tags=None +): # pylint: disable=redefined-outer-name + request = {"ModelName": name, "ExecutionRoleArn": role} if isinstance(container_def, list): - request['Containers'] = container_def + request["Containers"] = container_def else: - request['PrimaryContainer'] = container_def + request["PrimaryContainer"] = container_def if tags: - request['Tags'] = tags + request["Tags"] = tags return request @@ -1444,23 +1630,26 @@ def _deployment_entity_exists(describe_fn): describe_fn() return True except ClientError as ce: - error_code = ce.response['Error']['Code'] - if not (error_code == 'ValidationException' and 'Could not find' in ce.response['Error']['Message']): + error_code = ce.response["Error"]["Code"] + if not ( + error_code == "ValidationException" + and "Could not find" in ce.response["Error"]["Message"] + ): raise ce return False def _train_done(sagemaker_client, job_name, last_desc): - in_progress_statuses = ['InProgress', 'Created'] + in_progress_statuses = ["InProgress", "Created"] desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) - status = desc['TrainingJobStatus'] + status = desc["TrainingJobStatus"] if secondary_training_status_changed(desc, last_desc): print() - print(secondary_training_status_message(desc, last_desc), end='') + print(secondary_training_status_message(desc, last_desc), end="") else: - print('.', end='') + print(".", end="") sys.stdout.flush() if status in in_progress_statuses: @@ -1472,19 +1661,19 @@ def _train_done(sagemaker_client, job_name, last_desc): def _compilation_job_status(sagemaker_client, job_name): compile_status_codes = { - 'Completed': '!', - 'InProgress': '.', - 'Failed': '*', - 'Stopped': 's', - 'Stopping': '_' + "Completed": "!", + "InProgress": ".", + "Failed": "*", + "Stopped": "s", + "Stopping": "_", } - in_progress_statuses = ['InProgress', 'Stopping', 'Starting'] + in_progress_statuses = ["InProgress", "Stopping", "Starting"] desc = sagemaker_client.describe_compilation_job(CompilationJobName=job_name) - status = desc['CompilationJobStatus'] + status = desc["CompilationJobStatus"] status = _STATUS_CODE_TABLE.get(status, status) - print(compile_status_codes.get(status, '?'), end='') + print(compile_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: @@ -1495,62 +1684,64 @@ def _compilation_job_status(sagemaker_client, job_name): def _tuning_job_status(sagemaker_client, job_name): tuning_status_codes = { - 'Completed': '!', - 'InProgress': '.', - 'Failed': '*', - 'Stopped': 's', - 'Stopping': '_' + "Completed": "!", + "InProgress": ".", + "Failed": "*", + "Stopped": "s", + "Stopping": "_", } - in_progress_statuses = ['InProgress', 'Stopping'] + in_progress_statuses = ["InProgress", "Stopping"] - desc = sagemaker_client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=job_name) - status = desc['HyperParameterTuningJobStatus'] + desc = sagemaker_client.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=job_name + ) + status = desc["HyperParameterTuningJobStatus"] - print(tuning_status_codes.get(status, '?'), end='') + print(tuning_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None - print('') + print("") return desc def _transform_job_status(sagemaker_client, job_name): transform_job_status_codes = { - 'Completed': '!', - 'InProgress': '.', - 'Failed': '*', - 'Stopped': 's', - 'Stopping': '_' + "Completed": "!", + "InProgress": ".", + "Failed": "*", + "Stopped": "s", + "Stopping": "_", } - in_progress_statuses = ['InProgress', 'Stopping'] + in_progress_statuses = ["InProgress", "Stopping"] desc = sagemaker_client.describe_transform_job(TransformJobName=job_name) - status = desc['TransformJobStatus'] + status = desc["TransformJobStatus"] - print(transform_job_status_codes.get(status, '?'), end='') + print(transform_job_status_codes.get(status, "?"), end="") sys.stdout.flush() if status in in_progress_statuses: return None - print('') + print("") return desc def _create_model_package_status(sagemaker_client, model_package_name): - in_progress_statuses = ['InProgress', 'Pending'] + in_progress_statuses = ["InProgress", "Pending"] desc = sagemaker_client.describe_model_package(ModelPackageName=model_package_name) - status = desc['ModelPackageStatus'] - print('.', end='') + status = desc["ModelPackageStatus"] + print(".", end="") sys.stdout.flush() if status in in_progress_statuses: return None - print('') + print("") return desc @@ -1562,14 +1753,14 @@ def _deploy_done(sagemaker_client, endpoint_name): "InService": "!", "RollingBack": "<", "Deleting": "o", - "Failed": "*" + "Failed": "*", } - in_progress_statuses = ['Creating', 'Updating'] + in_progress_statuses = ["Creating", "Updating"] desc = sagemaker_client.describe_endpoint(EndpointName=endpoint_name) - status = desc['EndpointStatus'] + status = desc["EndpointStatus"] - print(hosting_status_codes.get(status, '?'), end='') + print(hosting_status_codes.get(status, "?"), end="") sys.stdout.flush() return None if status in in_progress_statuses else desc @@ -1597,7 +1788,9 @@ def _expand_container_def(c_def): return c_def -def _vpc_config_from_training_job(training_job_desc, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): +def _vpc_config_from_training_job( + training_job_desc, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT +): if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT: return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY) else: diff --git a/src/sagemaker/sklearn/defaults.py b/src/sagemaker/sklearn/defaults.py index c3d5edce51..8bd416cc9a 100644 --- a/src/sagemaker/sklearn/defaults.py +++ b/src/sagemaker/sklearn/defaults.py @@ -12,6 +12,6 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -SKLEARN_NAME = 'scikit-learn' +SKLEARN_NAME = "scikit-learn" -SKLEARN_VERSION = '0.20.0' +SKLEARN_VERSION = "0.20.0" diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 269d3bea29..e066980b75 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -16,12 +16,16 @@ from sagemaker.estimator import Framework from sagemaker.fw_registry import default_framework_uri -from sagemaker.fw_utils import framework_name_from_image, empty_framework_version_warning, python_deprecation_warning +from sagemaker.fw_utils import ( + framework_name_from_image, + empty_framework_version_warning, + python_deprecation_warning, +) from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME from sagemaker.sklearn.model import SKLearnModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class SKLearn(Framework): @@ -29,8 +33,16 @@ class SKLearn(Framework): __framework_name__ = SKLEARN_NAME - def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=None, hyperparameters=None, - py_version='py3', image_name=None, **kwargs): + def __init__( + self, + entry_point, + framework_version=SKLEARN_VERSION, + source_dir=None, + hyperparameters=None, + py_version="py3", + image_name=None, + **kwargs + ): """ This ``Estimator`` executes an Scikit-learn script in a managed Scikit-learn execution environment, within a SageMaker Training Job. The managed Scikit-learn environment is an Amazon-built Docker container that executes @@ -67,19 +79,26 @@ def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=No **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. """ # SciKit-Learn does not support distributed training or training on GPU instance types. Fail fast. - train_instance_type = kwargs.get('train_instance_type') + train_instance_type = kwargs.get("train_instance_type") _validate_not_gpu_instance_type(train_instance_type) - train_instance_count = kwargs.get('train_instance_count') + train_instance_count = kwargs.get("train_instance_count") if train_instance_count: if train_instance_count != 1: - raise AttributeError("Scikit-Learn does not support distributed training. " - "Please remove the 'train_instance_count' argument or set " - "'train_instance_count=1' when initializing SKLearn.") - super(SKLearn, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name, - **dict(kwargs, train_instance_count=1)) - - if py_version == 'py2': + raise AttributeError( + "Scikit-Learn does not support distributed training. " + "Please remove the 'train_instance_count' argument or set " + "'train_instance_count=1' when initializing SKLearn." + ) + super(SKLearn, self).__init__( + entry_point, + source_dir, + hyperparameters, + image_name=image_name, + **dict(kwargs, train_instance_count=1) + ) + + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -91,12 +110,12 @@ def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=No if image_name is None: image_tag = "{}-{}-{}".format(framework_version, "cpu", py_version) self.image_name = default_framework_uri( - SKLearn.__framework_name__, - self.sagemaker_session.boto_region_name, - image_tag) + SKLearn.__framework_name__, self.sagemaker_session.boto_region_name, image_tag + ) - def create_model(self, model_server_workers=None, role=None, - vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): + def create_model( + self, model_server_workers=None, role=None, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs + ): """Create a SageMaker ``SKLearnModel`` object that can be deployed to an ``Endpoint``. Args: @@ -115,14 +134,23 @@ def create_model(self, model_server_workers=None, role=None, See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details. """ role = role or self.role - return SKLearnModel(self.model_data, role, self.entry_point, source_dir=self._model_source_dir(), - enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, - container_log_level=self.container_log_level, code_location=self.code_location, - py_version=self.py_version, framework_version=self.framework_version, - model_server_workers=model_server_workers, image=self.image_name, - sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs) + return SKLearnModel( + self.model_data, + role, + self.entry_point, + source_dir=self._model_source_dir(), + enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, + name=self._current_job_name, + container_log_level=self.container_log_level, + code_location=self.code_location, + py_version=self.py_version, + framework_version=self.framework_version, + model_server_workers=model_server_workers, + image=self.image_name, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + **kwargs + ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -137,25 +165,37 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na """ init_params = super(SKLearn, cls)._prepare_init_params_from_job_description(job_details) - image_name = init_params.pop('image') + image_name = init_params.pop("image") framework, py_version, _, _ = framework_name_from_image(image_name) - init_params['py_version'] = py_version + init_params["py_version"] = py_version if framework and framework != cls.__framework_name__: - training_job_name = init_params['base_job_name'] - raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name)) + training_job_name = init_params["base_job_name"] + raise ValueError( + "Training job: {} didn't use image for requested framework".format( + training_job_name + ) + ) elif not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. - init_params['image_name'] = image_name + init_params["image_name"] = image_name return init_params def _validate_not_gpu_instance_type(training_instance_type): - gpu_instance_types = ['ml.p2.xlarge', 'ml.p2.8xlarge', 'ml.p2.16xlarge', - 'ml.p3.xlarge', 'ml.p3.8xlarge', 'ml.p3.16xlarge'] + gpu_instance_types = [ + "ml.p2.xlarge", + "ml.p2.8xlarge", + "ml.p2.16xlarge", + "ml.p3.xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + ] if training_instance_type in gpu_instance_types: - raise ValueError("GPU training in not supported for Scikit-Learn. " - "Please pick a different instance type from here: " - "https://aws.amazon.com/ec2/instance-types/") + raise ValueError( + "GPU training in not supported for Scikit-Learn. " + "Please pick a different instance type from here: " + "https://aws.amazon.com/ec2/instance-types/" + ) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index d10a35d5fe..4ed9ee5e60 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -21,7 +21,7 @@ from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class SKLearnPredictor(RealTimePredictor): @@ -39,7 +39,9 @@ def __init__(self, endpoint_name, sagemaker_session=None): Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. """ - super(SKLearnPredictor, self).__init__(endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer) + super(SKLearnPredictor, self).__init__( + endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer + ) class SKLearnModel(FrameworkModel): @@ -47,8 +49,18 @@ class SKLearnModel(FrameworkModel): __framework_name__ = SKLEARN_NAME - def __init__(self, model_data, role, entry_point, image=None, py_version='py3', framework_version=SKLEARN_VERSION, - predictor_cls=SKLearnPredictor, model_server_workers=None, **kwargs): + def __init__( + self, + model_data, + role, + entry_point, + image=None, + py_version="py3", + framework_version=SKLEARN_VERSION, + predictor_cls=SKLearnPredictor, + model_server_workers=None, + **kwargs + ): """Initialize an SKLearnModel. Args: @@ -70,10 +82,11 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py3', If None, server will use one worker per vCPU. **kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer. """ - super(SKLearnModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, - **kwargs) + super(SKLearnModel, self).__init__( + model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs + ) - if py_version == 'py2': + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -99,9 +112,8 @@ def prepare_container_def(self, instance_type, accelerator_type=None): if not deploy_image: image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version) deploy_image = default_framework_uri( - self.__framework_name__, - self.sagemaker_session.boto_region_name, - image_tag) + self.__framework_name__, self.sagemaker_session.boto_region_name, image_tag + ) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix) diff --git a/src/sagemaker/sparkml/model.py b/src/sagemaker/sparkml/model.py index c95efef38a..ab898de490 100644 --- a/src/sagemaker/sparkml/model.py +++ b/src/sagemaker/sparkml/model.py @@ -17,8 +17,8 @@ from sagemaker.fw_registry import registry from sagemaker.predictor import csv_serializer -framework_name = 'sparkml-serving' -repo_name = 'sagemaker-sparkml-serving' +framework_name = "sparkml-serving" +repo_name = "sagemaker-sparkml-serving" class SparkMLPredictor(RealTimePredictor): @@ -45,8 +45,12 @@ def __init__(self, endpoint, sagemaker_session=None): using the default AWS configuration chain. """ sagemaker_session = sagemaker_session or Session() - super(SparkMLPredictor, self).__init__(endpoint=endpoint, sagemaker_session=sagemaker_session, - serializer=csv_serializer, content_type=CONTENT_TYPE_CSV) + super(SparkMLPredictor, self).__init__( + endpoint=endpoint, + sagemaker_session=sagemaker_session, + serializer=csv_serializer, + content_type=CONTENT_TYPE_CSV, + ) class SparkMLModel(Model): @@ -74,5 +78,11 @@ def __init__(self, model_data, role=None, spark_version=2.2, sagemaker_session=N # for local mode, sagemaker_session should be passed as None but we need a session to get boto_region_name region_name = (sagemaker_session or Session()).boto_region_name image = "{}/{}:{}".format(registry(region_name, framework_name), repo_name, spark_version) - super(SparkMLModel, self).__init__(model_data, image, role, predictor_cls=SparkMLPredictor, - sagemaker_session=sagemaker_session, **kwargs) + super(SparkMLModel, self).__init__( + model_data, + image, + role, + predictor_cls=SparkMLPredictor, + sagemaker_session=sagemaker_session, + **kwargs + ) diff --git a/src/sagemaker/tensorflow/defaults.py b/src/sagemaker/tensorflow/defaults.py index fb68b1a59d..52c3bc5369 100644 --- a/src/sagemaker/tensorflow/defaults.py +++ b/src/sagemaker/tensorflow/defaults.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -TF_VERSION = '1.11' +TF_VERSION = "1.11" """Default TF version for when the framework version is not specified. This is no longer updated so as to not break existing workflows. """ diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 983e743315..ef22bd15ca 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -29,17 +29,26 @@ from sagemaker.utils import get_config_value from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -logger = logging.getLogger('sagemaker') - - -_FRAMEWORK_MODE_ARGS = ('training_steps', 'evaluation_steps', 'requirements_file', 'checkpoint_path') -_SCRIPT_MODE = 'tensorflow-scriptmode' -_SCRIPT_MODE_SERVING_ERROR_MSG = 'Script mode containers does not support serving yet. ' \ - 'Please use our new tensorflow-serving container by creating the model ' \ - 'with \'endpoint_type\' set to \'tensorflow-serving\'.' -_SCRIPT_MODE_TENSORBOARD_WARNING = 'Tensorboard is not supported with script mode. You can run the following ' \ - 'command: tensorboard --logdir {} --host localhost --port 6006 This can be ' \ - 'run from anywhere with access to the S3 URI used as the logdir.' +logger = logging.getLogger("sagemaker") + + +_FRAMEWORK_MODE_ARGS = ( + "training_steps", + "evaluation_steps", + "requirements_file", + "checkpoint_path", +) +_SCRIPT_MODE = "tensorflow-scriptmode" +_SCRIPT_MODE_SERVING_ERROR_MSG = ( + "Script mode containers does not support serving yet. " + "Please use our new tensorflow-serving container by creating the model " + "with 'endpoint_type' set to 'tensorflow-serving'." +) +_SCRIPT_MODE_TENSORBOARD_WARNING = ( + "Tensorboard is not supported with script mode. You can run the following " + "command: tensorboard --logdir {} --host localhost --port 6006 This can be " + "run from anywhere with access to the S3 URI used as the logdir." +) class Tensorboard(threading.Thread): @@ -88,7 +97,7 @@ def _sync_directories(from_directory, to_directory): for fname in files: from_file = os.path.join(root, fname) to_file = os.path.join(to_root, fname) - with open(from_file, 'rb') as a, open(to_file, 'wb') as b: + with open(from_file, "rb") as a, open(to_file, "wb") as b: b.write(a.read()) @staticmethod @@ -111,15 +120,17 @@ def validate_requirements(self): Raises: EnvironmentError: If at least one requirement is not installed. """ - if not self._cmd_exists('tensorboard'): + if not self._cmd_exists("tensorboard"): raise EnvironmentError( - 'TensorBoard is not installed in the system. Please install TensorBoard using the' - ' following command: \n pip install tensorboard') + "TensorBoard is not installed in the system. Please install TensorBoard using the" + " following command: \n pip install tensorboard" + ) - if not self._cmd_exists('aws'): + if not self._cmd_exists("aws"): raise EnvironmentError( - 'The AWS CLI is not installed in the system. Please install the AWS CLI using the' - ' following command: \n pip install awscli') + "The AWS CLI is not installed in the system. Please install the AWS CLI using the" + " following command: \n pip install awscli" + ) def create_tensorboard_process(self): """Create a TensorBoard process. @@ -136,10 +147,17 @@ def create_tensorboard_process(self): for _ in range(100): p = subprocess.Popen( - ["tensorboard", "--logdir", self.logdir, "--host", "localhost", "--port", - str(port)], + [ + "tensorboard", + "--logdir", + self.logdir, + "--host", + "localhost", + "--port", + str(port), + ], stdout=subprocess.PIPE, - stderr=subprocess.PIPE + stderr=subprocess.PIPE, ) self.event.wait(5) if p.poll(): @@ -148,18 +166,19 @@ def create_tensorboard_process(self): return port, p raise OSError( - 'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105') + "No available ports to start TensorBoard. Attempted all ports between 6006 and 6105" + ) def run(self): """Run TensorBoard process.""" port, tensorboard_process = self.create_tensorboard_process() - logger.info('TensorBoard 0.1.7 at http://localhost:{}'.format(port)) + logger.info("TensorBoard 0.1.7 at http://localhost:{}".format(port)) while not self.estimator.checkpoint_path: self.event.wait(1) with self._temporary_directory() as aws_sync_dir: while not self.event.is_set(): - args = ['aws', 's3', 'sync', self.estimator.checkpoint_path, aws_sync_dir] + args = ["aws", "s3", "sync", self.estimator.checkpoint_path, aws_sync_dir] subprocess.call(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) self._sync_directories(aws_sync_dir, self.logdir) self.event.wait(10) @@ -169,14 +188,25 @@ def run(self): class TensorFlow(Framework): """Handle end-to-end training and deployment of user-provided TensorFlow code.""" - __framework_name__ = 'tensorflow' + __framework_name__ = "tensorflow" - LATEST_VERSION = '1.12' + LATEST_VERSION = "1.12" """The latest version of TensorFlow included in the SageMaker pre-built Docker images.""" - def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', - framework_version=None, model_dir=None, requirements_file='', image_name=None, - script_mode=False, distributions=None, **kwargs): + def __init__( + self, + training_steps=None, + evaluation_steps=None, + checkpoint_path=None, + py_version="py2", + framework_version=None, + model_dir=None, + requirements_file="", + image_name=None, + script_mode=False, + distributions=None, + **kwargs + ): """Initialize a ``TensorFlow`` estimator. Args: @@ -237,8 +267,8 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N super(TensorFlow, self).__init__(image_name=image_name, **kwargs) self.checkpoint_path = checkpoint_path - if py_version == 'py2': - logger.warning('tensorflow py2 container will be deprecated soon.') + if py_version == "py2": + logger.warning("tensorflow py2 container will be deprecated soon.") self.py_version = py_version self.training_steps = training_steps @@ -247,33 +277,48 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N self.script_mode = script_mode self.distributions = distributions or {} - self._validate_args(py_version=py_version, script_mode=script_mode, framework_version=framework_version, - training_steps=training_steps, evaluation_steps=evaluation_steps, - requirements_file=requirements_file, checkpoint_path=checkpoint_path) + self._validate_args( + py_version=py_version, + script_mode=script_mode, + framework_version=framework_version, + training_steps=training_steps, + evaluation_steps=evaluation_steps, + requirements_file=requirements_file, + checkpoint_path=checkpoint_path, + ) self._validate_requirements_file(requirements_file) self.requirements_file = requirements_file - def _validate_args(self, py_version, script_mode, framework_version, training_steps, - evaluation_steps, requirements_file, checkpoint_path): + def _validate_args( + self, + py_version, + script_mode, + framework_version, + training_steps, + evaluation_steps, + requirements_file, + checkpoint_path, + ): - if py_version == 'py3' or script_mode: + if py_version == "py3" or script_mode: if framework_version is None: raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR) found_args = [] if training_steps: - found_args.append('training_steps') + found_args.append("training_steps") if evaluation_steps: - found_args.append('evaluation_steps') + found_args.append("evaluation_steps") if requirements_file: - found_args.append('requirements_file') + found_args.append("requirements_file") if checkpoint_path: - found_args.append('checkpoint_path') + found_args.append("checkpoint_path") if found_args: raise AttributeError( - '{} are deprecated in script mode. Please do not set {}.' - .format(', '.join(_FRAMEWORK_MODE_ARGS), ', '.join(found_args)) + "{} are deprecated in script mode. Please do not set {}.".format( + ", ".join(_FRAMEWORK_MODE_ARGS), ", ".join(found_args) + ) ) def _validate_requirements_file(self, requirements_file): @@ -281,17 +326,20 @@ def _validate_requirements_file(self, requirements_file): return if not self.source_dir: - raise ValueError('Must specify source_dir along with a requirements file.') + raise ValueError("Must specify source_dir along with a requirements file.") - if self.source_dir.lower().startswith('s3://'): + if self.source_dir.lower().startswith("s3://"): return if os.path.isabs(requirements_file): - raise ValueError('Requirements file {} is not a path relative to source_dir.'.format( - requirements_file)) + raise ValueError( + "Requirements file {} is not a path relative to source_dir.".format( + requirements_file + ) + ) if not os.path.exists(os.path.join(self.source_dir, requirements_file)): - raise ValueError('Requirements file {} does not exist.'.format(requirements_file)) + raise ValueError("Requirements file {} does not exist.".format(requirements_file)) def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_locally=False): """Train a model using the input training dataset. @@ -355,44 +403,54 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na dictionary: The transformed init_params """ - init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details, - model_channel_name) + init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description( + job_details, model_channel_name + ) # Move some of the tensorflow specific init params from hyperparameters into the main init params. - for argument in ('checkpoint_path', 'training_steps', 'evaluation_steps', 'model_dir'): - value = init_params['hyperparameters'].pop(argument, None) + for argument in ("checkpoint_path", "training_steps", "evaluation_steps", "model_dir"): + value = init_params["hyperparameters"].pop(argument, None) if value is not None: init_params[argument] = value - image_name = init_params.pop('image') + image_name = init_params.pop("image") framework, py_version, tag, script_mode = fw.framework_name_from_image(image_name) if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. - init_params['image_name'] = image_name + init_params["image_name"] = image_name return init_params if script_mode: - init_params['script_mode'] = True + init_params["script_mode"] = True - init_params['py_version'] = py_version + init_params["py_version"] = py_version # We switched image tagging scheme from regular image version (e.g. '1.0') to more expressive # containing framework version, device type and python version (e.g. '1.5-gpu-py2'). # For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version # otherwise extract framework version from the tag itself. - init_params['framework_version'] = '1.4' if tag == '1.0' else fw.framework_version_from_tag( - tag) + init_params["framework_version"] = ( + "1.4" if tag == "1.0" else fw.framework_version_from_tag(tag) + ) - training_job_name = init_params['base_job_name'] + training_job_name = init_params["base_job_name"] if framework != cls.__framework_name__: - raise ValueError("Training job: {} didn't use image for requested framework".format( - training_job_name)) + raise ValueError( + "Training job: {} didn't use image for requested framework".format( + training_job_name + ) + ) return init_params - def create_model(self, model_server_workers=None, role=None, - vpc_config_override=VPC_CONFIG_DEFAULT, endpoint_type=None): + def create_model( + self, + model_server_workers=None, + role=None, + vpc_config_override=VPC_CONFIG_DEFAULT, + endpoint_type=None, + ): """Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``. Args: @@ -415,89 +473,110 @@ def create_model(self, model_server_workers=None, role=None, """ role = role or self.role - if endpoint_type == 'tensorflow-serving' or self._script_mode_enabled(): + if endpoint_type == "tensorflow-serving" or self._script_mode_enabled(): return self._create_tfs_model(role=role, vpc_config_override=vpc_config_override) - return self._create_default_model(model_server_workers=model_server_workers, role=role, - vpc_config_override=vpc_config_override) + return self._create_default_model( + model_server_workers=model_server_workers, + role=role, + vpc_config_override=vpc_config_override, + ) def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT): - return Model(model_data=self.model_data, - role=role, - image=self.image_name, - name=self._current_job_name, - container_log_level=self.container_log_level, - framework_version=self.framework_version, - sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override)) + return Model( + model_data=self.model_data, + role=role, + image=self.image_name, + name=self._current_job_name, + container_log_level=self.container_log_level, + framework_version=self.framework_version, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + ) def _create_default_model(self, model_server_workers, role, vpc_config_override): - return TensorFlowModel(self.model_data, role, self.entry_point, - source_dir=self._model_source_dir(), - enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, - env={'SAGEMAKER_REQUIREMENTS': self.requirements_file}, - image=self.image_name, - name=self._current_job_name, - container_log_level=self.container_log_level, - code_location=self.code_location, py_version=self.py_version, - framework_version=self.framework_version, - model_server_workers=model_server_workers, - sagemaker_session=self.sagemaker_session, - vpc_config=self.get_vpc_config(vpc_config_override), - dependencies=self.dependencies) + return TensorFlowModel( + self.model_data, + role, + self.entry_point, + source_dir=self._model_source_dir(), + enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, + env={"SAGEMAKER_REQUIREMENTS": self.requirements_file}, + image=self.image_name, + name=self._current_job_name, + container_log_level=self.container_log_level, + code_location=self.code_location, + py_version=self.py_version, + framework_version=self.framework_version, + model_server_workers=model_server_workers, + sagemaker_session=self.sagemaker_session, + vpc_config=self.get_vpc_config(vpc_config_override), + dependencies=self.dependencies, + ) def hyperparameters(self): """Return hyperparameters used by your custom TensorFlow code during model training.""" hyperparameters = super(TensorFlow, self).hyperparameters() - self.checkpoint_path = self.checkpoint_path or self._default_s3_path('checkpoints') + self.checkpoint_path = self.checkpoint_path or self._default_s3_path("checkpoints") mpi_enabled = False if self._script_mode_enabled(): additional_hyperparameters = {} - if 'parameter_server' in self.distributions: - ps_enabled = self.distributions['parameter_server'].get('enabled', False) + if "parameter_server" in self.distributions: + ps_enabled = self.distributions["parameter_server"].get("enabled", False) additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled - if 'mpi' in self.distributions: - mpi_dict = self.distributions['mpi'] - mpi_enabled = mpi_dict.get('enabled', False) + if "mpi" in self.distributions: + mpi_dict = self.distributions["mpi"] + mpi_enabled = mpi_dict.get("enabled", False) additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled - additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get('processes_per_host', 1) - additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get('custom_mpi_options', '') + additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get( + "processes_per_host", 1 + ) + additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get( + "custom_mpi_options", "" + ) - self.model_dir = self.model_dir or self._default_s3_path('model', mpi=mpi_enabled) - additional_hyperparameters['model_dir'] = self.model_dir + self.model_dir = self.model_dir or self._default_s3_path("model", mpi=mpi_enabled) + additional_hyperparameters["model_dir"] = self.model_dir else: - additional_hyperparameters = {'checkpoint_path': self.checkpoint_path, - 'training_steps': self.training_steps, - 'evaluation_steps': self.evaluation_steps, - 'sagemaker_requirements': self.requirements_file} + additional_hyperparameters = { + "checkpoint_path": self.checkpoint_path, + "training_steps": self.training_steps, + "evaluation_steps": self.evaluation_steps, + "sagemaker_requirements": self.requirements_file, + } hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters)) return hyperparameters def _default_s3_path(self, directory, mpi=False): - local_code = get_config_value('local.local_code', self.sagemaker_session.config) + local_code = get_config_value("local.local_code", self.sagemaker_session.config) if self.sagemaker_session.local_mode and local_code: - return '/opt/ml/shared/{}'.format(directory) + return "/opt/ml/shared/{}".format(directory) elif mpi: - return '/opt/ml/model' + return "/opt/ml/model" elif self._current_job_name: return os.path.join(self.output_path, self._current_job_name, directory) else: return None def _script_mode_enabled(self): - return self.py_version == 'py3' or self.script_mode + return self.py_version == "py3" or self.script_mode def train_image(self): if self.image_name: return self.image_name if self._script_mode_enabled(): - return fw.create_image_uri(self.sagemaker_session.boto_region_name, _SCRIPT_MODE, - self.train_instance_type, self.framework_version, self.py_version) + return fw.create_image_uri( + self.sagemaker_session.boto_region_name, + _SCRIPT_MODE, + self.train_instance_type, + self.framework_version, + self.py_version, + ) return super(TensorFlow, self).train_image() diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 4c5207d83e..73a22684e6 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -21,7 +21,7 @@ from sagemaker.tensorflow.defaults import TF_VERSION from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer -logger = logging.getLogger('sagemaker') +logger = logging.getLogger("sagemaker") class TensorFlowPredictor(RealTimePredictor): @@ -29,6 +29,7 @@ class TensorFlowPredictor(RealTimePredictor): This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for inference""" + def __init__(self, endpoint_name, sagemaker_session=None): """Initialize an ``TensorFlowPredictor``. @@ -38,16 +39,27 @@ def __init__(self, endpoint_name, sagemaker_session=None): Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. """ - super(TensorFlowPredictor, self).__init__(endpoint_name, sagemaker_session, tf_json_serializer, - tf_json_deserializer) + super(TensorFlowPredictor, self).__init__( + endpoint_name, sagemaker_session, tf_json_serializer, tf_json_deserializer + ) class TensorFlowModel(FrameworkModel): - __framework_name__ = 'tensorflow' - - def __init__(self, model_data, role, entry_point, image=None, py_version='py2', framework_version=TF_VERSION, - predictor_cls=TensorFlowPredictor, model_server_workers=None, **kwargs): + __framework_name__ = "tensorflow" + + def __init__( + self, + model_data, + role, + entry_point, + image=None, + py_version="py2", + framework_version=TF_VERSION, + predictor_cls=TensorFlowPredictor, + model_server_workers=None, + **kwargs + ): """Initialize an TensorFlowModel. Args: @@ -69,10 +81,11 @@ def __init__(self, model_data, role, entry_point, image=None, py_version='py2', If None, server will use one worker per vCPU. **kwargs: Keyword arguments passed to the ``FrameworkModel`` initializer. """ - super(TensorFlowModel, self).__init__(model_data, image, role, entry_point, predictor_cls=predictor_cls, - **kwargs) + super(TensorFlowModel, self).__init__( + model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs + ) - if py_version == 'py2': + if py_version == "py2": logger.warning(python_deprecation_warning(self.__framework_name__)) self.py_version = py_version @@ -95,8 +108,14 @@ def prepare_container_def(self, instance_type, accelerator_type=None): deploy_image = self.image if not deploy_image: region_name = self.sagemaker_session.boto_region_name - deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type, - self.framework_version, self.py_version, accelerator_type=accelerator_type) + deploy_image = create_image_uri( + region_name, + self.__framework_name__, + instance_type, + self.framework_version, + self.py_version, + accelerator_type=accelerator_type, + ) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) self._upload_code(deploy_key_prefix) diff --git a/src/sagemaker/tensorflow/predictor.py b/src/sagemaker/tensorflow/predictor.py index 359dd062b8..c56f72ddc9 100644 --- a/src/sagemaker/tensorflow/predictor.py +++ b/src/sagemaker/tensorflow/predictor.py @@ -24,14 +24,18 @@ from sagemaker.predictor import json_serializer, csv_serializer from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2 -_POSSIBLE_RESPONSES = [predict_pb2.PredictResponse, classification_pb2.ClassificationResponse, - inference_pb2.MultiInferenceResponse, regression_pb2.RegressionResponse, - tensor_pb2.TensorProto] +_POSSIBLE_RESPONSES = [ + predict_pb2.PredictResponse, + classification_pb2.ClassificationResponse, + inference_pb2.MultiInferenceResponse, + regression_pb2.RegressionResponse, + tensor_pb2.TensorProto, +] -REGRESSION_REQUEST = 'RegressionRequest' -MULTI_INFERENCE_REQUEST = 'MultiInferenceRequest' -CLASSIFICATION_REQUEST = 'ClassificationRequest' -PREDICT_REQUEST = 'PredictRequest' +REGRESSION_REQUEST = "RegressionRequest" +MULTI_INFERENCE_REQUEST = "MultiInferenceRequest" +CLASSIFICATION_REQUEST = "ClassificationRequest" +PREDICT_REQUEST = "PredictRequest" class _TFProtobufSerializer(object): @@ -43,10 +47,15 @@ def __call__(self, data): # for example sagemaker.tensorflow.tensorflow_serving.regression_pb2 and tensorflow_serving.apis.regression_pb2 predict_type = data.__class__.__name__ - available_requests = [PREDICT_REQUEST, CLASSIFICATION_REQUEST, MULTI_INFERENCE_REQUEST, REGRESSION_REQUEST] + available_requests = [ + PREDICT_REQUEST, + CLASSIFICATION_REQUEST, + MULTI_INFERENCE_REQUEST, + REGRESSION_REQUEST, + ] if predict_type not in available_requests: - raise ValueError('request type {} is not supported'.format(predict_type)) + raise ValueError("request type {} is not supported".format(predict_type)) return data.SerializeToString() @@ -72,7 +81,7 @@ def __call__(self, stream, content_type): # given that the payload does not have the response type, there no way to infer # the response without keeping state, so I'm iterating all the options. pass - raise ValueError('data is not in the expected format') + raise ValueError("data is not in the expected format") tf_deserializer = _TFProtobufDeserializer() diff --git a/src/sagemaker/tensorflow/serving.py b/src/sagemaker/tensorflow/serving.py index 7a37318d10..e54e58bae6 100644 --- a/src/sagemaker/tensorflow/serving.py +++ b/src/sagemaker/tensorflow/serving.py @@ -26,12 +26,16 @@ class Predictor(sagemaker.RealTimePredictor): """A ``RealTimePredictor`` implementation for inference against TensorFlow Serving endpoints. """ - def __init__(self, endpoint_name, sagemaker_session=None, - serializer=json_serializer, - deserializer=json_deserializer, - content_type=None, - model_name=None, - model_version=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=json_serializer, + deserializer=json_deserializer, + content_type=None, + model_name=None, + model_version=None, + ): """Initialize a ``TFSPredictor``. See ``sagemaker.RealTimePredictor`` for more info about parameters. @@ -50,59 +54,67 @@ def __init__(self, endpoint_name, sagemaker_session=None, model_version (str): Optional. The version of the SavedModel model that should handle the request. If not specified, the latest version of the model will be used. """ - super(Predictor, self).__init__(endpoint_name, sagemaker_session, serializer, - deserializer, content_type) + super(Predictor, self).__init__( + endpoint_name, sagemaker_session, serializer, deserializer, content_type + ) attributes = [] if model_name: - attributes.append('tfs-model-name={}'.format(model_name)) + attributes.append("tfs-model-name={}".format(model_name)) if model_version: - attributes.append('tfs-model-version={}'.format(model_version)) - self._model_attributes = ','.join(attributes) if attributes else None + attributes.append("tfs-model-version={}".format(model_version)) + self._model_attributes = ",".join(attributes) if attributes else None def classify(self, data): - return self._classify_or_regress(data, 'classify') + return self._classify_or_regress(data, "classify") def regress(self, data): - return self._classify_or_regress(data, 'regress') + return self._classify_or_regress(data, "regress") def _classify_or_regress(self, data, method): - if method not in ['classify', 'regress']: - raise ValueError('invalid TensorFlow Serving method: {}'.format(method)) + if method not in ["classify", "regress"]: + raise ValueError("invalid TensorFlow Serving method: {}".format(method)) if self.content_type != CONTENT_TYPE_JSON: - raise ValueError('The {} api requires json requests.'.format(method)) + raise ValueError("The {} api requires json requests.".format(method)) - args = { - 'CustomAttributes': 'tfs-method={}'.format(method) - } + args = {"CustomAttributes": "tfs-method={}".format(method)} return self.predict(data, args) def predict(self, data, initial_args=None): args = dict(initial_args) if initial_args else {} if self._model_attributes: - if 'CustomAttributes' in args: - args['CustomAttributes'] += ',' + self._model_attributes + if "CustomAttributes" in args: + args["CustomAttributes"] += "," + self._model_attributes else: - args['CustomAttributes'] = self._model_attributes + args["CustomAttributes"] = self._model_attributes return super(Predictor, self).predict(data, args) class Model(sagemaker.model.FrameworkModel): - FRAMEWORK_NAME = 'tensorflow-serving' - LOG_LEVEL_PARAM_NAME = 'SAGEMAKER_TFS_NGINX_LOGLEVEL' + FRAMEWORK_NAME = "tensorflow-serving" + LOG_LEVEL_PARAM_NAME = "SAGEMAKER_TFS_NGINX_LOGLEVEL" LOG_LEVEL_MAP = { - logging.DEBUG: 'debug', - logging.INFO: 'info', - logging.WARNING: 'warn', - logging.ERROR: 'error', - logging.CRITICAL: 'crit', + logging.DEBUG: "debug", + logging.INFO: "info", + logging.WARNING: "warn", + logging.ERROR: "error", + logging.CRITICAL: "crit", } - def __init__(self, model_data, role, entry_point=None, image=None, framework_version=TF_VERSION, - container_log_level=None, predictor_cls=Predictor, **kwargs): + def __init__( + self, + model_data, + role, + entry_point=None, + image=None, + framework_version=TF_VERSION, + container_log_level=None, + predictor_cls=Predictor, + **kwargs + ): """Initialize a Model. Args: @@ -119,8 +131,14 @@ def __init__(self, model_data, role, entry_point=None, image=None, framework_ver returns the result of invoking this function on the created endpoint name. **kwargs: Keyword arguments passed to the ``Model`` initializer. """ - super(Model, self).__init__(model_data=model_data, role=role, image=image, - predictor_cls=predictor_cls, entry_point=entry_point, **kwargs) + super(Model, self).__init__( + model_data=model_data, + role=role, + image=image, + predictor_cls=predictor_cls, + entry_point=entry_point, + **kwargs + ) self._framework_version = framework_version self._container_log_level = container_log_level @@ -132,14 +150,16 @@ def prepare_container_def(self, instance_type, accelerator_type=None): key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image) bucket = self.bucket or self.sagemaker_session.default_bucket() - model_data = 's3://' + os.path.join(bucket, key_prefix, 'model.tar.gz') - - sagemaker.utils.repack_model(self.entry_point, - self.source_dir, - self.dependencies, - self.model_data, - model_data, - self.sagemaker_session) + model_data = "s3://" + os.path.join(bucket, key_prefix, "model.tar.gz") + + sagemaker.utils.repack_model( + self.entry_point, + self.source_dir, + self.dependencies, + self.model_data, + model_data, + self.sagemaker_session, + ) else: model_data = self.model_data @@ -150,7 +170,7 @@ def _get_container_env(self): return self.env if self._container_log_level not in Model.LOG_LEVEL_MAP: - logging.warning('ignoring invalid container log level: %s', self._container_log_level) + logging.warning("ignoring invalid container log level: %s", self._container_log_level) return self.env env = dict(self.env) @@ -162,5 +182,10 @@ def _get_image_uri(self, instance_type, accelerator_type=None): return self.image region_name = self.sagemaker_session.boto_region_name - return create_image_uri(region_name, Model.FRAMEWORK_NAME, instance_type, - self._framework_version, accelerator_type=accelerator_type) + return create_image_uri( + region_name, + Model.FRAMEWORK_NAME, + instance_type, + self._framework_version, + accelerator_type=accelerator_type, + ) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/classification_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/classification_pb2.py index 69748f2292..89e4558b23 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/classification_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/classification_pb2.py @@ -2,12 +2,14 @@ # source: tensorflow_serving/apis/classification.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -18,240 +20,342 @@ DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/classification.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n,tensorflow_serving/apis/classification.proto\x12\x12tensorflow.serving\x1a#tensorflow_serving/apis/input.proto\x1a#tensorflow_serving/apis/model.proto\"%\n\x05\x43lass\x12\r\n\x05label\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\"=\n\x0f\x43lassifications\x12*\n\x07\x63lasses\x18\x01 \x03(\x0b\x32\x19.tensorflow.serving.Class\"T\n\x14\x43lassificationResult\x12<\n\x0f\x63lassifications\x18\x01 \x03(\x0b\x32#.tensorflow.serving.Classifications\"t\n\x15\x43lassificationRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12(\n\x05input\x18\x02 \x01(\x0b\x32\x19.tensorflow.serving.Input\"\x85\x01\n\x16\x43lassificationResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x38\n\x06result\x18\x01 \x01(\x0b\x32(.tensorflow.serving.ClassificationResultB\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensorflow__serving_dot_apis_dot_input__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR,]) - - + name="tensorflow_serving/apis/classification.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + '\n,tensorflow_serving/apis/classification.proto\x12\x12tensorflow.serving\x1a#tensorflow_serving/apis/input.proto\x1a#tensorflow_serving/apis/model.proto"%\n\x05\x43lass\x12\r\n\x05label\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02"=\n\x0f\x43lassifications\x12*\n\x07\x63lasses\x18\x01 \x03(\x0b\x32\x19.tensorflow.serving.Class"T\n\x14\x43lassificationResult\x12<\n\x0f\x63lassifications\x18\x01 \x03(\x0b\x32#.tensorflow.serving.Classifications"t\n\x15\x43lassificationRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12(\n\x05input\x18\x02 \x01(\x0b\x32\x19.tensorflow.serving.Input"\x85\x01\n\x16\x43lassificationResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x38\n\x06result\x18\x01 \x01(\x0b\x32(.tensorflow.serving.ClassificationResultB\x03\xf8\x01\x01\x62\x06proto3' + ), + dependencies=[ + tensorflow__serving_dot_apis_dot_input__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR, + ], +) _CLASS = _descriptor.Descriptor( - name='Class', - full_name='tensorflow.serving.Class', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='label', full_name='tensorflow.serving.Class.label', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='score', full_name='tensorflow.serving.Class.score', index=1, - number=2, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=142, - serialized_end=179, + name="Class", + full_name="tensorflow.serving.Class", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="label", + full_name="tensorflow.serving.Class.label", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="score", + full_name="tensorflow.serving.Class.score", + index=1, + number=2, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=142, + serialized_end=179, ) _CLASSIFICATIONS = _descriptor.Descriptor( - name='Classifications', - full_name='tensorflow.serving.Classifications', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='classes', full_name='tensorflow.serving.Classifications.classes', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=181, - serialized_end=242, + name="Classifications", + full_name="tensorflow.serving.Classifications", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="classes", + full_name="tensorflow.serving.Classifications.classes", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ) + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=181, + serialized_end=242, ) _CLASSIFICATIONRESULT = _descriptor.Descriptor( - name='ClassificationResult', - full_name='tensorflow.serving.ClassificationResult', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='classifications', full_name='tensorflow.serving.ClassificationResult.classifications', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=244, - serialized_end=328, + name="ClassificationResult", + full_name="tensorflow.serving.ClassificationResult", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="classifications", + full_name="tensorflow.serving.ClassificationResult.classifications", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ) + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=244, + serialized_end=328, ) _CLASSIFICATIONREQUEST = _descriptor.Descriptor( - name='ClassificationRequest', - full_name='tensorflow.serving.ClassificationRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.ClassificationRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input', full_name='tensorflow.serving.ClassificationRequest.input', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=330, - serialized_end=446, + name="ClassificationRequest", + full_name="tensorflow.serving.ClassificationRequest", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.ClassificationRequest.model_spec", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="input", + full_name="tensorflow.serving.ClassificationRequest.input", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=330, + serialized_end=446, ) _CLASSIFICATIONRESPONSE = _descriptor.Descriptor( - name='ClassificationResponse', - full_name='tensorflow.serving.ClassificationResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.ClassificationResponse.model_spec', index=0, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='result', full_name='tensorflow.serving.ClassificationResponse.result', index=1, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=449, - serialized_end=582, + name="ClassificationResponse", + full_name="tensorflow.serving.ClassificationResponse", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.ClassificationResponse.model_spec", + index=0, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="result", + full_name="tensorflow.serving.ClassificationResponse.result", + index=1, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=449, + serialized_end=582, ) -_CLASSIFICATIONS.fields_by_name['classes'].message_type = _CLASS -_CLASSIFICATIONRESULT.fields_by_name['classifications'].message_type = _CLASSIFICATIONS -_CLASSIFICATIONREQUEST.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_CLASSIFICATIONREQUEST.fields_by_name['input'].message_type = tensorflow__serving_dot_apis_dot_input__pb2._INPUT -_CLASSIFICATIONRESPONSE.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_CLASSIFICATIONRESPONSE.fields_by_name['result'].message_type = _CLASSIFICATIONRESULT -DESCRIPTOR.message_types_by_name['Class'] = _CLASS -DESCRIPTOR.message_types_by_name['Classifications'] = _CLASSIFICATIONS -DESCRIPTOR.message_types_by_name['ClassificationResult'] = _CLASSIFICATIONRESULT -DESCRIPTOR.message_types_by_name['ClassificationRequest'] = _CLASSIFICATIONREQUEST -DESCRIPTOR.message_types_by_name['ClassificationResponse'] = _CLASSIFICATIONRESPONSE +_CLASSIFICATIONS.fields_by_name["classes"].message_type = _CLASS +_CLASSIFICATIONRESULT.fields_by_name["classifications"].message_type = _CLASSIFICATIONS +_CLASSIFICATIONREQUEST.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_CLASSIFICATIONREQUEST.fields_by_name[ + "input" +].message_type = tensorflow__serving_dot_apis_dot_input__pb2._INPUT +_CLASSIFICATIONRESPONSE.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_CLASSIFICATIONRESPONSE.fields_by_name["result"].message_type = _CLASSIFICATIONRESULT +DESCRIPTOR.message_types_by_name["Class"] = _CLASS +DESCRIPTOR.message_types_by_name["Classifications"] = _CLASSIFICATIONS +DESCRIPTOR.message_types_by_name["ClassificationResult"] = _CLASSIFICATIONRESULT +DESCRIPTOR.message_types_by_name["ClassificationRequest"] = _CLASSIFICATIONREQUEST +DESCRIPTOR.message_types_by_name["ClassificationResponse"] = _CLASSIFICATIONRESPONSE _sym_db.RegisterFileDescriptor(DESCRIPTOR) -Class = _reflection.GeneratedProtocolMessageType('Class', (_message.Message,), dict( - DESCRIPTOR = _CLASS, - __module__ = 'tensorflow_serving.apis.classification_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.Class) - )) +Class = _reflection.GeneratedProtocolMessageType( + "Class", + (_message.Message,), + dict( + DESCRIPTOR=_CLASS, + __module__="tensorflow_serving.apis.classification_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.Class) + ), +) _sym_db.RegisterMessage(Class) -Classifications = _reflection.GeneratedProtocolMessageType('Classifications', (_message.Message,), dict( - DESCRIPTOR = _CLASSIFICATIONS, - __module__ = 'tensorflow_serving.apis.classification_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.Classifications) - )) +Classifications = _reflection.GeneratedProtocolMessageType( + "Classifications", + (_message.Message,), + dict( + DESCRIPTOR=_CLASSIFICATIONS, + __module__="tensorflow_serving.apis.classification_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.Classifications) + ), +) _sym_db.RegisterMessage(Classifications) -ClassificationResult = _reflection.GeneratedProtocolMessageType('ClassificationResult', (_message.Message,), dict( - DESCRIPTOR = _CLASSIFICATIONRESULT, - __module__ = 'tensorflow_serving.apis.classification_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ClassificationResult) - )) +ClassificationResult = _reflection.GeneratedProtocolMessageType( + "ClassificationResult", + (_message.Message,), + dict( + DESCRIPTOR=_CLASSIFICATIONRESULT, + __module__="tensorflow_serving.apis.classification_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.ClassificationResult) + ), +) _sym_db.RegisterMessage(ClassificationResult) -ClassificationRequest = _reflection.GeneratedProtocolMessageType('ClassificationRequest', (_message.Message,), dict( - DESCRIPTOR = _CLASSIFICATIONREQUEST, - __module__ = 'tensorflow_serving.apis.classification_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ClassificationRequest) - )) +ClassificationRequest = _reflection.GeneratedProtocolMessageType( + "ClassificationRequest", + (_message.Message,), + dict( + DESCRIPTOR=_CLASSIFICATIONREQUEST, + __module__="tensorflow_serving.apis.classification_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.ClassificationRequest) + ), +) _sym_db.RegisterMessage(ClassificationRequest) -ClassificationResponse = _reflection.GeneratedProtocolMessageType('ClassificationResponse', (_message.Message,), dict( - DESCRIPTOR = _CLASSIFICATIONRESPONSE, - __module__ = 'tensorflow_serving.apis.classification_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ClassificationResponse) - )) +ClassificationResponse = _reflection.GeneratedProtocolMessageType( + "ClassificationResponse", + (_message.Message,), + dict( + DESCRIPTOR=_CLASSIFICATIONRESPONSE, + __module__="tensorflow_serving.apis.classification_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.ClassificationResponse) + ), +) _sym_db.RegisterMessage(ClassificationResponse) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/get_model_metadata_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/get_model_metadata_pb2.py index 30ddf4b992..07186710fe 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/get_model_metadata_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/get_model_metadata_pb2.py @@ -2,268 +2,390 @@ # source: tensorflow_serving/apis/get_model_metadata.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -from tensorflow.core.protobuf import meta_graph_pb2 as tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2 +from tensorflow.core.protobuf import ( + meta_graph_pb2 as tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2, +) from tensorflow_serving.apis import model_pb2 as tensorflow__serving_dot_apis_dot_model__pb2 DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/get_model_metadata.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n0tensorflow_serving/apis/get_model_metadata.proto\x12\x12tensorflow.serving\x1a\x19google/protobuf/any.proto\x1a)tensorflow/core/protobuf/meta_graph.proto\x1a#tensorflow_serving/apis/model.proto\"\xae\x01\n\x0fSignatureDefMap\x12L\n\rsignature_def\x18\x01 \x03(\x0b\x32\x35.tensorflow.serving.SignatureDefMap.SignatureDefEntry\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01\"d\n\x17GetModelMetadataRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x16\n\x0emetadata_field\x18\x02 \x03(\t\"\xe2\x01\n\x18GetModelMetadataResponse\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12L\n\x08metadata\x18\x02 \x03(\x0b\x32:.tensorflow.serving.GetModelMetadataResponse.MetadataEntry\x1a\x45\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR,]) - - + name="tensorflow_serving/apis/get_model_metadata.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + '\n0tensorflow_serving/apis/get_model_metadata.proto\x12\x12tensorflow.serving\x1a\x19google/protobuf/any.proto\x1a)tensorflow/core/protobuf/meta_graph.proto\x1a#tensorflow_serving/apis/model.proto"\xae\x01\n\x0fSignatureDefMap\x12L\n\rsignature_def\x18\x01 \x03(\x0b\x32\x35.tensorflow.serving.SignatureDefMap.SignatureDefEntry\x1aM\n\x11SignatureDefEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.tensorflow.SignatureDef:\x02\x38\x01"d\n\x17GetModelMetadataRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x16\n\x0emetadata_field\x18\x02 \x03(\t"\xe2\x01\n\x18GetModelMetadataResponse\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12L\n\x08metadata\x18\x02 \x03(\x0b\x32:.tensorflow.serving.GetModelMetadataResponse.MetadataEntry\x1a\x45\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.google.protobuf.Any:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' + ), + dependencies=[ + google_dot_protobuf_dot_any__pb2.DESCRIPTOR, + tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR, + ], +) _SIGNATUREDEFMAP_SIGNATUREDEFENTRY = _descriptor.Descriptor( - name='SignatureDefEntry', - full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.SignatureDefMap.SignatureDefEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=277, - serialized_end=354, + name="SignatureDefEntry", + full_name="tensorflow.serving.SignatureDefMap.SignatureDefEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tensorflow.serving.SignatureDefMap.SignatureDefEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tensorflow.serving.SignatureDefMap.SignatureDefEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=277, + serialized_end=354, ) _SIGNATUREDEFMAP = _descriptor.Descriptor( - name='SignatureDefMap', - full_name='tensorflow.serving.SignatureDefMap', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='signature_def', full_name='tensorflow.serving.SignatureDefMap.signature_def', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SIGNATUREDEFMAP_SIGNATUREDEFENTRY, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=180, - serialized_end=354, + name="SignatureDefMap", + full_name="tensorflow.serving.SignatureDefMap", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="signature_def", + full_name="tensorflow.serving.SignatureDefMap.signature_def", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ) + ], + extensions=[], + nested_types=[_SIGNATUREDEFMAP_SIGNATUREDEFENTRY], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=180, + serialized_end=354, ) _GETMODELMETADATAREQUEST = _descriptor.Descriptor( - name='GetModelMetadataRequest', - full_name='tensorflow.serving.GetModelMetadataRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.GetModelMetadataRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata_field', full_name='tensorflow.serving.GetModelMetadataRequest.metadata_field', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=356, - serialized_end=456, + name="GetModelMetadataRequest", + full_name="tensorflow.serving.GetModelMetadataRequest", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.GetModelMetadataRequest.model_spec", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="metadata_field", + full_name="tensorflow.serving.GetModelMetadataRequest.metadata_field", + index=1, + number=2, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=356, + serialized_end=456, ) _GETMODELMETADATARESPONSE_METADATAENTRY = _descriptor.Descriptor( - name='MetadataEntry', - full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.GetModelMetadataResponse.MetadataEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=616, - serialized_end=685, + name="MetadataEntry", + full_name="tensorflow.serving.GetModelMetadataResponse.MetadataEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tensorflow.serving.GetModelMetadataResponse.MetadataEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tensorflow.serving.GetModelMetadataResponse.MetadataEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=616, + serialized_end=685, ) _GETMODELMETADATARESPONSE = _descriptor.Descriptor( - name='GetModelMetadataResponse', - full_name='tensorflow.serving.GetModelMetadataResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.GetModelMetadataResponse.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='metadata', full_name='tensorflow.serving.GetModelMetadataResponse.metadata', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_GETMODELMETADATARESPONSE_METADATAENTRY, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=459, - serialized_end=685, + name="GetModelMetadataResponse", + full_name="tensorflow.serving.GetModelMetadataResponse", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.GetModelMetadataResponse.model_spec", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="metadata", + full_name="tensorflow.serving.GetModelMetadataResponse.metadata", + index=1, + number=2, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[_GETMODELMETADATARESPONSE_METADATAENTRY], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=459, + serialized_end=685, ) -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.fields_by_name['value'].message_type = tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2._SIGNATUREDEF +_SIGNATUREDEFMAP_SIGNATUREDEFENTRY.fields_by_name[ + "value" +].message_type = tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2._SIGNATUREDEF _SIGNATUREDEFMAP_SIGNATUREDEFENTRY.containing_type = _SIGNATUREDEFMAP -_SIGNATUREDEFMAP.fields_by_name['signature_def'].message_type = _SIGNATUREDEFMAP_SIGNATUREDEFENTRY -_GETMODELMETADATAREQUEST.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_GETMODELMETADATARESPONSE_METADATAENTRY.fields_by_name['value'].message_type = google_dot_protobuf_dot_any__pb2._ANY +_SIGNATUREDEFMAP.fields_by_name["signature_def"].message_type = _SIGNATUREDEFMAP_SIGNATUREDEFENTRY +_GETMODELMETADATAREQUEST.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_GETMODELMETADATARESPONSE_METADATAENTRY.fields_by_name[ + "value" +].message_type = google_dot_protobuf_dot_any__pb2._ANY _GETMODELMETADATARESPONSE_METADATAENTRY.containing_type = _GETMODELMETADATARESPONSE -_GETMODELMETADATARESPONSE.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_GETMODELMETADATARESPONSE.fields_by_name['metadata'].message_type = _GETMODELMETADATARESPONSE_METADATAENTRY -DESCRIPTOR.message_types_by_name['SignatureDefMap'] = _SIGNATUREDEFMAP -DESCRIPTOR.message_types_by_name['GetModelMetadataRequest'] = _GETMODELMETADATAREQUEST -DESCRIPTOR.message_types_by_name['GetModelMetadataResponse'] = _GETMODELMETADATARESPONSE +_GETMODELMETADATARESPONSE.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_GETMODELMETADATARESPONSE.fields_by_name[ + "metadata" +].message_type = _GETMODELMETADATARESPONSE_METADATAENTRY +DESCRIPTOR.message_types_by_name["SignatureDefMap"] = _SIGNATUREDEFMAP +DESCRIPTOR.message_types_by_name["GetModelMetadataRequest"] = _GETMODELMETADATAREQUEST +DESCRIPTOR.message_types_by_name["GetModelMetadataResponse"] = _GETMODELMETADATARESPONSE _sym_db.RegisterFileDescriptor(DESCRIPTOR) -SignatureDefMap = _reflection.GeneratedProtocolMessageType('SignatureDefMap', (_message.Message,), dict( - - SignatureDefEntry = _reflection.GeneratedProtocolMessageType('SignatureDefEntry', (_message.Message,), dict( - DESCRIPTOR = _SIGNATUREDEFMAP_SIGNATUREDEFENTRY, - __module__ = 'tensorflow_serving.apis.get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap.SignatureDefEntry) - )) - , - DESCRIPTOR = _SIGNATUREDEFMAP, - __module__ = 'tensorflow_serving.apis.get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap) - )) +SignatureDefMap = _reflection.GeneratedProtocolMessageType( + "SignatureDefMap", + (_message.Message,), + dict( + SignatureDefEntry=_reflection.GeneratedProtocolMessageType( + "SignatureDefEntry", + (_message.Message,), + dict( + DESCRIPTOR=_SIGNATUREDEFMAP_SIGNATUREDEFENTRY, + __module__="tensorflow_serving.apis.get_model_metadata_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap.SignatureDefEntry) + ), + ), + DESCRIPTOR=_SIGNATUREDEFMAP, + __module__="tensorflow_serving.apis.get_model_metadata_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.SignatureDefMap) + ), +) _sym_db.RegisterMessage(SignatureDefMap) _sym_db.RegisterMessage(SignatureDefMap.SignatureDefEntry) -GetModelMetadataRequest = _reflection.GeneratedProtocolMessageType('GetModelMetadataRequest', (_message.Message,), dict( - DESCRIPTOR = _GETMODELMETADATAREQUEST, - __module__ = 'tensorflow_serving.apis.get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataRequest) - )) +GetModelMetadataRequest = _reflection.GeneratedProtocolMessageType( + "GetModelMetadataRequest", + (_message.Message,), + dict( + DESCRIPTOR=_GETMODELMETADATAREQUEST, + __module__="tensorflow_serving.apis.get_model_metadata_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataRequest) + ), +) _sym_db.RegisterMessage(GetModelMetadataRequest) -GetModelMetadataResponse = _reflection.GeneratedProtocolMessageType('GetModelMetadataResponse', (_message.Message,), dict( - - MetadataEntry = _reflection.GeneratedProtocolMessageType('MetadataEntry', (_message.Message,), dict( - DESCRIPTOR = _GETMODELMETADATARESPONSE_METADATAENTRY, - __module__ = 'tensorflow_serving.apis.get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse.MetadataEntry) - )) - , - DESCRIPTOR = _GETMODELMETADATARESPONSE, - __module__ = 'tensorflow_serving.apis.get_model_metadata_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse) - )) +GetModelMetadataResponse = _reflection.GeneratedProtocolMessageType( + "GetModelMetadataResponse", + (_message.Message,), + dict( + MetadataEntry=_reflection.GeneratedProtocolMessageType( + "MetadataEntry", + (_message.Message,), + dict( + DESCRIPTOR=_GETMODELMETADATARESPONSE_METADATAENTRY, + __module__="tensorflow_serving.apis.get_model_metadata_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse.MetadataEntry) + ), + ), + DESCRIPTOR=_GETMODELMETADATARESPONSE, + __module__="tensorflow_serving.apis.get_model_metadata_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.GetModelMetadataResponse) + ), +) _sym_db.RegisterMessage(GetModelMetadataResponse) _sym_db.RegisterMessage(GetModelMetadataResponse.MetadataEntry) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) _SIGNATUREDEFMAP_SIGNATUREDEFENTRY.has_options = True -_SIGNATUREDEFMAP_SIGNATUREDEFENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +_SIGNATUREDEFMAP_SIGNATUREDEFENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b("8\001") +) _GETMODELMETADATARESPONSE_METADATAENTRY.has_options = True -_GETMODELMETADATARESPONSE_METADATAENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +_GETMODELMETADATARESPONSE_METADATAENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b("8\001") +) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/inference_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/inference_pb2.py index 2698fe2e17..3db862fe9b 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/inference_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/inference_pb2.py @@ -2,236 +2,360 @@ # source: tensorflow_serving/apis/inference.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -from tensorflow_serving.apis import classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2 +from tensorflow_serving.apis import ( + classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2, +) from tensorflow_serving.apis import input_pb2 as tensorflow__serving_dot_apis_dot_input__pb2 from tensorflow_serving.apis import model_pb2 as tensorflow__serving_dot_apis_dot_model__pb2 -from tensorflow_serving.apis import regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2 +from tensorflow_serving.apis import ( + regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2, +) DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/inference.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n\'tensorflow_serving/apis/inference.proto\x12\x12tensorflow.serving\x1a,tensorflow_serving/apis/classification.proto\x1a#tensorflow_serving/apis/input.proto\x1a#tensorflow_serving/apis/model.proto\x1a(tensorflow_serving/apis/regression.proto\"W\n\rInferenceTask\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x13\n\x0bmethod_name\x18\x02 \x01(\t\"\xdc\x01\n\x0fInferenceResult\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12I\n\x15\x63lassification_result\x18\x02 \x01(\x0b\x32(.tensorflow.serving.ClassificationResultH\x00\x12\x41\n\x11regression_result\x18\x03 \x01(\x0b\x32$.tensorflow.serving.RegressionResultH\x00\x42\x08\n\x06result\"s\n\x15MultiInferenceRequest\x12\x30\n\x05tasks\x18\x01 \x03(\x0b\x32!.tensorflow.serving.InferenceTask\x12(\n\x05input\x18\x02 \x01(\x0b\x32\x19.tensorflow.serving.Input\"N\n\x16MultiInferenceResponse\x12\x34\n\x07results\x18\x01 \x03(\x0b\x32#.tensorflow.serving.InferenceResultB\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensorflow__serving_dot_apis_dot_classification__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_input__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_regression__pb2.DESCRIPTOR,]) - - + name="tensorflow_serving/apis/inference.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + '\n\'tensorflow_serving/apis/inference.proto\x12\x12tensorflow.serving\x1a,tensorflow_serving/apis/classification.proto\x1a#tensorflow_serving/apis/input.proto\x1a#tensorflow_serving/apis/model.proto\x1a(tensorflow_serving/apis/regression.proto"W\n\rInferenceTask\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x13\n\x0bmethod_name\x18\x02 \x01(\t"\xdc\x01\n\x0fInferenceResult\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12I\n\x15\x63lassification_result\x18\x02 \x01(\x0b\x32(.tensorflow.serving.ClassificationResultH\x00\x12\x41\n\x11regression_result\x18\x03 \x01(\x0b\x32$.tensorflow.serving.RegressionResultH\x00\x42\x08\n\x06result"s\n\x15MultiInferenceRequest\x12\x30\n\x05tasks\x18\x01 \x03(\x0b\x32!.tensorflow.serving.InferenceTask\x12(\n\x05input\x18\x02 \x01(\x0b\x32\x19.tensorflow.serving.Input"N\n\x16MultiInferenceResponse\x12\x34\n\x07results\x18\x01 \x03(\x0b\x32#.tensorflow.serving.InferenceResultB\x03\xf8\x01\x01\x62\x06proto3' + ), + dependencies=[ + tensorflow__serving_dot_apis_dot_classification__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_input__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_regression__pb2.DESCRIPTOR, + ], +) _INFERENCETASK = _descriptor.Descriptor( - name='InferenceTask', - full_name='tensorflow.serving.InferenceTask', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.InferenceTask.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='method_name', full_name='tensorflow.serving.InferenceTask.method_name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=225, - serialized_end=312, + name="InferenceTask", + full_name="tensorflow.serving.InferenceTask", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.InferenceTask.model_spec", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="method_name", + full_name="tensorflow.serving.InferenceTask.method_name", + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=225, + serialized_end=312, ) _INFERENCERESULT = _descriptor.Descriptor( - name='InferenceResult', - full_name='tensorflow.serving.InferenceResult', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.InferenceResult.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='classification_result', full_name='tensorflow.serving.InferenceResult.classification_result', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='regression_result', full_name='tensorflow.serving.InferenceResult.regression_result', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='result', full_name='tensorflow.serving.InferenceResult.result', - index=0, containing_type=None, fields=[]), - ], - serialized_start=315, - serialized_end=535, + name="InferenceResult", + full_name="tensorflow.serving.InferenceResult", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.InferenceResult.model_spec", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="classification_result", + full_name="tensorflow.serving.InferenceResult.classification_result", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="regression_result", + full_name="tensorflow.serving.InferenceResult.regression_result", + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="result", + full_name="tensorflow.serving.InferenceResult.result", + index=0, + containing_type=None, + fields=[], + ) + ], + serialized_start=315, + serialized_end=535, ) _MULTIINFERENCEREQUEST = _descriptor.Descriptor( - name='MultiInferenceRequest', - full_name='tensorflow.serving.MultiInferenceRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tasks', full_name='tensorflow.serving.MultiInferenceRequest.tasks', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input', full_name='tensorflow.serving.MultiInferenceRequest.input', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=537, - serialized_end=652, + name="MultiInferenceRequest", + full_name="tensorflow.serving.MultiInferenceRequest", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="tasks", + full_name="tensorflow.serving.MultiInferenceRequest.tasks", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="input", + full_name="tensorflow.serving.MultiInferenceRequest.input", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=537, + serialized_end=652, ) _MULTIINFERENCERESPONSE = _descriptor.Descriptor( - name='MultiInferenceResponse', - full_name='tensorflow.serving.MultiInferenceResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='results', full_name='tensorflow.serving.MultiInferenceResponse.results', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=654, - serialized_end=732, + name="MultiInferenceResponse", + full_name="tensorflow.serving.MultiInferenceResponse", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="results", + full_name="tensorflow.serving.MultiInferenceResponse.results", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ) + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=654, + serialized_end=732, ) -_INFERENCETASK.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_INFERENCERESULT.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_INFERENCERESULT.fields_by_name['classification_result'].message_type = tensorflow__serving_dot_apis_dot_classification__pb2._CLASSIFICATIONRESULT -_INFERENCERESULT.fields_by_name['regression_result'].message_type = tensorflow__serving_dot_apis_dot_regression__pb2._REGRESSIONRESULT -_INFERENCERESULT.oneofs_by_name['result'].fields.append( - _INFERENCERESULT.fields_by_name['classification_result']) -_INFERENCERESULT.fields_by_name['classification_result'].containing_oneof = _INFERENCERESULT.oneofs_by_name['result'] -_INFERENCERESULT.oneofs_by_name['result'].fields.append( - _INFERENCERESULT.fields_by_name['regression_result']) -_INFERENCERESULT.fields_by_name['regression_result'].containing_oneof = _INFERENCERESULT.oneofs_by_name['result'] -_MULTIINFERENCEREQUEST.fields_by_name['tasks'].message_type = _INFERENCETASK -_MULTIINFERENCEREQUEST.fields_by_name['input'].message_type = tensorflow__serving_dot_apis_dot_input__pb2._INPUT -_MULTIINFERENCERESPONSE.fields_by_name['results'].message_type = _INFERENCERESULT -DESCRIPTOR.message_types_by_name['InferenceTask'] = _INFERENCETASK -DESCRIPTOR.message_types_by_name['InferenceResult'] = _INFERENCERESULT -DESCRIPTOR.message_types_by_name['MultiInferenceRequest'] = _MULTIINFERENCEREQUEST -DESCRIPTOR.message_types_by_name['MultiInferenceResponse'] = _MULTIINFERENCERESPONSE +_INFERENCETASK.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_INFERENCERESULT.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_INFERENCERESULT.fields_by_name[ + "classification_result" +].message_type = tensorflow__serving_dot_apis_dot_classification__pb2._CLASSIFICATIONRESULT +_INFERENCERESULT.fields_by_name[ + "regression_result" +].message_type = tensorflow__serving_dot_apis_dot_regression__pb2._REGRESSIONRESULT +_INFERENCERESULT.oneofs_by_name["result"].fields.append( + _INFERENCERESULT.fields_by_name["classification_result"] +) +_INFERENCERESULT.fields_by_name[ + "classification_result" +].containing_oneof = _INFERENCERESULT.oneofs_by_name["result"] +_INFERENCERESULT.oneofs_by_name["result"].fields.append( + _INFERENCERESULT.fields_by_name["regression_result"] +) +_INFERENCERESULT.fields_by_name[ + "regression_result" +].containing_oneof = _INFERENCERESULT.oneofs_by_name["result"] +_MULTIINFERENCEREQUEST.fields_by_name["tasks"].message_type = _INFERENCETASK +_MULTIINFERENCEREQUEST.fields_by_name[ + "input" +].message_type = tensorflow__serving_dot_apis_dot_input__pb2._INPUT +_MULTIINFERENCERESPONSE.fields_by_name["results"].message_type = _INFERENCERESULT +DESCRIPTOR.message_types_by_name["InferenceTask"] = _INFERENCETASK +DESCRIPTOR.message_types_by_name["InferenceResult"] = _INFERENCERESULT +DESCRIPTOR.message_types_by_name["MultiInferenceRequest"] = _MULTIINFERENCEREQUEST +DESCRIPTOR.message_types_by_name["MultiInferenceResponse"] = _MULTIINFERENCERESPONSE _sym_db.RegisterFileDescriptor(DESCRIPTOR) -InferenceTask = _reflection.GeneratedProtocolMessageType('InferenceTask', (_message.Message,), dict( - DESCRIPTOR = _INFERENCETASK, - __module__ = 'tensorflow_serving.apis.inference_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.InferenceTask) - )) +InferenceTask = _reflection.GeneratedProtocolMessageType( + "InferenceTask", + (_message.Message,), + dict( + DESCRIPTOR=_INFERENCETASK, + __module__="tensorflow_serving.apis.inference_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.InferenceTask) + ), +) _sym_db.RegisterMessage(InferenceTask) -InferenceResult = _reflection.GeneratedProtocolMessageType('InferenceResult', (_message.Message,), dict( - DESCRIPTOR = _INFERENCERESULT, - __module__ = 'tensorflow_serving.apis.inference_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.InferenceResult) - )) +InferenceResult = _reflection.GeneratedProtocolMessageType( + "InferenceResult", + (_message.Message,), + dict( + DESCRIPTOR=_INFERENCERESULT, + __module__="tensorflow_serving.apis.inference_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.InferenceResult) + ), +) _sym_db.RegisterMessage(InferenceResult) -MultiInferenceRequest = _reflection.GeneratedProtocolMessageType('MultiInferenceRequest', (_message.Message,), dict( - DESCRIPTOR = _MULTIINFERENCEREQUEST, - __module__ = 'tensorflow_serving.apis.inference_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.MultiInferenceRequest) - )) +MultiInferenceRequest = _reflection.GeneratedProtocolMessageType( + "MultiInferenceRequest", + (_message.Message,), + dict( + DESCRIPTOR=_MULTIINFERENCEREQUEST, + __module__="tensorflow_serving.apis.inference_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.MultiInferenceRequest) + ), +) _sym_db.RegisterMessage(MultiInferenceRequest) -MultiInferenceResponse = _reflection.GeneratedProtocolMessageType('MultiInferenceResponse', (_message.Message,), dict( - DESCRIPTOR = _MULTIINFERENCERESPONSE, - __module__ = 'tensorflow_serving.apis.inference_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.MultiInferenceResponse) - )) +MultiInferenceResponse = _reflection.GeneratedProtocolMessageType( + "MultiInferenceResponse", + (_message.Message,), + dict( + DESCRIPTOR=_MULTIINFERENCERESPONSE, + __module__="tensorflow_serving.apis.inference_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.MultiInferenceResponse) + ), +) _sym_db.RegisterMessage(MultiInferenceResponse) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/input_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/input_pb2.py index cf754e3157..f8df1ab2a5 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/input_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/input_pb2.py @@ -2,12 +2,14 @@ # source: tensorflow_serving/apis/input.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -17,167 +19,238 @@ DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/input.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n#tensorflow_serving/apis/input.proto\x12\x12tensorflow.serving\x1a%tensorflow/core/example/example.proto\"4\n\x0b\x45xampleList\x12%\n\x08\x65xamples\x18\x01 \x03(\x0b\x32\x13.tensorflow.Example\"e\n\x16\x45xampleListWithContext\x12%\n\x08\x65xamples\x18\x01 \x03(\x0b\x32\x13.tensorflow.Example\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.tensorflow.Example\"\xa1\x01\n\x05Input\x12;\n\x0c\x65xample_list\x18\x01 \x01(\x0b\x32\x1f.tensorflow.serving.ExampleListB\x02(\x01H\x00\x12S\n\x19\x65xample_list_with_context\x18\x02 \x01(\x0b\x32*.tensorflow.serving.ExampleListWithContextB\x02(\x01H\x00\x42\x06\n\x04kindB\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensorflow_dot_core_dot_example_dot_example__pb2.DESCRIPTOR,]) - - + name="tensorflow_serving/apis/input.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + '\n#tensorflow_serving/apis/input.proto\x12\x12tensorflow.serving\x1a%tensorflow/core/example/example.proto"4\n\x0b\x45xampleList\x12%\n\x08\x65xamples\x18\x01 \x03(\x0b\x32\x13.tensorflow.Example"e\n\x16\x45xampleListWithContext\x12%\n\x08\x65xamples\x18\x01 \x03(\x0b\x32\x13.tensorflow.Example\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.tensorflow.Example"\xa1\x01\n\x05Input\x12;\n\x0c\x65xample_list\x18\x01 \x01(\x0b\x32\x1f.tensorflow.serving.ExampleListB\x02(\x01H\x00\x12S\n\x19\x65xample_list_with_context\x18\x02 \x01(\x0b\x32*.tensorflow.serving.ExampleListWithContextB\x02(\x01H\x00\x42\x06\n\x04kindB\x03\xf8\x01\x01\x62\x06proto3' + ), + dependencies=[tensorflow_dot_core_dot_example_dot_example__pb2.DESCRIPTOR], +) _EXAMPLELIST = _descriptor.Descriptor( - name='ExampleList', - full_name='tensorflow.serving.ExampleList', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='examples', full_name='tensorflow.serving.ExampleList.examples', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=98, - serialized_end=150, + name="ExampleList", + full_name="tensorflow.serving.ExampleList", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="examples", + full_name="tensorflow.serving.ExampleList.examples", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ) + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=98, + serialized_end=150, ) _EXAMPLELISTWITHCONTEXT = _descriptor.Descriptor( - name='ExampleListWithContext', - full_name='tensorflow.serving.ExampleListWithContext', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='examples', full_name='tensorflow.serving.ExampleListWithContext.examples', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='context', full_name='tensorflow.serving.ExampleListWithContext.context', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=152, - serialized_end=253, + name="ExampleListWithContext", + full_name="tensorflow.serving.ExampleListWithContext", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="examples", + full_name="tensorflow.serving.ExampleListWithContext.examples", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="context", + full_name="tensorflow.serving.ExampleListWithContext.context", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=152, + serialized_end=253, ) _INPUT = _descriptor.Descriptor( - name='Input', - full_name='tensorflow.serving.Input', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='example_list', full_name='tensorflow.serving.Input.example_list', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('(\001')), file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='example_list_with_context', full_name='tensorflow.serving.Input.example_list_with_context', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('(\001')), file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='kind', full_name='tensorflow.serving.Input.kind', - index=0, containing_type=None, fields=[]), - ], - serialized_start=256, - serialized_end=417, + name="Input", + full_name="tensorflow.serving.Input", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="example_list", + full_name="tensorflow.serving.Input.example_list", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("(\001")), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="example_list_with_context", + full_name="tensorflow.serving.Input.example_list_with_context", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("(\001")), + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name="kind", + full_name="tensorflow.serving.Input.kind", + index=0, + containing_type=None, + fields=[], + ) + ], + serialized_start=256, + serialized_end=417, ) -_EXAMPLELIST.fields_by_name['examples'].message_type = tensorflow_dot_core_dot_example_dot_example__pb2._EXAMPLE -_EXAMPLELISTWITHCONTEXT.fields_by_name['examples'].message_type = tensorflow_dot_core_dot_example_dot_example__pb2._EXAMPLE -_EXAMPLELISTWITHCONTEXT.fields_by_name['context'].message_type = tensorflow_dot_core_dot_example_dot_example__pb2._EXAMPLE -_INPUT.fields_by_name['example_list'].message_type = _EXAMPLELIST -_INPUT.fields_by_name['example_list_with_context'].message_type = _EXAMPLELISTWITHCONTEXT -_INPUT.oneofs_by_name['kind'].fields.append( - _INPUT.fields_by_name['example_list']) -_INPUT.fields_by_name['example_list'].containing_oneof = _INPUT.oneofs_by_name['kind'] -_INPUT.oneofs_by_name['kind'].fields.append( - _INPUT.fields_by_name['example_list_with_context']) -_INPUT.fields_by_name['example_list_with_context'].containing_oneof = _INPUT.oneofs_by_name['kind'] -DESCRIPTOR.message_types_by_name['ExampleList'] = _EXAMPLELIST -DESCRIPTOR.message_types_by_name['ExampleListWithContext'] = _EXAMPLELISTWITHCONTEXT -DESCRIPTOR.message_types_by_name['Input'] = _INPUT +_EXAMPLELIST.fields_by_name[ + "examples" +].message_type = tensorflow_dot_core_dot_example_dot_example__pb2._EXAMPLE +_EXAMPLELISTWITHCONTEXT.fields_by_name[ + "examples" +].message_type = tensorflow_dot_core_dot_example_dot_example__pb2._EXAMPLE +_EXAMPLELISTWITHCONTEXT.fields_by_name[ + "context" +].message_type = tensorflow_dot_core_dot_example_dot_example__pb2._EXAMPLE +_INPUT.fields_by_name["example_list"].message_type = _EXAMPLELIST +_INPUT.fields_by_name["example_list_with_context"].message_type = _EXAMPLELISTWITHCONTEXT +_INPUT.oneofs_by_name["kind"].fields.append(_INPUT.fields_by_name["example_list"]) +_INPUT.fields_by_name["example_list"].containing_oneof = _INPUT.oneofs_by_name["kind"] +_INPUT.oneofs_by_name["kind"].fields.append(_INPUT.fields_by_name["example_list_with_context"]) +_INPUT.fields_by_name["example_list_with_context"].containing_oneof = _INPUT.oneofs_by_name["kind"] +DESCRIPTOR.message_types_by_name["ExampleList"] = _EXAMPLELIST +DESCRIPTOR.message_types_by_name["ExampleListWithContext"] = _EXAMPLELISTWITHCONTEXT +DESCRIPTOR.message_types_by_name["Input"] = _INPUT _sym_db.RegisterFileDescriptor(DESCRIPTOR) -ExampleList = _reflection.GeneratedProtocolMessageType('ExampleList', (_message.Message,), dict( - DESCRIPTOR = _EXAMPLELIST, - __module__ = 'tensorflow_serving.apis.input_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ExampleList) - )) +ExampleList = _reflection.GeneratedProtocolMessageType( + "ExampleList", + (_message.Message,), + dict( + DESCRIPTOR=_EXAMPLELIST, + __module__="tensorflow_serving.apis.input_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.ExampleList) + ), +) _sym_db.RegisterMessage(ExampleList) -ExampleListWithContext = _reflection.GeneratedProtocolMessageType('ExampleListWithContext', (_message.Message,), dict( - DESCRIPTOR = _EXAMPLELISTWITHCONTEXT, - __module__ = 'tensorflow_serving.apis.input_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ExampleListWithContext) - )) +ExampleListWithContext = _reflection.GeneratedProtocolMessageType( + "ExampleListWithContext", + (_message.Message,), + dict( + DESCRIPTOR=_EXAMPLELISTWITHCONTEXT, + __module__="tensorflow_serving.apis.input_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.ExampleListWithContext) + ), +) _sym_db.RegisterMessage(ExampleListWithContext) -Input = _reflection.GeneratedProtocolMessageType('Input', (_message.Message,), dict( - DESCRIPTOR = _INPUT, - __module__ = 'tensorflow_serving.apis.input_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.Input) - )) +Input = _reflection.GeneratedProtocolMessageType( + "Input", + (_message.Message,), + dict( + DESCRIPTOR=_INPUT, + __module__="tensorflow_serving.apis.input_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.Input) + ), +) _sym_db.RegisterMessage(Input) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) -_INPUT.fields_by_name['example_list'].has_options = True -_INPUT.fields_by_name['example_list']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('(\001')) -_INPUT.fields_by_name['example_list_with_context'].has_options = True -_INPUT.fields_by_name['example_list_with_context']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('(\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) +_INPUT.fields_by_name["example_list"].has_options = True +_INPUT.fields_by_name["example_list"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("(\001") +) +_INPUT.fields_by_name["example_list_with_context"].has_options = True +_INPUT.fields_by_name["example_list_with_context"]._options = _descriptor._ParseOptions( + descriptor_pb2.FieldOptions(), _b("(\001") +) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/model_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/model_pb2.py index 71aeff6620..5e07a816b7 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/model_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/model_pb2.py @@ -2,12 +2,14 @@ # source: tensorflow_serving/apis/model.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -17,72 +19,108 @@ DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/model.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n#tensorflow_serving/apis/model.proto\x12\x12tensorflow.serving\x1a\x1egoogle/protobuf/wrappers.proto\"_\n\tModelSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12,\n\x07version\x18\x02 \x01(\x0b\x32\x1b.google.protobuf.Int64Value\x12\x16\n\x0esignature_name\x18\x03 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[google_dot_protobuf_dot_wrappers__pb2.DESCRIPTOR,]) - - + name="tensorflow_serving/apis/model.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + '\n#tensorflow_serving/apis/model.proto\x12\x12tensorflow.serving\x1a\x1egoogle/protobuf/wrappers.proto"_\n\tModelSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12,\n\x07version\x18\x02 \x01(\x0b\x32\x1b.google.protobuf.Int64Value\x12\x16\n\x0esignature_name\x18\x03 \x01(\tB\x03\xf8\x01\x01\x62\x06proto3' + ), + dependencies=[google_dot_protobuf_dot_wrappers__pb2.DESCRIPTOR], +) _MODELSPEC = _descriptor.Descriptor( - name='ModelSpec', - full_name='tensorflow.serving.ModelSpec', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.serving.ModelSpec.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='version', full_name='tensorflow.serving.ModelSpec.version', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='signature_name', full_name='tensorflow.serving.ModelSpec.signature_name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=91, - serialized_end=186, + name="ModelSpec", + full_name="tensorflow.serving.ModelSpec", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="name", + full_name="tensorflow.serving.ModelSpec.name", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="version", + full_name="tensorflow.serving.ModelSpec.version", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="signature_name", + full_name="tensorflow.serving.ModelSpec.signature_name", + index=2, + number=3, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=91, + serialized_end=186, ) -_MODELSPEC.fields_by_name['version'].message_type = google_dot_protobuf_dot_wrappers__pb2._INT64VALUE -DESCRIPTOR.message_types_by_name['ModelSpec'] = _MODELSPEC +_MODELSPEC.fields_by_name[ + "version" +].message_type = google_dot_protobuf_dot_wrappers__pb2._INT64VALUE +DESCRIPTOR.message_types_by_name["ModelSpec"] = _MODELSPEC _sym_db.RegisterFileDescriptor(DESCRIPTOR) -ModelSpec = _reflection.GeneratedProtocolMessageType('ModelSpec', (_message.Message,), dict( - DESCRIPTOR = _MODELSPEC, - __module__ = 'tensorflow_serving.apis.model_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.ModelSpec) - )) +ModelSpec = _reflection.GeneratedProtocolMessageType( + "ModelSpec", + (_message.Message,), + dict( + DESCRIPTOR=_MODELSPEC, + __module__="tensorflow_serving.apis.model_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.ModelSpec) + ), +) _sym_db.RegisterMessage(ModelSpec) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2.py index 5c8787abd7..512c9952ac 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2.py @@ -18,165 +18,204 @@ # python -m grpc.tools.protoc --python_out=. --grpc_python_out=. -I. tensorflow_serving/apis/model_service.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -from tensorflow_serving.apis import get_model_status_pb2 as tensorflow__serving_dot_apis_dot_get__model__status__pb2 +from tensorflow_serving.apis import ( + get_model_status_pb2 as tensorflow__serving_dot_apis_dot_get__model__status__pb2, +) DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/model_service.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n+tensorflow_serving/apis/model_service.proto\x12\x12tensorflow.serving\x1a.tensorflow_serving/apis/get_model_status.proto2w\n\x0cModelService\x12g\n\x0eGetModelStatus\x12).tensorflow.serving.GetModelStatusRequest\x1a*.tensorflow.serving.GetModelStatusResponseB\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensorflow__serving_dot_apis_dot_get__model__status__pb2.DESCRIPTOR,]) + name="tensorflow_serving/apis/model_service.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + "\n+tensorflow_serving/apis/model_service.proto\x12\x12tensorflow.serving\x1a.tensorflow_serving/apis/get_model_status.proto2w\n\x0cModelService\x12g\n\x0eGetModelStatus\x12).tensorflow.serving.GetModelStatusRequest\x1a*.tensorflow.serving.GetModelStatusResponseB\x03\xf8\x01\x01\x62\x06proto3" + ), + dependencies=[tensorflow__serving_dot_apis_dot_get__model__status__pb2.DESCRIPTOR], +) _sym_db.RegisterFileDescriptor(DESCRIPTOR) - - - DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) try: - # THESE ELEMENTS WILL BE DEPRECATED. - # Please use the generated *_pb2_grpc.py files instead. - import grpc - from grpc.framework.common import cardinality - from grpc.framework.interfaces.face import utilities as face_utilities - from grpc.beta import implementations as beta_implementations - from grpc.beta import interfaces as beta_interfaces - - - class ModelServiceStub(object): - """ModelService provides access to information about model versions + # THESE ELEMENTS WILL BE DEPRECATED. + # Please use the generated *_pb2_grpc.py files instead. + import grpc + from grpc.framework.common import cardinality + from grpc.framework.interfaces.face import utilities as face_utilities + from grpc.beta import implementations as beta_implementations + from grpc.beta import interfaces as beta_interfaces + + class ModelServiceStub(object): + """ModelService provides access to information about model versions that have been handled by the model server. """ - def __init__(self, channel): - """Constructor. + def __init__(self, channel): + """Constructor. Args: channel: A grpc.Channel. """ - self.GetModelStatus = channel.unary_unary( - '/tensorflow.serving.ModelService/GetModelStatus', - request_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.SerializeToString, - response_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.FromString, - ) - - - class ModelServiceServicer(object): - """ModelService provides access to information about model versions + self.GetModelStatus = channel.unary_unary( + "/tensorflow.serving.ModelService/GetModelStatus", + request_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.SerializeToString, + response_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.FromString, + ) + + class ModelServiceServicer(object): + """ModelService provides access to information about model versions that have been handled by the model server. """ - def GetModelStatus(self, request, context): - """Gets status of model. If the ModelSpec in the request does not specify + def GetModelStatus(self, request, context): + """Gets status of model. If the ModelSpec in the request does not specify version, information about all versions of the model will be returned. If the ModelSpec in the request does specify a version, the status of only that version will be returned. """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - - def add_ModelServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'GetModelStatus': grpc.unary_unary_rpc_method_handler( - servicer.GetModelStatus, - request_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.FromString, - response_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'tensorflow.serving.ModelService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - class BetaModelServiceServicer(object): - """The Beta API is deprecated for 0.15.0 and later. + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def add_ModelServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + "GetModelStatus": grpc.unary_unary_rpc_method_handler( + servicer.GetModelStatus, + request_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.FromString, + response_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.SerializeToString, + ) + } + generic_handler = grpc.method_handlers_generic_handler( + "tensorflow.serving.ModelService", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + class BetaModelServiceServicer(object): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This class was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" - """ModelService provides access to information about model versions + + """ModelService provides access to information about model versions that have been handled by the model server. """ - def GetModelStatus(self, request, context): - """Gets status of model. If the ModelSpec in the request does not specify + + def GetModelStatus(self, request, context): + """Gets status of model. If the ModelSpec in the request does not specify version, information about all versions of the model will be returned. If the ModelSpec in the request does specify a version, the status of only that version will be returned. """ - context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) - - class BetaModelServiceStub(object): - """The Beta API is deprecated for 0.15.0 and later. + class BetaModelServiceStub(object): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This class was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" - """ModelService provides access to information about model versions + + """ModelService provides access to information about model versions that have been handled by the model server. """ - def GetModelStatus(self, request, timeout, metadata=None, with_call=False, protocol_options=None): - """Gets status of model. If the ModelSpec in the request does not specify + + def GetModelStatus( + self, request, timeout, metadata=None, with_call=False, protocol_options=None + ): + """Gets status of model. If the ModelSpec in the request does not specify version, information about all versions of the model will be returned. If the ModelSpec in the request does specify a version, the status of only that version will be returned. """ - raise NotImplementedError() - GetModelStatus.future = None + raise NotImplementedError() + GetModelStatus.future = None - def beta_create_ModelService_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None): - """The Beta API is deprecated for 0.15.0 and later. + def beta_create_ModelService_server( + servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None + ): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This function was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" - request_deserializers = { - ('tensorflow.serving.ModelService', 'GetModelStatus'): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.FromString, - } - response_serializers = { - ('tensorflow.serving.ModelService', 'GetModelStatus'): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.SerializeToString, - } - method_implementations = { - ('tensorflow.serving.ModelService', 'GetModelStatus'): face_utilities.unary_unary_inline(servicer.GetModelStatus), - } - server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout) - return beta_implementations.server(method_implementations, options=server_options) - - - def beta_create_ModelService_stub(channel, host=None, metadata_transformer=None, pool=None, pool_size=None): - """The Beta API is deprecated for 0.15.0 and later. + request_deserializers = { + ( + "tensorflow.serving.ModelService", + "GetModelStatus", + ): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.FromString + } + response_serializers = { + ( + "tensorflow.serving.ModelService", + "GetModelStatus", + ): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.SerializeToString + } + method_implementations = { + ( + "tensorflow.serving.ModelService", + "GetModelStatus", + ): face_utilities.unary_unary_inline(servicer.GetModelStatus) + } + server_options = beta_implementations.server_options( + request_deserializers=request_deserializers, + response_serializers=response_serializers, + thread_pool=pool, + thread_pool_size=pool_size, + default_timeout=default_timeout, + maximum_timeout=maximum_timeout, + ) + return beta_implementations.server(method_implementations, options=server_options) + + def beta_create_ModelService_stub( + channel, host=None, metadata_transformer=None, pool=None, pool_size=None + ): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This function was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" - request_serializers = { - ('tensorflow.serving.ModelService', 'GetModelStatus'): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.SerializeToString, - } - response_deserializers = { - ('tensorflow.serving.ModelService', 'GetModelStatus'): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.FromString, - } - cardinalities = { - 'GetModelStatus': cardinality.Cardinality.UNARY_UNARY, - } - stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size) - return beta_implementations.dynamic_stub(channel, 'tensorflow.serving.ModelService', cardinalities, options=stub_options) + request_serializers = { + ( + "tensorflow.serving.ModelService", + "GetModelStatus", + ): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.SerializeToString + } + response_deserializers = { + ( + "tensorflow.serving.ModelService", + "GetModelStatus", + ): tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.FromString + } + cardinalities = {"GetModelStatus": cardinality.Cardinality.UNARY_UNARY} + stub_options = beta_implementations.stub_options( + host=host, + metadata_transformer=metadata_transformer, + request_serializers=request_serializers, + response_deserializers=response_deserializers, + thread_pool=pool, + thread_pool_size=pool_size, + ) + return beta_implementations.dynamic_stub( + channel, "tensorflow.serving.ModelService", cardinalities, options=stub_options + ) + + except ImportError: - pass + pass # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2_grpc.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2_grpc.py index a44c7c7f53..17a83e6be1 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2_grpc.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/model_service_pb2_grpc.py @@ -25,47 +25,48 @@ class ModelServiceStub(object): - """ModelService provides access to information about model versions + """ModelService provides access to information about model versions that have been handled by the model server. """ - def __init__(self, channel): - """Constructor. + def __init__(self, channel): + """Constructor. Args: channel: A grpc.Channel. """ - self.GetModelStatus = channel.unary_unary( - '/tensorflow.serving.ModelService/GetModelStatus', - request_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.SerializeToString, - response_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.FromString, + self.GetModelStatus = channel.unary_unary( + "/tensorflow.serving.ModelService/GetModelStatus", + request_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.SerializeToString, + response_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.FromString, ) class ModelServiceServicer(object): - """ModelService provides access to information about model versions + """ModelService provides access to information about model versions that have been handled by the model server. """ - def GetModelStatus(self, request, context): - """Gets status of model. If the ModelSpec in the request does not specify + def GetModelStatus(self, request, context): + """Gets status of model. If the ModelSpec in the request does not specify version, information about all versions of the model will be returned. If the ModelSpec in the request does specify a version, the status of only that version will be returned. """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_ModelServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'GetModelStatus': grpc.unary_unary_rpc_method_handler( - servicer.GetModelStatus, - request_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.FromString, - response_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'tensorflow.serving.ModelService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) + rpc_method_handlers = { + "GetModelStatus": grpc.unary_unary_rpc_method_handler( + servicer.GetModelStatus, + request_deserializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusRequest.FromString, + response_serializer=tensorflow__serving_dot_apis_dot_get__model__status__pb2.GetModelStatusResponse.SerializeToString, + ) + } + generic_handler = grpc.method_handlers_generic_handler( + "tensorflow.serving.ModelService", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/predict_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/predict_pb2.py index 4db8f34668..5b0d8d9137 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/predict_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/predict_pb2.py @@ -2,235 +2,353 @@ # source: tensorflow_serving/apis/predict.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2 +from tensorflow.core.framework import ( + tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2, +) from tensorflow_serving.apis import model_pb2 as tensorflow__serving_dot_apis_dot_model__pb2 DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/predict.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n%tensorflow_serving/apis/predict.proto\x12\x12tensorflow.serving\x1a&tensorflow/core/framework/tensor.proto\x1a#tensorflow_serving/apis/model.proto\"\xe2\x01\n\x0ePredictRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.PredictRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"\xd0\x01\n\x0fPredictResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.PredictResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensorflow_dot_core_dot_framework_dot_tensor__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR,]) - - + name="tensorflow_serving/apis/predict.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + '\n%tensorflow_serving/apis/predict.proto\x12\x12tensorflow.serving\x1a&tensorflow/core/framework/tensor.proto\x1a#tensorflow_serving/apis/model.proto"\xe2\x01\n\x0ePredictRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12>\n\x06inputs\x18\x02 \x03(\x0b\x32..tensorflow.serving.PredictRequest.InputsEntry\x12\x15\n\routput_filter\x18\x03 \x03(\t\x1a\x46\n\x0bInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01"\xd0\x01\n\x0fPredictResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x41\n\x07outputs\x18\x01 \x03(\x0b\x32\x30.tensorflow.serving.PredictResponse.OutputsEntry\x1aG\n\x0cOutputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3' + ), + dependencies=[ + tensorflow_dot_core_dot_framework_dot_tensor__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR, + ], +) _PREDICTREQUEST_INPUTSENTRY = _descriptor.Descriptor( - name='InputsEntry', - full_name='tensorflow.serving.PredictRequest.InputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.PredictRequest.InputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.PredictRequest.InputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=295, - serialized_end=365, + name="InputsEntry", + full_name="tensorflow.serving.PredictRequest.InputsEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tensorflow.serving.PredictRequest.InputsEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tensorflow.serving.PredictRequest.InputsEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=295, + serialized_end=365, ) _PREDICTREQUEST = _descriptor.Descriptor( - name='PredictRequest', - full_name='tensorflow.serving.PredictRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.PredictRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='inputs', full_name='tensorflow.serving.PredictRequest.inputs', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='output_filter', full_name='tensorflow.serving.PredictRequest.output_filter', index=2, - number=3, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PREDICTREQUEST_INPUTSENTRY, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=139, - serialized_end=365, + name="PredictRequest", + full_name="tensorflow.serving.PredictRequest", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.PredictRequest.model_spec", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="inputs", + full_name="tensorflow.serving.PredictRequest.inputs", + index=1, + number=2, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="output_filter", + full_name="tensorflow.serving.PredictRequest.output_filter", + index=2, + number=3, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[_PREDICTREQUEST_INPUTSENTRY], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=139, + serialized_end=365, ) _PREDICTRESPONSE_OUTPUTSENTRY = _descriptor.Descriptor( - name='OutputsEntry', - full_name='tensorflow.serving.PredictResponse.OutputsEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.serving.PredictResponse.OutputsEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.PredictResponse.OutputsEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=505, - serialized_end=576, + name="OutputsEntry", + full_name="tensorflow.serving.PredictResponse.OutputsEntry", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="key", + full_name="tensorflow.serving.PredictResponse.OutputsEntry.key", + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode("utf-8"), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="value", + full_name="tensorflow.serving.PredictResponse.OutputsEntry.value", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=505, + serialized_end=576, ) _PREDICTRESPONSE = _descriptor.Descriptor( - name='PredictResponse', - full_name='tensorflow.serving.PredictResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.PredictResponse.model_spec', index=0, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='outputs', full_name='tensorflow.serving.PredictResponse.outputs', index=1, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_PREDICTRESPONSE_OUTPUTSENTRY, ], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=368, - serialized_end=576, + name="PredictResponse", + full_name="tensorflow.serving.PredictResponse", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.PredictResponse.model_spec", + index=0, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="outputs", + full_name="tensorflow.serving.PredictResponse.outputs", + index=1, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[_PREDICTRESPONSE_OUTPUTSENTRY], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=368, + serialized_end=576, ) -_PREDICTREQUEST_INPUTSENTRY.fields_by_name['value'].message_type = tensorflow_dot_core_dot_framework_dot_tensor__pb2._TENSORPROTO +_PREDICTREQUEST_INPUTSENTRY.fields_by_name[ + "value" +].message_type = tensorflow_dot_core_dot_framework_dot_tensor__pb2._TENSORPROTO _PREDICTREQUEST_INPUTSENTRY.containing_type = _PREDICTREQUEST -_PREDICTREQUEST.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_PREDICTREQUEST.fields_by_name['inputs'].message_type = _PREDICTREQUEST_INPUTSENTRY -_PREDICTRESPONSE_OUTPUTSENTRY.fields_by_name['value'].message_type = tensorflow_dot_core_dot_framework_dot_tensor__pb2._TENSORPROTO +_PREDICTREQUEST.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_PREDICTREQUEST.fields_by_name["inputs"].message_type = _PREDICTREQUEST_INPUTSENTRY +_PREDICTRESPONSE_OUTPUTSENTRY.fields_by_name[ + "value" +].message_type = tensorflow_dot_core_dot_framework_dot_tensor__pb2._TENSORPROTO _PREDICTRESPONSE_OUTPUTSENTRY.containing_type = _PREDICTRESPONSE -_PREDICTRESPONSE.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_PREDICTRESPONSE.fields_by_name['outputs'].message_type = _PREDICTRESPONSE_OUTPUTSENTRY -DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST -DESCRIPTOR.message_types_by_name['PredictResponse'] = _PREDICTRESPONSE +_PREDICTRESPONSE.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_PREDICTRESPONSE.fields_by_name["outputs"].message_type = _PREDICTRESPONSE_OUTPUTSENTRY +DESCRIPTOR.message_types_by_name["PredictRequest"] = _PREDICTREQUEST +DESCRIPTOR.message_types_by_name["PredictResponse"] = _PREDICTRESPONSE _sym_db.RegisterFileDescriptor(DESCRIPTOR) -PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), dict( - - InputsEntry = _reflection.GeneratedProtocolMessageType('InputsEntry', (_message.Message,), dict( - DESCRIPTOR = _PREDICTREQUEST_INPUTSENTRY, - __module__ = 'tensorflow_serving.apis.predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest.InputsEntry) - )) - , - DESCRIPTOR = _PREDICTREQUEST, - __module__ = 'tensorflow_serving.apis.predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest) - )) +PredictRequest = _reflection.GeneratedProtocolMessageType( + "PredictRequest", + (_message.Message,), + dict( + InputsEntry=_reflection.GeneratedProtocolMessageType( + "InputsEntry", + (_message.Message,), + dict( + DESCRIPTOR=_PREDICTREQUEST_INPUTSENTRY, + __module__="tensorflow_serving.apis.predict_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest.InputsEntry) + ), + ), + DESCRIPTOR=_PREDICTREQUEST, + __module__="tensorflow_serving.apis.predict_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictRequest) + ), +) _sym_db.RegisterMessage(PredictRequest) _sym_db.RegisterMessage(PredictRequest.InputsEntry) -PredictResponse = _reflection.GeneratedProtocolMessageType('PredictResponse', (_message.Message,), dict( - - OutputsEntry = _reflection.GeneratedProtocolMessageType('OutputsEntry', (_message.Message,), dict( - DESCRIPTOR = _PREDICTRESPONSE_OUTPUTSENTRY, - __module__ = 'tensorflow_serving.apis.predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse.OutputsEntry) - )) - , - DESCRIPTOR = _PREDICTRESPONSE, - __module__ = 'tensorflow_serving.apis.predict_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse) - )) +PredictResponse = _reflection.GeneratedProtocolMessageType( + "PredictResponse", + (_message.Message,), + dict( + OutputsEntry=_reflection.GeneratedProtocolMessageType( + "OutputsEntry", + (_message.Message,), + dict( + DESCRIPTOR=_PREDICTRESPONSE_OUTPUTSENTRY, + __module__="tensorflow_serving.apis.predict_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse.OutputsEntry) + ), + ), + DESCRIPTOR=_PREDICTRESPONSE, + __module__="tensorflow_serving.apis.predict_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.PredictResponse) + ), +) _sym_db.RegisterMessage(PredictResponse) _sym_db.RegisterMessage(PredictResponse.OutputsEntry) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) _PREDICTREQUEST_INPUTSENTRY.has_options = True -_PREDICTREQUEST_INPUTSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +_PREDICTREQUEST_INPUTSENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b("8\001") +) _PREDICTRESPONSE_OUTPUTSENTRY.has_options = True -_PREDICTRESPONSE_OUTPUTSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) +_PREDICTRESPONSE_OUTPUTSENTRY._options = _descriptor._ParseOptions( + descriptor_pb2.MessageOptions(), _b("8\001") +) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/prediction_service_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/prediction_service_pb2.py index a522541e9d..bc2924f17b 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/prediction_service_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/prediction_service_pb2.py @@ -18,292 +18,414 @@ # python -m grpc.tools.protoc --python_out=. --grpc_python_out=. -I. tensorflow_serving/apis/prediction_service.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -from tensorflow_serving.apis import classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2 -from tensorflow_serving.apis import get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2 +from tensorflow_serving.apis import ( + classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2, +) +from tensorflow_serving.apis import ( + get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2, +) from tensorflow_serving.apis import inference_pb2 as tensorflow__serving_dot_apis_dot_inference__pb2 from tensorflow_serving.apis import predict_pb2 as tensorflow__serving_dot_apis_dot_predict__pb2 -from tensorflow_serving.apis import regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2 +from tensorflow_serving.apis import ( + regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2, +) DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/prediction_service.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a,tensorflow_serving/apis/classification.proto\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a\'tensorflow_serving/apis/inference.proto\x1a%tensorflow_serving/apis/predict.proto\x1a(tensorflow_serving/apis/regression.proto2\xfc\x03\n\x11PredictionService\x12\x61\n\x08\x43lassify\x12).tensorflow.serving.ClassificationRequest\x1a*.tensorflow.serving.ClassificationResponse\x12X\n\x07Regress\x12%.tensorflow.serving.RegressionRequest\x1a&.tensorflow.serving.RegressionResponse\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12g\n\x0eMultiInference\x12).tensorflow.serving.MultiInferenceRequest\x1a*.tensorflow.serving.MultiInferenceResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensorflow__serving_dot_apis_dot_classification__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_inference__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_regression__pb2.DESCRIPTOR,]) + name="tensorflow_serving/apis/prediction_service.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + "\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a,tensorflow_serving/apis/classification.proto\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a'tensorflow_serving/apis/inference.proto\x1a%tensorflow_serving/apis/predict.proto\x1a(tensorflow_serving/apis/regression.proto2\xfc\x03\n\x11PredictionService\x12\x61\n\x08\x43lassify\x12).tensorflow.serving.ClassificationRequest\x1a*.tensorflow.serving.ClassificationResponse\x12X\n\x07Regress\x12%.tensorflow.serving.RegressionRequest\x1a&.tensorflow.serving.RegressionResponse\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12g\n\x0eMultiInference\x12).tensorflow.serving.MultiInferenceRequest\x1a*.tensorflow.serving.MultiInferenceResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3" + ), + dependencies=[ + tensorflow__serving_dot_apis_dot_classification__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_inference__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_regression__pb2.DESCRIPTOR, + ], +) _sym_db.RegisterFileDescriptor(DESCRIPTOR) - - - DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) try: - # THESE ELEMENTS WILL BE DEPRECATED. - # Please use the generated *_pb2_grpc.py files instead. - import grpc - from grpc.framework.common import cardinality - from grpc.framework.interfaces.face import utilities as face_utilities - from grpc.beta import implementations as beta_implementations - from grpc.beta import interfaces as beta_interfaces - - - class PredictionServiceStub(object): - """open source marker; do not remove + # THESE ELEMENTS WILL BE DEPRECATED. + # Please use the generated *_pb2_grpc.py files instead. + import grpc + from grpc.framework.common import cardinality + from grpc.framework.interfaces.face import utilities as face_utilities + from grpc.beta import implementations as beta_implementations + from grpc.beta import interfaces as beta_interfaces + + class PredictionServiceStub(object): + """open source marker; do not remove PredictionService provides access to machine-learned models loaded by model_servers. """ - def __init__(self, channel): - """Constructor. + def __init__(self, channel): + """Constructor. Args: channel: A grpc.Channel. """ - self.Classify = channel.unary_unary( - '/tensorflow.serving.PredictionService/Classify', - request_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString, - response_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString, - ) - self.Regress = channel.unary_unary( - '/tensorflow.serving.PredictionService/Regress', - request_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString, - response_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString, - ) - self.Predict = channel.unary_unary( - '/tensorflow.serving.PredictionService/Predict', - request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString, - response_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString, - ) - self.MultiInference = channel.unary_unary( - '/tensorflow.serving.PredictionService/MultiInference', - request_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.SerializeToString, - response_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.FromString, - ) - self.GetModelMetadata = channel.unary_unary( - '/tensorflow.serving.PredictionService/GetModelMetadata', - request_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString, - response_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString, - ) - - - class PredictionServiceServicer(object): - """open source marker; do not remove + self.Classify = channel.unary_unary( + "/tensorflow.serving.PredictionService/Classify", + request_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString, + response_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString, + ) + self.Regress = channel.unary_unary( + "/tensorflow.serving.PredictionService/Regress", + request_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString, + response_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString, + ) + self.Predict = channel.unary_unary( + "/tensorflow.serving.PredictionService/Predict", + request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString, + response_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString, + ) + self.MultiInference = channel.unary_unary( + "/tensorflow.serving.PredictionService/MultiInference", + request_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.SerializeToString, + response_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.FromString, + ) + self.GetModelMetadata = channel.unary_unary( + "/tensorflow.serving.PredictionService/GetModelMetadata", + request_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString, + response_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString, + ) + + class PredictionServiceServicer(object): + """open source marker; do not remove PredictionService provides access to machine-learned models loaded by model_servers. """ - def Classify(self, request, context): - """Classify. + def Classify(self, request, context): + """Classify. """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") - def Regress(self, request, context): - """Regress. + def Regress(self, request, context): + """Regress. """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") - def Predict(self, request, context): - """Predict -- provides access to loaded TensorFlow model. + def Predict(self, request, context): + """Predict -- provides access to loaded TensorFlow model. """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") - def MultiInference(self, request, context): - """MultiInference API for multi-headed models. + def MultiInference(self, request, context): + """MultiInference API for multi-headed models. """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") - def GetModelMetadata(self, request, context): - """GetModelMetadata - provides access to metadata for loaded models. + def GetModelMetadata(self, request, context): + """GetModelMetadata - provides access to metadata for loaded models. """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - - def add_PredictionServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Classify': grpc.unary_unary_rpc_method_handler( - servicer.Classify, - request_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString, - response_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString, - ), - 'Regress': grpc.unary_unary_rpc_method_handler( - servicer.Regress, - request_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString, - response_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString, - ), - 'Predict': grpc.unary_unary_rpc_method_handler( - servicer.Predict, - request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString, - response_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString, - ), - 'MultiInference': grpc.unary_unary_rpc_method_handler( - servicer.MultiInference, - request_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.FromString, - response_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.SerializeToString, - ), - 'GetModelMetadata': grpc.unary_unary_rpc_method_handler( - servicer.GetModelMetadata, - request_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString, - response_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'tensorflow.serving.PredictionService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - class BetaPredictionServiceServicer(object): - """The Beta API is deprecated for 0.15.0 and later. + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def add_PredictionServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + "Classify": grpc.unary_unary_rpc_method_handler( + servicer.Classify, + request_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString, + response_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString, + ), + "Regress": grpc.unary_unary_rpc_method_handler( + servicer.Regress, + request_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString, + response_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString, + ), + "Predict": grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString, + response_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString, + ), + "MultiInference": grpc.unary_unary_rpc_method_handler( + servicer.MultiInference, + request_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.FromString, + response_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.SerializeToString, + ), + "GetModelMetadata": grpc.unary_unary_rpc_method_handler( + servicer.GetModelMetadata, + request_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString, + response_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "tensorflow.serving.PredictionService", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + class BetaPredictionServiceServicer(object): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This class was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" - """open source marker; do not remove + + """open source marker; do not remove PredictionService provides access to machine-learned models loaded by model_servers. """ - def Classify(self, request, context): - """Classify. - """ - context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) - def Regress(self, request, context): - """Regress. + + def Classify(self, request, context): + """Classify. """ - context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) - def Predict(self, request, context): - """Predict -- provides access to loaded TensorFlow model. + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + + def Regress(self, request, context): + """Regress. """ - context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) - def MultiInference(self, request, context): - """MultiInference API for multi-headed models. + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + + def Predict(self, request, context): + """Predict -- provides access to loaded TensorFlow model. """ - context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) - def GetModelMetadata(self, request, context): - """GetModelMetadata - provides access to metadata for loaded models. + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + + def MultiInference(self, request, context): + """MultiInference API for multi-headed models. """ - context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + def GetModelMetadata(self, request, context): + """GetModelMetadata - provides access to metadata for loaded models. + """ + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) - class BetaPredictionServiceStub(object): - """The Beta API is deprecated for 0.15.0 and later. + class BetaPredictionServiceStub(object): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This class was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" - """open source marker; do not remove + + """open source marker; do not remove PredictionService provides access to machine-learned models loaded by model_servers. """ - def Classify(self, request, timeout, metadata=None, with_call=False, protocol_options=None): - """Classify. + + def Classify(self, request, timeout, metadata=None, with_call=False, protocol_options=None): + """Classify. """ - raise NotImplementedError() - Classify.future = None - def Regress(self, request, timeout, metadata=None, with_call=False, protocol_options=None): - """Regress. + raise NotImplementedError() + + Classify.future = None + + def Regress(self, request, timeout, metadata=None, with_call=False, protocol_options=None): + """Regress. """ - raise NotImplementedError() - Regress.future = None - def Predict(self, request, timeout, metadata=None, with_call=False, protocol_options=None): - """Predict -- provides access to loaded TensorFlow model. + raise NotImplementedError() + + Regress.future = None + + def Predict(self, request, timeout, metadata=None, with_call=False, protocol_options=None): + """Predict -- provides access to loaded TensorFlow model. """ - raise NotImplementedError() - Predict.future = None - def MultiInference(self, request, timeout, metadata=None, with_call=False, protocol_options=None): - """MultiInference API for multi-headed models. + raise NotImplementedError() + + Predict.future = None + + def MultiInference( + self, request, timeout, metadata=None, with_call=False, protocol_options=None + ): + """MultiInference API for multi-headed models. """ - raise NotImplementedError() - MultiInference.future = None - def GetModelMetadata(self, request, timeout, metadata=None, with_call=False, protocol_options=None): - """GetModelMetadata - provides access to metadata for loaded models. + raise NotImplementedError() + + MultiInference.future = None + + def GetModelMetadata( + self, request, timeout, metadata=None, with_call=False, protocol_options=None + ): + """GetModelMetadata - provides access to metadata for loaded models. """ - raise NotImplementedError() - GetModelMetadata.future = None + raise NotImplementedError() + GetModelMetadata.future = None - def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None): - """The Beta API is deprecated for 0.15.0 and later. + def beta_create_PredictionService_server( + servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None + ): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This function was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" - request_deserializers = { - ('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString, - ('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString, - ('tensorflow.serving.PredictionService', 'MultiInference'): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.FromString, - ('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString, - ('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString, - } - response_serializers = { - ('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString, - ('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString, - ('tensorflow.serving.PredictionService', 'MultiInference'): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.SerializeToString, - ('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString, - ('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString, - } - method_implementations = { - ('tensorflow.serving.PredictionService', 'Classify'): face_utilities.unary_unary_inline(servicer.Classify), - ('tensorflow.serving.PredictionService', 'GetModelMetadata'): face_utilities.unary_unary_inline(servicer.GetModelMetadata), - ('tensorflow.serving.PredictionService', 'MultiInference'): face_utilities.unary_unary_inline(servicer.MultiInference), - ('tensorflow.serving.PredictionService', 'Predict'): face_utilities.unary_unary_inline(servicer.Predict), - ('tensorflow.serving.PredictionService', 'Regress'): face_utilities.unary_unary_inline(servicer.Regress), - } - server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout) - return beta_implementations.server(method_implementations, options=server_options) - - - def beta_create_PredictionService_stub(channel, host=None, metadata_transformer=None, pool=None, pool_size=None): - """The Beta API is deprecated for 0.15.0 and later. + request_deserializers = { + ( + "tensorflow.serving.PredictionService", + "Classify", + ): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString, + ( + "tensorflow.serving.PredictionService", + "GetModelMetadata", + ): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString, + ( + "tensorflow.serving.PredictionService", + "MultiInference", + ): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.FromString, + ( + "tensorflow.serving.PredictionService", + "Predict", + ): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString, + ( + "tensorflow.serving.PredictionService", + "Regress", + ): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString, + } + response_serializers = { + ( + "tensorflow.serving.PredictionService", + "Classify", + ): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "GetModelMetadata", + ): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "MultiInference", + ): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "Predict", + ): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "Regress", + ): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString, + } + method_implementations = { + ("tensorflow.serving.PredictionService", "Classify"): face_utilities.unary_unary_inline( + servicer.Classify + ), + ( + "tensorflow.serving.PredictionService", + "GetModelMetadata", + ): face_utilities.unary_unary_inline(servicer.GetModelMetadata), + ( + "tensorflow.serving.PredictionService", + "MultiInference", + ): face_utilities.unary_unary_inline(servicer.MultiInference), + ("tensorflow.serving.PredictionService", "Predict"): face_utilities.unary_unary_inline( + servicer.Predict + ), + ("tensorflow.serving.PredictionService", "Regress"): face_utilities.unary_unary_inline( + servicer.Regress + ), + } + server_options = beta_implementations.server_options( + request_deserializers=request_deserializers, + response_serializers=response_serializers, + thread_pool=pool, + thread_pool_size=pool_size, + default_timeout=default_timeout, + maximum_timeout=maximum_timeout, + ) + return beta_implementations.server(method_implementations, options=server_options) + + def beta_create_PredictionService_stub( + channel, host=None, metadata_transformer=None, pool=None, pool_size=None + ): + """The Beta API is deprecated for 0.15.0 and later. It is recommended to use the GA API (classes and functions in this file not marked beta) for all further purposes. This function was generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" - request_serializers = { - ('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString, - ('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString, - ('tensorflow.serving.PredictionService', 'MultiInference'): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.SerializeToString, - ('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString, - ('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString, - } - response_deserializers = { - ('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString, - ('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString, - ('tensorflow.serving.PredictionService', 'MultiInference'): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.FromString, - ('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString, - ('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString, - } - cardinalities = { - 'Classify': cardinality.Cardinality.UNARY_UNARY, - 'GetModelMetadata': cardinality.Cardinality.UNARY_UNARY, - 'MultiInference': cardinality.Cardinality.UNARY_UNARY, - 'Predict': cardinality.Cardinality.UNARY_UNARY, - 'Regress': cardinality.Cardinality.UNARY_UNARY, - } - stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size) - return beta_implementations.dynamic_stub(channel, 'tensorflow.serving.PredictionService', cardinalities, options=stub_options) + request_serializers = { + ( + "tensorflow.serving.PredictionService", + "Classify", + ): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "GetModelMetadata", + ): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "MultiInference", + ): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "Predict", + ): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString, + ( + "tensorflow.serving.PredictionService", + "Regress", + ): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString, + } + response_deserializers = { + ( + "tensorflow.serving.PredictionService", + "Classify", + ): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString, + ( + "tensorflow.serving.PredictionService", + "GetModelMetadata", + ): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString, + ( + "tensorflow.serving.PredictionService", + "MultiInference", + ): tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.FromString, + ( + "tensorflow.serving.PredictionService", + "Predict", + ): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString, + ( + "tensorflow.serving.PredictionService", + "Regress", + ): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString, + } + cardinalities = { + "Classify": cardinality.Cardinality.UNARY_UNARY, + "GetModelMetadata": cardinality.Cardinality.UNARY_UNARY, + "MultiInference": cardinality.Cardinality.UNARY_UNARY, + "Predict": cardinality.Cardinality.UNARY_UNARY, + "Regress": cardinality.Cardinality.UNARY_UNARY, + } + stub_options = beta_implementations.stub_options( + host=host, + metadata_transformer=metadata_transformer, + request_serializers=request_serializers, + response_deserializers=response_deserializers, + thread_pool=pool, + thread_pool_size=pool_size, + ) + return beta_implementations.dynamic_stub( + channel, "tensorflow.serving.PredictionService", cardinalities, options=stub_options + ) + + except ImportError: - pass + pass # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/tensorflow/tensorflow_serving/apis/regression_pb2.py b/src/sagemaker/tensorflow/tensorflow_serving/apis/regression_pb2.py index a61dc8c6c6..d3c92f9539 100755 --- a/src/sagemaker/tensorflow/tensorflow_serving/apis/regression_pb2.py +++ b/src/sagemaker/tensorflow/tensorflow_serving/apis/regression_pb2.py @@ -2,12 +2,14 @@ # source: tensorflow_serving/apis/regression.proto import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -18,193 +20,272 @@ DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_serving/apis/regression.proto', - package='tensorflow.serving', - syntax='proto3', - serialized_pb=_b('\n(tensorflow_serving/apis/regression.proto\x12\x12tensorflow.serving\x1a#tensorflow_serving/apis/input.proto\x1a#tensorflow_serving/apis/model.proto\"\x1b\n\nRegression\x12\r\n\x05value\x18\x01 \x01(\x02\"G\n\x10RegressionResult\x12\x33\n\x0bregressions\x18\x01 \x03(\x0b\x32\x1e.tensorflow.serving.Regression\"p\n\x11RegressionRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12(\n\x05input\x18\x02 \x01(\x0b\x32\x19.tensorflow.serving.Input\"}\n\x12RegressionResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x34\n\x06result\x18\x01 \x01(\x0b\x32$.tensorflow.serving.RegressionResultB\x03\xf8\x01\x01\x62\x06proto3') - , - dependencies=[tensorflow__serving_dot_apis_dot_input__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR,]) - - + name="tensorflow_serving/apis/regression.proto", + package="tensorflow.serving", + syntax="proto3", + serialized_pb=_b( + '\n(tensorflow_serving/apis/regression.proto\x12\x12tensorflow.serving\x1a#tensorflow_serving/apis/input.proto\x1a#tensorflow_serving/apis/model.proto"\x1b\n\nRegression\x12\r\n\x05value\x18\x01 \x01(\x02"G\n\x10RegressionResult\x12\x33\n\x0bregressions\x18\x01 \x03(\x0b\x32\x1e.tensorflow.serving.Regression"p\n\x11RegressionRequest\x12\x31\n\nmodel_spec\x18\x01 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12(\n\x05input\x18\x02 \x01(\x0b\x32\x19.tensorflow.serving.Input"}\n\x12RegressionResponse\x12\x31\n\nmodel_spec\x18\x02 \x01(\x0b\x32\x1d.tensorflow.serving.ModelSpec\x12\x34\n\x06result\x18\x01 \x01(\x0b\x32$.tensorflow.serving.RegressionResultB\x03\xf8\x01\x01\x62\x06proto3' + ), + dependencies=[ + tensorflow__serving_dot_apis_dot_input__pb2.DESCRIPTOR, + tensorflow__serving_dot_apis_dot_model__pb2.DESCRIPTOR, + ], +) _REGRESSION = _descriptor.Descriptor( - name='Regression', - full_name='tensorflow.serving.Regression', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.serving.Regression.value', index=0, - number=1, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=138, - serialized_end=165, + name="Regression", + full_name="tensorflow.serving.Regression", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="value", + full_name="tensorflow.serving.Regression.value", + index=0, + number=1, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ) + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=138, + serialized_end=165, ) _REGRESSIONRESULT = _descriptor.Descriptor( - name='RegressionResult', - full_name='tensorflow.serving.RegressionResult', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='regressions', full_name='tensorflow.serving.RegressionResult.regressions', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=167, - serialized_end=238, + name="RegressionResult", + full_name="tensorflow.serving.RegressionResult", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="regressions", + full_name="tensorflow.serving.RegressionResult.regressions", + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ) + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=167, + serialized_end=238, ) _REGRESSIONREQUEST = _descriptor.Descriptor( - name='RegressionRequest', - full_name='tensorflow.serving.RegressionRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.RegressionRequest.model_spec', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='input', full_name='tensorflow.serving.RegressionRequest.input', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=240, - serialized_end=352, + name="RegressionRequest", + full_name="tensorflow.serving.RegressionRequest", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.RegressionRequest.model_spec", + index=0, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="input", + full_name="tensorflow.serving.RegressionRequest.input", + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=240, + serialized_end=352, ) _REGRESSIONRESPONSE = _descriptor.Descriptor( - name='RegressionResponse', - full_name='tensorflow.serving.RegressionResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='model_spec', full_name='tensorflow.serving.RegressionResponse.model_spec', index=0, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='result', full_name='tensorflow.serving.RegressionResponse.result', index=1, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=354, - serialized_end=479, + name="RegressionResponse", + full_name="tensorflow.serving.RegressionResponse", + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name="model_spec", + full_name="tensorflow.serving.RegressionResponse.model_spec", + index=0, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name="result", + full_name="tensorflow.serving.RegressionResponse.result", + index=1, + number=1, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax="proto3", + extension_ranges=[], + oneofs=[], + serialized_start=354, + serialized_end=479, ) -_REGRESSIONRESULT.fields_by_name['regressions'].message_type = _REGRESSION -_REGRESSIONREQUEST.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_REGRESSIONREQUEST.fields_by_name['input'].message_type = tensorflow__serving_dot_apis_dot_input__pb2._INPUT -_REGRESSIONRESPONSE.fields_by_name['model_spec'].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC -_REGRESSIONRESPONSE.fields_by_name['result'].message_type = _REGRESSIONRESULT -DESCRIPTOR.message_types_by_name['Regression'] = _REGRESSION -DESCRIPTOR.message_types_by_name['RegressionResult'] = _REGRESSIONRESULT -DESCRIPTOR.message_types_by_name['RegressionRequest'] = _REGRESSIONREQUEST -DESCRIPTOR.message_types_by_name['RegressionResponse'] = _REGRESSIONRESPONSE +_REGRESSIONRESULT.fields_by_name["regressions"].message_type = _REGRESSION +_REGRESSIONREQUEST.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_REGRESSIONREQUEST.fields_by_name[ + "input" +].message_type = tensorflow__serving_dot_apis_dot_input__pb2._INPUT +_REGRESSIONRESPONSE.fields_by_name[ + "model_spec" +].message_type = tensorflow__serving_dot_apis_dot_model__pb2._MODELSPEC +_REGRESSIONRESPONSE.fields_by_name["result"].message_type = _REGRESSIONRESULT +DESCRIPTOR.message_types_by_name["Regression"] = _REGRESSION +DESCRIPTOR.message_types_by_name["RegressionResult"] = _REGRESSIONRESULT +DESCRIPTOR.message_types_by_name["RegressionRequest"] = _REGRESSIONREQUEST +DESCRIPTOR.message_types_by_name["RegressionResponse"] = _REGRESSIONRESPONSE _sym_db.RegisterFileDescriptor(DESCRIPTOR) -Regression = _reflection.GeneratedProtocolMessageType('Regression', (_message.Message,), dict( - DESCRIPTOR = _REGRESSION, - __module__ = 'tensorflow_serving.apis.regression_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.Regression) - )) +Regression = _reflection.GeneratedProtocolMessageType( + "Regression", + (_message.Message,), + dict( + DESCRIPTOR=_REGRESSION, + __module__="tensorflow_serving.apis.regression_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.Regression) + ), +) _sym_db.RegisterMessage(Regression) -RegressionResult = _reflection.GeneratedProtocolMessageType('RegressionResult', (_message.Message,), dict( - DESCRIPTOR = _REGRESSIONRESULT, - __module__ = 'tensorflow_serving.apis.regression_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.RegressionResult) - )) +RegressionResult = _reflection.GeneratedProtocolMessageType( + "RegressionResult", + (_message.Message,), + dict( + DESCRIPTOR=_REGRESSIONRESULT, + __module__="tensorflow_serving.apis.regression_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.RegressionResult) + ), +) _sym_db.RegisterMessage(RegressionResult) -RegressionRequest = _reflection.GeneratedProtocolMessageType('RegressionRequest', (_message.Message,), dict( - DESCRIPTOR = _REGRESSIONREQUEST, - __module__ = 'tensorflow_serving.apis.regression_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.RegressionRequest) - )) +RegressionRequest = _reflection.GeneratedProtocolMessageType( + "RegressionRequest", + (_message.Message,), + dict( + DESCRIPTOR=_REGRESSIONREQUEST, + __module__="tensorflow_serving.apis.regression_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.RegressionRequest) + ), +) _sym_db.RegisterMessage(RegressionRequest) -RegressionResponse = _reflection.GeneratedProtocolMessageType('RegressionResponse', (_message.Message,), dict( - DESCRIPTOR = _REGRESSIONRESPONSE, - __module__ = 'tensorflow_serving.apis.regression_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.serving.RegressionResponse) - )) +RegressionResponse = _reflection.GeneratedProtocolMessageType( + "RegressionResponse", + (_message.Message,), + dict( + DESCRIPTOR=_REGRESSIONRESPONSE, + __module__="tensorflow_serving.apis.regression_pb2" + # @@protoc_insertion_point(class_scope:tensorflow.serving.RegressionResponse) + ), +) _sym_db.RegisterMessage(RegressionResponse) DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001')) +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) # @@protoc_insertion_point(module_scope) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 136d1f7f2e..0c2f36b414 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -23,9 +23,24 @@ class Transformer(object): """A class for handling creating and interacting with Amazon SageMaker transform jobs. """ - def __init__(self, model_name, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, - output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None, - env=None, base_transform_job_name=None, sagemaker_session=None, volume_kms_key=None): + def __init__( + self, + model_name, + instance_count, + instance_type, + strategy=None, + assemble_with=None, + output_path=None, + output_kms_key=None, + accept=None, + max_concurrent_transforms=None, + max_payload=None, + tags=None, + env=None, + base_transform_job_name=None, + sagemaker_session=None, + volume_kms_key=None, + ): """Initialize a ``Transformer``. Args: @@ -78,8 +93,18 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass self.sagemaker_session = sagemaker_session or Session() - def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, - job_name=None, input_filter=None, output_filter=None, join_source=None): + def transform( + self, + data, + data_type="S3Prefix", + content_type=None, + compression_type=None, + split_type=None, + job_name=None, + input_filter=None, + output_filter=None, + join_source=None, + ): """Start a new transform job. Args: @@ -108,8 +133,8 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t Valid values: Input, None. """ local_mode = self.sagemaker_session.local_mode - if not local_mode and not data.startswith('s3://'): - raise ValueError('Invalid S3 URI: {}'.format(data)) + if not local_mode and not data.startswith("s3://"): + raise ValueError("Invalid S3 URI: {}".format(data)) if job_name is not None: self._current_job_name = job_name @@ -122,10 +147,21 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t self._current_job_name = name_from_base(base_name) if self.output_path is None: - self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name) - - self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type, - split_type, input_filter, output_filter, join_source) + self.output_path = "s3://{}/{}".format( + self.sagemaker_session.default_bucket(), self._current_job_name + ) + + self.latest_transform_job = _TransformJob.start_new( + self, + data, + data_type, + content_type, + compression_type, + split_type, + input_filter, + output_filter, + join_source, + ) def delete_model(self): """Delete the corresponding SageMaker model for this Transformer. @@ -143,23 +179,26 @@ def _retrieve_base_name(self): def _retrieve_image_name(self): try: - model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name) + model_desc = self.sagemaker_session.sagemaker_client.describe_model( + ModelName=self.model_name + ) - primary_container = model_desc.get('PrimaryContainer') + primary_container = model_desc.get("PrimaryContainer") if primary_container: - return primary_container.get('Image') + return primary_container.get("Image") - containers = model_desc.get('Containers') + containers = model_desc.get("Containers") if containers: - return containers[0].get('Image') + return containers[0].get("Image") return None except exceptions.ClientError: - raise ValueError('Failed to fetch model information for %s. ' - 'Please ensure that the model exists. ' - 'Local instance types require locally created models.' - % self.model_name) + raise ValueError( + "Failed to fetch model information for %s. " + "Please ensure that the model exists. " + "Local instance types require locally created models." % self.model_name + ) def wait(self): self._ensure_last_transform_job() @@ -167,7 +206,7 @@ def wait(self): def _ensure_last_transform_job(self): if self.latest_transform_job is None: - raise ValueError('No transform job available') + raise ValueError("No transform job available") @classmethod def attach(cls, transform_job_name, sagemaker_session=None): @@ -185,11 +224,14 @@ def attach(cls, transform_job_name, sagemaker_session=None): """ sagemaker_session = sagemaker_session or Session() - job_details = sagemaker_session.sagemaker_client.describe_transform_job(TransformJobName=transform_job_name) + job_details = sagemaker_session.sagemaker_client.describe_transform_job( + TransformJobName=transform_job_name + ) init_params = cls._prepare_init_params_from_job_description(job_details) transformer = cls(sagemaker_session=sagemaker_session, **init_params) - transformer.latest_transform_job = _TransformJob(sagemaker_session=sagemaker_session, - job_name=init_params['base_transform_job_name']) + transformer.latest_transform_job = _TransformJob( + sagemaker_session=sagemaker_session, job_name=init_params["base_transform_job_name"] + ) return transformer @@ -205,37 +247,56 @@ def _prepare_init_params_from_job_description(cls, job_details): """ init_params = dict() - init_params['model_name'] = job_details['ModelName'] - init_params['instance_count'] = job_details['TransformResources']['InstanceCount'] - init_params['instance_type'] = job_details['TransformResources']['InstanceType'] - init_params['volume_kms_key'] = job_details['TransformResources'].get('VolumeKmsKeyId') - init_params['strategy'] = job_details.get('BatchStrategy') - init_params['assemble_with'] = job_details['TransformOutput'].get('AssembleWith') - init_params['output_path'] = job_details['TransformOutput']['S3OutputPath'] - init_params['output_kms_key'] = job_details['TransformOutput'].get('KmsKeyId') - init_params['accept'] = job_details['TransformOutput'].get('Accept') - init_params['max_concurrent_transforms'] = job_details.get('MaxConcurrentTransforms') - init_params['max_payload'] = job_details.get('MaxPayloadInMB') - init_params['base_transform_job_name'] = job_details['TransformJobName'] + init_params["model_name"] = job_details["ModelName"] + init_params["instance_count"] = job_details["TransformResources"]["InstanceCount"] + init_params["instance_type"] = job_details["TransformResources"]["InstanceType"] + init_params["volume_kms_key"] = job_details["TransformResources"].get("VolumeKmsKeyId") + init_params["strategy"] = job_details.get("BatchStrategy") + init_params["assemble_with"] = job_details["TransformOutput"].get("AssembleWith") + init_params["output_path"] = job_details["TransformOutput"]["S3OutputPath"] + init_params["output_kms_key"] = job_details["TransformOutput"].get("KmsKeyId") + init_params["accept"] = job_details["TransformOutput"].get("Accept") + init_params["max_concurrent_transforms"] = job_details.get("MaxConcurrentTransforms") + init_params["max_payload"] = job_details.get("MaxPayloadInMB") + init_params["base_transform_job_name"] = job_details["TransformJobName"] return init_params class _TransformJob(_Job): @classmethod - def start_new(cls, transformer, data, data_type, content_type, compression_type, - split_type, input_filter, output_filter, join_source): - config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer) - data_processing = _TransformJob._prepare_data_processing(input_filter, output_filter, join_source) - - transformer.sagemaker_session.transform(job_name=transformer._current_job_name, - model_name=transformer.model_name, strategy=transformer.strategy, - max_concurrent_transforms=transformer.max_concurrent_transforms, - max_payload=transformer.max_payload, env=transformer.env, - input_config=config['input_config'], - output_config=config['output_config'], - resource_config=config['resource_config'], - tags=transformer.tags, data_processing=data_processing) + def start_new( + cls, + transformer, + data, + data_type, + content_type, + compression_type, + split_type, + input_filter, + output_filter, + join_source, + ): + config = _TransformJob._load_config( + data, data_type, content_type, compression_type, split_type, transformer + ) + data_processing = _TransformJob._prepare_data_processing( + input_filter, output_filter, join_source + ) + + transformer.sagemaker_session.transform( + job_name=transformer._current_job_name, + model_name=transformer.model_name, + strategy=transformer.strategy, + max_concurrent_transforms=transformer.max_concurrent_transforms, + max_payload=transformer.max_payload, + env=transformer.env, + input_config=config["input_config"], + output_config=config["output_config"], + resource_config=config["resource_config"], + tags=transformer.tags, + data_processing=data_processing, + ) return cls(transformer.sagemaker_session, transformer._current_job_name) @@ -244,38 +305,39 @@ def wait(self): @staticmethod def _load_config(data, data_type, content_type, compression_type, split_type, transformer): - input_config = _TransformJob._format_inputs_to_input_config(data, data_type, content_type, - compression_type, split_type) - - output_config = _TransformJob._prepare_output_config(transformer.output_path, transformer.output_kms_key, - transformer.assemble_with, transformer.accept) - - resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type, - transformer.volume_kms_key) - - return {'input_config': input_config, - 'output_config': output_config, - 'resource_config': resource_config} + input_config = _TransformJob._format_inputs_to_input_config( + data, data_type, content_type, compression_type, split_type + ) + + output_config = _TransformJob._prepare_output_config( + transformer.output_path, + transformer.output_kms_key, + transformer.assemble_with, + transformer.accept, + ) + + resource_config = _TransformJob._prepare_resource_config( + transformer.instance_count, transformer.instance_type, transformer.volume_kms_key + ) + + return { + "input_config": input_config, + "output_config": output_config, + "resource_config": resource_config, + } @staticmethod def _format_inputs_to_input_config(data, data_type, content_type, compression_type, split_type): - config = { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': data_type, - 'S3Uri': data, - } - } - } + config = {"DataSource": {"S3DataSource": {"S3DataType": data_type, "S3Uri": data}}} if content_type is not None: - config['ContentType'] = content_type + config["ContentType"] = content_type if compression_type is not None: - config['CompressionType'] = compression_type + config["CompressionType"] = compression_type if split_type is not None: - config['SplitType'] = split_type + config["SplitType"] = split_type return config @@ -284,19 +346,19 @@ def _prepare_output_config(s3_path, kms_key_id, assemble_with, accept): config = super(_TransformJob, _TransformJob)._prepare_output_config(s3_path, kms_key_id) if assemble_with is not None: - config['AssembleWith'] = assemble_with + config["AssembleWith"] = assemble_with if accept is not None: - config['Accept'] = accept + config["Accept"] = accept return config @staticmethod def _prepare_resource_config(instance_count, instance_type, volume_kms_key): - config = {'InstanceCount': instance_count, 'InstanceType': instance_type} + config = {"InstanceCount": instance_count, "InstanceType": instance_type} if volume_kms_key is not None: - config['VolumeKmsKeyId'] = volume_kms_key + config["VolumeKmsKeyId"] = volume_kms_key return config @@ -305,13 +367,13 @@ def _prepare_data_processing(input_filter, output_filter, join_source): config = {} if input_filter is not None: - config['InputFilter'] = input_filter + config["InputFilter"] = input_filter if output_filter is not None: - config['OutputFilter'] = output_filter + config["OutputFilter"] = output_filter if join_source is not None: - config['JoinSource'] = join_source + config["JoinSource"] = join_source if len(config) == 0: return None diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 4f4615e522..1ad5457602 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -24,26 +24,30 @@ from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.estimator import Framework from sagemaker.job import _Job -from sagemaker.parameter import (CategoricalParameter, ContinuousParameter, - IntegerParameter, ParameterRange) +from sagemaker.parameter import ( + CategoricalParameter, + ContinuousParameter, + IntegerParameter, + ParameterRange, +) from sagemaker.session import Session from sagemaker.session import s3_input from sagemaker.utils import base_name_from_image, name_from_base, to_str -AMAZON_ESTIMATOR_MODULE = 'sagemaker' +AMAZON_ESTIMATOR_MODULE = "sagemaker" AMAZON_ESTIMATOR_CLS_NAMES = { - 'factorization-machines': 'FactorizationMachines', - 'kmeans': 'KMeans', - 'lda': 'LDA', - 'linear-learner': 'LinearLearner', - 'ntm': 'NTM', - 'randomcutforest': 'RandomCutForest', - 'knn': 'KNN', - 'object2vec': 'Object2Vec', + "factorization-machines": "FactorizationMachines", + "kmeans": "KMeans", + "lda": "LDA", + "linear-learner": "LinearLearner", + "ntm": "NTM", + "randomcutforest": "RandomCutForest", + "knn": "KNN", + "object2vec": "Object2Vec", } -HYPERPARAMETER_TUNING_JOB_NAME = 'HyperParameterTuningJobName' -PARENT_HYPERPARAMETER_TUNING_JOBS = 'ParentHyperParameterTuningJobs' -WARM_START_TYPE = 'WarmStartType' +HYPERPARAMETER_TUNING_JOB_NAME = "HyperParameterTuningJobName" +PARENT_HYPERPARAMETER_TUNING_JOBS = "ParentHyperParameterTuningJobs" +WARM_START_TYPE = "WarmStartType" class WarmStartTypes(Enum): @@ -53,6 +57,7 @@ class WarmStartTypes(Enum): * TransferLearning: Type of warm start that allows users to reuse training results from existing tuning jobs that have similar algorithm code and datasets. """ + IDENTICAL_DATA_AND_ALGORITHM = "IdenticalDataAndAlgorithm" TRANSFER_LEARNING = "TransferLearning" @@ -80,11 +85,15 @@ def __init__(self, warm_start_type, parents): if warm_start_type not in WarmStartTypes: raise ValueError( - "Invalid type: {}, valid warm start types are: [{}]".format(warm_start_type, - [t for t in WarmStartTypes])) + "Invalid type: {}, valid warm start types are: [{}]".format( + warm_start_type, [t for t in WarmStartTypes] + ) + ) if not parents: - raise ValueError("Invalid parents: {}, parents should not be None/empty".format(parents)) + raise ValueError( + "Invalid parents: {}, parents should not be None/empty".format(parents) + ) self.type = warm_start_type self.parents = set(parents) @@ -118,17 +127,20 @@ def from_job_desc(cls, warm_start_config): >>> warm_start_config.parents ["p1","p2"] """ - if not warm_start_config or \ - WARM_START_TYPE not in warm_start_config or \ - PARENT_HYPERPARAMETER_TUNING_JOBS not in warm_start_config: + if ( + not warm_start_config + or WARM_START_TYPE not in warm_start_config + or PARENT_HYPERPARAMETER_TUNING_JOBS not in warm_start_config + ): return None parents = [] for parent in warm_start_config[PARENT_HYPERPARAMETER_TUNING_JOBS]: parents.append(parent[HYPERPARAMETER_TUNING_JOB_NAME]) - return cls(warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]), - parents=parents) + return cls( + warm_start_type=WarmStartTypes(warm_start_config[WARM_START_TYPE]), parents=parents + ) def to_input_req(self): """Converts the ``self`` instance to the desired input request format. @@ -149,7 +161,9 @@ def to_input_req(self): """ return { WARM_START_TYPE: self.type.value, - PARENT_HYPERPARAMETER_TUNING_JOBS: [{HYPERPARAMETER_TUNING_JOB_NAME: parent} for parent in self.parents] + PARENT_HYPERPARAMETER_TUNING_JOBS: [ + {HYPERPARAMETER_TUNING_JOB_NAME: parent} for parent in self.parents + ], } @@ -157,17 +171,30 @@ class HyperparameterTuner(object): """A class for creating and interacting with Amazon SageMaker hyperparameter tuning jobs, as well as deploying the resulting model(s). """ - TUNING_JOB_NAME_MAX_LENGTH = 32 - - SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module' - SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name' - DEFAULT_ESTIMATOR_MODULE = 'sagemaker.estimator' - DEFAULT_ESTIMATOR_CLS_NAME = 'Estimator' + TUNING_JOB_NAME_MAX_LENGTH = 32 - def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metric_definitions=None, - strategy='Bayesian', objective_type='Maximize', max_jobs=1, max_parallel_jobs=1, - tags=None, base_tuning_job_name=None, warm_start_config=None, early_stopping_type='Off'): + SAGEMAKER_ESTIMATOR_MODULE = "sagemaker_estimator_module" + SAGEMAKER_ESTIMATOR_CLASS_NAME = "sagemaker_estimator_class_name" + + DEFAULT_ESTIMATOR_MODULE = "sagemaker.estimator" + DEFAULT_ESTIMATOR_CLS_NAME = "Estimator" + + def __init__( + self, + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions=None, + strategy="Bayesian", + objective_type="Maximize", + max_jobs=1, + max_parallel_jobs=1, + tags=None, + base_tuning_job_name=None, + warm_start_config=None, + early_stopping_type="Off", + ): """Initialize a ``HyperparameterTuner``. It takes an estimator to obtain configuration information for training jobs that are created as the result of a hyperparameter tuning job. @@ -202,7 +229,7 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr """ self._hyperparameter_ranges = hyperparameter_ranges if self._hyperparameter_ranges is None or len(self._hyperparameter_ranges) == 0: - raise ValueError('Need to specify hyperparameter ranges') + raise ValueError("Need to specify hyperparameter ranges") self.estimator = estimator self.objective_metric_name = objective_metric_name @@ -225,10 +252,16 @@ def _prepare_for_training(self, job_name=None, include_cls_metadata=False): if job_name is not None: self._current_job_name = job_name else: - base_name = self.base_tuning_job_name or base_name_from_image(self.estimator.train_image()) - self._current_job_name = name_from_base(base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True) - - self.static_hyperparameters = {to_str(k): to_str(v) for (k, v) in self.estimator.hyperparameters().items()} + base_name = self.base_tuning_job_name or base_name_from_image( + self.estimator.train_image() + ) + self._current_job_name = name_from_base( + base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True + ) + + self.static_hyperparameters = { + to_str(k): to_str(v) for (k, v) in self.estimator.hyperparameters().items() + } for hyperparameter_name in self._hyperparameter_ranges.keys(): self.static_hyperparameters.pop(hyperparameter_name, None) @@ -236,8 +269,11 @@ def _prepare_for_training(self, job_name=None, include_cls_metadata=False): # (other algorithms may not accept extra hyperparameters) if include_cls_metadata or isinstance(self.estimator, Framework): self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = json.dumps( - self.estimator.__class__.__name__) - self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(self.estimator.__module__) + self.estimator.__class__.__name__ + ) + self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps( + self.estimator.__module__ + ) def fit(self, inputs=None, job_name=None, include_cls_metadata=False, **kwargs): """Start a hyperparameter tuning job. @@ -313,21 +349,34 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim sagemaker_session = sagemaker_session or Session() if job_details is None: - job_details = sagemaker_session.sagemaker_client \ - .describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=tuning_job_name) - - estimator_cls = cls._prepare_estimator_cls(estimator_cls, job_details['TrainingJobDefinition']) - estimator = cls._prepare_estimator_from_job_description(estimator_cls, job_details['TrainingJobDefinition'], - sagemaker_session) + job_details = sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=tuning_job_name + ) + + estimator_cls = cls._prepare_estimator_cls( + estimator_cls, job_details["TrainingJobDefinition"] + ) + estimator = cls._prepare_estimator_from_job_description( + estimator_cls, job_details["TrainingJobDefinition"], sagemaker_session + ) init_params = cls._prepare_init_params_from_job_description(job_details) tuner = cls(estimator=estimator, **init_params) - tuner.latest_tuning_job = _TuningJob(sagemaker_session=sagemaker_session, job_name=tuning_job_name) + tuner.latest_tuning_job = _TuningJob( + sagemaker_session=sagemaker_session, job_name=tuning_job_name + ) return tuner - def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, wait=True, - **kwargs): + def deploy( + self, + initial_instance_count, + instance_type, + accelerator_type=None, + endpoint_name=None, + wait=True, + **kwargs + ): """Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object. @@ -352,11 +401,17 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences. """ endpoint_name = endpoint_name or self.best_training_job() - best_estimator = self.estimator.attach(self.best_training_job(), - sagemaker_session=self.estimator.sagemaker_session) - return best_estimator.deploy(initial_instance_count, instance_type, - accelerator_type=accelerator_type, - endpoint_name=endpoint_name, wait=wait, **kwargs) + best_estimator = self.estimator.attach( + self.best_training_job(), sagemaker_session=self.estimator.sagemaker_session + ) + return best_estimator.deploy( + initial_instance_count, + instance_type, + accelerator_type=accelerator_type, + endpoint_name=endpoint_name, + wait=wait, + **kwargs + ) def stop_tuning_job(self): """Stop latest running hyperparameter tuning job. @@ -378,14 +433,18 @@ def best_training_job(self): """ self._ensure_last_tuning_job() - tuning_job_describe_result = \ - self.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=self.latest_tuning_job.name) + tuning_job_describe_result = self.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( # noqa: E501 + HyperParameterTuningJobName=self.latest_tuning_job.name + ) try: - return tuning_job_describe_result['BestTrainingJob']['TrainingJobName'] + return tuning_job_describe_result["BestTrainingJob"]["TrainingJobName"] except KeyError: - raise Exception('Best training job not available for tuning job: {}'.format(self.latest_tuning_job.name)) + raise Exception( + "Best training job not available for tuning job: {}".format( + self.latest_tuning_job.name + ) + ) def delete_endpoint(self, endpoint_name=None): """Delete an Amazon SageMaker endpoint. @@ -401,76 +460,97 @@ def delete_endpoint(self, endpoint_name=None): def _ensure_last_tuning_job(self): if self.latest_tuning_job is None: - raise ValueError('No tuning job available') + raise ValueError("No tuning job available") @classmethod def _prepare_estimator_cls(cls, estimator_cls, training_details): # Check for customer-specified estimator first if estimator_cls is not None: - module, cls_name = estimator_cls.rsplit('.', 1) + module, cls_name = estimator_cls.rsplit(".", 1) return getattr(importlib.import_module(module), cls_name) # Then check for estimator class in hyperparameters - hyperparameters = training_details['StaticHyperParameters'] - if cls.SAGEMAKER_ESTIMATOR_CLASS_NAME in hyperparameters and cls.SAGEMAKER_ESTIMATOR_MODULE in hyperparameters: + hyperparameters = training_details["StaticHyperParameters"] + if ( + cls.SAGEMAKER_ESTIMATOR_CLASS_NAME in hyperparameters + and cls.SAGEMAKER_ESTIMATOR_MODULE in hyperparameters + ): module = hyperparameters.get(cls.SAGEMAKER_ESTIMATOR_MODULE) cls_name = hyperparameters.get(cls.SAGEMAKER_ESTIMATOR_CLASS_NAME) return getattr(importlib.import_module(json.loads(module)), json.loads(cls_name)) # Then try to derive the estimator from the image name for 1P algorithms - image_name = training_details['AlgorithmSpecification']['TrainingImage'] - algorithm = image_name[image_name.find('/') + 1:image_name.find(':')] + image_name = training_details["AlgorithmSpecification"]["TrainingImage"] + algorithm = image_name[image_name.find("/") + 1 : image_name.find(":")] if algorithm in AMAZON_ESTIMATOR_CLS_NAMES: cls_name = AMAZON_ESTIMATOR_CLS_NAMES[algorithm] return getattr(importlib.import_module(AMAZON_ESTIMATOR_MODULE), cls_name) # Default to the BYO estimator - return getattr(importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE), cls.DEFAULT_ESTIMATOR_CLS_NAME) + return getattr( + importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE), cls.DEFAULT_ESTIMATOR_CLS_NAME + ) @classmethod - def _prepare_estimator_from_job_description(cls, estimator_cls, training_details, sagemaker_session): + def _prepare_estimator_from_job_description( + cls, estimator_cls, training_details, sagemaker_session + ): # Swap name for static hyperparameters to what an estimator would expect - training_details['HyperParameters'] = training_details['StaticHyperParameters'] - del training_details['StaticHyperParameters'] + training_details["HyperParameters"] = training_details["StaticHyperParameters"] + del training_details["StaticHyperParameters"] # Remove hyperparameter reserved by SageMaker for tuning jobs - del training_details['HyperParameters']['_tuning_objective_metric'] + del training_details["HyperParameters"]["_tuning_objective_metric"] # Add items expected by the estimator (but aren't needed otherwise) - training_details['TrainingJobName'] = '' - if 'KmsKeyId' not in training_details['OutputDataConfig']: - training_details['OutputDataConfig']['KmsKeyId'] = '' + training_details["TrainingJobName"] = "" + if "KmsKeyId" not in training_details["OutputDataConfig"]: + training_details["OutputDataConfig"]["KmsKeyId"] = "" - estimator_init_params = estimator_cls._prepare_init_params_from_job_description(training_details) + estimator_init_params = estimator_cls._prepare_init_params_from_job_description( + training_details + ) return estimator_cls(sagemaker_session=sagemaker_session, **estimator_init_params) @classmethod def _prepare_init_params_from_job_description(cls, job_details): - tuning_config = job_details['HyperParameterTuningJobConfig'] + tuning_config = job_details["HyperParameterTuningJobConfig"] return { - 'metric_definitions': job_details['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'], - 'objective_metric_name': tuning_config['HyperParameterTuningJobObjective']['MetricName'], - 'objective_type': tuning_config['HyperParameterTuningJobObjective']['Type'], - 'hyperparameter_ranges': cls._prepare_parameter_ranges(tuning_config['ParameterRanges']), - 'strategy': tuning_config['Strategy'], - 'max_jobs': tuning_config['ResourceLimits']['MaxNumberOfTrainingJobs'], - 'max_parallel_jobs': tuning_config['ResourceLimits']['MaxParallelTrainingJobs'], - 'warm_start_config': WarmStartConfig.from_job_desc(job_details.get('WarmStartConfig', None)), - 'early_stopping_type': tuning_config['TrainingJobEarlyStoppingType'] + "metric_definitions": job_details["TrainingJobDefinition"]["AlgorithmSpecification"][ + "MetricDefinitions" + ], + "objective_metric_name": tuning_config["HyperParameterTuningJobObjective"][ + "MetricName" + ], + "objective_type": tuning_config["HyperParameterTuningJobObjective"]["Type"], + "hyperparameter_ranges": cls._prepare_parameter_ranges( + tuning_config["ParameterRanges"] + ), + "strategy": tuning_config["Strategy"], + "max_jobs": tuning_config["ResourceLimits"]["MaxNumberOfTrainingJobs"], + "max_parallel_jobs": tuning_config["ResourceLimits"]["MaxParallelTrainingJobs"], + "warm_start_config": WarmStartConfig.from_job_desc( + job_details.get("WarmStartConfig", None) + ), + "early_stopping_type": tuning_config["TrainingJobEarlyStoppingType"], } @classmethod def _prepare_parameter_ranges(cls, parameter_ranges): ranges = {} - for parameter in parameter_ranges['CategoricalParameterRanges']: - ranges[parameter['Name']] = CategoricalParameter(parameter['Values']) + for parameter in parameter_ranges["CategoricalParameterRanges"]: + ranges[parameter["Name"]] = CategoricalParameter(parameter["Values"]) - for parameter in parameter_ranges['ContinuousParameterRanges']: - ranges[parameter['Name']] = ContinuousParameter(float(parameter['MinValue']), float(parameter['MaxValue'])) + for parameter in parameter_ranges["ContinuousParameterRanges"]: + ranges[parameter["Name"]] = ContinuousParameter( + float(parameter["MinValue"]), float(parameter["MaxValue"]) + ) - for parameter in parameter_ranges['IntegerParameterRanges']: - ranges[parameter['Name']] = IntegerParameter(int(parameter['MinValue']), int(parameter['MaxValue'])) + for parameter in parameter_ranges["IntegerParameterRanges"]: + ranges[parameter["Name"]] = IntegerParameter( + int(parameter["MinValue"]), int(parameter["MaxValue"]) + ) return ranges @@ -484,12 +564,14 @@ def hyperparameter_ranges(self): for parameter_name, parameter in self._hyperparameter_ranges.items(): if parameter is not None and parameter.__name__ == range_type: # Categorical parameters needed to be serialized as JSON for our framework containers - if isinstance(parameter, CategoricalParameter) and isinstance(self.estimator, Framework): + if isinstance(parameter, CategoricalParameter) and isinstance( + self.estimator, Framework + ): tuning_range = parameter.as_json_range(parameter_name) else: tuning_range = parameter.as_tuning_range(parameter_name) parameter_ranges.append(tuning_range) - hyperparameter_ranges[range_type + 'ParameterRanges'] = parameter_ranges + hyperparameter_ranges[range_type + "ParameterRanges"] = parameter_ranges return hyperparameter_ranges @property @@ -521,8 +603,8 @@ def _validate_parameter_ranges(self): pass def _validate_parameter_range(self, value_hp, parameter_range): - for parameter_range_key, parameter_range_value in parameter_range.__dict__.items(): - if parameter_range_key == 'scaling_type': + for (parameter_range_key, parameter_range_value) in parameter_range.__dict__.items(): + if parameter_range_key == "scaling_type": continue # Categorical ranges @@ -556,9 +638,11 @@ def transfer_learning_tuner(self, additional_parents=None, estimator=None): >>> transfer_learning_tuner.fit(inputs={}) """ - return self._create_warm_start_tuner(additional_parents=additional_parents, - warm_start_type=WarmStartTypes.TRANSFER_LEARNING, - estimator=estimator) + return self._create_warm_start_tuner( + additional_parents=additional_parents, + warm_start_type=WarmStartTypes.TRANSFER_LEARNING, + estimator=estimator, + ) def identical_dataset_and_algorithm_tuner(self, additional_parents=None): """Creates a new ``HyperparameterTuner`` by copying the request fields from the provided parent to the new @@ -581,8 +665,10 @@ def identical_dataset_and_algorithm_tuner(self, additional_parents=None): >>> identical_dataset_algo_tuner.fit(inputs={}) """ - return self._create_warm_start_tuner(additional_parents=additional_parents, - warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM) + return self._create_warm_start_tuner( + additional_parents=additional_parents, + warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, + ) def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimator=None): """Creates a new ``HyperparameterTuner`` with ``WarmStartConfig``, where type will be equal to @@ -600,14 +686,15 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato if additional_parents: all_parents = all_parents.union(additional_parents) - return HyperparameterTuner(estimator=estimator if estimator else self.estimator, - objective_metric_name=self.objective_metric_name, - hyperparameter_ranges=self._hyperparameter_ranges, - objective_type=self.objective_type, - max_jobs=self.max_jobs, - max_parallel_jobs=self.max_parallel_jobs, - warm_start_config=WarmStartConfig(warm_start_type=warm_start_type, - parents=all_parents)) + return HyperparameterTuner( + estimator=estimator if estimator else self.estimator, + objective_metric_name=self.objective_metric_name, + hyperparameter_ranges=self._hyperparameter_ranges, + objective_type=self.objective_type, + max_jobs=self.max_jobs, + max_parallel_jobs=self.max_parallel_jobs, + warm_start_config=WarmStartConfig(warm_start_type=warm_start_type, parents=all_parents), + ) class _TuningJob(_Job): @@ -630,34 +717,38 @@ def start_new(cls, tuner, inputs): tuner_args = config.copy() - tuner_args['job_name'] = tuner._current_job_name - tuner_args['strategy'] = tuner.strategy - tuner_args['objective_type'] = tuner.objective_type - tuner_args['objective_metric_name'] = tuner.objective_metric_name - tuner_args['max_jobs'] = tuner.max_jobs - tuner_args['max_parallel_jobs'] = tuner.max_parallel_jobs - tuner_args['parameter_ranges'] = tuner.hyperparameter_ranges() - tuner_args['static_hyperparameters'] = tuner.static_hyperparameters - tuner_args['input_mode'] = tuner.estimator.input_mode - tuner_args['metric_definitions'] = tuner.metric_definitions - tuner_args['tags'] = tuner.tags - tuner_args['warm_start_config'] = warm_start_config_req - tuner_args['early_stopping_type'] = tuner.early_stopping_type + tuner_args["job_name"] = tuner._current_job_name + tuner_args["strategy"] = tuner.strategy + tuner_args["objective_type"] = tuner.objective_type + tuner_args["objective_metric_name"] = tuner.objective_metric_name + tuner_args["max_jobs"] = tuner.max_jobs + tuner_args["max_parallel_jobs"] = tuner.max_parallel_jobs + tuner_args["parameter_ranges"] = tuner.hyperparameter_ranges() + tuner_args["static_hyperparameters"] = tuner.static_hyperparameters + tuner_args["input_mode"] = tuner.estimator.input_mode + tuner_args["metric_definitions"] = tuner.metric_definitions + tuner_args["tags"] = tuner.tags + tuner_args["warm_start_config"] = warm_start_config_req + tuner_args["early_stopping_type"] = tuner.early_stopping_type if isinstance(inputs, s3_input): - if 'InputMode' in inputs.config: - logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.' - .format(inputs.config['InputMode'])) - tuner_args['input_mode'] = inputs.config['InputMode'] + if "InputMode" in inputs.config: + logging.debug( + "Selecting s3_input's input_mode ({}) for TrainingInputMode.".format( + inputs.config["InputMode"] + ) + ) + tuner_args["input_mode"] = inputs.config["InputMode"] if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator): - tuner_args['algorithm_arn'] = tuner.estimator.algorithm_arn + tuner_args["algorithm_arn"] = tuner.estimator.algorithm_arn else: - tuner_args['image'] = tuner.estimator.train_image() + tuner_args["image"] = tuner.estimator.train_image() - tuner_args['enable_network_isolation'] = tuner.estimator.enable_network_isolation() - tuner_args['encrypt_inter_container_traffic'] = \ - tuner.estimator.encrypt_inter_container_traffic + tuner_args["enable_network_isolation"] = tuner.estimator.enable_network_isolation() + tuner_args[ + "encrypt_inter_container_traffic" + ] = tuner.estimator.encrypt_inter_container_traffic tuner.estimator.sagemaker_session.tune(**tuner_args) @@ -670,7 +761,9 @@ def wait(self): self.sagemaker_session.wait_for_tuning_job(self.name) -def create_identical_dataset_and_algorithm_tuner(parent, additional_parents=None, sagemaker_session=None): +def create_identical_dataset_and_algorithm_tuner( + parent, additional_parents=None, sagemaker_session=None +): """Creates a new tuner by copying the request fields from the provided parent to the new instance of ``HyperparameterTuner`` followed by addition of warm start configuration with the type as "IdenticalDataAndAlgorithm" and ``parents`` as the union of provided list of ``additional_parents`` and the @@ -689,11 +782,15 @@ def create_identical_dataset_and_algorithm_tuner(parent, additional_parents=None hyperparameter tuning job """ - parent_tuner = HyperparameterTuner.attach(tuning_job_name=parent, sagemaker_session=sagemaker_session) + parent_tuner = HyperparameterTuner.attach( + tuning_job_name=parent, sagemaker_session=sagemaker_session + ) return parent_tuner.identical_dataset_and_algorithm_tuner(additional_parents=additional_parents) -def create_transfer_learning_tuner(parent, additional_parents=None, estimator=None, sagemaker_session=None): +def create_transfer_learning_tuner( + parent, additional_parents=None, estimator=None, sagemaker_session=None +): """Creates a new ``HyperParameterTuner`` by copying the request fields from the provided parent to the new instance of ``HyperparameterTuner`` followed by addition of warm start configuration with the type as "TransferLearning" and ``parents`` as the union of provided list of ``additional_parents`` and the ``parent``. @@ -712,6 +809,9 @@ def create_transfer_learning_tuner(parent, additional_parents=None, estimator=No sagemaker.tuner.HyperparameterTuner: New instance of warm started HyperparameterTuner """ - parent_tuner = HyperparameterTuner.attach(tuning_job_name=parent, sagemaker_session=sagemaker_session) - return parent_tuner.transfer_learning_tuner(additional_parents=additional_parents, - estimator=estimator) + parent_tuner = HyperparameterTuner.attach( + tuning_job_name=parent, sagemaker_session=sagemaker_session + ) + return parent_tuner.transfer_learning_tuner( + additional_parents=additional_parents, estimator=estimator + ) diff --git a/src/sagemaker/user_agent.py b/src/sagemaker/user_agent.py index f4ed19ec18..a63d2bb6bb 100644 --- a/src/sagemaker/user_agent.py +++ b/src/sagemaker/user_agent.py @@ -19,19 +19,24 @@ import boto3 import botocore -SDK_VERSION = pkg_resources.require('sagemaker')[0].version -OS_NAME = platform.system() or 'UnresolvedOS' -OS_VERSION = platform.release() or 'UnresolvedOSVersion' -PYTHON_VERSION = '{}.{}.{}'.format(sys.version_info.major, sys.version_info.minor, sys.version_info.micro) +SDK_VERSION = pkg_resources.require("sagemaker")[0].version +OS_NAME = platform.system() or "UnresolvedOS" +OS_VERSION = platform.release() or "UnresolvedOSVersion" +PYTHON_VERSION = "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro +) def determine_prefix(): - prefix = 'AWS-SageMaker-Python-SDK/{} Python/{} {}/{} Boto3/{} Botocore/{}'\ - .format(SDK_VERSION, PYTHON_VERSION, OS_NAME, OS_VERSION, boto3.__version__, botocore.__version__) + prefix = "AWS-SageMaker-Python-SDK/{} Python/{} {}/{} Boto3/{} Botocore/{}".format( + SDK_VERSION, PYTHON_VERSION, OS_NAME, OS_VERSION, boto3.__version__, botocore.__version__ + ) try: - with open('/etc/opt/ml/sagemaker-notebook-instance-version.txt') as sagemaker_nbi_file: - prefix = 'AWS-SageMaker-Notebook-Instance/{} {}'.format(sagemaker_nbi_file.read().strip(), prefix) + with open("/etc/opt/ml/sagemaker-notebook-instance-version.txt") as sagemaker_nbi_file: + prefix = "AWS-SageMaker-Notebook-Instance/{} {}".format( + sagemaker_nbi_file.read().strip(), prefix + ) except IOError: # This file isn't expected to always exist, and we DO want to silently ignore failures. pass @@ -45,4 +50,4 @@ def prepend_user_agent(client): if client._client_config.user_agent is None: client._client_config.user_agent = prefix else: - client._client_config.user_agent = '{} {}'.format(prefix, client._client_config.user_agent) + client._client_config.user_agent = "{} {}".format(prefix, client._client_config.user_agent) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 9d1d139cb3..68f5f9d47f 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,7 +29,7 @@ import six -ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' +ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$" # Use the base name of the image as the job name if the user doesn't give us one @@ -60,16 +60,16 @@ def name_from_base(base, max_length=63, short=False): str: Input parameter with appended timestamp. """ timestamp = sagemaker_short_timestamp() if short else sagemaker_timestamp() - trimmed_base = base[:max_length - len(timestamp) - 1] - return '{}-{}'.format(trimmed_base, timestamp) + trimmed_base = base[: max_length - len(timestamp) - 1] + return "{}-{}".format(trimmed_base, timestamp) def unique_name_from_base(base, max_length=63): - unique = '%04x' % random.randrange(16**4) # 4-digit hex + unique = "%04x" % random.randrange(16 ** 4) # 4-digit hex ts = str(int(time.time())) available_length = max_length - 2 - len(ts) - len(unique) trimmed = base[:available_length] - return '{}-{}-{}'.format(trimmed, ts, unique) + return "{}-{}-{}".format(trimmed, ts, unique) def base_name_from_image(image): @@ -89,17 +89,18 @@ def base_name_from_image(image): def sagemaker_timestamp(): """Return a timestamp with millisecond precision.""" moment = time.time() - moment_ms = repr(moment).split('.')[1][:3] + moment_ms = repr(moment).split(".")[1][:3] return time.strftime("%Y-%m-%d-%H-%M-%S-{}".format(moment_ms), time.gmtime(moment)) def sagemaker_short_timestamp(): """Return a timestamp that is relatively short in length""" - return time.strftime('%y%m%d-%H%M') + return time.strftime("%y%m%d-%H%M") def debug(func): """Print the function name and arguments for debugging.""" + @wraps(func) def wrapper(*args, **kwargs): print("{} args: {} kwargs: {}".format(func.__name__, args, kwargs)) @@ -113,7 +114,7 @@ def get_config_value(key_path, config): return None current_section = config - for key in key_path.split('.'): + for key in key_path.split("."): if key in current_section: current_section = current_section[key] else: @@ -144,10 +145,10 @@ def extract_name_from_job_arn(arn): """Returns the name used in the API given a full ARN for a training job or hyperparameter tuning job. """ - slash_pos = arn.find('/') + slash_pos = arn.find("/") if slash_pos == -1: raise ValueError("Cannot parse invalid ARN: %s" % arn) - return arn[(slash_pos + 1):] + return arn[(slash_pos + 1) :] def secondary_training_status_changed(current_job_description, prev_job_description): @@ -161,17 +162,27 @@ def secondary_training_status_changed(current_job_description, prev_job_descript boolean: Whether the secondary status message of a training job changed or not. """ - current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions') - if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0: + current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions") + if ( + current_secondary_status_transitions is None + or len(current_secondary_status_transitions) == 0 + ): return False - prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \ - if prev_job_description is not None else None + prev_job_secondary_status_transitions = ( + prev_job_description.get("SecondaryStatusTransitions") + if prev_job_description is not None + else None + ) - last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \ - if prev_job_secondary_status_transitions is not None and len(prev_job_secondary_status_transitions) > 0 else '' + last_message = ( + prev_job_secondary_status_transitions[-1]["StatusMessage"] + if prev_job_secondary_status_transitions is not None + and len(prev_job_secondary_status_transitions) > 0 + else "" + ) - message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage'] + message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"] return message != last_message @@ -188,31 +199,41 @@ def secondary_training_status_message(job_description, prev_description): """ - if job_description is None or job_description.get('SecondaryStatusTransitions') is None\ - or len(job_description.get('SecondaryStatusTransitions')) == 0: - return '' - - prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\ - if prev_description is not None else None - prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\ - if prev_description_secondary_transitions is not None else 0 - current_transitions = job_description['SecondaryStatusTransitions'] + if ( + job_description is None + or job_description.get("SecondaryStatusTransitions") is None + or len(job_description.get("SecondaryStatusTransitions")) == 0 + ): + return "" + + prev_description_secondary_transitions = ( + prev_description.get("SecondaryStatusTransitions") if prev_description is not None else None + ) + prev_transitions_num = ( + len(prev_description["SecondaryStatusTransitions"]) + if prev_description_secondary_transitions is not None + else 0 + ) + current_transitions = job_description["SecondaryStatusTransitions"] if len(current_transitions) == prev_transitions_num: # Secondary status is not changed but the message changed. transitions_to_print = current_transitions[-1:] else: # Secondary status is changed we need to print all the entries. - transitions_to_print = current_transitions[prev_transitions_num - len(current_transitions):] + transitions_to_print = current_transitions[ + prev_transitions_num - len(current_transitions) : + ] status_strs = [] for transition in transitions_to_print: - message = transition['StatusMessage'] + message = transition["StatusMessage"] time_str = datetime.utcfromtimestamp( - time.mktime(job_description['LastModifiedTime'].timetuple())).strftime('%Y-%m-%d %H:%M:%S') - status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message)) + time.mktime(job_description["LastModifiedTime"].timetuple()) + ).strftime("%Y-%m-%d %H:%M:%S") + status_strs.append("{} {} - {}".format(time_str, transition["Status"], message)) - return '\n'.join(status_strs) + return "\n".join(status_strs) def download_folder(bucket_name, prefix, target, sagemaker_session): @@ -226,26 +247,26 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): """ boto_session = sagemaker_session.boto_session - s3 = boto_session.resource('s3') + s3 = boto_session.resource("s3") bucket = s3.Bucket(bucket_name) - prefix = prefix.lstrip('/') + prefix = prefix.lstrip("/") # there is a chance that the prefix points to a file and not a 'directory' if that is the case # we should just download it. objects = list(bucket.objects.filter(Prefix=prefix)) - if len(objects) > 0 and objects[0].key == prefix and prefix[-1] != '/': + if len(objects) > 0 and objects[0].key == prefix and prefix[-1] != "/": s3.Object(bucket_name, prefix).download_file(os.path.join(target, os.path.basename(prefix))) return # the prefix points to an s3 'directory' download the whole thing for obj_sum in bucket.objects.filter(Prefix=prefix): # if obj_sum is a folder object skip it. - if obj_sum.key != '' and obj_sum.key[-1] == '/': + if obj_sum.key != "" and obj_sum.key[-1] == "/": continue obj = s3.Object(obj_sum.bucket_name, obj_sum.key) - s3_relative_path = obj_sum.key[len(prefix):].lstrip('/') + s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/") file_path = os.path.join(target, s3_relative_path) try: @@ -270,7 +291,7 @@ def create_tar_file(source_files, target=None): else: _, filename = tempfile.mkstemp() - with tarfile.open(filename, mode='w:gz') as t: + with tarfile.open(filename, mode="w:gz") as t: for sf in source_files: # Add all files from the directory into the root of the directory structure of the tar t.add(sf, arcname=os.path.basename(sf)) @@ -278,7 +299,7 @@ def create_tar_file(source_files, target=None): @contextlib.contextmanager -def _tmpdir(suffix='', prefix='tmp'): +def _tmpdir(suffix="", prefix="tmp"): """Create a temporary directory with a context manager. The file is deleted when the context exits. The prefix, suffix, and dir arguments are the same as for mkstemp(). @@ -298,12 +319,14 @@ def _tmpdir(suffix='', prefix='tmp'): shutil.rmtree(tmp) -def repack_model(inference_script, - source_directory, - dependencies, - model_uri, - repacked_model_uri, - sagemaker_session): +def repack_model( + inference_script, + source_directory, + dependencies, + model_uri, + repacked_model_uri, + sagemaker_session, +): """Unpack model tarball and creates a new model tarball with the provided code script. This function does the following: @@ -342,37 +365,41 @@ def repack_model(inference_script, with _tmpdir() as tmp: model_dir = _extract_model(model_uri, sagemaker_session, tmp) - _create_or_update_code_dir(model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp) + _create_or_update_code_dir( + model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp + ) - tmp_model_path = os.path.join(tmp, 'temp-model.tar.gz') - with tarfile.open(tmp_model_path, mode='w:gz') as t: + tmp_model_path = os.path.join(tmp, "temp-model.tar.gz") + with tarfile.open(tmp_model_path, mode="w:gz") as t: t.add(model_dir, arcname=os.path.sep) _save_model(repacked_model_uri, tmp_model_path, sagemaker_session) def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session): - if repacked_model_uri.lower().startswith('s3://'): + if repacked_model_uri.lower().startswith("s3://"): url = parse.urlparse(repacked_model_uri) - bucket, key = url.netloc, url.path.lstrip('/') + bucket, key = url.netloc, url.path.lstrip("/") new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri)) - sagemaker_session.boto_session.resource('s3').Object(bucket, new_key).upload_file( - tmp_model_path) + sagemaker_session.boto_session.resource("s3").Object(bucket, new_key).upload_file( + tmp_model_path + ) else: - shutil.move(tmp_model_path, repacked_model_uri.replace('file://', '')) + shutil.move(tmp_model_path, repacked_model_uri.replace("file://", "")) -def _create_or_update_code_dir(model_dir, inference_script, source_directory, - dependencies, sagemaker_session, tmp): - code_dir = os.path.join(model_dir, 'code') +def _create_or_update_code_dir( + model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp +): + code_dir = os.path.join(model_dir, "code") if os.path.exists(code_dir): shutil.rmtree(code_dir, ignore_errors=True) - if source_directory and source_directory.lower().startswith('s3://'): - local_code_path = os.path.join(tmp, 'local_code.tar.gz') + if source_directory and source_directory.lower().startswith("s3://"): + local_code_path = os.path.join(tmp, "local_code.tar.gz") download_file_from_url(source_directory, local_code_path, sagemaker_session) - with tarfile.open(name=local_code_path, mode='r:gz') as t: + with tarfile.open(name=local_code_path, mode="r:gz") as t: t.extractall(path=code_dir) elif source_directory: @@ -389,21 +416,21 @@ def _create_or_update_code_dir(model_dir, inference_script, source_directory, def _extract_model(model_uri, sagemaker_session, tmp): - tmp_model_dir = os.path.join(tmp, 'model') + tmp_model_dir = os.path.join(tmp, "model") os.mkdir(tmp_model_dir) - if model_uri.lower().startswith('s3://'): - local_model_path = os.path.join(tmp, 'tar_file') + if model_uri.lower().startswith("s3://"): + local_model_path = os.path.join(tmp, "tar_file") download_file_from_url(model_uri, local_model_path, sagemaker_session) else: - local_model_path = model_uri.replace('file://', '') - with tarfile.open(name=local_model_path, mode='r:gz') as t: + local_model_path = model_uri.replace("file://", "") + with tarfile.open(name=local_model_path, mode="r:gz") as t: t.extractall(path=tmp_model_dir) return tmp_model_dir def download_file_from_url(url, dst, sagemaker_session): url = parse.urlparse(url) - bucket, key = url.netloc, url.path.lstrip('/') + bucket, key = url.netloc, url.path.lstrip("/") download_file(bucket, key, dst, sagemaker_session) @@ -417,10 +444,10 @@ def download_file(bucket_name, path, target, sagemaker_session): target (str): destination directory for the downloaded file. sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3. """ - path = path.lstrip('/') + path = path.lstrip("/") boto_session = sagemaker_session.boto_session - s3 = boto_session.resource('s3') + s3 = boto_session.resource("s3") bucket = s3.Bucket(bucket_name) bucket.download_file(path, target) @@ -435,8 +462,8 @@ def get_ecr_image_uri_prefix(account, region): Returns: (str): URI prefix of ECR image """ - domain = 'c2s.ic.gov' if region == 'us-iso-east-1' else 'amazonaws.com' - return '{}.dkr.ecr.{}.{}'.format(account, region, domain) + domain = "c2s.ic.gov" if region == "us-iso-east-1" else "amazonaws.com" + return "{}.dkr.ecr.{}.{}".format(account, region, domain) class DeferredError(object): diff --git a/src/sagemaker/vpc_utils.py b/src/sagemaker/vpc_utils.py index 8fb84d7006..1be426b424 100644 --- a/src/sagemaker/vpc_utils.py +++ b/src/sagemaker/vpc_utils.py @@ -12,13 +12,13 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -SUBNETS_KEY = 'Subnets' -SECURITY_GROUP_IDS_KEY = 'SecurityGroupIds' -VPC_CONFIG_KEY = 'VpcConfig' +SUBNETS_KEY = "Subnets" +SECURITY_GROUP_IDS_KEY = "SecurityGroupIds" +VPC_CONFIG_KEY = "VpcConfig" # A global constant value for methods which can optionally override VpcConfig # Using the default implies that VpcConfig should be reused from an existing Estimator or Training Job -VPC_CONFIG_DEFAULT = 'VPC_CONFIG_DEFAULT' +VPC_CONFIG_DEFAULT = "VPC_CONFIG_DEFAULT" def to_dict(subnets, security_group_ids): @@ -37,8 +37,7 @@ def to_dict(subnets, security_group_ids): """ if subnets is None or security_group_ids is None: return None - return {SUBNETS_KEY: subnets, - SECURITY_GROUP_IDS_KEY: security_group_ids} + return {SUBNETS_KEY: subnets, SECURITY_GROUP_IDS_KEY: security_group_ids} def from_dict(vpc_config, do_sanitize=False): @@ -85,24 +84,28 @@ def sanitize(vpc_config): if vpc_config is None: return vpc_config elif type(vpc_config) is not dict: - raise ValueError('vpc_config is not a dict: {}'.format(vpc_config)) + raise ValueError("vpc_config is not a dict: {}".format(vpc_config)) elif not vpc_config: - raise ValueError('vpc_config is empty') + raise ValueError("vpc_config is empty") subnets = vpc_config.get(SUBNETS_KEY) if subnets is None: - raise ValueError('vpc_config is missing key: {}'.format(SUBNETS_KEY)) + raise ValueError("vpc_config is missing key: {}".format(SUBNETS_KEY)) if type(subnets) is not list: - raise ValueError('vpc_config value for {} is not a list: {}'.format(SUBNETS_KEY, subnets)) + raise ValueError("vpc_config value for {} is not a list: {}".format(SUBNETS_KEY, subnets)) elif not subnets: - raise ValueError('vpc_config value for {} is empty'.format(SUBNETS_KEY)) + raise ValueError("vpc_config value for {} is empty".format(SUBNETS_KEY)) security_group_ids = vpc_config.get(SECURITY_GROUP_IDS_KEY) if security_group_ids is None: - raise ValueError('vpc_config is missing key: {}'.format(SECURITY_GROUP_IDS_KEY)) + raise ValueError("vpc_config is missing key: {}".format(SECURITY_GROUP_IDS_KEY)) if type(security_group_ids) is not list: - raise ValueError('vpc_config value for {} is not a list: {}'.format(SECURITY_GROUP_IDS_KEY, security_group_ids)) + raise ValueError( + "vpc_config value for {} is not a list: {}".format( + SECURITY_GROUP_IDS_KEY, security_group_ids + ) + ) elif not security_group_ids: - raise ValueError('vpc_config value for {} is empty'.format(SECURITY_GROUP_IDS_KEY)) + raise ValueError("vpc_config value for {} is empty".format(SECURITY_GROUP_IDS_KEY)) return to_dict(subnets, security_group_ids) diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index b38140148d..92de8b22b2 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -30,31 +30,32 @@ def prepare_framework(estimator, s3_operations): """ if estimator.code_location is not None: bucket, key = fw_utils.parse_s3_url(estimator.code_location) - key = os.path.join(key, estimator._current_job_name, 'source', 'sourcedir.tar.gz') + key = os.path.join(key, estimator._current_job_name, "source", "sourcedir.tar.gz") else: bucket = estimator.sagemaker_session._default_bucket - key = os.path.join(estimator._current_job_name, 'source', 'sourcedir.tar.gz') + key = os.path.join(estimator._current_job_name, "source", "sourcedir.tar.gz") script = os.path.basename(estimator.entry_point) - if estimator.source_dir and estimator.source_dir.lower().startswith('s3://'): + if estimator.source_dir and estimator.source_dir.lower().startswith("s3://"): code_dir = estimator.source_dir estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) else: - code_dir = 's3://{}/{}'.format(bucket, key) + code_dir = "s3://{}/{}".format(bucket, key) estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) - s3_operations['S3Upload'] = [{ - 'Path': estimator.source_dir or script, - 'Bucket': bucket, - 'Key': key, - 'Tar': True - }] + s3_operations["S3Upload"] = [ + {"Path": estimator.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True} + ] estimator._hyperparameters[sagemaker.model.DIR_PARAM_NAME] = code_dir estimator._hyperparameters[sagemaker.model.SCRIPT_PARAM_NAME] = script - estimator._hyperparameters[sagemaker.model.CLOUDWATCH_METRICS_PARAM_NAME] = \ - estimator.enable_cloudwatch_metrics - estimator._hyperparameters[sagemaker.model.CONTAINER_LOG_LEVEL_PARAM_NAME] = estimator.container_log_level + estimator._hyperparameters[ + sagemaker.model.CLOUDWATCH_METRICS_PARAM_NAME + ] = estimator.enable_cloudwatch_metrics + estimator._hyperparameters[ + sagemaker.model.CONTAINER_LOG_LEVEL_PARAM_NAME + ] = estimator.container_log_level estimator._hyperparameters[sagemaker.model.JOB_NAME_PARAM_NAME] = estimator._current_job_name - estimator._hyperparameters[sagemaker.model.SAGEMAKER_REGION_PARAM_NAME] = \ - estimator.sagemaker_session.boto_region_name + estimator._hyperparameters[ + sagemaker.model.SAGEMAKER_REGION_PARAM_NAME + ] = estimator.sagemaker_session.boto_region_name def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None): @@ -73,13 +74,13 @@ def prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size=None): """ if isinstance(inputs, list): for record in inputs: - if isinstance(record, amazon_estimator.RecordSet) and record.channel == 'train': + if isinstance(record, amazon_estimator.RecordSet) and record.channel == "train": estimator.feature_dim = record.feature_dim break elif isinstance(inputs, amazon_estimator.RecordSet): estimator.feature_dim = inputs.feature_dim else: - raise TypeError('Training data must be represented in RecordSet or list of RecordSets') + raise TypeError("Training data must be represented in RecordSet or list of RecordSets") estimator.mini_batch_size = mini_batch_size @@ -124,7 +125,7 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size= estimator._current_job_name = utils.name_from_base(base_name) if estimator.output_path is None: - estimator.output_path = 's3://{}/'.format(default_bucket) + estimator.output_path = "s3://{}/".format(default_bucket) if isinstance(estimator, sagemaker.estimator.Framework): prepare_framework(estimator, s3_operations) @@ -134,30 +135,30 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size= job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False) train_config = { - 'AlgorithmSpecification': { - 'TrainingImage': estimator.train_image(), - 'TrainingInputMode': estimator.input_mode + "AlgorithmSpecification": { + "TrainingImage": estimator.train_image(), + "TrainingInputMode": estimator.input_mode, }, - 'OutputDataConfig': job_config['output_config'], - 'StoppingCondition': job_config['stop_condition'], - 'ResourceConfig': job_config['resource_config'], - 'RoleArn': job_config['role'], + "OutputDataConfig": job_config["output_config"], + "StoppingCondition": job_config["stop_condition"], + "ResourceConfig": job_config["resource_config"], + "RoleArn": job_config["role"], } - if job_config['input_config'] is not None: - train_config['InputDataConfig'] = job_config['input_config'] + if job_config["input_config"] is not None: + train_config["InputDataConfig"] = job_config["input_config"] - if job_config['vpc_config'] is not None: - train_config['VpcConfig'] = job_config['vpc_config'] + if job_config["vpc_config"] is not None: + train_config["VpcConfig"] = job_config["vpc_config"] if estimator.hyperparameters() is not None: hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()} if hyperparameters and len(hyperparameters) > 0: - train_config['HyperParameters'] = hyperparameters + train_config["HyperParameters"] = hyperparameters if s3_operations: - train_config['S3Operations'] = s3_operations + train_config["S3Operations"] = s3_operations return train_config @@ -196,10 +197,10 @@ def training_config(estimator, inputs=None, job_name=None, mini_batch_size=None) train_config = training_base_config(estimator, inputs, job_name, mini_batch_size) - train_config['TrainingJobName'] = estimator._current_job_name + train_config["TrainingJobName"] = estimator._current_job_name if estimator.tags is not None: - train_config['Tags'] = estimator.tags + train_config["Tags"] = estimator.tags return train_config @@ -232,50 +233,56 @@ def tuning_config(tuner, inputs, job_name=None): dict: Tuning config that can be directly used by SageMakerTuningOperator in Airflow. """ train_config = training_base_config(tuner.estimator, inputs) - hyperparameters = train_config.pop('HyperParameters', None) - s3_operations = train_config.pop('S3Operations', None) + hyperparameters = train_config.pop("HyperParameters", None) + s3_operations = train_config.pop("S3Operations", None) if hyperparameters and len(hyperparameters) > 0: - tuner.static_hyperparameters = \ - {utils.to_str(k): utils.to_str(v) for (k, v) in hyperparameters.items()} + tuner.static_hyperparameters = { + utils.to_str(k): utils.to_str(v) for (k, v) in hyperparameters.items() + } if job_name is not None: tuner._current_job_name = job_name else: - base_name = tuner.base_tuning_job_name or utils.base_name_from_image(tuner.estimator.train_image()) - tuner._current_job_name = utils.name_from_base(base_name, tuner.TUNING_JOB_NAME_MAX_LENGTH, True) + base_name = tuner.base_tuning_job_name or utils.base_name_from_image( + tuner.estimator.train_image() + ) + tuner._current_job_name = utils.name_from_base( + base_name, tuner.TUNING_JOB_NAME_MAX_LENGTH, True + ) for hyperparameter_name in tuner._hyperparameter_ranges.keys(): tuner.static_hyperparameters.pop(hyperparameter_name, None) - train_config['StaticHyperParameters'] = tuner.static_hyperparameters + train_config["StaticHyperParameters"] = tuner.static_hyperparameters tune_config = { - 'HyperParameterTuningJobName': tuner._current_job_name, - 'HyperParameterTuningJobConfig': { - 'Strategy': tuner.strategy, - 'HyperParameterTuningJobObjective': { - 'Type': tuner.objective_type, - 'MetricName': tuner.objective_metric_name, + "HyperParameterTuningJobName": tuner._current_job_name, + "HyperParameterTuningJobConfig": { + "Strategy": tuner.strategy, + "HyperParameterTuningJobObjective": { + "Type": tuner.objective_type, + "MetricName": tuner.objective_metric_name, }, - 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': tuner.max_jobs, - 'MaxParallelTrainingJobs': tuner.max_parallel_jobs, + "ResourceLimits": { + "MaxNumberOfTrainingJobs": tuner.max_jobs, + "MaxParallelTrainingJobs": tuner.max_parallel_jobs, }, - 'ParameterRanges': tuner.hyperparameter_ranges(), + "ParameterRanges": tuner.hyperparameter_ranges(), }, - 'TrainingJobDefinition': train_config + "TrainingJobDefinition": train_config, } if tuner.metric_definitions is not None: - tune_config['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = \ - tuner.metric_definitions + tune_config["TrainingJobDefinition"]["AlgorithmSpecification"][ + "MetricDefinitions" + ] = tuner.metric_definitions if tuner.tags is not None: - tune_config['Tags'] = tuner.tags + tune_config["Tags"] = tuner.tags if s3_operations is not None: - tune_config['S3Operations'] = s3_operations + tune_config["S3Operations"] = s3_operations return tune_config @@ -293,7 +300,7 @@ def update_submit_s3_uri(estimator, job_name): if estimator.uploaded_code is None: return - pattern = r'(?<=/)[^/]+?(?=/source/sourcedir.tar.gz)' + pattern = r"(?<=/)[^/]+?(?=/source/sourcedir.tar.gz)" # update the S3 URI with the latest training job. # s3://path/old_job/source/sourcedir.tar.gz will become s3://path/new_job/source/sourcedir.tar.gz @@ -315,14 +322,19 @@ def update_estimator_from_task(estimator, task_id, task_type): """ if task_type is None: return - if task_type.lower() == 'training': + if task_type.lower() == "training": training_job = "{{ ti.xcom_pull(task_ids='%s')['Training']['TrainingJobName'] }}" % task_id job_name = training_job - elif task_type.lower() == 'tuning': - training_job = "{{ ti.xcom_pull(task_ids='%s')['Tuning']['BestTrainingJob']['TrainingJobName'] }}" % task_id + elif task_type.lower() == "tuning": + training_job = ( + "{{ ti.xcom_pull(task_ids='%s')['Tuning']['BestTrainingJob']['TrainingJobName'] }}" + % task_id + ) # need to strip the double quotes in json to get the string - job_name = "{{ ti.xcom_pull(task_ids='%s')['Tuning']['TrainingJobDefinition']['StaticHyperParameters']" \ - "['sagemaker_job_name'].strip('%s') }}" % (task_id, '"') + job_name = ( + "{{ ti.xcom_pull(task_ids='%s')['Tuning']['TrainingJobDefinition']['StaticHyperParameters']" + "['sagemaker_job_name'].strip('%s') }}" % (task_id, '"') + ) else: raise ValueError("task_type must be either 'training', 'tuning' or None.") estimator._current_job_name = training_job @@ -346,34 +358,38 @@ def prepare_framework_container_def(model, instance_type, s3_operations): if not deploy_image: region_name = model.sagemaker_session.boto_session.region_name deploy_image = fw_utils.create_image_uri( - region_name, model.__framework_name__, instance_type, model.framework_version, model.py_version) + region_name, + model.__framework_name__, + instance_type, + model.framework_version, + model.py_version, + ) base_name = utils.base_name_from_image(deploy_image) model.name = model.name or utils.name_from_base(base_name) bucket = model.bucket or model.sagemaker_session._default_bucket script = os.path.basename(model.entry_point) - key = '{}/source/sourcedir.tar.gz'.format(model.name) + key = "{}/source/sourcedir.tar.gz".format(model.name) - if model.source_dir and model.source_dir.lower().startswith('s3://'): + if model.source_dir and model.source_dir.lower().startswith("s3://"): code_dir = model.source_dir model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) else: - code_dir = 's3://{}/{}'.format(bucket, key) + code_dir = "s3://{}/{}".format(bucket, key) model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) - s3_operations['S3Upload'] = [{ - 'Path': model.source_dir or script, - 'Bucket': bucket, - 'Key': key, - 'Tar': True - }] + s3_operations["S3Upload"] = [ + {"Path": model.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True} + ] deploy_env = dict(model.env) deploy_env.update(model._framework_env_vars()) try: if model.model_server_workers: - deploy_env[sagemaker.model.MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(model.model_server_workers) + deploy_env[sagemaker.model.MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str( + model.model_server_workers + ) except AttributeError: # This applies to a FrameworkModel which is not SageMaker Deep Learning Framework Model pass @@ -401,28 +417,37 @@ def model_config(instance_type, model, role=None, image=None): container_def = prepare_framework_container_def(model, instance_type, s3_operations) else: container_def = model.prepare_container_def(instance_type) - base_name = utils.base_name_from_image(container_def['Image']) + base_name = utils.base_name_from_image(container_def["Image"]) model.name = model.name or utils.name_from_base(base_name) primary_container = session._expand_container_def(container_def) config = { - 'ModelName': model.name, - 'PrimaryContainer': primary_container, - 'ExecutionRoleArn': role or model.role + "ModelName": model.name, + "PrimaryContainer": primary_container, + "ExecutionRoleArn": role or model.role, } if model.vpc_config: - config['VpcConfig'] = model.vpc_config + config["VpcConfig"] = model.vpc_config if s3_operations: - config['S3Operations'] = s3_operations + config["S3Operations"] = s3_operations return config -def model_config_from_estimator(instance_type, estimator, task_id, task_type, role=None, image=None, name=None, - model_server_workers=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): +def model_config_from_estimator( + instance_type, + estimator, + task_id, + task_type, + role=None, + image=None, + name=None, + model_server_workers=None, + vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, +): """Export Airflow model config from a SageMaker estimator Args: @@ -451,22 +476,36 @@ def model_config_from_estimator(instance_type, estimator, task_id, task_type, ro """ update_estimator_from_task(estimator, task_id, task_type) if isinstance(estimator, sagemaker.estimator.Estimator): - model = estimator.create_model(role=role, image=image, vpc_config_override=vpc_config_override) + model = estimator.create_model( + role=role, image=image, vpc_config_override=vpc_config_override + ) elif isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase): model = estimator.create_model(vpc_config_override=vpc_config_override) elif isinstance(estimator, sagemaker.estimator.Framework): - model = estimator.create_model(model_server_workers=model_server_workers, role=role, - vpc_config_override=vpc_config_override) + model = estimator.create_model( + model_server_workers=model_server_workers, + role=role, + vpc_config_override=vpc_config_override, + ) else: - raise TypeError('Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework' - ' or sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.') + raise TypeError( + "Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework" + " or sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase." + ) model.name = name return model_config(instance_type, model, role, image) -def transform_config(transformer, data, data_type='S3Prefix', content_type=None, compression_type=None, - split_type=None, job_name=None): +def transform_config( + transformer, + data, + data_type="S3Prefix", + content_type=None, + compression_type=None, + split_type=None, + job_name=None, +): """Export Airflow transform config from a SageMaker transformer Args: @@ -494,48 +533,73 @@ def transform_config(transformer, data, data_type='S3Prefix', content_type=None, transformer._current_job_name = job_name else: base_name = transformer.base_transform_job_name - transformer._current_job_name = utils.name_from_base(base_name) \ - if base_name is not None else transformer.model_name + transformer._current_job_name = ( + utils.name_from_base(base_name) if base_name is not None else transformer.model_name + ) if transformer.output_path is None: - transformer.output_path = 's3://{}/{}'.format( - transformer.sagemaker_session.default_bucket(), transformer._current_job_name) + transformer.output_path = "s3://{}/{}".format( + transformer.sagemaker_session.default_bucket(), transformer._current_job_name + ) job_config = sagemaker.transformer._TransformJob._load_config( - data, data_type, content_type, compression_type, split_type, transformer) + data, data_type, content_type, compression_type, split_type, transformer + ) config = { - 'TransformJobName': transformer._current_job_name, - 'ModelName': transformer.model_name, - 'TransformInput': job_config['input_config'], - 'TransformOutput': job_config['output_config'], - 'TransformResources': job_config['resource_config'], + "TransformJobName": transformer._current_job_name, + "ModelName": transformer.model_name, + "TransformInput": job_config["input_config"], + "TransformOutput": job_config["output_config"], + "TransformResources": job_config["resource_config"], } if transformer.strategy is not None: - config['BatchStrategy'] = transformer.strategy + config["BatchStrategy"] = transformer.strategy if transformer.max_concurrent_transforms is not None: - config['MaxConcurrentTransforms'] = transformer.max_concurrent_transforms + config["MaxConcurrentTransforms"] = transformer.max_concurrent_transforms if transformer.max_payload is not None: - config['MaxPayloadInMB'] = transformer.max_payload + config["MaxPayloadInMB"] = transformer.max_payload if transformer.env is not None: - config['Environment'] = transformer.env + config["Environment"] = transformer.env if transformer.tags is not None: - config['Tags'] = transformer.tags + config["Tags"] = transformer.tags return config -def transform_config_from_estimator(estimator, task_id, task_type, instance_count, instance_type, data, - data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, - job_name=None, model_name=None, strategy=None, assemble_with=None, output_path=None, - output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, role=None, volume_kms_key=None, - model_server_workers=None, image=None, vpc_config_override=None): +def transform_config_from_estimator( + estimator, + task_id, + task_type, + instance_count, + instance_type, + data, + data_type="S3Prefix", + content_type=None, + compression_type=None, + split_type=None, + job_name=None, + model_name=None, + strategy=None, + assemble_with=None, + output_path=None, + output_kms_key=None, + accept=None, + env=None, + max_concurrent_transforms=None, + max_payload=None, + tags=None, + role=None, + volume_kms_key=None, + model_server_workers=None, + image=None, + vpc_config_override=None, +): """Export Airflow transform config from a SageMaker estimator Args: @@ -591,28 +655,58 @@ def transform_config_from_estimator(estimator, task_id, task_type, instance_coun Returns: dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow. """ - model_base_config = model_config_from_estimator(instance_type=instance_type, estimator=estimator, task_id=task_id, - task_type=task_type, role=role, image=image, name=model_name, - model_server_workers=model_server_workers, - vpc_config_override=vpc_config_override) + model_base_config = model_config_from_estimator( + instance_type=instance_type, + estimator=estimator, + task_id=task_id, + task_type=task_type, + role=role, + image=image, + name=model_name, + model_server_workers=model_server_workers, + vpc_config_override=vpc_config_override, + ) if isinstance(estimator, sagemaker.estimator.Framework): - transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path, - output_kms_key, accept, env, max_concurrent_transforms, - max_payload, tags, role, model_server_workers, volume_kms_key) + transformer = estimator.transformer( + instance_count, + instance_type, + strategy, + assemble_with, + output_path, + output_kms_key, + accept, + env, + max_concurrent_transforms, + max_payload, + tags, + role, + model_server_workers, + volume_kms_key, + ) else: - transformer = estimator.transformer(instance_count, instance_type, strategy, assemble_with, output_path, - output_kms_key, accept, env, max_concurrent_transforms, - max_payload, tags, role, volume_kms_key) - transformer.model_name = model_base_config['ModelName'] - - transform_base_config = transform_config(transformer, data, data_type, content_type, compression_type, - split_type, job_name) - - config = { - 'Model': model_base_config, - 'Transform': transform_base_config - } + transformer = estimator.transformer( + instance_count, + instance_type, + strategy, + assemble_with, + output_path, + output_kms_key, + accept, + env, + max_concurrent_transforms, + max_payload, + tags, + role, + volume_kms_key, + ) + transformer.model_name = model_base_config["ModelName"] + + transform_base_config = transform_config( + transformer, data, data_type, content_type, compression_type, split_type, job_name + ) + + config = {"Model": model_base_config, "Transform": transform_base_config} return config @@ -636,34 +730,42 @@ def deploy_config(model, initial_instance_count, instance_type, endpoint_name=No """ model_base_config = model_config(instance_type, model) - production_variant = sagemaker.production_variant(model.name, instance_type, initial_instance_count) + production_variant = sagemaker.production_variant( + model.name, instance_type, initial_instance_count + ) name = model.name - config_options = {'EndpointConfigName': name, 'ProductionVariants': [production_variant]} + config_options = {"EndpointConfigName": name, "ProductionVariants": [production_variant]} if tags is not None: - config_options['Tags'] = tags + config_options["Tags"] = tags endpoint_name = endpoint_name or name - endpoint_base_config = { - 'EndpointName': endpoint_name, - 'EndpointConfigName': name - } + endpoint_base_config = {"EndpointName": endpoint_name, "EndpointConfigName": name} config = { - 'Model': model_base_config, - 'EndpointConfig': config_options, - 'Endpoint': endpoint_base_config + "Model": model_base_config, + "EndpointConfig": config_options, + "Endpoint": endpoint_base_config, } # if there is s3 operations needed for model, move it to root level of config - s3_operations = model_base_config.pop('S3Operations', None) + s3_operations = model_base_config.pop("S3Operations", None) if s3_operations is not None: - config['S3Operations'] = s3_operations + config["S3Operations"] = s3_operations return config -def deploy_config_from_estimator(estimator, task_id, task_type, initial_instance_count, instance_type, - model_name=None, endpoint_name=None, tags=None, **kwargs): +def deploy_config_from_estimator( + estimator, + task_id, + task_type, + initial_instance_count, + instance_type, + model_name=None, + endpoint_name=None, + tags=None, + **kwargs +): """Export Airflow deploy config from a SageMaker estimator Args: diff --git a/tests/component/test_mxnet_estimator.py b/tests/component/test_mxnet_estimator.py index 5316d56130..7d8ea7afb0 100644 --- a/tests/component/test_mxnet_estimator.py +++ b/tests/component/test_mxnet_estimator.py @@ -17,56 +17,71 @@ from sagemaker.mxnet import MXNet -SCRIPT = 'resnet_cifar_10.py' -TIMESTAMP = '2017-11-06-14:14:15.673' +SCRIPT = "resnet_cifar_10.py" +TIMESTAMP = "2017-11-06-14:14:15.673" TIME = 1510006209.073025 -BUCKET_NAME = 'mybucket' +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE_GPU = 'ml.p2.xlarge' -INSTANCE_TYPE_CPU = 'ml.m4.xlarge' -CPU_IMAGE_NAME = 'sagemaker-mxnet-py2-cpu' -GPU_IMAGE_NAME = 'sagemaker-mxnet-py2-gpu' -REGION = 'us-west-2' +INSTANCE_TYPE_GPU = "ml.p2.xlarge" +INSTANCE_TYPE_CPU = "ml.m4.xlarge" +CPU_IMAGE_NAME = "sagemaker-mxnet-py2-cpu" +GPU_IMAGE_NAME = "sagemaker-mxnet-py2-gpu" +REGION = "us-west-2" IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" -REGION = 'us-west-2' -ROLE = 'SagemakerRole' -SOURCE_DIR = 's3://fefergerger' +REGION = "us-west-2" +ROLE = "SagemakerRole" +SOURCE_DIR = "s3://fefergerger" @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - ims = Mock(name='sagemaker_session', boto_session=boto_mock, - config=None, local_mode=False, region_name=REGION) + boto_mock = Mock(name="boto_session", region_name=REGION) + ims = Mock( + name="sagemaker_session", + boto_session=boto_mock, + config=None, + local_mode=False, + region_name=REGION, + ) - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) ims.expand_role = Mock(name="expand_role", return_value=ROLE) - ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts': - {'S3ModelArtifacts': 's3://m/m.tar.gz'}}) + ims.sagemaker_client.describe_training_job = Mock( + return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + ) return ims # Test that we pass all necessary fields from estimator to the session when we call deploy def test_deploy(sagemaker_session, tf_version): - estimator = MXNet(entry_point=SCRIPT, source_dir=SOURCE_DIR, role=ROLE, - framework_version=tf_version, - train_instance_count=2, train_instance_type=INSTANCE_TYPE_GPU, - sagemaker_session=sagemaker_session, - base_job_name='test-cifar') + estimator = MXNet( + entry_point=SCRIPT, + source_dir=SOURCE_DIR, + role=ROLE, + framework_version=tf_version, + train_instance_count=2, + train_instance_type=INSTANCE_TYPE_GPU, + sagemaker_session=sagemaker_session, + base_job_name="test-cifar", + ) - estimator.fit('s3://mybucket/train') - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + estimator.fit("s3://mybucket/train") + print("job succeeded: {}".format(estimator.latest_training_job.name)) estimator.deploy(initial_instance_count=1, instance_type=INSTANCE_TYPE_CPU) - image = IMAGE_URI_FORMAT_STRING.format(REGION, CPU_IMAGE_NAME, tf_version, 'cpu', 'py2') + image = IMAGE_URI_FORMAT_STRING.format(REGION, CPU_IMAGE_NAME, tf_version, "cpu", "py2") sagemaker_session.create_model.assert_called_with( estimator._current_job_name, ROLE, - {'Environment': - {'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_SUBMIT_DIRECTORY': SOURCE_DIR, - 'SAGEMAKER_REGION': REGION, - 'SAGEMAKER_PROGRAM': SCRIPT}, - 'Image': image, - 'ModelDataUrl': 's3://m/m.tar.gz'}) + { + "Environment": { + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_SUBMIT_DIRECTORY": SOURCE_DIR, + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_PROGRAM": SCRIPT, + }, + "Image": image, + "ModelDataUrl": "s3://m/m.tar.gz", + }, + ) diff --git a/tests/component/test_tf_estimator.py b/tests/component/test_tf_estimator.py index 8d7ede0429..b2e4b90fd7 100644 --- a/tests/component/test_tf_estimator.py +++ b/tests/component/test_tf_estimator.py @@ -17,56 +17,71 @@ from sagemaker.tensorflow import TensorFlow -SCRIPT = 'resnet_cifar_10.py' -TIMESTAMP = '2017-11-06-14:14:15.673' +SCRIPT = "resnet_cifar_10.py" +TIMESTAMP = "2017-11-06-14:14:15.673" TIME = 1510006209.073025 -BUCKET_NAME = 'mybucket' +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE_GPU = 'ml.p2.xlarge' -INSTANCE_TYPE_CPU = 'ml.m4.xlarge' -CPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-cpu' -GPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-gpu' -REGION = 'us-west-2' +INSTANCE_TYPE_GPU = "ml.p2.xlarge" +INSTANCE_TYPE_CPU = "ml.m4.xlarge" +CPU_IMAGE_NAME = "sagemaker-tensorflow-py2-cpu" +GPU_IMAGE_NAME = "sagemaker-tensorflow-py2-gpu" +REGION = "us-west-2" IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" -REGION = 'us-west-2' -ROLE = 'SagemakerRole' -SOURCE_DIR = 's3://fefergerger' +REGION = "us-west-2" +ROLE = "SagemakerRole" +SOURCE_DIR = "s3://fefergerger" @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - ims = Mock(name='sagemaker_session', boto_session=boto_mock, config=None, - local_mode=False, region_name=REGION) - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + boto_mock = Mock(name="boto_session", region_name=REGION) + ims = Mock( + name="sagemaker_session", + boto_session=boto_mock, + config=None, + local_mode=False, + region_name=REGION, + ) + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) ims.expand_role = Mock(name="expand_role", return_value=ROLE) - ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts': - {'S3ModelArtifacts': 's3://m/m.tar.gz'}}) + ims.sagemaker_client.describe_training_job = Mock( + return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + ) return ims # Test that we pass all necessary fields from estimator to the session when we call deploy def test_deploy(sagemaker_session, tf_version): - estimator = TensorFlow(entry_point=SCRIPT, source_dir=SOURCE_DIR, role=ROLE, - framework_version=tf_version, - train_instance_count=2, train_instance_type=INSTANCE_TYPE_CPU, - sagemaker_session=sagemaker_session, - base_job_name='test-cifar') + estimator = TensorFlow( + entry_point=SCRIPT, + source_dir=SOURCE_DIR, + role=ROLE, + framework_version=tf_version, + train_instance_count=2, + train_instance_type=INSTANCE_TYPE_CPU, + sagemaker_session=sagemaker_session, + base_job_name="test-cifar", + ) - estimator.fit('s3://mybucket/train') - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + estimator.fit("s3://mybucket/train") + print("job succeeded: {}".format(estimator.latest_training_job.name)) estimator.deploy(initial_instance_count=1, instance_type=INSTANCE_TYPE_CPU) - image = IMAGE_URI_FORMAT_STRING.format(REGION, CPU_IMAGE_NAME, tf_version, 'cpu', 'py2') + image = IMAGE_URI_FORMAT_STRING.format(REGION, CPU_IMAGE_NAME, tf_version, "cpu", "py2") sagemaker_session.create_model.assert_called_with( estimator._current_job_name, ROLE, - {'Environment': - {'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_SUBMIT_DIRECTORY': SOURCE_DIR, - 'SAGEMAKER_REQUIREMENTS': '', - 'SAGEMAKER_REGION': REGION, - 'SAGEMAKER_PROGRAM': SCRIPT}, - 'Image': image, - 'ModelDataUrl': 's3://m/m.tar.gz'}) + { + "Environment": { + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_SUBMIT_DIRECTORY": SOURCE_DIR, + "SAGEMAKER_REQUIREMENTS": "", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_PROGRAM": SCRIPT, + }, + "Image": image, + "ModelDataUrl": "s3://m/m.tar.gz", + }, + ) diff --git a/tests/conftest.py b/tests/conftest.py index a110adbd29..00033e6fd2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,72 +28,85 @@ from sagemaker.sklearn.defaults import SKLEARN_VERSION from sagemaker.tensorflow.estimator import TensorFlow -DEFAULT_REGION = 'us-west-2' +DEFAULT_REGION = "us-west-2" def pytest_addoption(parser): - parser.addoption('--sagemaker-client-config', action='store', default=None) - parser.addoption('--sagemaker-runtime-config', action='store', default=None) - parser.addoption('--boto-config', action='store', default=None) - parser.addoption('--chainer-full-version', action='store', default=Chainer.LATEST_VERSION) - parser.addoption('--mxnet-full-version', action='store', default=MXNet.LATEST_VERSION) - parser.addoption('--ei-mxnet-full-version', action='store', default=MXNet.LATEST_VERSION) - parser.addoption('--pytorch-full-version', action='store', default=PyTorch.LATEST_VERSION) - parser.addoption('--rl-coach-mxnet-full-version', action='store', - default=RLEstimator.COACH_LATEST_VERSION_MXNET) - parser.addoption('--rl-coach-tf-full-version', action='store', - default=RLEstimator.COACH_LATEST_VERSION_TF) - parser.addoption('--rl-ray-full-version', action='store', - default=RLEstimator.RAY_LATEST_VERSION) - parser.addoption('--sklearn-full-version', action='store', default=SKLEARN_VERSION) - parser.addoption('--tf-full-version', action='store', default=TensorFlow.LATEST_VERSION) - parser.addoption('--ei-tf-full-version', action='store', default=TensorFlow.LATEST_VERSION) + parser.addoption("--sagemaker-client-config", action="store", default=None) + parser.addoption("--sagemaker-runtime-config", action="store", default=None) + parser.addoption("--boto-config", action="store", default=None) + parser.addoption("--chainer-full-version", action="store", default=Chainer.LATEST_VERSION) + parser.addoption("--mxnet-full-version", action="store", default=MXNet.LATEST_VERSION) + parser.addoption("--ei-mxnet-full-version", action="store", default=MXNet.LATEST_VERSION) + parser.addoption("--pytorch-full-version", action="store", default=PyTorch.LATEST_VERSION) + parser.addoption( + "--rl-coach-mxnet-full-version", + action="store", + default=RLEstimator.COACH_LATEST_VERSION_MXNET, + ) + parser.addoption( + "--rl-coach-tf-full-version", action="store", default=RLEstimator.COACH_LATEST_VERSION_TF + ) + parser.addoption( + "--rl-ray-full-version", action="store", default=RLEstimator.RAY_LATEST_VERSION + ) + parser.addoption("--sklearn-full-version", action="store", default=SKLEARN_VERSION) + parser.addoption("--tf-full-version", action="store", default=TensorFlow.LATEST_VERSION) + parser.addoption("--ei-tf-full-version", action="store", default=TensorFlow.LATEST_VERSION) def pytest_configure(config): - bc = config.getoption('--boto-config') + bc = config.getoption("--boto-config") parsed = json.loads(bc) if bc else {} - region = parsed.get('region_name', boto3.session.Session().region_name) + region = parsed.get("region_name", boto3.session.Session().region_name) if region: - os.environ['TEST_AWS_REGION_NAME'] = region + os.environ["TEST_AWS_REGION_NAME"] = region -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def sagemaker_client_config(request): - config = request.config.getoption('--sagemaker-client-config') + config = request.config.getoption("--sagemaker-client-config") return json.loads(config) if config else dict() -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def sagemaker_runtime_config(request): - config = request.config.getoption('--sagemaker-runtime-config') + config = request.config.getoption("--sagemaker-runtime-config") return json.loads(config) if config else None -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def boto_config(request): - config = request.config.getoption('--boto-config') + config = request.config.getoption("--boto-config") return json.loads(config) if config else None -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_config): - boto_session = boto3.Session(**boto_config) if boto_config else boto3.Session( - region_name=DEFAULT_REGION) - sagemaker_client_config.setdefault('config', Config(retries=dict(max_attempts=10))) - sagemaker_client = boto_session.client('sagemaker', - **sagemaker_client_config) if sagemaker_client_config else None - runtime_client = (boto_session.client('sagemaker-runtime', - **sagemaker_runtime_config) if sagemaker_runtime_config - else None) - - return Session(boto_session=boto_session, - sagemaker_client=sagemaker_client, - sagemaker_runtime_client=runtime_client) - - -@pytest.fixture(scope='session') + boto_session = ( + boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION) + ) + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) + sagemaker_client = ( + boto_session.client("sagemaker", **sagemaker_client_config) + if sagemaker_client_config + else None + ) + runtime_client = ( + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) + if sagemaker_runtime_config + else None + ) + + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=runtime_client, + ) + + +@pytest.fixture(scope="session") def sagemaker_local_session(boto_config): if boto_config: boto_session = boto3.Session(**boto_config) @@ -102,94 +115,129 @@ def sagemaker_local_session(boto_config): return LocalSession(boto_session=boto_session) -@pytest.fixture(scope='module', params=['4.0', '4.0.0', '4.1', '4.1.0', '5.0', '5.0.0']) +@pytest.fixture(scope="module", params=["4.0", "4.0.0", "4.1", "4.1.0", "5.0", "5.0.0"]) def chainer_version(request): return request.param -@pytest.fixture(scope='module', params=['0.12', '0.12.1', '1.0', '1.0.0', '1.1', '1.1.0', '1.2', - '1.2.1', '1.3', '1.3.0', '1.4', '1.4.0']) +@pytest.fixture( + scope="module", + params=[ + "0.12", + "0.12.1", + "1.0", + "1.0.0", + "1.1", + "1.1.0", + "1.2", + "1.2.1", + "1.3", + "1.3.0", + "1.4", + "1.4.0", + ], +) def mxnet_version(request): return request.param -@pytest.fixture(scope='module', params=['0.4', '0.4.0', '1.0', '1.0.0']) +@pytest.fixture(scope="module", params=["0.4", "0.4.0", "1.0", "1.0.0"]) def pytorch_version(request): return request.param -@pytest.fixture(scope='module', params=['0.20.0']) +@pytest.fixture(scope="module", params=["0.20.0"]) def sklearn_version(request): return request.param -@pytest.fixture(scope='module', params=['1.4', '1.4.1', '1.5', '1.5.0', '1.6', '1.6.0', - '1.7', '1.7.0', '1.8', '1.8.0', '1.9', '1.9.0', - '1.10', '1.10.0', '1.11', '1.11.0', '1.12', '1.12.0']) +@pytest.fixture( + scope="module", + params=[ + "1.4", + "1.4.1", + "1.5", + "1.5.0", + "1.6", + "1.6.0", + "1.7", + "1.7.0", + "1.8", + "1.8.0", + "1.9", + "1.9.0", + "1.10", + "1.10.0", + "1.11", + "1.11.0", + "1.12", + "1.12.0", + ], +) def tf_version(request): return request.param -@pytest.fixture(scope='module', params=['0.10.1', '0.10.1', '0.11', '0.11.0', '0.11.1']) +@pytest.fixture(scope="module", params=["0.10.1", "0.10.1", "0.11", "0.11.0", "0.11.1"]) def rl_coach_tf_version(request): return request.param -@pytest.fixture(scope='module', params=['0.11', '0.11.0']) +@pytest.fixture(scope="module", params=["0.11", "0.11.0"]) def rl_coach_mxnet_version(request): return request.param -@pytest.fixture(scope='module', params=['0.5', '0.5.3', '0.6', '0.6.5']) +@pytest.fixture(scope="module", params=["0.5", "0.5.3", "0.6", "0.6.5"]) def rl_ray_version(request): return request.param -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def chainer_full_version(request): - return request.config.getoption('--chainer-full-version') + return request.config.getoption("--chainer-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mxnet_full_version(request): - return request.config.getoption('--mxnet-full-version') + return request.config.getoption("--mxnet-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ei_mxnet_full_version(request): - return request.config.getoption('--ei-mxnet-full-version') + return request.config.getoption("--ei-mxnet-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def pytorch_full_version(request): - return request.config.getoption('--pytorch-full-version') + return request.config.getoption("--pytorch-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def rl_coach_mxnet_full_version(request): - return request.config.getoption('--rl-coach-mxnet-full-version') + return request.config.getoption("--rl-coach-mxnet-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def rl_coach_tf_full_version(request): - return request.config.getoption('--rl-coach-tf-full-version') + return request.config.getoption("--rl-coach-tf-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def rl_ray_full_version(request): - return request.config.getoption('--rl-ray-full-version') + return request.config.getoption("--rl-ray-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def sklearn_full_version(request): - return request.config.getoption('--sklearn-full-version') + return request.config.getoption("--sklearn-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def tf_full_version(request): - return request.config.getoption('--tf-full-version') + return request.config.getoption("--tf-full-version") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def ei_tf_full_version(request): - return request.config.getoption('--ei-tf-full-version') + return request.config.getoption("--ei-tf-full-version") diff --git a/tests/data/chainer_mnist/distributed_mnist.py b/tests/data/chainer_mnist/distributed_mnist.py index e43b613725..7507fd1ab3 100644 --- a/tests/data/chainer_mnist/distributed_mnist.py +++ b/tests/data/chainer_mnist/distributed_mnist.py @@ -46,7 +46,7 @@ def __call__(self, x): def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format): - images = raw['x'] + images = raw["x"] if ndim == 2: images = images.reshape(-1, 28, 28) elif ndim == 3: @@ -54,74 +54,73 @@ def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb if rgb_format: images = np.broadcast_to(images, (len(images), 3) + images.shape[2:]) elif ndim != 1: - raise ValueError('invalid ndim for MNIST dataset') + raise ValueError("invalid ndim for MNIST dataset") images = images.astype(image_dtype) - images *= scale / 255. + images *= scale / 255.0 if withlabel: - labels = raw['y'].astype(label_dtype) + labels = raw["y"].astype(label_dtype) return tuple_dataset.TupleDataset(images, labels) return images -if __name__ == '__main__': +if __name__ == "__main__": env = sagemaker_containers.training_env() parser = argparse.ArgumentParser() # Data and model checkpoints directories - parser.add_argument('--epochs', type=int, default=1) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--communicator', type=str, default='pure_nccl') - parser.add_argument('--frequency', type=int, default=20) - parser.add_argument('--units', type=int, default=1000) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--communicator", type=str, default="pure_nccl") + parser.add_argument("--frequency", type=int, default=20) + parser.add_argument("--units", type=int, default=1000) - parser.add_argument('--model-dir', type=str) - parser.add_argument('--output-data-dir', type=str, default=env.output_data_dir) - parser.add_argument('--host', type=str, default=env.current_host) - parser.add_argument('--num-gpus', type=int, default=env.num_gpus) + parser.add_argument("--model-dir", type=str) + parser.add_argument("--output-data-dir", type=str, default=env.output_data_dir) + parser.add_argument("--host", type=str, default=env.current_host) + parser.add_argument("--num-gpus", type=int, default=env.num_gpus) - parser.add_argument('--train', type=str, default=env.channel_input_dirs['train']) - parser.add_argument('--test', type=str, default=env.channel_input_dirs['test']) + parser.add_argument("--train", type=str, default=env.channel_input_dirs["train"]) + parser.add_argument("--test", type=str, default=env.channel_input_dirs["test"]) args = parser.parse_args() - train_file = np.load(os.path.join(args.train, 'train.npz')) - test_file = np.load(os.path.join(args.test, 'test.npz')) + train_file = np.load(os.path.join(args.train, "train.npz")) + test_file = np.load(os.path.join(args.test, "test.npz")) - logger.info('Current host: {}'.format(args.host)) + logger.info("Current host: {}".format(args.host)) - communicator = 'naive' if args.num_gpus == 0 else args.communicator + communicator = "naive" if args.num_gpus == 0 else args.communicator comm = chainermn.create_communicator(communicator) device = comm.intra_rank if args.num_gpus > 0 else -1 - print('==========================================') - print('Using {} communicator'.format(comm)) - print('Num unit: {}'.format(args.units)) - print('Num Minibatch-size: {}'.format(args.batch_size)) - print('Num epoch: {}'.format(args.epochs)) - print('==========================================') + print("==========================================") + print("Using {} communicator".format(comm)) + print("Num unit: {}".format(args.units)) + print("Num Minibatch-size: {}".format(args.batch_size)) + print("Num epoch: {}".format(args.epochs)) + print("==========================================") model = L.Classifier(MLP(args.units, 10)) if device >= 0: chainer.cuda.get_device(device).use() # Create a multi node optimizer from a standard Chainer optimizer. - optimizer = chainermn.create_multi_node_optimizer( - chainer.optimizers.Adam(), comm) + optimizer = chainermn.create_multi_node_optimizer(chainer.optimizers.Adam(), comm) optimizer.setup(model) - train_file = np.load(os.path.join(args.train, 'train.npz')) - test_file = np.load(os.path.join(args.test, 'test.npz')) + train_file = np.load(os.path.join(args.train, "train.npz")) + test_file = np.load(os.path.join(args.test, "test.npz")) preprocess_mnist_options = { - 'withlabel': True, - 'ndim': 1, - 'scale': 1., - 'image_dtype': np.float32, - 'label_dtype': np.int32, - 'rgb_format': False + "withlabel": True, + "ndim": 1, + "scale": 1.0, + "image_dtype": np.float32, + "label_dtype": np.int32, + "rgb_format": False, } train_dataset = _preprocess_mnist(train_file, **preprocess_mnist_options) @@ -129,10 +128,11 @@ def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb train_iter = chainer.iterators.SerialIterator(train_dataset, args.batch_size) test_iter = chainer.iterators.SerialIterator( - test_dataset, args.batch_size, repeat=False, shuffle=False) + test_dataset, args.batch_size, repeat=False, shuffle=False + ) updater = training.StandardUpdater(train_iter, optimizer, device=device) - trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.output_data_dir) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.output_data_dir) # Create a multi node evaluator from a standard Chainer evaluator. evaluator = extensions.Evaluator(test_iter, model, device=device) @@ -145,32 +145,39 @@ def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb if extensions.PlotReport.available(): trainer.extend( extensions.PlotReport( - ['main/loss', 'validation/main/loss'], - 'epoch', - file_name='loss.png')) + ["main/loss", "validation/main/loss"], "epoch", file_name="loss.png" + ) + ) trainer.extend( extensions.PlotReport( - ['main/accuracy', 'validation/main/accuracy'], - 'epoch', - file_name='accuracy.png')) - trainer.extend(extensions.snapshot(), trigger=(args.frequency, 'epoch')) - trainer.extend(extensions.dump_graph('main/loss')) + ["main/accuracy", "validation/main/accuracy"], "epoch", file_name="accuracy.png" + ) + ) + trainer.extend(extensions.snapshot(), trigger=(args.frequency, "epoch")) + trainer.extend(extensions.dump_graph("main/loss")) trainer.extend(extensions.LogReport()) trainer.extend( - extensions.PrintReport([ - 'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', - 'validation/main/accuracy', 'elapsed_time' - ])) + extensions.PrintReport( + [ + "epoch", + "main/loss", + "validation/main/loss", + "main/accuracy", + "validation/main/accuracy", + "elapsed_time", + ] + ) + ) trainer.extend(extensions.ProgressBar()) trainer.run() # only save the model in the master node - if args.host == 'algo-1': - serializers.save_npz(os.path.join(env.model_dir, 'model.npz'), model) + if args.host == "algo-1": + serializers.save_npz(os.path.join(env.model_dir, "model.npz"), model) def model_fn(model_dir): model = L.Classifier(MLP(1000, 10)) - serializers.load_npz(os.path.join(model_dir, 'model.npz'), model) + serializers.load_npz(os.path.join(model_dir, "model.npz"), model) return model.predictor diff --git a/tests/data/chainer_mnist/failure_script.py b/tests/data/chainer_mnist/failure_script.py index b528919896..b19dd46c01 100644 --- a/tests/data/chainer_mnist/failure_script.py +++ b/tests/data/chainer_mnist/failure_script.py @@ -11,6 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -if __name__ == '__main__': +if __name__ == "__main__": """For use with integration tests expecting failures.""" - raise Exception('This failure is expected.') + raise Exception("This failure is expected.") diff --git a/tests/data/chainer_mnist/mnist.py b/tests/data/chainer_mnist/mnist.py index f52d982966..c31a0167db 100644 --- a/tests/data/chainer_mnist/mnist.py +++ b/tests/data/chainer_mnist/mnist.py @@ -42,55 +42,56 @@ def __call__(self, x): def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format): - images = raw['x'] + images = raw["x"] if ndim == 2: images = images.reshape(-1, 28, 28) elif ndim == 3: images = images.reshape(-1, 1, 28, 28) if rgb_format: - images = np.broadcast_to(images, - (len(images), 3) + images.shape[2:]) + images = np.broadcast_to(images, (len(images), 3) + images.shape[2:]) elif ndim != 1: - raise ValueError('invalid ndim for MNIST dataset') + raise ValueError("invalid ndim for MNIST dataset") images = images.astype(image_dtype) - images *= scale / 255. + images *= scale / 255.0 if withlabel: - labels = raw['y'].astype(label_dtype) + labels = raw["y"].astype(label_dtype) return tuple_dataset.TupleDataset(images, labels) else: return images -if __name__ == '__main__': +if __name__ == "__main__": env = sagemaker_containers.training_env() parser = argparse.ArgumentParser() # Data and model checkpoints directories - parser.add_argument('--units', type=int, default=1000) - parser.add_argument('--epochs', type=int, default=20) - parser.add_argument('--frequency', type=int, default=20) - parser.add_argument('--batch-size', type=int, default=100) - parser.add_argument('--alpha', type=float, default=0.001) - parser.add_argument('--model-dir', type=str, default=env.model_dir) + parser.add_argument("--units", type=int, default=1000) + parser.add_argument("--epochs", type=int, default=20) + parser.add_argument("--frequency", type=int, default=20) + parser.add_argument("--batch-size", type=int, default=100) + parser.add_argument("--alpha", type=float, default=0.001) + parser.add_argument("--model-dir", type=str, default=env.model_dir) - parser.add_argument('--train', type=str, default=env.channel_input_dirs['train']) - parser.add_argument('--test', type=str, default=env.channel_input_dirs['test']) + parser.add_argument("--train", type=str, default=env.channel_input_dirs["train"]) + parser.add_argument("--test", type=str, default=env.channel_input_dirs["test"]) - parser.add_argument('--num-gpus', type=int, default=env.num_gpus) + parser.add_argument("--num-gpus", type=int, default=env.num_gpus) args = parser.parse_args() - train_file = np.load(os.path.join(args.train, 'train.npz')) - test_file = np.load(os.path.join(args.test, 'test.npz')) + train_file = np.load(os.path.join(args.train, "train.npz")) + test_file = np.load(os.path.join(args.test, "test.npz")) - preprocess_mnist_options = {'withlabel': True, - 'ndim': 1, - 'scale': 1., - 'image_dtype': np.float32, - 'label_dtype': np.int32, - 'rgb_format': False} + preprocess_mnist_options = { + "withlabel": True, + "ndim": 1, + "scale": 1.0, + "image_dtype": np.float32, + "label_dtype": np.int32, + "rgb_format": False, + } train = _preprocess_mnist(train_file, **preprocess_mnist_options) test = _preprocess_mnist(test_file, **preprocess_mnist_options) @@ -109,16 +110,14 @@ def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb # Load the MNIST dataset train_iter = chainer.iterators.SerialIterator(train, args.batch_size) - test_iter = chainer.iterators.SerialIterator(test, args.batch_size, - repeat=False, shuffle=False) + test_iter = chainer.iterators.SerialIterator(test, args.batch_size, repeat=False, shuffle=False) # Set up a trainer device = 0 if chainer.cuda.available else -1 # -1 indicates CPU, 0 indicates first GPU device. if chainer.cuda.available: def device_name(device_intra_rank): - return 'main' if device_intra_rank == 0 else str(device_intra_rank) - + return "main" if device_intra_rank == 0 else str(device_intra_rank) devices = {device_name(device): device for device in range(args.num_gpus)} updater = training.updater.ParallelUpdater( @@ -126,13 +125,14 @@ def device_name(device_intra_rank): optimizer, # The device of the name 'main' is used as a "master", while others are # used as slaves. Names other than 'main' are arbitrary. - devices=devices) + devices=devices, + ) else: updater = training.updater.StandardUpdater(train_iter, optimizer, device=device) # Write output files to output_data_dir. # These are zipped and uploaded to S3 output path as output.tar.gz. - trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=env.output_data_dir) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=env.output_data_dir) # Evaluate the model with the test dataset for each epoch @@ -140,10 +140,10 @@ def device_name(device_intra_rank): # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. - trainer.extend(extensions.dump_graph('main/loss')) + trainer.extend(extensions.dump_graph("main/loss")) # Take a snapshot for each specified epoch - trainer.extend(extensions.snapshot(), trigger=(args.frequency, 'epoch')) + trainer.extend(extensions.snapshot(), trigger=(args.frequency, "epoch")) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport()) @@ -151,21 +151,33 @@ def device_name(device_intra_rank): # Save two plot images to the result dir if extensions.PlotReport.available(): trainer.extend( - extensions.PlotReport(['main/loss', 'validation/main/loss'], - 'epoch', file_name='loss.png')) + extensions.PlotReport( + ["main/loss", "validation/main/loss"], "epoch", file_name="loss.png" + ) + ) trainer.extend( extensions.PlotReport( - ['main/accuracy', 'validation/main/accuracy'], - 'epoch', file_name='accuracy.png')) + ["main/accuracy", "validation/main/accuracy"], "epoch", file_name="accuracy.png" + ) + ) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. - trainer.extend(extensions.PrintReport( - ['epoch', 'main/loss', 'validation/main/loss', - 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) + trainer.extend( + extensions.PrintReport( + [ + "epoch", + "main/loss", + "validation/main/loss", + "main/accuracy", + "validation/main/accuracy", + "elapsed_time", + ] + ) + ) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) @@ -173,10 +185,10 @@ def device_name(device_intra_rank): # Run the training trainer.run() - serializers.save_npz(os.path.join(args.model_dir, 'model.npz'), model) + serializers.save_npz(os.path.join(args.model_dir, "model.npz"), model) def model_fn(model_dir): model = L.Classifier(MLP(1000, 10)) - serializers.load_npz(os.path.join(model_dir, 'model.npz'), model) + serializers.load_npz(os.path.join(model_dir, "model.npz"), model) return model.predictor diff --git a/tests/data/cifar_10/source/keras_cnn_cifar_10.py b/tests/data/cifar_10/source/keras_cnn_cifar_10.py index b478644d70..be1195d779 100644 --- a/tests/data/cifar_10/source/keras_cnn_cifar_10.py +++ b/tests/data/cifar_10/source/keras_cnn_cifar_10.py @@ -24,7 +24,7 @@ NUM_DATA_BATCHES = 5 NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES BATCH_SIZE = 128 -INPUT_TENSOR_NAME = 'inputs_input' # needs to match the name of the first layer + "_input" +INPUT_TENSOR_NAME = "inputs_input" # needs to match the name of the first layer + "_input" def keras_model_fn(hyperparameters): @@ -38,32 +38,32 @@ def keras_model_fn(hyperparameters): """ model = Sequential() - model.add(Conv2D(32, (3, 3), padding='same', name='inputs', input_shape=(HEIGHT, WIDTH, DEPTH))) - model.add(Activation('relu')) + model.add(Conv2D(32, (3, 3), padding="same", name="inputs", input_shape=(HEIGHT, WIDTH, DEPTH))) + model.add(Activation("relu")) model.add(Conv2D(32, (3, 3))) - model.add(Activation('relu')) + model.add(Activation("relu")) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) - model.add(Conv2D(64, (3, 3), padding='same')) - model.add(Activation('relu')) + model.add(Conv2D(64, (3, 3), padding="same")) + model.add(Activation("relu")) model.add(Conv2D(64, (3, 3))) - model.add(Activation('relu')) + model.add(Activation("relu")) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(512)) - model.add(Activation('relu')) + model.add(Activation("relu")) model.add(Dropout(0.5)) model.add(Dense(NUM_CLASSES)) - model.add(Activation('softmax')) + model.add(Activation("softmax")) - opt = RMSPropOptimizer(learning_rate=hyperparameters['learning_rate'], decay=hyperparameters['decay']) + opt = RMSPropOptimizer( + learning_rate=hyperparameters["learning_rate"], decay=hyperparameters["decay"] + ) - model.compile(loss='categorical_crossentropy', - optimizer=opt, - metrics=['accuracy']) + model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"]) return model @@ -84,18 +84,17 @@ def eval_input_fn(training_dir, hyperparameters): def _generate_synthetic_data(mode, batch_size): input_shape = [batch_size, HEIGHT, WIDTH, DEPTH] images = tf.truncated_normal( - input_shape, - dtype=tf.float32, - stddev=1e-1, - name='synthetic_images') + input_shape, dtype=tf.float32, stddev=1e-1, name="synthetic_images" + ) labels = tf.random_uniform( [batch_size, NUM_CLASSES], minval=0, maxval=NUM_CLASSES - 1, dtype=tf.float32, - name='synthetic_labels') + name="synthetic_labels", + ) - images = tf.contrib.framework.local_variable(images, name='images') - labels = tf.contrib.framework.local_variable(labels, name='labels') + images = tf.contrib.framework.local_variable(images, name="images") + labels = tf.contrib.framework.local_variable(labels, name="labels") return {INPUT_TENSOR_NAME: images}, labels diff --git a/tests/data/cifar_10/source/resnet_cifar_10.py b/tests/data/cifar_10/source/resnet_cifar_10.py index 39101f1fb6..7708190119 100644 --- a/tests/data/cifar_10/source/resnet_cifar_10.py +++ b/tests/data/cifar_10/source/resnet_cifar_10.py @@ -48,7 +48,7 @@ def model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" inputs = features[INPUT_TENSOR_NAME] - tf.summary.image('images', inputs, max_outputs=6) + tf.summary.image("images", inputs, max_outputs=6) network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES) @@ -57,27 +57,27 @@ def model_fn(features, labels, mode, params): logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { - 'classes': tf.argmax(logits, axis=1), - 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') + "classes": tf.argmax(logits, axis=1), + "probabilities": tf.nn.softmax(logits, name="softmax_tensor"), } if mode == tf.estimator.ModeKeys.PREDICT: - export_outputs = { - SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions) - } - return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) + export_outputs = {SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)} + return tf.estimator.EstimatorSpec( + mode=mode, predictions=predictions, export_outputs=export_outputs + ) # Calculate loss, which includes softmax cross entropy and L2 regularization. - cross_entropy = tf.losses.softmax_cross_entropy( - logits=logits, onehot_labels=labels) + cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. - tf.identity(cross_entropy, name='cross_entropy') - tf.summary.scalar('cross_entropy', cross_entropy) + tf.identity(cross_entropy, name="cross_entropy") + tf.summary.scalar("cross_entropy", cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( - [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) + [tf.nn.l2_loss(v) for v in tf.trainable_variables()] + ) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() @@ -86,15 +86,14 @@ def model_fn(features, labels, mode, params): boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]] values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]] learning_rate = tf.train.piecewise_constant( - tf.cast(global_step, tf.int32), boundaries, values) + tf.cast(global_step, tf.int32), boundaries, values + ) # Create a tensor named learning_rate for logging purposes - tf.identity(learning_rate, name='learning_rate') - tf.summary.scalar('learning_rate', learning_rate) + tf.identity(learning_rate, name="learning_rate") + tf.summary.scalar("learning_rate", learning_rate) - optimizer = tf.train.MomentumOptimizer( - learning_rate=learning_rate, - momentum=_MOMENTUM) + optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) @@ -103,20 +102,16 @@ def model_fn(features, labels, mode, params): else: train_op = None - accuracy = tf.metrics.accuracy( - tf.argmax(labels, axis=1), predictions['classes']) - metrics = {'accuracy': accuracy} + accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions["classes"]) + metrics = {"accuracy": accuracy} # Create a tensor named train_accuracy for logging purposes - tf.identity(accuracy[1], name='train_accuracy') - tf.summary.scalar('train_accuracy', accuracy[1]) + tf.identity(accuracy[1], name="train_accuracy") + tf.summary.scalar("train_accuracy", accuracy[1]) return tf.estimator.EstimatorSpec( - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metric_ops=metrics) + mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics + ) def serving_input_fn(hyperpameters): @@ -135,19 +130,18 @@ def eval_input_fn(training_dir, hyperparameters): def _generate_synthetic_data(mode, batch_size): input_shape = [batch_size, HEIGHT, WIDTH, DEPTH] images = tf.truncated_normal( - input_shape, - dtype=tf.float32, - stddev=1e-1, - name='synthetic_images') + input_shape, dtype=tf.float32, stddev=1e-1, name="synthetic_images" + ) labels = tf.random_uniform( [batch_size, NUM_CLASSES], minval=0, maxval=NUM_CLASSES - 1, dtype=tf.int32, - name='synthetic_labels') + name="synthetic_labels", + ) - images = tf.contrib.framework.local_variable(images, name='images') - labels = tf.contrib.framework.local_variable(labels, name='labels') + images = tf.contrib.framework.local_variable(images, name="images") + labels = tf.contrib.framework.local_variable(labels, name="labels") return {INPUT_TENSOR_NAME: images}, labels diff --git a/tests/data/cifar_10/source/resnet_model.py b/tests/data/cifar_10/source/resnet_model.py index 098ae8f701..d9b8547d49 100644 --- a/tests/data/cifar_10/source/resnet_model.py +++ b/tests/data/cifar_10/source/resnet_model.py @@ -41,9 +41,15 @@ def batch_norm_relu(inputs, is_training, data_format): # We set fused=True for a significant performance boost. See # https://www.tensorflow.org/performance/performance_guide#common_fused_ops inputs = tf.layers.batch_normalization( - inputs=inputs, axis=1 if data_format == 'channels_first' else 3, - momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True, - scale=True, training=is_training, fused=True) + inputs=inputs, + axis=1 if data_format == "channels_first" else 3, + momentum=_BATCH_NORM_DECAY, + epsilon=_BATCH_NORM_EPSILON, + center=True, + scale=True, + training=is_training, + fused=True, + ) inputs = tf.nn.relu(inputs) return inputs @@ -66,12 +72,10 @@ def fixed_padding(inputs, kernel_size, data_format): pad_beg = pad_total // 2 pad_end = pad_total - pad_beg - if data_format == 'channels_first': - padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], - [pad_beg, pad_end], [pad_beg, pad_end]]) + if data_format == "channels_first": + padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]]) else: - padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], - [pad_beg, pad_end], [0, 0]]) + padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) return padded_inputs @@ -83,14 +87,18 @@ def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format): inputs = fixed_padding(inputs, kernel_size, data_format) return tf.layers.conv2d( - inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, - padding=('SAME' if strides == 1 else 'VALID'), use_bias=False, + inputs=inputs, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=("SAME" if strides == 1 else "VALID"), + use_bias=False, kernel_initializer=tf.variance_scaling_initializer(), - data_format=data_format) + data_format=data_format, + ) -def building_block(inputs, filters, is_training, projection_shortcut, strides, - data_format): +def building_block(inputs, filters, is_training, projection_shortcut, strides, data_format): """Standard building block for residual networks with BN before convolutions. Args: @@ -117,19 +125,18 @@ def building_block(inputs, filters, is_training, projection_shortcut, strides, shortcut = projection_shortcut(inputs) inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=strides, - data_format=data_format) + inputs=inputs, filters=filters, kernel_size=3, strides=strides, data_format=data_format + ) inputs = batch_norm_relu(inputs, is_training, data_format) inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=1, - data_format=data_format) + inputs=inputs, filters=filters, kernel_size=3, strides=1, data_format=data_format + ) return inputs + shortcut -def bottleneck_block(inputs, filters, is_training, projection_shortcut, - strides, data_format): +def bottleneck_block(inputs, filters, is_training, projection_shortcut, strides, data_format): """Bottleneck block variant for residual networks with BN before convolutions. Args: @@ -157,24 +164,23 @@ def bottleneck_block(inputs, filters, is_training, projection_shortcut, shortcut = projection_shortcut(inputs) inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=1, strides=1, - data_format=data_format) + inputs=inputs, filters=filters, kernel_size=1, strides=1, data_format=data_format + ) inputs = batch_norm_relu(inputs, is_training, data_format) inputs = conv2d_fixed_padding( - inputs=inputs, filters=filters, kernel_size=3, strides=strides, - data_format=data_format) + inputs=inputs, filters=filters, kernel_size=3, strides=strides, data_format=data_format + ) inputs = batch_norm_relu(inputs, is_training, data_format) inputs = conv2d_fixed_padding( - inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, - data_format=data_format) + inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, data_format=data_format + ) return inputs + shortcut -def block_layer(inputs, filters, block_fn, blocks, strides, is_training, name, - data_format): +def block_layer(inputs, filters, block_fn, blocks, strides, is_training, name, data_format): """Creates one layer of blocks for the ResNet model. Args: @@ -199,12 +205,15 @@ def block_layer(inputs, filters, block_fn, blocks, strides, is_training, name, def projection_shortcut(inputs): return conv2d_fixed_padding( - inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, - data_format=data_format) + inputs=inputs, + filters=filters_out, + kernel_size=1, + strides=strides, + data_format=data_format, + ) # Only the first block per block_layer uses projection_shortcut and strides - inputs = block_fn(inputs, filters, is_training, projection_shortcut, strides, - data_format) + inputs = block_fn(inputs, filters, is_training, projection_shortcut, strides, data_format) for _ in range(1, blocks): inputs = block_fn(inputs, filters, is_training, None, 1, data_format) @@ -229,56 +238,72 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None): ValueError: If `resnet_size` is invalid. """ if resnet_size % 6 != 2: - raise ValueError('resnet_size must be 6n + 2:', resnet_size) + raise ValueError("resnet_size must be 6n + 2:", resnet_size) num_blocks = (resnet_size - 2) // 6 if data_format is None: - data_format = ( - 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') + data_format = "channels_first" if tf.test.is_built_with_cuda() else "channels_last" def model(inputs, is_training): """Constructs the ResNet model given the inputs.""" - if data_format == 'channels_first': + if data_format == "channels_first": # Convert from channels_last (NHWC) to channels_first (NCHW). This # provides a large performance boost on GPU. See # https://www.tensorflow.org/performance/performance_guide#data_formats inputs = tf.transpose(inputs, [0, 3, 1, 2]) inputs = conv2d_fixed_padding( - inputs=inputs, filters=16, kernel_size=3, strides=1, - data_format=data_format) - inputs = tf.identity(inputs, 'initial_conv') + inputs=inputs, filters=16, kernel_size=3, strides=1, data_format=data_format + ) + inputs = tf.identity(inputs, "initial_conv") inputs = block_layer( - inputs=inputs, filters=16, block_fn=building_block, blocks=num_blocks, - strides=1, is_training=is_training, name='block_layer1', - data_format=data_format) + inputs=inputs, + filters=16, + block_fn=building_block, + blocks=num_blocks, + strides=1, + is_training=is_training, + name="block_layer1", + data_format=data_format, + ) inputs = block_layer( - inputs=inputs, filters=32, block_fn=building_block, blocks=num_blocks, - strides=2, is_training=is_training, name='block_layer2', - data_format=data_format) + inputs=inputs, + filters=32, + block_fn=building_block, + blocks=num_blocks, + strides=2, + is_training=is_training, + name="block_layer2", + data_format=data_format, + ) inputs = block_layer( - inputs=inputs, filters=64, block_fn=building_block, blocks=num_blocks, - strides=2, is_training=is_training, name='block_layer3', - data_format=data_format) + inputs=inputs, + filters=64, + block_fn=building_block, + blocks=num_blocks, + strides=2, + is_training=is_training, + name="block_layer3", + data_format=data_format, + ) inputs = batch_norm_relu(inputs, is_training, data_format) inputs = tf.layers.average_pooling2d( - inputs=inputs, pool_size=8, strides=1, padding='VALID', - data_format=data_format) - inputs = tf.identity(inputs, 'final_avg_pool') + inputs=inputs, pool_size=8, strides=1, padding="VALID", data_format=data_format + ) + inputs = tf.identity(inputs, "final_avg_pool") inputs = tf.reshape(inputs, [-1, 64]) inputs = tf.layers.dense(inputs=inputs, units=num_classes) - inputs = tf.identity(inputs, 'final_dense') + inputs = tf.identity(inputs, "final_dense") return inputs model.default_image_size = 32 return model -def imagenet_resnet_v2_generator(block_fn, layers, num_classes, - data_format=None): +def imagenet_resnet_v2_generator(block_fn, layers, num_classes, data_format=None): """Generator for ImageNet ResNet v2 models. Args: @@ -295,51 +320,73 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes, returns the output tensor of the ResNet model. """ if data_format is None: - data_format = ( - 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') + data_format = "channels_first" if tf.test.is_built_with_cuda() else "channels_last" def model(inputs, is_training): """Constructs the ResNet model given the inputs.""" - if data_format == 'channels_first': + if data_format == "channels_first": # Convert from channels_last (NHWC) to channels_first (NCHW). This # provides a large performance boost on GPU. inputs = tf.transpose(inputs, [0, 3, 1, 2]) inputs = conv2d_fixed_padding( - inputs=inputs, filters=64, kernel_size=7, strides=2, - data_format=data_format) - inputs = tf.identity(inputs, 'initial_conv') + inputs=inputs, filters=64, kernel_size=7, strides=2, data_format=data_format + ) + inputs = tf.identity(inputs, "initial_conv") inputs = tf.layers.max_pooling2d( - inputs=inputs, pool_size=3, strides=2, padding='SAME', - data_format=data_format) - inputs = tf.identity(inputs, 'initial_max_pool') + inputs=inputs, pool_size=3, strides=2, padding="SAME", data_format=data_format + ) + inputs = tf.identity(inputs, "initial_max_pool") inputs = block_layer( - inputs=inputs, filters=64, block_fn=block_fn, blocks=layers[0], - strides=1, is_training=is_training, name='block_layer1', - data_format=data_format) + inputs=inputs, + filters=64, + block_fn=block_fn, + blocks=layers[0], + strides=1, + is_training=is_training, + name="block_layer1", + data_format=data_format, + ) inputs = block_layer( - inputs=inputs, filters=128, block_fn=block_fn, blocks=layers[1], - strides=2, is_training=is_training, name='block_layer2', - data_format=data_format) + inputs=inputs, + filters=128, + block_fn=block_fn, + blocks=layers[1], + strides=2, + is_training=is_training, + name="block_layer2", + data_format=data_format, + ) inputs = block_layer( - inputs=inputs, filters=256, block_fn=block_fn, blocks=layers[2], - strides=2, is_training=is_training, name='block_layer3', - data_format=data_format) + inputs=inputs, + filters=256, + block_fn=block_fn, + blocks=layers[2], + strides=2, + is_training=is_training, + name="block_layer3", + data_format=data_format, + ) inputs = block_layer( - inputs=inputs, filters=512, block_fn=block_fn, blocks=layers[3], - strides=2, is_training=is_training, name='block_layer4', - data_format=data_format) + inputs=inputs, + filters=512, + block_fn=block_fn, + blocks=layers[3], + strides=2, + is_training=is_training, + name="block_layer4", + data_format=data_format, + ) inputs = batch_norm_relu(inputs, is_training, data_format) inputs = tf.layers.average_pooling2d( - inputs=inputs, pool_size=7, strides=1, padding='VALID', - data_format=data_format) - inputs = tf.identity(inputs, 'final_avg_pool') - inputs = tf.reshape(inputs, - [-1, 512 if block_fn is building_block else 2048]) + inputs=inputs, pool_size=7, strides=1, padding="VALID", data_format=data_format + ) + inputs = tf.identity(inputs, "final_avg_pool") + inputs = tf.reshape(inputs, [-1, 512 if block_fn is building_block else 2048]) inputs = tf.layers.dense(inputs=inputs, units=num_classes) - inputs = tf.identity(inputs, 'final_dense') + inputs = tf.identity(inputs, "final_dense") return inputs model.default_image_size = 224 @@ -349,17 +396,16 @@ def model(inputs, is_training): def resnet_v2(resnet_size, num_classes, data_format=None): """Returns the ResNet model for a given size and number of output classes.""" model_params = { - 18: {'block': building_block, 'layers': [2, 2, 2, 2]}, - 34: {'block': building_block, 'layers': [3, 4, 6, 3]}, - 50: {'block': bottleneck_block, 'layers': [3, 4, 6, 3]}, - 101: {'block': bottleneck_block, 'layers': [3, 4, 23, 3]}, - 152: {'block': bottleneck_block, 'layers': [3, 8, 36, 3]}, - 200: {'block': bottleneck_block, 'layers': [3, 24, 36, 3]} + 18: {"block": building_block, "layers": [2, 2, 2, 2]}, + 34: {"block": building_block, "layers": [3, 4, 6, 3]}, + 50: {"block": bottleneck_block, "layers": [3, 4, 6, 3]}, + 101: {"block": bottleneck_block, "layers": [3, 4, 23, 3]}, + 152: {"block": bottleneck_block, "layers": [3, 8, 36, 3]}, + 200: {"block": bottleneck_block, "layers": [3, 24, 36, 3]}, } if resnet_size not in model_params: - raise ValueError('Not a valid resnet_size:', resnet_size) + raise ValueError("Not a valid resnet_size:", resnet_size) params = model_params[resnet_size] - return imagenet_resnet_v2_generator( - params['block'], params['layers'], num_classes, data_format) + return imagenet_resnet_v2_generator(params["block"], params["layers"], num_classes, data_format) diff --git a/tests/data/coach_cartpole/mxnet_deploy.py b/tests/data/coach_cartpole/mxnet_deploy.py index 04915aed8f..5b599dce9b 100644 --- a/tests/data/coach_cartpole/mxnet_deploy.py +++ b/tests/data/coach_cartpole/mxnet_deploy.py @@ -13,12 +13,12 @@ def model_fn(model_dir): :return: a model """ onnx_path = os.path.join(model_dir, "model.onnx") - ctx = mx.cpu() # todo: pass into function + ctx = mx.cpu() # todo: pass into function # load onnx model symbol and parameters sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path) model_metadata = onnx_mxnet.get_model_metadata(onnx_path) # first index is name, second index is shape - input_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')] + input_names = [inputs[0] for inputs in model_metadata.get("input_tensor_data")] input_symbols = [mx.sym.var(i) for i in input_names] net = gluon.nn.SymbolBlock(outputs=sym, inputs=input_symbols) net_params = net.collect_params() @@ -49,4 +49,4 @@ def transform_fn(net, data, input_content_type, output_content_type): output_nd = net(input_nd) output_np = output_nd.asnumpy() output_list = output_np.tolist() - return json.dumps(output_list), output_content_type \ No newline at end of file + return json.dumps(output_list), output_content_type diff --git a/tests/data/coach_cartpole/preset_cartpole_clippedppo.py b/tests/data/coach_cartpole/preset_cartpole_clippedppo.py index 15999cdec0..b80e7fbdb2 100644 --- a/tests/data/coach_cartpole/preset_cartpole_clippedppo.py +++ b/tests/data/coach_cartpole/preset_cartpole_clippedppo.py @@ -1,11 +1,17 @@ from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters from rl_coach.architectures.layers import Dense -from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType +from rl_coach.base_parameters import ( + VisualizationParameters, + PresetValidationParameters, + DistributedCoachSynchronizationType, +) from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters -from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter +from rl_coach.filters.observation.observation_normalization_filter import ( + ObservationNormalizationFilter, +) from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.schedules import LinearSchedule @@ -31,14 +37,16 @@ agent_params = ClippedPPOAgentParameters() -agent_params.network_wrappers['main'].learning_rate = 0.0003 -agent_params.network_wrappers['main'].input_embedders_parameters['observation'].activation_function = 'tanh' -agent_params.network_wrappers['main'].input_embedders_parameters['observation'].scheme = [Dense(64)] -agent_params.network_wrappers['main'].middleware_parameters.scheme = [Dense(64)] -agent_params.network_wrappers['main'].middleware_parameters.activation_function = 'tanh' -agent_params.network_wrappers['main'].batch_size = 64 -agent_params.network_wrappers['main'].optimizer_epsilon = 1e-5 -agent_params.network_wrappers['main'].adam_optimizer_beta2 = 0.999 +agent_params.network_wrappers["main"].learning_rate = 0.0003 +agent_params.network_wrappers["main"].input_embedders_parameters[ + "observation" +].activation_function = "tanh" +agent_params.network_wrappers["main"].input_embedders_parameters["observation"].scheme = [Dense(64)] +agent_params.network_wrappers["main"].middleware_parameters.scheme = [Dense(64)] +agent_params.network_wrappers["main"].middleware_parameters.activation_function = "tanh" +agent_params.network_wrappers["main"].batch_size = 64 +agent_params.network_wrappers["main"].optimizer_epsilon = 1e-5 +agent_params.network_wrappers["main"].adam_optimizer_beta2 = 0.999 agent_params.algorithm.clip_likelihood_ratio_using_epsilon = 0.2 agent_params.algorithm.clipping_decay_schedule = LinearSchedule(1.0, 0, 1000000) @@ -50,17 +58,22 @@ agent_params.algorithm.num_steps_between_copying_online_weights_to_target = EnvironmentSteps(2048) # Distributed Coach synchronization type. -agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC +agent_params.algorithm.distributed_coach_synchronization_type = ( + DistributedCoachSynchronizationType.SYNC +) agent_params.exploration = EGreedyParameters() agent_params.exploration.epsilon_schedule = LinearSchedule(1.0, 0.01, 10000) -agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', - ObservationNormalizationFilter(name='normalize_observation')) +agent_params.pre_network_filter.add_observation_filter( + "observation", + "normalize_observation", + ObservationNormalizationFilter(name="normalize_observation"), +) ############### # Environment # ############### -env_params = GymVectorEnvironment(level='CartPole-v0') +env_params = GymVectorEnvironment(level="CartPole-v0") ################# # Visualization # @@ -77,6 +90,10 @@ preset_validation_params.min_reward_threshold = 150 preset_validation_params.max_episodes_to_achieve_reward = 400 -graph_manager = BasicRLGraphManager(agent_params=agent_params, env_params=env_params, - schedule_params=schedule_params, vis_params=vis_params, - preset_validation_params=preset_validation_params) +graph_manager = BasicRLGraphManager( + agent_params=agent_params, + env_params=env_params, + schedule_params=schedule_params, + vis_params=vis_params, + preset_validation_params=preset_validation_params, +) diff --git a/tests/data/coach_cartpole/train_coach.py b/tests/data/coach_cartpole/train_coach.py index 83b092225a..8656143e78 100644 --- a/tests/data/coach_cartpole/train_coach.py +++ b/tests/data/coach_cartpole/train_coach.py @@ -2,12 +2,11 @@ class MyLauncher(SageMakerCoachPresetLauncher): - def default_preset_name(self): """This points to a .py file that configures everything about the RL job. It can be overridden at runtime by specifying the RLCOACH_PRESET hyperparameter. """ - return 'preset-cartpole-dqn' + return "preset-cartpole-dqn" def map_hyperparameter(self, name, value): """Here we configure some shortcut names for hyperparameters that we expect to use frequently. @@ -18,7 +17,7 @@ def map_hyperparameter(self, name, value): mapping = { "discount": "rl.agent_params.algorithm.discount", "evaluation_episodes": "rl.evaluation_steps:EnvironmentEpisodes", - "improve_steps": "rl.improve_steps:TrainingSteps" + "improve_steps": "rl.improve_steps:TrainingSteps", } if name in mapping: self.apply_hyperparameter(mapping[name], value) @@ -26,5 +25,5 @@ def map_hyperparameter(self, name, value): super().map_hyperparameter(name, value) -if __name__ == '__main__': +if __name__ == "__main__": MyLauncher.train_main() diff --git a/tests/data/dummy_script.py b/tests/data/dummy_script.py index 6381f25211..c5a8884153 100644 --- a/tests/data/dummy_script.py +++ b/tests/data/dummy_script.py @@ -12,5 +12,5 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -print('This is definitely code which does machine learning stuff') -print('and not just a random file I threw together for unit testing.') +print("This is definitely code which does machine learning stuff") +print("and not just a random file I threw together for unit testing.") diff --git a/tests/data/horovod/test_hvd_basic.py b/tests/data/horovod/test_hvd_basic.py index d37214defc..df15d739e8 100644 --- a/tests/data/horovod/test_hvd_basic.py +++ b/tests/data/horovod/test_hvd_basic.py @@ -3,13 +3,12 @@ import horovod.tensorflow as hvd -if __name__ == '__main__': +if __name__ == "__main__": hvd.init() - with open(os.path.join('/opt/ml/model/rank-%s' % hvd.rank()), 'w+') as f: - basic_info = {'rank': hvd.rank(), 'size': hvd.size()} + with open(os.path.join("/opt/ml/model/rank-%s" % hvd.rank()), "w+") as f: + basic_info = {"rank": hvd.rank(), "size": hvd.size()} json.dump(basic_info, f) print('Saved file "rank-%s": %s' % (hvd.rank(), basic_info)) - diff --git a/tests/data/iris/failure_script.py b/tests/data/iris/failure_script.py index 2dd8783372..5b6f0f531d 100644 --- a/tests/data/iris/failure_script.py +++ b/tests/data/iris/failure_script.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import + def estimator_fn(run_config, params): """For use with integration tests expecting failures.""" - raise Exception('This failure is expected.') + raise Exception("This failure is expected.") diff --git a/tests/data/iris/iris-dnn-classifier.py b/tests/data/iris/iris-dnn-classifier.py index 6dcaa48832..6d737f35e9 100644 --- a/tests/data/iris/iris-dnn-classifier.py +++ b/tests/data/iris/iris-dnn-classifier.py @@ -18,42 +18,46 @@ def estimator_fn(run_config, hyperparameters): - input_tensor_name = hyperparameters.get('input_tensor_name', 'inputs') - learning_rate = hyperparameters.get('learning_rate', 0.05) + input_tensor_name = hyperparameters.get("input_tensor_name", "inputs") + learning_rate = hyperparameters.get("learning_rate", 0.05) feature_columns = [tf.feature_column.numeric_column(input_tensor_name, shape=[4])] - return tf.estimator.DNNClassifier(feature_columns=feature_columns, - hidden_units=[10, 20, 10], - optimizer=tf.train.AdagradOptimizer(learning_rate=learning_rate), - n_classes=3, - config=run_config) + return tf.estimator.DNNClassifier( + feature_columns=feature_columns, + hidden_units=[10, 20, 10], + optimizer=tf.train.AdagradOptimizer(learning_rate=learning_rate), + n_classes=3, + config=run_config, + ) def serving_input_fn(hyperparameters): - input_tensor_name = hyperparameters['input_tensor_name'] + input_tensor_name = hyperparameters["input_tensor_name"] feature_spec = {input_tensor_name: tf.FixedLenFeature(dtype=tf.float32, shape=[4])} return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)() def train_input_fn(training_dir, hyperparameters): """Returns input function that would feed the model during training""" - return _generate_input_fn(training_dir, 'iris_training.csv', hyperparameters) + return _generate_input_fn(training_dir, "iris_training.csv", hyperparameters) def eval_input_fn(training_dir, hyperparameters): """Returns input function that would feed the model during evaluation""" - return _generate_input_fn(training_dir, 'iris_test.csv', hyperparameters) + return _generate_input_fn(training_dir, "iris_test.csv", hyperparameters) def _generate_input_fn(training_dir, training_filename, hyperparameters): - input_tensor_name = hyperparameters['input_tensor_name'] + input_tensor_name = hyperparameters["input_tensor_name"] training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=os.path.join(training_dir, training_filename), target_dtype=np.int, - features_dtype=np.float32) + features_dtype=np.float32, + ) return tf.estimator.inputs.numpy_input_fn( x={input_tensor_name: np.array(training_set.data)}, y=np.array(training_set.target), num_epochs=None, - shuffle=True)() + shuffle=True, + )() diff --git a/tests/data/mxnet_mnist/failure_script.py b/tests/data/mxnet_mnist/failure_script.py index 7f2362c047..ddc2a061c8 100644 --- a/tests/data/mxnet_mnist/failure_script.py +++ b/tests/data/mxnet_mnist/failure_script.py @@ -14,4 +14,4 @@ # For use with integration tests expecting failures. -raise Exception('This failure is expected.') +raise Exception("This failure is expected.") diff --git a/tests/data/mxnet_mnist/mnist.py b/tests/data/mxnet_mnist/mnist.py index 16ab7c2e98..90e4c65af9 100644 --- a/tests/data/mxnet_mnist/mnist.py +++ b/tests/data/mxnet_mnist/mnist.py @@ -39,14 +39,14 @@ def find_file(root_path, file_name): def build_graph(): - data = mx.sym.var('data') + data = mx.sym.var("data") data = mx.sym.flatten(data=data) fc1 = mx.sym.FullyConnected(data=data, num_hidden=128) act1 = mx.sym.Activation(data=fc1, act_type="relu") fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64) act2 = mx.sym.Activation(data=fc2, act_type="relu") fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10) - return mx.sym.SoftmaxOutput(data=fc3, name='softmax') + return mx.sym.SoftmaxOutput(data=fc3, name="softmax") def get_train_context(num_gpus): @@ -56,8 +56,17 @@ def get_train_context(num_gpus): return mx.cpu() -def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel, - hosts, current_host, model_dir): +def train( + batch_size, + epochs, + learning_rate, + num_gpus, + training_channel, + testing_channel, + hosts, + current_host, + model_dir, +): (train_labels, train_images) = load_data(training_channel) (test_labels, test_images) = load_data(testing_channel) @@ -70,56 +79,69 @@ def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing end = start + shard_size break - train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, - shuffle=True) + train_iter = mx.io.NDArrayIter( + train_images[start:end], train_labels[start:end], batch_size, shuffle=True + ) val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size) logging.getLogger().setLevel(logging.DEBUG) - kvstore = 'local' if len(hosts) == 1 else 'dist_sync' + kvstore = "local" if len(hosts) == 1 else "dist_sync" - mlp_model = mx.mod.Module(symbol=build_graph(), - context=get_train_context(num_gpus)) - mlp_model.fit(train_iter, - eval_data=val_iter, - kvstore=kvstore, - optimizer='sgd', - optimizer_params={'learning_rate': learning_rate}, - eval_metric='acc', - batch_end_callback=mx.callback.Speedometer(batch_size, 100), - num_epoch=epochs) + mlp_model = mx.mod.Module(symbol=build_graph(), context=get_train_context(num_gpus)) + mlp_model.fit( + train_iter, + eval_data=val_iter, + kvstore=kvstore, + optimizer="sgd", + optimizer_params={"learning_rate": learning_rate}, + eval_metric="acc", + batch_end_callback=mx.callback.Speedometer(batch_size, 100), + num_epoch=epochs, + ) if len(hosts) == 1 or current_host == hosts[0]: save(model_dir, mlp_model) def save(model_dir, model): - model.symbol.save(os.path.join(model_dir, 'model-symbol.json')) - model.save_params(os.path.join(model_dir, 'model-0000.params')) - - signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]} - for data_desc in model.data_shapes] - with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f: + model.symbol.save(os.path.join(model_dir, "model-symbol.json")) + model.save_params(os.path.join(model_dir, "model-0000.params")) + + signature = [ + {"name": data_desc.name, "shape": [dim for dim in data_desc.shape]} + for data_desc in model.data_shapes + ] + with open(os.path.join(model_dir, "model-shapes.json"), "w") as f: json.dump(signature, f) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch-size', type=int, default=100) - parser.add_argument('--epochs', type=int, default=10) - parser.add_argument('--learning-rate', type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=100) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--learning-rate", type=float, default=0.1) - parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) - parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) - parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) + parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) + parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) + parser.add_argument("--test", type=str, default=os.environ["SM_CHANNEL_TEST"]) - parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) - parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) + parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"]) + parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"])) args = parser.parse_args() - num_gpus = int(os.environ['SM_NUM_GPUS']) - - train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test, - args.hosts, args.current_host, args.model_dir) + num_gpus = int(os.environ["SM_NUM_GPUS"]) + + train( + args.batch_size, + args.epochs, + args.learning_rate, + num_gpus, + args.train, + args.test, + args.hosts, + args.current_host, + args.model_dir, + ) diff --git a/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py b/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py index af3a802d5b..e001d841e5 100644 --- a/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py +++ b/tests/data/mxnet_mnist/mnist_hosting_with_custom_handlers.py @@ -26,9 +26,10 @@ # it is possible to specify own code to load the model, otherwise a default model loading takes place def model_fn(path_to_model_files): from mxnet.io import DataDesc + loaded_symbol = mx.symbol.load(os.path.join(path_to_model_files, "symbol")) created_module = mx.mod.Module(symbol=loaded_symbol) - created_module.bind([DataDesc("data", (1L, 1L, 28L, 28L))]) + created_module.bind([DataDesc("data", (1, 1, 28, 28))]) created_module.load_params(os.path.join(path_to_model_files, "params")) return created_module @@ -38,8 +39,13 @@ def model_fn(path_to_model_files): # returns serialized data and content type it has used def transform_fn(model, request_data, input_content_type, requested_output_content_type): # for demonstration purposes we will be calling handlers from Option2 - return output_fn(process_request_fn(model, request_data, input_content_type), requested_output_content_type), \ - requested_output_content_type + return ( + output_fn( + process_request_fn(model, request_data, input_content_type), + requested_output_content_type, + ), + requested_output_content_type, + ) # --- Option 2 - overwrite container's default input/output behavior with handlers --- @@ -50,7 +56,9 @@ def process_request_fn(model, data, input_content_type): elif input_content_type == "application/json": prediction_input = handle_json_input(data) else: - raise NotImplementedError("This model doesnt support requested input type: " + input_content_type) + raise NotImplementedError( + "This model doesnt support requested input type: " + input_content_type + ) return model.predict(prediction_input) @@ -61,9 +69,11 @@ def handle_s3_file_path(path): if sys.version_info.major == 2: import urlparse + parse_cmd = urlparse.urlparse else: import urllib + parse_cmd = urllib.parse.urlparse import boto3 @@ -73,20 +83,23 @@ def handle_s3_file_path(path): parsed_url = parse_cmd(path) # get S3 client - s3 = boto3.resource('s3') + s3 = boto3.resource("s3") # read file content and pass it down - obj = s3.Object(parsed_url.netloc, parsed_url.path.lstrip('/')) - print ("loading file: " + str(obj)) + obj = s3.Object(parsed_url.netloc, parsed_url.path.lstrip("/")) + print("loading file: " + str(obj)) try: - data = obj.get()['Body'] + data = obj.get()["Body"] except ClientError as ce: - raise ValueError("Can't download from S3 path: " + path + " : " + ce.response['Error']['Message']) + raise ValueError( + "Can't download from S3 path: " + path + " : " + ce.response["Error"]["Message"] + ) import StringIO + buf = StringIO(data.read()) - img = gzip.GzipFile(mode='rb', fileobj=buf) + img = gzip.GzipFile(mode="rb", fileobj=buf) _, _, rows, cols = struct.unpack(">IIII", img.read(16)) images = np.fromstring(img.read(), dtype=np.uint8).reshape(10000, rows, cols) @@ -109,4 +122,6 @@ def output_fn(prediction_output, requested_output_content_type): if requested_output_content_type == "application/json": json.dumps(data_to_return.tolist), requested_output_content_type - raise NotImplementedError("Model doesn't support requested output type: " + requested_output_content_type) + raise NotImplementedError( + "Model doesn't support requested output type: " + requested_output_content_type + ) diff --git a/tests/data/mxnet_mnist/mnist_neo.py b/tests/data/mxnet_mnist/mnist_neo.py index cdb9a94d8c..fb5b60a493 100644 --- a/tests/data/mxnet_mnist/mnist_neo.py +++ b/tests/data/mxnet_mnist/mnist_neo.py @@ -41,14 +41,14 @@ def find_file(root_path, file_name): def build_graph(): - data = mx.sym.var('data') + data = mx.sym.var("data") data = mx.sym.flatten(data=data) fc1 = mx.sym.FullyConnected(data=data, num_hidden=128) act1 = mx.sym.Activation(data=fc1, act_type="relu") fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64) act2 = mx.sym.Activation(data=fc2, act_type="relu") fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10) - return mx.sym.SoftmaxOutput(data=fc3, name='softmax') + return mx.sym.SoftmaxOutput(data=fc3, name="softmax") def get_train_context(num_gpus): @@ -58,8 +58,17 @@ def get_train_context(num_gpus): return mx.cpu() -def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel, - hosts, current_host, model_dir): +def train( + batch_size, + epochs, + learning_rate, + num_gpus, + training_channel, + testing_channel, + hosts, + current_host, + model_dir, +): (train_labels, train_images) = load_data(training_channel) (test_labels, test_images) = load_data(testing_channel) @@ -72,40 +81,44 @@ def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing end = start + shard_size break - train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, - shuffle=True) + train_iter = mx.io.NDArrayIter( + train_images[start:end], train_labels[start:end], batch_size, shuffle=True + ) val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size) logging.getLogger().setLevel(logging.DEBUG) - kvstore = 'local' if len(hosts) == 1 else 'dist_sync' + kvstore = "local" if len(hosts) == 1 else "dist_sync" - mlp_model = mx.mod.Module(symbol=build_graph(), - context=get_train_context(num_gpus)) - mlp_model.fit(train_iter, - eval_data=val_iter, - kvstore=kvstore, - optimizer='sgd', - optimizer_params={'learning_rate': learning_rate}, - eval_metric='acc', - batch_end_callback=mx.callback.Speedometer(batch_size, 100), - num_epoch=epochs) + mlp_model = mx.mod.Module(symbol=build_graph(), context=get_train_context(num_gpus)) + mlp_model.fit( + train_iter, + eval_data=val_iter, + kvstore=kvstore, + optimizer="sgd", + optimizer_params={"learning_rate": learning_rate}, + eval_metric="acc", + batch_end_callback=mx.callback.Speedometer(batch_size, 100), + num_epoch=epochs, + ) if len(hosts) == 1 or current_host == scheduler_host(hosts): save(model_dir, mlp_model) + def neo_preprocess(payload, content_type): - logging.info('Invoking user-defined pre-processing function') + logging.info("Invoking user-defined pre-processing function") - if content_type != 'application/vnd+python.numpy+binary': - raise RuntimeError('Content type must be application/vnd+python.numpy+binary') + if content_type != "application/vnd+python.numpy+binary": + raise RuntimeError("Content type must be application/vnd+python.numpy+binary") f = io.BytesIO(payload) return np.load(f) + ### NOTE: this function cannot use MXNet def neo_postprocess(result): - logging.info('Invoking user-defined post-processing function') + logging.info("Invoking user-defined post-processing function") # Softmax (assumes batch size 1) result = np.squeeze(result) @@ -113,37 +126,49 @@ def neo_postprocess(result): result = result_exp / np.sum(result_exp) response_body = json.dumps(result.tolist()) - content_type = 'application/json' + content_type = "application/json" return response_body, content_type -def save(model_dir, model): - model.symbol.save(os.path.join(model_dir, 'model-symbol.json')) - model.save_params(os.path.join(model_dir, 'model-0000.params')) - signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]} - for data_desc in model.data_shapes] - with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f: +def save(model_dir, model): + model.symbol.save(os.path.join(model_dir, "model-symbol.json")) + model.save_params(os.path.join(model_dir, "model-0000.params")) + + signature = [ + {"name": data_desc.name, "shape": [dim for dim in data_desc.shape]} + for data_desc in model.data_shapes + ] + with open(os.path.join(model_dir, "model-shapes.json"), "w") as f: json.dump(signature, f) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch-size', type=int, default=100) - parser.add_argument('--epochs', type=int, default=10) - parser.add_argument('--learning-rate', type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=100) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--learning-rate", type=float, default=0.1) - parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) - parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) - parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) + parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) + parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) + parser.add_argument("--test", type=str, default=os.environ["SM_CHANNEL_TEST"]) - parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) - parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) + parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"]) + parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"])) args = parser.parse_args() - num_gpus = int(os.environ['SM_NUM_GPUS']) - - train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test, - args.hosts, args.current_host, args.model_dir) + num_gpus = int(os.environ["SM_NUM_GPUS"]) + + train( + args.batch_size, + args.epochs, + args.learning_rate, + num_gpus, + args.train, + args.test, + args.hosts, + args.current_host, + args.model_dir, + ) diff --git a/tests/data/pytorch_mnist/failure_script.py b/tests/data/pytorch_mnist/failure_script.py index 7a173f6aa5..4b9abfbd16 100644 --- a/tests/data/pytorch_mnist/failure_script.py +++ b/tests/data/pytorch_mnist/failure_script.py @@ -1,3 +1,3 @@ -if __name__ == '__main__': +if __name__ == "__main__": """For use with integration tests expecting failures.""" - raise Exception('This failure is expected.') + raise Exception("This failure is expected.") diff --git a/tests/data/pytorch_mnist/mnist.py b/tests/data/pytorch_mnist/mnist.py index 5ee20c4bbc..6eb2d43228 100644 --- a/tests/data/pytorch_mnist/mnist.py +++ b/tests/data/pytorch_mnist/mnist.py @@ -20,7 +20,7 @@ class Net(nn.Module): # Based on https://github.com/pytorch/examples/blob/master/mnist/main.py def __init__(self): - logger.info('Create neural network module') + logger.info("Create neural network module") super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) @@ -40,25 +40,41 @@ def forward(self, x): def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs): - logger.info('Get train data loader') - dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])) - train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None - train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None, - sampler=train_sampler, **kwargs) + logger.info("Get train data loader") + dataset = datasets.MNIST( + training_dir, + train=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + train_sampler = ( + torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None + ) + train_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=train_sampler is None, + sampler=train_sampler, + **kwargs + ) return train_sampler, train_loader def _get_test_data_loader(training_dir, **kwargs): - logger.info('Get test data loader') + logger.info("Get test data loader") return torch.utils.data.DataLoader( - datasets.MNIST(training_dir, train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=1000, shuffle=True, **kwargs) + datasets.MNIST( + training_dir, + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=1000, + shuffle=True, + **kwargs + ) def _average_gradients(model): @@ -72,21 +88,26 @@ def _average_gradients(model): def train(args): world_size = len(args.hosts) is_distributed = world_size > 1 - logger.debug('Number of hosts {}. Distributed training - {}'.format(world_size, is_distributed)) + logger.debug("Number of hosts {}. Distributed training - {}".format(world_size, is_distributed)) use_cuda = args.num_gpus > 0 - logger.debug('Number of gpus available - {}'.format(args.num_gpus)) - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} - device = torch.device('cuda' if use_cuda else 'cpu') + logger.debug("Number of gpus available - {}".format(args.num_gpus)) + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + device = torch.device("cuda" if use_cuda else "cpu") if is_distributed: # Initialize the distributed environment. - backend = 'gloo' - os.environ['WORLD_SIZE'] = str(world_size) + backend = "gloo" + os.environ["WORLD_SIZE"] = str(world_size) host_rank = args.hosts.index(args.current_host) dist.init_process_group(backend=backend, rank=host_rank, world_size=world_size) - logger.info('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format( - backend, dist.get_world_size()) + 'Current host rank is {}. Is cuda available: {}. Number of gpus: {}'.format( - dist.get_rank(), torch.cuda.is_available(), args.num_gpus)) + logger.info( + "Initialized the distributed environment: '{}' backend on {} nodes. ".format( + backend, dist.get_world_size() + ) + + "Current host rank is {}. Is cuda available: {}. Number of gpus: {}".format( + dist.get_rank(), torch.cuda.is_available(), args.num_gpus + ) + ) # set the seed for generating random numbers seed = 1 @@ -94,31 +115,39 @@ def train(args): if use_cuda: torch.cuda.manual_seed(seed) - train_sampler, train_loader = _get_train_data_loader(args.data_dir, is_distributed, args.batch_size, **kwargs) + train_sampler, train_loader = _get_train_data_loader( + args.data_dir, is_distributed, args.batch_size, **kwargs + ) test_loader = _get_test_data_loader(args.data_dir, **kwargs) - logger.debug('Processes {}/{} ({:.0f}%) of train data'.format( - len(train_loader.sampler), len(train_loader.dataset), - 100. * len(train_loader.sampler) / len(train_loader.dataset) - )) - - logger.debug('Processes {}/{} ({:.0f}%) of test data'.format( - len(test_loader.sampler), len(test_loader.dataset), - 100. * len(test_loader.sampler) / len(test_loader.dataset) - )) + logger.debug( + "Processes {}/{} ({:.0f}%) of train data".format( + len(train_loader.sampler), + len(train_loader.dataset), + 100.0 * len(train_loader.sampler) / len(train_loader.dataset), + ) + ) + + logger.debug( + "Processes {}/{} ({:.0f}%) of test data".format( + len(test_loader.sampler), + len(test_loader.dataset), + 100.0 * len(test_loader.sampler) / len(test_loader.dataset), + ) + ) model = Net().to(device) if is_distributed and use_cuda: # multi-machine multi-gpu case - logger.debug('Multi-machine multi-gpu: using DistributedDataParallel.') + logger.debug("Multi-machine multi-gpu: using DistributedDataParallel.") model = torch.nn.parallel.DistributedDataParallel(model) elif use_cuda: # single-machine multi-gpu case - logger.debug('Single-machine multi-gpu: using DataParallel().cuda().') + logger.debug("Single-machine multi-gpu: using DataParallel().cuda().") model = torch.nn.DataParallel(model) else: # single-machine or multi-machine cpu case - logger.debug('Single-machine/multi-machine cpu: using DataParallel.') + logger.debug("Single-machine/multi-machine cpu: using DataParallel.") model = torch.nn.DataParallel(model) optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.5) @@ -139,13 +168,19 @@ def train(args): _average_gradients(model) optimizer.step() if batch_idx % log_interval == 0: - logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.sampler), - 100. * batch_idx / len(train_loader), loss.item())) + logger.debug( + "Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.sampler), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) accuracy = test(model, test_loader, device) save_model(model, args.model_dir) - logger.debug('Overall test accuracy: {}'.format(accuracy)) + logger.debug("Overall test accuracy: {}".format(accuracy)) def test(model, test_loader, device): @@ -161,39 +196,42 @@ def test(model, test_loader, device): correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - accuracy = 100. * correct / len(test_loader.dataset) + accuracy = 100.0 * correct / len(test_loader.dataset) - logger.debug('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), accuracy)) + logger.debug( + "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), accuracy + ) + ) return accuracy def model_fn(model_dir): model = torch.nn.DataParallel(Net()) - with open(os.path.join(model_dir, 'model.pth'), 'rb') as f: + with open(os.path.join(model_dir, "model.pth"), "rb") as f: model.load_state_dict(torch.load(f)) return model def save_model(model, model_dir): - logger.info('Saving the model.') - path = os.path.join(model_dir, 'model.pth') + logger.info("Saving the model.") + path = os.path.join(model_dir, "model.pth") # recommended way from http://pytorch.org/docs/master/notes/serialization.html torch.save(model.state_dict(), path) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--epochs', type=int, default=1, metavar='N') - parser.add_argument('--batch-size', type=int, default=64, metavar='N') + parser.add_argument("--epochs", type=int, default=1, metavar="N") + parser.add_argument("--batch-size", type=int, default=64, metavar="N") # Container environment - parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) - parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) - parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) - parser.add_argument('--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING']) - parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS']) - parser.add_argument('--num-cpus', type=int, default=os.environ['SM_NUM_CPUS']) + parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"])) + parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"]) + parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) + parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"]) + parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"]) + parser.add_argument("--num-cpus", type=int, default=os.environ["SM_NUM_CPUS"]) train(parser.parse_args()) diff --git a/tests/data/pytorch_source_dirs/train.py b/tests/data/pytorch_source_dirs/train.py index e1f0667015..9b89b7d6a8 100644 --- a/tests/data/pytorch_source_dirs/train.py +++ b/tests/data/pytorch_source_dirs/train.py @@ -13,7 +13,7 @@ import alexa import json -MODEL = '/opt/ml/model/answer' +MODEL = "/opt/ml/model/answer" def model_fn(anything): @@ -25,6 +25,6 @@ def predict_fn(input_object, model): return input_object + model -if __name__ == '__main__': - with open(MODEL, 'w') as model: - json.dump(alexa.question('How many roads must a man walk down?'), model) +if __name__ == "__main__": + with open(MODEL, "w") as model: + json.dump(alexa.question("How many roads must a man walk down?"), model) diff --git a/tests/data/ray_cartpole/train_ray.py b/tests/data/ray_cartpole/train_ray.py index e5ccf9df95..aea02f621c 100644 --- a/tests/data/ray_cartpole/train_ray.py +++ b/tests/data/ray_cartpole/train_ray.py @@ -8,7 +8,7 @@ ray.init(log_to_driver=False) config = ppo.DEFAULT_CONFIG.copy() config["num_gpus"] = int(os.environ.get("SM_NUM_GPUS", 0)) -checkpoint_dir = os.environ.get("SM_MODEL_DIR", '/Users/nadzeya/gym') +checkpoint_dir = os.environ.get("SM_MODEL_DIR", "/Users/nadzeya/gym") config["num_workers"] = 1 agent = ppo.PPOAgent(config=config, env="CartPole-v0") diff --git a/tests/data/sagemaker_rl/coach_launcher.py b/tests/data/sagemaker_rl/coach_launcher.py index e513dfd8fb..4d69e582e1 100644 --- a/tests/data/sagemaker_rl/coach_launcher.py +++ b/tests/data/sagemaker_rl/coach_launcher.py @@ -5,7 +5,7 @@ from rl_coach.base_parameters import VisualizationParameters, TaskParameters, Frameworks from rl_coach.utils import short_dynamic_import from rl_coach.core_types import SelectedPhaseOnlyDumpFilter, MaxDumpFilter, RunPhase -import rl_coach.core_types +import rl_coach.core_types from rl_coach import logger from rl_coach.logger import screen import argparse @@ -24,7 +24,9 @@ try: from rl_coach.coach import CoachLauncher except ImportError: - raise RuntimeError("Please upgrade to coach-0.11.0. e.g. 388651196716.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-rl-beta:1.11.0-coach11-cpu-py3") + raise RuntimeError( + "Please upgrade to coach-0.11.0. e.g. 388651196716.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-rl-beta:1.11.0-coach11-cpu-py3" + ) screen.set_use_colors(False) # Simple text logging so it looks good in CloudWatch @@ -37,15 +39,14 @@ class CoachConfigurationList(ConfigurationList): # Being security-paranoid and not instantiating any arbitrary string the customer passes in ALLOWED_TYPES = { - 'Frames': rl_coach.core_types.Frames, - 'EnvironmentSteps': rl_coach.core_types.EnvironmentSteps, - 'EnvironmentEpisodes': rl_coach.core_types.EnvironmentEpisodes, - 'TrainingSteps': rl_coach.core_types.TrainingSteps, - 'Time': rl_coach.core_types.Time, + "Frames": rl_coach.core_types.Frames, + "EnvironmentSteps": rl_coach.core_types.EnvironmentSteps, + "EnvironmentEpisodes": rl_coach.core_types.EnvironmentEpisodes, + "TrainingSteps": rl_coach.core_types.TrainingSteps, + "Time": rl_coach.core_types.Time, } - class SageMakerCoachPresetLauncher(CoachLauncher): """Base class for training RL tasks using RL-Coach. Customers subclass this to define specific kinds of workloads, overriding these methods as needed. @@ -55,7 +56,6 @@ def __init__(self): super().__init__() self.hyperparams = None - def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace: """Overrides the default CLI parsing. Sets the configuration parameters for what a SageMaker run should do. @@ -68,15 +68,15 @@ def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace # Now fill in the args that we care about. sagemaker_job_name = os.environ.get("sagemaker_job_name", "sagemaker-experiment") args.experiment_name = logger.get_experiment_name(sagemaker_job_name) - + # Override experiment_path used for outputs - args.experiment_path = '/opt/ml/output/intermediate' - rl_coach.logger.experiment_path = '/opt/ml/output/intermediate' # for gifs + args.experiment_path = "/opt/ml/output/intermediate" + rl_coach.logger.experiment_path = "/opt/ml/output/intermediate" # for gifs - args.checkpoint_save_dir = '/opt/ml/output/data/checkpoint' - args.checkpoint_save_secs = 10 # should avoid hardcoding + args.checkpoint_save_dir = "/opt/ml/output/data/checkpoint" + args.checkpoint_save_secs = 10 # should avoid hardcoding # onnx for deployment for mxnet (not tensorflow) - args.export_onnx_graph = os.getenv('COACH_BACKEND', 'tensorflow') == 'mxnet' + args.export_onnx_graph = os.getenv("COACH_BACKEND", "tensorflow") == "mxnet" args.no_summary = True @@ -97,7 +97,7 @@ def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace name = name[2:] else: raise ValueError("Unknown command-line argument %s" % name) - val = unknown[i+1] + val = unknown[i + 1] self.map_hyperparameter(name, val) return args @@ -112,29 +112,31 @@ def map_hyperparameter(self, name, value): else: raise ValueError("Unknown hyperparameter %s" % name) - def apply_hyperparameter(self, name, value): """Save this hyperparameter to be applied to the graph_manager object when it's ready. """ - print("Applying RL hyperparameter %s=%s" % (name,value)) + print("Applying RL hyperparameter %s=%s" % (name, value)) self.hyperparameters.store(name, value) - def default_preset_name(self): """ Sub-classes will typically return a single hard-coded string. """ try: - #TODO: remove this after converting all samples. + # TODO: remove this after converting all samples. default_preset = self.DEFAULT_PRESET - screen.warning("Deprecated configuration of default preset. Please implement default_preset_name()") + screen.warning( + "Deprecated configuration of default preset. Please implement default_preset_name()" + ) return default_preset except: pass - raise NotImplementedError("Sub-classes must specify the name of the default preset "+ - "for this RL problem. This will be the name of a python "+ - "file (without .py) that defines a graph_manager variable") + raise NotImplementedError( + "Sub-classes must specify the name of the default preset " + + "for this RL problem. This will be the name of a python " + + "file (without .py) that defines a graph_manager variable" + ) def sagemaker_argparser(self) -> argparse.ArgumentParser: """ @@ -143,31 +145,42 @@ def sagemaker_argparser(self) -> argparse.ArgumentParser: parser = argparse.ArgumentParser() # Arguably this would be cleaner if we copied the config from the base class argparser. - parser.add_argument('-n', '--num_workers', - help="(int) Number of workers for multi-process based agents, e.g. A3C", - default=1, - type=int) - parser.add_argument('-f', '--framework', - help="(string) Neural network framework. Available values: tensorflow, mxnet", - default=os.getenv('COACH_BACKEND', 'tensorflow'), - type=str) - parser.add_argument('-p', '--RLCOACH_PRESET', - help="(string) Name of the file with the RLCoach preset", - default=self.default_preset_name(), - type=str) - parser.add_argument('--save_model', - help="(int) Flag to save model artifact after training finish", - default=0, - type=int) + parser.add_argument( + "-n", + "--num_workers", + help="(int) Number of workers for multi-process based agents, e.g. A3C", + default=1, + type=int, + ) + parser.add_argument( + "-f", + "--framework", + help="(string) Neural network framework. Available values: tensorflow, mxnet", + default=os.getenv("COACH_BACKEND", "tensorflow"), + type=str, + ) + parser.add_argument( + "-p", + "--RLCOACH_PRESET", + help="(string) Name of the file with the RLCoach preset", + default=self.default_preset_name(), + type=str, + ) + parser.add_argument( + "--save_model", + help="(int) Flag to save model artifact after training finish", + default=0, + type=int, + ) return parser def path_of_main_launcher(self): """ A bit of python magic to find the path of the file that launched the current process. """ - main_mod = sys.modules['__main__'] + main_mod = sys.modules["__main__"] try: - launcher_file = os.path.abspath(sys.modules['__main__'].__file__) + launcher_file = os.path.abspath(sys.modules["__main__"].__file__) return os.path.dirname(launcher_file) except AttributeError: # If __main__.__file__ is missing, then we're probably in an interactive python shell @@ -176,7 +189,7 @@ def path_of_main_launcher(self): def preset_from_name(self, preset_name): preset_path = self.path_of_main_launcher() print("Loading preset %s from %s" % (preset_name, preset_path)) - preset_path = os.path.join(self.path_of_main_launcher(),preset_name) + '.py:graph_manager' + preset_path = os.path.join(self.path_of_main_launcher(), preset_name) + ".py:graph_manager" graph_manager = short_dynamic_import(preset_path, ignore_module_case=True) return graph_manager @@ -187,10 +200,10 @@ def get_graph_manager_from_args(self, args): self.hyperparameters.apply_subset(graph_manager, "rl.") # Set framework # Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params - if hasattr(graph_manager, 'agent_params'): + if hasattr(graph_manager, "agent_params"): for network_parameters in graph_manager.agent_params.network_wrappers.values(): network_parameters.framework = args.framework - elif hasattr(graph_manager, 'agents_params'): + elif hasattr(graph_manager, "agents_params"): for ap in graph_manager.agents_params: for network_parameters in ap.network_wrappers.values(): network_parameters.framework = args.framework @@ -198,45 +211,54 @@ def get_graph_manager_from_args(self, args): def _save_tf_model(self): import tensorflow as tf - ckpt_dir = '/opt/ml/output/data/checkpoint' - model_dir = '/opt/ml/model' + + ckpt_dir = "/opt/ml/output/data/checkpoint" + model_dir = "/opt/ml/model" # Re-Initialize from the checkpoint so that you will have the latest models up. - tf.train.init_from_checkpoint(ckpt_dir, - {'main_level/agent/online/network_0/': 'main_level/agent/online/network_0'}) - tf.train.init_from_checkpoint(ckpt_dir, - {'main_level/agent/online/network_1/': 'main_level/agent/online/network_1'}) + tf.train.init_from_checkpoint( + ckpt_dir, {"main_level/agent/online/network_0/": "main_level/agent/online/network_0"} + ) + tf.train.init_from_checkpoint( + ckpt_dir, {"main_level/agent/online/network_1/": "main_level/agent/online/network_1"} + ) # Create a new session with a new tf graph. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) # initialize the checkpoint. # This is the node that will accept the input. - input_nodes = tf.get_default_graph().get_tensor_by_name('main_level/agent/main/online/' + \ - 'network_0/observation/observation:0') + input_nodes = tf.get_default_graph().get_tensor_by_name( + "main_level/agent/main/online/" + "network_0/observation/observation:0" + ) # This is the node that will produce the output. - output_nodes = tf.get_default_graph().get_operation_by_name('main_level/agent/main/online/' + \ - 'network_1/ppo_head_0/policy') + output_nodes = tf.get_default_graph().get_operation_by_name( + "main_level/agent/main/online/" + "network_1/ppo_head_0/policy" + ) # Save the model as a servable model. - tf.saved_model.simple_save(session=sess, - export_dir='model', - inputs={"observation": input_nodes}, - outputs={"policy": output_nodes.outputs[0]}) + tf.saved_model.simple_save( + session=sess, + export_dir="model", + inputs={"observation": input_nodes}, + outputs={"policy": output_nodes.outputs[0]}, + ) # Move to the appropriate folder. Don't mind the directory, this just works. # rl-cart-pole is the name of the model. Remember it. - shutil.move('model/', model_dir + '/model/tf-model/00000001/') + shutil.move("model/", model_dir + "/model/tf-model/00000001/") # EASE will pick it up and upload to the right path. print("Success") def _save_onnx_model(self): - ckpt_dir = '/opt/ml/output/data/checkpoint' - model_dir = '/opt/ml/model' + ckpt_dir = "/opt/ml/output/data/checkpoint" + model_dir = "/opt/ml/model" # find latest onnx file # currently done by name, expected to be changed in future release of coach. - glob_pattern = os.path.join(ckpt_dir, '*.onnx') + glob_pattern = os.path.join(ckpt_dir, "*.onnx") onnx_files = [file for file in glob.iglob(glob_pattern, recursive=True)] if len(onnx_files) > 0: - extract_step = lambda string: int(re.search('/(\d*)_Step.*', string, re.IGNORECASE).group(1)) + extract_step = lambda string: int( + re.search("/(\d*)_Step.*", string, re.IGNORECASE).group(1) + ) onnx_files.sort(key=extract_step) latest_onnx_file = onnx_files[-1] # move to model directory @@ -245,7 +267,7 @@ def _save_onnx_model(self): shutil.move(filepath_from, filepath_to) else: screen.warning("No ONNX files found in {}".format(ckpt_dir)) - + @classmethod def train_main(cls): """Entrypoint for training. @@ -258,10 +280,10 @@ def train_main(cls): parser = trainer.sagemaker_argparser() sage_args, unknown = parser.parse_known_args() if sage_args.save_model == 1: - backend = os.getenv('COACH_BACKEND', 'tensorflow') - if backend == 'tensorflow': + backend = os.getenv("COACH_BACKEND", "tensorflow") + if backend == "tensorflow": trainer._save_tf_model() - if backend == 'mxnet': + if backend == "mxnet": trainer._save_onnx_model() @@ -273,10 +295,12 @@ class SageMakerCoachLauncher(SageMakerCoachPresetLauncher): def __init__(self): super().__init__() screen.warning("DEPRECATION WARNING: Please switch to SageMakerCoachPresetLauncher") - #TODO: Remove this whole class when nobody's using it any more. + # TODO: Remove this whole class when nobody's using it any more. def define_environment(self): - return NotImplementedEror("Sub-class must define environment e.g. GymVectorEnvironment(level='your_module:YourClass')") + return NotImplementedEror( + "Sub-class must define environment e.g. GymVectorEnvironment(level='your_module:YourClass')" + ) def get_graph_manager_from_args(self, args): """Returns the GraphManager object for coach to use to train by calling improve() @@ -314,11 +338,16 @@ def config_schedule(self, schedule_params): pass def define_agent(self): - raise NotImplementedError("Subclass must create define_agent() method which returns an AgentParameters object. e.g.\n" \ - " return rl_coach.agents.dqn_agent.DQNAgentParameters()"); + raise NotImplementedError( + "Subclass must create define_agent() method which returns an AgentParameters object. e.g.\n" + " return rl_coach.agents.dqn_agent.DQNAgentParameters()" + ) def config_visualization(self, vis_params): vis_params.dump_gifs = True - vis_params.video_dump_methods = [SelectedPhaseOnlyDumpFilter(RunPhase.TEST), MaxDumpFilter()] + vis_params.video_dump_methods = [ + SelectedPhaseOnlyDumpFilter(RunPhase.TEST), + MaxDumpFilter(), + ] vis_params.print_networks_summary = True return vis_params diff --git a/tests/data/sagemaker_rl/configuration_list.py b/tests/data/sagemaker_rl/configuration_list.py index 2676d72cb9..4728ba7b60 100644 --- a/tests/data/sagemaker_rl/configuration_list.py +++ b/tests/data/sagemaker_rl/configuration_list.py @@ -31,7 +31,7 @@ def apply_subset(self, config_object, prefix): for key, val in list(self.hp_dict.items()): if key.startswith(prefix): logging.debug("Configuring %s with %s=%s" % (prefix, key, val)) - subkey = key[ len(prefix): ] + subkey = key[len(prefix) :] msg = "%s%s=%s" % (prefix, subkey, val) try: self._set_rl_property_value(config_object, subkey, val, prefix) @@ -44,17 +44,17 @@ def _set_rl_property_value(self, obj, key, val, path=""): """Sets a property on obj to val, or to a sub-object within obj if key looks like "foo.bar" """ if key.find(".") >= 0: - top_key, sub_keys = key_list = key.split(".",1) + top_key, sub_keys = key_list = key.split(".", 1) if top_key.startswith("__"): raise ValueError("Attempting to set unsafe property name %s" % top_key) - if isinstance(obj,dict): + if isinstance(obj, dict): sub_obj = obj[top_key] else: sub_obj = obj.__dict__[top_key] # Recurse - return self._set_rl_property_value(sub_obj, sub_keys, val, "%s.%s" % (path,top_key) ) + return self._set_rl_property_value(sub_obj, sub_keys, val, "%s.%s" % (path, top_key)) else: - key, val = self._parse_type(key,val) + key, val = self._parse_type(key, val) if key.startswith("__"): raise ValueError("Attempting to set unsafe property name %s" % key) if isinstance(obj, dict): @@ -90,8 +90,9 @@ def _parse_type(self, key, val): key, obj_type = key.split(":", 1) cls = self.ALLOWED_TYPES.get(obj_type) if not cls: - raise ValueError("Unrecognized object type %s. Allowed values are %s" % (obj_type, self.ALLOWED_TYPES.keys())) + raise ValueError( + "Unrecognized object type %s. Allowed values are %s" + % (obj_type, self.ALLOWED_TYPES.keys()) + ) val = cls(val) return key, val - - diff --git a/tests/data/sklearn_mnist/failure_script.py b/tests/data/sklearn_mnist/failure_script.py index 9ea99b6660..4b9abfbd16 100644 --- a/tests/data/sklearn_mnist/failure_script.py +++ b/tests/data/sklearn_mnist/failure_script.py @@ -1,3 +1,3 @@ -if __name__=='__main__': +if __name__ == "__main__": """For use with integration tests expecting failures.""" - raise Exception('This failure is expected.') + raise Exception("This failure is expected.") diff --git a/tests/data/sklearn_mnist/mnist.py b/tests/data/sklearn_mnist/mnist.py index 66a5967381..4c9213e40d 100644 --- a/tests/data/sklearn_mnist/mnist.py +++ b/tests/data/sklearn_mnist/mnist.py @@ -21,7 +21,7 @@ def preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format): - images = raw['x'] + images = raw["x"] if ndim == 2: images = images.reshape(-1, 28, 28) elif ndim == 3: @@ -30,44 +30,46 @@ def preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_ images = np.broadcast_to(images, (len(images), 3) + images.shape[2:]) elif ndim != 1: - raise ValueError('invalid ndim for MNIST dataset') + raise ValueError("invalid ndim for MNIST dataset") images = images.astype(image_dtype) - images *= scale / 255. + images *= scale / 255.0 if withlabel: - labels = raw['y'].astype(label_dtype) + labels = raw["y"].astype(label_dtype) return images, labels return images -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() # Data and model checkpoints directories - parser.add_argument('--epochs', type=int, default=-1) - parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR']) - parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) - parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) - parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) + parser.add_argument("--epochs", type=int, default=-1) + parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) + parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) + parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) + parser.add_argument("--test", type=str, default=os.environ["SM_CHANNEL_TEST"]) args = parser.parse_args() - train_file = np.load(os.path.join(args.train, 'train.npz')) - test_file = np.load(os.path.join(args.test, 'test.npz')) + train_file = np.load(os.path.join(args.train, "train.npz")) + test_file = np.load(os.path.join(args.test, "test.npz")) - preprocess_mnist_options = {'withlabel': True, - 'ndim': 1, - 'scale': 1., - 'image_dtype': np.float32, - 'label_dtype': np.int32, - 'rgb_format': False} + preprocess_mnist_options = { + "withlabel": True, + "ndim": 1, + "scale": 1.0, + "image_dtype": np.float32, + "label_dtype": np.int32, + "rgb_format": False, + } # Preprocess MNIST data train_images, train_labels = preprocess_mnist(train_file, **preprocess_mnist_options) test_images, test_labels = preprocess_mnist(test_file, **preprocess_mnist_options) # Set up a Support Vector Machine classifier to predict digit from images - clf = svm.SVC(gamma=0.001, C=100., max_iter=args.epochs) + clf = svm.SVC(gamma=0.001, C=100.0, max_iter=args.epochs) # Fit the SVM classifier with the images and the corresponding labels clf.fit(train_images, train_labels) diff --git a/tests/data/tensorflow_mnist/mnist.py b/tests/data/tensorflow_mnist/mnist.py index 62e045fc5b..5882011815 100644 --- a/tests/data/tensorflow_mnist/mnist.py +++ b/tests/data/tensorflow_mnist/mnist.py @@ -26,167 +26,158 @@ tf_logger = tf_logging._get_logger() tf_logger.handlers = [_handler] + def cnn_model_fn(features, labels, mode): - """Model function for CNN.""" - # Input Layer - # Reshape X to 4-D tensor: [batch_size, width, height, channels] - # MNIST images are 28x28 pixels, and have one color channel - input_layer = tf.reshape(features["x"], [-1, 28, 28, 1]) - - # Convolutional Layer #1 - # Computes 32 features using a 5x5 filter with ReLU activation. - # Padding is added to preserve width and height. - # Input Tensor Shape: [batch_size, 28, 28, 1] - # Output Tensor Shape: [batch_size, 28, 28, 32] - conv1 = tf.layers.conv2d( - inputs=input_layer, - filters=32, - kernel_size=[5, 5], - padding="same", - activation=tf.nn.relu) - - # Pooling Layer #1 - # First max pooling layer with a 2x2 filter and stride of 2 - # Input Tensor Shape: [batch_size, 28, 28, 32] - # Output Tensor Shape: [batch_size, 14, 14, 32] - pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) - - # Convolutional Layer #2 - # Computes 64 features using a 5x5 filter. - # Padding is added to preserve width and height. - # Input Tensor Shape: [batch_size, 14, 14, 32] - # Output Tensor Shape: [batch_size, 14, 14, 64] - conv2 = tf.layers.conv2d( - inputs=pool1, - filters=64, - kernel_size=[5, 5], - padding="same", - activation=tf.nn.relu) - - # Pooling Layer #2 - # Second max pooling layer with a 2x2 filter and stride of 2 - # Input Tensor Shape: [batch_size, 14, 14, 64] - # Output Tensor Shape: [batch_size, 7, 7, 64] - pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) - - # Flatten tensor into a batch of vectors - # Input Tensor Shape: [batch_size, 7, 7, 64] - # Output Tensor Shape: [batch_size, 7 * 7 * 64] - pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) - - # Dense Layer - # Densely connected layer with 1024 neurons - # Input Tensor Shape: [batch_size, 7 * 7 * 64] - # Output Tensor Shape: [batch_size, 1024] - dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) - - # Add dropout operation; 0.6 probability that element will be kept - dropout = tf.layers.dropout( - inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN) - - # Logits layer - # Input Tensor Shape: [batch_size, 1024] - # Output Tensor Shape: [batch_size, 10] - logits = tf.layers.dense(inputs=dropout, units=10) - - predictions = { - # Generate predictions (for PREDICT and EVAL mode) - "classes": tf.argmax(input=logits, axis=1), - # Add `softmax_tensor` to the graph. It is used for PREDICT and by the - # `logging_hook`. - "probabilities": tf.nn.softmax(logits, name="softmax_tensor") - } - if mode == tf.estimator.ModeKeys.PREDICT: - return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) - - # Calculate Loss (for both TRAIN and EVAL modes) - loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) - - # Configure the Training Op (for TRAIN mode) - if mode == tf.estimator.ModeKeys.TRAIN: - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) - train_op = optimizer.minimize( - loss=loss, - global_step=tf.train.get_global_step()) - return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) - - # Add evaluation metrics (for EVAL mode) - eval_metric_ops = { - "accuracy": tf.metrics.accuracy( - labels=labels, predictions=predictions["classes"])} - return tf.estimator.EstimatorSpec( - mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) + """Model function for CNN.""" + # Input Layer + # Reshape X to 4-D tensor: [batch_size, width, height, channels] + # MNIST images are 28x28 pixels, and have one color channel + input_layer = tf.reshape(features["x"], [-1, 28, 28, 1]) + + # Convolutional Layer #1 + # Computes 32 features using a 5x5 filter with ReLU activation. + # Padding is added to preserve width and height. + # Input Tensor Shape: [batch_size, 28, 28, 1] + # Output Tensor Shape: [batch_size, 28, 28, 32] + conv1 = tf.layers.conv2d( + inputs=input_layer, filters=32, kernel_size=[5, 5], padding="same", activation=tf.nn.relu + ) + + # Pooling Layer #1 + # First max pooling layer with a 2x2 filter and stride of 2 + # Input Tensor Shape: [batch_size, 28, 28, 32] + # Output Tensor Shape: [batch_size, 14, 14, 32] + pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) + + # Convolutional Layer #2 + # Computes 64 features using a 5x5 filter. + # Padding is added to preserve width and height. + # Input Tensor Shape: [batch_size, 14, 14, 32] + # Output Tensor Shape: [batch_size, 14, 14, 64] + conv2 = tf.layers.conv2d( + inputs=pool1, filters=64, kernel_size=[5, 5], padding="same", activation=tf.nn.relu + ) + + # Pooling Layer #2 + # Second max pooling layer with a 2x2 filter and stride of 2 + # Input Tensor Shape: [batch_size, 14, 14, 64] + # Output Tensor Shape: [batch_size, 7, 7, 64] + pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) + + # Flatten tensor into a batch of vectors + # Input Tensor Shape: [batch_size, 7, 7, 64] + # Output Tensor Shape: [batch_size, 7 * 7 * 64] + pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) + + # Dense Layer + # Densely connected layer with 1024 neurons + # Input Tensor Shape: [batch_size, 7 * 7 * 64] + # Output Tensor Shape: [batch_size, 1024] + dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) + + # Add dropout operation; 0.6 probability that element will be kept + dropout = tf.layers.dropout( + inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN + ) + + # Logits layer + # Input Tensor Shape: [batch_size, 1024] + # Output Tensor Shape: [batch_size, 10] + logits = tf.layers.dense(inputs=dropout, units=10) + + predictions = { + # Generate predictions (for PREDICT and EVAL mode) + "classes": tf.argmax(input=logits, axis=1), + # Add `softmax_tensor` to the graph. It is used for PREDICT and by the + # `logging_hook`. + "probabilities": tf.nn.softmax(logits, name="softmax_tensor"), + } + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) + + # Calculate Loss (for both TRAIN and EVAL modes) + loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) + + # Configure the Training Op (for TRAIN mode) + if mode == tf.estimator.ModeKeys.TRAIN: + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) + train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step()) + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + + # Add evaluation metrics (for EVAL mode) + eval_metric_ops = { + "accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions["classes"]) + } + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) + def _load_training_data(base_dir): - x_train = np.load(os.path.join(base_dir, 'train_data.npy')) - y_train = np.load(os.path.join(base_dir, 'train_labels.npy')) + x_train = np.load(os.path.join(base_dir, "train_data.npy")) + y_train = np.load(os.path.join(base_dir, "train_labels.npy")) return x_train, y_train + def _load_testing_data(base_dir): - x_test = np.load(os.path.join(base_dir, 'eval_data.npy')) - y_test = np.load(os.path.join(base_dir, 'eval_labels.npy')) + x_test = np.load(os.path.join(base_dir, "eval_data.npy")) + y_test = np.load(os.path.join(base_dir, "eval_labels.npy")) return x_test, y_test + def _parse_args(): parser = argparse.ArgumentParser() # hyperparameters sent by the client are passed as command-line arguments to the script. - parser.add_argument('--epochs', type=int, default=1) + parser.add_argument("--epochs", type=int, default=1) # Data, model, and output directories - parser.add_argument('--output-data-dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR')) - parser.add_argument('--model_dir', type=str) - parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAINING')) - parser.add_argument('--hosts', type=list, default=json.loads(os.environ.get('SM_HOSTS'))) - parser.add_argument('--current-host', type=str, default=os.environ.get('SM_CURRENT_HOST')) + parser.add_argument("--output-data-dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR")) + parser.add_argument("--model_dir", type=str) + parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAINING")) + parser.add_argument("--hosts", type=list, default=json.loads(os.environ.get("SM_HOSTS"))) + parser.add_argument("--current-host", type=str, default=os.environ.get("SM_CURRENT_HOST")) return parser.parse_known_args() + def serving_input_fn(): - inputs = {'x': tf.placeholder(tf.float32, [None, 784])} + inputs = {"x": tf.placeholder(tf.float32, [None, 784])} return tf.estimator.export.ServingInputReceiver(inputs, inputs) + if __name__ == "__main__": args, unknown = _parse_args() - if args.model_dir.startswith('s3://'): - os.environ['S3_REGION'] = 'us-west-2' - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' - os.environ['S3_USE_HTTPS'] = '1' + if args.model_dir.startswith("s3://"): + os.environ["S3_REGION"] = "us-west-2" + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" + os.environ["S3_USE_HTTPS"] = "1" train_data, train_labels = _load_training_data(args.train) eval_data, eval_labels = _load_testing_data(args.train) # Create the Estimator - mnist_classifier = tf.estimator.Estimator( - model_fn=cnn_model_fn, model_dir=args.model_dir) + mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=args.model_dir) # Set up logging for predictions # Log the values in the "Softmax" tensor with label "probabilities" tensors_to_log = {"probabilities": "softmax_tensor"} - logging_hook = tf.train.LoggingTensorHook( - tensors=tensors_to_log, every_n_iter=50) + logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=50) # Train the model train_input_fn = tf.estimator.inputs.numpy_input_fn( - x={"x": train_data}, - y=train_labels, - batch_size=50, - num_epochs=None, - shuffle=True) + x={"x": train_data}, y=train_labels, batch_size=50, num_epochs=None, shuffle=True + ) # Evaluate the model and print results eval_input_fn = tf.estimator.inputs.numpy_input_fn( - x={"x": eval_data}, - y=eval_labels, - num_epochs=1, - shuffle=False) + x={"x": eval_data}, y=eval_labels, num_epochs=1, shuffle=False + ) train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=1000) eval_spec = tf.estimator.EvalSpec(eval_input_fn) tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec) if args.current_host == args.hosts[0]: - mnist_classifier.export_savedmodel('/opt/ml/model', serving_input_fn) + mnist_classifier.export_savedmodel("/opt/ml/model", serving_input_fn) - tf_logger.info('====== Training finished =========') + tf_logger.info("====== Training finished =========") diff --git a/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py index 2fe2eb3327..0dfd36c46d 100644 --- a/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py +++ b/tests/data/tfs/tfs-test-entrypoint-and-dependencies/inference.py @@ -14,10 +14,11 @@ import dependency + def input_handler(data, context): - data = json.loads(data.read().decode('utf-8')) - new_values = [x + 1 for x in data['instances']] - dumps = json.dumps({'instances': new_values}) + data = json.loads(data.read().decode("utf-8")) + new_values = [x + 1 for x in data["instances"]] + dumps = json.dumps({"instances": new_values}) return dumps diff --git a/tests/data/tfs/tfs-test-entrypoint-with-handler/inference.py b/tests/data/tfs/tfs-test-entrypoint-with-handler/inference.py index a3929bca25..495c0dde24 100644 --- a/tests/data/tfs/tfs-test-entrypoint-with-handler/inference.py +++ b/tests/data/tfs/tfs-test-entrypoint-with-handler/inference.py @@ -22,13 +22,13 @@ def save_model(): - shutil.copytree('/opt/ml/code/123', '/opt/ml/model/123') + shutil.copytree("/opt/ml/code/123", "/opt/ml/model/123") def input_handler(data, context): - data = json.loads(data.read().decode('utf-8')) - new_values = [x + 1 for x in data['instances']] - dumps = json.dumps({'instances': new_values}) + data = json.loads(data.read().decode("utf-8")) + new_values = [x + 1 for x in data["instances"]] + dumps = json.dumps({"instances": new_values}) return dumps @@ -40,4 +40,3 @@ def output_handler(data, context): if __name__ == "__main__": save_model() - diff --git a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py index 2f691fea1d..c89138edc8 100644 --- a/tests/data/tfs/tfs-test-model-with-inference/code/inference.py +++ b/tests/data/tfs/tfs-test-model-with-inference/code/inference.py @@ -12,10 +12,11 @@ # language governing permissions and limitations under the License. import json + def input_handler(data, context): - data = json.loads(data.read().decode('utf-8')) - new_values = [x + 1 for x in data['instances']] - dumps = json.dumps({'instances': new_values}) + data = json.loads(data.read().decode("utf-8")) + new_values = [x + 1 for x in data["instances"]] + dumps = json.dumps({"instances": new_values}) return dumps diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index 36c2a6e518..1a99c29601 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -18,25 +18,39 @@ import boto3 -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") TRAINING_DEFAULT_TIMEOUT_MINUTES = 20 TUNING_DEFAULT_TIMEOUT_MINUTES = 20 TRANSFORM_DEFAULT_TIMEOUT_MINUTES = 20 -PYTHON_VERSION = 'py' + str(sys.version_info.major) +PYTHON_VERSION = "py" + str(sys.version_info.major) # these regions have some p2 and p3 instances, but not enough for continuous testing -HOSTING_NO_P2_REGIONS = ['ca-central-1', 'eu-central-1', 'eu-west-2', 'us-west-1'] -HOSTING_NO_P3_REGIONS = ['ap-southeast-1', 'ap-southeast-2', 'ap-south-1', 'ca-central-1', - 'eu-central-1', 'eu-west-2', 'us-west-1'] -TRAINING_NO_P2_REGIONS = ['ap-southeast-1', 'ap-southeast-2'] +HOSTING_NO_P2_REGIONS = ["ca-central-1", "eu-central-1", "eu-west-2", "us-west-1"] +HOSTING_NO_P3_REGIONS = [ + "ap-southeast-1", + "ap-southeast-2", + "ap-south-1", + "ca-central-1", + "eu-central-1", + "eu-west-2", + "us-west-1", +] +TRAINING_NO_P2_REGIONS = ["ap-southeast-1", "ap-southeast-2"] # EI is currently only supported in the following regions # regions were derived from https://aws.amazon.com/machine-learning/elastic-inference/pricing/ -EI_SUPPORTED_REGIONS = ['us-east-1', 'us-east-2', 'us-west-2', 'eu-west-1', 'ap-northeast-1', 'ap-northeast-2'] +EI_SUPPORTED_REGIONS = [ + "us-east-1", + "us-east-2", + "us-west-2", + "eu-west-1", + "ap-northeast-1", + "ap-northeast-2", +] -logging.getLogger('boto3').setLevel(logging.INFO) -logging.getLogger('botocore').setLevel(logging.INFO) +logging.getLogger("boto3").setLevel(logging.INFO) +logging.getLogger("botocore").setLevel(logging.INFO) def test_region(): - return os.environ.get('TEST_AWS_REGION_NAME', boto3.session.Session().region_name) + return os.environ.get("TEST_AWS_REGION_NAME", boto3.session.Session().region_name) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 8d8092569e..a0c9f1cb2e 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -24,12 +24,12 @@ def create_sagemaker_local_network(): creates the network sagemaker-local beforehand, avoiding this issue in CI. """ - os.system('docker network create sagemaker-local') + os.system("docker network create sagemaker-local") create_sagemaker_local_network() -@pytest.fixture(scope='session', params=['local', 'ml.c4.xlarge']) +@pytest.fixture(scope="session", params=["local", "ml.c4.xlarge"]) def instance_type(request): return request.param diff --git a/tests/integ/kms_utils.py b/tests/integ/kms_utils.py index f5e5e0c5aa..dc7ad79097 100644 --- a/tests/integ/kms_utils.py +++ b/tests/integ/kms_utils.py @@ -17,13 +17,14 @@ from botocore import exceptions -PRINCIPAL_TEMPLATE = '["{account_id}", "{role_arn}", ' \ - '"arn:aws:iam::{account_id}:role/{sagemaker_role}"] ' - -KEY_ALIAS = 'SageMakerTestKMSKey' -KMS_S3_ALIAS = 'SageMakerTestS3KMSKey' -POLICY_NAME = 'default' -KEY_POLICY = ''' +PRINCIPAL_TEMPLATE = ( + '["{account_id}", "{role_arn}", ' '"arn:aws:iam::{account_id}:role/{sagemaker_role}"] ' +) + +KEY_ALIAS = "SageMakerTestKMSKey" +KMS_S3_ALIAS = "SageMakerTestS3KMSKey" +POLICY_NAME = "default" +KEY_POLICY = """ {{ "Version": "2012-10-17", "Id": "{id}", @@ -39,87 +40,82 @@ }} ] }} -''' +""" def _get_kms_key_arn(kms_client, alias): try: - response = kms_client.describe_key(KeyId='alias/' + alias) - return response['KeyMetadata']['Arn'] + response = kms_client.describe_key(KeyId="alias/" + alias) + return response["KeyMetadata"]["Arn"] except kms_client.exceptions.NotFoundException: return None def _get_kms_key_id(kms_client, alias): try: - response = kms_client.describe_key(KeyId='alias/' + alias) - return response['KeyMetadata']['KeyId'] + response = kms_client.describe_key(KeyId="alias/" + alias) + return response["KeyMetadata"]["KeyId"] except kms_client.exceptions.NotFoundException: return None -def _create_kms_key(kms_client, - account_id, - role_arn=None, - sagemaker_role='SageMakerRole', - alias=KEY_ALIAS): +def _create_kms_key( + kms_client, account_id, role_arn=None, sagemaker_role="SageMakerRole", alias=KEY_ALIAS +): if role_arn: - principal = PRINCIPAL_TEMPLATE.format(account_id=account_id, - role_arn=role_arn, - sagemaker_role=sagemaker_role) + principal = PRINCIPAL_TEMPLATE.format( + account_id=account_id, role_arn=role_arn, sagemaker_role=sagemaker_role + ) else: principal = '"{account_id}"'.format(account_id=account_id) response = kms_client.create_key( - Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role), - Description='KMS key for SageMaker Python SDK integ tests', + Policy=KEY_POLICY.format( + id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role + ), + Description="KMS key for SageMaker Python SDK integ tests", ) - key_arn = response['KeyMetadata']['Arn'] + key_arn = response["KeyMetadata"]["Arn"] if alias: - kms_client.create_alias(AliasName='alias/' + alias, TargetKeyId=key_arn) + kms_client.create_alias(AliasName="alias/" + alias, TargetKeyId=key_arn) return key_arn -def _add_role_to_policy(kms_client, - account_id, - role_arn, - alias=KEY_ALIAS, - sagemaker_role='SageMakerRole'): +def _add_role_to_policy( + kms_client, account_id, role_arn, alias=KEY_ALIAS, sagemaker_role="SageMakerRole" +): key_id = _get_kms_key_id(kms_client, alias) policy = kms_client.get_key_policy(KeyId=key_id, PolicyName=POLICY_NAME) - policy = json.loads(policy['Policy']) - principal = policy['Statement'][0]['Principal']['AWS'] + policy = json.loads(policy["Policy"]) + principal = policy["Statement"][0]["Principal"]["AWS"] if role_arn not in principal or sagemaker_role not in principal: - principal = PRINCIPAL_TEMPLATE.format(account_id=account_id, - role_arn=role_arn, - sagemaker_role=sagemaker_role) + principal = PRINCIPAL_TEMPLATE.format( + account_id=account_id, role_arn=role_arn, sagemaker_role=sagemaker_role + ) - kms_client.put_key_policy(KeyId=key_id, - PolicyName=POLICY_NAME, - Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal)) + kms_client.put_key_policy( + KeyId=key_id, + PolicyName=POLICY_NAME, + Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal), + ) -def get_or_create_kms_key(sagemaker_session, - role_arn=None, - alias=KEY_ALIAS, - sagemaker_role='SageMakerRole'): - kms_client = sagemaker_session.boto_session.client('kms') +def get_or_create_kms_key( + sagemaker_session, role_arn=None, alias=KEY_ALIAS, sagemaker_role="SageMakerRole" +): + kms_client = sagemaker_session.boto_session.client("kms") kms_key_arn = _get_kms_key_arn(kms_client, alias) - sts_client = sagemaker_session.boto_session.client('sts') - account_id = sts_client.get_caller_identity()['Account'] + sts_client = sagemaker_session.boto_session.client("sts") + account_id = sts_client.get_caller_identity()["Account"] if kms_key_arn is None: return _create_kms_key(kms_client, account_id, role_arn, sagemaker_role, alias) if role_arn: - _add_role_to_policy(kms_client, - account_id, - role_arn, - alias, - sagemaker_role) + _add_role_to_policy(kms_client, account_id, role_arn, alias, sagemaker_role) return kms_key_arn @@ -158,48 +154,46 @@ def get_or_create_kms_key(sagemaker_session, @contextlib.contextmanager def bucket_with_encryption(boto_session, sagemaker_role): - account = boto_session.client('sts').get_caller_identity()['Account'] - role_arn = boto_session.client('sts').get_caller_identity()['Arn'] + account = boto_session.client("sts").get_caller_identity()["Account"] + role_arn = boto_session.client("sts").get_caller_identity()["Arn"] - kms_client = boto_session.client('kms') + kms_client = boto_session.client("kms") kms_key_arn = _create_kms_key(kms_client, account, role_arn, sagemaker_role, None) region = boto_session.region_name - bucket_name = 'sagemaker-{}-{}-with-kms'.format(region, account) + bucket_name = "sagemaker-{}-{}-with-kms".format(region, account) - s3 = boto_session.client('s3') + s3 = boto_session.client("s3") try: # 'us-east-1' cannot be specified because it is the default region: # https://github.com/boto/boto3/issues/125 - if region == 'us-east-1': + if region == "us-east-1": s3.create_bucket(Bucket=bucket_name) else: - s3.create_bucket(Bucket=bucket_name, - CreateBucketConfiguration={'LocationConstraint': region}) + s3.create_bucket( + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} + ) except exceptions.ClientError as e: - if e.response['Error']['Code'] != 'BucketAlreadyOwnedByYou': + if e.response["Error"]["Code"] != "BucketAlreadyOwnedByYou": raise s3.put_bucket_encryption( Bucket=bucket_name, ServerSideEncryptionConfiguration={ - 'Rules': [ + "Rules": [ { - 'ApplyServerSideEncryptionByDefault': { - 'SSEAlgorithm': 'aws:kms', - 'KMSMasterKeyID': kms_key_arn + "ApplyServerSideEncryptionByDefault": { + "SSEAlgorithm": "aws:kms", + "KMSMasterKeyID": kms_key_arn, } - }, + } ] - } + }, ) - s3.put_bucket_policy( - Bucket=bucket_name, - Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name) - ) + s3.put_bucket_policy(Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)) - yield 's3://' + bucket_name, kms_key_arn + yield "s3://" + bucket_name, kms_key_arn kms_client.schedule_key_deletion(KeyId=kms_key_arn, PendingWindowInDays=7) diff --git a/tests/integ/lock.py b/tests/integ/lock.py index 07d651d2b1..b547d9b1e5 100644 --- a/tests/integ/lock.py +++ b/tests/integ/lock.py @@ -18,7 +18,7 @@ import tempfile from contextlib import contextmanager -DEFAULT_LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_lock') +DEFAULT_LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_lock") @contextmanager @@ -27,7 +27,7 @@ def lock(path=DEFAULT_LOCK_PATH): test operations need to limit concurrency to work reliably. Examples include local mode endpoint tests and vpc creation tests. """ - f = open(path, 'w') + f = open(path, "w") fd = f.fileno() fcntl.lockf(fd, fcntl.LOCK_EX) diff --git a/tests/integ/marketplace_utils.py b/tests/integ/marketplace_utils.py index c830927a46..5bfb293914 100644 --- a/tests/integ/marketplace_utils.py +++ b/tests/integ/marketplace_utils.py @@ -13,17 +13,17 @@ from __future__ import absolute_import REGION_ACCOUNT_MAP = { - 'us-east-1': '865070037744', - 'us-east-2': '057799348421', - 'us-west-2': '594846645681', - 'eu-west-1': '985815980388', - 'eu-central-1': '446921602837', - 'ap-northeast-1': '977537786026', - 'ap-northeast-2': '745090734665', - 'ap-southeast-2': '666831318237', - 'ap-southeast-1': '192199979996', - 'ap-south-1': '077584701553', - 'ca-central-1': '470592106596', - 'eu-west-2': '856760150666', - 'us-west-1': '382657785993' + "us-east-1": "865070037744", + "us-east-2": "057799348421", + "us-west-2": "594846645681", + "eu-west-1": "985815980388", + "eu-central-1": "446921602837", + "ap-northeast-1": "977537786026", + "ap-northeast-2": "745090734665", + "ap-southeast-2": "666831318237", + "ap-southeast-1": "192199979996", + "ap-south-1": "077584701553", + "ca-central-1": "470592106596", + "eu-west-2": "856760150666", + "us-west-1": "382657785993", } diff --git a/tests/integ/record_set.py b/tests/integ/record_set.py index 77da668fbd..e71134f550 100644 --- a/tests/integ/record_set.py +++ b/tests/integ/record_set.py @@ -18,7 +18,9 @@ from sagemaker.utils import sagemaker_timestamp -def prepare_record_set_from_local_files(dir_path, destination, num_records, feature_dim, sagemaker_session): +def prepare_record_set_from_local_files( + dir_path, destination, num_records, feature_dim, sagemaker_session +): """Build a :class:`~RecordSet` by pointing to local files. Args: @@ -31,7 +33,7 @@ def prepare_record_set_from_local_files(dir_path, destination, num_records, feat RecordSet: A RecordSet specified by S3Prefix to to be used in training. """ key_prefix = urlparse(destination).path - key_prefix = key_prefix + '{}-{}'.format("testfiles", sagemaker_timestamp()) - key_prefix = key_prefix.lstrip('/') + key_prefix = key_prefix + "{}-{}".format("testfiles", sagemaker_timestamp()) + key_prefix = key_prefix.lstrip("/") uploaded_location = sagemaker_session.upload_data(path=dir_path, key_prefix=key_prefix) - return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type='S3Prefix') + return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type="S3Prefix") diff --git a/tests/integ/test_byo_estimator.py b/tests/integ/test_byo_estimator.py index 94c3e61409..22b2318e36 100644 --- a/tests/integ/test_byo_estimator.py +++ b/tests/integ/test_byo_estimator.py @@ -28,15 +28,15 @@ from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def region(sagemaker_session): return sagemaker_session.boto_session.region_name def fm_serializer(data): - js = {'instances': []} + js = {"instances": []} for row in data: - js['instances'].append({'features': row.tolist()}) + js["instances"].append({"features": row.tolist()}) return json.dumps(js) @@ -53,94 +53,101 @@ def test_byo_estimator(sagemaker_session, region): """ image_name = registry(region) + "/factorization-machines:1" - training_data_path = os.path.join(DATA_DIR, 'dummy_tensor') - job_name = unique_name_from_base('byo') + training_data_path = os.path.join(DATA_DIR, "dummy_tensor") + job_name = unique_name_from_base("byo") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - prefix = 'test_byo_estimator' - key = 'recordio-pb-data' + prefix = "test_byo_estimator" + key = "recordio-pb-data" - s3_train_data = sagemaker_session.upload_data(path=training_data_path, - key_prefix=os.path.join(prefix, 'train', key)) + s3_train_data = sagemaker_session.upload_data( + path=training_data_path, key_prefix=os.path.join(prefix, "train", key) + ) - estimator = Estimator(image_name=image_name, - role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) + estimator = Estimator( + image_name=image_name, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) - estimator.set_hyperparameters(num_factors=10, - feature_dim=784, - mini_batch_size=100, - predictor_type='binary_classifier') + estimator.set_hyperparameters( + num_factors=10, feature_dim=784, mini_batch_size=100, predictor_type="binary_classifier" + ) # training labels must be 'float32' - estimator.fit({'train': s3_train_data}, job_name=job_name) + estimator.fit({"train": s3_train_data}, job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): model = estimator.create_model() - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=job_name) + predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=job_name) predictor.serializer = fm_serializer - predictor.content_type = 'application/json' + predictor.content_type = "application/json" predictor.deserializer = sagemaker.predictor.json_deserializer result = predictor.predict(train_set[0][:10]) - assert len(result['predictions']) == 10 - for prediction in result['predictions']: - assert prediction['score'] is not None + assert len(result["predictions"]) == 10 + for prediction in result["predictions"]: + assert prediction["score"] is not None def test_async_byo_estimator(sagemaker_session, region): image_name = registry(region) + "/factorization-machines:1" - endpoint_name = unique_name_from_base('byo') - training_data_path = os.path.join(DATA_DIR, 'dummy_tensor') - job_name = unique_name_from_base('byo') + endpoint_name = unique_name_from_base("byo") + training_data_path = os.path.join(DATA_DIR, "dummy_tensor") + job_name = unique_name_from_base("byo") with timeout(minutes=5): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - prefix = 'test_byo_estimator' - key = 'recordio-pb-data' + prefix = "test_byo_estimator" + key = "recordio-pb-data" - s3_train_data = sagemaker_session.upload_data(path=training_data_path, - key_prefix=os.path.join(prefix, 'train', key)) + s3_train_data = sagemaker_session.upload_data( + path=training_data_path, key_prefix=os.path.join(prefix, "train", key) + ) - estimator = Estimator(image_name=image_name, - role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) + estimator = Estimator( + image_name=image_name, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) - estimator.set_hyperparameters(num_factors=10, - feature_dim=784, - mini_batch_size=100, - predictor_type='binary_classifier') + estimator.set_hyperparameters( + num_factors=10, feature_dim=784, mini_batch_size=100, predictor_type="binary_classifier" + ) # training labels must be 'float32' - estimator.fit({'train': s3_train_data}, wait=False, job_name=job_name) + estimator.fit({"train": s3_train_data}, wait=False, job_name=job_name) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - estimator = Estimator.attach(training_job_name=job_name, - sagemaker_session=sagemaker_session) + estimator = Estimator.attach( + training_job_name=job_name, sagemaker_session=sagemaker_session + ) model = estimator.create_model() - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) predictor.serializer = fm_serializer - predictor.content_type = 'application/json' + predictor.content_type = "application/json" predictor.deserializer = sagemaker.predictor.json_deserializer result = predictor.predict(train_set[0][:10]) - assert len(result['predictions']) == 10 - for prediction in result['predictions']: - assert prediction['score'] is not None + assert len(result["predictions"]) == 10 + for prediction in result["predictions"]: + assert prediction["score"] is not None assert estimator.train_image() == image_name diff --git a/tests/integ/test_chainer_train.py b/tests/integ/test_chainer_train.py index 515b7cc05a..ebfee9469c 100644 --- a/tests/integ/test_chainer_train.py +++ b/tests/integ/test_chainer_train.py @@ -27,7 +27,7 @@ from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def chainer_training_job(sagemaker_session, chainer_full_version): return _run_mnist_training_job(sagemaker_session, "ml.c4.xlarge", 1, chainer_full_version) @@ -36,124 +36,160 @@ def test_distributed_cpu_training(sagemaker_session, chainer_full_version): _run_mnist_training_job(sagemaker_session, "ml.c4.xlarge", 2, chainer_full_version) -@pytest.mark.skipif(tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS - or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, - reason='no ml.p2 instances in these regions') +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS + or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, + reason="no ml.p2 instances in these regions", +) def test_distributed_gpu_training(sagemaker_session, chainer_full_version): _run_mnist_training_job(sagemaker_session, "ml.p2.xlarge", 2, chainer_full_version) def test_training_with_additional_hyperparameters(sagemaker_session, chainer_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'chainer_mnist') - - chainer = Chainer(entry_point=script_path, role='SageMakerRole', - train_instance_count=1, train_instance_type="ml.c4.xlarge", - framework_version=chainer_full_version, - py_version=PYTHON_VERSION, - sagemaker_session=sagemaker_session, hyperparameters={'epochs': 1}, - use_mpi=True, - num_processes=2, - process_slots_per_host=2, - additional_mpi_options="-x NCCL_DEBUG=INFO") - - train_input = chainer.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/chainer_mnist/train') - test_input = chainer.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/chainer_mnist/test') - - job_name = unique_name_from_base('test-chainer-training') - chainer.fit({'train': train_input, 'test': test_input}, job_name=job_name) + script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "chainer_mnist") + + chainer = Chainer( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + framework_version=chainer_full_version, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + hyperparameters={"epochs": 1}, + use_mpi=True, + num_processes=2, + process_slots_per_host=2, + additional_mpi_options="-x NCCL_DEBUG=INFO", + ) + + train_input = chainer.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/chainer_mnist/train" + ) + test_input = chainer.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/chainer_mnist/test" + ) + + job_name = unique_name_from_base("test-chainer-training") + chainer.fit({"train": train_input, "test": test_input}, job_name=job_name) return chainer.latest_training_job.name @pytest.mark.canary_quick @pytest.mark.regional_testing def test_attach_deploy(chainer_training_job, sagemaker_session): - endpoint_name = unique_name_from_base('test-chainer-attach-deploy') + endpoint_name = unique_name_from_base("test-chainer-attach-deploy") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = Chainer.attach(chainer_training_job, sagemaker_session=sagemaker_session) - predictor = estimator.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + predictor = estimator.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) _predict_and_assert(predictor) def test_deploy_model(chainer_training_job, sagemaker_session): - endpoint_name = unique_name_from_base('test-chainer-deploy-model') + endpoint_name = unique_name_from_base("test-chainer-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=chainer_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'mnist.py') - model = ChainerModel(model_data, 'SageMakerRole', entry_point=script_path, sagemaker_session=sagemaker_session) + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=chainer_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py") + model = ChainerModel( + model_data, + "SageMakerRole", + entry_point=script_path, + sagemaker_session=sagemaker_session, + ) predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) _predict_and_assert(predictor) def test_async_fit(sagemaker_session): with timeout(minutes=5): - training_job_name = _run_mnist_training_job(sagemaker_session, "ml.c4.xlarge", 1, - chainer_full_version=CHAINER_VERSION, wait=False) + training_job_name = _run_mnist_training_job( + sagemaker_session, "ml.c4.xlarge", 1, chainer_full_version=CHAINER_VERSION, wait=False + ) print("Waiting to re-attach to the training job: %s" % training_job_name) time.sleep(20) - endpoint_name = unique_name_from_base('test-chainer-async-fit') + endpoint_name = unique_name_from_base("test-chainer-async-fit") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): print("Re-attaching now to: %s" % training_job_name) - estimator = Chainer.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session) + estimator = Chainer.attach( + training_job_name=training_job_name, sagemaker_session=sagemaker_session + ) predictor = estimator.deploy(1, "ml.c4.xlarge", endpoint_name=endpoint_name) _predict_and_assert(predictor) def test_failed_training_job(sagemaker_session, chainer_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'failure_script.py') - - chainer = Chainer(entry_point=script_path, role='SageMakerRole', - framework_version=chainer_full_version, py_version=PYTHON_VERSION, - train_instance_count=1, train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) + script_path = os.path.join(DATA_DIR, "chainer_mnist", "failure_script.py") + + chainer = Chainer( + entry_point=script_path, + role="SageMakerRole", + framework_version=chainer_full_version, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) with pytest.raises(ValueError) as e: - chainer.fit(job_name=unique_name_from_base('test-chainer-training')) - assert 'ExecuteUserScriptError' in str(e.value) + chainer.fit(job_name=unique_name_from_base("test-chainer-training")) + assert "ExecuteUserScriptError" in str(e.value) -def _run_mnist_training_job(sagemaker_session, instance_type, instance_count, - chainer_full_version, wait=True): +def _run_mnist_training_job( + sagemaker_session, instance_type, instance_count, chainer_full_version, wait=True +): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'mnist.py') if instance_type == 1 else \ - os.path.join(DATA_DIR, 'chainer_mnist', 'distributed_mnist.py') - - data_path = os.path.join(DATA_DIR, 'chainer_mnist') - - chainer = Chainer(entry_point=script_path, role='SageMakerRole', - framework_version=chainer_full_version, py_version=PYTHON_VERSION, - train_instance_count=instance_count, train_instance_type=instance_type, - sagemaker_session=sagemaker_session, hyperparameters={'epochs': 1}) - - train_input = chainer.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/chainer_mnist/train') - test_input = chainer.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/chainer_mnist/test') - - job_name = unique_name_from_base('test-chainer-training') - chainer.fit({'train': train_input, 'test': test_input}, wait=wait, job_name=job_name) + script_path = ( + os.path.join(DATA_DIR, "chainer_mnist", "mnist.py") + if instance_type == 1 + else os.path.join(DATA_DIR, "chainer_mnist", "distributed_mnist.py") + ) + + data_path = os.path.join(DATA_DIR, "chainer_mnist") + + chainer = Chainer( + entry_point=script_path, + role="SageMakerRole", + framework_version=chainer_full_version, + py_version=PYTHON_VERSION, + train_instance_count=instance_count, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + hyperparameters={"epochs": 1}, + ) + + train_input = chainer.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/chainer_mnist/train" + ) + test_input = chainer.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/chainer_mnist/test" + ) + + job_name = unique_name_from_base("test-chainer-training") + chainer.fit({"train": train_input, "test": test_input}, wait=wait, job_name=job_name) return chainer.latest_training_job.name def _predict_and_assert(predictor): batch_size = 100 - data = numpy.zeros((batch_size, 784), dtype='float32') + data = numpy.zeros((batch_size, 784), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size - data = numpy.zeros((batch_size, 1, 28, 28), dtype='float32') + data = numpy.zeros((batch_size, 1, 28, 28), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size - data = numpy.zeros((batch_size, 28, 28), dtype='float32') + data = numpy.zeros((batch_size, 28, 28), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size diff --git a/tests/integ/test_data_upload.py b/tests/integ/test_data_upload.py index c009f137e9..3499c46733 100755 --- a/tests/integ/test_data_upload.py +++ b/tests/integ/test_data_upload.py @@ -18,28 +18,28 @@ from tests.integ import DATA_DIR -AES_ENCRYPTION_ENABLED = {'ServerSideEncryption': 'AES256'} +AES_ENCRYPTION_ENABLED = {"ServerSideEncryption": "AES256"} def test_upload_data_absolute_file(sagemaker_session): """Test the method ``Session.upload_data`` can upload one encrypted file to S3 bucket""" - data_path = os.path.join(DATA_DIR, 'upload_data_tests', 'file1.py') + data_path = os.path.join(DATA_DIR, "upload_data_tests", "file1.py") uploaded_file = sagemaker_session.upload_data(data_path, extra_args=AES_ENCRYPTION_ENABLED) parsed_url = urlparse(uploaded_file) - s3_client = sagemaker_session.boto_session.client('s3') - head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip('/')) - assert head['ServerSideEncryption'] == 'AES256' + s3_client = sagemaker_session.boto_session.client("s3") + head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip("/")) + assert head["ServerSideEncryption"] == "AES256" def test_upload_data_absolute_dir(sagemaker_session): """Test the method ``Session.upload_data`` can upload encrypted objects to S3 bucket""" - data_path = os.path.join(DATA_DIR, 'upload_data_tests', 'nested_dir') + data_path = os.path.join(DATA_DIR, "upload_data_tests", "nested_dir") uploaded_dir = sagemaker_session.upload_data(data_path, extra_args=AES_ENCRYPTION_ENABLED) parsed_url = urlparse(uploaded_dir) s3_bucket = parsed_url.netloc - s3_prefix = parsed_url.path.lstrip('/') - s3_client = sagemaker_session.boto_session.client('s3') + s3_prefix = parsed_url.path.lstrip("/") + s3_client = sagemaker_session.boto_session.client("s3") for file in os.listdir(data_path): - s3_key = '{}/{}'.format(s3_prefix, file) + s3_key = "{}/{}".format(s3_prefix, file) head = s3_client.head_object(Bucket=s3_bucket, Key=s3_key) - assert head['ServerSideEncryption'] == 'AES256' + assert head["ServerSideEncryption"] == "AES256" diff --git a/tests/integ/test_factorization_machines.py b/tests/integ/test_factorization_machines.py index 6456c24238..55a576b88a 100644 --- a/tests/integ/test_factorization_machines.py +++ b/tests/integ/test_factorization_machines.py @@ -25,30 +25,40 @@ def test_factorization_machines(sagemaker_session): - job_name = unique_name_from_base('fm') + job_name = unique_name_from_base("fm") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - fm = FactorizationMachines(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - num_factors=10, predictor_type='regressor', - epochs=2, clip_gradient=1e2, eps=0.001, rescale_grad=1.0 / 100, - sagemaker_session=sagemaker_session) + fm = FactorizationMachines( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + num_factors=10, + predictor_type="regressor", + epochs=2, + clip_gradient=1e2, + eps=0.001, + rescale_grad=1.0 / 100, + sagemaker_session=sagemaker_session, + ) # training labels must be 'float32' - fm.fit(fm.record_set(train_set[0][:200], train_set[1][:200].astype('float32')), - job_name=job_name) + fm.fit( + fm.record_set(train_set[0][:200], train_set[1][:200].astype("float32")), + job_name=job_name, + ) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = FactorizationMachinesModel(fm.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = FactorizationMachinesModel( + fm.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][:10]) assert len(result) == 10 @@ -57,37 +67,48 @@ def test_factorization_machines(sagemaker_session): def test_async_factorization_machines(sagemaker_session): - job_name = unique_name_from_base('fm') + job_name = unique_name_from_base("fm") with timeout(minutes=5): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - fm = FactorizationMachines(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - num_factors=10, predictor_type='regressor', - epochs=2, clip_gradient=1e2, eps=0.001, rescale_grad=1.0 / 100, - sagemaker_session=sagemaker_session) + fm = FactorizationMachines( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + num_factors=10, + predictor_type="regressor", + epochs=2, + clip_gradient=1e2, + eps=0.001, + rescale_grad=1.0 / 100, + sagemaker_session=sagemaker_session, + ) # training labels must be 'float32' - fm.fit(fm.record_set(train_set[0][:200], train_set[1][:200].astype('float32')), - job_name=job_name, - wait=False) + fm.fit( + fm.record_set(train_set[0][:200], train_set[1][:200].astype("float32")), + job_name=job_name, + wait=False, + ) print("Detached from training job. Will re-attach in 20 seconds") time.sleep(20) print("attaching now...") with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - estimator = FactorizationMachines.attach(training_job_name=job_name, - sagemaker_session=sagemaker_session) - model = FactorizationMachinesModel(estimator.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + estimator = FactorizationMachines.attach( + training_job_name=job_name, sagemaker_session=sagemaker_session + ) + model = FactorizationMachinesModel( + estimator.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][:10]) assert len(result) == 10 diff --git a/tests/integ/test_horovod.py b/tests/integ/test_horovod.py index 9a2bc8c737..76e50b15e6 100644 --- a/tests/integ/test_horovod.py +++ b/tests/integ/test_horovod.py @@ -25,31 +25,39 @@ from sagemaker.tensorflow import TensorFlow from tests.integ import test_region, timeout, HOSTING_NO_P3_REGIONS -horovod_dir = os.path.join(os.path.dirname(__file__), '..', 'data', 'horovod') - - -@pytest.fixture(scope='session', params=[ - 'ml.c4.xlarge', - pytest.param('ml.p3.2xlarge', - marks=pytest.mark.skipif( - test_region() in HOSTING_NO_P3_REGIONS, - reason='no ml.p3 instances in this region'))]) +horovod_dir = os.path.join(os.path.dirname(__file__), "..", "data", "horovod") + + +@pytest.fixture( + scope="session", + params=[ + "ml.c4.xlarge", + pytest.param( + "ml.p3.2xlarge", + marks=pytest.mark.skipif( + test_region() in HOSTING_NO_P3_REGIONS, reason="no ml.p3 instances in this region" + ), + ), + ], +) def instance_type(request): return request.param @pytest.mark.canary_quick def test_horovod(sagemaker_session, instance_type, tmpdir): - job_name = sagemaker.utils.unique_name_from_base('tf-horovod') - estimator = TensorFlow(entry_point=os.path.join(horovod_dir, 'test_hvd_basic.py'), - role='SageMakerRole', - train_instance_count=2, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session, - py_version=integ.PYTHON_VERSION, - script_mode=True, - framework_version='1.12', - distributions={'mpi': {'enabled': True}}) + job_name = sagemaker.utils.unique_name_from_base("tf-horovod") + estimator = TensorFlow( + entry_point=os.path.join(horovod_dir, "test_hvd_basic.py"), + role="SageMakerRole", + train_instance_count=2, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + py_version=integ.PYTHON_VERSION, + script_mode=True, + framework_version="1.12", + distributions={"mpi": {"enabled": True}}, + ) with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit(job_name=job_name) @@ -58,43 +66,41 @@ def test_horovod(sagemaker_session, instance_type, tmpdir): extract_files_from_s3(estimator.model_data, tmp) for rank in range(2): - assert read_json('rank-%s' % rank, tmp)['rank'] == rank + assert read_json("rank-%s" % rank, tmp)["rank"] == rank @pytest.mark.local_mode -@pytest.mark.parametrize('instances, processes', [ - [1, 2], - (2, 1), - (2, 2)]) +@pytest.mark.parametrize("instances, processes", [[1, 2], (2, 1), (2, 2)]) def test_horovod_local_mode(sagemaker_local_session, instances, processes, tmpdir): - output_path = 'file://%s' % tmpdir - job_name = sagemaker.utils.unique_name_from_base('tf-horovod') - estimator = TensorFlow(entry_point=os.path.join(horovod_dir, 'test_hvd_basic.py'), - role='SageMakerRole', - train_instance_count=2, - train_instance_type='local', - sagemaker_session=sagemaker_local_session, - py_version=integ.PYTHON_VERSION, - script_mode=True, - output_path=output_path, - framework_version='1.12', - distributions={'mpi': {'enabled': True, - 'processes_per_host': processes}}) + output_path = "file://%s" % tmpdir + job_name = sagemaker.utils.unique_name_from_base("tf-horovod") + estimator = TensorFlow( + entry_point=os.path.join(horovod_dir, "test_hvd_basic.py"), + role="SageMakerRole", + train_instance_count=2, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + py_version=integ.PYTHON_VERSION, + script_mode=True, + output_path=output_path, + framework_version="1.12", + distributions={"mpi": {"enabled": True, "processes_per_host": processes}}, + ) with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit(job_name=job_name) tmp = str(tmpdir) - extract_files(output_path.replace('file://', ''), tmp) + extract_files(output_path.replace("file://", ""), tmp) size = instances * processes for rank in range(size): - assert read_json('rank-%s' % rank, tmp)['rank'] == rank + assert read_json("rank-%s" % rank, tmp)["rank"] == rank def extract_files(output_path, tmpdir): - with tarfile.open(os.path.join(output_path, 'model.tar.gz')) as tar: + with tarfile.open(os.path.join(output_path, "model.tar.gz")) as tar: tar.extractall(tmpdir) @@ -105,10 +111,10 @@ def read_json(file, tmp): def extract_files_from_s3(s3_url, tmpdir): parsed_url = urlparse(s3_url) - s3 = boto3.resource('s3') + s3 = boto3.resource("s3") - model = os.path.join(tmpdir, 'model') - s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip('/'), model) + model = os.path.join(tmpdir, "model") + s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model) - with tarfile.open(model, 'r') as tar_file: + with tarfile.open(model, "r") as tar_file: tar_file.extractall(tmpdir) diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index bfd2f3bae2..b7bf489ce2 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -17,7 +17,10 @@ import pytest from tests.integ import DATA_DIR, TRANSFORM_DEFAULT_TIMEOUT_MINUTES -from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout_and_delete_model_with_transformer +from tests.integ.timeout import ( + timeout_and_delete_endpoint_by_name, + timeout_and_delete_model_with_transformer, +) from sagemaker.amazon.amazon_estimator import get_image_uri from sagemaker.content_types import CONTENT_TYPE_CSV @@ -27,111 +30,117 @@ from sagemaker.sparkml.model import SparkMLModel from sagemaker.utils import sagemaker_timestamp -SPARKML_DATA_PATH = os.path.join(DATA_DIR, 'sparkml_model') -XGBOOST_DATA_PATH = os.path.join(DATA_DIR, 'xgboost_model') -SPARKML_XGBOOST_DATA_DIR = 'sparkml_xgboost_pipeline' -VALID_DATA_PATH = os.path.join(DATA_DIR, SPARKML_XGBOOST_DATA_DIR, 'valid_input.csv') -INVALID_DATA_PATH = os.path.join(DATA_DIR, SPARKML_XGBOOST_DATA_DIR, 'invalid_input.csv') -SCHEMA = json.dumps({ - "input": [ - { - "name": "Pclass", - "type": "float" - }, - { - "name": "Embarked", - "type": "string" - }, - { - "name": "Age", - "type": "float" - }, - { - "name": "Fare", - "type": "float" - }, - { - "name": "SibSp", - "type": "float" - }, - { - "name": "Sex", - "type": "string" - } - ], - "output": { - "name": "features", - "struct": "vector", - "type": "double" +SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model") +XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model") +SPARKML_XGBOOST_DATA_DIR = "sparkml_xgboost_pipeline" +VALID_DATA_PATH = os.path.join(DATA_DIR, SPARKML_XGBOOST_DATA_DIR, "valid_input.csv") +INVALID_DATA_PATH = os.path.join(DATA_DIR, SPARKML_XGBOOST_DATA_DIR, "invalid_input.csv") +SCHEMA = json.dumps( + { + "input": [ + {"name": "Pclass", "type": "float"}, + {"name": "Embarked", "type": "string"}, + {"name": "Age", "type": "float"}, + {"name": "Fare", "type": "float"}, + {"name": "SibSp", "type": "float"}, + {"name": "Sex", "type": "string"}, + ], + "output": {"name": "features", "struct": "vector", "type": "double"}, } -}) +) @pytest.mark.continuous_testing @pytest.mark.regional_testing def test_inference_pipeline_batch_transform(sagemaker_session): sparkml_model_data = sagemaker_session.upload_data( - path=os.path.join(SPARKML_DATA_PATH, 'mleap_model.tar.gz'), - key_prefix='integ-test-data/sparkml/model') + path=os.path.join(SPARKML_DATA_PATH, "mleap_model.tar.gz"), + key_prefix="integ-test-data/sparkml/model", + ) xgb_model_data = sagemaker_session.upload_data( - path=os.path.join(XGBOOST_DATA_PATH, 'xgb_model.tar.gz'), - key_prefix='integ-test-data/xgboost/model') - batch_job_name = 'test-inference-pipeline-batch-{}'.format(sagemaker_timestamp()) - sparkml_model = SparkMLModel(model_data=sparkml_model_data, - env={'SAGEMAKER_SPARKML_SCHEMA': SCHEMA}, - sagemaker_session=sagemaker_session) - xgb_image = get_image_uri(sagemaker_session.boto_region_name, 'xgboost') - xgb_model = Model(model_data=xgb_model_data, image=xgb_image, - sagemaker_session=sagemaker_session) - model = PipelineModel(models=[sparkml_model, xgb_model], role='SageMakerRole', - sagemaker_session=sagemaker_session, name=batch_job_name) - transformer = model.transformer(1, 'ml.m4.xlarge') - transform_input_key_prefix = 'integ-test-data/sparkml_xgboost/transform' - transform_input = transformer.sagemaker_session.upload_data(path=VALID_DATA_PATH, - key_prefix=transform_input_key_prefix) + path=os.path.join(XGBOOST_DATA_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + batch_job_name = "test-inference-pipeline-batch-{}".format(sagemaker_timestamp()) + sparkml_model = SparkMLModel( + model_data=sparkml_model_data, + env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, + sagemaker_session=sagemaker_session, + ) + xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost") + xgb_model = Model( + model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session + ) + model = PipelineModel( + models=[sparkml_model, xgb_model], + role="SageMakerRole", + sagemaker_session=sagemaker_session, + name=batch_job_name, + ) + transformer = model.transformer(1, "ml.m4.xlarge") + transform_input_key_prefix = "integ-test-data/sparkml_xgboost/transform" + transform_input = transformer.sagemaker_session.upload_data( + path=VALID_DATA_PATH, key_prefix=transform_input_key_prefix + ) - with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, - minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): - transformer.transform(transform_input, content_type=CONTENT_TYPE_CSV, job_name=batch_job_name) + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): + transformer.transform( + transform_input, content_type=CONTENT_TYPE_CSV, job_name=batch_job_name + ) transformer.wait() @pytest.mark.canary_quick @pytest.mark.regional_testing def test_inference_pipeline_model_deploy(sagemaker_session): - sparkml_data_path = os.path.join(DATA_DIR, 'sparkml_model') - xgboost_data_path = os.path.join(DATA_DIR, 'xgboost_model') - endpoint_name = 'test-inference-pipeline-deploy-{}'.format(sagemaker_timestamp()) + sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") + xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model") + endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp()) sparkml_model_data = sagemaker_session.upload_data( - path=os.path.join(sparkml_data_path, 'mleap_model.tar.gz'), - key_prefix='integ-test-data/sparkml/model') + path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), + key_prefix="integ-test-data/sparkml/model", + ) xgb_model_data = sagemaker_session.upload_data( - path=os.path.join(xgboost_data_path, 'xgb_model.tar.gz'), - key_prefix='integ-test-data/xgboost/model') + path=os.path.join(xgboost_data_path, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - sparkml_model = SparkMLModel(model_data=sparkml_model_data, - env={'SAGEMAKER_SPARKML_SCHEMA': SCHEMA}, - sagemaker_session=sagemaker_session) - xgb_image = get_image_uri(sagemaker_session.boto_region_name, 'xgboost') - xgb_model = Model(model_data=xgb_model_data, image=xgb_image, - sagemaker_session=sagemaker_session) - model = PipelineModel(models=[sparkml_model, xgb_model], role='SageMakerRole', - sagemaker_session=sagemaker_session, name=endpoint_name) - model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) - predictor = RealTimePredictor(endpoint=endpoint_name, sagemaker_session=sagemaker_session, - serializer=json_serializer, content_type=CONTENT_TYPE_CSV, - accept=CONTENT_TYPE_CSV) + sparkml_model = SparkMLModel( + model_data=sparkml_model_data, + env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, + sagemaker_session=sagemaker_session, + ) + xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost") + xgb_model = Model( + model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session + ) + model = PipelineModel( + models=[sparkml_model, xgb_model], + role="SageMakerRole", + sagemaker_session=sagemaker_session, + name=endpoint_name, + ) + model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) + predictor = RealTimePredictor( + endpoint=endpoint_name, + sagemaker_session=sagemaker_session, + serializer=json_serializer, + content_type=CONTENT_TYPE_CSV, + accept=CONTENT_TYPE_CSV, + ) - with open(VALID_DATA_PATH, 'r') as f: + with open(VALID_DATA_PATH, "r") as f: valid_data = f.read() - assert predictor.predict(valid_data) == '0.714013934135' + assert predictor.predict(valid_data) == "0.714013934135" - with open(INVALID_DATA_PATH, 'r') as f: + with open(INVALID_DATA_PATH, "r") as f: invalid_data = f.read() - assert (predictor.predict(invalid_data) is None) + assert predictor.predict(invalid_data) is None model.delete_model() with pytest.raises(Exception) as exception: sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) - assert 'Could not find model' in str(exception.value) + assert "Could not find model" in str(exception.value) diff --git a/tests/integ/test_ipinsights.py b/tests/integ/test_ipinsights.py index 13ce103641..399f07ce61 100644 --- a/tests/integ/test_ipinsights.py +++ b/tests/integ/test_ipinsights.py @@ -25,35 +25,37 @@ def test_ipinsights(sagemaker_session): - job_name = unique_name_from_base('ipinsights') + job_name = unique_name_from_base("ipinsights") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'ipinsights') - data_filename = 'train.csv' + data_path = os.path.join(DATA_DIR, "ipinsights") + data_filename = "train.csv" - with open(os.path.join(data_path, data_filename), 'rb') as f: + with open(os.path.join(data_path, data_filename), "rb") as f: num_records = len(f.readlines()) ipinsights = IPInsights( - role='SageMakerRole', + role="SageMakerRole", train_instance_count=1, - train_instance_type='ml.c4.xlarge', + train_instance_type="ml.c4.xlarge", num_entity_vectors=10, vector_dim=100, - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) - record_set = prepare_record_set_from_local_files(data_path, ipinsights.data_location, - num_records, FEATURE_DIM, - sagemaker_session) + record_set = prepare_record_set_from_local_files( + data_path, ipinsights.data_location, num_records, FEATURE_DIM, sagemaker_session + ) ipinsights.fit(records=record_set, job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = IPInsightsModel(ipinsights.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = IPInsightsModel( + ipinsights.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) assert isinstance(predictor, RealTimePredictor) - predict_input = [['user_1', '1.1.1.1']] + predict_input = [["user_1", "1.1.1.1"]] result = predictor.predict(predict_input) assert len(result) == 1 diff --git a/tests/integ/test_kmeans.py b/tests/integ/test_kmeans.py index 33148d813c..3323e961cd 100644 --- a/tests/integ/test_kmeans.py +++ b/tests/integ/test_kmeans.py @@ -27,24 +27,28 @@ def test_kmeans(sagemaker_session): - job_name = unique_name_from_base('kmeans') + job_name = unique_name_from_base("kmeans") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - kmeans = KMeans(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - k=10, sagemaker_session=sagemaker_session) + kmeans = KMeans( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + sagemaker_session=sagemaker_session, + ) - kmeans.init_method = 'random' + kmeans.init_method = "random" kmeans.max_iterations = 1 kmeans.tol = 1 kmeans.num_trials = 1 - kmeans.local_init_method = 'kmeans++' + kmeans.local_init_method = "kmeans++" kmeans.half_life_time_size = 1 kmeans.epochs = 1 kmeans.center_factor = 1 @@ -59,15 +63,16 @@ def test_kmeans(sagemaker_session): epochs=str(kmeans.epochs), extra_center_factor=str(kmeans.center_factor), k=str(kmeans.k), - force_dense='True', + force_dense="True", ) kmeans.fit(kmeans.record_set(train_set[0][:100]), job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = KMeansModel(kmeans.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = KMeansModel( + kmeans.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][:10]) assert len(result) == 10 @@ -78,29 +83,33 @@ def test_kmeans(sagemaker_session): predictor.delete_model() with pytest.raises(Exception) as exception: sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) - assert 'Could not find model' in str(exception.value) + assert "Could not find model" in str(exception.value) def test_async_kmeans(sagemaker_session): - job_name = unique_name_from_base('kmeans') + job_name = unique_name_from_base("kmeans") with timeout(minutes=5): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - kmeans = KMeans(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - k=10, sagemaker_session=sagemaker_session) + kmeans = KMeans( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + sagemaker_session=sagemaker_session, + ) - kmeans.init_method = 'random' + kmeans.init_method = "random" kmeans.max_iterations = 1 kmeans.tol = 1 kmeans.num_trials = 1 - kmeans.local_init_method = 'kmeans++' + kmeans.local_init_method = "kmeans++" kmeans.half_life_time_size = 1 kmeans.epochs = 1 kmeans.center_factor = 1 @@ -115,7 +124,7 @@ def test_async_kmeans(sagemaker_session): epochs=str(kmeans.epochs), extra_center_factor=str(kmeans.center_factor), k=str(kmeans.k), - force_dense='True', + force_dense="True", ) kmeans.fit(kmeans.record_set(train_set[0][:100]), wait=False, job_name=job_name) @@ -126,9 +135,10 @@ def test_async_kmeans(sagemaker_session): with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): estimator = KMeans.attach(training_job_name=job_name, sagemaker_session=sagemaker_session) - model = KMeansModel(estimator.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = KMeansModel( + estimator.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][:10]) assert len(result) == 10 diff --git a/tests/integ/test_knn.py b/tests/integ/test_knn.py index 85fa5e1778..beafa002d6 100644 --- a/tests/integ/test_knn.py +++ b/tests/integ/test_knn.py @@ -25,28 +25,35 @@ def test_knn_regressor(sagemaker_session): - job_name = unique_name_from_base('knn') + job_name = unique_name_from_base("knn") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - knn = KNN(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - k=10, predictor_type='regressor', sample_size=500, - sagemaker_session=sagemaker_session) + knn = KNN( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + predictor_type="regressor", + sample_size=500, + sagemaker_session=sagemaker_session, + ) # training labels must be 'float32' - knn.fit(knn.record_set(train_set[0][:200], train_set[1][:200].astype('float32')), - job_name=job_name) + knn.fit( + knn.record_set(train_set[0][:200], train_set[1][:200].astype("float32")), + job_name=job_name, + ) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = KNNModel(knn.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = KNNModel(knn.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][:10]) assert len(result) == 10 @@ -55,36 +62,45 @@ def test_knn_regressor(sagemaker_session): def test_async_knn_classifier(sagemaker_session): - job_name = unique_name_from_base('knn') + job_name = unique_name_from_base("knn") with timeout(minutes=5): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - knn = KNN(role='SageMakerRole', - train_instance_count=1, train_instance_type='ml.c4.xlarge', - k=10, predictor_type='classifier', sample_size=500, - index_type='faiss.IVFFlat', index_metric='L2', - sagemaker_session=sagemaker_session) + knn = KNN( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + predictor_type="classifier", + sample_size=500, + index_type="faiss.IVFFlat", + index_metric="L2", + sagemaker_session=sagemaker_session, + ) # training labels must be 'float32' - knn.fit(knn.record_set(train_set[0][:200], train_set[1][:200].astype('float32')), - wait=False, job_name=job_name) + knn.fit( + knn.record_set(train_set[0][:200], train_set[1][:200].astype("float32")), + wait=False, + job_name=job_name, + ) print("Detached from training job. Will re-attach in 20 seconds") time.sleep(20) print("attaching now...") with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - estimator = KNN.attach(training_job_name=job_name, - sagemaker_session=sagemaker_session) - model = KNNModel(estimator.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + estimator = KNN.attach(training_job_name=job_name, sagemaker_session=sagemaker_session) + model = KNNModel( + estimator.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][:10]) assert len(result) == 10 diff --git a/tests/integ/test_lda.py b/tests/integ/test_lda.py index cc91cd7d8b..bfc2da792f 100644 --- a/tests/integ/test_lda.py +++ b/tests/integ/test_lda.py @@ -25,29 +25,33 @@ def test_lda(sagemaker_session): - job_name = unique_name_from_base('lda') + job_name = unique_name_from_base("lda") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'lda') - data_filename = 'nips-train_1.pbr' + data_path = os.path.join(DATA_DIR, "lda") + data_filename = "nips-train_1.pbr" - with open(os.path.join(data_path, data_filename), 'rb') as f: + with open(os.path.join(data_path, data_filename), "rb") as f: all_records = read_records(f) # all records must be same - feature_num = int(all_records[0].features['values'].float32_tensor.shape[0]) - - lda = LDA(role='SageMakerRole', train_instance_type='ml.c4.xlarge', num_topics=10, - sagemaker_session=sagemaker_session) - - record_set = prepare_record_set_from_local_files(data_path, lda.data_location, - len(all_records), feature_num, - sagemaker_session) + feature_num = int(all_records[0].features["values"].float32_tensor.shape[0]) + + lda = LDA( + role="SageMakerRole", + train_instance_type="ml.c4.xlarge", + num_topics=10, + sagemaker_session=sagemaker_session, + ) + + record_set = prepare_record_set_from_local_files( + data_path, lda.data_location, len(all_records), feature_num, sagemaker_session + ) lda.fit(records=record_set, mini_batch_size=100, job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = LDAModel(lda.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = LDAModel(lda.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) predict_input = np.random.rand(1, feature_num) result = predictor.predict(predict_input) diff --git a/tests/integ/test_linear_learner.py b/tests/integ/test_linear_learner.py index c22bbfc4a8..e754819e2c 100644 --- a/tests/integ/test_linear_learner.py +++ b/tests/integ/test_linear_learner.py @@ -29,23 +29,28 @@ @pytest.mark.canary_quick def test_linear_learner(sagemaker_session): - job_name = unique_name_from_base('linear-learner') + job_name = unique_name_from_base("linear-learner") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) train_set[1][:100] = 1 train_set[1][100:200] = 0 - train_set = train_set[0], train_set[1].astype(np.dtype('float32')) - - ll = LinearLearner('SageMakerRole', 1, 'ml.c4.2xlarge', - predictor_type='binary_classifier', sagemaker_session=sagemaker_session) - ll.binary_classifier_model_selection_criteria = 'accuracy' + train_set = train_set[0], train_set[1].astype(np.dtype("float32")) + + ll = LinearLearner( + "SageMakerRole", + 1, + "ml.c4.2xlarge", + predictor_type="binary_classifier", + sagemaker_session=sagemaker_session, + ) + ll.binary_classifier_model_selection_criteria = "accuracy" ll.target_recall = 0.5 ll.target_precision = 0.5 ll.positive_example_weight_mult = 0.1 @@ -53,12 +58,12 @@ def test_linear_learner(sagemaker_session): ll.use_bias = True ll.num_models = 1 ll.num_calibration_samples = 1 - ll.init_method = 'uniform' + ll.init_method = "uniform" ll.init_scale = 0.5 ll.init_sigma = 0.2 ll.init_bias = 5 - ll.optimizer = 'adam' - ll.loss = 'logistic' + ll.optimizer = "adam" + ll.loss = "logistic" ll.wd = 0.5 ll.l1 = 0.5 ll.momentum = 0.5 @@ -83,7 +88,7 @@ def test_linear_learner(sagemaker_session): ll.fit(ll.record_set(train_set[0][:200], train_set[1][:200]), job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - predictor = ll.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + predictor = ll.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][0:100]) assert len(result) == 100 @@ -93,27 +98,32 @@ def test_linear_learner(sagemaker_session): def test_linear_learner_multiclass(sagemaker_session): - job_name = unique_name_from_base('linear-learner') + job_name = unique_name_from_base("linear-learner") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - train_set = train_set[0], train_set[1].astype(np.dtype('float32')) + train_set = train_set[0], train_set[1].astype(np.dtype("float32")) - ll = LinearLearner('SageMakerRole', 1, 'ml.c4.2xlarge', - predictor_type='multiclass_classifier', num_classes=10, - sagemaker_session=sagemaker_session) + ll = LinearLearner( + "SageMakerRole", + 1, + "ml.c4.2xlarge", + predictor_type="multiclass_classifier", + num_classes=10, + sagemaker_session=sagemaker_session, + ) ll.epochs = 1 ll.fit(ll.record_set(train_set[0][:200], train_set[1][:200]), job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - predictor = ll.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + predictor = ll.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][0:100]) assert len(result) == 100 @@ -123,23 +133,28 @@ def test_linear_learner_multiclass(sagemaker_session): def test_async_linear_learner(sagemaker_session): - job_name = unique_name_from_base('linear-learner') + job_name = unique_name_from_base("linear-learner") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) train_set[1][:100] = 1 train_set[1][100:200] = 0 - train_set = train_set[0], train_set[1].astype(np.dtype('float32')) - - ll = LinearLearner('SageMakerRole', 1, 'ml.c4.2xlarge', - predictor_type='binary_classifier', sagemaker_session=sagemaker_session) - ll.binary_classifier_model_selection_criteria = 'accuracy' + train_set = train_set[0], train_set[1].astype(np.dtype("float32")) + + ll = LinearLearner( + "SageMakerRole", + 1, + "ml.c4.2xlarge", + predictor_type="binary_classifier", + sagemaker_session=sagemaker_session, + ) + ll.binary_classifier_model_selection_criteria = "accuracy" ll.target_recall = 0.5 ll.target_precision = 0.5 ll.positive_example_weight_mult = 0.1 @@ -147,12 +162,12 @@ def test_async_linear_learner(sagemaker_session): ll.use_bias = True ll.num_models = 1 ll.num_calibration_samples = 1 - ll.init_method = 'uniform' + ll.init_method = "uniform" ll.init_scale = 0.5 ll.init_sigma = 0.2 ll.init_bias = 5 - ll.optimizer = 'adam' - ll.loss = 'logistic' + ll.optimizer = "adam" + ll.loss = "logistic" ll.wd = 0.5 ll.l1 = 0.5 ll.momentum = 0.5 @@ -180,11 +195,13 @@ def test_async_linear_learner(sagemaker_session): time.sleep(20) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - estimator = LinearLearner.attach(training_job_name=job_name, - sagemaker_session=sagemaker_session) - model = LinearLearnerModel(estimator.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + estimator = LinearLearner.attach( + training_job_name=job_name, sagemaker_session=sagemaker_session + ) + model = LinearLearnerModel( + estimator.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) result = predictor.predict(train_set[0][0:100]) assert len(result) == 100 diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index e6b3ef2da8..f8f31c7516 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -29,9 +29,9 @@ from sagemaker.tensorflow import TensorFlow # endpoint tests all use the same port, so we use this lock to prevent concurrent execution -LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_local_mode_lock') -DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data') -DEFAULT_REGION = 'us-west-2' +LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_local_mode_lock") +DATA_PATH = os.path.join(DATA_DIR, "iris", "data") +DEFAULT_REGION = "us-west-2" class LocalNoS3Session(LocalSession): @@ -45,13 +45,7 @@ def __init__(self): def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.boto_session = boto3.Session(region_name=DEFAULT_REGION) if self.config is None: - self.config = { - 'local': - { - 'local_code': True, - 'region_name': DEFAULT_REGION - } - } + self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} self._region_name = DEFAULT_REGION self.sagemaker_client = LocalSagemakerClient(self) @@ -59,23 +53,30 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client): self.local_mode = True -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mxnet_model(sagemaker_local_session, mxnet_full_version): def _create_model(output_path): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', - train_instance_count=1, train_instance_type='local', - output_path=output_path, framework_version=mxnet_full_version, - sagemaker_session=sagemaker_local_session) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - - mx.fit({'train': train_input, 'test': test_input}) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="local", + output_path=output_path, + framework_version=mxnet_full_version, + sagemaker_session=sagemaker_local_session, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + mx.fit({"train": train_input, "test": test_input}) model = mx.create_model(1) return model @@ -83,39 +84,42 @@ def _create_model(output_path): @pytest.mark.local_mode -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tf_local_mode(tf_full_version, sagemaker_local_session): with timeout(minutes=5): - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - framework_version=tf_full_version, - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=1, - train_instance_type='local', - base_job_name='test-tf', - sagemaker_session=sagemaker_local_session) - - inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, - key_prefix='integ-test-data/tf_iris') + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + framework_version=tf_full_version, + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=1, + train_instance_type="local", + base_job_name="test-tf", + sagemaker_session=sagemaker_local_session, + ) + + inputs = estimator.sagemaker_session.upload_data( + path=DATA_PATH, key_prefix="integ-test-data/tf_iris" + ) estimator.fit(inputs) - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + print("job succeeded: {}".format(estimator.latest_training_job.name)) endpoint_name = estimator.latest_training_job.name with lock.lock(LOCK_PATH): try: - json_predictor = estimator.deploy(initial_instance_count=1, - instance_type='local', - endpoint_name=endpoint_name) + json_predictor = estimator.deploy( + initial_instance_count=1, instance_type="local", endpoint_name=endpoint_name + ) features = [6.4, 3.2, 4.5, 1.5] - dict_result = json_predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = json_predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = json_predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result finally: @@ -123,38 +127,40 @@ def test_tf_local_mode(tf_full_version, sagemaker_local_session): @pytest.mark.local_mode -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tf_distributed_local_mode(sagemaker_local_session): with timeout(minutes=5): - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=3, - train_instance_type='local', - base_job_name='test-tf', - sagemaker_session=sagemaker_local_session) - - inputs = 'file://' + DATA_PATH + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=3, + train_instance_type="local", + base_job_name="test-tf", + sagemaker_session=sagemaker_local_session, + ) + + inputs = "file://" + DATA_PATH estimator.fit(inputs) - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + print("job succeeded: {}".format(estimator.latest_training_job.name)) endpoint_name = estimator.latest_training_job.name with lock.lock(LOCK_PATH): try: - json_predictor = estimator.deploy(initial_instance_count=1, - instance_type='local', - endpoint_name=endpoint_name) + json_predictor = estimator.deploy( + initial_instance_count=1, instance_type="local", endpoint_name=endpoint_name + ) features = [6.4, 3.2, 4.5, 1.5] - dict_result = json_predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = json_predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = json_predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result finally: @@ -162,37 +168,39 @@ def test_tf_distributed_local_mode(sagemaker_local_session): @pytest.mark.local_mode -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tf_local_data(sagemaker_local_session): with timeout(minutes=5): - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=1, - train_instance_type='local', - base_job_name='test-tf', - sagemaker_session=sagemaker_local_session) - - inputs = 'file://' + DATA_PATH + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=1, + train_instance_type="local", + base_job_name="test-tf", + sagemaker_session=sagemaker_local_session, + ) + + inputs = "file://" + DATA_PATH estimator.fit(inputs) - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + print("job succeeded: {}".format(estimator.latest_training_job.name)) endpoint_name = estimator.latest_training_job.name with lock.lock(LOCK_PATH): try: - json_predictor = estimator.deploy(initial_instance_count=1, - instance_type='local', - endpoint_name=endpoint_name) + json_predictor = estimator.deploy( + initial_instance_count=1, instance_type="local", endpoint_name=endpoint_name + ) features = [6.4, 3.2, 4.5, 1.5] - dict_result = json_predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = json_predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = json_predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result finally: @@ -200,38 +208,40 @@ def test_tf_local_data(sagemaker_local_session): @pytest.mark.local_mode -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tf_local_data_local_script(): with timeout(minutes=5): - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=1, - train_instance_type='local', - base_job_name='test-tf', - sagemaker_session=LocalNoS3Session()) - - inputs = 'file://' + DATA_PATH + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=1, + train_instance_type="local", + base_job_name="test-tf", + sagemaker_session=LocalNoS3Session(), + ) + + inputs = "file://" + DATA_PATH estimator.fit(inputs) - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + print("job succeeded: {}".format(estimator.latest_training_job.name)) endpoint_name = estimator.latest_training_job.name with lock.lock(LOCK_PATH): try: - json_predictor = estimator.deploy(initial_instance_count=1, - instance_type='local', - endpoint_name=endpoint_name) + json_predictor = estimator.deploy( + initial_instance_count=1, instance_type="local", endpoint_name=endpoint_name + ) features = [6.4, 3.2, 4.5, 1.5] - dict_result = json_predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = json_predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = json_predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result finally: @@ -240,14 +250,14 @@ def test_tf_local_data_local_script(): @pytest.mark.local_mode def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model, mxnet_full_version): - path = 's3://%s' % sagemaker_local_session.default_bucket() + path = "s3://%s" % sagemaker_local_session.default_bucket() s3_model = mxnet_model(path) s3_model.sagemaker_session = sagemaker_local_session predictor = None with lock.lock(LOCK_PATH): try: - predictor = s3_model.deploy(initial_instance_count=1, instance_type='local') + predictor = s3_model.deploy(initial_instance_count=1, instance_type="local") data = numpy.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) finally: @@ -261,10 +271,10 @@ def test_local_mode_serving_from_local_model(tmpdir, sagemaker_local_session, mx with lock.lock(LOCK_PATH): try: - path = 'file://%s' % (str(tmpdir)) + path = "file://%s" % (str(tmpdir)) model = mxnet_model(path) model.sagemaker_session = sagemaker_local_session - predictor = model.deploy(initial_instance_count=1, instance_type='local') + predictor = model.deploy(initial_instance_count=1, instance_type="local") data = numpy.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) finally: @@ -274,24 +284,32 @@ def test_local_mode_serving_from_local_model(tmpdir, sagemaker_local_session, mx @pytest.mark.local_mode def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', py_version=PYTHON_VERSION, - train_instance_count=1, train_instance_type='local', - sagemaker_session=sagemaker_local_session, framework_version=mxnet_full_version) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - - mx.fit({'train': train_input, 'test': test_input}) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + framework_version=mxnet_full_version, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + mx.fit({"train": train_input, "test": test_input}) endpoint_name = mx.latest_training_job.name with lock.lock(LOCK_PATH): try: - predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name) + predictor = mx.deploy(1, "local", endpoint_name=endpoint_name) data = numpy.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) finally: @@ -300,23 +318,27 @@ def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version): @pytest.mark.local_mode def test_mxnet_local_data_local_script(mxnet_full_version): - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - script_path = os.path.join(data_path, 'mnist.py') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', - train_instance_count=1, train_instance_type='local', - framework_version=mxnet_full_version, - sagemaker_session=LocalNoS3Session()) - - train_input = 'file://' + os.path.join(data_path, 'train') - test_input = 'file://' + os.path.join(data_path, 'test') - - mx.fit({'train': train_input, 'test': test_input}) + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="local", + framework_version=mxnet_full_version, + sagemaker_session=LocalNoS3Session(), + ) + + train_input = "file://" + os.path.join(data_path, "train") + test_input = "file://" + os.path.join(data_path, "test") + + mx.fit({"train": train_input, "test": test_input}) endpoint_name = mx.latest_training_job.name with lock.lock(LOCK_PATH): try: - predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name) + predictor = mx.deploy(1, "local", endpoint_name=endpoint_name) data = numpy.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) finally: @@ -325,52 +347,68 @@ def test_mxnet_local_data_local_script(mxnet_full_version): @pytest.mark.local_mode def test_mxnet_training_failure(sagemaker_local_session, mxnet_full_version, tmpdir): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py') - - mx = MXNet(entry_point=script_path, - role='SageMakerRole', - framework_version=mxnet_full_version, - py_version=PYTHON_VERSION, - train_instance_count=1, - train_instance_type='local', - sagemaker_session=sagemaker_local_session, - output_path='file://{}'.format(tmpdir)) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + output_path="file://{}".format(tmpdir), + ) with pytest.raises(RuntimeError): mx.fit() - with tarfile.open(os.path.join(str(tmpdir), 'output.tar.gz')) as tar: - tar.getmember('failure') + with tarfile.open(os.path.join(str(tmpdir), "output.tar.gz")) as tar: + tar.getmember("failure") @pytest.mark.local_mode def test_local_transform_mxnet(sagemaker_local_session, tmpdir, mxnet_full_version): - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - script_path = os.path.join(data_path, 'mnist.py') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', framework_version=mxnet_full_version, - sagemaker_session=sagemaker_local_session) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + framework_version=mxnet_full_version, + sagemaker_session=sagemaker_local_session, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) with timeout(minutes=15): - mx.fit({'train': train_input, 'test': test_input}) - - transform_input_path = os.path.join(data_path, 'transform') - transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' - transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) - - output_path = 'file://%s' % (str(tmpdir)) - transformer = mx.transformer(1, 'local', assemble_with='Line', max_payload=1, - strategy='SingleRecord', output_path=output_path) + mx.fit({"train": train_input, "test": test_input}) + + transform_input_path = os.path.join(data_path, "transform") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) + + output_path = "file://%s" % (str(tmpdir)) + transformer = mx.transformer( + 1, + "local", + assemble_with="Line", + max_payload=1, + strategy="SingleRecord", + output_path=output_path, + ) with lock.lock(LOCK_PATH): - transformer.transform(transform_input, content_type='text/csv', split_type='Line') + transformer.transform(transform_input, content_type="text/csv", split_type="Line") transformer.wait() - assert os.path.exists(os.path.join(str(tmpdir), 'data.csv.out')) + assert os.path.exists(os.path.join(str(tmpdir), "data.csv.out")) diff --git a/tests/integ/test_marketplace.py b/tests/integ/test_marketplace.py index 3f43ff2596..2d7ad00871 100644 --- a/tests/integ/test_marketplace.py +++ b/tests/integ/test_marketplace.py @@ -37,37 +37,43 @@ # # Both are written by Amazon and are free to subscribe. -ALGORITHM_ARN = 'arn:aws:sagemaker:%s:%s:algorithm/scikit-decision-trees-' \ - '15423055-57b73412d2e93e9239e4e16f83298b8f' +ALGORITHM_ARN = ( + "arn:aws:sagemaker:%s:%s:algorithm/scikit-decision-trees-" + "15423055-57b73412d2e93e9239e4e16f83298b8f" +) -MODEL_PACKAGE_ARN = 'arn:aws:sagemaker:%s:%s:model-package/scikit-iris-detector-' \ - '154230595-8f00905c1f927a512b73ea29dd09ae30' +MODEL_PACKAGE_ARN = ( + "arn:aws:sagemaker:%s:%s:model-package/scikit-iris-detector-" + "154230595-8f00905c1f927a512b73ea29dd09ae30" +) @pytest.mark.canary_quick def test_marketplace_estimator(sagemaker_session): with timeout(minutes=15): - data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + data_path = os.path.join(DATA_DIR, "marketplace", "training") region = sagemaker_session.boto_region_name account = REGION_ACCOUNT_MAP[region] algorithm_arn = ALGORITHM_ARN % (region, account) algo = AlgorithmEstimator( algorithm_arn=algorithm_arn, - role='SageMakerRole', + role="SageMakerRole", train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) train_input = algo.sagemaker_session.upload_data( - path=data_path, key_prefix='integ-test-data/marketplace/train') + path=data_path, key_prefix="integ-test-data/marketplace/train" + ) - algo.fit({'training': train_input}) + algo.fit({"training": train_input}) - endpoint_name = 'test-marketplace-estimator{}'.format(sagemaker_timestamp()) + endpoint_name = "test-marketplace-estimator{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): - predictor = algo.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) - shape = pandas.read_csv(os.path.join(data_path, 'iris.csv'), header=None) + predictor = algo.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) + shape = pandas.read_csv(os.path.join(data_path, "iris.csv"), header=None) a = [50 * i for i in range(3)] b = [40 + i for i in range(10)] @@ -76,41 +82,48 @@ def test_marketplace_estimator(sagemaker_session): test_data = shape.iloc[indices[:-1]] test_x = test_data.iloc[:, 1:] - print(predictor.predict(test_x.values).decode('utf-8')) + print(predictor.predict(test_x.values).decode("utf-8")) def test_marketplace_attach(sagemaker_session): with timeout(minutes=15): - data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + data_path = os.path.join(DATA_DIR, "marketplace", "training") region = sagemaker_session.boto_region_name account = REGION_ACCOUNT_MAP[region] algorithm_arn = ALGORITHM_ARN % (region, account) mktplace = AlgorithmEstimator( algorithm_arn=algorithm_arn, - role='SageMakerRole', + role="SageMakerRole", train_instance_count=1, - train_instance_type='ml.c4.xlarge', + train_instance_type="ml.c4.xlarge", sagemaker_session=sagemaker_session, - base_job_name='test-marketplace') + base_job_name="test-marketplace", + ) train_input = mktplace.sagemaker_session.upload_data( - path=data_path, key_prefix='integ-test-data/marketplace/train') + path=data_path, key_prefix="integ-test-data/marketplace/train" + ) - mktplace.fit({'training': train_input}, wait=False) + mktplace.fit({"training": train_input}, wait=False) training_job_name = mktplace.latest_training_job.name - print('Waiting to re-attach to the training job: %s' % training_job_name) + print("Waiting to re-attach to the training job: %s" % training_job_name) time.sleep(20) - endpoint_name = 'test-marketplace-estimator{}'.format(sagemaker_timestamp()) + endpoint_name = "test-marketplace-estimator{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): - print('Re-attaching now to: %s' % training_job_name) - estimator = AlgorithmEstimator.attach(training_job_name=training_job_name, - sagemaker_session=sagemaker_session) - predictor = estimator.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, - serializer=sagemaker.predictor.csv_serializer) - shape = pandas.read_csv(os.path.join(data_path, 'iris.csv'), header=None) + print("Re-attaching now to: %s" % training_job_name) + estimator = AlgorithmEstimator.attach( + training_job_name=training_job_name, sagemaker_session=sagemaker_session + ) + predictor = estimator.deploy( + 1, + "ml.m4.xlarge", + endpoint_name=endpoint_name, + serializer=sagemaker.predictor.csv_serializer, + ) + shape = pandas.read_csv(os.path.join(data_path, "iris.csv"), header=None) a = [50 * i for i in range(3)] b = [40 + i for i in range(10)] indices = [i + j for i, j in itertools.product(a, b)] @@ -118,7 +131,7 @@ def test_marketplace_attach(sagemaker_session): test_data = shape.iloc[indices[:-1]] test_x = test_data.iloc[:, 1:] - print(predictor.predict(test_x.values).decode('utf-8')) + print(predictor.predict(test_x.values).decode("utf-8")) @pytest.mark.canary_quick @@ -132,16 +145,18 @@ def predict_wrapper(endpoint, session): endpoint, session, serializer=sagemaker.predictor.csv_serializer ) - model = ModelPackage(role='SageMakerRole', - model_package_arn=model_package_arn, - sagemaker_session=sagemaker_session, - predictor_cls=predict_wrapper) + model = ModelPackage( + role="SageMakerRole", + model_package_arn=model_package_arn, + sagemaker_session=sagemaker_session, + predictor_cls=predict_wrapper, + ) - endpoint_name = 'test-marketplace-model-endpoint{}'.format(sagemaker_timestamp()) + endpoint_name = "test-marketplace-model-endpoint{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) - data_path = os.path.join(DATA_DIR, 'marketplace', 'training') - shape = pandas.read_csv(os.path.join(data_path, 'iris.csv'), header=None) + predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) + data_path = os.path.join(DATA_DIR, "marketplace", "training") + shape = pandas.read_csv(os.path.join(data_path, "iris.csv"), header=None) a = [50 * i for i in range(3)] b = [40 + i for i in range(10)] indices = [i + j for i, j in itertools.product(a, b)] @@ -149,90 +164,100 @@ def predict_wrapper(endpoint, session): test_data = shape.iloc[indices[:-1]] test_x = test_data.iloc[:, 1:] - print(predictor.predict(test_x.values).decode('utf-8')) + print(predictor.predict(test_x.values).decode("utf-8")) def test_marketplace_tuning_job(sagemaker_session): - data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + data_path = os.path.join(DATA_DIR, "marketplace", "training") region = sagemaker_session.boto_region_name account = REGION_ACCOUNT_MAP[region] algorithm_arn = ALGORITHM_ARN % (region, account) mktplace = AlgorithmEstimator( algorithm_arn=algorithm_arn, - role='SageMakerRole', + role="SageMakerRole", train_instance_count=1, - train_instance_type='ml.c4.xlarge', + train_instance_type="ml.c4.xlarge", sagemaker_session=sagemaker_session, - base_job_name='test-marketplace') + base_job_name="test-marketplace", + ) train_input = mktplace.sagemaker_session.upload_data( - path=data_path, key_prefix='integ-test-data/marketplace/train') + path=data_path, key_prefix="integ-test-data/marketplace/train" + ) mktplace.set_hyperparameters(max_leaf_nodes=10) - hyperparameter_ranges = {'max_leaf_nodes': IntegerParameter(1, 100000)} + hyperparameter_ranges = {"max_leaf_nodes": IntegerParameter(1, 100000)} - tuner = HyperparameterTuner(estimator=mktplace, base_tuning_job_name='byo', - objective_metric_name='validation:accuracy', - hyperparameter_ranges=hyperparameter_ranges, - max_jobs=2, max_parallel_jobs=2) + tuner = HyperparameterTuner( + estimator=mktplace, + base_tuning_job_name="byo", + objective_metric_name="validation:accuracy", + hyperparameter_ranges=hyperparameter_ranges, + max_jobs=2, + max_parallel_jobs=2, + ) - tuner.fit({'training': train_input}, include_cls_metadata=False) + tuner.fit({"training": train_input}, include_cls_metadata=False) time.sleep(15) tuner.wait() def test_marketplace_transform_job(sagemaker_session): - data_path = os.path.join(DATA_DIR, 'marketplace', 'training') + data_path = os.path.join(DATA_DIR, "marketplace", "training") region = sagemaker_session.boto_region_name account = REGION_ACCOUNT_MAP[region] algorithm_arn = ALGORITHM_ARN % (region, account) algo = AlgorithmEstimator( algorithm_arn=algorithm_arn, - role='SageMakerRole', + role="SageMakerRole", train_instance_count=1, - train_instance_type='ml.c4.xlarge', + train_instance_type="ml.c4.xlarge", sagemaker_session=sagemaker_session, - base_job_name='test-marketplace') + base_job_name="test-marketplace", + ) train_input = algo.sagemaker_session.upload_data( - path=data_path, key_prefix='integ-test-data/marketplace/train') + path=data_path, key_prefix="integ-test-data/marketplace/train" + ) - shape = pandas.read_csv(data_path + '/iris.csv', header=None).drop([0], axis=1) + shape = pandas.read_csv(data_path + "/iris.csv", header=None).drop([0], axis=1) - transform_workdir = DATA_DIR + '/marketplace/transform' - shape.to_csv(transform_workdir + '/batchtransform_test.csv', index=False, header=False) + transform_workdir = DATA_DIR + "/marketplace/transform" + shape.to_csv(transform_workdir + "/batchtransform_test.csv", index=False, header=False) transform_input = algo.sagemaker_session.upload_data( - transform_workdir, - key_prefix='integ-test-data/marketplace/transform') + transform_workdir, key_prefix="integ-test-data/marketplace/transform" + ) - algo.fit({'training': train_input}) + algo.fit({"training": train_input}) - transformer = algo.transformer(1, 'ml.m4.xlarge') - transformer.transform(transform_input, content_type='text/csv') + transformer = algo.transformer(1, "ml.m4.xlarge") + transformer.transform(transform_input, content_type="text/csv") transformer.wait() def test_marketplace_transform_job_from_model_package(sagemaker_session): - data_path = os.path.join(DATA_DIR, 'marketplace', 'training') - shape = pandas.read_csv(data_path + '/iris.csv', header=None).drop([0], axis=1) + data_path = os.path.join(DATA_DIR, "marketplace", "training") + shape = pandas.read_csv(data_path + "/iris.csv", header=None).drop([0], axis=1) - TRANSFORM_WORKDIR = DATA_DIR + '/marketplace/transform' - shape.to_csv(TRANSFORM_WORKDIR + '/batchtransform_test.csv', index=False, header=False) + TRANSFORM_WORKDIR = DATA_DIR + "/marketplace/transform" + shape.to_csv(TRANSFORM_WORKDIR + "/batchtransform_test.csv", index=False, header=False) transform_input = sagemaker_session.upload_data( - TRANSFORM_WORKDIR, - key_prefix='integ-test-data/marketplace/transform') + TRANSFORM_WORKDIR, key_prefix="integ-test-data/marketplace/transform" + ) region = sagemaker_session.boto_region_name account = REGION_ACCOUNT_MAP[region] model_package_arn = MODEL_PACKAGE_ARN % (region, account) - model = ModelPackage(role='SageMakerRole', - model_package_arn=model_package_arn, - sagemaker_session=sagemaker_session) + model = ModelPackage( + role="SageMakerRole", + model_package_arn=model_package_arn, + sagemaker_session=sagemaker_session, + ) - transformer = model.transformer(1, 'ml.m4.xlarge') - transformer.transform(transform_input, content_type='text/csv') + transformer = model.transformer(1, "ml.m4.xlarge") + transformer.transform(transform_input, content_type="text/csv") transformer.wait() diff --git a/tests/integ/test_mxnet_train.py b/tests/integ/test_mxnet_train.py index b801b9252f..e4f1a96983 100644 --- a/tests/integ/test_mxnet_train.py +++ b/tests/integ/test_mxnet_train.py @@ -27,49 +27,64 @@ from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mxnet_training_job(sagemaker_session, mxnet_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', framework_version=mxnet_full_version, - py_version=PYTHON_VERSION, train_instance_count=1, train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - - mx.fit({'train': train_input, 'test': test_input}) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + mx.fit({"train": train_input, "test": test_input}) return mx.latest_training_job.name @pytest.mark.canary_quick @pytest.mark.regional_testing def test_attach_deploy(mxnet_training_job, sagemaker_session): - endpoint_name = 'test-mxnet-attach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session) - predictor = estimator.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + predictor = estimator.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) data = numpy.zeros(shape=(1, 1, 28, 28)) result = predictor.predict(data) assert result is not None def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version): - endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path, - py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=mxnet_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + model = MXNetModel( + model_data, + "SageMakerRole", + entry_point=script_path, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) data = numpy.zeros(shape=(1, 1, 28, 28)) result = predictor.predict(data) @@ -78,103 +93,153 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version) predictor.delete_model() with pytest.raises(Exception) as exception: sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) - assert 'Could not find model' in str(exception.value) + assert "Could not find model" in str(exception.value) def test_deploy_model_with_tags_and_kms(mxnet_training_job, sagemaker_session, mxnet_full_version): - endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path, - py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) - - tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=mxnet_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + model = MXNetModel( + model_data, + "SageMakerRole", + entry_point=script_path, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + + tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] kms_key_arn = get_or_create_kms_key(sagemaker_session) - model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags, kms_key=kms_key_arn) + model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name, tags=tags, kms_key=kms_key_arn) returned_model = sagemaker_session.describe_model(EndpointName=model.name) - returned_model_tags = sagemaker_session.list_tags(ResourceArn=returned_model['ModelArn'])['Tags'] + returned_model_tags = sagemaker_session.list_tags(ResourceArn=returned_model["ModelArn"])[ + "Tags" + ] endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name) - endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint['EndpointArn'])['Tags'] + endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint["EndpointArn"])["Tags"] - endpoint_config = sagemaker_session.describe_endpoint_config(EndpointConfigName=endpoint['EndpointConfigName']) - endpoint_config_tags = sagemaker_session.list_tags(ResourceArn=endpoint_config['EndpointConfigArn'])['Tags'] + endpoint_config = sagemaker_session.describe_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + endpoint_config_tags = sagemaker_session.list_tags( + ResourceArn=endpoint_config["EndpointConfigArn"] + )["Tags"] - production_variants = endpoint_config['ProductionVariants'] + production_variants = endpoint_config["ProductionVariants"] assert returned_model_tags == tags assert endpoint_config_tags == tags assert endpoint_tags == tags - assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge' - assert production_variants[0]['InitialInstanceCount'] == 1 - assert endpoint_config['KmsKeyId'] == kms_key_arn + assert production_variants[0]["InstanceType"] == "ml.m4.xlarge" + assert production_variants[0]["InitialInstanceCount"] == 1 + assert endpoint_config["KmsKeyId"] == kms_key_arn -def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version): - endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) +def test_deploy_model_with_update_endpoint( + mxnet_training_job, sagemaker_session, mxnet_full_version +): + endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path, - py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) - model.deploy(1, 'ml.t2.medium', endpoint_name=endpoint_name) + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=mxnet_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + model = MXNetModel( + model_data, + "SageMakerRole", + entry_point=script_path, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + model.deploy(1, "ml.t2.medium", endpoint_name=endpoint_name) old_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name) - old_config_name = old_endpoint['EndpointConfigName'] + old_config_name = old_endpoint["EndpointConfigName"] - model.deploy(1, 'ml.m4.xlarge', update_endpoint=True, endpoint_name=endpoint_name) - new_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)['ProductionVariants'] - new_production_variants = new_endpoint['ProductionVariants'] - new_config_name = new_endpoint['EndpointConfigName'] + model.deploy(1, "ml.m4.xlarge", update_endpoint=True, endpoint_name=endpoint_name) + new_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)[ + "ProductionVariants" + ] + new_production_variants = new_endpoint["ProductionVariants"] + new_config_name = new_endpoint["EndpointConfigName"] assert old_config_name != new_config_name - assert new_production_variants['InstanceType'] == 'ml.m4.xlarge' - assert new_production_variants['InitialInstanceCount'] == 1 - assert new_production_variants['AcceleratorType'] is None + assert new_production_variants["InstanceType"] == "ml.m4.xlarge" + assert new_production_variants["InitialInstanceCount"] == 1 + assert new_production_variants["AcceleratorType"] is None -def test_deploy_model_with_update_non_existing_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version): - endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) - expected_error_message = 'Endpoint with name "{}" does not exist; ' \ - 'please use an existing endpoint name'.format(endpoint_name) +def test_deploy_model_with_update_non_existing_endpoint( + mxnet_training_job, sagemaker_session, mxnet_full_version +): + endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) + expected_error_message = ( + 'Endpoint with name "{}" does not exist; ' + "please use an existing endpoint name".format(endpoint_name) + ) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path, - py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) - model.deploy(1, 'ml.t2.medium', endpoint_name=endpoint_name) + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=mxnet_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + model = MXNetModel( + model_data, + "SageMakerRole", + entry_point=script_path, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + model.deploy(1, "ml.t2.medium", endpoint_name=endpoint_name) sagemaker_session.describe_endpoint(EndpointName=endpoint_name) with pytest.raises(ValueError, message=expected_error_message): - model.deploy(1, 'ml.m4.xlarge', update_endpoint=True, endpoint_name='non-existing-endpoint') + model.deploy( + 1, "ml.m4.xlarge", update_endpoint=True, endpoint_name="non-existing-endpoint" + ) @pytest.mark.canary_quick @pytest.mark.regional_testing -@pytest.mark.skipif(tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS, - reason="EI isn't supported in that specific region.") -def test_deploy_model_with_accelerator(mxnet_training_job, sagemaker_session, ei_mxnet_full_version): - endpoint_name = 'test-mxnet-deploy-model-ei-{}'.format(sagemaker_timestamp()) +@pytest.mark.skipif( + tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS, + reason="EI isn't supported in that specific region.", +) +def test_deploy_model_with_accelerator( + mxnet_training_job, sagemaker_session, ei_mxnet_full_version +): + endpoint_name = "test-mxnet-deploy-model-ei-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path, - framework_version=ei_mxnet_full_version, py_version=PYTHON_VERSION, - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, accelerator_type='ml.eia1.medium') + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=mxnet_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + model = MXNetModel( + model_data, + "SageMakerRole", + entry_point=script_path, + framework_version=ei_mxnet_full_version, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + ) + predictor = model.deploy( + 1, "ml.m4.xlarge", endpoint_name=endpoint_name, accelerator_type="ml.eia1.medium" + ) data = numpy.zeros(shape=(1, 1, 28, 28)) result = predictor.predict(data) @@ -182,23 +247,31 @@ def test_deploy_model_with_accelerator(mxnet_training_job, sagemaker_session, ei def test_async_fit(sagemaker_session, mxnet_full_version): - endpoint_name = 'test-mxnet-attach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp()) with timeout(minutes=5): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', py_version=PYTHON_VERSION, - train_instance_count=1, train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session, framework_version=mxnet_full_version, - distributions={'parameter_server': {'enabled': True}}) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - - mx.fit({'train': train_input, 'test': test_input}, wait=False) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + distributions={"parameter_server": {"enabled": True}}, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + mx.fit({"train": train_input, "test": test_input}, wait=False) training_job_name = mx.latest_training_job.name print("Waiting to re-attach to the training job: %s" % training_job_name) @@ -206,8 +279,10 @@ def test_async_fit(sagemaker_session, mxnet_full_version): with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): print("Re-attaching now to: %s" % training_job_name) - estimator = MXNet.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session) - predictor = estimator.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + estimator = MXNet.attach( + training_job_name=training_job_name, sagemaker_session=sagemaker_session + ) + predictor = estimator.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) data = numpy.zeros(shape=(1, 1, 28, 28)) result = predictor.predict(data) assert result is not None @@ -215,12 +290,18 @@ def test_async_fit(sagemaker_session, mxnet_full_version): def test_failed_training_job(sagemaker_session, mxnet_full_version): with timeout(): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', framework_version=mxnet_full_version, - py_version=PYTHON_VERSION, train_instance_count=1, train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) with pytest.raises(ValueError) as e: mx.fit() - assert 'ExecuteUserScriptError' in str(e.value) + assert "ExecuteUserScriptError" in str(e.value) diff --git a/tests/integ/test_neo_mxnet.py b/tests/integ/test_neo_mxnet.py index 57083c00fc..fed278dc6e 100644 --- a/tests/integ/test_neo_mxnet.py +++ b/tests/integ/test_neo_mxnet.py @@ -23,66 +23,87 @@ import time -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mxnet_training_job(sagemaker_session, mxnet_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist_neo.py') - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', - framework_version=mxnet_full_version, - py_version=PYTHON_VERSION, train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - - mx.fit({'train': train_input, 'test': test_input}) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_neo.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + framework_version=mxnet_full_version, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + mx.fit({"train": train_input, "test": test_input}) return mx.latest_training_job.name @pytest.mark.canary_quick @pytest.mark.regional_testing -@pytest.mark.skip(reason="This should be enabled along with the Boto SDK release for Neo API changes") +@pytest.mark.skip( + reason="This should be enabled along with the Boto SDK release for Neo API changes" +) def test_attach_deploy(mxnet_training_job, sagemaker_session): - endpoint_name = 'test-mxnet-attach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session) - estimator.compile_model(target_instance_family='ml_m4', - input_shape={'data': [1, 1, 28, 28]}, - output_path=estimator.output_path) + estimator.compile_model( + target_instance_family="ml_m4", + input_shape={"data": [1, 1, 28, 28]}, + output_path=estimator.output_path, + ) - predictor = estimator.deploy(1, 'ml.m4.xlarge', use_compiled_model=True, - endpoint_name=endpoint_name) + predictor = estimator.deploy( + 1, "ml.m4.xlarge", use_compiled_model=True, endpoint_name=endpoint_name + ) predictor.content_type = "application/vnd+python.numpy+binary" data = numpy.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) -@pytest.mark.skip(reason="This should be enabled along with the Boto SDK release for Neo API changes") +@pytest.mark.skip( + reason="This should be enabled along with the Boto SDK release for Neo API changes" +) def test_deploy_model(mxnet_training_job, sagemaker_session): - endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=mxnet_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist_neo.py') - role = 'SageMakerRole' - model = MXNetModel(model_data, role, entry_point=script_path, - py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session) - - model.compile(target_instance_family='ml_m4', - input_shape={'data': [1, 1, 28, 28]}, - role=role, - job_name='test-deploy-model-compilation-job-{}'.format(int(time.time())), - output_path='/'.join(model_data.split('/')[:-1])) - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + TrainingJobName=mxnet_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_neo.py") + role = "SageMakerRole" + model = MXNetModel( + model_data, + role, + entry_point=script_path, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + ) + + model.compile( + target_instance_family="ml_m4", + input_shape={"data": [1, 1, 28, 28]}, + role=role, + job_name="test-deploy-model-compilation-job-{}".format(int(time.time())), + output_path="/".join(model_data.split("/")[:-1]), + ) + predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) predictor.content_type = "application/vnd+python.numpy+binary" data = numpy.zeros(shape=(1, 1, 28, 28)) diff --git a/tests/integ/test_ntm.py b/tests/integ/test_ntm.py index 4081f6c5d5..928162a4cd 100644 --- a/tests/integ/test_ntm.py +++ b/tests/integ/test_ntm.py @@ -27,30 +27,34 @@ @pytest.mark.canary_quick def test_ntm(sagemaker_session): - job_name = unique_name_from_base('ntm') + job_name = unique_name_from_base("ntm") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'ntm') - data_filename = 'nips-train_1.pbr' + data_path = os.path.join(DATA_DIR, "ntm") + data_filename = "nips-train_1.pbr" - with open(os.path.join(data_path, data_filename), 'rb') as f: + with open(os.path.join(data_path, data_filename), "rb") as f: all_records = read_records(f) # all records must be same - feature_num = int(all_records[0].features['values'].float32_tensor.shape[0]) + feature_num = int(all_records[0].features["values"].float32_tensor.shape[0]) - ntm = NTM(role='SageMakerRole', train_instance_count=1, train_instance_type='ml.c4.xlarge', - num_topics=10, - sagemaker_session=sagemaker_session) + ntm = NTM( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + num_topics=10, + sagemaker_session=sagemaker_session, + ) - record_set = prepare_record_set_from_local_files(data_path, ntm.data_location, - len(all_records), feature_num, - sagemaker_session) + record_set = prepare_record_set_from_local_files( + data_path, ntm.data_location, len(all_records), feature_num, sagemaker_session + ) ntm.fit(records=record_set, job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = NTMModel(ntm.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = NTMModel(ntm.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) predict_input = np.random.rand(1, feature_num) result = predictor.predict(predict_input) diff --git a/tests/integ/test_object2vec.py b/tests/integ/test_object2vec.py index 5795159585..1f33fc50e5 100644 --- a/tests/integ/test_object2vec.py +++ b/tests/integ/test_object2vec.py @@ -25,43 +25,45 @@ def test_object2vec(sagemaker_session): - job_name = unique_name_from_base('object2vec') + job_name = unique_name_from_base("object2vec") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'object2vec') - data_filename = 'train.jsonl' + data_path = os.path.join(DATA_DIR, "object2vec") + data_filename = "train.jsonl" - with open(os.path.join(data_path, data_filename), 'r') as f: + with open(os.path.join(data_path, data_filename), "r") as f: num_records = len(f.readlines()) object2vec = Object2Vec( - role='SageMakerRole', + role="SageMakerRole", train_instance_count=1, - train_instance_type='ml.c4.xlarge', + train_instance_type="ml.c4.xlarge", epochs=3, enc0_max_seq_len=20, enc0_vocab_size=45000, enc_dim=16, num_classes=3, negative_sampling_rate=0, - comparator_list='hadamard,concat,abs_diff', + comparator_list="hadamard,concat,abs_diff", tied_token_embedding_weight=False, - token_embedding_storage_type='dense', - sagemaker_session=sagemaker_session) + token_embedding_storage_type="dense", + sagemaker_session=sagemaker_session, + ) - record_set = prepare_record_set_from_local_files(data_path, object2vec.data_location, - num_records, FEATURE_NUM, - sagemaker_session) + record_set = prepare_record_set_from_local_files( + data_path, object2vec.data_location, num_records, FEATURE_NUM, sagemaker_session + ) object2vec.fit(records=record_set, job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = Object2VecModel(object2vec.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = Object2VecModel( + object2vec.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) assert isinstance(predictor, RealTimePredictor) - predict_input = {'instances': [{"in0": [354, 623], "in1": [16]}]} + predict_input = {"instances": [{"in0": [354, 623], "in1": [16]}]} result = predictor.predict(predict_input) diff --git a/tests/integ/test_pca.py b/tests/integ/test_pca.py index 1f07da164b..15b8da1095 100644 --- a/tests/integ/test_pca.py +++ b/tests/integ/test_pca.py @@ -25,30 +25,36 @@ def test_pca(sagemaker_session): - job_name = unique_name_from_base('pca') + job_name = unique_name_from_base("pca") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - pca = sagemaker.amazon.pca.PCA(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.m4.xlarge', - num_components=48, sagemaker_session=sagemaker_session) + pca = sagemaker.amazon.pca.PCA( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.m4.xlarge", + num_components=48, + sagemaker_session=sagemaker_session, + ) - pca.algorithm_mode = 'randomized' + pca.algorithm_mode = "randomized" pca.subtract_mean = True pca.extra_components = 5 pca.fit(pca.record_set(train_set[0][:100]), job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - pca_model = sagemaker.amazon.pca.PCAModel(model_data=pca.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = pca_model.deploy(initial_instance_count=1, instance_type="ml.c4.xlarge", - endpoint_name=job_name) + pca_model = sagemaker.amazon.pca.PCAModel( + model_data=pca.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = pca_model.deploy( + initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=job_name + ) result = predictor.predict(train_set[0][:5]) @@ -58,22 +64,26 @@ def test_pca(sagemaker_session): def test_async_pca(sagemaker_session): - job_name = unique_name_from_base('pca') + job_name = unique_name_from_base("pca") with timeout(minutes=5): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - pca = sagemaker.amazon.pca.PCA(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.m4.xlarge', - num_components=48, sagemaker_session=sagemaker_session, - base_job_name='test-pca') + pca = sagemaker.amazon.pca.PCA( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.m4.xlarge", + num_components=48, + sagemaker_session=sagemaker_session, + base_job_name="test-pca", + ) - pca.algorithm_mode = 'randomized' + pca.algorithm_mode = "randomized" pca.subtract_mean = True pca.extra_components = 5 pca.fit(pca.record_set(train_set[0][:100]), wait=False, job_name=job_name) @@ -82,13 +92,16 @@ def test_async_pca(sagemaker_session): time.sleep(20) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - estimator = sagemaker.amazon.pca.PCA.attach(training_job_name=job_name, - sagemaker_session=sagemaker_session) - - model = sagemaker.amazon.pca.PCAModel(estimator.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(initial_instance_count=1, instance_type="ml.c4.xlarge", - endpoint_name=job_name) + estimator = sagemaker.amazon.pca.PCA.attach( + training_job_name=job_name, sagemaker_session=sagemaker_session + ) + + model = sagemaker.amazon.pca.PCAModel( + estimator.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy( + initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=job_name + ) result = predictor.predict(train_set[0][:5]) diff --git a/tests/integ/test_pytorch_train.py b/tests/integ/test_pytorch_train.py index 847372d76d..4f9467d420 100644 --- a/tests/integ/test_pytorch_train.py +++ b/tests/integ/test_pytorch_train.py @@ -25,17 +25,17 @@ from sagemaker.pytorch.model import PyTorchModel from sagemaker.utils import sagemaker_timestamp -MNIST_DIR = os.path.join(DATA_DIR, 'pytorch_mnist') -MNIST_SCRIPT = os.path.join(MNIST_DIR, 'mnist.py') +MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist") +MNIST_SCRIPT = os.path.join(MNIST_DIR, "mnist.py") -@pytest.fixture(scope='module', name='pytorch_training_job') +@pytest.fixture(scope="module", name="pytorch_training_job") def fixture_training_job(sagemaker_session, pytorch_full_version): - instance_type = 'ml.c4.xlarge' + instance_type = "ml.c4.xlarge" with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type) - pytorch.fit({'training': _upload_training_data(pytorch)}) + pytorch.fit({"training": _upload_training_data(pytorch)}) return pytorch.latest_training_job.name @@ -43,10 +43,10 @@ def fixture_training_job(sagemaker_session, pytorch_full_version): @pytest.mark.regional_testing def test_sync_fit_deploy(pytorch_training_job, sagemaker_session): # TODO: add tests against local mode when it's ready to be used - endpoint_name = 'test-pytorch-sync-fit-attach-deploy{}'.format(sagemaker_timestamp()) + endpoint_name = "test-pytorch-sync-fit-attach-deploy{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = PyTorch.attach(pytorch_training_job, sagemaker_session=sagemaker_session) - predictor = estimator.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name) + predictor = estimator.deploy(1, "ml.c4.xlarge", endpoint_name=endpoint_name) data = numpy.zeros(shape=(1, 1, 28, 28), dtype=numpy.float32) predictor.predict(data) @@ -58,15 +58,20 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session): def test_deploy_model(pytorch_training_job, sagemaker_session): - endpoint_name = 'test-pytorch-deploy-model-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=pytorch_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - model = PyTorchModel(model_data, 'SageMakerRole', entry_point=MNIST_SCRIPT, - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + TrainingJobName=pytorch_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + model = PyTorchModel( + model_data, + "SageMakerRole", + entry_point=MNIST_SCRIPT, + sagemaker_session=sagemaker_session, + ) + predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) batch_size = 100 data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32) @@ -75,30 +80,33 @@ def test_deploy_model(pytorch_training_job, sagemaker_session): assert output.shape == (batch_size, 10) -@pytest.mark.skipif(tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS - or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, - reason='no ml.p2 instances in these regions') +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS + or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, + reason="no ml.p2 instances in these regions", +) def test_async_fit_deploy(sagemaker_session, pytorch_full_version): training_job_name = "" # TODO: add tests against local mode when it's ready to be used - instance_type = 'ml.p2.xlarge' + instance_type = "ml.p2.xlarge" with timeout(minutes=10): pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type) - pytorch.fit({'training': _upload_training_data(pytorch)}, wait=False) + pytorch.fit({"training": _upload_training_data(pytorch)}, wait=False) training_job_name = pytorch.latest_training_job.name print("Waiting to re-attach to the training job: %s" % training_job_name) time.sleep(20) if not _is_local_mode(instance_type): - endpoint_name = 'test-pytorch-async-fit-attach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-pytorch-async-fit-attach-deploy-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): print("Re-attaching now to: %s" % training_job_name) - estimator = PyTorch.attach(training_job_name=training_job_name, - sagemaker_session=sagemaker_session) + estimator = PyTorch.attach( + training_job_name=training_job_name, sagemaker_session=sagemaker_session + ) predictor = estimator.deploy(1, instance_type, endpoint_name=endpoint_name) batch_size = 100 @@ -110,30 +118,38 @@ def test_async_fit_deploy(sagemaker_session, pytorch_full_version): # TODO(nadiaya): Run against local mode when errors will be propagated def test_failed_training_job(sagemaker_session, pytorch_full_version): - script_path = os.path.join(MNIST_DIR, 'failure_script.py') + script_path = os.path.join(MNIST_DIR, "failure_script.py") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, - entry_point=script_path) + pytorch = _get_pytorch_estimator( + sagemaker_session, pytorch_full_version, entry_point=script_path + ) with pytest.raises(ValueError) as e: pytorch.fit() - assert 'ExecuteUserScriptError' in str(e.value) + assert "ExecuteUserScriptError" in str(e.value) def _upload_training_data(pytorch): - return pytorch.sagemaker_session.upload_data(path=os.path.join(MNIST_DIR, 'training'), - key_prefix='integ-test-data/pytorch_mnist/training') - - -def _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type='ml.c4.xlarge', - entry_point=MNIST_SCRIPT): - return PyTorch(entry_point=entry_point, role='SageMakerRole', - framework_version=pytorch_full_version, - py_version=PYTHON_VERSION, train_instance_count=1, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session) + return pytorch.sagemaker_session.upload_data( + path=os.path.join(MNIST_DIR, "training"), + key_prefix="integ-test-data/pytorch_mnist/training", + ) + + +def _get_pytorch_estimator( + sagemaker_session, pytorch_full_version, instance_type="ml.c4.xlarge", entry_point=MNIST_SCRIPT +): + return PyTorch( + entry_point=entry_point, + role="SageMakerRole", + framework_version=pytorch_full_version, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + ) def _is_local_mode(instance_type): - return instance_type == 'local' + return instance_type == "local" diff --git a/tests/integ/test_randomcutforest.py b/tests/integ/test_randomcutforest.py index cdacf006b6..0c74251fd1 100644 --- a/tests/integ/test_randomcutforest.py +++ b/tests/integ/test_randomcutforest.py @@ -21,24 +21,29 @@ def test_randomcutforest(sagemaker_session): - job_name = unique_name_from_base('randomcutforest') + job_name = unique_name_from_base("randomcutforest") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): # Generate a thousand 14-dimensional datapoints. feature_num = 14 train_input = np.random.rand(1000, feature_num) - rcf = RandomCutForest(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - num_trees=50, num_samples_per_tree=20, - sagemaker_session=sagemaker_session) + rcf = RandomCutForest( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + num_trees=50, + num_samples_per_tree=20, + sagemaker_session=sagemaker_session, + ) rcf.fit(records=rcf.record_set(train_input), job_name=job_name) with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session): - model = RandomCutForestModel(rcf.model_data, role='SageMakerRole', - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=job_name) + model = RandomCutForestModel( + rcf.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session + ) + predictor = model.deploy(1, "ml.c4.xlarge", endpoint_name=job_name) predict_input = np.random.rand(1, feature_num) result = predictor.predict(predict_input) diff --git a/tests/integ/test_record_set.py b/tests/integ/test_record_set.py index 96e2b84aa1..9dbeb51e2d 100644 --- a/tests/integ/test_record_set.py +++ b/tests/integ/test_record_set.py @@ -28,15 +28,19 @@ def test_record_set(sagemaker_session): In particular, test that the objects uploaded to the S3 bucket are encrypted. """ - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} - with gzip.open(data_path, 'rb') as file_object: + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} + with gzip.open(data_path, "rb") as file_object: train_set, _, _ = pickle.load(file_object, **pickle_args) - kmeans = KMeans(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - k=10, sagemaker_session=sagemaker_session) + kmeans = KMeans( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + sagemaker_session=sagemaker_session, + ) record_set = kmeans.record_set(train_set[0][:100], encrypt=True) parsed_url = urlparse(record_set.s3_data) - s3_client = sagemaker_session.boto_session.client('s3') - head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip('/')) - assert head['ServerSideEncryption'] == 'AES256' + s3_client = sagemaker_session.boto_session.client("s3") + head = s3_client.head_object(Bucket=parsed_url.netloc, Key=parsed_url.path.lstrip("/")) + assert head["ServerSideEncryption"] == "AES256" diff --git a/tests/integ/test_rl.py b/tests/integ/test_rl.py index 94cc139797..9288f49b16 100644 --- a/tests/integ/test_rl.py +++ b/tests/integ/test_rl.py @@ -22,26 +22,28 @@ from tests.integ import DATA_DIR, PYTHON_VERSION from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name -CPU_INSTANCE = 'ml.m4.xlarge' +CPU_INSTANCE = "ml.m4.xlarge" @pytest.mark.canary_quick -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="RL images supports only Python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.") def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version): estimator = _test_coach(sagemaker_session, RLFramework.MXNET, rl_coach_mxnet_full_version) - job_name = unique_name_from_base('test-coach-mxnet') + job_name = unique_name_from_base("test-coach-mxnet") with timeout(minutes=15): - estimator.fit(wait='False', job_name=job_name) + estimator.fit(wait="False", job_name=job_name) - estimator = RLEstimator.attach(estimator.latest_training_job.name, - sagemaker_session=sagemaker_session) + estimator = RLEstimator.attach( + estimator.latest_training_job.name, sagemaker_session=sagemaker_session + ) - endpoint_name = 'test-mxnet-coach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-mxnet-coach-deploy-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - predictor = estimator.deploy(1, CPU_INSTANCE, entry_point='mxnet_deploy.py', - endpoint_name=endpoint_name) + predictor = estimator.deploy( + 1, CPU_INSTANCE, entry_point="mxnet_deploy.py", endpoint_name=endpoint_name + ) observation = numpy.asarray([0, 0, 0, 0]) action = predictor.predict(observation) @@ -50,67 +52,71 @@ def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version): assert 0 < action[0][1] < 1 -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="RL images supports only Python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.") def test_coach_tf(sagemaker_session, rl_coach_tf_full_version): estimator = _test_coach(sagemaker_session, RLFramework.TENSORFLOW, rl_coach_tf_full_version) - job_name = unique_name_from_base('test-coach-tf') + job_name = unique_name_from_base("test-coach-tf") with timeout(minutes=15): estimator.fit(job_name=job_name) - endpoint_name = 'test-tf-coach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-tf-coach-deploy-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): predictor = estimator.deploy(1, CPU_INSTANCE) observation = numpy.asarray([0, 0, 0, 0]) action = predictor.predict(observation) - assert action == {'predictions': [[0.5, 0.5]]} + assert action == {"predictions": [[0.5, 0.5]]} def _test_coach(sagemaker_session, rl_framework, rl_coach_version): - source_dir = os.path.join(DATA_DIR, 'coach_cartpole') - dependencies = [os.path.join(DATA_DIR, 'sagemaker_rl')] - cartpole = 'train_coach.py' - - return RLEstimator(toolkit=RLToolkit.COACH, - toolkit_version=rl_coach_version, - framework=rl_framework, - entry_point=cartpole, - source_dir=source_dir, - role='SageMakerRole', - train_instance_count=1, - train_instance_type=CPU_INSTANCE, - sagemaker_session=sagemaker_session, - dependencies=dependencies, - hyperparameters={ - "save_model": 1, - "RLCOACH_PRESET": "preset_cartpole_clippedppo", - "rl.agent_params.algorithm.discount": 0.9, - "rl.evaluation_steps:EnvironmentEpisodes": 1, - }) + source_dir = os.path.join(DATA_DIR, "coach_cartpole") + dependencies = [os.path.join(DATA_DIR, "sagemaker_rl")] + cartpole = "train_coach.py" + + return RLEstimator( + toolkit=RLToolkit.COACH, + toolkit_version=rl_coach_version, + framework=rl_framework, + entry_point=cartpole, + source_dir=source_dir, + role="SageMakerRole", + train_instance_count=1, + train_instance_type=CPU_INSTANCE, + sagemaker_session=sagemaker_session, + dependencies=dependencies, + hyperparameters={ + "save_model": 1, + "RLCOACH_PRESET": "preset_cartpole_clippedppo", + "rl.agent_params.algorithm.discount": 0.9, + "rl.evaluation_steps:EnvironmentEpisodes": 1, + }, + ) @pytest.mark.canary_quick -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="RL images supports only Python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="RL images supports only Python 3.") def test_ray_tf(sagemaker_session, rl_ray_full_version): - source_dir = os.path.join(DATA_DIR, 'ray_cartpole') - cartpole = 'train_ray.py' - - estimator = RLEstimator(entry_point=cartpole, - source_dir=source_dir, - toolkit=RLToolkit.RAY, - framework=RLFramework.TENSORFLOW, - toolkit_version=rl_ray_full_version, - sagemaker_session=sagemaker_session, - role='SageMakerRole', - train_instance_type=CPU_INSTANCE, - train_instance_count=1) - job_name = unique_name_from_base('test-ray-tf') + source_dir = os.path.join(DATA_DIR, "ray_cartpole") + cartpole = "train_ray.py" + + estimator = RLEstimator( + entry_point=cartpole, + source_dir=source_dir, + toolkit=RLToolkit.RAY, + framework=RLFramework.TENSORFLOW, + toolkit_version=rl_ray_full_version, + sagemaker_session=sagemaker_session, + role="SageMakerRole", + train_instance_type=CPU_INSTANCE, + train_instance_count=1, + ) + job_name = unique_name_from_base("test-ray-tf") with timeout(minutes=15): estimator.fit(job_name=job_name) with pytest.raises(NotImplementedError) as e: estimator.deploy(1, CPU_INSTANCE) - assert 'Automatic deployment of Ray models is not currently available' in str(e.value) + assert "Automatic deployment of Ray models is not currently available" in str(e.value) diff --git a/tests/integ/test_sklearn_train.py b/tests/integ/test_sklearn_train.py index c332dcc73a..e4191ad6e8 100644 --- a/tests/integ/test_sklearn_train.py +++ b/tests/integ/test_sklearn_train.py @@ -26,118 +26,143 @@ from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def sklearn_training_job(sagemaker_session, sklearn_full_version): return _run_mnist_training_job(sagemaker_session, "ml.c4.xlarge", sklearn_full_version) -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.") def test_training_with_additional_hyperparameters(sagemaker_session, sklearn_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'sklearn_mnist') - - sklearn = SKLearn(entry_point=script_path, - role='SageMakerRole', - train_instance_type="ml.c4.xlarge", - framework_version=sklearn_full_version, - py_version=PYTHON_VERSION, - sagemaker_session=sagemaker_session, - hyperparameters={'epochs': 1}) - - train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/sklearn_mnist/train') - test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/sklearn_mnist/test') - job_name = unique_name_from_base('test-sklearn-hp') - - sklearn.fit({'train': train_input, 'test': test_input}, job_name=job_name) + script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "sklearn_mnist") + + sklearn = SKLearn( + entry_point=script_path, + role="SageMakerRole", + train_instance_type="ml.c4.xlarge", + framework_version=sklearn_full_version, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + hyperparameters={"epochs": 1}, + ) + + train_input = sklearn.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/sklearn_mnist/train" + ) + test_input = sklearn.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/sklearn_mnist/test" + ) + job_name = unique_name_from_base("test-sklearn-hp") + + sklearn.fit({"train": train_input, "test": test_input}, job_name=job_name) return sklearn.latest_training_job.name -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.") def test_training_with_network_isolation(sagemaker_session, sklearn_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'sklearn_mnist') - - sklearn = SKLearn(entry_point=script_path, - role='SageMakerRole', - train_instance_type="ml.c4.xlarge", - framework_version=sklearn_full_version, - py_version=PYTHON_VERSION, - sagemaker_session=sagemaker_session, - hyperparameters={'epochs': 1}, - enable_network_isolation=True) - - train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/sklearn_mnist/train') - test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/sklearn_mnist/test') - job_name = unique_name_from_base('test-sklearn-hp') - - sklearn.fit({'train': train_input, 'test': test_input}, job_name=job_name) - assert sagemaker_session.sagemaker_client \ - .describe_training_job(TrainingJobName=job_name)['EnableNetworkIsolation'] + script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "sklearn_mnist") + + sklearn = SKLearn( + entry_point=script_path, + role="SageMakerRole", + train_instance_type="ml.c4.xlarge", + framework_version=sklearn_full_version, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_session, + hyperparameters={"epochs": 1}, + enable_network_isolation=True, + ) + + train_input = sklearn.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/sklearn_mnist/train" + ) + test_input = sklearn.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/sklearn_mnist/test" + ) + job_name = unique_name_from_base("test-sklearn-hp") + + sklearn.fit({"train": train_input, "test": test_input}, job_name=job_name) + assert sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=job_name)[ + "EnableNetworkIsolation" + ] return sklearn.latest_training_job.name @pytest.mark.canary_quick @pytest.mark.regional_testing -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.") def test_attach_deploy(sklearn_training_job, sagemaker_session): - endpoint_name = 'test-sklearn-attach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-sklearn-attach-deploy-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = SKLearn.attach(sklearn_training_job, sagemaker_session=sagemaker_session) - predictor = estimator.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + predictor = estimator.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) _predict_and_assert(predictor) -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.") def test_deploy_model(sklearn_training_job, sagemaker_session): - endpoint_name = 'test-sklearn-deploy-model-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-sklearn-deploy-model-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=sklearn_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py') - model = SKLearnModel(model_data, 'SageMakerRole', entry_point=script_path, sagemaker_session=sagemaker_session) + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=sklearn_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") + model = SKLearnModel( + model_data, + "SageMakerRole", + entry_point=script_path, + sagemaker_session=sagemaker_session, + ) predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) _predict_and_assert(predictor) -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.") def test_async_fit(sagemaker_session): - endpoint_name = 'test-sklearn-attach-deploy-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-sklearn-attach-deploy-{}".format(sagemaker_timestamp()) with timeout(minutes=5): - training_job_name = _run_mnist_training_job(sagemaker_session, "ml.c4.xlarge", - sklearn_full_version=SKLEARN_VERSION, wait=False) + training_job_name = _run_mnist_training_job( + sagemaker_session, "ml.c4.xlarge", sklearn_full_version=SKLEARN_VERSION, wait=False + ) print("Waiting to re-attach to the training job: %s" % training_job_name) time.sleep(20) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): print("Re-attaching now to: %s" % training_job_name) - estimator = SKLearn.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session) + estimator = SKLearn.attach( + training_job_name=training_job_name, sagemaker_session=sagemaker_session + ) predictor = estimator.deploy(1, "ml.c4.xlarge", endpoint_name=endpoint_name) _predict_and_assert(predictor) -@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="Scikit-learn image supports only python 3.") +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.") def test_failed_training_job(sagemaker_session, sklearn_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'failure_script.py') - data_path = os.path.join(DATA_DIR, 'sklearn_mnist') - - sklearn = SKLearn(entry_point=script_path, role='SageMakerRole', - framework_version=sklearn_full_version, py_version=PYTHON_VERSION, - train_instance_count=1, train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) - - train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/sklearn_mnist/train') - job_name = unique_name_from_base('test-sklearn-failed') + script_path = os.path.join(DATA_DIR, "sklearn_mnist", "failure_script.py") + data_path = os.path.join(DATA_DIR, "sklearn_mnist") + + sklearn = SKLearn( + entry_point=script_path, + role="SageMakerRole", + framework_version=sklearn_full_version, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) + + train_input = sklearn.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/sklearn_mnist/train" + ) + job_name = unique_name_from_base("test-sklearn-failed") with pytest.raises(ValueError): sklearn.fit(train_input, job_name=job_name) @@ -146,35 +171,42 @@ def test_failed_training_job(sagemaker_session, sklearn_full_version): def _run_mnist_training_job(sagemaker_session, instance_type, sklearn_full_version, wait=True): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'sklearn_mnist', 'mnist.py') - - data_path = os.path.join(DATA_DIR, 'sklearn_mnist') - - sklearn = SKLearn(entry_point=script_path, role='SageMakerRole', - framework_version=sklearn_full_version, py_version=PYTHON_VERSION, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session, hyperparameters={'epochs': 1}) - - train_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/sklearn_mnist/train') - test_input = sklearn.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/sklearn_mnist/test') - job_name = unique_name_from_base('test-sklearn-mnist') - - sklearn.fit({'train': train_input, 'test': test_input}, wait=wait, job_name=job_name) + script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") + + data_path = os.path.join(DATA_DIR, "sklearn_mnist") + + sklearn = SKLearn( + entry_point=script_path, + role="SageMakerRole", + framework_version=sklearn_full_version, + py_version=PYTHON_VERSION, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + hyperparameters={"epochs": 1}, + ) + + train_input = sklearn.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/sklearn_mnist/train" + ) + test_input = sklearn.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/sklearn_mnist/test" + ) + job_name = unique_name_from_base("test-sklearn-mnist") + + sklearn.fit({"train": train_input, "test": test_input}, wait=wait, job_name=job_name) return sklearn.latest_training_job.name def _predict_and_assert(predictor): batch_size = 100 - data = numpy.zeros((batch_size, 784), dtype='float32') + data = numpy.zeros((batch_size, 784), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size - data = numpy.zeros((batch_size, 1, 28, 28), dtype='float32') + data = numpy.zeros((batch_size, 1, 28, 28), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size - data = numpy.zeros((batch_size, 28, 28), dtype='float32') + data = numpy.zeros((batch_size, 28, 28), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size diff --git a/tests/integ/test_source_dirs.py b/tests/integ/test_source_dirs.py index c018c6ca9f..38f1e82a8e 100644 --- a/tests/integ/test_source_dirs.py +++ b/tests/integ/test_source_dirs.py @@ -24,23 +24,28 @@ @pytest.mark.local_mode def test_source_dirs(tmpdir, sagemaker_local_session): - source_dir = os.path.join(DATA_DIR, 'pytorch_source_dirs') - lib = os.path.join(str(tmpdir), 'alexa.py') - - with open(lib, 'w') as f: - f.write('def question(to_anything): return 42') - - estimator = PyTorch(entry_point='train.py', role='SageMakerRole', source_dir=source_dir, - dependencies=[lib], - py_version=PYTHON_VERSION, train_instance_count=1, - train_instance_type='local', - sagemaker_session=sagemaker_local_session) + source_dir = os.path.join(DATA_DIR, "pytorch_source_dirs") + lib = os.path.join(str(tmpdir), "alexa.py") + + with open(lib, "w") as f: + f.write("def question(to_anything): return 42") + + estimator = PyTorch( + entry_point="train.py", + role="SageMakerRole", + source_dir=source_dir, + dependencies=[lib], + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + ) estimator.fit() # endpoint tests all use the same port, so we use this lock to prevent concurrent execution with lock.lock(): try: - predictor = estimator.deploy(initial_instance_count=1, instance_type='local') + predictor = estimator.deploy(initial_instance_count=1, instance_type="local") predict_response = predictor.predict([7]) assert predict_response == [49] finally: diff --git a/tests/integ/test_sparkml_serving.py b/tests/integ/test_sparkml_serving.py index e1b35c02da..71c6a65e09 100644 --- a/tests/integ/test_sparkml_serving.py +++ b/tests/integ/test_sparkml_serving.py @@ -27,50 +27,36 @@ @pytest.mark.regional_testing def test_sparkml_model_deploy(sagemaker_session): # Uploads an MLeap serialized MLeap model to S3 and use that to deploy a SparkML model to perform inference - data_path = os.path.join(DATA_DIR, 'sparkml_model') - endpoint_name = 'test-sparkml-deploy-{}'.format(sagemaker_timestamp()) - model_data = sagemaker_session.upload_data(path=os.path.join(data_path, 'mleap_model.tar.gz'), - key_prefix='integ-test-data/sparkml/model') - schema = json.dumps({ - "input": [ - { - "name": "Pclass", - "type": "float" - }, - { - "name": "Embarked", - "type": "string" - }, - { - "name": "Age", - "type": "float" - }, - { - "name": "Fare", - "type": "float" - }, - { - "name": "SibSp", - "type": "float" - }, - { - "name": "Sex", - "type": "string" - } - ], - "output": { - "name": "features", - "struct": "vector", - "type": "double" + data_path = os.path.join(DATA_DIR, "sparkml_model") + endpoint_name = "test-sparkml-deploy-{}".format(sagemaker_timestamp()) + model_data = sagemaker_session.upload_data( + path=os.path.join(data_path, "mleap_model.tar.gz"), + key_prefix="integ-test-data/sparkml/model", + ) + schema = json.dumps( + { + "input": [ + {"name": "Pclass", "type": "float"}, + {"name": "Embarked", "type": "string"}, + {"name": "Age", "type": "float"}, + {"name": "Fare", "type": "float"}, + {"name": "SibSp", "type": "float"}, + {"name": "Sex", "type": "string"}, + ], + "output": {"name": "features", "struct": "vector", "type": "double"}, } - }) + ) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - model = SparkMLModel(model_data=model_data, role='SageMakerRole', sagemaker_session=sagemaker_session, - env={'SAGEMAKER_SPARKML_SCHEMA': schema}) - predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name) + model = SparkMLModel( + model_data=model_data, + role="SageMakerRole", + sagemaker_session=sagemaker_session, + env={"SAGEMAKER_SPARKML_SCHEMA": schema}, + ) + predictor = model.deploy(1, "ml.m4.xlarge", endpoint_name=endpoint_name) valid_data = "1.0,C,38.0,71.5,1.0,female" assert predictor.predict(valid_data) == "1.0,0.0,38.0,1.0,71.5,0.0,1.0" invalid_data = "1.0,28.0,C,38.0,71.5,1.0" - assert (predictor.predict(invalid_data) is None) + assert predictor.predict(invalid_data) is None diff --git a/tests/integ/test_tf.py b/tests/integ/test_tf.py index 27e8cb3c0c..6d04b5026e 100644 --- a/tests/integ/test_tf.py +++ b/tests/integ/test_tf.py @@ -22,198 +22,240 @@ from sagemaker.utils import sagemaker_timestamp, unique_name_from_base from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, PYTHON_VERSION from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout -from tests.integ.vpc_test_utils import get_or_create_vpc_resources, setup_security_group_for_encryption +from tests.integ.vpc_test_utils import ( + get_or_create_vpc_resources, + setup_security_group_for_encryption, +) -DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data') +DATA_PATH = os.path.join(DATA_DIR, "iris", "data") -@pytest.fixture(scope='module') -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.fixture(scope="module") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def tf_training_job(sagemaker_session, tf_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - framework_version=tf_full_version, - training_steps=1, - evaluation_steps=1, - checkpoint_path='/opt/ml/model', - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session, - base_job_name='test-tf') - - inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf_iris') - job_name = unique_name_from_base('test-tf-train') + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + framework_version=tf_full_version, + training_steps=1, + evaluation_steps=1, + checkpoint_path="/opt/ml/model", + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + base_job_name="test-tf", + ) + + inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix="integ-test-data/tf_iris") + job_name = unique_name_from_base("test-tf-train") estimator.fit(inputs, job_name=job_name) - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + print("job succeeded: {}".format(estimator.latest_training_job.name)) return estimator.latest_training_job.name @pytest.mark.canary_quick @pytest.mark.regional_testing -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_deploy_model(sagemaker_session, tf_training_job): - endpoint_name = 'test-tf-deploy-model-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-tf-deploy-model-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=tf_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - model = TensorFlowModel(model_data, 'SageMakerRole', entry_point=script_path, - sagemaker_session=sagemaker_session) - - json_predictor = model.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge', - endpoint_name=endpoint_name) + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=tf_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + model = TensorFlowModel( + model_data, + "SageMakerRole", + entry_point=script_path, + sagemaker_session=sagemaker_session, + ) + + json_predictor = model.deploy( + initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=endpoint_name + ) features = [6.4, 3.2, 4.5, 1.5] - dict_result = json_predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = json_predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = json_predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result @pytest.mark.canary_quick @pytest.mark.regional_testing -@pytest.mark.skipif(tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS, - reason="EI isn't supported in that specific region.") -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif( + tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS, + reason="EI isn't supported in that specific region.", +) +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_deploy_model_with_accelerator(sagemaker_session, tf_training_job, ei_tf_full_version): - endpoint_name = 'test-tf-deploy-model-ei-{}'.format(sagemaker_timestamp()) + endpoint_name = "test-tf-deploy-model-ei-{}".format(sagemaker_timestamp()) with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=tf_training_job) - model_data = desc['ModelArtifacts']['S3ModelArtifacts'] - - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - model = TensorFlowModel(model_data, 'SageMakerRole', entry_point=script_path, - framework_version=ei_tf_full_version, sagemaker_session=sagemaker_session) - - json_predictor = model.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge', - endpoint_name=endpoint_name, accelerator_type='ml.eia1.medium') + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=tf_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + model = TensorFlowModel( + model_data, + "SageMakerRole", + entry_point=script_path, + framework_version=ei_tf_full_version, + sagemaker_session=sagemaker_session, + ) + + json_predictor = model.deploy( + initial_instance_count=1, + instance_type="ml.c4.xlarge", + endpoint_name=endpoint_name, + accelerator_type="ml.eia1.medium", + ) features = [6.4, 3.2, 4.5, 1.5] - dict_result = json_predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = json_predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = json_predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tf_async(sagemaker_session): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - training_steps=1, - evaluation_steps=1, - checkpoint_path='/opt/ml/model', - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session, - base_job_name='test-tf') - - inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf_iris') - job_name = unique_name_from_base('test-tf-async') + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + training_steps=1, + evaluation_steps=1, + checkpoint_path="/opt/ml/model", + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + base_job_name="test-tf", + ) + + inputs = estimator.sagemaker_session.upload_data( + path=DATA_PATH, key_prefix="integ-test-data/tf_iris" + ) + job_name = unique_name_from_base("test-tf-async") estimator.fit(inputs, wait=False, job_name=job_name) training_job_name = estimator.latest_training_job.name time.sleep(20) endpoint_name = training_job_name with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - estimator = TensorFlow.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session) - json_predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge', - endpoint_name=endpoint_name) + estimator = TensorFlow.attach( + training_job_name=training_job_name, sagemaker_session=sagemaker_session + ) + json_predictor = estimator.deploy( + initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=endpoint_name + ) result = json_predictor.predict([6.4, 3.2, 4.5, 1.5]) - print('predict result: {}'.format(result)) + print("predict result: {}".format(result)) -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tf_vpc_multi(sagemaker_session, tf_full_version): """Test Tensorflow multi-instance using the same VpcConfig for training and inference""" - instance_type = 'ml.c4.xlarge' + instance_type = "ml.c4.xlarge" instance_count = 2 - train_input = sagemaker_session.upload_data(path=os.path.join(DATA_DIR, 'iris', 'data'), - key_prefix='integ-test-data/tf_iris') - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') + train_input = sagemaker_session.upload_data( + path=os.path.join(DATA_DIR, "iris", "data"), key_prefix="integ-test-data/tf_iris" + ) + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") - ec2_client = sagemaker_session.boto_session.client('ec2') - subnet_ids, security_group_id = get_or_create_vpc_resources(ec2_client, - sagemaker_session.boto_session.region_name) + ec2_client = sagemaker_session.boto_session.client("ec2") + subnet_ids, security_group_id = get_or_create_vpc_resources( + ec2_client, sagemaker_session.boto_session.region_name + ) setup_security_group_for_encryption(ec2_client, security_group_id) - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - framework_version=tf_full_version, - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=instance_count, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session, - base_job_name='test-vpc-tf', - subnets=subnet_ids, - security_group_ids=[security_group_id], - encrypt_inter_container_traffic=True) - job_name = unique_name_from_base('test-tf-vpc-multi') + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + framework_version=tf_full_version, + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=instance_count, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + base_job_name="test-vpc-tf", + subnets=subnet_ids, + security_group_ids=[security_group_id], + encrypt_inter_container_traffic=True, + ) + job_name = unique_name_from_base("test-tf-vpc-multi") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit(train_input, job_name=job_name) - print('training job succeeded: {}'.format(estimator.latest_training_job.name)) + print("training job succeeded: {}".format(estimator.latest_training_job.name)) job_desc = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=estimator.latest_training_job.name) - assert set(subnet_ids) == set(job_desc['VpcConfig']['Subnets']) - assert [security_group_id] == job_desc['VpcConfig']['SecurityGroupIds'] - assert job_desc['EnableInterContainerTrafficEncryption'] is True + TrainingJobName=estimator.latest_training_job.name + ) + assert set(subnet_ids) == set(job_desc["VpcConfig"]["Subnets"]) + assert [security_group_id] == job_desc["VpcConfig"]["SecurityGroupIds"] + assert job_desc["EnableInterContainerTrafficEncryption"] is True endpoint_name = estimator.latest_training_job.name with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): model = estimator.create_model() - json_predictor = model.deploy(initial_instance_count=instance_count, instance_type='ml.c4.xlarge', - endpoint_name=endpoint_name) + json_predictor = model.deploy( + initial_instance_count=instance_count, + instance_type="ml.c4.xlarge", + endpoint_name=endpoint_name, + ) features = [6.4, 3.2, 4.5, 1.5] - dict_result = json_predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = json_predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = json_predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=model.name) - assert set(subnet_ids) == set(model_desc['VpcConfig']['Subnets']) - assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds'] + assert set(subnet_ids) == set(model_desc["VpcConfig"]["Subnets"]) + assert [security_group_id] == model_desc["VpcConfig"]["SecurityGroupIds"] -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_failed_tf_training(sagemaker_session, tf_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'iris', 'failure_script.py') - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - framework_version=tf_full_version, - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) - job_name = unique_name_from_base('test-tf-fail') + script_path = os.path.join(DATA_DIR, "iris", "failure_script.py") + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + framework_version=tf_full_version, + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) + job_name = unique_name_from_base("test-tf-fail") with pytest.raises(ValueError) as e: estimator.fit(job_name=job_name) - assert 'This failure is expected' in str(e.value) + assert "This failure is expected" in str(e.value) diff --git a/tests/integ/test_tf_cifar.py b/tests/integ/test_tf_cifar.py index 032b2cd962..ed7418b77d 100644 --- a/tests/integ/test_tf_cifar.py +++ b/tests/integ/test_tf_cifar.py @@ -24,7 +24,7 @@ from sagemaker.tensorflow import TensorFlow from sagemaker.utils import unique_name_from_base -PICKLE_CONTENT_TYPE = 'application/python-pickle' +PICKLE_CONTENT_TYPE = "application/python-pickle" class PickleSerializer(object): @@ -36,38 +36,48 @@ def __call__(self, data): @pytest.mark.canary_quick -@pytest.mark.skipif(tests.integ.PYTHON_VERSION != 'py2', - reason="TensorFlow image supports only python 2.") -@pytest.mark.skipif(tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS - or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, - reason='no ml.p2 instances in these regions') +@pytest.mark.skipif( + tests.integ.PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2." +) +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS + or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, + reason="no ml.p2 instances in these regions", +) def test_cifar(sagemaker_session, tf_full_version): with timeout(minutes=45): - script_path = os.path.join(tests.integ.DATA_DIR, 'cifar_10', 'source') - - dataset_path = os.path.join(tests.integ.DATA_DIR, 'cifar_10', 'data') - - estimator = TensorFlow(entry_point='resnet_cifar_10.py', source_dir=script_path, - role='SageMakerRole', - framework_version=tf_full_version, training_steps=500, - evaluation_steps=5, - train_instance_count=2, train_instance_type='ml.p2.xlarge', - sagemaker_session=sagemaker_session, train_max_run=45 * 60, - base_job_name='test-cifar') - - inputs = estimator.sagemaker_session.upload_data(path=dataset_path, - key_prefix='data/cifar10') - job_name = unique_name_from_base('test-tf-cifar') + script_path = os.path.join(tests.integ.DATA_DIR, "cifar_10", "source") + + dataset_path = os.path.join(tests.integ.DATA_DIR, "cifar_10", "data") + + estimator = TensorFlow( + entry_point="resnet_cifar_10.py", + source_dir=script_path, + role="SageMakerRole", + framework_version=tf_full_version, + training_steps=500, + evaluation_steps=5, + train_instance_count=2, + train_instance_type="ml.p2.xlarge", + sagemaker_session=sagemaker_session, + train_max_run=45 * 60, + base_job_name="test-cifar", + ) + + inputs = estimator.sagemaker_session.upload_data( + path=dataset_path, key_prefix="data/cifar10" + ) + job_name = unique_name_from_base("test-tf-cifar") estimator.fit(inputs, logs=False, job_name=job_name) - print('job succeeded: {}'.format(estimator.latest_training_job.name)) + print("job succeeded: {}".format(estimator.latest_training_job.name)) endpoint_name = estimator.latest_training_job.name with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge') + predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.p2.xlarge") predictor.serializer = PickleSerializer() predictor.content_type = PICKLE_CONTENT_TYPE data = np.random.randn(32, 32, 3) predict_response = predictor.predict(data) - assert len(predict_response['outputs']['probabilities']['floatVal']) == 10 + assert len(predict_response["outputs"]["probabilities"]["floatVal"]) == 10 diff --git a/tests/integ/test_tf_keras.py b/tests/integ/test_tf_keras.py index f5ba535a17..9fea4a7ca5 100644 --- a/tests/integ/test_tf_keras.py +++ b/tests/integ/test_tf_keras.py @@ -25,33 +25,42 @@ @pytest.mark.canary_quick -@pytest.mark.skipif(tests.integ.PYTHON_VERSION != 'py2', - reason="TensorFlow image supports only python 2.") -@pytest.mark.skipif(tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS, - reason='no ml.p2 instances in these regions') +@pytest.mark.skipif( + tests.integ.PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2." +) +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS, + reason="no ml.p2 instances in these regions", +) def test_keras(sagemaker_session, tf_full_version): - script_path = os.path.join(tests.integ.DATA_DIR, 'cifar_10', 'source') - dataset_path = os.path.join(tests.integ.DATA_DIR, 'cifar_10', 'data') + script_path = os.path.join(tests.integ.DATA_DIR, "cifar_10", "source") + dataset_path = os.path.join(tests.integ.DATA_DIR, "cifar_10", "data") with timeout(minutes=45): - estimator = TensorFlow(entry_point='keras_cnn_cifar_10.py', - source_dir=script_path, - role='SageMakerRole', sagemaker_session=sagemaker_session, - hyperparameters={'learning_rate': 1e-4, 'decay': 1e-6}, - training_steps=50, evaluation_steps=5, - train_instance_count=1, train_instance_type='ml.c4.xlarge', - train_max_run=45 * 60) - - inputs = estimator.sagemaker_session.upload_data(path=dataset_path, - key_prefix='data/cifar10') - job_name = unique_name_from_base('test-tf-keras') + estimator = TensorFlow( + entry_point="keras_cnn_cifar_10.py", + source_dir=script_path, + role="SageMakerRole", + sagemaker_session=sagemaker_session, + hyperparameters={"learning_rate": 1e-4, "decay": 1e-6}, + training_steps=50, + evaluation_steps=5, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + train_max_run=45 * 60, + ) + + inputs = estimator.sagemaker_session.upload_data( + path=dataset_path, key_prefix="data/cifar10" + ) + job_name = unique_name_from_base("test-tf-keras") estimator.fit(inputs, job_name=job_name) endpoint_name = estimator.latest_training_job.name with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.p2.xlarge') + predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.p2.xlarge") data = np.random.randn(32, 32, 3) predict_response = predictor.predict(data) - assert len(predict_response['outputs']['probabilities']['floatVal']) == 10 + assert len(predict_response["outputs"]["probabilities"]["floatVal"]) == 10 diff --git a/tests/integ/test_tf_script_mode.py b/tests/integ/test_tf_script_mode.py index 29cd08cc74..49255fd9a3 100644 --- a/tests/integ/test_tf_script_mode.py +++ b/tests/integ/test_tf_script_mode.py @@ -26,47 +26,56 @@ import tests.integ from tests.integ import timeout -ROLE = 'SageMakerRole' - -RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'data') -MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, 'tensorflow_mnist') -TFS_RESOURCE_PATH = os.path.join(RESOURCE_PATH, 'tfs', 'tfs-test-entrypoint-with-handler') - -SCRIPT = os.path.join(MNIST_RESOURCE_PATH, 'mnist.py') -PARAMETER_SERVER_DISTRIBUTION = {'parameter_server': {'enabled': True}} -MPI_DISTRIBUTION = {'mpi': {'enabled': True}} -TAGS = [{'Key': 'some-key', 'Value': 'some-value'}] - - -@pytest.fixture(scope='session', params=[ - 'ml.c4.xlarge', - pytest.param('ml.p2.xlarge', - marks=pytest.mark.skipif( - tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS - or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, - reason='no ml.p2 instances in this region'))]) +ROLE = "SageMakerRole" + +RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "data") +MNIST_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tensorflow_mnist") +TFS_RESOURCE_PATH = os.path.join(RESOURCE_PATH, "tfs", "tfs-test-entrypoint-with-handler") + +SCRIPT = os.path.join(MNIST_RESOURCE_PATH, "mnist.py") +PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}} +MPI_DISTRIBUTION = {"mpi": {"enabled": True}} +TAGS = [{"Key": "some-key", "Value": "some-value"}] + + +@pytest.fixture( + scope="session", + params=[ + "ml.c4.xlarge", + pytest.param( + "ml.p2.xlarge", + marks=pytest.mark.skipif( + tests.integ.test_region() in tests.integ.HOSTING_NO_P2_REGIONS + or tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS, + reason="no ml.p2 instances in this region", + ), + ), + ], +) def instance_type(request): return request.param def test_mnist(sagemaker_session, instance_type): - estimator = TensorFlow(entry_point=SCRIPT, - role='SageMakerRole', - train_instance_count=1, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session, - script_mode=True, - framework_version=TensorFlow.LATEST_VERSION, - metric_definitions=[ - {'Name': 'train:global_steps', 'Regex': r'global_step\/sec:\s(.*)'}]) + estimator = TensorFlow( + entry_point=SCRIPT, + role="SageMakerRole", + train_instance_count=1, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + script_mode=True, + framework_version=TensorFlow.LATEST_VERSION, + metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}], + ) inputs = estimator.sagemaker_session.upload_data( - path=os.path.join(MNIST_RESOURCE_PATH, 'data'), - key_prefix='scriptmode/mnist') + path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist" + ) with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): - estimator.fit(inputs=inputs, job_name=unique_name_from_base('test-tf-sm-mnist')) - _assert_s3_files_exist(estimator.model_dir, - ['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta']) + estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-mnist")) + _assert_s3_files_exist( + estimator.model_dir, ["graph.pbtxt", "model.ckpt-0.index", "model.ckpt-0.meta"] + ) df = estimator.training_job_analytics.dataframe() assert df.size > 0 @@ -74,105 +83,123 @@ def test_mnist(sagemaker_session, instance_type): def test_server_side_encryption(sagemaker_session): boto_session = sagemaker_session.boto_session with tests.integ.kms_utils.bucket_with_encryption(boto_session, ROLE) as ( - bucket_with_kms, kms_key): - output_path = os.path.join(bucket_with_kms, 'test-server-side-encryption', - time.strftime('%y%m%d-%H%M')) - - estimator = TensorFlow(entry_point=SCRIPT, - role=ROLE, - train_instance_count=1, - train_instance_type='ml.c5.xlarge', - sagemaker_session=sagemaker_session, - script_mode=True, - framework_version=TensorFlow.LATEST_VERSION, - code_location=output_path, - output_path=output_path, - model_dir='/opt/ml/model', - output_kms_key=kms_key) + bucket_with_kms, + kms_key, + ): + output_path = os.path.join( + bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M") + ) + + estimator = TensorFlow( + entry_point=SCRIPT, + role=ROLE, + train_instance_count=1, + train_instance_type="ml.c5.xlarge", + sagemaker_session=sagemaker_session, + script_mode=True, + framework_version=TensorFlow.LATEST_VERSION, + code_location=output_path, + output_path=output_path, + model_dir="/opt/ml/model", + output_kms_key=kms_key, + ) inputs = estimator.sagemaker_session.upload_data( - path=os.path.join(MNIST_RESOURCE_PATH, 'data'), - key_prefix='scriptmode/mnist') + path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist" + ) with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): - estimator.fit(inputs=inputs, - job_name=unique_name_from_base('test-server-side-encryption')) + estimator.fit( + inputs=inputs, job_name=unique_name_from_base("test-server-side-encryption") + ) @pytest.mark.canary_quick def test_mnist_distributed(sagemaker_session, instance_type): - estimator = TensorFlow(entry_point=SCRIPT, - role=ROLE, - train_instance_count=2, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session, - py_version=tests.integ.PYTHON_VERSION, - script_mode=True, - framework_version=TensorFlow.LATEST_VERSION, - distributions=PARAMETER_SERVER_DISTRIBUTION) + estimator = TensorFlow( + entry_point=SCRIPT, + role=ROLE, + train_instance_count=2, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + py_version=tests.integ.PYTHON_VERSION, + script_mode=True, + framework_version=TensorFlow.LATEST_VERSION, + distributions=PARAMETER_SERVER_DISTRIBUTION, + ) inputs = estimator.sagemaker_session.upload_data( - path=os.path.join(MNIST_RESOURCE_PATH, 'data'), - key_prefix='scriptmode/distributed_mnist') + path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/distributed_mnist" + ) with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): - estimator.fit(inputs=inputs, job_name=unique_name_from_base('test-tf-sm-distributed')) - _assert_s3_files_exist(estimator.model_dir, - ['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta']) + estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed")) + _assert_s3_files_exist( + estimator.model_dir, ["graph.pbtxt", "model.ckpt-0.index", "model.ckpt-0.meta"] + ) def test_mnist_async(sagemaker_session): - estimator = TensorFlow(entry_point=SCRIPT, - role=ROLE, - train_instance_count=1, - train_instance_type='ml.c5.4xlarge', - sagemaker_session=sagemaker_session, - script_mode=True, - framework_version=TensorFlow.LATEST_VERSION, - tags=TAGS) + estimator = TensorFlow( + entry_point=SCRIPT, + role=ROLE, + train_instance_count=1, + train_instance_type="ml.c5.4xlarge", + sagemaker_session=sagemaker_session, + script_mode=True, + framework_version=TensorFlow.LATEST_VERSION, + tags=TAGS, + ) inputs = estimator.sagemaker_session.upload_data( - path=os.path.join(MNIST_RESOURCE_PATH, 'data'), - key_prefix='scriptmode/mnist') - estimator.fit(inputs=inputs, wait=False, job_name=unique_name_from_base('test-tf-sm-async')) + path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist" + ) + estimator.fit(inputs=inputs, wait=False, job_name=unique_name_from_base("test-tf-sm-async")) training_job_name = estimator.latest_training_job.name time.sleep(20) endpoint_name = training_job_name - _assert_training_job_tags_match(sagemaker_session.sagemaker_client, - estimator.latest_training_job.name, TAGS) + _assert_training_job_tags_match( + sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS + ) with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - estimator = TensorFlow.attach(training_job_name=training_job_name, - sagemaker_session=sagemaker_session) - predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge', - endpoint_name=endpoint_name) + estimator = TensorFlow.attach( + training_job_name=training_job_name, sagemaker_session=sagemaker_session + ) + predictor = estimator.deploy( + initial_instance_count=1, instance_type="ml.c4.xlarge", endpoint_name=endpoint_name + ) result = predictor.predict(np.zeros(784)) - print('predict result: {}'.format(result)) + print("predict result: {}".format(result)) _assert_endpoint_tags_match(sagemaker_session.sagemaker_client, predictor.endpoint, TAGS) - _assert_model_tags_match(sagemaker_session.sagemaker_client, - estimator.latest_training_job.name, TAGS) + _assert_model_tags_match( + sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS + ) def test_deploy_with_input_handlers(sagemaker_session, instance_type): - estimator = TensorFlow(entry_point='inference.py', - source_dir=TFS_RESOURCE_PATH, - role=ROLE, - train_instance_count=1, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session, - script_mode=True, - framework_version=TensorFlow.LATEST_VERSION, - tags=TAGS) - - estimator.fit(job_name=unique_name_from_base('test-tf-tfs-deploy')) + estimator = TensorFlow( + entry_point="inference.py", + source_dir=TFS_RESOURCE_PATH, + role=ROLE, + train_instance_count=1, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + script_mode=True, + framework_version=TensorFlow.LATEST_VERSION, + tags=TAGS, + ) + + estimator.fit(job_name=unique_name_from_base("test-tf-tfs-deploy")) endpoint_name = estimator.latest_training_job.name with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - predictor = estimator.deploy(initial_instance_count=1, instance_type=instance_type, - endpoint_name=endpoint_name) + predictor = estimator.deploy( + initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name + ) - input_data = {'instances': [1.0, 2.0, 5.0]} - expected_result = {'predictions': [4.0, 4.5, 6.0]} + input_data = {"instances": [1.0, 2.0, 5.0]} + expected_result = {"predictions": [4.0, 4.5, 6.0]} result = predictor.predict(input_data) assert expected_result == result @@ -180,31 +207,33 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type): def _assert_s3_files_exist(s3_url, files): parsed_url = urlparse(s3_url) - s3 = boto3.client('s3') - contents = s3.list_objects_v2(Bucket=parsed_url.netloc, Prefix=parsed_url.path.lstrip('/'))[ - "Contents"] + s3 = boto3.client("s3") + contents = s3.list_objects_v2(Bucket=parsed_url.netloc, Prefix=parsed_url.path.lstrip("/"))[ + "Contents" + ] for f in files: - found = [x['Key'] for x in contents if x['Key'].endswith(f)] + found = [x["Key"] for x in contents if x["Key"].endswith(f)] if not found: - raise ValueError('File {} is not found under {}'.format(f, s3_url)) + raise ValueError("File {} is not found under {}".format(f, s3_url)) def _assert_tags_match(sagemaker_client, resource_arn, tags): - actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)['Tags'] + actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"] assert actual_tags == tags def _assert_model_tags_match(sagemaker_client, model_name, tags): model_description = sagemaker_client.describe_model(ModelName=model_name) - _assert_tags_match(sagemaker_client, model_description['ModelArn'], tags) + _assert_tags_match(sagemaker_client, model_description["ModelArn"], tags) def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags): endpoint_description = sagemaker_client.describe_endpoint(EndpointName=endpoint_name) - _assert_tags_match(sagemaker_client, endpoint_description['EndpointArn'], tags) + _assert_tags_match(sagemaker_client, endpoint_description["EndpointArn"], tags) def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags): training_job_description = sagemaker_client.describe_training_job( - TrainingJobName=training_job_name) - _assert_tags_match(sagemaker_client, training_job_description['TrainingJobArn'], tags) + TrainingJobName=training_job_name + ) + _assert_tags_match(sagemaker_client, training_job_description["TrainingJobArn"], tags) diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 9c7edb5dfe..a53ed52c2a 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -26,53 +26,66 @@ from sagemaker.tensorflow.serving import Model, Predictor -@pytest.fixture(scope='session', params=[ - 'ml.c5.xlarge', - pytest.param('ml.p3.2xlarge', - marks=pytest.mark.skipif( - tests.integ.test_region() in tests.integ.HOSTING_NO_P3_REGIONS, - reason='no ml.p3 instances in this region'))]) +@pytest.fixture( + scope="session", + params=[ + "ml.c5.xlarge", + pytest.param( + "ml.p3.2xlarge", + marks=pytest.mark.skipif( + tests.integ.test_region() in tests.integ.HOSTING_NO_P3_REGIONS, + reason="no ml.p3 instances in this region", + ), + ), + ], +) def instance_type(request): return request.param -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def tfs_predictor(instance_type, sagemaker_session, tf_full_version): - endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') + endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") model_data = sagemaker_session.upload_data( - path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), - key_prefix='tensorflow-serving/models') - with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, - sagemaker_session): - model = Model(model_data=model_data, role='SageMakerRole', - framework_version=tf_full_version, - sagemaker_session=sagemaker_session) + path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"), + key_prefix="tensorflow-serving/models", + ) + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + model = Model( + model_data=model_data, + role="SageMakerRole", + framework_version=tf_full_version, + sagemaker_session=sagemaker_session, + ) predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name) yield predictor def tar_dir(directory, tmpdir): - target = os.path.join(str(tmpdir), 'model.tar.gz') + target = os.path.join(str(tmpdir), "model.tar.gz") - with tarfile.open(target, mode='w:gz') as t: + with tarfile.open(target, mode="w:gz") as t: t.add(directory, arcname=os.path.sep) return target @pytest.fixture -def tfs_predictor_with_model_and_entry_point_same_tar(sagemaker_local_session, - tf_full_version, - tmpdir): - endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') +def tfs_predictor_with_model_and_entry_point_same_tar( + sagemaker_local_session, tf_full_version, tmpdir +): + endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") - model_tar = tar_dir(os.path.join(tests.integ.DATA_DIR, 'tfs/tfs-test-model-with-inference'), - tmpdir) + model_tar = tar_dir( + os.path.join(tests.integ.DATA_DIR, "tfs/tfs-test-model-with-inference"), tmpdir + ) - model = Model(model_data='file://' + model_tar, - role='SageMakerRole', - framework_version=tf_full_version, - sagemaker_session=sagemaker_local_session) - predictor = model.deploy(1, 'local', endpoint_name=endpoint_name) + model = Model( + model_data="file://" + model_tar, + role="SageMakerRole", + framework_version=tf_full_version, + sagemaker_session=sagemaker_local_session, + ) + predictor = model.deploy(1, "local", endpoint_name=endpoint_name) try: yield predictor @@ -80,27 +93,33 @@ def tfs_predictor_with_model_and_entry_point_same_tar(sagemaker_local_session, predictor.delete_endpoint() -@pytest.fixture(scope='module') -def tfs_predictor_with_model_and_entry_point_and_dependencies(sagemaker_local_session, - tf_full_version): - endpoint_name = sagemaker.utils.unique_name_from_base('sagemaker-tensorflow-serving') - - entry_point = os.path.join(tests.integ.DATA_DIR, - 'tfs/tfs-test-entrypoint-and-dependencies/inference.py') - dependencies = [os.path.join(tests.integ.DATA_DIR, - 'tfs/tfs-test-entrypoint-and-dependencies/dependency.py')] - - model_data = 'file://' + os.path.join(tests.integ.DATA_DIR, - 'tensorflow-serving-test-model.tar.gz') - - model = Model(entry_point=entry_point, - model_data=model_data, - role='SageMakerRole', - dependencies=dependencies, - framework_version=tf_full_version, - sagemaker_session=sagemaker_local_session) +@pytest.fixture(scope="module") +def tfs_predictor_with_model_and_entry_point_and_dependencies( + sagemaker_local_session, tf_full_version +): + endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") - predictor = model.deploy(1, 'local', endpoint_name=endpoint_name) + entry_point = os.path.join( + tests.integ.DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies/inference.py" + ) + dependencies = [ + os.path.join(tests.integ.DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies/dependency.py") + ] + + model_data = "file://" + os.path.join( + tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz" + ) + + model = Model( + entry_point=entry_point, + model_data=model_data, + role="SageMakerRole", + dependencies=dependencies, + framework_version=tf_full_version, + sagemaker_session=sagemaker_local_session, + ) + + predictor = model.deploy(1, "local", endpoint_name=endpoint_name) try: yield predictor @@ -108,56 +127,63 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(sagemaker_local_se predictor.delete_endpoint() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def tfs_predictor_with_accelerator(sagemaker_session, tf_full_version): endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving") - instance_type = 'ml.c4.large' - accelerator_type = 'ml.eia1.medium' + instance_type = "ml.c4.large" + accelerator_type = "ml.eia1.medium" model_data = sagemaker_session.upload_data( - path=os.path.join(tests.integ.DATA_DIR, 'tensorflow-serving-test-model.tar.gz'), - key_prefix='tensorflow-serving/models') - with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, - sagemaker_session): - model = Model(model_data=model_data, role='SageMakerRole', - framework_version=tf_full_version, - sagemaker_session=sagemaker_session) - predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name, - accelerator_type=accelerator_type) + path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"), + key_prefix="tensorflow-serving/models", + ) + with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + model = Model( + model_data=model_data, + role="SageMakerRole", + framework_version=tf_full_version, + sagemaker_session=sagemaker_session, + ) + predictor = model.deploy( + 1, instance_type, endpoint_name=endpoint_name, accelerator_type=accelerator_type + ) yield predictor @pytest.mark.canary_quick def test_predict(tfs_predictor, instance_type): # pylint: disable=W0613 - input_data = {'instances': [1.0, 2.0, 5.0]} - expected_result = {'predictions': [3.5, 4.0, 5.5]} + input_data = {"instances": [1.0, 2.0, 5.0]} + expected_result = {"predictions": [3.5, 4.0, 5.5]} result = tfs_predictor.predict(input_data) assert expected_result == result -@pytest.mark.skipif(tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS, - reason='EI is not supported in region {}'.format(tests.integ.test_region())) +@pytest.mark.skipif( + tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS, + reason="EI is not supported in region {}".format(tests.integ.test_region()), +) @pytest.mark.canary_quick def test_predict_with_accelerator(tfs_predictor_with_accelerator): - input_data = {'instances': [1.0, 2.0, 5.0]} - expected_result = {'predictions': [3.5, 4.0, 5.5]} + input_data = {"instances": [1.0, 2.0, 5.0]} + expected_result = {"predictions": [3.5, 4.0, 5.5]} result = tfs_predictor_with_accelerator.predict(input_data) assert expected_result == result def test_predict_with_entry_point(tfs_predictor_with_model_and_entry_point_same_tar): - input_data = {'instances': [1.0, 2.0, 5.0]} - expected_result = {'predictions': [4.0, 4.5, 6.0]} + input_data = {"instances": [1.0, 2.0, 5.0]} + expected_result = {"predictions": [4.0, 4.5, 6.0]} result = tfs_predictor_with_model_and_entry_point_same_tar.predict(input_data) assert expected_result == result def test_predict_with_model_and_entry_point_and_dependencies_separated( - tfs_predictor_with_model_and_entry_point_and_dependencies): - input_data = {'instances': [1.0, 2.0, 5.0]} - expected_result = {'predictions': [4.0, 4.5, 6.0]} + tfs_predictor_with_model_and_entry_point_and_dependencies +): + input_data = {"instances": [1.0, 2.0, 5.0]} + expected_result = {"predictions": [4.0, 4.5, 6.0]} result = tfs_predictor_with_model_and_entry_point_and_dependencies.predict(input_data) assert expected_result == result @@ -165,66 +191,78 @@ def test_predict_with_model_and_entry_point_and_dependencies_separated( def test_predict_generic_json(tfs_predictor): input_data = [[1.0, 2.0, 5.0], [1.0, 2.0, 5.0]] - expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} + expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} result = tfs_predictor.predict(input_data) assert expected_result == result def test_predict_jsons_json_content_type(tfs_predictor): - input_data = '[1.0, 2.0, 5.0]\n[1.0, 2.0, 5.0]' - expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} - - predictor = sagemaker.RealTimePredictor(tfs_predictor.endpoint, - tfs_predictor.sagemaker_session, serializer=None, - deserializer=sagemaker.predictor.json_deserializer, - content_type='application/json', - accept='application/json') + input_data = "[1.0, 2.0, 5.0]\n[1.0, 2.0, 5.0]" + expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} + + predictor = sagemaker.RealTimePredictor( + tfs_predictor.endpoint, + tfs_predictor.sagemaker_session, + serializer=None, + deserializer=sagemaker.predictor.json_deserializer, + content_type="application/json", + accept="application/json", + ) result = predictor.predict(input_data) assert expected_result == result def test_predict_jsons(tfs_predictor): - input_data = '[1.0, 2.0, 5.0]\n[1.0, 2.0, 5.0]' - expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} - - predictor = sagemaker.RealTimePredictor(tfs_predictor.endpoint, - tfs_predictor.sagemaker_session, serializer=None, - deserializer=sagemaker.predictor.json_deserializer, - content_type='application/jsons', - accept='application/jsons') + input_data = "[1.0, 2.0, 5.0]\n[1.0, 2.0, 5.0]" + expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} + + predictor = sagemaker.RealTimePredictor( + tfs_predictor.endpoint, + tfs_predictor.sagemaker_session, + serializer=None, + deserializer=sagemaker.predictor.json_deserializer, + content_type="application/jsons", + accept="application/jsons", + ) result = predictor.predict(input_data) assert expected_result == result def test_predict_jsonlines(tfs_predictor): - input_data = '[1.0, 2.0, 5.0]\n[1.0, 2.0, 5.0]' - expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} - - predictor = sagemaker.RealTimePredictor(tfs_predictor.endpoint, - tfs_predictor.sagemaker_session, serializer=None, - deserializer=sagemaker.predictor.json_deserializer, - content_type='application/jsonlines', - accept='application/jsonlines') + input_data = "[1.0, 2.0, 5.0]\n[1.0, 2.0, 5.0]" + expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} + + predictor = sagemaker.RealTimePredictor( + tfs_predictor.endpoint, + tfs_predictor.sagemaker_session, + serializer=None, + deserializer=sagemaker.predictor.json_deserializer, + content_type="application/jsonlines", + accept="application/jsonlines", + ) result = predictor.predict(input_data) assert expected_result == result def test_predict_csv(tfs_predictor): - input_data = '1.0,2.0,5.0\n1.0,2.0,5.0' - expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} + input_data = "1.0,2.0,5.0\n1.0,2.0,5.0" + expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} - predictor = Predictor(tfs_predictor.endpoint, tfs_predictor.sagemaker_session, - serializer=sagemaker.predictor.csv_serializer) + predictor = Predictor( + tfs_predictor.endpoint, + tfs_predictor.sagemaker_session, + serializer=sagemaker.predictor.csv_serializer, + ) result = predictor.predict(input_data) assert expected_result == result def test_predict_bad_input(tfs_predictor): - input_data = {'junk': 'data'} + input_data = {"junk": "data"} with pytest.raises(botocore.exceptions.ClientError): tfs_predictor.predict(input_data) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 02047def33..989b074380 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -24,7 +24,11 @@ from sagemaker.transformer import Transformer from sagemaker.estimator import Estimator from sagemaker.utils import unique_name_from_base -from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES +from tests.integ import ( + DATA_DIR, + TRAINING_DEFAULT_TIMEOUT_MINUTES, + TRANSFORM_DEFAULT_TIMEOUT_MINUTES, +) from tests.integ.kms_utils import get_or_create_kms_key from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer from tests.integ.vpc_test_utils import get_or_create_vpc_resources @@ -32,213 +36,285 @@ @pytest.mark.canary_quick def test_transform_mxnet(sagemaker_session, mxnet_full_version): - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - script_path = os.path.join(data_path, 'mnist.py') - - mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - job_name = unique_name_from_base('test-mxnet-transform') + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-transform") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) - transform_input_path = os.path.join(data_path, 'transform', 'data.csv') - transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' - transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) kms_key_arn = get_or_create_kms_key(sagemaker_session) output_filter = "$" - transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn, - input_filter=None, output_filter=output_filter, - join_source=None) - with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, - minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + transformer = _create_transformer_and_transform_job( + mx, + transform_input, + kms_key_arn, + input_filter=None, + output_filter=output_filter, + join_source=None, + ) + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): transformer.wait() job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job( - TransformJobName=transformer.latest_transform_job.name) - assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId'] - assert output_filter == job_desc['DataProcessing']['OutputFilter'] + TransformJobName=transformer.latest_transform_job.name + ) + assert kms_key_arn == job_desc["TransformResources"]["VolumeKmsKeyId"] + assert output_filter == job_desc["DataProcessing"]["OutputFilter"] @pytest.mark.canary_quick def test_attach_transform_kmeans(sagemaker_session): - data_path = os.path.join(DATA_DIR, 'one_p_mnist') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - train_set_path = os.path.join(data_path, 'mnist.pkl.gz') - with gzip.open(train_set_path, 'rb') as f: + train_set_path = os.path.join(data_path, "mnist.pkl.gz") + with gzip.open(train_set_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - kmeans = KMeans(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session, - output_path='s3://{}/'.format(sagemaker_session.default_bucket())) + kmeans = KMeans( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + sagemaker_session=sagemaker_session, + output_path="s3://{}/".format(sagemaker_session.default_bucket()), + ) # set kmeans specific hp - kmeans.init_method = 'random' + kmeans.init_method = "random" kmeans.max_iterators = 1 kmeans.tol = 1 kmeans.num_trials = 1 - kmeans.local_init_method = 'kmeans++' + kmeans.local_init_method = "kmeans++" kmeans.half_life_time_size = 1 kmeans.epochs = 1 records = kmeans.record_set(train_set[0][:100]) - job_name = unique_name_from_base('test-kmeans-attach') + job_name = unique_name_from_base("test-kmeans-attach") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): kmeans.fit(records, job_name=job_name) - transform_input_path = os.path.join(data_path, 'transform_input.csv') - transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform' - transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) + transform_input_path = os.path.join(data_path, "transform_input.csv") + transform_input_key_prefix = "integ-test-data/one_p_mnist/transform" + transform_input = kmeans.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) transformer = _create_transformer_and_transform_job(kmeans, transform_input) - attached_transformer = Transformer.attach(transformer.latest_transform_job.name, - sagemaker_session=sagemaker_session) - with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, - minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + attached_transformer = Transformer.attach( + transformer.latest_transform_job.name, sagemaker_session=sagemaker_session + ) + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): attached_transformer.wait() def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version): - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - script_path = os.path.join(data_path, 'mnist.py') - - ec2_client = sagemaker_session.boto_session.client('ec2') - subnet_ids, security_group_id = get_or_create_vpc_resources(ec2_client, - sagemaker_session.boto_session.region_name) - - mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version, subnets=subnet_ids, - security_group_ids=[security_group_id]) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - job_name = unique_name_from_base('test-mxnet-vpc') + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + + ec2_client = sagemaker_session.boto_session.client("ec2") + subnet_ids, security_group_id = get_or_create_vpc_resources( + ec2_client, sagemaker_session.boto_session.region_name + ) + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + subnets=subnet_ids, + security_group_ids=[security_group_id], + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-vpc") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) - job_desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mx.latest_training_job.name) - assert set(subnet_ids) == set(job_desc['VpcConfig']['Subnets']) - assert [security_group_id] == job_desc['VpcConfig']['SecurityGroupIds'] + job_desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=mx.latest_training_job.name + ) + assert set(subnet_ids) == set(job_desc["VpcConfig"]["Subnets"]) + assert [security_group_id] == job_desc["VpcConfig"]["SecurityGroupIds"] - transform_input_path = os.path.join(data_path, 'transform', 'data.csv') - transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' - transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) transformer = _create_transformer_and_transform_job(mx, transform_input) - with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, - minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): transformer.wait() - model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) - assert set(subnet_ids) == set(model_desc['VpcConfig']['Subnets']) - assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds'] + model_desc = sagemaker_session.sagemaker_client.describe_model( + ModelName=transformer.model_name + ) + assert set(subnet_ids) == set(model_desc["VpcConfig"]["Subnets"]) + assert [security_group_id] == model_desc["VpcConfig"]["SecurityGroupIds"] def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version): - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - script_path = os.path.join(data_path, 'mnist.py') - tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] - - mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) - - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - job_name = unique_name_from_base('test-mxnet-transform') + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + tags = [{"Key": "some-tag", "Value": "value-for-tag"}] + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-transform") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) - transform_input_path = os.path.join(data_path, 'transform', 'data.csv') - transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' - transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) - transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) - transformer.transform(transform_input, content_type='text/csv') + transformer = mx.transformer(1, "ml.m4.xlarge", tags=tags) + transformer.transform(transform_input, content_type="text/csv") - with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, - minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): transformer.wait() - model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) - model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags'] + model_desc = sagemaker_session.sagemaker_client.describe_model( + ModelName=transformer.model_name + ) + model_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=model_desc["ModelArn"] + )["Tags"] assert tags == model_tags def test_transform_byo_estimator(sagemaker_session): - data_path = os.path.join(DATA_DIR, 'one_p_mnist') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} - tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] + data_path = os.path.join(DATA_DIR, "one_p_mnist") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} + tags = [{"Key": "some-tag", "Value": "value-for-tag"}] # Load the data into memory as numpy arrays - train_set_path = os.path.join(data_path, 'mnist.pkl.gz') - with gzip.open(train_set_path, 'rb') as f: + train_set_path = os.path.join(data_path, "mnist.pkl.gz") + with gzip.open(train_set_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - kmeans = KMeans(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session, - output_path='s3://{}/'.format(sagemaker_session.default_bucket())) + kmeans = KMeans( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + sagemaker_session=sagemaker_session, + output_path="s3://{}/".format(sagemaker_session.default_bucket()), + ) # set kmeans specific hp - kmeans.init_method = 'random' + kmeans.init_method = "random" kmeans.max_iterators = 1 kmeans.tol = 1 kmeans.num_trials = 1 - kmeans.local_init_method = 'kmeans++' + kmeans.local_init_method = "kmeans++" kmeans.half_life_time_size = 1 kmeans.epochs = 1 records = kmeans.record_set(train_set[0][:100]) - job_name = unique_name_from_base('test-kmeans-attach') + job_name = unique_name_from_base("test-kmeans-attach") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): kmeans.fit(records, job_name=job_name) - transform_input_path = os.path.join(data_path, 'transform_input.csv') - transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform' - transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) + transform_input_path = os.path.join(data_path, "transform_input.csv") + transform_input_key_prefix = "integ-test-data/one_p_mnist/transform" + transform_input = kmeans.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) - estimator = Estimator.attach(training_job_name=job_name, - sagemaker_session=sagemaker_session) + estimator = Estimator.attach(training_job_name=job_name, sagemaker_session=sagemaker_session) - transformer = estimator.transformer(1, 'ml.m4.xlarge', tags=tags) - transformer.transform(transform_input, content_type='text/csv') + transformer = estimator.transformer(1, "ml.m4.xlarge", tags=tags) + transformer.transform(transform_input, content_type="text/csv") - with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, - minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): transformer.wait() - model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name) - model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags'] + model_desc = sagemaker_session.sagemaker_client.describe_model( + ModelName=transformer.model_name + ) + model_tags = sagemaker_session.sagemaker_client.list_tags( + ResourceArn=model_desc["ModelArn"] + )["Tags"] assert tags == model_tags -def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, - input_filter=None, output_filter=None, join_source=None): - transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) - transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter, - output_filter=output_filter, join_source=join_source) +def _create_transformer_and_transform_job( + estimator, + transform_input, + volume_kms_key=None, + input_filter=None, + output_filter=None, + join_source=None, +): + transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key) + transformer.transform( + transform_input, + content_type="text/csv", + input_filter=input_filter, + output_filter=output_filter, + join_source=join_source, + ) return transformer diff --git a/tests/integ/test_tuner.py b/tests/integ/test_tuner.py index 0fd14121e5..a7269220fe 100644 --- a/tests/integ/test_tuner.py +++ b/tests/integ/test_tuner.py @@ -36,58 +36,81 @@ from sagemaker.predictor import json_deserializer from sagemaker.pytorch import PyTorch from sagemaker.tensorflow import TensorFlow -from sagemaker.tuner import IntegerParameter, ContinuousParameter, CategoricalParameter, \ - HyperparameterTuner, \ - WarmStartConfig, WarmStartTypes, create_transfer_learning_tuner, \ - create_identical_dataset_and_algorithm_tuner +from sagemaker.tuner import ( + IntegerParameter, + ContinuousParameter, + CategoricalParameter, + HyperparameterTuner, + WarmStartConfig, + WarmStartTypes, + create_transfer_learning_tuner, + create_identical_dataset_and_algorithm_tuner, +) from sagemaker.utils import unique_name_from_base -DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data') +DATA_PATH = os.path.join(DATA_DIR, "iris", "data") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def kmeans_train_set(sagemaker_session): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} # Load the data into memory as numpy arrays - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) return train_set -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def kmeans_estimator(sagemaker_session): - kmeans = KMeans(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - k=10, sagemaker_session=sagemaker_session, - output_path='s3://{}/'.format(sagemaker_session.default_bucket())) + kmeans = KMeans( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + k=10, + sagemaker_session=sagemaker_session, + output_path="s3://{}/".format(sagemaker_session.default_bucket()), + ) # set kmeans specific hp - kmeans.init_method = 'random' + kmeans.init_method = "random" kmeans.max_iterators = 1 kmeans.tol = 1 kmeans.num_trials = 1 - kmeans.local_init_method = 'kmeans++' + kmeans.local_init_method = "kmeans++" kmeans.half_life_time_size = 1 kmeans.epochs = 1 return kmeans -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def hyperparameter_ranges(): - return {'extra_center_factor': IntegerParameter(1, 10), - 'mini_batch_size': IntegerParameter(10, 100), - 'epochs': IntegerParameter(1, 2), - 'init_method': CategoricalParameter(['kmeans++', 'random'])} - - -def _tune_and_deploy(kmeans_estimator, kmeans_train_set, sagemaker_session, - hyperparameter_ranges=None, job_name=None, - warm_start_config=None, early_stopping_type='Off'): - tuner = _tune(kmeans_estimator, kmeans_train_set, - hyperparameter_ranges=hyperparameter_ranges, warm_start_config=warm_start_config, - job_name=job_name, early_stopping_type=early_stopping_type) + return { + "extra_center_factor": IntegerParameter(1, 10), + "mini_batch_size": IntegerParameter(10, 100), + "epochs": IntegerParameter(1, 2), + "init_method": CategoricalParameter(["kmeans++", "random"]), + } + + +def _tune_and_deploy( + kmeans_estimator, + kmeans_train_set, + sagemaker_session, + hyperparameter_ranges=None, + job_name=None, + warm_start_config=None, + early_stopping_type="Off", +): + tuner = _tune( + kmeans_estimator, + kmeans_train_set, + hyperparameter_ranges=hyperparameter_ranges, + warm_start_config=warm_start_config, + job_name=job_name, + early_stopping_type=early_stopping_type, + ) _deploy(kmeans_train_set, sagemaker_session, tuner, early_stopping_type) @@ -95,36 +118,47 @@ def _deploy(kmeans_train_set, sagemaker_session, tuner, early_stopping_type): best_training_job = tuner.best_training_job() assert tuner.early_stopping_type == early_stopping_type with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): - predictor = tuner.deploy(1, 'ml.c4.xlarge') + predictor = tuner.deploy(1, "ml.c4.xlarge") result = predictor.predict(kmeans_train_set[0][:10]) assert len(result) == 10 for record in result: - assert record.label['closest_cluster'] is not None - assert record.label['distance_to_cluster'] is not None - - -def _tune(kmeans_estimator, kmeans_train_set, tuner=None, - hyperparameter_ranges=None, job_name=None, warm_start_config=None, - wait_till_terminal=True, max_jobs=2, max_parallel_jobs=2, early_stopping_type='Off'): + assert record.label["closest_cluster"] is not None + assert record.label["distance_to_cluster"] is not None + + +def _tune( + kmeans_estimator, + kmeans_train_set, + tuner=None, + hyperparameter_ranges=None, + job_name=None, + warm_start_config=None, + wait_till_terminal=True, + max_jobs=2, + max_parallel_jobs=2, + early_stopping_type="Off", +): with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): if not tuner: - tuner = HyperparameterTuner(estimator=kmeans_estimator, - objective_metric_name='test:msd', - hyperparameter_ranges=hyperparameter_ranges, - objective_type='Minimize', - max_jobs=max_jobs, - max_parallel_jobs=max_parallel_jobs, - warm_start_config=warm_start_config, - early_stopping_type=early_stopping_type) + tuner = HyperparameterTuner( + estimator=kmeans_estimator, + objective_metric_name="test:msd", + hyperparameter_ranges=hyperparameter_ranges, + objective_type="Minimize", + max_jobs=max_jobs, + max_parallel_jobs=max_parallel_jobs, + warm_start_config=warm_start_config, + early_stopping_type=early_stopping_type, + ) records = kmeans_estimator.record_set(kmeans_train_set[0][:100]) - test_record_set = kmeans_estimator.record_set(kmeans_train_set[0][:100], channel='test') + test_record_set = kmeans_estimator.record_set(kmeans_train_set[0][:100], channel="test") tuner.fit([records, test_record_set], job_name=job_name) - print('Started hyperparameter tuning job with name:' + tuner.latest_tuning_job.name) + print("Started hyperparameter tuning job with name:" + tuner.latest_tuning_job.name) if wait_till_terminal: tuner.wait() @@ -133,440 +167,565 @@ def _tune(kmeans_estimator, kmeans_train_set, tuner=None, @pytest.mark.canary_quick -def test_tuning_kmeans(sagemaker_session, - kmeans_train_set, - kmeans_estimator, - hyperparameter_ranges): - job_name = unique_name_from_base('test-tune-kmeans') - _tune_and_deploy(kmeans_estimator, kmeans_train_set, sagemaker_session, - hyperparameter_ranges=hyperparameter_ranges, job_name=job_name) - - -def test_tuning_kmeans_identical_dataset_algorithm_tuner_raw(sagemaker_session, - kmeans_train_set, - kmeans_estimator, - hyperparameter_ranges): +def test_tuning_kmeans( + sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges +): + job_name = unique_name_from_base("test-tune-kmeans") + _tune_and_deploy( + kmeans_estimator, + kmeans_train_set, + sagemaker_session, + hyperparameter_ranges=hyperparameter_ranges, + job_name=job_name, + ) + + +def test_tuning_kmeans_identical_dataset_algorithm_tuner_raw( + sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges +): parent_tuning_job_name = unique_name_from_base("kmeans-identical", max_length=32) child_tuning_job_name = unique_name_from_base("c-kmeans-identical", max_length=32) - _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name, - hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, max_jobs=1) - child_tuner = _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, - hyperparameter_ranges=hyperparameter_ranges, - warm_start_config=WarmStartConfig( - warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, - parents=[parent_tuning_job_name]), max_parallel_jobs=1, - max_jobs=1) + _tune( + kmeans_estimator, + kmeans_train_set, + job_name=parent_tuning_job_name, + hyperparameter_ranges=hyperparameter_ranges, + max_parallel_jobs=1, + max_jobs=1, + ) + child_tuner = _tune( + kmeans_estimator, + kmeans_train_set, + job_name=child_tuning_job_name, + hyperparameter_ranges=hyperparameter_ranges, + warm_start_config=WarmStartConfig( + warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, + parents=[parent_tuning_job_name], + ), + max_parallel_jobs=1, + max_jobs=1, + ) child_warm_start_config_response = WarmStartConfig.from_job_desc( sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=child_tuning_job_name)["WarmStartConfig"]) + HyperParameterTuningJobName=child_tuning_job_name + )["WarmStartConfig"] + ) assert child_warm_start_config_response.type == child_tuner.warm_start_config.type assert child_warm_start_config_response.parents == child_tuner.warm_start_config.parents -def test_tuning_kmeans_identical_dataset_algorithm_tuner(sagemaker_session, - kmeans_train_set, - kmeans_estimator, - hyperparameter_ranges): +def test_tuning_kmeans_identical_dataset_algorithm_tuner( + sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges +): """Tests Identical dataset and algorithm use case with one parent and child job launched with .identical_dataset_and_algorithm_tuner() """ parent_tuning_job_name = unique_name_from_base("km-iden1-parent", max_length=32) child_tuning_job_name = unique_name_from_base("km-iden1-child", max_length=32) - parent_tuner = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name, - hyperparameter_ranges=hyperparameter_ranges) + parent_tuner = _tune( + kmeans_estimator, + kmeans_train_set, + job_name=parent_tuning_job_name, + hyperparameter_ranges=hyperparameter_ranges, + ) child_tuner = parent_tuner.identical_dataset_and_algorithm_tuner() - _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, tuner=child_tuner, - max_parallel_jobs=1, - max_jobs=1) + _tune( + kmeans_estimator, + kmeans_train_set, + job_name=child_tuning_job_name, + tuner=child_tuner, + max_parallel_jobs=1, + max_jobs=1, + ) child_warm_start_config_response = WarmStartConfig.from_job_desc( sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=child_tuning_job_name)["WarmStartConfig"]) + HyperParameterTuningJobName=child_tuning_job_name + )["WarmStartConfig"] + ) assert child_warm_start_config_response.type == child_tuner.warm_start_config.type assert child_warm_start_config_response.parents == child_tuner.warm_start_config.parents -def test_create_tuning_kmeans_identical_dataset_algorithm_tuner(sagemaker_session, - kmeans_train_set, - kmeans_estimator, - hyperparameter_ranges): +def test_create_tuning_kmeans_identical_dataset_algorithm_tuner( + sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges +): """Tests Identical dataset and algorithm use case with one parent and child job launched with .create_identical_dataset_and_algorithm_tuner() """ parent_tuning_job_name = unique_name_from_base("km-iden2-parent", max_length=32) child_tuning_job_name = unique_name_from_base("km-iden2-child", max_length=32) - parent_tuner = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name, - hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, - max_jobs=1) + parent_tuner = _tune( + kmeans_estimator, + kmeans_train_set, + job_name=parent_tuning_job_name, + hyperparameter_ranges=hyperparameter_ranges, + max_parallel_jobs=1, + max_jobs=1, + ) child_tuner = create_identical_dataset_and_algorithm_tuner( - parent=parent_tuner.latest_tuning_job.name, - sagemaker_session=sagemaker_session) - - _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, tuner=child_tuner, - max_parallel_jobs=1, - max_jobs=1) + parent=parent_tuner.latest_tuning_job.name, sagemaker_session=sagemaker_session + ) + + _tune( + kmeans_estimator, + kmeans_train_set, + job_name=child_tuning_job_name, + tuner=child_tuner, + max_parallel_jobs=1, + max_jobs=1, + ) child_warm_start_config_response = WarmStartConfig.from_job_desc( sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=child_tuning_job_name)["WarmStartConfig"]) + HyperParameterTuningJobName=child_tuning_job_name + )["WarmStartConfig"] + ) assert child_warm_start_config_response.type == child_tuner.warm_start_config.type assert child_warm_start_config_response.parents == child_tuner.warm_start_config.parents -def test_transfer_learning_tuner(sagemaker_session, - kmeans_train_set, - kmeans_estimator, - hyperparameter_ranges): +def test_transfer_learning_tuner( + sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges +): """Tests Transfer learning use case with one parent and child job launched with .transfer_learning_tuner() """ parent_tuning_job_name = unique_name_from_base("km-tran1-parent", max_length=32) child_tuning_job_name = unique_name_from_base("km-tran1-child", max_length=32) - parent_tuner = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name, - hyperparameter_ranges=hyperparameter_ranges, max_jobs=1, - max_parallel_jobs=1) + parent_tuner = _tune( + kmeans_estimator, + kmeans_train_set, + job_name=parent_tuning_job_name, + hyperparameter_ranges=hyperparameter_ranges, + max_jobs=1, + max_parallel_jobs=1, + ) child_tuner = parent_tuner.transfer_learning_tuner() - _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, tuner=child_tuner, - max_parallel_jobs=1, - max_jobs=1) + _tune( + kmeans_estimator, + kmeans_train_set, + job_name=child_tuning_job_name, + tuner=child_tuner, + max_parallel_jobs=1, + max_jobs=1, + ) child_warm_start_config_response = WarmStartConfig.from_job_desc( sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=child_tuning_job_name)["WarmStartConfig"]) + HyperParameterTuningJobName=child_tuning_job_name + )["WarmStartConfig"] + ) assert child_warm_start_config_response.type == child_tuner.warm_start_config.type assert child_warm_start_config_response.parents == child_tuner.warm_start_config.parents -def test_create_transfer_learning_tuner(sagemaker_session, - kmeans_train_set, - kmeans_estimator, - hyperparameter_ranges): +def test_create_transfer_learning_tuner( + sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges +): """Tests Transfer learning use case with two parents and child job launched with create_transfer_learning_tuner() """ parent_tuning_job_name_1 = unique_name_from_base("km-tran2-parent1", max_length=32) parent_tuning_job_name_2 = unique_name_from_base("km-tran2-parent2", max_length=32) child_tuning_job_name = unique_name_from_base("km-tran2-child", max_length=32) - parent_tuner_1 = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name_1, - hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, - max_jobs=1) - - parent_tuner_2 = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name_2, - hyperparameter_ranges=hyperparameter_ranges, max_parallel_jobs=1, - max_jobs=1) + parent_tuner_1 = _tune( + kmeans_estimator, + kmeans_train_set, + job_name=parent_tuning_job_name_1, + hyperparameter_ranges=hyperparameter_ranges, + max_parallel_jobs=1, + max_jobs=1, + ) + + parent_tuner_2 = _tune( + kmeans_estimator, + kmeans_train_set, + job_name=parent_tuning_job_name_2, + hyperparameter_ranges=hyperparameter_ranges, + max_parallel_jobs=1, + max_jobs=1, + ) child_tuner = create_transfer_learning_tuner( parent=parent_tuner_1.latest_tuning_job.name, sagemaker_session=sagemaker_session, estimator=kmeans_estimator, - additional_parents={parent_tuner_2.latest_tuning_job.name}) + additional_parents={parent_tuner_2.latest_tuning_job.name}, + ) _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, tuner=child_tuner) child_warm_start_config_response = WarmStartConfig.from_job_desc( sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( - HyperParameterTuningJobName=child_tuning_job_name)["WarmStartConfig"]) + HyperParameterTuningJobName=child_tuning_job_name + )["WarmStartConfig"] + ) assert child_warm_start_config_response.type == child_tuner.warm_start_config.type assert child_warm_start_config_response.parents == child_tuner.warm_start_config.parents -def test_tuning_kmeans_identical_dataset_algorithm_tuner_from_non_terminal_parent(sagemaker_session, - kmeans_train_set, - kmeans_estimator, - hyperparameter_ranges): +def test_tuning_kmeans_identical_dataset_algorithm_tuner_from_non_terminal_parent( + sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges +): """Tests Identical dataset and algorithm use case with one non terminal parent and child job launched with .identical_dataset_and_algorithm_tuner() """ parent_tuning_job_name = unique_name_from_base("km-non-term", max_length=32) child_tuning_job_name = unique_name_from_base("km-non-term-child", max_length=32) - parent_tuner = _tune(kmeans_estimator, kmeans_train_set, job_name=parent_tuning_job_name, - hyperparameter_ranges=hyperparameter_ranges, wait_till_terminal=False, - max_parallel_jobs=1, - max_jobs=1) + parent_tuner = _tune( + kmeans_estimator, + kmeans_train_set, + job_name=parent_tuning_job_name, + hyperparameter_ranges=hyperparameter_ranges, + wait_till_terminal=False, + max_parallel_jobs=1, + max_jobs=1, + ) child_tuner = parent_tuner.identical_dataset_and_algorithm_tuner() with pytest.raises(ClientError): - _tune(kmeans_estimator, kmeans_train_set, job_name=child_tuning_job_name, tuner=child_tuner, - max_parallel_jobs=1, max_jobs=1) + _tune( + kmeans_estimator, + kmeans_train_set, + job_name=child_tuning_job_name, + tuner=child_tuner, + max_parallel_jobs=1, + max_jobs=1, + ) def test_tuning_lda(sagemaker_session): with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'lda') - data_filename = 'nips-train_1.pbr' + data_path = os.path.join(DATA_DIR, "lda") + data_filename = "nips-train_1.pbr" - with open(os.path.join(data_path, data_filename), 'rb') as f: + with open(os.path.join(data_path, data_filename), "rb") as f: all_records = read_records(f) # all records must be same - feature_num = int(all_records[0].features['values'].float32_tensor.shape[0]) - - lda = LDA(role='SageMakerRole', train_instance_type='ml.c4.xlarge', num_topics=10, - sagemaker_session=sagemaker_session) - - record_set = prepare_record_set_from_local_files(data_path, lda.data_location, - len(all_records), feature_num, - sagemaker_session) - test_record_set = prepare_record_set_from_local_files(data_path, lda.data_location, - len(all_records), feature_num, - sagemaker_session) - test_record_set.channel = 'test' + feature_num = int(all_records[0].features["values"].float32_tensor.shape[0]) + + lda = LDA( + role="SageMakerRole", + train_instance_type="ml.c4.xlarge", + num_topics=10, + sagemaker_session=sagemaker_session, + ) + + record_set = prepare_record_set_from_local_files( + data_path, lda.data_location, len(all_records), feature_num, sagemaker_session + ) + test_record_set = prepare_record_set_from_local_files( + data_path, lda.data_location, len(all_records), feature_num, sagemaker_session + ) + test_record_set.channel = "test" # specify which hp you want to optimize over - hyperparameter_ranges = {'alpha0': ContinuousParameter(1, 10), - 'num_topics': IntegerParameter(1, 2)} - objective_metric_name = 'test:pwll' - - tuner = HyperparameterTuner(estimator=lda, objective_metric_name=objective_metric_name, - hyperparameter_ranges=hyperparameter_ranges, - objective_type='Maximize', max_jobs=2, - max_parallel_jobs=2, - early_stopping_type='Auto') - - tuning_job_name = unique_name_from_base('test-lda', max_length=32) + hyperparameter_ranges = { + "alpha0": ContinuousParameter(1, 10), + "num_topics": IntegerParameter(1, 2), + } + objective_metric_name = "test:pwll" + + tuner = HyperparameterTuner( + estimator=lda, + objective_metric_name=objective_metric_name, + hyperparameter_ranges=hyperparameter_ranges, + objective_type="Maximize", + max_jobs=2, + max_parallel_jobs=2, + early_stopping_type="Auto", + ) + + tuning_job_name = unique_name_from_base("test-lda", max_length=32) tuner.fit([record_set, test_record_set], mini_batch_size=1, job_name=tuning_job_name) latest_tuning_job_name = tuner.latest_tuning_job.name - print('Started hyperparameter tuning job with name:' + latest_tuning_job_name) + print("Started hyperparameter tuning job with name:" + latest_tuning_job_name) time.sleep(15) tuner.wait() - desc = tuner.latest_tuning_job.sagemaker_session.sagemaker_client \ - .describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=latest_tuning_job_name) - assert desc['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] == 'Auto' + desc = tuner.latest_tuning_job.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=latest_tuning_job_name + ) + assert desc["HyperParameterTuningJobConfig"]["TrainingJobEarlyStoppingType"] == "Auto" best_training_job = tuner.best_training_job() with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): - predictor = tuner.deploy(1, 'ml.c4.xlarge') + predictor = tuner.deploy(1, "ml.c4.xlarge") predict_input = np.random.rand(1, feature_num) result = predictor.predict(predict_input) assert len(result) == 1 for record in result: - assert record.label['topic_mixture'] is not None + assert record.label["topic_mixture"] is not None def test_stop_tuning_job(sagemaker_session): feature_num = 14 train_input = np.random.rand(1000, feature_num) - rcf = RandomCutForest(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - num_trees=50, num_samples_per_tree=20, - sagemaker_session=sagemaker_session) + rcf = RandomCutForest( + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + num_trees=50, + num_samples_per_tree=20, + sagemaker_session=sagemaker_session, + ) records = rcf.record_set(train_input) - records.distribution = 'FullyReplicated' - - test_records = rcf.record_set(train_input, channel='test') - test_records.distribution = 'FullyReplicated' - - hyperparameter_ranges = {'num_trees': IntegerParameter(50, 100), - 'num_samples_per_tree': IntegerParameter(1, 2)} - - objective_metric_name = 'test:f1' - tuner = HyperparameterTuner(estimator=rcf, objective_metric_name=objective_metric_name, - hyperparameter_ranges=hyperparameter_ranges, - objective_type='Maximize', max_jobs=2, - max_parallel_jobs=2) - - tuning_job_name = unique_name_from_base('test-randomcutforest', max_length=32) + records.distribution = "FullyReplicated" + + test_records = rcf.record_set(train_input, channel="test") + test_records.distribution = "FullyReplicated" + + hyperparameter_ranges = { + "num_trees": IntegerParameter(50, 100), + "num_samples_per_tree": IntegerParameter(1, 2), + } + + objective_metric_name = "test:f1" + tuner = HyperparameterTuner( + estimator=rcf, + objective_metric_name=objective_metric_name, + hyperparameter_ranges=hyperparameter_ranges, + objective_type="Maximize", + max_jobs=2, + max_parallel_jobs=2, + ) + + tuning_job_name = unique_name_from_base("test-randomcutforest", max_length=32) tuner.fit([records, test_records], tuning_job_name) time.sleep(15) latest_tuning_job_name = tuner.latest_tuning_job.name - print('Attempting to stop {}'.format(latest_tuning_job_name)) + print("Attempting to stop {}".format(latest_tuning_job_name)) tuner.stop_tuning_job() - desc = tuner.latest_tuning_job.sagemaker_session.sagemaker_client \ - .describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=latest_tuning_job_name) - assert desc['HyperParameterTuningJobStatus'] == 'Stopping' + desc = tuner.latest_tuning_job.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=latest_tuning_job_name + ) + assert desc["HyperParameterTuningJobStatus"] == "Stopping" @pytest.mark.canary_quick def test_tuning_mxnet(sagemaker_session, mxnet_full_version): with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - - estimator = MXNet(entry_point=script_path, - role='SageMakerRole', - py_version=PYTHON_VERSION, - train_instance_count=1, - train_instance_type='ml.m4.xlarge', - framework_version=mxnet_full_version, - sagemaker_session=sagemaker_session) - - hyperparameter_ranges = {'learning-rate': ContinuousParameter(0.01, 0.2)} - objective_metric_name = 'Validation-accuracy' + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + estimator = MXNet( + entry_point=script_path, + role="SageMakerRole", + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.m4.xlarge", + framework_version=mxnet_full_version, + sagemaker_session=sagemaker_session, + ) + + hyperparameter_ranges = {"learning-rate": ContinuousParameter(0.01, 0.2)} + objective_metric_name = "Validation-accuracy" metric_definitions = [ - {'Name': 'Validation-accuracy', 'Regex': 'Validation-accuracy=([0-9\\.]+)'}] - tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges, - metric_definitions, - max_jobs=4, max_parallel_jobs=2) - - train_input = estimator.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = estimator.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - - tuning_job_name = unique_name_from_base('tune-mxnet', max_length=32) - tuner.fit({'train': train_input, 'test': test_input}, job_name=tuning_job_name) - - print('Started hyperparameter tuning job with name:' + tuning_job_name) + {"Name": "Validation-accuracy", "Regex": "Validation-accuracy=([0-9\\.]+)"} + ] + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + max_jobs=4, + max_parallel_jobs=2, + ) + + train_input = estimator.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = estimator.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + tuning_job_name = unique_name_from_base("tune-mxnet", max_length=32) + tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name) + + print("Started hyperparameter tuning job with name:" + tuning_job_name) time.sleep(15) tuner.wait() best_training_job = tuner.best_training_job() with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): - predictor = tuner.deploy(1, 'ml.c4.xlarge') + predictor = tuner.deploy(1, "ml.c4.xlarge") data = np.zeros(shape=(1, 1, 28, 28)) predictor.predict(data) @pytest.mark.canary_quick def test_tuning_tf_script_mode(sagemaker_session): - resource_path = os.path.join(DATA_DIR, 'tensorflow_mnist') - script_path = os.path.join(resource_path, 'mnist.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - train_instance_count=1, - train_instance_type='ml.m4.xlarge', - script_mode=True, - sagemaker_session=sagemaker_session, - py_version=PYTHON_VERSION, - framework_version=TensorFlow.LATEST_VERSION) - - hyperparameter_ranges = {'epochs': IntegerParameter(1, 2)} - objective_metric_name = 'accuracy' - metric_definitions = [{'Name': objective_metric_name, 'Regex': 'accuracy = ([0-9\\.]+)'}] - - tuner = HyperparameterTuner(estimator, - objective_metric_name, - hyperparameter_ranges, - metric_definitions, - max_jobs=2, - max_parallel_jobs=2) + resource_path = os.path.join(DATA_DIR, "tensorflow_mnist") + script_path = os.path.join(resource_path, "mnist.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.m4.xlarge", + script_mode=True, + sagemaker_session=sagemaker_session, + py_version=PYTHON_VERSION, + framework_version=TensorFlow.LATEST_VERSION, + ) + + hyperparameter_ranges = {"epochs": IntegerParameter(1, 2)} + objective_metric_name = "accuracy" + metric_definitions = [{"Name": objective_metric_name, "Regex": "accuracy = ([0-9\\.]+)"}] + + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + max_jobs=2, + max_parallel_jobs=2, + ) with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): - inputs = estimator.sagemaker_session.upload_data(path=os.path.join(resource_path, 'data'), - key_prefix='scriptmode/mnist') + inputs = estimator.sagemaker_session.upload_data( + path=os.path.join(resource_path, "data"), key_prefix="scriptmode/mnist" + ) - tuning_job_name = unique_name_from_base('tune-tf-script-mode', max_length=32) + tuning_job_name = unique_name_from_base("tune-tf-script-mode", max_length=32) tuner.fit(inputs, job_name=tuning_job_name) - print('Started hyperparameter tuning job with name: ' + tuning_job_name) + print("Started hyperparameter tuning job with name: " + tuning_job_name) time.sleep(15) tuner.wait() @pytest.mark.canary_quick -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tuning_tf(sagemaker_session): with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') - - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) - - inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf_iris') - hyperparameter_ranges = {'learning_rate': ContinuousParameter(0.05, 0.2)} - - objective_metric_name = 'loss' - metric_definitions = [{'Name': 'loss', 'Regex': 'loss = ([0-9\\.]+)'}] - - tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges, - metric_definitions, - objective_type='Minimize', max_jobs=2, max_parallel_jobs=2) - - tuning_job_name = unique_name_from_base('tune-tf', max_length=32) + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") + + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) + + inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix="integ-test-data/tf_iris") + hyperparameter_ranges = {"learning_rate": ContinuousParameter(0.05, 0.2)} + + objective_metric_name = "loss" + metric_definitions = [{"Name": "loss", "Regex": "loss = ([0-9\\.]+)"}] + + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + objective_type="Minimize", + max_jobs=2, + max_parallel_jobs=2, + ) + + tuning_job_name = unique_name_from_base("tune-tf", max_length=32) tuner.fit(inputs, job_name=tuning_job_name) - print('Started hyperparameter tuning job with name:' + tuning_job_name) + print("Started hyperparameter tuning job with name:" + tuning_job_name) time.sleep(15) tuner.wait() best_training_job = tuner.best_training_job() with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): - predictor = tuner.deploy(1, 'ml.c4.xlarge') + predictor = tuner.deploy(1, "ml.c4.xlarge") features = [6.4, 3.2, 4.5, 1.5] - dict_result = predictor.predict({'inputs': features}) - print('predict result: {}'.format(dict_result)) + dict_result = predictor.predict({"inputs": features}) + print("predict result: {}".format(dict_result)) list_result = predictor.predict(features) - print('predict result: {}'.format(list_result)) + print("predict result: {}".format(list_result)) assert dict_result == list_result -@pytest.mark.skipif(PYTHON_VERSION != 'py2', reason="TensorFlow image supports only python 2.") +@pytest.mark.skipif(PYTHON_VERSION != "py2", reason="TensorFlow image supports only python 2.") def test_tuning_tf_vpc_multi(sagemaker_session): """Test Tensorflow multi-instance using the same VpcConfig for training and inference""" - instance_type = 'ml.c4.xlarge' + instance_type = "ml.c4.xlarge" instance_count = 2 - script_path = os.path.join(DATA_DIR, 'iris', 'iris-dnn-classifier.py') + script_path = os.path.join(DATA_DIR, "iris", "iris-dnn-classifier.py") - ec2_client = sagemaker_session.boto_session.client('ec2') - subnet_ids, security_group_id = vpc_test_utils.get_or_create_vpc_resources(ec2_client, - sagemaker_session.boto_region_name) + ec2_client = sagemaker_session.boto_session.client("ec2") + subnet_ids, security_group_id = vpc_test_utils.get_or_create_vpc_resources( + ec2_client, sagemaker_session.boto_region_name + ) vpc_test_utils.setup_security_group_for_encryption(ec2_client, security_group_id) - estimator = TensorFlow(entry_point=script_path, - role='SageMakerRole', - training_steps=1, - evaluation_steps=1, - hyperparameters={'input_tensor_name': 'inputs'}, - train_instance_count=instance_count, - train_instance_type=instance_type, - sagemaker_session=sagemaker_session, - base_job_name='test-vpc-tf', - subnets=subnet_ids, - security_group_ids=[security_group_id], - encrypt_inter_container_traffic=True) - - inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf_iris') - hyperparameter_ranges = {'learning_rate': ContinuousParameter(0.05, 0.2)} - - objective_metric_name = 'loss' - metric_definitions = [{'Name': 'loss', 'Regex': 'loss = ([0-9\\.]+)'}] - - tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges, - metric_definitions, - objective_type='Minimize', max_jobs=2, max_parallel_jobs=2) - - tuning_job_name = unique_name_from_base('tune-tf', max_length=32) + estimator = TensorFlow( + entry_point=script_path, + role="SageMakerRole", + training_steps=1, + evaluation_steps=1, + hyperparameters={"input_tensor_name": "inputs"}, + train_instance_count=instance_count, + train_instance_type=instance_type, + sagemaker_session=sagemaker_session, + base_job_name="test-vpc-tf", + subnets=subnet_ids, + security_group_ids=[security_group_id], + encrypt_inter_container_traffic=True, + ) + + inputs = sagemaker_session.upload_data(path=DATA_PATH, key_prefix="integ-test-data/tf_iris") + hyperparameter_ranges = {"learning_rate": ContinuousParameter(0.05, 0.2)} + + objective_metric_name = "loss" + metric_definitions = [{"Name": "loss", "Regex": "loss = ([0-9\\.]+)"}] + + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + objective_type="Minimize", + max_jobs=2, + max_parallel_jobs=2, + ) + + tuning_job_name = unique_name_from_base("tune-tf", max_length=32) with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): tuner.fit(inputs, job_name=tuning_job_name) - print('Started hyperparameter tuning job with name:' + tuning_job_name) + print("Started hyperparameter tuning job with name:" + tuning_job_name) time.sleep(15) tuner.wait() @@ -575,98 +734,123 @@ def test_tuning_tf_vpc_multi(sagemaker_session): @pytest.mark.canary_quick def test_tuning_chainer(sagemaker_session): with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): - script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'mnist.py') - data_path = os.path.join(DATA_DIR, 'chainer_mnist') - - estimator = Chainer(entry_point=script_path, - role='SageMakerRole', - py_version=PYTHON_VERSION, - train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session, - hyperparameters={'epochs': 1}) - - train_input = estimator.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/chainer_mnist/train') - test_input = estimator.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/chainer_mnist/test') - - hyperparameter_ranges = {'alpha': ContinuousParameter(0.001, 0.005)} - - objective_metric_name = 'Validation-accuracy' + script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py") + data_path = os.path.join(DATA_DIR, "chainer_mnist") + + estimator = Chainer( + entry_point=script_path, + role="SageMakerRole", + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + hyperparameters={"epochs": 1}, + ) + + train_input = estimator.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/chainer_mnist/train" + ) + test_input = estimator.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/chainer_mnist/test" + ) + + hyperparameter_ranges = {"alpha": ContinuousParameter(0.001, 0.005)} + + objective_metric_name = "Validation-accuracy" metric_definitions = [ - {'Name': 'Validation-accuracy', - 'Regex': r'\[J1\s+\d\.\d+\s+\d\.\d+\s+\d\.\d+\s+(\d\.\d+)'}] - - tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges, - metric_definitions, - max_jobs=2, max_parallel_jobs=2) - - tuning_job_name = unique_name_from_base('chainer', max_length=32) - tuner.fit({'train': train_input, 'test': test_input}, job_name=tuning_job_name) - - print('Started hyperparameter tuning job with name:' + tuning_job_name) + { + "Name": "Validation-accuracy", + "Regex": r"\[J1\s+\d\.\d+\s+\d\.\d+\s+\d\.\d+\s+(\d\.\d+)", + } + ] + + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + max_jobs=2, + max_parallel_jobs=2, + ) + + tuning_job_name = unique_name_from_base("chainer", max_length=32) + tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name) + + print("Started hyperparameter tuning job with name:" + tuning_job_name) time.sleep(15) tuner.wait() best_training_job = tuner.best_training_job() with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): - predictor = tuner.deploy(1, 'ml.c4.xlarge') + predictor = tuner.deploy(1, "ml.c4.xlarge") batch_size = 100 - data = np.zeros((batch_size, 784), dtype='float32') + data = np.zeros((batch_size, 784), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size - data = np.zeros((batch_size, 1, 28, 28), dtype='float32') + data = np.zeros((batch_size, 1, 28, 28), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size - data = np.zeros((batch_size, 28, 28), dtype='float32') + data = np.zeros((batch_size, 28, 28), dtype="float32") output = predictor.predict(data) assert len(output) == batch_size @pytest.mark.canary_quick def test_attach_tuning_pytorch(sagemaker_session): - mnist_dir = os.path.join(DATA_DIR, 'pytorch_mnist') - mnist_script = os.path.join(mnist_dir, 'mnist.py') - - estimator = PyTorch(entry_point=mnist_script, role='SageMakerRole', - train_instance_count=1, py_version=PYTHON_VERSION, - train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session) + mnist_dir = os.path.join(DATA_DIR, "pytorch_mnist") + mnist_script = os.path.join(mnist_dir, "mnist.py") + + estimator = PyTorch( + entry_point=mnist_script, + role="SageMakerRole", + train_instance_count=1, + py_version=PYTHON_VERSION, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): - objective_metric_name = 'evaluation-accuracy' + objective_metric_name = "evaluation-accuracy" metric_definitions = [ - {'Name': 'evaluation-accuracy', 'Regex': r'Overall test accuracy: (\d+)'}] - hyperparameter_ranges = {'batch-size': IntegerParameter(50, 100)} - - tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges, - metric_definitions, - max_jobs=2, max_parallel_jobs=2, - early_stopping_type='Auto') + {"Name": "evaluation-accuracy", "Regex": r"Overall test accuracy: (\d+)"} + ] + hyperparameter_ranges = {"batch-size": IntegerParameter(50, 100)} + + tuner = HyperparameterTuner( + estimator, + objective_metric_name, + hyperparameter_ranges, + metric_definitions, + max_jobs=2, + max_parallel_jobs=2, + early_stopping_type="Auto", + ) training_data = estimator.sagemaker_session.upload_data( - path=os.path.join(mnist_dir, 'training'), - key_prefix='integ-test-data/pytorch_mnist/training') + path=os.path.join(mnist_dir, "training"), + key_prefix="integ-test-data/pytorch_mnist/training", + ) - tuning_job_name = unique_name_from_base('pytorch', max_length=32) - tuner.fit({'training': training_data}, job_name=tuning_job_name) + tuning_job_name = unique_name_from_base("pytorch", max_length=32) + tuner.fit({"training": training_data}, job_name=tuning_job_name) - print('Started hyperparameter tuning job with name:' + tuning_job_name) + print("Started hyperparameter tuning job with name:" + tuning_job_name) time.sleep(15) tuner.wait() - attached_tuner = HyperparameterTuner.attach(tuning_job_name, - sagemaker_session=sagemaker_session) - assert attached_tuner.early_stopping_type == 'Auto' + attached_tuner = HyperparameterTuner.attach( + tuning_job_name, sagemaker_session=sagemaker_session + ) + assert attached_tuner.early_stopping_type == "Auto" best_training_job = tuner.best_training_job() with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): - predictor = attached_tuner.deploy(1, 'ml.c4.xlarge') + predictor = attached_tuner.deploy(1, "ml.c4.xlarge") data = np.zeros(shape=(1, 1, 28, 28), dtype=np.float32) predictor.predict(data) @@ -688,64 +872,72 @@ def test_tuning_byo_estimator(sagemaker_session): Later the trained model is deployed and prediction is called against the endpoint. Default predictor is updated with json serializer and deserializer. """ - image_name = registry(sagemaker_session.boto_session.region_name) + '/factorization-machines:1' - training_data_path = os.path.join(DATA_DIR, 'dummy_tensor') + image_name = registry(sagemaker_session.boto_session.region_name) + "/factorization-machines:1" + training_data_path = os.path.join(DATA_DIR, "dummy_tensor") with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz') - pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'} + data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz") + pickle_args = {} if sys.version_info.major == 2 else {"encoding": "latin1"} - with gzip.open(data_path, 'rb') as f: + with gzip.open(data_path, "rb") as f: train_set, _, _ = pickle.load(f, **pickle_args) - prefix = 'test_byo_estimator' - key = 'recordio-pb-data' - s3_train_data = sagemaker_session.upload_data(path=training_data_path, - key_prefix=os.path.join(prefix, 'train', key)) - - estimator = Estimator(image_name=image_name, - role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) - - estimator.set_hyperparameters(num_factors=10, - feature_dim=784, - mini_batch_size=100, - predictor_type='binary_classifier') - - hyperparameter_ranges = {'mini_batch_size': IntegerParameter(100, 200)} - - tuner = HyperparameterTuner(estimator=estimator, - objective_metric_name='test:binary_classification_accuracy', - hyperparameter_ranges=hyperparameter_ranges, - max_jobs=2, max_parallel_jobs=2) - - tuner.fit({'train': s3_train_data, 'test': s3_train_data}, - include_cls_metadata=False, - job_name=unique_name_from_base('byo', 32)) - - print('Started hyperparameter tuning job with name:' + tuner.latest_tuning_job.name) + prefix = "test_byo_estimator" + key = "recordio-pb-data" + s3_train_data = sagemaker_session.upload_data( + path=training_data_path, key_prefix=os.path.join(prefix, "train", key) + ) + + estimator = Estimator( + image_name=image_name, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + ) + + estimator.set_hyperparameters( + num_factors=10, feature_dim=784, mini_batch_size=100, predictor_type="binary_classifier" + ) + + hyperparameter_ranges = {"mini_batch_size": IntegerParameter(100, 200)} + + tuner = HyperparameterTuner( + estimator=estimator, + objective_metric_name="test:binary_classification_accuracy", + hyperparameter_ranges=hyperparameter_ranges, + max_jobs=2, + max_parallel_jobs=2, + ) + + tuner.fit( + {"train": s3_train_data, "test": s3_train_data}, + include_cls_metadata=False, + job_name=unique_name_from_base("byo", 32), + ) + + print("Started hyperparameter tuning job with name:" + tuner.latest_tuning_job.name) time.sleep(15) tuner.wait() best_training_job = tuner.best_training_job() with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session): - predictor = tuner.deploy(1, 'ml.m4.xlarge', endpoint_name=best_training_job) + predictor = tuner.deploy(1, "ml.m4.xlarge", endpoint_name=best_training_job) predictor.serializer = _fm_serializer - predictor.content_type = 'application/json' + predictor.content_type = "application/json" predictor.deserializer = json_deserializer result = predictor.predict(train_set[0][:10]) - assert len(result['predictions']) == 10 - for prediction in result['predictions']: - assert prediction['score'] is not None + assert len(result["predictions"]) == 10 + for prediction in result["predictions"]: + assert prediction["score"] is not None # Serializer for the Factorization Machines predictor (for BYO example) def _fm_serializer(data): - js = {'instances': []} + js = {"instances": []} for row in data: - js['instances'].append({'features': row.tolist()}) + js["instances"].append({"features": row.tolist()}) return json.dumps(js) diff --git a/tests/integ/timeout.py b/tests/integ/timeout.py index 684b83fd9c..d3955d27f7 100644 --- a/tests/integ/timeout.py +++ b/tests/integ/timeout.py @@ -20,7 +20,7 @@ from awslogs.core import AWSLogs from botocore.exceptions import ClientError -LOGGER = logging.getLogger('timeout') +LOGGER = logging.getLogger("timeout") class TimeoutError(Exception): @@ -44,7 +44,7 @@ def timeout(seconds=0, minutes=0, hours=0): limit = seconds + 60 * minutes + 3600 * hours def handler(signum, frame): - raise TimeoutError('timed out after {} seconds'.format(limit)) + raise TimeoutError("timed out after {} seconds".format(limit)) try: signal.signal(signal.SIGALRM, handler) @@ -56,7 +56,9 @@ def handler(signum, frame): @contextmanager -def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, seconds=0, minutes=45, hours=0): +def timeout_and_delete_endpoint_by_name( + endpoint_name, sagemaker_session, seconds=0, minutes=45, hours=0 +): with timeout(seconds=seconds, minutes=minutes, hours=hours) as t: no_errors = False try: @@ -69,14 +71,14 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second attempts -= 1 try: sagemaker_session.delete_endpoint(endpoint_name) - LOGGER.info('deleted endpoint {}'.format(endpoint_name)) + LOGGER.info("deleted endpoint {}".format(endpoint_name)) - _show_logs(endpoint_name, 'Endpoints', sagemaker_session) + _show_logs(endpoint_name, "Endpoints", sagemaker_session) if no_errors: - _cleanup_logs(endpoint_name, 'Endpoints', sagemaker_session) + _cleanup_logs(endpoint_name, "Endpoints", sagemaker_session) return except ClientError as ce: - if ce.response['Error']['Code'] == 'ValidationException': + if ce.response["Error"]["Code"] == "ValidationException": # avoids the inner exception to be overwritten pass # trying to delete the resource again in 10 seconds @@ -84,7 +86,9 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second @contextmanager -def timeout_and_delete_model_with_transformer(transformer, sagemaker_session, seconds=0, minutes=0, hours=0): +def timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, seconds=0, minutes=0, hours=0 +): with timeout(seconds=seconds, minutes=minutes, hours=hours) as t: no_errors = False try: @@ -97,39 +101,49 @@ def timeout_and_delete_model_with_transformer(transformer, sagemaker_session, se attempts -= 1 try: transformer.delete_model() - LOGGER.info('deleted SageMaker model {}'.format(transformer.model_name)) + LOGGER.info("deleted SageMaker model {}".format(transformer.model_name)) - _show_logs(transformer.model_name, 'Models', sagemaker_session) + _show_logs(transformer.model_name, "Models", sagemaker_session) if no_errors: - _cleanup_logs(transformer.model_name, 'Models', sagemaker_session) + _cleanup_logs(transformer.model_name, "Models", sagemaker_session) return except ClientError as ce: - if ce.response['Error']['Code'] == 'ValidationException': + if ce.response["Error"]["Code"] == "ValidationException": pass sleep(10) def _show_logs(resource_name, resource_type, sagemaker_session): - log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name) + log_group = "/aws/sagemaker/{}/{}".format(resource_type, resource_name) try: # print out logs before deletion for debuggability - LOGGER.info('cloudwatch logs for log group {}:'.format(log_group)) - logs = AWSLogs(log_group_name=log_group, log_stream_name='ALL', start='1d', - aws_region=sagemaker_session.boto_session.region_name) + LOGGER.info("cloudwatch logs for log group {}:".format(log_group)) + logs = AWSLogs( + log_group_name=log_group, + log_stream_name="ALL", + start="1d", + aws_region=sagemaker_session.boto_session.region_name, + ) logs.list_logs() except Exception: - LOGGER.exception('Failure occurred while listing cloudwatch log group %s. Swallowing exception but printing ' - 'stacktrace for debugging.', log_group) + LOGGER.exception( + "Failure occurred while listing cloudwatch log group %s. Swallowing exception but printing " + "stacktrace for debugging.", + log_group, + ) def _cleanup_logs(resource_name, resource_type, sagemaker_session): - log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name) + log_group = "/aws/sagemaker/{}/{}".format(resource_type, resource_name) try: # print out logs before deletion for debuggability - LOGGER.info('deleting cloudwatch log group {}:'.format(log_group)) - cwl_client = sagemaker_session.boto_session.client('logs') + LOGGER.info("deleting cloudwatch log group {}:".format(log_group)) + cwl_client = sagemaker_session.boto_session.client("logs") cwl_client.delete_log_group(logGroupName=log_group) - LOGGER.info('deleted cloudwatch log group: {}'.format(log_group)) + LOGGER.info("deleted cloudwatch log group: {}".format(log_group)) except Exception: - LOGGER.exception('Failure occurred while cleaning up cloudwatch log group %s. ' - 'Swallowing exception but printing stacktrace for debugging.', log_group) + LOGGER.exception( + "Failure occurred while cleaning up cloudwatch log group %s. " + "Swallowing exception but printing stacktrace for debugging.", + log_group, + ) diff --git a/tests/integ/vpc_test_utils.py b/tests/integ/vpc_test_utils.py index f381e2b068..82014796a0 100644 --- a/tests/integ/vpc_test_utils.py +++ b/tests/integ/vpc_test_utils.py @@ -17,77 +17,82 @@ import tests.integ.lock as lock -VPC_NAME = 'sagemaker-python-sdk-test-vpc' -LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_vpc_lock') +VPC_NAME = "sagemaker-python-sdk-test-vpc" +LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_vpc_lock") def _get_subnet_ids_by_name(ec2_client, name): - desc = ec2_client.describe_subnets(Filters=[ - {'Name': 'tag-value', 'Values': [name]} - ]) - if len(desc['Subnets']) == 0: + desc = ec2_client.describe_subnets(Filters=[{"Name": "tag-value", "Values": [name]}]) + if len(desc["Subnets"]) == 0: return None else: - return [subnet['SubnetId'] for subnet in desc['Subnets']] + return [subnet["SubnetId"] for subnet in desc["Subnets"]] def _get_security_id_by_name(ec2_client, name): - desc = ec2_client.describe_security_groups(Filters=[ - {'Name': 'tag-value', 'Values': [name]} - ]) - if len(desc['SecurityGroups']) == 0: + desc = ec2_client.describe_security_groups(Filters=[{"Name": "tag-value", "Values": [name]}]) + if len(desc["SecurityGroups"]) == 0: return None else: - return desc['SecurityGroups'][0]['GroupId'] + return desc["SecurityGroups"][0]["GroupId"] def _vpc_exists(ec2_client, name): - desc = ec2_client.describe_vpcs(Filters=[ - {'Name': 'tag-value', 'Values': [name]} - ]) - return len(desc['Vpcs']) > 0 + desc = ec2_client.describe_vpcs(Filters=[{"Name": "tag-value", "Values": [name]}]) + return len(desc["Vpcs"]) > 0 def _get_route_table_id(ec2_client, vpc_id): - desc = ec2_client.describe_route_tables(Filters=[ - {'Name': 'vpc-id', 'Values': [vpc_id]} - ]) - return desc['RouteTables'][0]['RouteTableId'] + desc = ec2_client.describe_route_tables(Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]) + return desc["RouteTables"][0]["RouteTableId"] def _create_vpc_with_name(ec2_client, region, name): - vpc_id = ec2_client.create_vpc(CidrBlock='10.0.0.0/16')['Vpc']['VpcId'] - print('created vpc: {}'.format(vpc_id)) + vpc_id = ec2_client.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]["VpcId"] + print("created vpc: {}".format(vpc_id)) # sagemaker endpoints require subnets in at least 2 different AZs for vpc mode - subnet_id_a = ec2_client.create_subnet(CidrBlock='10.0.0.0/24', VpcId=vpc_id, - AvailabilityZone=(region + 'a'))['Subnet']['SubnetId'] - print('created subnet: {}'.format(subnet_id_a)) - subnet_id_b = ec2_client.create_subnet(CidrBlock='10.0.1.0/24', VpcId=vpc_id, - AvailabilityZone=(region + 'b'))['Subnet']['SubnetId'] - print('created subnet: {}'.format(subnet_id_b)) - - s3_service = \ - [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if - s.endswith('s3')][0] - ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service, - RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)]) - print('created s3 vpc endpoint') - - security_group_id = \ - ec2_client.create_security_group(VpcId=vpc_id, GroupName=name, Description=name)['GroupId'] - print('created security group: {}'.format(security_group_id)) + subnet_id_a = ec2_client.create_subnet( + CidrBlock="10.0.0.0/24", VpcId=vpc_id, AvailabilityZone=(region + "a") + )["Subnet"]["SubnetId"] + print("created subnet: {}".format(subnet_id_a)) + subnet_id_b = ec2_client.create_subnet( + CidrBlock="10.0.1.0/24", VpcId=vpc_id, AvailabilityZone=(region + "b") + )["Subnet"]["SubnetId"] + print("created subnet: {}".format(subnet_id_b)) + + s3_service = [ + s for s in ec2_client.describe_vpc_endpoint_services()["ServiceNames"] if s.endswith("s3") + ][0] + ec2_client.create_vpc_endpoint( + VpcId=vpc_id, + ServiceName=s3_service, + RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)], + ) + print("created s3 vpc endpoint") + + security_group_id = ec2_client.create_security_group( + VpcId=vpc_id, GroupName=name, Description=name + )["GroupId"] + print("created security group: {}".format(security_group_id)) # multi-host vpc jobs require communication among hosts - ec2_client.authorize_security_group_ingress(GroupId=security_group_id, - IpPermissions=[{'IpProtocol': 'tcp', - 'FromPort': 0, - 'ToPort': 65535, - 'UserIdGroupPairs': [{ - 'GroupId': security_group_id}]}]) - - ec2_client.create_tags(Resources=[vpc_id, subnet_id_a, subnet_id_b, security_group_id], - Tags=[{'Key': 'Name', 'Value': name}]) + ec2_client.authorize_security_group_ingress( + GroupId=security_group_id, + IpPermissions=[ + { + "IpProtocol": "tcp", + "FromPort": 0, + "ToPort": 65535, + "UserIdGroupPairs": [{"GroupId": security_group_id}], + } + ], + ) + + ec2_client.create_tags( + Resources=[vpc_id, subnet_id_a, subnet_id_b, security_group_id], + Tags=[{"Key": "Name", "Value": name}], + ) return [subnet_id_a, subnet_id_b], security_group_id @@ -96,25 +101,29 @@ def get_or_create_vpc_resources(ec2_client, region, name=VPC_NAME): # use lock to prevent race condition when tests are running concurrently with lock.lock(LOCK_PATH): if _vpc_exists(ec2_client, name): - print('using existing vpc: {}'.format(name)) - return _get_subnet_ids_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, - name) + print("using existing vpc: {}".format(name)) + return ( + _get_subnet_ids_by_name(ec2_client, name), + _get_security_id_by_name(ec2_client, name), + ) else: - print('creating new vpc: {}'.format(name)) + print("creating new vpc: {}".format(name)) return _create_vpc_with_name(ec2_client, region, name) def setup_security_group_for_encryption(ec2_client, security_group_id): sg_desc = ec2_client.describe_security_groups(GroupIds=[security_group_id]) - ingress_perms = sg_desc['SecurityGroups'][0]['IpPermissions'] + ingress_perms = sg_desc["SecurityGroups"][0]["IpPermissions"] if len(ingress_perms) == 1: - ec2_client. \ - authorize_security_group_ingress(GroupId=security_group_id, - IpPermissions=[{'IpProtocol': '50', - 'UserIdGroupPairs': [ - {'GroupId': security_group_id}]}, - {'IpProtocol': 'udp', - 'FromPort': 500, - 'ToPort': 500, - 'UserIdGroupPairs': [ - {'GroupId': security_group_id}]}]) + ec2_client.authorize_security_group_ingress( + GroupId=security_group_id, + IpPermissions=[ + {"IpProtocol": "50", "UserIdGroupPairs": [{"GroupId": security_group_id}]}, + { + "IpProtocol": "udp", + "FromPort": 500, + "ToPort": 500, + "UserIdGroupPairs": [{"GroupId": security_group_id}], + }, + ], + ) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 2d945650f0..fda3436100 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -14,4 +14,4 @@ import os -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") diff --git a/tests/unit/test_airflow.py b/tests/unit/test_airflow.py index d038415f4a..2cd6bbc9af 100644 --- a/tests/unit/test_airflow.py +++ b/tests/unit/test_airflow.py @@ -22,73 +22,70 @@ from sagemaker.amazon import knn, ntm, pca -REGION = 'us-west-2' -BUCKET_NAME = 'output' -TIME_STAMP = '1111' +REGION = "us-west-2" +BUCKET_NAME = "output" +TIME_STAMP = "1111" @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session._default_bucket = BUCKET_NAME return session -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_byo_training_config_required_args(sagemaker_session): byo = estimator.Estimator( image_name="byo", role="{{ role }}", train_instance_count="{{ instance_count }}", train_instance_type="ml.c4.2xlarge", - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) - byo.set_hyperparameters(epochs=32, - feature_dim=1024, - mini_batch_size=256) + byo.set_hyperparameters(epochs=32, feature_dim=1024, mini_batch_size=256) - data = {'train': "{{ training_data }}"} + data = {"train": "{{ training_data }}"} config = airflow.training_config(byo, data) expected_config = { - 'AlgorithmSpecification': { - 'TrainingImage': 'byo', - 'TrainingInputMode': 'File' - }, - 'OutputDataConfig': { - 'S3OutputPath': 's3://output/' - }, - 'TrainingJobName': "byo-%s" % TIME_STAMP, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 86400 - }, - 'ResourceConfig': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.c4.2xlarge', - 'VolumeSizeInGB': 30 + "AlgorithmSpecification": {"TrainingImage": "byo", "TrainingInputMode": "File"}, + "OutputDataConfig": {"S3OutputPath": "s3://output/"}, + "TrainingJobName": "byo-%s" % TIME_STAMP, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "ResourceConfig": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": 30, }, - 'RoleArn': '{{ role }}', - 'InputDataConfig': [{ - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ training_data }}' - } - }, 'ChannelName': 'train' - }], - 'HyperParameters': { - 'epochs': '32', - 'feature_dim': '1024', - 'mini_batch_size': '256'} + "RoleArn": "{{ role }}", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "{{ training_data }}", + } + }, + "ChannelName": "train", + } + ], + "HyperParameters": {"epochs": "32", "feature_dim": "1024", "mini_batch_size": "256"}, } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_byo_training_config_all_args(sagemaker_session): byo = estimator.Estimator( image_name="byo", @@ -98,7 +95,7 @@ def test_byo_training_config_all_args(sagemaker_session): train_volume_size="{{ train_volume_size }}", train_volume_kms_key="{{ train_volume_kms_key }}", train_max_run="{{ train_max_run }}", - input_mode='Pipe', + input_mode="Pipe", output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", base_job_name="{{ base_job_name }}", @@ -107,138 +104,133 @@ def test_byo_training_config_all_args(sagemaker_session): security_group_ids=["{{ security_group_ids }}"], model_uri="{{ model_uri }}", model_channel_name="{{ model_chanel }}", - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) - byo.set_hyperparameters(epochs=32, - feature_dim=1024, - mini_batch_size=256) + byo.set_hyperparameters(epochs=32, feature_dim=1024, mini_batch_size=256) - data = {'train': "{{ training_data }}"} + data = {"train": "{{ training_data }}"} config = airflow.training_config(byo, data) expected_config = { - 'AlgorithmSpecification': { - 'TrainingImage': 'byo', - 'TrainingInputMode': 'Pipe' - }, - 'OutputDataConfig': { - 'S3OutputPath': '{{ output_path }}', - 'KmsKeyId': '{{ output_volume_kms_key }}' - }, - 'TrainingJobName': "{{ base_job_name }}-%s" % TIME_STAMP, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': '{{ train_max_run }}' + "AlgorithmSpecification": {"TrainingImage": "byo", "TrainingInputMode": "Pipe"}, + "OutputDataConfig": { + "S3OutputPath": "{{ output_path }}", + "KmsKeyId": "{{ output_volume_kms_key }}", }, - 'ResourceConfig': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.c4.2xlarge', - 'VolumeSizeInGB': '{{ train_volume_size }}', - 'VolumeKmsKeyId': '{{ train_volume_kms_key }}' + "TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP, + "StoppingCondition": {"MaxRuntimeInSeconds": "{{ train_max_run }}"}, + "ResourceConfig": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": "{{ train_volume_size }}", + "VolumeKmsKeyId": "{{ train_volume_kms_key }}", }, - 'RoleArn': '{{ role }}', - 'InputDataConfig': [ + "RoleArn": "{{ role }}", + "InputDataConfig": [ { - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ training_data }}' + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "{{ training_data }}", } }, - 'ChannelName': 'train' + "ChannelName": "train", }, { - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ model_uri }}' + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "{{ model_uri }}", } }, - 'ContentType': 'application/x-sagemaker-model', - 'InputMode': 'File', - 'ChannelName': '{{ model_chanel }}' - } + "ContentType": "application/x-sagemaker-model", + "InputMode": "File", + "ChannelName": "{{ model_chanel }}", + }, ], - 'VpcConfig': { - 'Subnets': ['{{ subnet }}'], - 'SecurityGroupIds': ['{{ security_group_ids }}'] + "VpcConfig": { + "Subnets": ["{{ subnet }}"], + "SecurityGroupIds": ["{{ security_group_ids }}"], }, - 'HyperParameters': { - 'epochs': '32', - 'feature_dim': '1024', - 'mini_batch_size': '256'}, - 'Tags': [{'{{ key }}': '{{ value }}'}] + "HyperParameters": {"epochs": "32", "feature_dim": "1024", "mini_batch_size": "256"}, + "Tags": [{"{{ key }}": "{{ value }}"}], } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_framework_training_config_required_args(sagemaker_session): tf = tensorflow.TensorFlow( entry_point="{{ entry_point }}", - framework_version='1.10.0', + framework_version="1.10.0", training_steps=1000, evaluation_steps=100, role="{{ role }}", train_instance_count="{{ instance_count }}", train_instance_type="ml.c4.2xlarge", - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) data = "{{ training_data }}" config = airflow.training_config(tf, data) expected_config = { - 'AlgorithmSpecification': { - 'TrainingImage': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2', - 'TrainingInputMode': 'File' - }, - 'OutputDataConfig': { - 'S3OutputPath': 's3://output/' + "AlgorithmSpecification": { + "TrainingImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2", + "TrainingInputMode": "File", }, - 'TrainingJobName': "sagemaker-tensorflow-%s" % TIME_STAMP, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 86400 + "OutputDataConfig": {"S3OutputPath": "s3://output/"}, + "TrainingJobName": "sagemaker-tensorflow-%s" % TIME_STAMP, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "ResourceConfig": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": 30, }, - 'ResourceConfig': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.c4.2xlarge', - 'VolumeSizeInGB': 30 + "RoleArn": "{{ role }}", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "{{ training_data }}", + } + }, + "ChannelName": "training", + } + ], + "HyperParameters": { + "sagemaker_submit_directory": '"s3://output/sagemaker-tensorflow-%s/source/sourcedir.tar.gz"' + % TIME_STAMP, + "sagemaker_program": '"{{ entry_point }}"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": "20", + "sagemaker_job_name": '"sagemaker-tensorflow-%s"' % TIME_STAMP, + "sagemaker_region": '"us-west-2"', + "checkpoint_path": '"s3://output/sagemaker-tensorflow-%s/checkpoints"' % TIME_STAMP, + "training_steps": "1000", + "evaluation_steps": "100", + "sagemaker_requirements": '""', }, - 'RoleArn': '{{ role }}', - 'InputDataConfig': [{ - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ training_data }}' + "S3Operations": { + "S3Upload": [ + { + "Path": "{{ entry_point }}", + "Bucket": "output", + "Key": "sagemaker-tensorflow-%s/source/sourcedir.tar.gz" % TIME_STAMP, + "Tar": True, } - }, - 'ChannelName': 'training' - }], - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://output/sagemaker-tensorflow-%s/source/sourcedir.tar.gz"' % TIME_STAMP, - 'sagemaker_program': '"{{ entry_point }}"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '20', - 'sagemaker_job_name': '"sagemaker-tensorflow-%s"' % TIME_STAMP, - 'sagemaker_region': '"us-west-2"', - 'checkpoint_path': '"s3://output/sagemaker-tensorflow-%s/checkpoints"' % TIME_STAMP, - 'training_steps': '1000', - 'evaluation_steps': '100', - 'sagemaker_requirements': '""'}, - 'S3Operations': { - 'S3Upload': [{ - 'Path': '{{ entry_point }}', - 'Bucket': 'output', - 'Key': "sagemaker-tensorflow-%s/source/sourcedir.tar.gz" % TIME_STAMP, - 'Tar': True}] - } + ] + }, } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_framework_training_config_all_args(sagemaker_session): tf = tensorflow.TensorFlow( entry_point="{{ entry_point }}", @@ -249,8 +241,8 @@ def test_framework_training_config_all_args(sagemaker_session): training_steps=1000, evaluation_steps=100, checkpoint_path="{{ checkpoint_path }}", - py_version='py2', - framework_version='1.10.0', + py_version="py2", + framework_version="1.10.0", requirements_file="", role="{{ role }}", train_instance_count="{{ instance_count }}", @@ -258,128 +250,134 @@ def test_framework_training_config_all_args(sagemaker_session): train_volume_size="{{ train_volume_size }}", train_volume_kms_key="{{ train_volume_kms_key }}", train_max_run="{{ train_max_run }}", - input_mode='Pipe', + input_mode="Pipe", output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", base_job_name="{{ base_job_name }}", tags=[{"{{ key }}": "{{ value }}"}], subnets=["{{ subnet }}"], security_group_ids=["{{ security_group_ids }}"], - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) data = "{{ training_data }}" config = airflow.training_config(tf, data) expected_config = { - 'AlgorithmSpecification': { - 'TrainingImage': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2', - 'TrainingInputMode': 'Pipe' + "AlgorithmSpecification": { + "TrainingImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2", + "TrainingInputMode": "Pipe", + }, + "OutputDataConfig": { + "S3OutputPath": "{{ output_path }}", + "KmsKeyId": "{{ output_volume_kms_key }}", }, - 'OutputDataConfig': { - 'S3OutputPath': '{{ output_path }}', - 'KmsKeyId': '{{ output_volume_kms_key }}' + "TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP, + "StoppingCondition": {"MaxRuntimeInSeconds": "{{ train_max_run }}"}, + "ResourceConfig": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": "{{ train_volume_size }}", + "VolumeKmsKeyId": "{{ train_volume_kms_key }}", }, - 'TrainingJobName': "{{ base_job_name }}-%s" % TIME_STAMP, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': '{{ train_max_run }}' + "RoleArn": "{{ role }}", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "{{ training_data }}", + } + }, + "ChannelName": "training", + } + ], + "VpcConfig": { + "Subnets": ["{{ subnet }}"], + "SecurityGroupIds": ["{{ security_group_ids }}"], }, - 'ResourceConfig': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.c4.2xlarge', - 'VolumeSizeInGB': '{{ train_volume_size }}', - 'VolumeKmsKeyId': '{{ train_volume_kms_key }}' + "HyperParameters": { + "sagemaker_submit_directory": '"s3://{{ bucket_name }}/{{ prefix }}/{{ base_job_name }}-%s/' + 'source/sourcedir.tar.gz"' % TIME_STAMP, + "sagemaker_program": '"{{ entry_point }}"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"{{ log_level }}"', + "sagemaker_job_name": '"{{ base_job_name }}-%s"' % TIME_STAMP, + "sagemaker_region": '"us-west-2"', + "checkpoint_path": '"{{ checkpoint_path }}"', + "training_steps": "1000", + "evaluation_steps": "100", + "sagemaker_requirements": '""', }, - 'RoleArn': '{{ role }}', - 'InputDataConfig': [{ - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ training_data }}' + "Tags": [{"{{ key }}": "{{ value }}"}], + "S3Operations": { + "S3Upload": [ + { + "Path": "{{ source_dir }}", + "Bucket": "{{ bucket_name }}", + "Key": "{{ prefix }}/{{ base_job_name }}-%s/source/sourcedir.tar.gz" + % TIME_STAMP, + "Tar": True, } - }, - 'ChannelName': 'training' - }], - 'VpcConfig': { - 'Subnets': ['{{ subnet }}'], - 'SecurityGroupIds': ['{{ security_group_ids }}'] - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://{{ bucket_name }}/{{ prefix }}/{{ base_job_name }}-%s/' - 'source/sourcedir.tar.gz"' % TIME_STAMP, - 'sagemaker_program': '"{{ entry_point }}"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"{{ log_level }}"', - 'sagemaker_job_name': '"{{ base_job_name }}-%s"' % TIME_STAMP, - 'sagemaker_region': '"us-west-2"', - 'checkpoint_path': '"{{ checkpoint_path }}"', - 'training_steps': '1000', - 'evaluation_steps': '100', - 'sagemaker_requirements': '""' + ] }, - 'Tags': [{'{{ key }}': '{{ value }}'}], - 'S3Operations': { - 'S3Upload': [{ - 'Path': '{{ source_dir }}', - 'Bucket': '{{ bucket_name }}', - 'Key': "{{ prefix }}/{{ base_job_name }}-%s/source/sourcedir.tar.gz" % TIME_STAMP, - 'Tar': True}] - } } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_amazon_alg_training_config_required_args(sagemaker_session): ntm_estimator = ntm.NTM( role="{{ role }}", num_topics=10, train_instance_count="{{ instance_count }}", train_instance_type="ml.c4.2xlarge", - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) ntm_estimator.epochs = 32 - record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') + record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") config = airflow.training_config(ntm_estimator, record, mini_batch_size=256) expected_config = { - 'AlgorithmSpecification': { - 'TrainingImage': '174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1', - 'TrainingInputMode': 'File' + "AlgorithmSpecification": { + "TrainingImage": "174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1", + "TrainingInputMode": "File", }, - 'OutputDataConfig': { - 'S3OutputPath': 's3://output/' + "OutputDataConfig": {"S3OutputPath": "s3://output/"}, + "TrainingJobName": "ntm-%s" % TIME_STAMP, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "ResourceConfig": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": 30, }, - 'TrainingJobName': "ntm-%s" % TIME_STAMP, - 'StoppingCondition': {'MaxRuntimeInSeconds': 86400}, - 'ResourceConfig': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.c4.2xlarge', - 'VolumeSizeInGB': 30 + "RoleArn": "{{ role }}", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "ShardedByS3Key", + "S3DataType": "S3Prefix", + "S3Uri": "{{ record }}", + } + }, + "ChannelName": "train", + } + ], + "HyperParameters": { + "num_topics": "10", + "epochs": "32", + "mini_batch_size": "256", + "feature_dim": "100", }, - 'RoleArn': '{{ role }}', - 'InputDataConfig': [{ - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'ShardedByS3Key', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ record }}' - } - }, - 'ChannelName': 'train' - }], - 'HyperParameters': { - 'num_topics': '10', - 'epochs': '32', - 'mini_batch_size': '256', - 'feature_dim': '100' - } } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_amazon_alg_training_config_all_args(sagemaker_session): ntm_estimator = ntm.NTM( role="{{ role }}", @@ -389,181 +387,188 @@ def test_amazon_alg_training_config_all_args(sagemaker_session): train_volume_size="{{ train_volume_size }}", train_volume_kms_key="{{ train_volume_kms_key }}", train_max_run="{{ train_max_run }}", - input_mode='Pipe', + input_mode="Pipe", output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", base_job_name="{{ base_job_name }}", tags=[{"{{ key }}": "{{ value }}"}], subnets=["{{ subnet }}"], security_group_ids=["{{ security_group_ids }}"], - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) ntm_estimator.epochs = 32 - record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') + record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") config = airflow.training_config(ntm_estimator, record, mini_batch_size=256) expected_config = { - 'AlgorithmSpecification': { - 'TrainingImage': '174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1', - 'TrainingInputMode': 'Pipe' - }, - 'OutputDataConfig': { - 'S3OutputPath': '{{ output_path }}', - 'KmsKeyId': '{{ output_volume_kms_key }}' + "AlgorithmSpecification": { + "TrainingImage": "174872318107.dkr.ecr.us-west-2.amazonaws.com/ntm:1", + "TrainingInputMode": "Pipe", }, - 'TrainingJobName': "{{ base_job_name }}-%s" % TIME_STAMP, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': '{{ train_max_run }}' + "OutputDataConfig": { + "S3OutputPath": "{{ output_path }}", + "KmsKeyId": "{{ output_volume_kms_key }}", }, - 'ResourceConfig': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.c4.2xlarge', - 'VolumeSizeInGB': '{{ train_volume_size }}', - 'VolumeKmsKeyId': '{{ train_volume_kms_key }}' + "TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP, + "StoppingCondition": {"MaxRuntimeInSeconds": "{{ train_max_run }}"}, + "ResourceConfig": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": "{{ train_volume_size }}", + "VolumeKmsKeyId": "{{ train_volume_kms_key }}", }, - 'RoleArn': '{{ role }}', - 'InputDataConfig': [{ - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'ShardedByS3Key', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ record }}' - } - }, - 'ChannelName': 'train' - }], - 'VpcConfig': { - 'Subnets': ['{{ subnet }}'], - 'SecurityGroupIds': ['{{ security_group_ids }}'] + "RoleArn": "{{ role }}", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "ShardedByS3Key", + "S3DataType": "S3Prefix", + "S3Uri": "{{ record }}", + } + }, + "ChannelName": "train", + } + ], + "VpcConfig": { + "Subnets": ["{{ subnet }}"], + "SecurityGroupIds": ["{{ security_group_ids }}"], }, - 'HyperParameters': { - 'num_topics': '10', - 'epochs': '32', - 'mini_batch_size': '256', - 'feature_dim': '100' + "HyperParameters": { + "num_topics": "10", + "epochs": "32", + "mini_batch_size": "256", + "feature_dim": "100", }, - 'Tags': [{'{{ key }}': '{{ value }}'}] + "Tags": [{"{{ key }}": "{{ value }}"}], } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) -@patch('sagemaker.utils.sagemaker_short_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_short_timestamp", MagicMock(return_value=TIME_STAMP)) def test_framework_tuning_config(sagemaker_session): mxnet_estimator = mxnet.MXNet( entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", - py_version='py3', - framework_version='1.3.0', + py_version="py3", + framework_version="1.3.0", role="{{ role }}", train_instance_count=1, - train_instance_type='ml.m4.xlarge', + train_instance_type="ml.m4.xlarge", sagemaker_session=sagemaker_session, base_job_name="{{ base_job_name }}", - hyperparameters={'batch_size': 100}) + hyperparameters={"batch_size": 100}, + ) - hyperparameter_ranges = {'optimizer': tuner.CategoricalParameter(['sgd', 'Adam']), - 'learning_rate': tuner.ContinuousParameter(0.01, 0.2), - 'num_epoch': tuner.IntegerParameter(10, 50)} - objective_metric_name = 'Validation-accuracy' - metric_definitions = [{'Name': 'Validation-accuracy', - 'Regex': 'Validation-accuracy=([0-9\\.]+)'}] + hyperparameter_ranges = { + "optimizer": tuner.CategoricalParameter(["sgd", "Adam"]), + "learning_rate": tuner.ContinuousParameter(0.01, 0.2), + "num_epoch": tuner.IntegerParameter(10, 50), + } + objective_metric_name = "Validation-accuracy" + metric_definitions = [ + {"Name": "Validation-accuracy", "Regex": "Validation-accuracy=([0-9\\.]+)"} + ] mxnet_tuner = tuner.HyperparameterTuner( estimator=mxnet_estimator, objective_metric_name=objective_metric_name, hyperparameter_ranges=hyperparameter_ranges, metric_definitions=metric_definitions, - strategy='Bayesian', - objective_type='Maximize', + strategy="Bayesian", + objective_type="Maximize", max_jobs="{{ max_job }}", max_parallel_jobs="{{ max_parallel_job }}", - tags=[{'{{ key }}': '{{ value }}'}], - base_tuning_job_name="{{ base_job_name }}") + tags=[{"{{ key }}": "{{ value }}"}], + base_tuning_job_name="{{ base_job_name }}", + ) data = "{{ training_data }}" config = airflow.tuning_config(mxnet_tuner, data) expected_config = { - 'HyperParameterTuningJobName': "{{ base_job_name }}-%s" % TIME_STAMP, - 'HyperParameterTuningJobConfig': { - 'Strategy': 'Bayesian', - 'HyperParameterTuningJobObjective': { - 'Type': 'Maximize', - 'MetricName': 'Validation-accuracy' + "HyperParameterTuningJobName": "{{ base_job_name }}-%s" % TIME_STAMP, + "HyperParameterTuningJobConfig": { + "Strategy": "Bayesian", + "HyperParameterTuningJobObjective": { + "Type": "Maximize", + "MetricName": "Validation-accuracy", }, - 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': '{{ max_job }}', - 'MaxParallelTrainingJobs': '{{ max_parallel_job }}' + "ResourceLimits": { + "MaxNumberOfTrainingJobs": "{{ max_job }}", + "MaxParallelTrainingJobs": "{{ max_parallel_job }}", }, - 'ParameterRanges': { - 'ContinuousParameterRanges': [{ - 'Name': 'learning_rate', - 'MinValue': '0.01', - 'MaxValue': '0.2', - 'ScalingType': 'Auto'}], - 'CategoricalParameterRanges': [{ - 'Name': 'optimizer', - 'Values': ['"sgd"', '"Adam"'] - }], - 'IntegerParameterRanges': [{ - 'Name': 'num_epoch', - 'MinValue': '10', - 'MaxValue': '50', - 'ScalingType': 'Auto' - }] - }}, - 'TrainingJobDefinition': { - 'AlgorithmSpecification': { - 'TrainingImage': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3', - 'TrainingInputMode': 'File', - 'MetricDefinitions': [{ - 'Name': 'Validation-accuracy', - 'Regex': 'Validation-accuracy=([0-9\\.]+)' - }] + "ParameterRanges": { + "ContinuousParameterRanges": [ + { + "Name": "learning_rate", + "MinValue": "0.01", + "MaxValue": "0.2", + "ScalingType": "Auto", + } + ], + "CategoricalParameterRanges": [ + {"Name": "optimizer", "Values": ['"sgd"', '"Adam"']} + ], + "IntegerParameterRanges": [ + {"Name": "num_epoch", "MinValue": "10", "MaxValue": "50", "ScalingType": "Auto"} + ], }, - 'OutputDataConfig': { - 'S3OutputPath': 's3://output/' + }, + "TrainingJobDefinition": { + "AlgorithmSpecification": { + "TrainingImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3", + "TrainingInputMode": "File", + "MetricDefinitions": [ + {"Name": "Validation-accuracy", "Regex": "Validation-accuracy=([0-9\\.]+)"} + ], }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 86400 + "OutputDataConfig": {"S3OutputPath": "s3://output/"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.m4.xlarge", + "VolumeSizeInGB": 30, }, - 'ResourceConfig': { - 'InstanceCount': 1, - 'InstanceType': 'ml.m4.xlarge', - 'VolumeSizeInGB': 30 + "RoleArn": "{{ role }}", + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "{{ training_data }}", + } + }, + "ChannelName": "training", + } + ], + "StaticHyperParameters": { + "batch_size": "100", + "sagemaker_submit_directory": '"s3://output/{{ base_job_name }}-%s/source/sourcedir.tar.gz"' + % TIME_STAMP, + "sagemaker_program": '"{{ entry_point }}"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": "20", + "sagemaker_job_name": '"{{ base_job_name }}-%s"' % TIME_STAMP, + "sagemaker_region": '"us-west-2"', }, - 'RoleArn': '{{ role }}', - 'InputDataConfig': [{ - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ training_data }}' - } - }, - 'ChannelName': 'training' - }], - 'StaticHyperParameters': { - 'batch_size': '100', - 'sagemaker_submit_directory': '"s3://output/{{ base_job_name }}-%s/source/sourcedir.tar.gz"' - % TIME_STAMP, - 'sagemaker_program': '"{{ entry_point }}"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '20', - 'sagemaker_job_name': '"{{ base_job_name }}-%s"' % TIME_STAMP, - 'sagemaker_region': '"us-west-2"'}}, - 'Tags': [{'{{ key }}': '{{ value }}'}], - 'S3Operations': { - 'S3Upload': [{ - 'Path': '{{ source_dir }}', - 'Bucket': 'output', - 'Key': "{{ base_job_name }}-%s/source/sourcedir.tar.gz" % TIME_STAMP, - 'Tar': True - }] - } + }, + "Tags": [{"{{ key }}": "{{ value }}"}], + "S3Operations": { + "S3Upload": [ + { + "Path": "{{ source_dir }}", + "Bucket": "output", + "Key": "{{ base_job_name }}-%s/source/sourcedir.tar.gz" % TIME_STAMP, + "Tar": True, + } + ] + }, } assert config == expected_config @@ -575,18 +580,19 @@ def test_byo_model_config(sagemaker_session): image="{{ image }}", role="{{ role }}", env={"{{ key }}": "{{ value }}"}, - name='model', - sagemaker_session=sagemaker_session) + name="model", + sagemaker_session=sagemaker_session, + ) - config = airflow.model_config(instance_type='ml.c4.xlarge', model=byo_model) + config = airflow.model_config(instance_type="ml.c4.xlarge", model=byo_model) expected_config = { - 'ModelName': 'model', - 'PrimaryContainer': { - 'Image': '{{ image }}', - 'Environment': {'{{ key }}': '{{ value }}'}, - 'ModelDataUrl': '{{ model_data }}' + "ModelName": "model", + "PrimaryContainer": { + "Image": "{{ image }}", + "Environment": {"{{ key }}": "{{ value }}"}, + "ModelDataUrl": "{{ model_data }}", }, - 'ExecutionRoleArn': '{{ role }}' + "ExecutionRoleArn": "{{ role }}", } assert config == expected_config @@ -600,38 +606,42 @@ def test_byo_framework_model_config(sagemaker_session): entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", env={"{{ key }}": "{{ value }}"}, - name='model', - sagemaker_session=sagemaker_session) + name="model", + sagemaker_session=sagemaker_session, + ) - config = airflow.model_config(instance_type='ml.c4.xlarge', model=byo_model) + config = airflow.model_config(instance_type="ml.c4.xlarge", model=byo_model) expected_config = { - 'ModelName': 'model', - 'PrimaryContainer': { - 'Image': '{{ image }}', - 'Environment': { - '{{ key }}': '{{ value }}', - 'SAGEMAKER_PROGRAM': '{{ entry_point }}', - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://output/model/source/sourcedir.tar.gz', - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': 'us-west-2' + "ModelName": "model", + "PrimaryContainer": { + "Image": "{{ image }}", + "Environment": { + "{{ key }}": "{{ value }}", + "SAGEMAKER_PROGRAM": "{{ entry_point }}", + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://output/model/source/sourcedir.tar.gz", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", }, - 'ModelDataUrl': '{{ model_data }}'}, - 'ExecutionRoleArn': '{{ role }}', - 'S3Operations': { - 'S3Upload': [{ - 'Path': '{{ source_dir }}', - 'Bucket': 'output', - 'Key': 'model/source/sourcedir.tar.gz', - 'Tar': True - }] - } + "ModelDataUrl": "{{ model_data }}", + }, + "ExecutionRoleArn": "{{ role }}", + "S3Operations": { + "S3Upload": [ + { + "Path": "{{ source_dir }}", + "Bucket": "output", + "Key": "model/source/sourcedir.tar.gz", + "Tar": True, + } + ] + }, } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_framework_model_config(sagemaker_session): chainer_model = chainer.ChainerModel( model_data="{{ model_data }}", @@ -639,146 +649,153 @@ def test_framework_model_config(sagemaker_session): entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", image=None, - py_version='py3', - framework_version='5.0.0', + py_version="py3", + framework_version="5.0.0", model_server_workers="{{ model_server_worker }}", - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) - config = airflow.model_config(instance_type='ml.c4.xlarge', model=chainer_model) + config = airflow.model_config(instance_type="ml.c4.xlarge", model=chainer_model) expected_config = { - 'ModelName': "sagemaker-chainer-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3', - 'Environment': { - 'SAGEMAKER_PROGRAM': '{{ entry_point }}', - 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/sagemaker-chainer-%s/source/sourcedir.tar.gz" % TIME_STAMP, - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_MODEL_SERVER_WORKERS': '{{ model_server_worker }}' + "ModelName": "sagemaker-chainer-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3", + "Environment": { + "SAGEMAKER_PROGRAM": "{{ entry_point }}", + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://output/sagemaker-chainer-%s/source/sourcedir.tar.gz" + % TIME_STAMP, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_MODEL_SERVER_WORKERS": "{{ model_server_worker }}", }, - 'ModelDataUrl': '{{ model_data }}' + "ModelDataUrl": "{{ model_data }}", + }, + "ExecutionRoleArn": "{{ role }}", + "S3Operations": { + "S3Upload": [ + { + "Path": "{{ source_dir }}", + "Bucket": "output", + "Key": "sagemaker-chainer-%s/source/sourcedir.tar.gz" % TIME_STAMP, + "Tar": True, + } + ] }, - 'ExecutionRoleArn': '{{ role }}', - 'S3Operations': { - 'S3Upload': [{ - 'Path': '{{ source_dir }}', - 'Bucket': 'output', - 'Key': "sagemaker-chainer-%s/source/sourcedir.tar.gz" % TIME_STAMP, - 'Tar': True}] - } } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_amazon_alg_model_config(sagemaker_session): pca_model = pca.PCAModel( - model_data="{{ model_data }}", - role="{{ role }}", - sagemaker_session=sagemaker_session) + model_data="{{ model_data }}", role="{{ role }}", sagemaker_session=sagemaker_session + ) - config = airflow.model_config(instance_type='ml.c4.xlarge', model=pca_model) + config = airflow.model_config(instance_type="ml.c4.xlarge", model=pca_model) expected_config = { - 'ModelName': "pca-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/pca:1', - 'Environment': {}, - 'ModelDataUrl': '{{ model_data }}' + "ModelName": "pca-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/pca:1", + "Environment": {}, + "ModelDataUrl": "{{ model_data }}", }, - 'ExecutionRoleArn': '{{ role }}' + "ExecutionRoleArn": "{{ role }}", } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_model_config_from_framework_estimator(sagemaker_session): mxnet_estimator = mxnet.MXNet( entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", - py_version='py3', - framework_version='1.3.0', + py_version="py3", + framework_version="1.3.0", role="{{ role }}", train_instance_count=1, - train_instance_type='ml.m4.xlarge', + train_instance_type="ml.m4.xlarge", sagemaker_session=sagemaker_session, base_job_name="{{ base_job_name }}", - hyperparameters={'batch_size': 100}) + hyperparameters={"batch_size": 100}, + ) data = "{{ training_data }}" # simulate training airflow.training_config(mxnet_estimator, data) - config = airflow.model_config_from_estimator(instance_type='ml.c4.xlarge', - estimator=mxnet_estimator, - task_id='task_id', - task_type='training') + config = airflow.model_config_from_estimator( + instance_type="ml.c4.xlarge", + estimator=mxnet_estimator, + task_id="task_id", + task_type="training", + ) expected_config = { - 'ModelName': "sagemaker-mxnet-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3', - 'Environment': { - 'SAGEMAKER_PROGRAM': '{{ entry_point }}', - 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']" - "['TrainingJobName'] }}/source/sourcedir.tar.gz", - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': 'us-west-2' + "ModelName": "sagemaker-mxnet-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3", + "Environment": { + "SAGEMAKER_PROGRAM": "{{ entry_point }}", + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']" + "['TrainingJobName'] }}/source/sourcedir.tar.gz", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", }, - 'ModelDataUrl': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" - "/output/model.tar.gz" + "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" + "/output/model.tar.gz", }, - 'ExecutionRoleArn': '{{ role }}' + "ExecutionRoleArn": "{{ role }}", } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_model_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN( role="{{ role }}", train_instance_count="{{ instance_count }}", - train_instance_type='ml.m4.xlarge', + train_instance_type="ml.m4.xlarge", k=16, sample_size=128, - predictor_type='regressor', - sagemaker_session=sagemaker_session) + predictor_type="regressor", + sagemaker_session=sagemaker_session, + ) - record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') + record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") # simulate training airflow.training_config(knn_estimator, record, mini_batch_size=256) - config = airflow.model_config_from_estimator(instance_type='ml.c4.xlarge', - estimator=knn_estimator, - task_id='task_id', - task_type='tuning') + config = airflow.model_config_from_estimator( + instance_type="ml.c4.xlarge", estimator=knn_estimator, task_id="task_id", task_type="tuning" + ) expected_config = { - 'ModelName': "knn-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1', - 'Environment': {}, - 'ModelDataUrl': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Tuning']['BestTrainingJob']" - "['TrainingJobName'] }}/output/model.tar.gz" + "ModelName": "knn-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1", + "Environment": {}, + "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Tuning']['BestTrainingJob']" + "['TrainingJobName'] }}/output/model.tar.gz", }, - 'ExecutionRoleArn': '{{ role }}' + "ExecutionRoleArn": "{{ role }}", } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_transform_config(sagemaker_session): tf_transformer = transformer.Transformer( model_name="tensorflow-model", instance_count="{{ instance_count }}", instance_type="ml.p2.xlarge", strategy="SingleRecord", - assemble_with='Line', + assemble_with="Line", output_path="{{ output_path }}", output_kms_key="{{ kms_key }}", accept="{{ accept }}", @@ -788,59 +805,65 @@ def test_transform_config(sagemaker_session): env={"{{ key }}": "{{ value }}"}, base_transform_job_name="tensorflow-transform", sagemaker_session=sagemaker_session, - volume_kms_key="{{ kms_key }}") + volume_kms_key="{{ kms_key }}", + ) data = "{{ transform_data }}" - config = airflow.transform_config(tf_transformer, data, data_type='S3Prefix', content_type="{{ content_type }}", - compression_type="{{ compression_type }}", split_type="{{ split_type }}") + config = airflow.transform_config( + tf_transformer, + data, + data_type="S3Prefix", + content_type="{{ content_type }}", + compression_type="{{ compression_type }}", + split_type="{{ split_type }}", + ) expected_config = { - 'TransformJobName': "tensorflow-transform-%s" % TIME_STAMP, - 'ModelName': 'tensorflow-model', - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ transform_data }}' - } + "TransformJobName": "tensorflow-transform-%s" % TIME_STAMP, + "ModelName": "tensorflow-model", + "TransformInput": { + "DataSource": { + "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "{{ transform_data }}"} }, - 'ContentType': '{{ content_type }}', - 'CompressionType': '{{ compression_type }}', - 'SplitType': '{{ split_type }}'}, - 'TransformOutput': { - 'S3OutputPath': '{{ output_path }}', - 'KmsKeyId': '{{ kms_key }}', - 'AssembleWith': 'Line', - 'Accept': '{{ accept }}' + "ContentType": "{{ content_type }}", + "CompressionType": "{{ compression_type }}", + "SplitType": "{{ split_type }}", + }, + "TransformOutput": { + "S3OutputPath": "{{ output_path }}", + "KmsKeyId": "{{ kms_key }}", + "AssembleWith": "Line", + "Accept": "{{ accept }}", }, - 'TransformResources': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.p2.xlarge', - 'VolumeKmsKeyId': '{{ kms_key }}' + "TransformResources": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.p2.xlarge", + "VolumeKmsKeyId": "{{ kms_key }}", }, - 'BatchStrategy': 'SingleRecord', - 'MaxConcurrentTransforms': '{{ max_parallel_job }}', - 'MaxPayloadInMB': '{{ max_payload }}', - 'Environment': {'{{ key }}': '{{ value }}'}, - 'Tags': [{'{{ key }}': '{{ value }}'}] + "BatchStrategy": "SingleRecord", + "MaxConcurrentTransforms": "{{ max_parallel_job }}", + "MaxPayloadInMB": "{{ max_payload }}", + "Environment": {"{{ key }}": "{{ value }}"}, + "Tags": [{"{{ key }}": "{{ value }}"}], } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_transform_config_from_framework_estimator(sagemaker_session): mxnet_estimator = mxnet.MXNet( entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", - py_version='py3', - framework_version='1.3.0', + py_version="py3", + framework_version="1.3.0", role="{{ role }}", train_instance_count=1, - train_instance_type='ml.m4.xlarge', + train_instance_type="ml.m4.xlarge", sagemaker_session=sagemaker_session, base_job_name="{{ base_job_name }}", - hyperparameters={'batch_size': 100}) + hyperparameters={"batch_size": 100}, + ) train_data = "{{ train_data }}" transform_data = "{{ transform_data }}" @@ -850,66 +873,64 @@ def test_transform_config_from_framework_estimator(sagemaker_session): config = airflow.transform_config_from_estimator( estimator=mxnet_estimator, - task_id='task_id', - task_type='training', + task_id="task_id", + task_type="training", instance_count="{{ instance_count }}", instance_type="ml.p2.xlarge", - data=transform_data) + data=transform_data, + ) expected_config = { - 'Model': { - 'ModelName': "sagemaker-mxnet-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-gpu-py3', - 'Environment': {'SAGEMAKER_PROGRAM': '{{ entry_point }}', - 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/{{ ti.xcom_pull(task_ids='task_id')" - "['Training']['TrainingJobName'] }}" - "/source/sourcedir.tar.gz", - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': 'us-west-2' - }, - 'ModelDataUrl': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" - "/output/model.tar.gz" + "Model": { + "ModelName": "sagemaker-mxnet-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-gpu-py3", + "Environment": { + "SAGEMAKER_PROGRAM": "{{ entry_point }}", + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://output/{{ ti.xcom_pull(task_ids='task_id')" + "['Training']['TrainingJobName'] }}" + "/source/sourcedir.tar.gz", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + }, + "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" + "/output/model.tar.gz", }, - 'ExecutionRoleArn': '{{ role }}' + "ExecutionRoleArn": "{{ role }}", }, - 'Transform': { - 'TransformJobName': "{{ base_job_name }}-%s" % TIME_STAMP, - 'ModelName': "sagemaker-mxnet-%s" % TIME_STAMP, - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ transform_data }}' - } + "Transform": { + "TransformJobName": "{{ base_job_name }}-%s" % TIME_STAMP, + "ModelName": "sagemaker-mxnet-%s" % TIME_STAMP, + "TransformInput": { + "DataSource": { + "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "{{ transform_data }}"} } }, - 'TransformOutput': { - 'S3OutputPath': "s3://output/{{ base_job_name }}-%s" % TIME_STAMP - }, - 'TransformResources': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.p2.xlarge' + "TransformOutput": {"S3OutputPath": "s3://output/{{ base_job_name }}-%s" % TIME_STAMP}, + "TransformResources": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.p2.xlarge", }, - 'Environment': {} - } + "Environment": {}, + }, } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_transform_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN( role="{{ role }}", train_instance_count="{{ instance_count }}", - train_instance_type='ml.m4.xlarge', + train_instance_type="ml.m4.xlarge", k=16, sample_size=128, - predictor_type='regressor', - sagemaker_session=sagemaker_session) + predictor_type="regressor", + sagemaker_session=sagemaker_session, + ) - record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') + record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") transform_data = "{{ transform_data }}" # simulate training @@ -917,44 +938,43 @@ def test_transform_config_from_amazon_alg_estimator(sagemaker_session): config = airflow.transform_config_from_estimator( estimator=knn_estimator, - task_id='task_id', - task_type='training', + task_id="task_id", + task_type="training", instance_count="{{ instance_count }}", instance_type="ml.p2.xlarge", - data=transform_data) + data=transform_data, + ) expected_config = { - 'Model': { - 'ModelName': "knn-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1', - 'Environment': {}, - 'ModelDataUrl': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" - "/output/model.tar.gz" + "Model": { + "ModelName": "knn-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1", + "Environment": {}, + "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" + "/output/model.tar.gz", }, - 'ExecutionRoleArn': '{{ role }}'}, - 'Transform': { - 'TransformJobName': "knn-%s" % TIME_STAMP, - 'ModelName': "knn-%s" % TIME_STAMP, - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': '{{ transform_data }}'} + "ExecutionRoleArn": "{{ role }}", + }, + "Transform": { + "TransformJobName": "knn-%s" % TIME_STAMP, + "ModelName": "knn-%s" % TIME_STAMP, + "TransformInput": { + "DataSource": { + "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "{{ transform_data }}"} } }, - 'TransformOutput': { - 'S3OutputPath': "s3://output/knn-%s" % TIME_STAMP + "TransformOutput": {"S3OutputPath": "s3://output/knn-%s" % TIME_STAMP}, + "TransformResources": { + "InstanceCount": "{{ instance_count }}", + "InstanceType": "ml.p2.xlarge", }, - 'TransformResources': { - 'InstanceCount': '{{ instance_count }}', - 'InstanceType': 'ml.p2.xlarge'} - } + }, } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_deploy_framework_model_config(sagemaker_session): chainer_model = chainer.ChainerModel( model_data="{{ model_data }}", @@ -962,199 +982,222 @@ def test_deploy_framework_model_config(sagemaker_session): entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", image=None, - py_version='py3', - framework_version='5.0.0', + py_version="py3", + framework_version="5.0.0", model_server_workers="{{ model_server_worker }}", - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + ) - config = airflow.deploy_config(chainer_model, - initial_instance_count="{{ instance_count }}", - instance_type="ml.m4.xlarge") + config = airflow.deploy_config( + chainer_model, initial_instance_count="{{ instance_count }}", instance_type="ml.m4.xlarge" + ) expected_config = { - 'Model': { - 'ModelName': "sagemaker-chainer-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3', - 'Environment': { - 'SAGEMAKER_PROGRAM': '{{ entry_point }}', - 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/sagemaker-chainer-%s/source/sourcedir.tar.gz" - % TIME_STAMP, - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_MODEL_SERVER_WORKERS': '{{ model_server_worker }}' + "Model": { + "ModelName": "sagemaker-chainer-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:5.0.0-cpu-py3", + "Environment": { + "SAGEMAKER_PROGRAM": "{{ entry_point }}", + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://output/sagemaker-chainer-%s/source/sourcedir.tar.gz" + % TIME_STAMP, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_MODEL_SERVER_WORKERS": "{{ model_server_worker }}", }, - 'ModelDataUrl': '{{ model_data }}'}, - 'ExecutionRoleArn': '{{ role }}' + "ModelDataUrl": "{{ model_data }}", + }, + "ExecutionRoleArn": "{{ role }}", }, - 'EndpointConfig': { - 'EndpointConfigName': "sagemaker-chainer-%s" % TIME_STAMP, - 'ProductionVariants': [{ - 'InstanceType': 'ml.m4.xlarge', - 'InitialInstanceCount': '{{ instance_count }}', - 'ModelName': "sagemaker-chainer-%s" % TIME_STAMP, - 'VariantName': 'AllTraffic', - 'InitialVariantWeight': 1 - }] + "EndpointConfig": { + "EndpointConfigName": "sagemaker-chainer-%s" % TIME_STAMP, + "ProductionVariants": [ + { + "InstanceType": "ml.m4.xlarge", + "InitialInstanceCount": "{{ instance_count }}", + "ModelName": "sagemaker-chainer-%s" % TIME_STAMP, + "VariantName": "AllTraffic", + "InitialVariantWeight": 1, + } + ], }, - 'Endpoint': { - 'EndpointName': "sagemaker-chainer-%s" % TIME_STAMP, - 'EndpointConfigName': "sagemaker-chainer-%s" % TIME_STAMP + "Endpoint": { + "EndpointName": "sagemaker-chainer-%s" % TIME_STAMP, + "EndpointConfigName": "sagemaker-chainer-%s" % TIME_STAMP, + }, + "S3Operations": { + "S3Upload": [ + { + "Path": "{{ source_dir }}", + "Bucket": "output", + "Key": "sagemaker-chainer-%s/source/sourcedir.tar.gz" % TIME_STAMP, + "Tar": True, + } + ] }, - 'S3Operations': { - 'S3Upload': [{ - 'Path': '{{ source_dir }}', - 'Bucket': 'output', - 'Key': "sagemaker-chainer-%s/source/sourcedir.tar.gz" % TIME_STAMP, - 'Tar': True - }] - } } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_deploy_amazon_alg_model_config(sagemaker_session): pca_model = pca.PCAModel( - model_data="{{ model_data }}", - role="{{ role }}", - sagemaker_session=sagemaker_session) + model_data="{{ model_data }}", role="{{ role }}", sagemaker_session=sagemaker_session + ) - config = airflow.deploy_config(pca_model, - initial_instance_count="{{ instance_count }}", - instance_type='ml.c4.xlarge') + config = airflow.deploy_config( + pca_model, initial_instance_count="{{ instance_count }}", instance_type="ml.c4.xlarge" + ) expected_config = { - 'Model': { - 'ModelName': "pca-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/pca:1', - 'Environment': {}, - 'ModelDataUrl': '{{ model_data }}'}, - 'ExecutionRoleArn': '{{ role }}'}, - 'EndpointConfig': { - 'EndpointConfigName': "pca-%s" % TIME_STAMP, - 'ProductionVariants': [{ - 'InstanceType': 'ml.c4.xlarge', - 'InitialInstanceCount': '{{ instance_count }}', - 'ModelName': "pca-%s" % TIME_STAMP, - 'VariantName': 'AllTraffic', - 'InitialVariantWeight': 1 - }] + "Model": { + "ModelName": "pca-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/pca:1", + "Environment": {}, + "ModelDataUrl": "{{ model_data }}", + }, + "ExecutionRoleArn": "{{ role }}", + }, + "EndpointConfig": { + "EndpointConfigName": "pca-%s" % TIME_STAMP, + "ProductionVariants": [ + { + "InstanceType": "ml.c4.xlarge", + "InitialInstanceCount": "{{ instance_count }}", + "ModelName": "pca-%s" % TIME_STAMP, + "VariantName": "AllTraffic", + "InitialVariantWeight": 1, + } + ], + }, + "Endpoint": { + "EndpointName": "pca-%s" % TIME_STAMP, + "EndpointConfigName": "pca-%s" % TIME_STAMP, }, - 'Endpoint': { - 'EndpointName': "pca-%s" % TIME_STAMP, - 'EndpointConfigName': "pca-%s" % TIME_STAMP - } } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_deploy_config_from_framework_estimator(sagemaker_session): mxnet_estimator = mxnet.MXNet( entry_point="{{ entry_point }}", source_dir="{{ source_dir }}", - py_version='py3', - framework_version='1.3.0', + py_version="py3", + framework_version="1.3.0", role="{{ role }}", train_instance_count=1, - train_instance_type='ml.m4.xlarge', + train_instance_type="ml.m4.xlarge", sagemaker_session=sagemaker_session, base_job_name="{{ base_job_name }}", - hyperparameters={'batch_size': 100}) + hyperparameters={"batch_size": 100}, + ) train_data = "{{ train_data }}" # simulate training airflow.training_config(mxnet_estimator, train_data) - config = airflow.deploy_config_from_estimator(estimator=mxnet_estimator, - task_id='task_id', - task_type='training', - initial_instance_count="{{ instance_count}}", - instance_type="ml.c4.large", - endpoint_name="mxnet-endpoint") + config = airflow.deploy_config_from_estimator( + estimator=mxnet_estimator, + task_id="task_id", + task_type="training", + initial_instance_count="{{ instance_count}}", + instance_type="ml.c4.large", + endpoint_name="mxnet-endpoint", + ) expected_config = { - 'Model': { - 'ModelName': "sagemaker-mxnet-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3', - 'Environment': { - 'SAGEMAKER_PROGRAM': '{{ entry_point }}', - 'SAGEMAKER_SUBMIT_DIRECTORY': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']" - "['TrainingJobName'] }}/source/sourcedir.tar.gz", - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': 'us-west-2'}, - 'ModelDataUrl': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" - "/output/model.tar.gz" + "Model": { + "ModelName": "sagemaker-mxnet-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.3.0-cpu-py3", + "Environment": { + "SAGEMAKER_PROGRAM": "{{ entry_point }}", + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']" + "['TrainingJobName'] }}/source/sourcedir.tar.gz", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + }, + "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" + "/output/model.tar.gz", }, - 'ExecutionRoleArn': '{{ role }}' + "ExecutionRoleArn": "{{ role }}", }, - 'EndpointConfig': { - 'EndpointConfigName': "sagemaker-mxnet-%s" % TIME_STAMP, - 'ProductionVariants': [{ - 'InstanceType': 'ml.c4.large', - 'InitialInstanceCount': '{{ instance_count}}', - 'ModelName': "sagemaker-mxnet-%s" % TIME_STAMP, - 'VariantName': 'AllTraffic', - 'InitialVariantWeight': 1 - }] + "EndpointConfig": { + "EndpointConfigName": "sagemaker-mxnet-%s" % TIME_STAMP, + "ProductionVariants": [ + { + "InstanceType": "ml.c4.large", + "InitialInstanceCount": "{{ instance_count}}", + "ModelName": "sagemaker-mxnet-%s" % TIME_STAMP, + "VariantName": "AllTraffic", + "InitialVariantWeight": 1, + } + ], + }, + "Endpoint": { + "EndpointName": "mxnet-endpoint", + "EndpointConfigName": "sagemaker-mxnet-%s" % TIME_STAMP, }, - 'Endpoint': { - 'EndpointName': 'mxnet-endpoint', - 'EndpointConfigName': "sagemaker-mxnet-%s" % TIME_STAMP - } } assert config == expected_config -@patch('sagemaker.utils.sagemaker_timestamp', MagicMock(return_value=TIME_STAMP)) +@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP)) def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN( role="{{ role }}", train_instance_count="{{ instance_count }}", - train_instance_type='ml.m4.xlarge', + train_instance_type="ml.m4.xlarge", k=16, sample_size=128, - predictor_type='regressor', - sagemaker_session=sagemaker_session) + predictor_type="regressor", + sagemaker_session=sagemaker_session, + ) - record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, 'S3Prefix') + record = amazon_estimator.RecordSet("{{ record }}", 10000, 100, "S3Prefix") # simulate training airflow.training_config(knn_estimator, record, mini_batch_size=256) - config = airflow.deploy_config_from_estimator(estimator=knn_estimator, - task_id='task_id', - task_type='tuning', - initial_instance_count="{{ instance_count }}", - instance_type="ml.p2.xlarge") + config = airflow.deploy_config_from_estimator( + estimator=knn_estimator, + task_id="task_id", + task_type="tuning", + initial_instance_count="{{ instance_count }}", + instance_type="ml.p2.xlarge", + ) expected_config = { - 'Model': { - 'ModelName': "knn-%s" % TIME_STAMP, - 'PrimaryContainer': { - 'Image': '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1', - 'Environment': {}, - 'ModelDataUrl': "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Tuning']['BestTrainingJob']" - "['TrainingJobName'] }}/output/model.tar.gz"}, - 'ExecutionRoleArn': '{{ role }}'}, - 'EndpointConfig': { - 'EndpointConfigName': "knn-%s" % TIME_STAMP, - 'ProductionVariants': [{ - 'InstanceType': 'ml.p2.xlarge', - 'InitialInstanceCount': '{{ instance_count }}', - 'ModelName': "knn-%s" % TIME_STAMP, - 'VariantName': 'AllTraffic', 'InitialVariantWeight': 1 - }] + "Model": { + "ModelName": "knn-%s" % TIME_STAMP, + "PrimaryContainer": { + "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1", + "Environment": {}, + "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Tuning']['BestTrainingJob']" + "['TrainingJobName'] }}/output/model.tar.gz", + }, + "ExecutionRoleArn": "{{ role }}", + }, + "EndpointConfig": { + "EndpointConfigName": "knn-%s" % TIME_STAMP, + "ProductionVariants": [ + { + "InstanceType": "ml.p2.xlarge", + "InitialInstanceCount": "{{ instance_count }}", + "ModelName": "knn-%s" % TIME_STAMP, + "VariantName": "AllTraffic", + "InitialVariantWeight": 1, + } + ], + }, + "Endpoint": { + "EndpointName": "knn-%s" % TIME_STAMP, + "EndpointConfigName": "knn-%s" % TIME_STAMP, }, - 'Endpoint': { - 'EndpointName': "knn-%s" % TIME_STAMP, - 'EndpointConfigName': "knn-%s" % TIME_STAMP - } } assert config == expected_config diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index c00254eae0..071b317b44 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -23,159 +23,159 @@ from sagemaker.transformer import Transformer DESCRIBE_ALGORITHM_RESPONSE = { - 'AlgorithmName': 'scikit-decision-trees', - 'AlgorithmArn': 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - 'AlgorithmDescription': 'Decision trees using Scikit', - 'CreationTime': datetime.datetime(2018, 8, 3, 22, 44, 54, 437000), - 'TrainingSpecification': { - 'TrainingImage': '123.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample@sha256:12345', - 'TrainingImageDigest': 'sha256:206854b6ea2f0020d216311da732010515169820b898ec29720bcf1d2b46806a', - 'SupportedHyperParameters': [ + "AlgorithmName": "scikit-decision-trees", + "AlgorithmArn": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + "AlgorithmDescription": "Decision trees using Scikit", + "CreationTime": datetime.datetime(2018, 8, 3, 22, 44, 54, 437000), + "TrainingSpecification": { + "TrainingImage": "123.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample@sha256:12345", + "TrainingImageDigest": "sha256:206854b6ea2f0020d216311da732010515169820b898ec29720bcf1d2b46806a", + "SupportedHyperParameters": [ { - 'Name': 'max_leaf_nodes', - 'Description': 'Grow a tree with max_leaf_nodes in best-first fashion.', - 'Type': 'Integer', - 'Range': { - 'IntegerParameterRangeSpecification': {'MinValue': '1', 'MaxValue': '100000'} + "Name": "max_leaf_nodes", + "Description": "Grow a tree with max_leaf_nodes in best-first fashion.", + "Type": "Integer", + "Range": { + "IntegerParameterRangeSpecification": {"MinValue": "1", "MaxValue": "100000"} }, - 'IsTunable': True, - 'IsRequired': False, - 'DefaultValue': '100', + "IsTunable": True, + "IsRequired": False, + "DefaultValue": "100", }, { - 'Name': 'free_text_hp1', - 'Description': 'You can write anything here', - 'Type': 'FreeText', - 'IsTunable': False, - 'IsRequired': True - } + "Name": "free_text_hp1", + "Description": "You can write anything here", + "Type": "FreeText", + "IsTunable": False, + "IsRequired": True, + }, ], - 'SupportedTrainingInstanceTypes': ['ml.m4.xlarge', 'ml.m4.2xlarge', 'ml.m4.4xlarge'], - 'SupportsDistributedTraining': False, - 'MetricDefinitions': [ - {'Name': 'validation:accuracy', 'Regex': 'validation-accuracy: (\\S+)'} + "SupportedTrainingInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge", "ml.m4.4xlarge"], + "SupportsDistributedTraining": False, + "MetricDefinitions": [ + {"Name": "validation:accuracy", "Regex": "validation-accuracy: (\\S+)"} ], - 'TrainingChannels': [ + "TrainingChannels": [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File"], } ], - 'SupportedTuningJobObjectiveMetrics': [ - {'Type': 'Maximize', 'MetricName': 'validation:accuracy'} + "SupportedTuningJobObjectiveMetrics": [ + {"Type": "Maximize", "MetricName": "validation:accuracy"} ], }, - 'InferenceSpecification': { - 'InferenceImage': '123.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample@sha256:123', - 'SupportedTransformInstanceTypes': ['ml.m4.xlarge', 'ml.m4.2xlarge'], - 'SupportedContentTypes': ['text/csv'], - 'SupportedResponseMIMETypes': ['text'], + "InferenceSpecification": { + "InferenceImage": "123.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample@sha256:123", + "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], + "SupportedContentTypes": ["text/csv"], + "SupportedResponseMIMETypes": ["text"], }, - 'ValidationSpecification': { - 'ValidationRole': 'arn:aws:iam::764419575721:role/SageMakerRole', - 'ValidationProfiles': [ + "ValidationSpecification": { + "ValidationRole": "arn:aws:iam::764419575721:role/SageMakerRole", + "ValidationProfiles": [ { - 'ProfileName': 'ValidationProfile1', - 'TrainingJobDefinition': { - 'TrainingInputMode': 'File', - 'HyperParameters': {}, - 'InputDataConfig': [ + "ProfileName": "ValidationProfile1", + "TrainingJobDefinition": { + "TrainingInputMode": "File", + "HyperParameters": {}, + "InputDataConfig": [ { - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://sagemaker-us-east-2-7123/-scikit-byo-iris/training-input-data', - 'S3DataDistributionType': 'FullyReplicated', + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-us-east-2-7123/-scikit-byo-iris/training-input-data", + "S3DataDistributionType": "FullyReplicated", } }, - 'ContentType': 'text/csv', - 'CompressionType': 'None', - 'RecordWrapperType': 'None', + "ContentType": "text/csv", + "CompressionType": "None", + "RecordWrapperType": "None", } ], - 'OutputDataConfig': { - 'KmsKeyId': '', - 'S3OutputPath': 's3://sagemaker-us-east-2-764419575721/DEMO-scikit-byo-iris/training-output', + "OutputDataConfig": { + "KmsKeyId": "", + "S3OutputPath": "s3://sagemaker-us-east-2-764419575721/DEMO-scikit-byo-iris/training-output", }, - 'ResourceConfig': { - 'InstanceType': 'ml.c4.xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 10, + "ResourceConfig": { + "InstanceType": "ml.c4.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 10, }, - 'StoppingCondition': {'MaxRuntimeInSeconds': 3600}, + "StoppingCondition": {"MaxRuntimeInSeconds": 3600}, }, - 'TransformJobDefinition': { - 'MaxConcurrentTransforms': 0, - 'MaxPayloadInMB': 0, - 'TransformInput': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://sagemaker-us-east-2/scikit-byo-iris/batch-inference/transform_test.csv', + "TransformJobDefinition": { + "MaxConcurrentTransforms": 0, + "MaxPayloadInMB": 0, + "TransformInput": { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-us-east-2/scikit-byo-iris/batch-inference/transform_test.csv", } }, - 'ContentType': 'text/csv', - 'CompressionType': 'None', - 'SplitType': 'Line', + "ContentType": "text/csv", + "CompressionType": "None", + "SplitType": "Line", }, - 'TransformOutput': { - 'S3OutputPath': 's3://sagemaker-us-east-2-764419575721/scikit-byo-iris/batch-transform-output', - 'Accept': 'text/csv', - 'AssembleWith': 'Line', - 'KmsKeyId': '', + "TransformOutput": { + "S3OutputPath": "s3://sagemaker-us-east-2-764419575721/scikit-byo-iris/batch-transform-output", + "Accept": "text/csv", + "AssembleWith": "Line", + "KmsKeyId": "", }, - 'TransformResources': {'InstanceType': 'ml.c4.xlarge', 'InstanceCount': 1}, + "TransformResources": {"InstanceType": "ml.c4.xlarge", "InstanceCount": 1}, }, } ], - 'ValidationOutputS3Prefix': 's3://sagemaker-us-east-2-764419575721/DEMO-scikit-byo-iris/validation-output', - 'ValidateForMarketplace': True, + "ValidationOutputS3Prefix": "s3://sagemaker-us-east-2-764419575721/DEMO-scikit-byo-iris/validation-output", + "ValidateForMarketplace": True, }, - 'AlgorithmStatus': 'Completed', - 'AlgorithmStatusDetails': { - 'ValidationStatuses': [{'ProfileName': 'ValidationProfile1', 'Status': 'Completed'}] + "AlgorithmStatus": "Completed", + "AlgorithmStatusDetails": { + "ValidationStatuses": [{"ProfileName": "ValidationProfile1", "Status": "Completed"}] }, - 'ResponseMetadata': { - 'RequestId': 'e04bc28b-61b6-4486-9106-0edf07f5649c', - 'HTTPStatusCode': 200, - 'HTTPHeaders': { - 'x-amzn-requestid': 'e04bc28b-61b6-4486-9106-0edf07f5649c', - 'content-type': 'application/x-amz-json-1.1', - 'content-length': '3949', - 'date': 'Fri, 03 Aug 2018 23:08:43 GMT', + "ResponseMetadata": { + "RequestId": "e04bc28b-61b6-4486-9106-0edf07f5649c", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "e04bc28b-61b6-4486-9106-0edf07f5649c", + "content-type": "application/x-amz-json-1.1", + "content-length": "3949", + "date": "Fri, 03 Aug 2018 23:08:43 GMT", }, - 'RetryAttempts': 0, + "RetryAttempts": 0, }, } -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_supported_input_mode_with_valid_input_types(session): # verify that the Estimator verifies the # input mode that an Algorithm supports. file_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - file_mode_algo['TrainingSpecification']['TrainingChannels'] = [ + file_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File"], }, { - 'Name': 'validation', - 'Description': 'Input channel that provides validation data', - 'IsRequired': False, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File', 'Pipe'], + "Name": "validation", + "Description": "Input channel that provides validation data", + "IsRequired": False, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File", "Pipe"], }, ] @@ -183,30 +183,30 @@ def test_algorithm_supported_input_mode_with_valid_input_types(session): # Creating a File mode Estimator with a File mode algorithm should work AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) pipe_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - pipe_mode_algo['TrainingSpecification']['TrainingChannels'] = [ + pipe_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['Pipe'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["Pipe"], }, { - 'Name': 'validation', - 'Description': 'Input channel that provides validation data', - 'IsRequired': False, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File', 'Pipe'], + "Name": "validation", + "Description": "Input channel that provides validation data", + "IsRequired": False, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File", "Pipe"], }, ] @@ -214,31 +214,31 @@ def test_algorithm_supported_input_mode_with_valid_input_types(session): # Creating a Pipe mode Estimator with a Pipe mode algorithm should work. AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, - input_mode='Pipe', + input_mode="Pipe", sagemaker_session=session, ) any_input_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - any_input_algo['TrainingSpecification']['TrainingChannels'] = [ + any_input_algo["TrainingSpecification"]["TrainingChannels"] = [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File', 'Pipe'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File", "Pipe"], }, { - 'Name': 'validation', - 'Description': 'Input channel that provides validation data', - 'IsRequired': False, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File', 'Pipe'], + "Name": "validation", + "Description": "Input channel that provides validation data", + "IsRequired": False, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File", "Pipe"], }, ] @@ -247,36 +247,36 @@ def test_algorithm_supported_input_mode_with_valid_input_types(session): # Creating a File mode Estimator with an algorithm that supports both input modes # should work. AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_supported_input_mode_with_bad_input_types(session): # verify that the Estimator verifies raises exceptions when # attempting to train with an incorrect input type file_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - file_mode_algo['TrainingSpecification']['TrainingChannels'] = [ + file_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File"], }, { - 'Name': 'validation', - 'Description': 'Input channel that provides validation data', - 'IsRequired': False, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File', 'Pipe'], + "Name": "validation", + "Description": "Input channel that provides validation data", + "IsRequired": False, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File", "Pipe"], }, ] @@ -285,31 +285,31 @@ def test_algorithm_supported_input_mode_with_bad_input_types(session): # Creating a Pipe mode Estimator with a File mode algorithm should fail. with pytest.raises(ValueError): AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, - input_mode='Pipe', + input_mode="Pipe", sagemaker_session=session, ) pipe_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - pipe_mode_algo['TrainingSpecification']['TrainingChannels'] = [ + pipe_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['Pipe'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["Pipe"], }, { - 'Name': 'validation', - 'Description': 'Input channel that provides validation data', - 'IsRequired': False, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File', 'Pipe'], + "Name": "validation", + "Description": "Input channel that provides validation data", + "IsRequired": False, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File", "Pipe"], }, ] @@ -318,175 +318,171 @@ def test_algorithm_supported_input_mode_with_bad_input_types(session): # Creating a File mode Estimator with a Pipe mode algorithm should fail. with pytest.raises(ValueError): AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) -@patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -@patch('sagemaker.Session') +@patch("sagemaker.estimator.EstimatorBase.fit", Mock()) +@patch("sagemaker.Session") def test_algorithm_trainining_channels_with_expected_channels(session): training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - training_channels['TrainingSpecification']['TrainingChannels'] = [ + training_channels["TrainingSpecification"]["TrainingChannels"] = [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File"], }, { - 'Name': 'validation', - 'Description': 'Input channel that provides validation data', - 'IsRequired': False, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File'], + "Name": "validation", + "Description": "Input channel that provides validation data", + "IsRequired": False, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File"], }, ] session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) # Pass training and validation channels. This should work - estimator.fit({'training': 's3://some/place', 'validation': 's3://some/other'}) + estimator.fit({"training": "s3://some/place", "validation": "s3://some/other"}) # Passing only the training channel. Validation is optional so this should also work. - estimator.fit({'training': 's3://some/place'}) + estimator.fit({"training": "s3://some/place"}) -@patch('sagemaker.estimator.EstimatorBase.fit', Mock()) -@patch('sagemaker.Session') +@patch("sagemaker.estimator.EstimatorBase.fit", Mock()) +@patch("sagemaker.Session") def test_algorithm_trainining_channels_with_invalid_channels(session): training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - training_channels['TrainingSpecification']['TrainingChannels'] = [ + training_channels["TrainingSpecification"]["TrainingChannels"] = [ { - 'Name': 'training', - 'Description': 'Input channel that provides training data', - 'IsRequired': True, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File'], + "Name": "training", + "Description": "Input channel that provides training data", + "IsRequired": True, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File"], }, { - 'Name': 'validation', - 'Description': 'Input channel that provides validation data', - 'IsRequired': False, - 'SupportedContentTypes': ['text/csv'], - 'SupportedCompressionTypes': ['None'], - 'SupportedInputModes': ['File'], + "Name": "validation", + "Description": "Input channel that provides validation data", + "IsRequired": False, + "SupportedContentTypes": ["text/csv"], + "SupportedCompressionTypes": ["None"], + "SupportedInputModes": ["File"], }, ] session.sagemaker_client.describe_algorithm = Mock(return_value=training_channels) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) # Passing only validation should fail as training is required. with pytest.raises(ValueError): - estimator.fit({'validation': 's3://some/thing'}) + estimator.fit({"validation": "s3://some/thing"}) # Passing an unknown channel should fail??? with pytest.raises(ValueError): - estimator.fit({'training': 's3://some/data', 'training2': 's3://some/other/data'}) + estimator.fit({"training": "s3://some/data", "training2": "s3://some/other/data"}) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_train_instance_types_valid_instance_types(session): describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - train_instance_types = ['ml.m4.xlarge', 'ml.m5.2xlarge'] + train_instance_types = ["ml.m4.xlarge", "ml.m5.2xlarge"] - describe_algo_response['TrainingSpecification'][ - 'SupportedTrainingInstanceTypes' + describe_algo_response["TrainingSpecification"][ + "SupportedTrainingInstanceTypes" ] = train_instance_types - session.sagemaker_client.describe_algorithm = Mock( - return_value=describe_algo_response - ) + session.sagemaker_client.describe_algorithm = Mock(return_value=describe_algo_response) AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m5.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m5.2xlarge", train_instance_count=1, sagemaker_session=session, ) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_train_instance_types_invalid_instance_types(session): describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - train_instance_types = ['ml.m4.xlarge', 'ml.m5.2xlarge'] + train_instance_types = ["ml.m4.xlarge", "ml.m5.2xlarge"] - describe_algo_response['TrainingSpecification'][ - 'SupportedTrainingInstanceTypes' + describe_algo_response["TrainingSpecification"][ + "SupportedTrainingInstanceTypes" ] = train_instance_types - session.sagemaker_client.describe_algorithm = Mock( - return_value=describe_algo_response - ) + session.sagemaker_client.describe_algorithm = Mock(return_value=describe_algo_response) # invalid instance type, should fail with pytest.raises(ValueError): AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.8xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.8xlarge", train_instance_count=1, sagemaker_session=session, ) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_distributed_training_validation(session): distributed_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - distributed_algo['TrainingSpecification']['SupportsDistributedTraining'] = True + distributed_algo["TrainingSpecification"]["SupportsDistributedTraining"] = True single_instance_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - single_instance_algo['TrainingSpecification']['SupportsDistributedTraining'] = False + single_instance_algo["TrainingSpecification"]["SupportsDistributedTraining"] = False session.sagemaker_client.describe_algorithm = Mock(return_value=distributed_algo) # Distributed training should work for Distributed and Single instance. AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=2, sagemaker_session=session, ) @@ -496,39 +492,39 @@ def test_algorithm_distributed_training_validation(session): # distributed training on a single instance algorithm should fail. with pytest.raises(ValueError): AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m5.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m5.2xlarge", train_instance_count=2, sagemaker_session=session, ) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_hyperparameter_integer_range_valid_range(session): hyperparameters = [ { - 'Description': 'Grow a tree with max_leaf_nodes in best-first fashion.', - 'Type': 'Integer', - 'Name': 'max_leaf_nodes', - 'Range': { - 'IntegerParameterRangeSpecification': {'MinValue': '1', 'MaxValue': '100000'} + "Description": "Grow a tree with max_leaf_nodes in best-first fashion.", + "Type": "Integer", + "Name": "max_leaf_nodes", + "Range": { + "IntegerParameterRangeSpecification": {"MinValue": "1", "MaxValue": "100000"} }, - 'IsTunable': True, - 'IsRequired': False, - 'DefaultValue': '100', + "IsTunable": True, + "IsRequired": False, + "DefaultValue": "100", } ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) @@ -537,31 +533,31 @@ def test_algorithm_hyperparameter_integer_range_valid_range(session): estimator.set_hyperparameters(max_leaf_nodes=100000) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_hyperparameter_integer_range_invalid_range(session): hyperparameters = [ { - 'Description': 'Grow a tree with max_leaf_nodes in best-first fashion.', - 'Type': 'Integer', - 'Name': 'max_leaf_nodes', - 'Range': { - 'IntegerParameterRangeSpecification': {'MinValue': '1', 'MaxValue': '100000'} + "Description": "Grow a tree with max_leaf_nodes in best-first fashion.", + "Type": "Integer", + "Name": "max_leaf_nodes", + "Range": { + "IntegerParameterRangeSpecification": {"MinValue": "1", "MaxValue": "100000"} }, - 'IsTunable': True, - 'IsRequired': False, - 'DefaultValue': '100', + "IsTunable": True, + "IsRequired": False, + "DefaultValue": "100", } ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) @@ -573,31 +569,31 @@ def test_algorithm_hyperparameter_integer_range_invalid_range(session): estimator.set_hyperparameters(max_leaf_nodes=100001) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_hyperparameter_continuous_range_valid_range(session): hyperparameters = [ { - 'Description': 'A continuous hyperparameter', - 'Type': 'Continuous', - 'Name': 'max_leaf_nodes', - 'Range': { - 'ContinuousParameterRangeSpecification': {'MinValue': '0.0', 'MaxValue': '1.0'} + "Description": "A continuous hyperparameter", + "Type": "Continuous", + "Name": "max_leaf_nodes", + "Range": { + "ContinuousParameterRangeSpecification": {"MinValue": "0.0", "MaxValue": "1.0"} }, - 'IsTunable': True, - 'IsRequired': False, - 'DefaultValue': '100', + "IsTunable": True, + "IsRequired": False, + "DefaultValue": "100", } ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) @@ -608,31 +604,31 @@ def test_algorithm_hyperparameter_continuous_range_valid_range(session): estimator.set_hyperparameters(max_leaf_nodes=1) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_hyperparameter_continuous_range_invalid_range(session): hyperparameters = [ { - 'Description': 'A continuous hyperparameter', - 'Type': 'Continuous', - 'Name': 'max_leaf_nodes', - 'Range': { - 'ContinuousParameterRangeSpecification': {'MinValue': '0.0', 'MaxValue': '1.0'} + "Description": "A continuous hyperparameter", + "Type": "Continuous", + "Name": "max_leaf_nodes", + "Range": { + "ContinuousParameterRangeSpecification": {"MinValue": "0.0", "MaxValue": "1.0"} }, - 'IsTunable': True, - 'IsRequired': False, - 'DefaultValue': '100', + "IsTunable": True, + "IsRequired": False, + "DefaultValue": "100", } ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) @@ -644,159 +640,159 @@ def test_algorithm_hyperparameter_continuous_range_invalid_range(session): estimator.set_hyperparameters(max_leaf_nodes=-0.1) -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_hyperparameter_categorical_range(session): hyperparameters = [ { - 'Description': 'A continuous hyperparameter', - 'Type': 'Categorical', - 'Name': 'hp1', - 'Range': {'CategoricalParameterRangeSpecification': {'Values': ['TF', 'MXNet']}}, - 'IsTunable': True, - 'IsRequired': False, - 'DefaultValue': '100', + "Description": "A continuous hyperparameter", + "Type": "Categorical", + "Name": "hp1", + "Range": {"CategoricalParameterRangeSpecification": {"Values": ["TF", "MXNet"]}}, + "IsTunable": True, + "IsRequired": False, + "DefaultValue": "100", } ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) - estimator.set_hyperparameters(hp1='MXNet') - estimator.set_hyperparameters(hp1='TF') + estimator.set_hyperparameters(hp1="MXNet") + estimator.set_hyperparameters(hp1="TF") with pytest.raises(ValueError): - estimator.set_hyperparameters(hp1='Chainer') + estimator.set_hyperparameters(hp1="Chainer") with pytest.raises(ValueError): - estimator.set_hyperparameters(hp1='MxNET') + estimator.set_hyperparameters(hp1="MxNET") -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_required_hyperparameters_not_provided(session): hyperparameters = [ { - 'Description': 'A continuous hyperparameter', - 'Type': 'Categorical', - 'Name': 'hp1', - 'Range': {'CategoricalParameterRangeSpecification': {'Values': ['TF', 'MXNet']}}, - 'IsTunable': True, - 'IsRequired': True, + "Description": "A continuous hyperparameter", + "Type": "Categorical", + "Name": "hp1", + "Range": {"CategoricalParameterRangeSpecification": {"Values": ["TF", "MXNet"]}}, + "IsTunable": True, + "IsRequired": True, }, { - 'Name': 'hp2', - 'Description': 'A continuous hyperparameter', - 'Type': 'Categorical', - 'IsTunable': False, - 'IsRequired': True - } + "Name": "hp2", + "Description": "A continuous hyperparameter", + "Type": "Categorical", + "IsTunable": False, + "IsRequired": True, + }, ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) # hp1 is required and was not provided with pytest.raises(ValueError): - estimator.set_hyperparameters(hp2='TF2') + estimator.set_hyperparameters(hp2="TF2") # Calling fit with unset required hyperparameters should fail # this covers the use case of not calling set_hyperparameters() explicitly with pytest.raises(ValueError): - estimator.fit({'training': 's3://some/place'}) + estimator.fit({"training": "s3://some/place"}) -@patch('sagemaker.Session') -@patch('sagemaker.estimator.EstimatorBase.fit', Mock()) +@patch("sagemaker.Session") +@patch("sagemaker.estimator.EstimatorBase.fit", Mock()) def test_algorithm_required_hyperparameters_are_provided(session): hyperparameters = [ { - 'Description': 'A categorical hyperparameter', - 'Type': 'Categorical', - 'Name': 'hp1', - 'Range': {'CategoricalParameterRangeSpecification': {'Values': ['TF', 'MXNet']}}, - 'IsTunable': True, - 'IsRequired': True, + "Description": "A categorical hyperparameter", + "Type": "Categorical", + "Name": "hp1", + "Range": {"CategoricalParameterRangeSpecification": {"Values": ["TF", "MXNet"]}}, + "IsTunable": True, + "IsRequired": True, }, { - 'Name': 'hp2', - 'Description': 'A categorical hyperparameter', - 'Type': 'Categorical', - 'IsTunable': False, - 'IsRequired': True + "Name": "hp2", + "Description": "A categorical hyperparameter", + "Type": "Categorical", + "IsTunable": False, + "IsRequired": True, }, { - 'Name': 'free_text_hp1', - 'Description': 'You can write anything here', - 'Type': 'FreeText', - 'IsTunable': False, - 'IsRequired': True - } + "Name": "free_text_hp1", + "Description": "You can write anything here", + "Type": "FreeText", + "IsTunable": False, + "IsRequired": True, + }, ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) # All 3 Hyperparameters are provided - estimator.set_hyperparameters(hp1='TF', hp2='TF2', free_text_hp1='Hello!') + estimator.set_hyperparameters(hp1="TF", hp2="TF2", free_text_hp1="Hello!") -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_required_free_text_hyperparameter_not_provided(session): hyperparameters = [ { - 'Name': 'free_text_hp1', - 'Description': 'You can write anything here', - 'Type': 'FreeText', - 'IsTunable': False, - 'IsRequired': True + "Name": "free_text_hp1", + "Description": "You can write anything here", + "Type": "FreeText", + "IsTunable": False, + "IsRequired": True, }, { - 'Name': 'free_text_hp2', - 'Description': 'You can write anything here', - 'Type': 'FreeText', - 'IsTunable': False, - 'IsRequired': False - } + "Name": "free_text_hp2", + "Description": "You can write anything here", + "Type": "FreeText", + "IsTunable": False, + "IsRequired": False, + }, ] some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - some_algo['TrainingSpecification']['SupportedHyperParameters'] = hyperparameters + some_algo["TrainingSpecification"]["SupportedHyperParameters"] = hyperparameters session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) @@ -804,91 +800,87 @@ def test_algorithm_required_free_text_hyperparameter_not_provided(session): # Calling fit with unset required hyperparameters should fail # this covers the use case of not calling set_hyperparameters() explicitly with pytest.raises(ValueError): - estimator.fit({'training': 's3://some/place'}) + estimator.fit({"training": "s3://some/place"}) # hp1 is required and was not provided with pytest.raises(ValueError): - estimator.set_hyperparameters(free_text_hp2='some text') + estimator.set_hyperparameters(free_text_hp2="some text") -@patch('sagemaker.Session') -@patch('sagemaker.algorithm.AlgorithmEstimator.create_model') +@patch("sagemaker.Session") +@patch("sagemaker.algorithm.AlgorithmEstimator.create_model") def test_algorithm_create_transformer(create_model, session): - session.sagemaker_client.describe_algorithm = Mock( - return_value=DESCRIBE_ALGORITHM_RESPONSE) + session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) - estimator.latest_training_job = _TrainingJob(session, 'some-job-name') + estimator.latest_training_job = _TrainingJob(session, "some-job-name") model = Mock() - model.name = 'my-model' + model.name = "my-model" create_model.return_value = model - transformer = estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge') + transformer = estimator.transformer(instance_count=1, instance_type="ml.m4.xlarge") assert isinstance(transformer, Transformer) create_model.assert_called() - assert transformer.model_name == 'my-model' + assert transformer.model_name == "my-model" -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_create_transformer_without_completed_training_job(session): - session.sagemaker_client.describe_algorithm = Mock( - return_value=DESCRIBE_ALGORITHM_RESPONSE) + session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) with pytest.raises(RuntimeError) as error: - estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge') - assert 'No finished training job found associated with this estimator' in str(error) + estimator.transformer(instance_count=1, instance_type="ml.m4.xlarge") + assert "No finished training job found associated with this estimator" in str(error) -@patch('sagemaker.algorithm.AlgorithmEstimator.create_model') -@patch('sagemaker.Session') +@patch("sagemaker.algorithm.AlgorithmEstimator.create_model") +@patch("sagemaker.Session") def test_algorithm_create_transformer_with_product_id(create_model, session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - response['ProductId'] = 'some-product-id' - session.sagemaker_client.describe_algorithm = Mock( - return_value=response) + response["ProductId"] = "some-product-id" + session.sagemaker_client.describe_algorithm = Mock(return_value=response) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) - estimator.latest_training_job = _TrainingJob(session, 'some-job-name') + estimator.latest_training_job = _TrainingJob(session, "some-job-name") model = Mock() - model.name = 'my-model' + model.name = "my-model" create_model.return_value = model - transformer = estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge') + transformer = estimator.transformer(instance_count=1, instance_type="ml.m4.xlarge") assert transformer.env is None -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_enable_network_isolation_no_product_id(session): - session.sagemaker_client.describe_algorithm = Mock( - return_value=DESCRIBE_ALGORITHM_RESPONSE) + session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) @@ -897,17 +889,16 @@ def test_algorithm_enable_network_isolation_no_product_id(session): assert network_isolation is False -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_enable_network_isolation_with_product_id(session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - response['ProductId'] = 'some-product-id' - session.sagemaker_client.describe_algorithm = Mock( - return_value=response) + response["ProductId"] = "some-product-id" + session.sagemaker_client.describe_algorithm = Mock(return_value=response) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, ) @@ -916,30 +907,29 @@ def test_algorithm_enable_network_isolation_with_product_id(session): assert network_isolation is True -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_encrypt_inter_container_traffic(session): response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - response['encrypt_inter_container_traffic'] = True - session.sagemaker_client.describe_algorithm = Mock( - return_value=response) + response["encrypt_inter_container_traffic"] = True + session.sagemaker_client.describe_algorithm = Mock(return_value=response) estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.xlarge", train_instance_count=1, sagemaker_session=session, - encrypt_inter_container_traffic=True + encrypt_inter_container_traffic=True, ) encrypt_inter_container_traffic = estimator.encrypt_inter_container_traffic assert encrypt_inter_container_traffic is True -@patch('sagemaker.Session') +@patch("sagemaker.Session") def test_algorithm_no_required_hyperparameters(session): some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) - del some_algo['TrainingSpecification']['SupportedHyperParameters'] + del some_algo["TrainingSpecification"]["SupportedHyperParameters"] session.sagemaker_client.describe_algorithm = Mock(return_value=some_algo) @@ -947,9 +937,9 @@ def test_algorithm_no_required_hyperparameters(session): # should fail if they are required. # Pass training and hyperparameters channels. This should work assert AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - role='SageMakerRole', - train_instance_type='ml.m4.2xlarge', + algorithm_arn="arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + role="SageMakerRole", + train_instance_type="ml.m4.2xlarge", train_instance_count=1, sagemaker_session=session, ) diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 168d90d3e7..43a7a71c8b 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -18,55 +18,75 @@ # Use PCA as a test implementation of AmazonAlgorithmEstimator from sagemaker.amazon.pca import PCA -from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry, get_image_uri +from sagemaker.amazon.amazon_estimator import ( + upload_numpy_to_s3_shards, + _build_shards, + registry, + get_image_uri, +) -COMMON_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'} +COMMON_ARGS = {"role": "myrole", "train_instance_count": 1, "train_instance_type": "ml.c4.xlarge"} REGION = "us-west-2" BUCKET_NAME = "Some-Bucket" -TIMESTAMP = '2017-11-06-14:14:15.671' +TIMESTAMP = "2017-11-06-14:14:15.671" @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - returned_job_description = {'AlgorithmSpecification': {'TrainingInputMode': 'File', - 'TrainingImage': registry("us-west-2") + "/pca:1"}, - 'ModelArtifacts': {'S3ModelArtifacts': "s3://some-bucket/model.tar.gz"}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100'}, - 'RoleArn': 'arn:aws:iam::366:role/IMRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + returned_job_description = { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": registry("us-west-2") + "/pca:1", + }, + "ModelArtifacts": {"S3ModelArtifacts": "s3://some-bucket/model.tar.gz"}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + }, + "RoleArn": "arn:aws:iam::366:role/IMRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) return sms def test_gov_ecr_uri(): - assert get_image_uri('us-gov-west-1', 'kmeans', 'latest') == \ - '226302683700.dkr.ecr.us-gov-west-1.amazonaws.com/kmeans:latest' + assert ( + get_image_uri("us-gov-west-1", "kmeans", "latest") + == "226302683700.dkr.ecr.us-gov-west-1.amazonaws.com/kmeans:latest" + ) - assert get_image_uri('us-iso-east-1', 'kmeans', 'latest') == \ - '490574956308.dkr.ecr.us-iso-east-1.c2s.ic.gov/kmeans:latest' + assert ( + get_image_uri("us-iso-east-1", "kmeans", "latest") + == "490574956308.dkr.ecr.us-iso-east-1.c2s.ic.gov/kmeans:latest" + ) def test_init(sagemaker_session): @@ -75,22 +95,32 @@ def test_init(sagemaker_session): def test_init_all_pca_hyperparameters(sagemaker_session): - pca = PCA(num_components=55, algorithm_mode='randomized', - subtract_mean=True, extra_components=33, sagemaker_session=sagemaker_session, - **COMMON_ARGS) + pca = PCA( + num_components=55, + algorithm_mode="randomized", + subtract_mean=True, + extra_components=33, + sagemaker_session=sagemaker_session, + **COMMON_ARGS + ) assert pca.num_components == 55 - assert pca.algorithm_mode == 'randomized' + assert pca.algorithm_mode == "randomized" assert pca.extra_components == 33 def test_init_estimator_args(sagemaker_session): - pca = PCA(num_components=1, train_max_run=1234, sagemaker_session=sagemaker_session, - data_location='s3://some-bucket/some-key/', **COMMON_ARGS) - assert pca.train_instance_type == COMMON_ARGS['train_instance_type'] - assert pca.train_instance_count == COMMON_ARGS['train_instance_count'] - assert pca.role == COMMON_ARGS['role'] + pca = PCA( + num_components=1, + train_max_run=1234, + sagemaker_session=sagemaker_session, + data_location="s3://some-bucket/some-key/", + **COMMON_ARGS + ) + assert pca.train_instance_type == COMMON_ARGS["train_instance_type"] + assert pca.train_instance_count == COMMON_ARGS["train_instance_count"] + assert pca.role == COMMON_ARGS["role"] assert pca.train_max_run == 1234 - assert pca.data_location == 's3://some-bucket/some-key/' + assert pca.data_location == "s3://some-bucket/some-key/" def test_data_location_validation(sagemaker_session): @@ -101,7 +131,12 @@ def test_data_location_validation(sagemaker_session): def test_data_location_does_not_call_default_bucket(sagemaker_session): data_location = "s3://my-bucket/path/" - pca = PCA(num_components=2, sagemaker_session=sagemaker_session, data_location=data_location, **COMMON_ARGS) + pca = PCA( + num_components=2, + sagemaker_session=sagemaker_session, + data_location=data_location, + **COMMON_ARGS + ) assert pca.data_location == data_location assert not sagemaker_session.default_bucket.called @@ -135,12 +170,12 @@ def test_prepare_for_training_list_no_train_channel(sagemaker_session): train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]] labels = [99, 85, 87, 2] - records = [pca.record_set(np.array(train), np.array(labels), 'test')] + records = [pca.record_set(np.array(train), np.array(labels), "test")] with pytest.raises(ValueError) as ex: pca._prepare_for_training(records, mini_batch_size=1) - assert 'Must provide train channel.' in str(ex) + assert "Must provide train channel." in str(ex) def test_prepare_for_training_encrypt(sagemaker_session): @@ -148,8 +183,9 @@ def test_prepare_for_training_encrypt(sagemaker_session): train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]] labels = [99, 85, 87, 2] - with patch('sagemaker.amazon.amazon_estimator.upload_numpy_to_s3_shards', - return_value='manfiest_file') as mock_upload: + with patch( + "sagemaker.amazon.amazon_estimator.upload_numpy_to_s3_shards", return_value="manfiest_file" + ) as mock_upload: pca.record_set(np.array(train), np.array(labels)) pca.record_set(np.array(train), np.array(labels), encrypt=True) @@ -159,27 +195,35 @@ def make_upload_call(encrypt): mock_upload.assert_has_calls([make_upload_call(False), make_upload_call(True)]) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_fit_ndarray(time, sagemaker_session): mock_s3 = Mock() mock_object = Mock() mock_s3.Object = Mock(return_value=mock_object) sagemaker_session.boto_session.resource = Mock(return_value=mock_s3) kwargs = dict(COMMON_ARGS) - kwargs['train_instance_count'] = 3 - pca = PCA(num_components=55, sagemaker_session=sagemaker_session, - data_location='s3://{}/key-prefix/'.format(BUCKET_NAME), **kwargs) + kwargs["train_instance_count"] = 3 + pca = PCA( + num_components=55, + sagemaker_session=sagemaker_session, + data_location="s3://{}/key-prefix/".format(BUCKET_NAME), + **kwargs + ) train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]] labels = [99, 85, 87, 2] pca.fit(pca.record_set(np.array(train), np.array(labels))) mock_s3.Object.assert_any_call( - BUCKET_NAME, 'key-prefix/PCA-2017-11-06-14:14:15.671/matrix_0.pbr'.format(TIMESTAMP)) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_0.pbr".format(TIMESTAMP) + ) mock_s3.Object.assert_any_call( - BUCKET_NAME, 'key-prefix/PCA-2017-11-06-14:14:15.671/matrix_1.pbr'.format(TIMESTAMP)) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_1.pbr".format(TIMESTAMP) + ) mock_s3.Object.assert_any_call( - BUCKET_NAME, 'key-prefix/PCA-2017-11-06-14:14:15.671/matrix_2.pbr'.format(TIMESTAMP)) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/matrix_2.pbr".format(TIMESTAMP) + ) mock_s3.Object.assert_any_call( - BUCKET_NAME, 'key-prefix/PCA-2017-11-06-14:14:15.671/.amazon.manifest'.format(TIMESTAMP)) + BUCKET_NAME, "key-prefix/PCA-2017-11-06-14:14:15.671/.amazon.manifest".format(TIMESTAMP) + ) assert mock_object.put.call_count == 4 @@ -211,11 +255,11 @@ def make_all_put_calls(**kwargs): return [call(Body=ANY, **kwargs) for i in range(num_objects)] upload_numpy_to_s3_shards(num_shards, mock_s3, BUCKET_NAME, "key-prefix", array, labels) - mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_0.pbr')]) - mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_1.pbr')]) - mock_s3.Object.assert_has_calls([call(BUCKET_NAME, 'key-prefix/matrix_2.pbr')]) + mock_s3.Object.assert_has_calls([call(BUCKET_NAME, "key-prefix/matrix_0.pbr")]) + mock_s3.Object.assert_has_calls([call(BUCKET_NAME, "key-prefix/matrix_1.pbr")]) + mock_s3.Object.assert_has_calls([call(BUCKET_NAME, "key-prefix/matrix_2.pbr")]) mock_put.assert_has_calls(make_all_put_calls()) mock_put.reset() upload_numpy_to_s3_shards(3, mock_s3, BUCKET_NAME, "key-prefix", array, labels, encrypt=True) - mock_put.assert_has_calls(make_all_put_calls(ServerSideEncryption='AES256')) + mock_put.assert_has_calls(make_all_put_calls(ServerSideEncryption="AES256")) diff --git a/tests/unit/test_analytics.py b/tests/unit/test_analytics.py index df432a9dd5..65a43af9b9 100644 --- a/tests/unit/test_analytics.py +++ b/tests/unit/test_analytics.py @@ -19,10 +19,14 @@ import pytest from mock import Mock -from sagemaker.analytics import AnalyticsMetricsBase, HyperparameterTuningJobAnalytics, TrainingJobAnalytics +from sagemaker.analytics import ( + AnalyticsMetricsBase, + HyperparameterTuningJobAnalytics, + TrainingJobAnalytics, +) -BUCKET_NAME = 'mybucket' -REGION = 'us-west-2' +BUCKET_NAME = "mybucket" +REGION = "us-west-2" @pytest.fixture() @@ -30,43 +34,53 @@ def sagemaker_session(): return create_sagemaker_session() -def create_sagemaker_session(describe_training_result=None, list_training_results=None, metric_stats_results=None, - describe_tuning_result=None): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_hyper_parameter_tuning_job', - return_value=describe_tuning_result) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=describe_training_result) +def create_sagemaker_session( + describe_training_result=None, + list_training_results=None, + metric_stats_results=None, + describe_tuning_result=None, +): + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_hyper_parameter_tuning_job", return_value=describe_tuning_result + ) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=describe_training_result + ) sms.sagemaker_client.list_training_jobs_for_hyper_parameter_tuning_job = Mock( - name='list_training_jobs_for_hyper_parameter_tuning_job', - return_value=list_training_results, + name="list_training_jobs_for_hyper_parameter_tuning_job", return_value=list_training_results ) - cwm_mock = Mock(name='cloudwatch_client') + cwm_mock = Mock(name="cloudwatch_client") boto_mock.client = Mock(return_value=cwm_mock) - cwm_mock.get_metric_statistics = Mock( - name='get_metric_statistics' - ) + cwm_mock.get_metric_statistics = Mock(name="get_metric_statistics") cwm_mock.get_metric_statistics.side_effect = cw_request_side_effect return sms -def cw_request_side_effect(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics): +def cw_request_side_effect( + Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics +): if _is_valid_request(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics): return _metric_stats_results() def _is_valid_request(Namespace, MetricName, Dimensions, StartTime, EndTime, Period, Statistics): could_watch_request = { - 'Namespace': Namespace, - 'MetricName': MetricName, - 'Dimensions': Dimensions, - 'StartTime': StartTime, - 'EndTime': EndTime, - 'Period': Period, - 'Statistics': Statistics, + "Namespace": Namespace, + "MetricName": MetricName, + "Dimensions": Dimensions, + "StartTime": StartTime, + "EndTime": EndTime, + "Period": Period, + "Statistics": Statistics, } print(could_watch_request) return could_watch_request == cw_request() @@ -75,18 +89,13 @@ def _is_valid_request(Namespace, MetricName, Dimensions, StartTime, EndTime, Per def cw_request(): describe_training_result = _describe_training_result() return { - 'Namespace': '/aws/sagemaker/TrainingJobs', - 'MetricName': 'train:acc', - 'Dimensions': [ - { - 'Name': 'TrainingJobName', - 'Value': 'my-training-job' - } - ], - 'StartTime': describe_training_result['TrainingStartTime'], - 'EndTime': describe_training_result['TrainingEndTime'] + datetime.timedelta(minutes=1), - 'Period': 60, - 'Statistics': ['Average'], + "Namespace": "/aws/sagemaker/TrainingJobs", + "MetricName": "train:acc", + "Dimensions": [{"Name": "TrainingJobName", "Value": "my-training-job"}], + "StartTime": describe_training_result["TrainingStartTime"], + "EndTime": describe_training_result["TrainingEndTime"] + datetime.timedelta(minutes=1), + "Period": 60, + "Statistics": ["Average"], } @@ -107,46 +116,52 @@ def mock_summary(name="job-name", value=0.9): return { "TrainingJobName": name, "TrainingJobStatus": "Completed", - "FinalHyperParameterTuningJobObjectiveMetric": { - "Name": "awesomeness", - "Value": value, - }, + "FinalHyperParameterTuningJobObjectiveMetric": {"Name": "awesomeness", "Value": value}, "TrainingStartTime": datetime.datetime(2018, 5, 16, 1, 2, 3), "TrainingEndTime": datetime.datetime(2018, 5, 16, 5, 6, 7), - "TunedHyperParameters": { - "learning_rate": 0.1, - "layers": 137, - }, + "TunedHyperParameters": {"learning_rate": 0.1, "layers": 137}, } - session = create_sagemaker_session(list_training_results={ - "TrainingJobSummaries": [ - mock_summary(), - mock_summary(), - mock_summary(), - mock_summary(), - mock_summary(), - ] - }) + + session = create_sagemaker_session( + list_training_results={ + "TrainingJobSummaries": [ + mock_summary(), + mock_summary(), + mock_summary(), + mock_summary(), + mock_summary(), + ] + } + ) tuner = HyperparameterTuningJobAnalytics("my-tuning-job", sagemaker_session=session) df = tuner.dataframe() assert df is not None assert len(df) == 5 - assert len(session.sagemaker_client.list_training_jobs_for_hyper_parameter_tuning_job.mock_calls) == 1 + assert ( + len(session.sagemaker_client.list_training_jobs_for_hyper_parameter_tuning_job.mock_calls) + == 1 + ) # Clear the cache, check that it calls the service again. tuner.clear_cache() df = tuner.dataframe() - assert len(session.sagemaker_client.list_training_jobs_for_hyper_parameter_tuning_job.mock_calls) == 2 + assert ( + len(session.sagemaker_client.list_training_jobs_for_hyper_parameter_tuning_job.mock_calls) + == 2 + ) df = tuner.dataframe(force_refresh=True) - assert len(session.sagemaker_client.list_training_jobs_for_hyper_parameter_tuning_job.mock_calls) == 3 + assert ( + len(session.sagemaker_client.list_training_jobs_for_hyper_parameter_tuning_job.mock_calls) + == 3 + ) # check that the hyperparameter is in the dataframe - assert len(df['layers']) == 5 - assert min(df['layers']) == 137 + assert len(df["layers"]) == 5 + assert min(df["layers"]) == 137 # Check that the training time calculation is returning something sane. - assert min(df['TrainingElapsedTimeSeconds']) > 5 - assert max(df['TrainingElapsedTimeSeconds']) < 86400 + assert min(df["TrainingElapsedTimeSeconds"]) > 5 + assert max(df["TrainingElapsedTimeSeconds"]) < 86400 # Export to CSV and check that file exists tmp_name = "/tmp/unit-test-%s.csv" % uuid.uuid4() @@ -157,27 +172,29 @@ def mock_summary(name="job-name", value=0.9): def test_description(): - session = create_sagemaker_session(describe_tuning_result={ - 'HyperParameterTuningJobConfig': { - 'ParameterRanges': { - 'CategoricalParameterRanges': [], - 'ContinuousParameterRanges': [ - {'MaxValue': '1', 'MinValue': '0', 'Name': 'eta'}, - {'MaxValue': '10', 'MinValue': '0', 'Name': 'gamma'}, - ], - 'IntegerParameterRanges': [ - {'MaxValue': '30', 'MinValue': '5', 'Name': 'num_layers'}, - {'MaxValue': '100', 'MinValue': '50', 'Name': 'iterations'}, - ], - }, - }, - }) + session = create_sagemaker_session( + describe_tuning_result={ + "HyperParameterTuningJobConfig": { + "ParameterRanges": { + "CategoricalParameterRanges": [], + "ContinuousParameterRanges": [ + {"MaxValue": "1", "MinValue": "0", "Name": "eta"}, + {"MaxValue": "10", "MinValue": "0", "Name": "gamma"}, + ], + "IntegerParameterRanges": [ + {"MaxValue": "30", "MinValue": "5", "Name": "num_layers"}, + {"MaxValue": "100", "MinValue": "50", "Name": "iterations"}, + ], + } + } + } + ) tuner = HyperparameterTuningJobAnalytics("my-tuning-job", sagemaker_session=session) d = tuner.description() assert len(session.sagemaker_client.describe_hyper_parameter_tuning_job.mock_calls) == 1 assert d is not None - assert d['HyperParameterTuningJobConfig'] is not None + assert d["HyperParameterTuningJobConfig"] is not None tuner.clear_cache() d = tuner.description() assert len(session.sagemaker_client.describe_hyper_parameter_tuning_job.mock_calls) == 2 @@ -193,8 +210,8 @@ def test_description(): def test_trainer_name(): describe_training_result = { - 'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3), - 'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7), + "TrainingStartTime": datetime.datetime(2018, 5, 16, 1, 2, 3), + "TrainingEndTime": datetime.datetime(2018, 5, 16, 5, 6, 7), } session = create_sagemaker_session(describe_training_result) trainer = TrainingJobAnalytics("my-training-job", ["metric"], sagemaker_session=session) @@ -204,40 +221,33 @@ def test_trainer_name(): def _describe_training_result(): return { - 'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3), - 'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7), + "TrainingStartTime": datetime.datetime(2018, 5, 16, 1, 2, 3), + "TrainingEndTime": datetime.datetime(2018, 5, 16, 5, 6, 7), } def _metric_stats_results(): return { - 'Datapoints': [ - { - 'Average': 77.1, - 'Timestamp': datetime.datetime(2018, 5, 16, 1, 3, 3), - }, - { - 'Average': 87.1, - 'Timestamp': datetime.datetime(2018, 5, 16, 1, 8, 3), - }, - { - 'Average': 97.1, - 'Timestamp': datetime.datetime(2018, 5, 16, 2, 3, 3), - }, + "Datapoints": [ + {"Average": 77.1, "Timestamp": datetime.datetime(2018, 5, 16, 1, 3, 3)}, + {"Average": 87.1, "Timestamp": datetime.datetime(2018, 5, 16, 1, 8, 3)}, + {"Average": 97.1, "Timestamp": datetime.datetime(2018, 5, 16, 2, 3, 3)}, ] } def test_trainer_dataframe(): - session = create_sagemaker_session(describe_training_result=_describe_training_result(), - metric_stats_results=_metric_stats_results()) + session = create_sagemaker_session( + describe_training_result=_describe_training_result(), + metric_stats_results=_metric_stats_results(), + ) trainer = TrainingJobAnalytics("my-training-job", ["train:acc"], sagemaker_session=session) df = trainer.dataframe() assert df is not None assert len(df) == 3 - assert min(df['value']) == 77.1 - assert max(df['value']) == 97.1 + assert min(df["value"]) == 77.1 + assert max(df["value"]) == 97.1 # Export to CSV and check that file exists tmp_name = "/tmp/unit-test-%s.csv" % uuid.uuid4() @@ -249,16 +259,22 @@ def test_trainer_dataframe(): def test_start_time_end_time_and_period_specified(): describe_training_result = { - 'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3), - 'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7), + "TrainingStartTime": datetime.datetime(2018, 5, 16, 1, 2, 3), + "TrainingEndTime": datetime.datetime(2018, 5, 16, 5, 6, 7), } session = create_sagemaker_session(describe_training_result) start_time = datetime.datetime(2018, 5, 16, 1, 3, 4) end_time = datetime.datetime(2018, 5, 16, 5, 1, 1) period = 300 - trainer = TrainingJobAnalytics('my-training-job', ['metric'], - sagemaker_session=session, start_time=start_time, end_time=end_time, period=period) + trainer = TrainingJobAnalytics( + "my-training-job", + ["metric"], + sagemaker_session=session, + start_time=start_time, + end_time=end_time, + period=period, + ) - assert trainer._time_interval['start_time'] == start_time - assert trainer._time_interval['end_time'] == end_time + assert trainer._time_interval["start_time"] == start_time + assert trainer._time_interval["end_time"] == end_time assert trainer._period == period diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 7e548ac3a6..8357f4c6b4 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -27,218 +27,233 @@ from sagemaker.chainer import ChainerPredictor, ChainerModel -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py') +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") MODEL_DATA = "s3://some/data.tar.gz" -TIMESTAMP = '2017-11-06-14:14:15.672' +TIMESTAMP = "2017-11-06-14:14:15.672" TIME = 1507167947 -BUCKET_NAME = 'mybucket' +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.4xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -PYTHON_VERSION = 'py' + str(sys.version_info.major) -IMAGE_NAME = 'sagemaker-chainer' -JOB_NAME = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) +INSTANCE_TYPE = "ml.c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +PYTHON_VERSION = "py" + str(sys.version_info.major) +IMAGE_NAME = "sagemaker-chainer" +JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" -ROLE = 'Dummy' -REGION = 'us-west-2' -GPU = 'ml.p2.xlarge' -CPU = 'ml.c4.xlarge' +ROLE = "Dummy" +REGION = "us-west-2" +GPU = "ml.p2.xlarge" +CPU = "ml.c4.xlarge" -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} -LIST_TAGS_RESULT = { - 'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] -} +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - - describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) - session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session def _get_full_cpu_image_uri(version, py_version=PYTHON_VERSION): - return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'cpu', py_version) + return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "cpu", py_version) def _get_full_gpu_image_uri(version, py_version=PYTHON_VERSION): - return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'gpu', py_version) + return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "gpu", py_version) def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION): - return _get_full_cpu_image_uri(version, py_version=py_version) + '-eia' - - -def _chainer_estimator(sagemaker_session, framework_version=defaults.CHAINER_VERSION, train_instance_type=None, - base_job_name=None, use_mpi=None, num_processes=None, - process_slots_per_host=None, additional_mpi_options=None, **kwargs): - return Chainer(entry_point=SCRIPT_PATH, - framework_version=framework_version, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, - base_job_name=base_job_name, - use_mpi=use_mpi, - num_processes=num_processes, - process_slots_per_host=process_slots_per_host, - additional_mpi_options=additional_mpi_options, - py_version=PYTHON_VERSION, - **kwargs) + return _get_full_cpu_image_uri(version, py_version=py_version) + "-eia" + + +def _chainer_estimator( + sagemaker_session, + framework_version=defaults.CHAINER_VERSION, + train_instance_type=None, + base_job_name=None, + use_mpi=None, + num_processes=None, + process_slots_per_host=None, + additional_mpi_options=None, + **kwargs +): + return Chainer( + entry_point=SCRIPT_PATH, + framework_version=framework_version, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, + base_job_name=base_job_name, + use_mpi=use_mpi, + num_processes=num_processes, + process_slots_per_host=process_slots_per_host, + additional_mpi_options=additional_mpi_options, + py_version=PYTHON_VERSION, + **kwargs + ) def _create_train_job(version): return { - 'image': _get_full_cpu_image_uri(version), - 'input_mode': 'File', - 'input_config': [ + "image": _get_full_cpu_image_uri(version), + "input_mode": "File", + "input_config": [ { - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix' + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", } - } + }, } ], - 'role': ROLE, - 'job_name': JOB_NAME, - 'output_config': { - 'S3OutputPath': 's3://{}/'.format(BUCKET_NAME), - }, - 'resource_config': { - 'InstanceType': 'ml.c4.4xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 30, - }, - 'hyperparameters': { - 'sagemaker_program': json.dumps('dummy_script.py'), - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': str(logging.INFO), - 'sagemaker_job_name': json.dumps(JOB_NAME), - 'sagemaker_submit_directory': - json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)), - 'sagemaker_region': '"us-west-2"' + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": "ml.c4.4xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, }, - 'stop_condition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-west-2"', }, - 'tags': None, - 'vpc_config': None, - 'metric_definitions': None + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, } def _create_train_job_with_additional_hyperparameters(version): return { - 'image': _get_full_cpu_image_uri(version), - 'input_mode': 'File', - 'input_config': [ + "image": _get_full_cpu_image_uri(version), + "input_mode": "File", + "input_config": [ { - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix' + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", } - } + }, } ], - 'role': ROLE, - 'job_name': JOB_NAME, - 'output_config': { - 'S3OutputPath': 's3://{}/'.format(BUCKET_NAME), + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": "ml.c4.4xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, }, - 'resource_config': { - 'InstanceType': 'ml.c4.4xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 30, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-west-2"', + "sagemaker_num_processes": "4", + "sagemaker_additional_mpi_options": '"-x MY_ENVIRONMENT_VARIABLE"', + "sagemaker_process_slots_per_host": "10", + "sagemaker_use_mpi": "true", }, - 'hyperparameters': { - 'sagemaker_program': json.dumps('dummy_script.py'), - 'sagemaker_container_log_level': str(logging.INFO), - 'sagemaker_job_name': json.dumps(JOB_NAME), - 'sagemaker_submit_directory': - json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)), - 'sagemaker_region': '"us-west-2"', - 'sagemaker_num_processes': '4', - 'sagemaker_additional_mpi_options': '"-x MY_ENVIRONMENT_VARIABLE"', - 'sagemaker_process_slots_per_host': '10', - 'sagemaker_use_mpi': 'true' - }, - 'stop_condition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - } + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, } def test_additional_hyperparameters(sagemaker_session): - chainer = _chainer_estimator(sagemaker_session, use_mpi=True, num_processes=4, process_slots_per_host=10, - additional_mpi_options="-x MY_ENVIRONMENT_VARIABLE") - assert bool(strtobool(chainer.hyperparameters()['sagemaker_use_mpi'])) - assert int(chainer.hyperparameters()['sagemaker_num_processes']) == 4 - assert int(chainer.hyperparameters()['sagemaker_process_slots_per_host']) == 10 - assert str(chainer.hyperparameters()['sagemaker_additional_mpi_options']) == '\"-x MY_ENVIRONMENT_VARIABLE\"' + chainer = _chainer_estimator( + sagemaker_session, + use_mpi=True, + num_processes=4, + process_slots_per_host=10, + additional_mpi_options="-x MY_ENVIRONMENT_VARIABLE", + ) + assert bool(strtobool(chainer.hyperparameters()["sagemaker_use_mpi"])) + assert int(chainer.hyperparameters()["sagemaker_num_processes"]) == 4 + assert int(chainer.hyperparameters()["sagemaker_process_slots_per_host"]) == 10 + assert ( + str(chainer.hyperparameters()["sagemaker_additional_mpi_options"]) + == '"-x MY_ENVIRONMENT_VARIABLE"' + ) def test_attach_with_additional_hyperparameters(sagemaker_session, chainer_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}'.format(chainer_version, - PYTHON_VERSION) - returned_job_description = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'sagemaker_region': '"us-west-2"', - 'sagemaker_num_processes': '4', - 'sagemaker_additional_mpi_options': '"-x MY_ENVIRONMENT_VARIABLE"', - 'sagemaker_process_slots_per_host': '10', - 'sagemaker_use_mpi': 'true' - }, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = Chainer.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert bool(estimator.hyperparameters()['sagemaker_use_mpi']) - assert int(estimator.hyperparameters()['sagemaker_num_processes']) == 4 - assert int(estimator.hyperparameters()['sagemaker_process_slots_per_host']) == 10 - assert str(estimator.hyperparameters()['sagemaker_additional_mpi_options']) == '\"-x MY_ENVIRONMENT_VARIABLE\"' + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}".format( + chainer_version, PYTHON_VERSION + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "sagemaker_region": '"us-west-2"', + "sagemaker_num_processes": "4", + "sagemaker_additional_mpi_options": '"-x MY_ENVIRONMENT_VARIABLE"', + "sagemaker_process_slots_per_host": "10", + "sagemaker_use_mpi": "true", + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = Chainer.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert bool(estimator.hyperparameters()["sagemaker_use_mpi"]) + assert int(estimator.hyperparameters()["sagemaker_num_processes"]) == 4 + assert int(estimator.hyperparameters()["sagemaker_process_slots_per_host"]) == 10 + assert ( + str(estimator.hyperparameters()["sagemaker_additional_mpi_options"]) + == '"-x MY_ENVIRONMENT_VARIABLE"' + ) assert estimator.use_mpi assert estimator.num_processes == 4 assert estimator.process_slots_per_host == 10 @@ -247,14 +262,22 @@ def test_attach_with_additional_hyperparameters(sagemaker_session, chainer_versi def test_create_model(sagemaker_session, chainer_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=chainer_version, container_log_level=container_log_level, - py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir) - - job_name = 'new_name' - chainer.fit(inputs='s3://mybucket/train', job_name=job_name) + source_dir = "s3://mybucket/source" + chainer = Chainer( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=chainer_version, + container_log_level=container_log_level, + py_version=PYTHON_VERSION, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "new_name" + chainer.fit(inputs="s3://mybucket/train", job_name=job_name) model = chainer.create_model() assert model.sagemaker_session == sagemaker_session @@ -270,20 +293,29 @@ def test_create_model(sagemaker_session, chainer_version): def test_create_model_with_optional_params(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - enable_cloudwatch_metrics = 'true' - chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name='job', - source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics) - - chainer.fit(inputs='s3://mybucket/train', job_name='new_name') - - new_role = 'role' + source_dir = "s3://mybucket/source" + enable_cloudwatch_metrics = "true" + chainer = Chainer( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, + py_version=PYTHON_VERSION, + base_job_name="job", + source_dir=source_dir, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + ) + + chainer.fit(inputs="s3://mybucket/train", job_name="new_name") + + new_role = "role" model_server_workers = 2 - vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - model = chainer.create_model(role=new_role, model_server_workers=model_server_workers, - vpc_config_override=vpc_config) + vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + model = chainer.create_model( + role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config + ) assert model.role == new_role assert model.model_server_workers == model_server_workers @@ -292,217 +324,270 @@ def test_create_model_with_optional_params(sagemaker_session): def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - custom_image = 'ubuntu:latest' - chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - image_name=custom_image, container_log_level=container_log_level, - py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir) - - chainer.fit(inputs='s3://mybucket/train', job_name='new_name') + source_dir = "s3://mybucket/source" + custom_image = "ubuntu:latest" + chainer = Chainer( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + image_name=custom_image, + container_log_level=container_log_level, + py_version=PYTHON_VERSION, + base_job_name="job", + source_dir=source_dir, + ) + + chainer.fit(inputs="s3://mybucket/train", job_name="new_name") model = chainer.create_model() assert model.image == custom_image -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) def test_chainer(strftime, sagemaker_session, chainer_version): - chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, py_version=PYTHON_VERSION, - framework_version=chainer_version) - - inputs = 's3://mybucket/train' + chainer = Chainer( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version=PYTHON_VERSION, + framework_version=chainer_version, + ) + + inputs = "s3://mybucket/train" chainer.fit(inputs=inputs) sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job'] + assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ['resource'] + assert boto_call_names == ["resource"] expected_train_args = _create_train_job(chainer_version) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args model = chainer.create_model() - expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-gpu-{}' - assert {'Environment': - {'SAGEMAKER_SUBMIT_DIRECTORY': - 's3://mybucket/sagemaker-chainer-{}/source/sourcedir.tar.gz'.format(TIMESTAMP), - 'SAGEMAKER_PROGRAM': 'dummy_script.py', - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'}, - 'Image': expected_image_base.format(chainer_version, PYTHON_VERSION), - 'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(GPU) - - assert 'cpu' in model.prepare_container_def(CPU)['Image'] + expected_image_base = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-gpu-{}" + assert { + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-chainer-{}/source/sourcedir.tar.gz".format( + TIMESTAMP + ), + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + }, + "Image": expected_image_base.format(chainer_version, PYTHON_VERSION), + "ModelDataUrl": "s3://m/m.tar.gz", + } == model.prepare_container_def(GPU) + + assert "cpu" in model.prepare_container_def(CPU)["Image"] predictor = chainer.deploy(1, GPU) assert isinstance(predictor, ChainerPredictor) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_model(sagemaker_session): - model = ChainerModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = ChainerModel( + "s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + ) predictor = model.deploy(1, GPU) assert isinstance(predictor, ChainerPredictor) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_model_prepare_container_def_accelerator_error(sagemaker_session): - model = ChainerModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = ChainerModel( + MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + ) with pytest.raises(ValueError): model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) def test_train_image_default(sagemaker_session): - chainer = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, py_version=PYTHON_VERSION) + chainer = Chainer( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version=PYTHON_VERSION, + ) assert _get_full_cpu_image_uri(defaults.CHAINER_VERSION) in chainer.train_image() def test_train_image_cpu_instances(sagemaker_session, chainer_version): - chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type='ml.c2.2xlarge') + chainer = _chainer_estimator( + sagemaker_session, chainer_version, train_instance_type="ml.c2.2xlarge" + ) assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version) - chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type='ml.c4.2xlarge') + chainer = _chainer_estimator( + sagemaker_session, chainer_version, train_instance_type="ml.c4.2xlarge" + ) assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version) - chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type='ml.m16') + chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type="ml.m16") assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version) def test_train_image_gpu_instances(sagemaker_session, chainer_version): - chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type='ml.g2.2xlarge') + chainer = _chainer_estimator( + sagemaker_session, chainer_version, train_instance_type="ml.g2.2xlarge" + ) assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version) - chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type='ml.p2.2xlarge') + chainer = _chainer_estimator( + sagemaker_session, chainer_version, train_instance_type="ml.p2.2xlarge" + ) assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version) def test_attach(sagemaker_session, chainer_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}'.format(chainer_version, - PYTHON_VERSION) - returned_job_description = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = Chainer.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}".format( + chainer_version, PYTHON_VERSION + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = Chainer.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" assert estimator.py_version == PYTHON_VERSION assert estimator.framework_version == chainer_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" def test_attach_wrong_framework(sagemaker_session): - rjd = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0.4'}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) + rjd = { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0.4", + }, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) with pytest.raises(ValueError) as error: - Chainer.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + Chainer.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "didn't use image for requested framework" in str(error) def test_attach_custom_image(sagemaker_session): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/my_custom_chainer_image:latest' - returned_job_description = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = Chainer.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/my_custom_chainer_image:latest" + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = Chainer.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.image_name == training_image assert estimator.train_image() == training_image -@patch('sagemaker.chainer.estimator.empty_framework_version_warning') +@patch("sagemaker.chainer.estimator.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): - estimator = Chainer(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=None) + estimator = Chainer( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=None, + ) assert estimator.framework_version == defaults.CHAINER_VERSION warning.assert_called_with(defaults.CHAINER_VERSION, Chainer.LATEST_VERSION) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 36b5eacbdb..cfad410796 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -16,47 +16,49 @@ import sagemaker.cli.main as cli from mock import patch -COMMON_ARGS = '--role-name myrole --data mydata --script myscript --job-name myjob --bucket-name mybucket ' + \ - '--python py3 --instance-type myinstance --instance-count 2' +COMMON_ARGS = ( + "--role-name myrole --data mydata --script myscript --job-name myjob --bucket-name mybucket " + + "--python py3 --instance-type myinstance --instance-count 2" +) -TRAIN_ARGS = '--hyperparameters myhyperparameters.json' +TRAIN_ARGS = "--hyperparameters myhyperparameters.json" -LOG_ARGS = '--log-level debug --botocore-log-level debug' +LOG_ARGS = "--log-level debug --botocore-log-level debug" -HOST_ARGS = '--env ENV1=env1 ENV2=env2' +HOST_ARGS = "--env ENV1=env1 ENV2=env2" def assert_common_defaults(args): - assert args.data == './data' - assert args.script == './script.py' + assert args.data == "./data" + assert args.script == "./script.py" assert args.job_name is None assert args.bucket_name is None - assert args.python == 'py2' - assert args.instance_type == 'ml.m4.xlarge' + assert args.python == "py2" + assert args.instance_type == "ml.m4.xlarge" assert args.instance_count == 1 - assert args.log_level == 'info' - assert args.botocore_log_level == 'warning' + assert args.log_level == "info" + assert args.botocore_log_level == "warning" def assert_common_non_defaults(args): - assert args.data == 'mydata' - assert args.script == 'myscript' - assert args.job_name == 'myjob' - assert args.bucket_name == 'mybucket' - assert args.role_name == 'myrole' - assert args.python == 'py3' - assert args.instance_type == 'myinstance' + assert args.data == "mydata" + assert args.script == "myscript" + assert args.job_name == "myjob" + assert args.bucket_name == "mybucket" + assert args.role_name == "myrole" + assert args.python == "py3" + assert args.instance_type == "myinstance" assert args.instance_count == 2 - assert args.log_level == 'debug' - assert args.botocore_log_level == 'debug' + assert args.log_level == "debug" + assert args.botocore_log_level == "debug" def assert_train_defaults(args): - assert args.hyperparameters == './hyperparameters.json' + assert args.hyperparameters == "./hyperparameters.json" def assert_train_non_defaults(args): - assert args.hyperparameters == 'myhyperparameters.json' + assert args.hyperparameters == "myhyperparameters.json" def assert_host_defaults(args): @@ -64,130 +66,132 @@ def assert_host_defaults(args): def assert_host_non_defaults(args): - assert args.env == ['ENV1=env1', 'ENV2=env2'] + assert args.env == ["ENV1=env1", "ENV2=env2"] def test_args_mxnet_train_defaults(): - args = cli.parse_arguments('mxnet train --role-name role'.split()) + args = cli.parse_arguments("mxnet train --role-name role".split()) assert_common_defaults(args) assert_train_defaults(args) - assert args.func.__module__ == 'sagemaker.cli.mxnet' - assert args.func.__name__ == 'train' + assert args.func.__module__ == "sagemaker.cli.mxnet" + assert args.func.__name__ == "train" def test_args_mxnet_train_non_defaults(): - args = cli.parse_arguments('{} mxnet train --role-name role {} {}' - .format(LOG_ARGS, COMMON_ARGS, TRAIN_ARGS) - .split()) + args = cli.parse_arguments( + "{} mxnet train --role-name role {} {}".format(LOG_ARGS, COMMON_ARGS, TRAIN_ARGS).split() + ) assert_common_non_defaults(args) assert_train_non_defaults(args) - assert args.func.__module__ == 'sagemaker.cli.mxnet' - assert args.func.__name__ == 'train' + assert args.func.__module__ == "sagemaker.cli.mxnet" + assert args.func.__name__ == "train" def test_args_mxnet_host_defaults(): - args = cli.parse_arguments('mxnet host --role-name role'.split()) + args = cli.parse_arguments("mxnet host --role-name role".split()) assert_common_defaults(args) assert_host_defaults(args) - assert args.func.__module__ == 'sagemaker.cli.mxnet' - assert args.func.__name__ == 'host' + assert args.func.__module__ == "sagemaker.cli.mxnet" + assert args.func.__name__ == "host" def test_args_mxnet_host_non_defaults(): - args = cli.parse_arguments('{} mxnet host --role-name role {} {}' - .format(LOG_ARGS, COMMON_ARGS, HOST_ARGS) - .split()) + args = cli.parse_arguments( + "{} mxnet host --role-name role {} {}".format(LOG_ARGS, COMMON_ARGS, HOST_ARGS).split() + ) assert_common_non_defaults(args) assert_host_non_defaults(args) - assert args.func.__module__ == 'sagemaker.cli.mxnet' - assert args.func.__name__ == 'host' + assert args.func.__module__ == "sagemaker.cli.mxnet" + assert args.func.__name__ == "host" def test_args_tensorflow_train_defaults(): - args = cli.parse_arguments('tensorflow train --role-name role'.split()) + args = cli.parse_arguments("tensorflow train --role-name role".split()) assert_common_defaults(args) assert_train_defaults(args) assert args.training_steps is None assert args.evaluation_steps is None - assert args.func.__module__ == 'sagemaker.cli.tensorflow' - assert args.func.__name__ == 'train' + assert args.func.__module__ == "sagemaker.cli.tensorflow" + assert args.func.__name__ == "train" def test_args_tensorflow_train_non_defaults(): - args = cli.parse_arguments('{} tensorflow train --role-name role --training-steps 10 --evaluation-steps 5 {} {}' - .format(LOG_ARGS, COMMON_ARGS, TRAIN_ARGS) - .split()) + args = cli.parse_arguments( + "{} tensorflow train --role-name role --training-steps 10 --evaluation-steps 5 {} {}".format( + LOG_ARGS, COMMON_ARGS, TRAIN_ARGS + ).split() + ) assert_common_non_defaults(args) assert_train_non_defaults(args) assert args.training_steps == 10 assert args.evaluation_steps == 5 - assert args.func.__module__ == 'sagemaker.cli.tensorflow' - assert args.func.__name__ == 'train' + assert args.func.__module__ == "sagemaker.cli.tensorflow" + assert args.func.__name__ == "train" def test_args_tensorflow_host_defaults(): - args = cli.parse_arguments('tensorflow host --role-name role'.split()) + args = cli.parse_arguments("tensorflow host --role-name role".split()) assert_common_defaults(args) assert_host_defaults(args) - assert args.func.__module__ == 'sagemaker.cli.tensorflow' - assert args.func.__name__ == 'host' + assert args.func.__module__ == "sagemaker.cli.tensorflow" + assert args.func.__name__ == "host" def test_args_tensorflow_host_non_defaults(): - args = cli.parse_arguments('{} tensorflow host --role-name role {} {}' - .format(LOG_ARGS, COMMON_ARGS, HOST_ARGS) - .split()) + args = cli.parse_arguments( + "{} tensorflow host --role-name role {} {}".format(LOG_ARGS, COMMON_ARGS, HOST_ARGS).split() + ) assert_common_non_defaults(args) assert_host_non_defaults(args) - assert args.func.__module__ == 'sagemaker.cli.tensorflow' - assert args.func.__name__ == 'host' + assert args.func.__module__ == "sagemaker.cli.tensorflow" + assert args.func.__name__ == "host" def test_args_invalid_framework(): with pytest.raises(SystemExit): - cli.parse_arguments('fakeframework train --role-name role'.split()) + cli.parse_arguments("fakeframework train --role-name role".split()) def test_args_invalid_subcommand(): with pytest.raises(SystemExit): - cli.parse_arguments('mxnet drain'.split()) + cli.parse_arguments("mxnet drain".split()) def test_args_invalid_args(): with pytest.raises(SystemExit): - cli.parse_arguments('tensorflow train --role-name role --notdata foo'.split()) + cli.parse_arguments("tensorflow train --role-name role --notdata foo".split()) def test_args_invalid_mxnet_python(): with pytest.raises(SystemExit): - cli.parse_arguments('mxnet train --role-name role nython py2'.split()) + cli.parse_arguments("mxnet train --role-name role nython py2".split()) def test_args_invalid_host_args_in_train(): with pytest.raises(SystemExit): - cli.parse_arguments('mxnet train --role-name role --env FOO=bar'.split()) + cli.parse_arguments("mxnet train --role-name role --env FOO=bar".split()) def test_args_invalid_train_args_in_host(): with pytest.raises(SystemExit): - cli.parse_arguments('tensorflow host --role-name role --hyperparameters foo.json'.split()) + cli.parse_arguments("tensorflow host --role-name role --hyperparameters foo.json".split()) -@patch('sagemaker.mxnet.estimator.MXNet') -@patch('sagemaker.Session') +@patch("sagemaker.mxnet.estimator.MXNet") +@patch("sagemaker.Session") def test_mxnet_train(session, estimator): - args = cli.parse_arguments('mxnet train --role-name role'.split()) + args = cli.parse_arguments("mxnet train --role-name role".split()) args.func(args) session.return_value.upload_data.assert_called() estimator.assert_called() estimator.return_value.fit.assert_called() -@patch('sagemaker.mxnet.model.MXNetModel') -@patch('sagemaker.cli.common.HostCommand.upload_model') -@patch('sagemaker.Session') +@patch("sagemaker.mxnet.model.MXNetModel") +@patch("sagemaker.cli.common.HostCommand.upload_model") +@patch("sagemaker.Session") def test_mxnet_host(session, upload_model, model): - args = cli.parse_arguments('mxnet host --role-name role'.split()) + args = cli.parse_arguments("mxnet host --role-name role".split()) args.func(args) session.assert_called() upload_model.assert_called() diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index d27ffd0890..bb460d59c7 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -17,8 +17,13 @@ import pytest import itertools from scipy.sparse import coo_matrix -from sagemaker.amazon.common import (record_deserializer, write_numpy_to_dense_tensor, read_recordio, - numpy_to_record_serializer, write_spmatrix_to_sparse_tensor) +from sagemaker.amazon.common import ( + record_deserializer, + write_numpy_to_dense_tensor, + read_recordio, + numpy_to_record_serializer, + write_spmatrix_to_sparse_tensor, +) from sagemaker.amazon.record_pb2 import Record @@ -47,7 +52,7 @@ def test_deserializer(): s = numpy_to_record_serializer() buf = s(np.array(array_data)) d = record_deserializer() - for record, expected in zip(d(buf, 'who cares'), array_data): + for record, expected in zip(d(buf, "who cares"), array_data): assert record.features["values"].float64_tensor.values == expected @@ -65,7 +70,7 @@ def test_float_write_numpy_to_dense_tensor(): def test_float32_write_numpy_to_dense_tensor(): array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]] - array = np.array(array_data).astype(np.dtype('float32')) + array = np.array(array_data).astype(np.dtype("float32")) with tempfile.TemporaryFile() as f: write_numpy_to_dense_tensor(f, array) f.seek(0) @@ -104,7 +109,7 @@ def test_int_label(): def test_float32_label(): array_data = [[1, 2, 3], [10, 20, 3]] array = np.array(array_data) - label_data = np.array([99, 98, 97]).astype(np.dtype('float32')) + label_data = np.array([99, 98, 97]).astype(np.dtype("float32")) with tempfile.TemporaryFile() as f: write_numpy_to_dense_tensor(f, array, label_data) f.seek(0) @@ -118,7 +123,7 @@ def test_float32_label(): def test_float_label(): array_data = [[1, 2, 3], [10, 20, 3]] array = np.array(array_data) - label_data = np.array([99, 98, 97]).astype(np.dtype('float64')) + label_data = np.array([99, 98, 97]).astype(np.dtype("float64")) with tempfile.TemporaryFile() as f: write_numpy_to_dense_tensor(f, array, label_data) f.seek(0) @@ -132,7 +137,7 @@ def test_float_label(): def test_invalid_array(): array_data = [[[1, 2, 3], [10, 20, 3]], [[1, 2, 3], [10, 20, 3]]] array = np.array(array_data) - label_data = np.array([99, 98, 97]).astype(np.dtype('float64')) + label_data = np.array([99, 98, 97]).astype(np.dtype("float64")) with tempfile.TemporaryFile() as f: with pytest.raises(ValueError): write_numpy_to_dense_tensor(f, array, label_data) @@ -141,7 +146,7 @@ def test_invalid_array(): def test_invalid_label(): array_data = [[1, 2, 3], [10, 20, 3]] array = np.array(array_data) - label_data = np.array([99, 98, 97, 1000]).astype(np.dtype('float64')) + label_data = np.array([99, 98, 97, 1000]).astype(np.dtype("float64")) with tempfile.TemporaryFile() as f: with pytest.raises(ValueError): write_numpy_to_dense_tensor(f, array, label_data) @@ -154,7 +159,9 @@ def test_dense_float_write_spmatrix_to_sparse_tensor(): with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array) f.seek(0) - for record_data, expected_data, expected_keys in zip(read_recordio(f), array_data, keys_data): + for record_data, expected_data, expected_keys in zip( + read_recordio(f), array_data, keys_data + ): record = Record() record.ParseFromString(record_data) assert record.features["values"].float64_tensor.values == expected_data @@ -165,11 +172,13 @@ def test_dense_float_write_spmatrix_to_sparse_tensor(): def test_dense_float32_write_spmatrix_to_sparse_tensor(): array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]] keys_data = [[0, 1, 2], [0, 1, 2]] - array = coo_matrix(np.array(array_data).astype(np.dtype('float32'))) + array = coo_matrix(np.array(array_data).astype(np.dtype("float32"))) with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array) f.seek(0) - for record_data, expected_data, expected_keys in zip(read_recordio(f), array_data, keys_data): + for record_data, expected_data, expected_keys in zip( + read_recordio(f), array_data, keys_data + ): record = Record() record.ParseFromString(record_data) assert record.features["values"].float32_tensor.values == expected_data @@ -180,11 +189,13 @@ def test_dense_float32_write_spmatrix_to_sparse_tensor(): def test_dense_int_write_spmatrix_to_sparse_tensor(): array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]] keys_data = [[0, 1, 2], [0, 1, 2]] - array = coo_matrix(np.array(array_data).astype(np.dtype('int'))) + array = coo_matrix(np.array(array_data).astype(np.dtype("int"))) with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array) f.seek(0) - for record_data, expected_data, expected_keys in zip(read_recordio(f), array_data, keys_data): + for record_data, expected_data, expected_keys in zip( + read_recordio(f), array_data, keys_data + ): record = Record() record.ParseFromString(record_data) assert record.features["values"].int32_tensor.values == expected_data @@ -201,10 +212,7 @@ def test_dense_int_spmatrix_to_sparse_label(): write_spmatrix_to_sparse_tensor(f, array, label_data) f.seek(0) for record_data, expected_data, expected_keys, label in zip( - read_recordio(f), - array_data, - keys_data, - label_data + read_recordio(f), array_data, keys_data, label_data ): record = Record() record.ParseFromString(record_data) @@ -217,16 +225,13 @@ def test_dense_int_spmatrix_to_sparse_label(): def test_dense_float32_spmatrix_to_sparse_label(): array_data = [[1, 2, 3], [10, 20, 3]] keys_data = [[0, 1, 2], [0, 1, 2]] - array = coo_matrix(np.array(array_data).astype('float32')) + array = coo_matrix(np.array(array_data).astype("float32")) label_data = np.array([99, 98, 97]) with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array, label_data) f.seek(0) for record_data, expected_data, expected_keys, label in zip( - read_recordio(f), - array_data, - keys_data, - label_data + read_recordio(f), array_data, keys_data, label_data ): record = Record() record.ParseFromString(record_data) @@ -239,16 +244,13 @@ def test_dense_float32_spmatrix_to_sparse_label(): def test_dense_float64_spmatrix_to_sparse_label(): array_data = [[1, 2, 3], [10, 20, 3]] keys_data = [[0, 1, 2], [0, 1, 2]] - array = coo_matrix(np.array(array_data).astype('float64')) + array = coo_matrix(np.array(array_data).astype("float64")) label_data = np.array([99, 98, 97]) with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array, label_data) f.seek(0) for record_data, expected_data, expected_keys, label in zip( - read_recordio(f), - array_data, - keys_data, - label_data + read_recordio(f), array_data, keys_data, label_data ): record = Record() record.ParseFromString(record_data) @@ -261,7 +263,7 @@ def test_dense_float64_spmatrix_to_sparse_label(): def test_invalid_sparse_label(): array_data = [[1, 2, 3], [10, 20, 3]] array = coo_matrix(np.array(array_data)) - label_data = np.array([99, 98, 97, 1000]).astype(np.dtype('float64')) + label_data = np.array([99, 98, 97, 1000]).astype(np.dtype("float64")) with tempfile.TemporaryFile() as f: with pytest.raises(ValueError): write_spmatrix_to_sparse_tensor(f, array, label_data) @@ -277,11 +279,13 @@ def test_sparse_float_write_spmatrix_to_sparse_tensor(): x_indices = [[i] * len(keys_data[i]) for i in range(len(keys_data))] x_indices = list(itertools.chain.from_iterable(x_indices)) - array = coo_matrix((flatten_data, (x_indices, y_indices)), dtype='float64') + array = coo_matrix((flatten_data, (x_indices, y_indices)), dtype="float64") with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array) f.seek(0) - for record_data, expected_data, expected_keys in zip(read_recordio(f), array_data, keys_data): + for record_data, expected_data, expected_keys in zip( + read_recordio(f), array_data, keys_data + ): record = Record() record.ParseFromString(record_data) assert record.features["values"].float64_tensor.values == expected_data @@ -299,11 +303,13 @@ def test_sparse_float32_write_spmatrix_to_sparse_tensor(): x_indices = [[i] * len(keys_data[i]) for i in range(len(keys_data))] x_indices = list(itertools.chain.from_iterable(x_indices)) - array = coo_matrix((flatten_data, (x_indices, y_indices)), dtype='float32') + array = coo_matrix((flatten_data, (x_indices, y_indices)), dtype="float32") with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array) f.seek(0) - for record_data, expected_data, expected_keys in zip(read_recordio(f), array_data, keys_data): + for record_data, expected_data, expected_keys in zip( + read_recordio(f), array_data, keys_data + ): record = Record() record.ParseFromString(record_data) assert record.features["values"].float32_tensor.values == expected_data @@ -321,11 +327,13 @@ def test_sparse_int_write_spmatrix_to_sparse_tensor(): x_indices = [[i] * len(keys_data[i]) for i in range(len(keys_data))] x_indices = list(itertools.chain.from_iterable(x_indices)) - array = coo_matrix((flatten_data, (x_indices, y_indices)), dtype='int') + array = coo_matrix((flatten_data, (x_indices, y_indices)), dtype="int") with tempfile.TemporaryFile() as f: write_spmatrix_to_sparse_tensor(f, array) f.seek(0) - for record_data, expected_data, expected_keys in zip(read_recordio(f), array_data, keys_data): + for record_data, expected_data, expected_keys in zip( + read_recordio(f), array_data, keys_data + ): record = Record() record.ParseFromString(record_data) assert record.features["values"].int32_tensor.values == expected_data @@ -336,7 +344,7 @@ def test_sparse_int_write_spmatrix_to_sparse_tensor(): def test_dense_to_sparse(): array_data = [[1, 2, 3], [10, 20, 3]] array = np.array(array_data) - label_data = np.array([99, 98, 97]).astype(np.dtype('float64')) + label_data = np.array([99, 98, 97]).astype(np.dtype("float64")) with tempfile.TemporaryFile() as f: with pytest.raises(TypeError): write_spmatrix_to_sparse_tensor(f, array, label_data) diff --git a/tests/unit/test_create_deploy_entities.py b/tests/unit/test_create_deploy_entities.py index cad572ffe2..cb1b6eafb7 100644 --- a/tests/unit/test_create_deploy_entities.py +++ b/tests/unit/test_create_deploy_entities.py @@ -17,93 +17,116 @@ import sagemaker -MODEL_NAME = 'mymodelname' -ENDPOINT_CONFIG_NAME = 'myendpointconfigname' -ENDPOINT_NAME = 'myendpointname' -ROLE = 'myimrole' -EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole' -IMAGE = 'myimage' -FULL_CONTAINER_DEF = {'Environment': {}, 'Image': IMAGE, 'ModelDataUrl': 's3://mybucket/mymodel'} -VPC_CONFIG = {'Subnets': ['subnet-foo'], 'SecurityGroups': ['sg-foo']} +MODEL_NAME = "mymodelname" +ENDPOINT_CONFIG_NAME = "myendpointconfigname" +ENDPOINT_NAME = "myendpointname" +ROLE = "myimrole" +EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" +IMAGE = "myimage" +FULL_CONTAINER_DEF = {"Environment": {}, "Image": IMAGE, "ModelDataUrl": "s3://mybucket/mymodel"} +VPC_CONFIG = {"Subnets": ["subnet-foo"], "SecurityGroups": ["sg-foo"]} INITIAL_INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -REGION = 'us-west-2' +INSTANCE_TYPE = "ml.c4.xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +REGION = "us-west-2" @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) + boto_mock = Mock(name="boto_session", region_name=REGION) ims = sagemaker.Session(boto_session=boto_mock) ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims def test_create_model(sagemaker_session): - returned_name = sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, container_defs=FULL_CONTAINER_DEF, - vpc_config=VPC_CONFIG) + returned_name = sagemaker_session.create_model( + name=MODEL_NAME, role=ROLE, container_defs=FULL_CONTAINER_DEF, vpc_config=VPC_CONFIG + ) assert returned_name == MODEL_NAME sagemaker_session.sagemaker_client.create_model.assert_called_once_with( ModelName=MODEL_NAME, PrimaryContainer=FULL_CONTAINER_DEF, ExecutionRoleArn=EXPANDED_ROLE, - VpcConfig=VPC_CONFIG) + VpcConfig=VPC_CONFIG, + ) def test_create_model_expand_primary_container(sagemaker_session): sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, container_defs=IMAGE) _1, _2, create_model_kwargs = sagemaker_session.sagemaker_client.create_model.mock_calls[0] - assert create_model_kwargs['PrimaryContainer'] == {'Environment': {}, 'Image': IMAGE} + assert create_model_kwargs["PrimaryContainer"] == {"Environment": {}, "Image": IMAGE} def test_create_endpoint_config(sagemaker_session): - returned_name = sagemaker_session.create_endpoint_config(name=ENDPOINT_CONFIG_NAME, model_name=MODEL_NAME, - initial_instance_count=INITIAL_INSTANCE_COUNT, - instance_type=INSTANCE_TYPE) + returned_name = sagemaker_session.create_endpoint_config( + name=ENDPOINT_CONFIG_NAME, + model_name=MODEL_NAME, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + ) assert returned_name == ENDPOINT_CONFIG_NAME - expected_pvs = [{'ModelName': MODEL_NAME, - 'InitialInstanceCount': INITIAL_INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'InitialVariantWeight': 1, - 'VariantName': 'AllTraffic'}] + expected_pvs = [ + { + "ModelName": MODEL_NAME, + "InitialInstanceCount": INITIAL_INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "InitialVariantWeight": 1, + "VariantName": "AllTraffic", + } + ] sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_once_with( - EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs, Tags=[]) + EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs, Tags=[] + ) def test_create_endpoint_config_with_accelerator(sagemaker_session): - returned_name = sagemaker_session.create_endpoint_config(name=ENDPOINT_CONFIG_NAME, model_name=MODEL_NAME, - initial_instance_count=INITIAL_INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE) + returned_name = sagemaker_session.create_endpoint_config( + name=ENDPOINT_CONFIG_NAME, + model_name=MODEL_NAME, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + ) assert returned_name == ENDPOINT_CONFIG_NAME - expected_pvs = [{'ModelName': MODEL_NAME, - 'InitialInstanceCount': INITIAL_INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'InitialVariantWeight': 1, - 'VariantName': 'AllTraffic', - 'AcceleratorType': ACCELERATOR_TYPE}] + expected_pvs = [ + { + "ModelName": MODEL_NAME, + "InitialInstanceCount": INITIAL_INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "InitialVariantWeight": 1, + "VariantName": "AllTraffic", + "AcceleratorType": ACCELERATOR_TYPE, + } + ] sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_once_with( - EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs, Tags=[]) + EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs, Tags=[] + ) def test_create_endpoint_no_wait(sagemaker_session): returned_name = sagemaker_session.create_endpoint( - endpoint_name=ENDPOINT_NAME, config_name=ENDPOINT_CONFIG_NAME, wait=False) + endpoint_name=ENDPOINT_NAME, config_name=ENDPOINT_CONFIG_NAME, wait=False + ) assert returned_name == ENDPOINT_NAME sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with( - EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[]) + EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[] + ) def test_create_endpoint_wait(sagemaker_session): sagemaker_session.wait_for_endpoint = Mock() - returned_name = sagemaker_session.create_endpoint(endpoint_name=ENDPOINT_NAME, config_name=ENDPOINT_CONFIG_NAME) + returned_name = sagemaker_session.create_endpoint( + endpoint_name=ENDPOINT_NAME, config_name=ENDPOINT_CONFIG_NAME + ) assert returned_name == ENDPOINT_NAME sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with( - EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[]) + EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[] + ) sagemaker_session.wait_for_endpoint.assert_called_once_with(ENDPOINT_NAME) diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index e7478cf956..dfcdb83378 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -17,15 +17,15 @@ from mock import Mock import sagemaker -ACCOUNT_ID = '123' -REGION = 'us-west-2' -DEFAULT_BUCKET_NAME = 'sagemaker-{}-{}'.format(REGION, ACCOUNT_ID) +ACCOUNT_ID = "123" +REGION = "us-west-2" +DEFAULT_BUCKET_NAME = "sagemaker-{}-{}".format(REGION, ACCOUNT_ID) @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - boto_mock.client('sts').get_caller_identity.return_value = {'Account': ACCOUNT_ID} + boto_mock = Mock(name="boto_session", region_name=REGION) + boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} ims = sagemaker.Session(boto_session=boto_mock) return ims @@ -37,12 +37,15 @@ def test_default_bucket_s3_create_call(sagemaker_session): _1, _2, create_kwargs = create_calls[0] assert bucket_name == DEFAULT_BUCKET_NAME assert len(create_calls) == 1 - assert create_kwargs == {'CreateBucketConfiguration': {'LocationConstraint': 'us-west-2'}, 'Bucket': bucket_name} + assert create_kwargs == { + "CreateBucketConfiguration": {"LocationConstraint": "us-west-2"}, + "Bucket": bucket_name, + } assert sagemaker_session._default_bucket == bucket_name def test_default_already_cached(sagemaker_session): - existing_default = 'mydefaultbucket' + existing_default = "mydefaultbucket" sagemaker_session._default_bucket = existing_default bucket_name = sagemaker_session.default_bucket() @@ -53,8 +56,10 @@ def test_default_already_cached(sagemaker_session): def test_default_bucket_exists(sagemaker_session): - error = ClientError(error_response={'Error': {'Code': 'BucketAlreadyOwnedByYou', 'Message': 'message'}}, - operation_name='foo') + error = ClientError( + error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": "message"}}, + operation_name="foo", + ) sagemaker_session.boto_session.resource().create_bucket.side_effect = error bucket_name = sagemaker_session.default_bucket() @@ -62,9 +67,11 @@ def test_default_bucket_exists(sagemaker_session): def test_concurrent_bucket_modification(sagemaker_session): - message = 'A conflicting conditional operation is currently in progress against this resource. Please try again' - error = ClientError(error_response={'Error': {'Code': 'BucketAlreadyOwnedByYou', 'Message': message}}, - operation_name='foo') + message = "A conflicting conditional operation is currently in progress against this resource. Please try again" + error = ClientError( + error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": message}}, + operation_name="foo", + ) sagemaker_session.boto_session.resource().create_bucket.side_effect = error bucket_name = sagemaker_session.default_bucket() @@ -73,8 +80,10 @@ def test_concurrent_bucket_modification(sagemaker_session): def test_bucket_creation_client_error(sagemaker_session): with pytest.raises(ClientError): - error = ClientError(error_response={'Error': {'Code': 'SomethingWrong', 'Message': 'message'}}, - operation_name='foo') + error = ClientError( + error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}}, + operation_name="foo", + ) sagemaker_session.boto_session.resource().create_bucket.side_effect = error sagemaker_session.default_bucket() diff --git a/tests/unit/test_endpoint_from_job.py b/tests/unit/test_endpoint_from_job.py index db603a8798..0a969ba8f9 100644 --- a/tests/unit/test_endpoint_from_job.py +++ b/tests/unit/test_endpoint_from_job.py @@ -17,76 +17,87 @@ import sagemaker -JOB_NAME = 'myjob' +JOB_NAME = "myjob" INITIAL_INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -IMAGE = 'myimage' -S3_MODEL_ARTIFACTS = 's3://mybucket/mymodel' -TRAIN_ROLE = 'mytrainrole' -VPC_CONFIG = {'Subnets': ['subnet-foo'], 'SecurityGroupIds': ['sg-foo']} +INSTANCE_TYPE = "ml.c4.xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +IMAGE = "myimage" +S3_MODEL_ARTIFACTS = "s3://mybucket/mymodel" +TRAIN_ROLE = "mytrainrole" +VPC_CONFIG = {"Subnets": ["subnet-foo"], "SecurityGroupIds": ["sg-foo"]} TRAINING_JOB_RESPONSE = { - "AlgorithmSpecification": { - "TrainingImage": IMAGE - }, - "ModelArtifacts": { - "S3ModelArtifacts": S3_MODEL_ARTIFACTS - }, + "AlgorithmSpecification": {"TrainingImage": IMAGE}, + "ModelArtifacts": {"S3ModelArtifacts": S3_MODEL_ARTIFACTS}, "RoleArn": TRAIN_ROLE, - "VpcConfig": VPC_CONFIG + "VpcConfig": VPC_CONFIG, } -FULL_CONTAINER_DEF = {'Environment': {}, 'Image': IMAGE, 'ModelDataUrl': S3_MODEL_ARTIFACTS} -DEPLOY_IMAGE = 'mydeployimage' -DEPLOY_ROLE = 'mydeployrole' -NEW_ENTITY_NAME = 'mynewendpoint' -ENV_VARS = {'PYTHONUNBUFFERED': 'TRUE', 'some': 'nonsense'} -ENDPOINT_FROM_MODEL_RETURNED_NAME = 'endpointfrommodelname' -REGION = 'us-west-2' +FULL_CONTAINER_DEF = {"Environment": {}, "Image": IMAGE, "ModelDataUrl": S3_MODEL_ARTIFACTS} +DEPLOY_IMAGE = "mydeployimage" +DEPLOY_ROLE = "mydeployrole" +NEW_ENTITY_NAME = "mynewendpoint" +ENV_VARS = {"PYTHONUNBUFFERED": "TRUE", "some": "nonsense"} +ENDPOINT_FROM_MODEL_RETURNED_NAME = "endpointfrommodelname" +REGION = "us-west-2" @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - ims = sagemaker.Session(sagemaker_client=Mock(name='sagemaker_client'), boto_session=boto_mock) - ims.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=TRAINING_JOB_RESPONSE) + boto_mock = Mock(name="boto_session", region_name=REGION) + ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock) + ims.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=TRAINING_JOB_RESPONSE + ) - ims.endpoint_from_model_data = Mock('endpoint_from_model_data', return_value=ENDPOINT_FROM_MODEL_RETURNED_NAME) + ims.endpoint_from_model_data = Mock( + "endpoint_from_model_data", return_value=ENDPOINT_FROM_MODEL_RETURNED_NAME + ) return ims def test_all_defaults_no_existing_entities(sagemaker_session): - original_args = {'job_name': JOB_NAME, 'initial_instance_count': INITIAL_INSTANCE_COUNT, - 'instance_type': INSTANCE_TYPE, 'wait': False} + original_args = { + "job_name": JOB_NAME, + "initial_instance_count": INITIAL_INSTANCE_COUNT, + "instance_type": INSTANCE_TYPE, + "wait": False, + } returned_name = sagemaker_session.endpoint_from_job(**original_args) expected_args = original_args.copy() - expected_args.pop('job_name') - expected_args['model_s3_location'] = S3_MODEL_ARTIFACTS - expected_args['deployment_image'] = IMAGE - expected_args['role'] = TRAIN_ROLE - expected_args['name'] = JOB_NAME - expected_args['model_environment_vars'] = None - expected_args['model_vpc_config'] = VPC_CONFIG - expected_args['accelerator_type'] = None + expected_args.pop("job_name") + expected_args["model_s3_location"] = S3_MODEL_ARTIFACTS + expected_args["deployment_image"] = IMAGE + expected_args["role"] = TRAIN_ROLE + expected_args["name"] = JOB_NAME + expected_args["model_environment_vars"] = None + expected_args["model_vpc_config"] = VPC_CONFIG + expected_args["accelerator_type"] = None sagemaker_session.endpoint_from_model_data.assert_called_once_with(**expected_args) assert returned_name == ENDPOINT_FROM_MODEL_RETURNED_NAME def test_no_defaults_no_existing_entities(sagemaker_session): - vpc_config_override = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']} + vpc_config_override = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} - original_args = {'job_name': JOB_NAME, 'initial_instance_count': INITIAL_INSTANCE_COUNT, - 'instance_type': INSTANCE_TYPE, 'deployment_image': DEPLOY_IMAGE, 'role': DEPLOY_ROLE, - 'name': NEW_ENTITY_NAME, 'model_environment_vars': ENV_VARS, - 'vpc_config_override': vpc_config_override, 'accelerator_type': ACCELERATOR_TYPE, - 'wait': False} + original_args = { + "job_name": JOB_NAME, + "initial_instance_count": INITIAL_INSTANCE_COUNT, + "instance_type": INSTANCE_TYPE, + "deployment_image": DEPLOY_IMAGE, + "role": DEPLOY_ROLE, + "name": NEW_ENTITY_NAME, + "model_environment_vars": ENV_VARS, + "vpc_config_override": vpc_config_override, + "accelerator_type": ACCELERATOR_TYPE, + "wait": False, + } returned_name = sagemaker_session.endpoint_from_job(**original_args) expected_args = original_args.copy() - expected_args.pop('job_name') - expected_args['model_s3_location'] = S3_MODEL_ARTIFACTS - expected_args['model_vpc_config'] = expected_args.pop('vpc_config_override') + expected_args.pop("job_name") + expected_args["model_s3_location"] = S3_MODEL_ARTIFACTS + expected_args["model_vpc_config"] = expected_args.pop("vpc_config_override") sagemaker_session.endpoint_from_model_data.assert_called_once_with(**expected_args) assert returned_name == ENDPOINT_FROM_MODEL_RETURNED_NAME diff --git a/tests/unit/test_endpoint_from_model_data.py b/tests/unit/test_endpoint_from_model_data.py index 672087a02c..f87e5b0973 100644 --- a/tests/unit/test_endpoint_from_model_data.py +++ b/tests/unit/test_endpoint_from_model_data.py @@ -19,109 +19,141 @@ import sagemaker -ENDPOINT_NAME = 'myendpoint' +ENDPOINT_NAME = "myendpoint" INITIAL_INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -S3_MODEL_ARTIFACTS = 's3://mybucket/mymodel' -DEPLOY_IMAGE = 'mydeployimage' -CONTAINER_DEF = {'Environment': {}, 'Image': DEPLOY_IMAGE, 'ModelDataUrl': S3_MODEL_ARTIFACTS} -VPC_CONFIG = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} -DEPLOY_ROLE = 'mydeployrole' -ENV_VARS = {'PYTHONUNBUFFERED': 'TRUE', 'some': 'nonsense'} -NAME_FROM_IMAGE = 'namefromimage' -REGION = 'us-west-2' +INSTANCE_TYPE = "ml.c4.xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +S3_MODEL_ARTIFACTS = "s3://mybucket/mymodel" +DEPLOY_IMAGE = "mydeployimage" +CONTAINER_DEF = {"Environment": {}, "Image": DEPLOY_IMAGE, "ModelDataUrl": S3_MODEL_ARTIFACTS} +VPC_CONFIG = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} +DEPLOY_ROLE = "mydeployrole" +ENV_VARS = {"PYTHONUNBUFFERED": "TRUE", "some": "nonsense"} +NAME_FROM_IMAGE = "namefromimage" +REGION = "us-west-2" @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - ims = sagemaker.Session(sagemaker_client=Mock(name='sagemaker_client'), boto_session=boto_mock) - ims.sagemaker_client.describe_model = Mock(name='describe_model', side_effect=_raise_does_not_exist_client_error) - ims.sagemaker_client.describe_endpoint_config = Mock(name='describe_endpoint_config', - side_effect=_raise_does_not_exist_client_error) - ims.sagemaker_client.describe_endpoint = Mock(name='describe_endpoint', - side_effect=_raise_does_not_exist_client_error) - ims.create_model = Mock(name='create_model') - ims.create_endpoint_config = Mock(name='create_endpoint_config') - ims.create_endpoint = Mock(name='create_endpoint') + boto_mock = Mock(name="boto_session", region_name=REGION) + ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock) + ims.sagemaker_client.describe_model = Mock( + name="describe_model", side_effect=_raise_does_not_exist_client_error + ) + ims.sagemaker_client.describe_endpoint_config = Mock( + name="describe_endpoint_config", side_effect=_raise_does_not_exist_client_error + ) + ims.sagemaker_client.describe_endpoint = Mock( + name="describe_endpoint", side_effect=_raise_does_not_exist_client_error + ) + ims.create_model = Mock(name="create_model") + ims.create_endpoint_config = Mock(name="create_endpoint_config") + ims.create_endpoint = Mock(name="create_endpoint") return ims -@patch('sagemaker.session.name_from_image', return_value=NAME_FROM_IMAGE) +@patch("sagemaker.session.name_from_image", return_value=NAME_FROM_IMAGE) def test_all_defaults_no_existing_entities(name_from_image_mock, sagemaker_session): - returned_name = sagemaker_session.endpoint_from_model_data(model_s3_location=S3_MODEL_ARTIFACTS, - deployment_image=DEPLOY_IMAGE, - initial_instance_count=INITIAL_INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, role=DEPLOY_ROLE, - wait=False) - - sagemaker_session.sagemaker_client.describe_endpoint.assert_called_once_with(EndpointName=NAME_FROM_IMAGE) - sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=NAME_FROM_IMAGE) + returned_name = sagemaker_session.endpoint_from_model_data( + model_s3_location=S3_MODEL_ARTIFACTS, + deployment_image=DEPLOY_IMAGE, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + role=DEPLOY_ROLE, + wait=False, + ) + + sagemaker_session.sagemaker_client.describe_endpoint.assert_called_once_with( + EndpointName=NAME_FROM_IMAGE + ) + sagemaker_session.sagemaker_client.describe_model.assert_called_once_with( + ModelName=NAME_FROM_IMAGE + ) sagemaker_session.sagemaker_client.describe_endpoint_config.assert_called_once_with( - EndpointConfigName=NAME_FROM_IMAGE) - sagemaker_session.create_model.assert_called_once_with(name=NAME_FROM_IMAGE, - role=DEPLOY_ROLE, - container_defs=CONTAINER_DEF, - vpc_config=None) - sagemaker_session.create_endpoint_config.assert_called_once_with(name=NAME_FROM_IMAGE, - model_name=NAME_FROM_IMAGE, - initial_instance_count=INITIAL_INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=None) - sagemaker_session.create_endpoint.assert_called_once_with(endpoint_name=NAME_FROM_IMAGE, - config_name=NAME_FROM_IMAGE, - wait=False) + EndpointConfigName=NAME_FROM_IMAGE + ) + sagemaker_session.create_model.assert_called_once_with( + name=NAME_FROM_IMAGE, role=DEPLOY_ROLE, container_defs=CONTAINER_DEF, vpc_config=None + ) + sagemaker_session.create_endpoint_config.assert_called_once_with( + name=NAME_FROM_IMAGE, + model_name=NAME_FROM_IMAGE, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + ) + sagemaker_session.create_endpoint.assert_called_once_with( + endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False + ) assert returned_name == NAME_FROM_IMAGE -@patch('sagemaker.session.name_from_image', return_value=NAME_FROM_IMAGE) +@patch("sagemaker.session.name_from_image", return_value=NAME_FROM_IMAGE) def test_no_defaults_no_existing_entities(name_from_image_mock, sagemaker_session): container_def_with_env = CONTAINER_DEF.copy() - container_def_with_env.update({'Environment': ENV_VARS}) - - returned_name = sagemaker_session.endpoint_from_model_data(model_s3_location=S3_MODEL_ARTIFACTS, - deployment_image=DEPLOY_IMAGE, - initial_instance_count=INITIAL_INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, role=DEPLOY_ROLE, - wait=False, name=ENDPOINT_NAME, - model_environment_vars=ENV_VARS, - model_vpc_config=VPC_CONFIG, - accelerator_type=ACCELERATOR_TYPE) - - sagemaker_session.sagemaker_client.describe_endpoint.assert_called_once_with(EndpointName=ENDPOINT_NAME) - sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=ENDPOINT_NAME) + container_def_with_env.update({"Environment": ENV_VARS}) + + returned_name = sagemaker_session.endpoint_from_model_data( + model_s3_location=S3_MODEL_ARTIFACTS, + deployment_image=DEPLOY_IMAGE, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + role=DEPLOY_ROLE, + wait=False, + name=ENDPOINT_NAME, + model_environment_vars=ENV_VARS, + model_vpc_config=VPC_CONFIG, + accelerator_type=ACCELERATOR_TYPE, + ) + + sagemaker_session.sagemaker_client.describe_endpoint.assert_called_once_with( + EndpointName=ENDPOINT_NAME + ) + sagemaker_session.sagemaker_client.describe_model.assert_called_once_with( + ModelName=ENDPOINT_NAME + ) sagemaker_session.sagemaker_client.describe_endpoint_config.assert_called_once_with( - EndpointConfigName=ENDPOINT_NAME) - sagemaker_session.create_model.assert_called_once_with(name=ENDPOINT_NAME, - role=DEPLOY_ROLE, - container_defs=container_def_with_env, - vpc_config=VPC_CONFIG) - sagemaker_session.create_endpoint_config.assert_called_once_with(name=ENDPOINT_NAME, - model_name=ENDPOINT_NAME, - initial_instance_count=INITIAL_INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE) - sagemaker_session.create_endpoint.assert_called_once_with(endpoint_name=ENDPOINT_NAME, - config_name=ENDPOINT_NAME, - wait=False) + EndpointConfigName=ENDPOINT_NAME + ) + sagemaker_session.create_model.assert_called_once_with( + name=ENDPOINT_NAME, + role=DEPLOY_ROLE, + container_defs=container_def_with_env, + vpc_config=VPC_CONFIG, + ) + sagemaker_session.create_endpoint_config.assert_called_once_with( + name=ENDPOINT_NAME, + model_name=ENDPOINT_NAME, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + ) + sagemaker_session.create_endpoint.assert_called_once_with( + endpoint_name=ENDPOINT_NAME, config_name=ENDPOINT_NAME, wait=False + ) assert returned_name == ENDPOINT_NAME -@patch('sagemaker.session.name_from_image', return_value=NAME_FROM_IMAGE) +@patch("sagemaker.session.name_from_image", return_value=NAME_FROM_IMAGE) def test_model_and_endpoint_config_exist(name_from_image_mock, sagemaker_session): - sagemaker_session.sagemaker_client.describe_model = Mock(name='describe_model') - sagemaker_session.sagemaker_client.describe_endpoint_config = Mock(name='describe_endpoint_config') - - sagemaker_session.endpoint_from_model_data(model_s3_location=S3_MODEL_ARTIFACTS, deployment_image=DEPLOY_IMAGE, - initial_instance_count=INITIAL_INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, wait=False) + sagemaker_session.sagemaker_client.describe_model = Mock(name="describe_model") + sagemaker_session.sagemaker_client.describe_endpoint_config = Mock( + name="describe_endpoint_config" + ) + + sagemaker_session.endpoint_from_model_data( + model_s3_location=S3_MODEL_ARTIFACTS, + deployment_image=DEPLOY_IMAGE, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + wait=False, + ) sagemaker_session.create_model.assert_not_called() sagemaker_session.create_endpoint_config.assert_not_called() - sagemaker_session.create_endpoint.assert_called_once_with(endpoint_name=NAME_FROM_IMAGE, - config_name=NAME_FROM_IMAGE, - wait=False) + sagemaker_session.create_endpoint.assert_called_once_with( + endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False + ) def test_entity_exists(): @@ -134,13 +166,15 @@ def test_entity_doesnt_exist(): def test_describe_failure(): def _raise_unexpected_client_error(): - response = {'Error': {'Code': 'ValidationException', 'Message': 'Name does not satisfy expression.'}} - raise ClientError(error_response=response, operation_name='foo') + response = { + "Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."} + } + raise ClientError(error_response=response, operation_name="foo") with pytest.raises(ClientError): sagemaker.session._deployment_entity_exists(_raise_unexpected_client_error) def _raise_does_not_exist_client_error(**kwargs): - response = {'Error': {'Code': 'ValidationException', 'Message': 'Could not find entity.'}} - raise ClientError(error_response=response, operation_name='foo') + response = {"Error": {"Code": "ValidationException", "Message": "Could not find entity."}} + raise ClientError(error_response=response, operation_name="foo") diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 9fc0b6580c..924c494cd3 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -33,92 +33,69 @@ MODEL_IMAGE = "mi" ENTRY_POINT = "blah.py" -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_NAME = 'dummy_script.py' +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_NAME = "dummy_script.py" SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME) -TIMESTAMP = '2017-11-06-14:14:15.671' -BUCKET_NAME = 'mybucket' +TIMESTAMP = "2017-11-06-14:14:15.671" +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'c4.4xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -ROLE = 'DummyRole' -IMAGE_NAME = 'fakeimage' -REGION = 'us-west-2' -JOB_NAME = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) -TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] -OUTPUT_PATH = 's3://bucket/prefix' - -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': MODEL_DATA - } -} +INSTANCE_TYPE = "c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +ROLE = "DummyRole" +IMAGE_NAME = "fakeimage" +REGION = "us-west-2" +JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) +TAGS = [{"Name": "some-tag", "Value": "value-for-tag"}] +OUTPUT_PATH = "s3://bucket/prefix" -RETURNED_JOB_DESCRIPTION = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other-py2-cpu:1.0.4' - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - }, +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' - }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': { - 'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo' +RETURNED_JOB_DESCRIPTION = { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other-py2-cpu:1.0.4", }, - 'TrainingJobOutput': { - 'S3TrainingJobOutput': 's3://here/output.tar.gz' + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", }, - 'EnableInterContainerTrafficEncryption': False + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": {"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + "EnableInterContainerTrafficEncryption": False, } MODEL_CONTAINER_DEF = { - 'Environment': { - 'SAGEMAKER_PROGRAM': ENTRY_POINT, - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': REGION, - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false' + "Environment": { + "SAGEMAKER_PROGRAM": ENTRY_POINT, + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", }, - 'Image': MODEL_IMAGE, - 'ModelDataUrl': MODEL_DATA, + "Image": MODEL_IMAGE, + "ModelDataUrl": MODEL_DATA, } -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} -LIST_TAGS_RESULT = { - 'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] -} +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} class DummyFramework(Framework): - __framework_name__ = 'dummy' + __framework_name__ = "dummy" def train_image(self): return IMAGE_NAME @@ -129,15 +106,23 @@ def create_model(self, role=None, model_server_workers=None): @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): init_params = super(DummyFramework, cls)._prepare_init_params_from_job_description( - job_details, model_channel_name) + job_details, model_channel_name + ) init_params.pop("image", None) return init_params class DummyFrameworkModel(FrameworkModel): def __init__(self, sagemaker_session, **kwargs): - super(DummyFrameworkModel, self).__init__(MODEL_DATA, MODEL_IMAGE, INSTANCE_TYPE, ROLE, ENTRY_POINT, - sagemaker_session=sagemaker_session, **kwargs) + super(DummyFrameworkModel, self).__init__( + MODEL_DATA, + MODEL_IMAGE, + INSTANCE_TYPE, + ROLE, + ENTRY_POINT, + sagemaker_session=sagemaker_session, + **kwargs + ) def create_predictor(self, endpoint_name): return None @@ -148,18 +133,24 @@ def prepare_container_def(self, instance_type): @pytest.fixture(autouse=True) def mock_create_tar_file(): - with patch('sagemaker.utils.create_tar_file', MagicMock()) as create_tar_file: + with patch("sagemaker.utils.create_tar_file", MagicMock()) as create_tar_file: yield create_tar_file @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) sms.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) @@ -167,407 +158,568 @@ def sagemaker_session(): def test_framework_all_init_args(sagemaker_session): - f = DummyFramework('my_script.py', role='DummyRole', train_instance_count=3, train_instance_type='ml.m4.xlarge', - sagemaker_session=sagemaker_session, train_volume_size=123, train_volume_kms_key='volumekms', - train_max_run=456, input_mode='inputmode', output_path='outputpath', output_kms_key='outputkms', - base_job_name='basejobname', tags=[{'foo': 'bar'}], subnets=['123', '456'], - security_group_ids=['789', '012'], - metric_definitions=[{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}], - encrypt_inter_container_traffic=True) - _TrainingJob.start_new(f, 's3://mydata') + f = DummyFramework( + "my_script.py", + role="DummyRole", + train_instance_count=3, + train_instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + train_volume_size=123, + train_volume_kms_key="volumekms", + train_max_run=456, + input_mode="inputmode", + output_path="outputpath", + output_kms_key="outputkms", + base_job_name="basejobname", + tags=[{"foo": "bar"}], + subnets=["123", "456"], + security_group_ids=["789", "012"], + metric_definitions=[{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], + encrypt_inter_container_traffic=True, + ) + _TrainingJob.start_new(f, "s3://mydata") sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args - assert args == {'input_mode': 'inputmode', 'tags': [{'foo': 'bar'}], 'hyperparameters': {}, 'image': 'fakeimage', - 'input_config': [{'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': {'S3DataType': 'S3Prefix', - 'S3DataDistributionType': 'FullyReplicated', - 'S3Uri': 's3://mydata'}}}], - 'output_config': {'KmsKeyId': 'outputkms', 'S3OutputPath': 'outputpath'}, - 'vpc_config': {'Subnets': ['123', '456'], 'SecurityGroupIds': ['789', '012']}, - 'stop_condition': {'MaxRuntimeInSeconds': 456}, - 'role': sagemaker_session.expand_role(), 'job_name': None, - 'resource_config': {'VolumeSizeInGB': 123, 'InstanceCount': 3, 'VolumeKmsKeyId': 'volumekms', - 'InstanceType': 'ml.m4.xlarge'}, - 'metric_definitions': [{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}], - 'encrypt_inter_container_traffic': True} + assert args == { + "input_mode": "inputmode", + "tags": [{"foo": "bar"}], + "hyperparameters": {}, + "image": "fakeimage", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + "S3Uri": "s3://mydata", + } + }, + } + ], + "output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"}, + "vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]}, + "stop_condition": {"MaxRuntimeInSeconds": 456}, + "role": sagemaker_session.expand_role(), + "job_name": None, + "resource_config": { + "VolumeSizeInGB": 123, + "InstanceCount": 3, + "VolumeKmsKeyId": "volumekms", + "InstanceType": "ml.m4.xlarge", + }, + "metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], + "encrypt_inter_container_traffic": True, + } def test_framework_init_s3_entry_point_invalid(sagemaker_session): with pytest.raises(ValueError) as error: - DummyFramework('s3://remote-script-because-im-mistaken', role=ROLE, - sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE) - assert 'Must be a path to a local file' in str(error) + DummyFramework( + "s3://remote-script-because-im-mistaken", + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) + assert "Must be a path to a local file" in str(error) def test_sagemaker_s3_uri_invalid(sagemaker_session): with pytest.raises(ValueError) as error: - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) - t.fit('thisdoesntstartwiths3') - assert 'must be a valid S3 or FILE URI' in str(error) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) + t.fit("thisdoesntstartwiths3") + assert "must be a valid S3 or FILE URI" in str(error) def test_sagemaker_model_s3_uri_invalid(sagemaker_session): with pytest.raises(ValueError) as error: - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - model_uri='thisdoesntstartwiths3either.tar.gz') - t.fit('s3://mydata') - assert 'must be a valid S3 or FILE URI' in str(error) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + model_uri="thisdoesntstartwiths3either.tar.gz", + ) + t.fit("s3://mydata") + assert "must be a valid S3 or FILE URI" in str(error) def test_sagemaker_model_file_uri_invalid(sagemaker_session): with pytest.raises(ValueError) as error: - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - model_uri='file://notins3.tar.gz') - t.fit('s3://mydata') - assert 'File URIs are supported in local mode only' in str(error) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + model_uri="file://notins3.tar.gz", + ) + t.fit("s3://mydata") + assert "File URIs are supported in local mode only" in str(error) def test_sagemaker_model_default_channel_name(sagemaker_session): - f = DummyFramework(entry_point='my_script.py', role='DummyRole', train_instance_count=3, - train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker_session, - model_uri='s3://model-bucket/prefix/model.tar.gz') + f = DummyFramework( + entry_point="my_script.py", + role="DummyRole", + train_instance_count=3, + train_instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + model_uri="s3://model-bucket/prefix/model.tar.gz", + ) _TrainingJob.start_new(f, {}) sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args - assert args['input_config'] == [{'ChannelName': 'model', - 'InputMode': 'File', - 'ContentType': 'application/x-sagemaker-model', - 'DataSource': { - 'S3DataSource': {'S3DataType': 'S3Prefix', - 'S3DataDistributionType': 'FullyReplicated', - 'S3Uri': 's3://model-bucket/prefix/model.tar.gz'}}}] + assert args["input_config"] == [ + { + "ChannelName": "model", + "InputMode": "File", + "ContentType": "application/x-sagemaker-model", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + "S3Uri": "s3://model-bucket/prefix/model.tar.gz", + } + }, + } + ] def test_sagemaker_model_custom_channel_name(sagemaker_session): - f = DummyFramework(entry_point='my_script.py', role='DummyRole', train_instance_count=3, - train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker_session, - model_uri='s3://model-bucket/prefix/model.tar.gz', model_channel_name='testModelChannel') + f = DummyFramework( + entry_point="my_script.py", + role="DummyRole", + train_instance_count=3, + train_instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + model_uri="s3://model-bucket/prefix/model.tar.gz", + model_channel_name="testModelChannel", + ) _TrainingJob.start_new(f, {}) sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args - assert args['input_config'] == [{'ChannelName': 'testModelChannel', - 'InputMode': 'File', - 'ContentType': 'application/x-sagemaker-model', - 'DataSource': { - 'S3DataSource': {'S3DataType': 'S3Prefix', - 'S3DataDistributionType': 'FullyReplicated', - 'S3Uri': 's3://model-bucket/prefix/model.tar.gz'}}}] + assert args["input_config"] == [ + { + "ChannelName": "testModelChannel", + "InputMode": "File", + "ContentType": "application/x-sagemaker-model", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + "S3Uri": "s3://model-bucket/prefix/model.tar.gz", + } + }, + } + ] -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_custom_code_bucket(time, sagemaker_session): - code_bucket = 'codebucket' - prefix = 'someprefix' - code_location = 's3://{}/{}'.format(code_bucket, prefix) - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - code_location=code_location) - t.fit('s3://bucket/mydata') - - expected_key = '{}/{}/source/sourcedir.tar.gz'.format(prefix, JOB_NAME) - _, s3_args, _ = sagemaker_session.boto_session.resource('s3').Object.mock_calls[0] + code_bucket = "codebucket" + prefix = "someprefix" + code_location = "s3://{}/{}".format(code_bucket, prefix) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + code_location=code_location, + ) + t.fit("s3://bucket/mydata") + + expected_key = "{}/{}/source/sourcedir.tar.gz".format(prefix, JOB_NAME) + _, s3_args, _ = sagemaker_session.boto_session.resource("s3").Object.mock_calls[0] assert s3_args == (code_bucket, expected_key) - expected_submit_dir = 's3://{}/{}'.format(code_bucket, expected_key) + expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs['hyperparameters']['sagemaker_submit_directory'] == json.dumps(expected_submit_dir) + assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps( + expected_submit_dir + ) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_custom_code_bucket_without_prefix(time, sagemaker_session): - code_bucket = 'codebucket' - code_location = 's3://{}'.format(code_bucket) - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - code_location=code_location) - t.fit('s3://bucket/mydata') - - expected_key = '{}/source/sourcedir.tar.gz'.format(JOB_NAME) - _, s3_args, _ = sagemaker_session.boto_session.resource('s3').Object.mock_calls[0] + code_bucket = "codebucket" + code_location = "s3://{}".format(code_bucket) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + code_location=code_location, + ) + t.fit("s3://bucket/mydata") + + expected_key = "{}/source/sourcedir.tar.gz".format(JOB_NAME) + _, s3_args, _ = sagemaker_session.boto_session.resource("s3").Object.mock_calls[0] assert s3_args == (code_bucket, expected_key) - expected_submit_dir = 's3://{}/{}'.format(code_bucket, expected_key) + expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs['hyperparameters']['sagemaker_submit_directory'] == json.dumps(expected_submit_dir) + assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps( + expected_submit_dir + ) def test_invalid_custom_code_bucket(sagemaker_session): - code_location = 'thisllworkright?' - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - code_location=code_location) + code_location = "thisllworkright?" + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + code_location=code_location, + ) with pytest.raises(ValueError) as error: - t.fit('s3://bucket/mydata') + t.fit("s3://bucket/mydata") assert "Expecting 's3' scheme" in str(error) def test_augmented_manifest(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True) - fw.fit(inputs=s3_input('s3://mybucket/train_manifest', s3_data_type='AugmentedManifestFile', - attribute_names=['foo', 'bar'])) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit( + inputs=s3_input( + "s3://mybucket/train_manifest", + s3_data_type="AugmentedManifestFile", + attribute_names=["foo", "bar"], + ) + ) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - s3_data_source = train_kwargs['input_config'][0]['DataSource']['S3DataSource'] - assert s3_data_source['S3Uri'] == 's3://mybucket/train_manifest' - assert s3_data_source['S3DataType'] == 'AugmentedManifestFile' - assert s3_data_source['AttributeNames'] == ['foo', 'bar'] + s3_data_source = train_kwargs["input_config"][0]["DataSource"]["S3DataSource"] + assert s3_data_source["S3Uri"] == "s3://mybucket/train_manifest" + assert s3_data_source["S3DataType"] == "AugmentedManifestFile" + assert s3_data_source["AttributeNames"] == ["foo", "bar"] def test_s3_input_mode(sagemaker_session): - expected_input_mode = 'Pipe' - fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True) - fw.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode)) - - actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode'] + expected_input_mode = "Pipe" + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit(inputs=s3_input("s3://mybucket/train_manifest", input_mode=expected_input_mode)) + + actual_input_mode = sagemaker_session.method_calls[1][2]["input_mode"] assert actual_input_mode == expected_input_mode def test_shuffle_config(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True) - fw.fit(inputs=s3_input('s3://mybucket/train_manifest', shuffle_config=ShuffleConfig(100))) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit(inputs=s3_input("s3://mybucket/train_manifest", shuffle_config=ShuffleConfig(100))) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - channel = train_kwargs['input_config'][0] - assert channel['ShuffleConfig']['Seed'] == 100 + channel = train_kwargs["input_config"][0] + assert channel["ShuffleConfig"]["Seed"] == 100 BASE_HP = { - 'sagemaker_program': json.dumps(SCRIPT_NAME), - 'sagemaker_submit_directory': json.dumps('s3://mybucket/{}/source/sourcedir.tar.gz'.format(JOB_NAME)), - 'sagemaker_job_name': json.dumps(JOB_NAME) + "sagemaker_program": json.dumps(SCRIPT_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME) + ), + "sagemaker_job_name": json.dumps(JOB_NAME), } def test_local_code_location(): - config = { - 'local': { - 'local_code': True, - 'region': 'us-west-2' - } - } - sms = Mock(name='sagemaker_session', boto_session=None, - boto_region_name=REGION, config=config, local_mode=True) - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sms, - train_instance_count=1, train_instance_type='local', - base_job_name=IMAGE_NAME, hyperparameters={123: [456], 'learning_rate': 0.1}) - - t.fit('file:///data/file') + config = {"local": {"local_code": True, "region": "us-west-2"}} + sms = Mock( + name="sagemaker_session", + boto_session=None, + boto_region_name=REGION, + config=config, + local_mode=True, + ) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sms, + train_instance_count=1, + train_instance_type="local", + base_job_name=IMAGE_NAME, + hyperparameters={123: [456], "learning_rate": 0.1}, + ) + + t.fit("file:///data/file") assert t.source_dir == DATA_DIR - assert t.entry_point == 'dummy_script.py' + assert t.entry_point == "dummy_script.py" -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session): - uri = 'bucket/mydata' - - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - base_job_name=IMAGE_NAME, hyperparameters={123: [456], 'learning_rate': 0.1}) - t.fit('s3://{}'.format(uri)) + uri = "bucket/mydata" + + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + base_job_name=IMAGE_NAME, + hyperparameters={123: [456], "learning_rate": 0.1}, + ) + t.fit("s3://{}".format(uri)) expected_hyperparameters = BASE_HP.copy() - expected_hyperparameters['sagemaker_enable_cloudwatch_metrics'] = 'false' - expected_hyperparameters['sagemaker_container_log_level'] = str(logging.INFO) - expected_hyperparameters['learning_rate'] = json.dumps(0.1) - expected_hyperparameters['123'] = json.dumps([456]) - expected_hyperparameters['sagemaker_region'] = '"us-west-2"' + expected_hyperparameters["sagemaker_enable_cloudwatch_metrics"] = "false" + expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO) + expected_hyperparameters["learning_rate"] = json.dumps(0.1) + expected_hyperparameters["123"] = json.dumps([456]) + expected_hyperparameters["sagemaker_region"] = '"us-west-2"' - actual_hyperparameter = sagemaker_session.method_calls[1][2]['hyperparameters'] + actual_hyperparameter = sagemaker_session.method_calls[1][2]["hyperparameters"] assert actual_hyperparameter == expected_hyperparameters -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_start_new_wait_called(strftime, sagemaker_session): - uri = 'bucket/mydata' + uri = "bucket/mydata" - t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) - t.fit('s3://{}'.format(uri)) + t.fit("s3://{}".format(uri)) expected_hyperparameters = BASE_HP.copy() - expected_hyperparameters['sagemaker_enable_cloudwatch_metrics'] = 'false' - expected_hyperparameters['sagemaker_container_log_level'] = str(logging.INFO) - expected_hyperparameters['sagemaker_region'] = '"us-west-2"' + expected_hyperparameters["sagemaker_enable_cloudwatch_metrics"] = "false" + expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO) + expected_hyperparameters["sagemaker_region"] = '"us-west-2"' - actual_hyperparameter = sagemaker_session.method_calls[1][2]['hyperparameters'] + actual_hyperparameter = sagemaker_session.method_calls[1][2]["hyperparameters"] assert actual_hyperparameter == expected_hyperparameters assert sagemaker_session.wait_for_job.assert_called_once def test_delete_endpoint(sagemaker_session): - t = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - container_log_level=logging.INFO) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + container_log_level=logging.INFO, + ) class tj(object): @property def name(self): - return 'myjob' + return "myjob" t.latest_training_job = tj() t.delete_endpoint() - sagemaker_session.delete_endpoint.assert_called_with('myjob') + sagemaker_session.delete_endpoint.assert_called_with("myjob") def test_delete_endpoint_without_endpoint(sagemaker_session): - t = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) + t = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) with pytest.raises(ValueError) as error: t.delete_endpoint() - assert 'Endpoint was not created yet' in str(error) + assert "Endpoint was not created yet" in str(error) def test_enable_cloudwatch_metrics(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True) - fw.fit(inputs=s3_input('s3://mybucket/train')) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit(inputs=s3_input("s3://mybucket/train")) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics'] + assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"] def test_attach_framework(sagemaker_session): returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description['VpcConfig'] = { - 'Subnets': ['foo'], - 'SecurityGroupIds': ['bar'] - } - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - framework_estimator = DummyFramework.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert framework_estimator._current_job_name == 'neo' - assert framework_estimator.latest_training_job.job_name == 'neo' - assert framework_estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + returned_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + framework_estimator = DummyFramework.attach( + training_job_name="neo", sagemaker_session=sagemaker_session + ) + assert framework_estimator._current_job_name == "neo" + assert framework_estimator.latest_training_job.job_name == "neo" + assert framework_estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert framework_estimator.train_instance_count == 1 assert framework_estimator.train_max_run == 24 * 60 * 60 - assert framework_estimator.input_mode == 'File' - assert framework_estimator.base_job_name == 'neo' - assert framework_estimator.output_path == 's3://place/output/neo' - assert framework_estimator.output_kms_key == '' - assert framework_estimator.hyperparameters()['training_steps'] == '100' - assert framework_estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert framework_estimator.entry_point == 'iris-dnn-classifier.py' - assert framework_estimator.subnets == ['foo'] - assert framework_estimator.security_group_ids == ['bar'] + assert framework_estimator.input_mode == "File" + assert framework_estimator.base_job_name == "neo" + assert framework_estimator.output_path == "s3://place/output/neo" + assert framework_estimator.output_kms_key == "" + assert framework_estimator.hyperparameters()["training_steps"] == "100" + assert framework_estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert framework_estimator.entry_point == "iris-dnn-classifier.py" + assert framework_estimator.subnets == ["foo"] + assert framework_estimator.security_group_ids == ["bar"] assert framework_estimator.encrypt_inter_container_traffic is False - assert framework_estimator.tags == LIST_TAGS_RESULT['Tags'] + assert framework_estimator.tags == LIST_TAGS_RESULT["Tags"] def test_attach_without_hyperparameters(sagemaker_session): returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - del returned_job_description['HyperParameters'] + del returned_job_description["HyperParameters"] - mock_describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) + mock_describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job - estimator = Estimator.attach(training_job_name='job', - sagemaker_session=sagemaker_session) + estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) assert estimator.hyperparameters() == {} def test_attach_framework_with_tuning(sagemaker_session): returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description['HyperParameters']['_tuning_objective_metric'] = 'Validation-accuracy' + returned_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy" - mock_describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) + mock_describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job - framework_estimator = DummyFramework.attach(training_job_name='neo', - sagemaker_session=sagemaker_session) - assert framework_estimator.latest_training_job.job_name == 'neo' - assert framework_estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + framework_estimator = DummyFramework.attach( + training_job_name="neo", sagemaker_session=sagemaker_session + ) + assert framework_estimator.latest_training_job.job_name == "neo" + assert framework_estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert framework_estimator.train_instance_count == 1 assert framework_estimator.train_max_run == 24 * 60 * 60 - assert framework_estimator.input_mode == 'File' - assert framework_estimator.base_job_name == 'neo' - assert framework_estimator.output_path == 's3://place/output/neo' - assert framework_estimator.output_kms_key == '' + assert framework_estimator.input_mode == "File" + assert framework_estimator.base_job_name == "neo" + assert framework_estimator.output_path == "s3://place/output/neo" + assert framework_estimator.output_kms_key == "" hyper_params = framework_estimator.hyperparameters() - assert hyper_params['training_steps'] == '100' - assert hyper_params['_tuning_objective_metric'] == '"Validation-accuracy"' - assert framework_estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert framework_estimator.entry_point == 'iris-dnn-classifier.py' + assert hyper_params["training_steps"] == "100" + assert hyper_params["_tuning_objective_metric"] == '"Validation-accuracy"' + assert framework_estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert framework_estimator.entry_point == "iris-dnn-classifier.py" assert framework_estimator.encrypt_inter_container_traffic is False def test_attach_framework_with_model_channel(sagemaker_session): - s3_uri = 's3://some/s3/path/model.tar.gz' + s3_uri = "s3://some/s3/path/model.tar.gz" returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description['InputDataConfig'] = [ + returned_job_description["InputDataConfig"] = [ { - 'ChannelName': 'model', - 'InputMode': 'File', - 'DataSource': { - 'S3DataSource': { - 'S3Uri': s3_uri - } - } + "ChannelName": "model", + "InputMode": "File", + "DataSource": {"S3DataSource": {"S3Uri": s3_uri}}, } ] - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) - framework_estimator = DummyFramework.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + framework_estimator = DummyFramework.attach( + training_job_name="neo", sagemaker_session=sagemaker_session + ) assert framework_estimator.model_uri is s3_uri assert framework_estimator.encrypt_inter_container_traffic is False def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker_session): returned_job_description = RETURNED_JOB_DESCRIPTION.copy() - returned_job_description['EnableInterContainerTrafficEncryption'] = True + returned_job_description["EnableInterContainerTrafficEncryption"] = True - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) - framework_estimator = DummyFramework.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + framework_estimator = DummyFramework.attach( + training_job_name="neo", sagemaker_session=sagemaker_session + ) assert framework_estimator.encrypt_inter_container_traffic is True -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_fit_verify_job_name(strftime, sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True, tags=TAGS, - encrypt_inter_container_traffic=True) - fw.fit(inputs=s3_input('s3://mybucket/train')) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + tags=TAGS, + encrypt_inter_container_traffic=True, + ) + fw.fit(inputs=s3_input("s3://mybucket/train")) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs['hyperparameters']['sagemaker_enable_cloudwatch_metrics'] - assert train_kwargs['image'] == IMAGE_NAME - assert train_kwargs['input_mode'] == 'File' - assert train_kwargs['tags'] == TAGS - assert train_kwargs['job_name'] == JOB_NAME - assert train_kwargs['encrypt_inter_container_traffic'] is True + assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"] + assert train_kwargs["image"] == IMAGE_NAME + assert train_kwargs["input_mode"] == "File" + assert train_kwargs["tags"] == TAGS + assert train_kwargs["job_name"] == JOB_NAME + assert train_kwargs["encrypt_inter_container_traffic"] is True assert fw.latest_training_job.name == JOB_NAME def test_prepare_for_training_unique_job_name_generation(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=True) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) fw._prepare_for_training() first_job_name = fw._current_job_name @@ -579,52 +731,76 @@ def test_prepare_for_training_unique_job_name_generation(sagemaker_session): def test_prepare_for_training_force_name(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - base_job_name='some', enable_cloudwatch_metrics=True) - fw._prepare_for_training(job_name='use_it') - assert 'use_it' == fw._current_job_name - - -@patch('time.strftime', return_value=TIMESTAMP) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + base_job_name="some", + enable_cloudwatch_metrics=True, + ) + fw._prepare_for_training(job_name="use_it") + assert "use_it" == fw._current_job_name + + +@patch("time.strftime", return_value=TIMESTAMP) def test_prepare_for_training_force_name_generation(strftime, sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - base_job_name='some', enable_cloudwatch_metrics=True) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + base_job_name="some", + enable_cloudwatch_metrics=True, + ) fw.base_job_name = None fw._prepare_for_training() assert JOB_NAME == fw._current_job_name -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_init_with_source_dir_s3(strftime, sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, source_dir='s3://location', role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - enable_cloudwatch_metrics=False) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + source_dir="s3://location", + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=False, + ) fw._prepare_for_training() expected_hyperparameters = { - 'sagemaker_program': SCRIPT_NAME, - 'sagemaker_job_name': JOB_NAME, - 'sagemaker_enable_cloudwatch_metrics': False, - 'sagemaker_container_log_level': logging.INFO, - 'sagemaker_submit_directory': 's3://location', - 'sagemaker_region': 'us-west-2', + "sagemaker_program": SCRIPT_NAME, + "sagemaker_job_name": JOB_NAME, + "sagemaker_enable_cloudwatch_metrics": False, + "sagemaker_container_log_level": logging.INFO, + "sagemaker_submit_directory": "s3://location", + "sagemaker_region": "us-west-2", } assert fw._hyperparameters == expected_hyperparameters -@patch('sagemaker.estimator.name_from_image', return_value=MODEL_IMAGE) +@patch("sagemaker.estimator.name_from_image", return_value=MODEL_IMAGE) def test_framework_transformer_creation(name_from_image, sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE) name_from_image.assert_called_with(MODEL_IMAGE) - sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None, tags=None) + sagemaker_session.create_model.assert_called_with( + MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None, tags=None + ) assert isinstance(transformer, Transformer) assert transformer.sagemaker_session == sagemaker_session @@ -635,31 +811,51 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session): assert transformer.env == {} -@patch('sagemaker.estimator.name_from_image', return_value=MODEL_IMAGE) +@patch("sagemaker.estimator.name_from_image", return_value=MODEL_IMAGE) def test_framework_transformer_creation_with_optional_params(name_from_image, sagemaker_session): - base_name = 'foo' - vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session, - base_job_name=base_name, - subnets=vpc_config['Subnets'], security_group_ids=vpc_config['SecurityGroupIds']) + base_name = "foo" + vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + base_job_name=base_name, + subnets=vpc_config["Subnets"], + security_group_ids=vpc_config["SecurityGroupIds"], + ) fw.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) - strategy = 'MultiRecord' - assemble_with = 'Line' - kms_key = 'key' - accept = 'text/csv' + strategy = "MultiRecord" + assemble_with = "Line" + kms_key = "key" + accept = "text/csv" max_concurrent_transforms = 1 max_payload = 6 - env = {'FOO': 'BAR'} - new_role = 'dummy-model-role' - - transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, - output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, - max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1) + env = {"FOO": "BAR"} + new_role = "dummy-model-role" + + transformer = fw.transformer( + INSTANCE_COUNT, + INSTANCE_TYPE, + strategy=strategy, + assemble_with=assemble_with, + output_path=OUTPUT_PATH, + output_kms_key=kms_key, + accept=accept, + tags=TAGS, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + volume_kms_key=kms_key, + env=env, + role=new_role, + model_server_workers=1, + ) - sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config, tags=TAGS) + sagemaker_session.create_model.assert_called_with( + MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config, tags=TAGS + ) assert transformer.strategy == strategy assert transformer.assemble_with == assemble_with assert transformer.output_path == OUTPUT_PATH @@ -674,25 +870,40 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa def test_ensure_latest_training_job(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) - fw.latest_training_job = Mock(name='training_job') + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) + fw.latest_training_job = Mock(name="training_job") fw._ensure_latest_training_job() def test_ensure_latest_training_job_failure(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) with pytest.raises(ValueError) as e: fw._ensure_latest_training_job() - assert 'Estimator is not associated with a training job' in str(e) + assert "Estimator is not associated with a training job" in str(e) def test_estimator_transformer_creation(sagemaker_session): - estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session) + estimator = Estimator( + image_name=IMAGE_NAME, + role=ROLE, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) sagemaker_session.create_model_from_job.return_value = JOB_NAME @@ -708,25 +919,40 @@ def test_estimator_transformer_creation(sagemaker_session): def test_estimator_transformer_creation_with_optional_params(sagemaker_session): - base_name = 'foo' - estimator = Estimator(image_name=IMAGE_NAME, role=ROLE, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, sagemaker_session=sagemaker_session, - base_job_name=base_name) + base_name = "foo" + estimator = Estimator( + image_name=IMAGE_NAME, + role=ROLE, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + base_job_name=base_name, + ) estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME) sagemaker_session.create_model_from_job.return_value = JOB_NAME - strategy = 'MultiRecord' - assemble_with = 'Line' - kms_key = 'key' - accept = 'text/csv' + strategy = "MultiRecord" + assemble_with = "Line" + kms_key = "key" + accept = "text/csv" max_concurrent_transforms = 1 max_payload = 6 - env = {'FOO': 'BAR'} - - transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, - output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, - max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=env, role=ROLE) + env = {"FOO": "BAR"} + + transformer = estimator.transformer( + INSTANCE_COUNT, + INSTANCE_TYPE, + strategy=strategy, + assemble_with=assemble_with, + output_path=OUTPUT_PATH, + output_kms_key=kms_key, + accept=accept, + tags=TAGS, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + env=env, + role=ROLE, + ) sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE, tags=TAGS) assert transformer.strategy == strategy @@ -744,97 +970,122 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session): # _TrainingJob 'utils' def test_start_new(sagemaker_session): training_job = _TrainingJob(sagemaker_session, JOB_NAME) - hyperparameters = {'mock': 'hyperparameters'} - inputs = 's3://mybucket/train' - - estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, - output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session, - hyperparameters=hyperparameters) + hyperparameters = {"mock": "hyperparameters"} + inputs = "s3://mybucket/train" + + estimator = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + hyperparameters=hyperparameters, + ) started_training_job = training_job.start_new(estimator, inputs) called_args = sagemaker_session.train.call_args assert started_training_job.sagemaker_session == sagemaker_session - assert called_args[1]['hyperparameters'] == hyperparameters + assert called_args[1]["hyperparameters"] == hyperparameters sagemaker_session.train.assert_called_once() def test_start_new_not_local_mode_error(sagemaker_session): training_job = _TrainingJob(sagemaker_session, JOB_NAME) - inputs = 'file://mybucket/train' - - estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, - output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) + inputs = "file://mybucket/train" + + estimator = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) with pytest.raises(ValueError) as error: training_job.start_new(estimator, inputs) - assert 'File URIs are supported in local mode only. Please use a S3 URI instead.' == str(error) + assert "File URIs are supported in local mode only. Please use a S3 URI instead." == str( + error + ) def test_container_log_level(sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - container_log_level=logging.DEBUG) - fw.fit(inputs=s3_input('s3://mybucket/train')) + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + container_log_level=logging.DEBUG, + ) + fw.fit(inputs=s3_input("s3://mybucket/train")) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] - assert train_kwargs['hyperparameters']['sagemaker_container_log_level'] == '10' + assert train_kwargs["hyperparameters"]["sagemaker_container_log_level"] == "10" -@patch('sagemaker.utils') +@patch("sagemaker.utils") def test_same_code_location_keeps_kms_key(utils, sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, - role='DummyRole', - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, - output_kms_key='kms-key') + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + output_kms_key="kms-key", + ) fw.fit(wait=False) - extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'} - obj = sagemaker_session.boto_session.resource('s3').Object + extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": "kms-key"} + obj = sagemaker_session.boto_session.resource("s3").Object - obj.assert_called_with('mybucket', '%s/source/sourcedir.tar.gz' % fw._current_job_name) + obj.assert_called_with("mybucket", "%s/source/sourcedir.tar.gz" % fw._current_job_name) obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args) -@patch('sagemaker.utils') +@patch("sagemaker.utils") def test_different_code_location_kms_key(utils, sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, - role='DummyRole', - sagemaker_session=sagemaker_session, - code_location='s3://another-location', - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, - output_kms_key='kms-key') + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + code_location="s3://another-location", + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + output_kms_key="kms-key", + ) fw.fit(wait=False) - obj = sagemaker_session.boto_session.resource('s3').Object + obj = sagemaker_session.boto_session.resource("s3").Object - obj.assert_called_with('another-location', '%s/source/sourcedir.tar.gz' % fw._current_job_name) + obj.assert_called_with("another-location", "%s/source/sourcedir.tar.gz" % fw._current_job_name) obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None) -@patch('sagemaker.utils') +@patch("sagemaker.utils") def test_default_code_location_uses_output_path(utils, sagemaker_session): - fw = DummyFramework(entry_point=SCRIPT_PATH, - role='DummyRole', - sagemaker_session=sagemaker_session, - output_path='s3://output_path', - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, - output_kms_key='kms-key') + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + output_path="s3://output_path", + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + output_kms_key="kms-key", + ) fw.fit(wait=False) - obj = sagemaker_session.boto_session.resource('s3').Object + obj = sagemaker_session.boto_session.resource("s3").Object - obj.assert_called_with('output_path', '%s/source/sourcedir.tar.gz' % fw._current_job_name) + obj.assert_called_with("output_path", "%s/source/sourcedir.tar.gz" % fw._current_job_name) - extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'} + extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": "kms-key"} obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args) @@ -858,163 +1109,203 @@ def test_wait_with_logs(sagemaker_session): def test_unsupported_type_in_dict(): with pytest.raises(ValueError): - _TrainingJob._format_inputs_to_input_config({'a': 66}) + _TrainingJob._format_inputs_to_input_config({"a": 66}) ################################################################################# # Tests for the generic Estimator class NO_INPUT_TRAIN_CALL = { - 'hyperparameters': {}, - 'image': IMAGE_NAME, - 'input_config': None, - 'input_mode': 'File', - 'output_config': {'S3OutputPath': OUTPUT_PATH}, - 'resource_config': { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeSizeInGB': 30 + "hyperparameters": {}, + "image": IMAGE_NAME, + "input_config": None, + "input_mode": "File", + "output_config": {"S3OutputPath": OUTPUT_PATH}, + "resource_config": { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": 30, }, - 'stop_condition': {'MaxRuntimeInSeconds': 86400}, - 'tags': None, - 'vpc_config': None, - 'metric_definitions': None + "stop_condition": {"MaxRuntimeInSeconds": 86400}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, } -INPUT_CONFIG = [{ - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://bucket/training-prefix' - } - }, - 'ChannelName': 'train' -}] +INPUT_CONFIG = [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/training-prefix", + } + }, + "ChannelName": "train", + } +] BASE_TRAIN_CALL = dict(NO_INPUT_TRAIN_CALL) -BASE_TRAIN_CALL.update({'input_config': INPUT_CONFIG}) +BASE_TRAIN_CALL.update({"input_config": INPUT_CONFIG}) -HYPERPARAMS = {'x': 1, 'y': 'hello'} +HYPERPARAMS = {"x": 1, "y": "hello"} STRINGIFIED_HYPERPARAMS = dict([(x, str(y)) for x, y in HYPERPARAMS.items()]) HP_TRAIN_CALL = dict(BASE_TRAIN_CALL) -HP_TRAIN_CALL.update({'hyperparameters': STRINGIFIED_HYPERPARAMS}) +HP_TRAIN_CALL.update({"hyperparameters": STRINGIFIED_HYPERPARAMS}) def test_fit_deploy_keep_tags(sagemaker_session): - tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] - estimator = Estimator(IMAGE_NAME, - ROLE, - INSTANCE_COUNT, - INSTANCE_TYPE, - tags=tags, - sagemaker_session=sagemaker_session) + tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] + estimator = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + tags=tags, + sagemaker_session=sagemaker_session, + ) estimator.fit() estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE) - variant = [{'InstanceType': 'c4.4xlarge', 'VariantName': 'AllTraffic', - 'ModelName': ANY, 'InitialVariantWeight': 1, - 'InitialInstanceCount': 1}] + variant = [ + { + "InstanceType": "c4.4xlarge", + "VariantName": "AllTraffic", + "ModelName": ANY, + "InitialVariantWeight": 1, + "InitialInstanceCount": 1, + } + ] job_name = estimator._current_job_name - sagemaker_session.endpoint_from_production_variants.assert_called_with(job_name, - variant, - tags, - None, - True) + sagemaker_session.endpoint_from_production_variants.assert_called_with( + job_name, variant, tags, None, True + ) sagemaker_session.create_model.assert_called_with( ANY, - 'DummyRole', - {'ModelDataUrl': 's3://bucket/model.tar.gz', 'Environment': {}, 'Image': 'fakeimage'}, + "DummyRole", + {"ModelDataUrl": "s3://bucket/model.tar.gz", "Environment": {}, "Image": "fakeimage"}, enable_network_isolation=False, vpc_config=None, - tags=tags) + tags=tags, + ) def test_generic_to_fit_no_input(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session) + e = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) e.fit() sagemaker_session.train.assert_called_once() assert len(sagemaker_session.train.call_args[0]) == 0 args = sagemaker_session.train.call_args[1] - assert args['job_name'].startswith(IMAGE_NAME) + assert args["job_name"].startswith(IMAGE_NAME) - args.pop('job_name') - args.pop('role') + args.pop("job_name") + args.pop("role") assert args == NO_INPUT_TRAIN_CALL def test_generic_to_fit_no_hps(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session) + e = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) - e.fit({'train': 's3://bucket/training-prefix'}) + e.fit({"train": "s3://bucket/training-prefix"}) sagemaker_session.train.assert_called_once() assert len(sagemaker_session.train.call_args[0]) == 0 args = sagemaker_session.train.call_args[1] - assert args['job_name'].startswith(IMAGE_NAME) + assert args["job_name"].startswith(IMAGE_NAME) - args.pop('job_name') - args.pop('role') + args.pop("job_name") + args.pop("role") assert args == BASE_TRAIN_CALL def test_generic_to_fit_with_hps(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session) + e = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) e.set_hyperparameters(**HYPERPARAMS) - e.fit({'train': 's3://bucket/training-prefix'}) + e.fit({"train": "s3://bucket/training-prefix"}) sagemaker_session.train.assert_called_once() assert len(sagemaker_session.train.call_args[0]) == 0 args = sagemaker_session.train.call_args[1] - assert args['job_name'].startswith(IMAGE_NAME) + assert args["job_name"].startswith(IMAGE_NAME) - args.pop('job_name') - args.pop('role') + args.pop("job_name") + args.pop("role") assert args == HP_TRAIN_CALL def test_generic_to_fit_with_encrypt_inter_container_traffic_flag(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session, encrypt_inter_container_traffic=True) + e = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + encrypt_inter_container_traffic=True, + ) e.fit() sagemaker_session.train.assert_called_once() args = sagemaker_session.train.call_args[1] - assert args['encrypt_inter_container_traffic'] is True + assert args["encrypt_inter_container_traffic"] is True def test_generic_to_deploy(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session) + e = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) e.set_hyperparameters(**HYPERPARAMS) - e.fit({'train': 's3://bucket/training-prefix'}) + e.fit({"train": "s3://bucket/training-prefix"}) predictor = e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) sagemaker_session.train.assert_called_once() assert len(sagemaker_session.train.call_args[0]) == 0 args = sagemaker_session.train.call_args[1] - assert args['job_name'].startswith(IMAGE_NAME) + assert args["job_name"].startswith(IMAGE_NAME) - args.pop('job_name') - args.pop('role') + args.pop("job_name") + args.pop("role") assert args == HP_TRAIN_CALL @@ -1022,9 +1313,9 @@ def test_generic_to_deploy(sagemaker_session): args, kwargs = sagemaker_session.create_model.call_args assert args[0].startswith(IMAGE_NAME) assert args[1] == ROLE - assert args[2]['Image'] == IMAGE_NAME - assert args[2]['ModelDataUrl'] == MODEL_DATA - assert kwargs['vpc_config'] is None + assert args[2]["Image"] == IMAGE_NAME + assert args[2]["ModelDataUrl"] == MODEL_DATA + assert kwargs["vpc_config"] is None assert isinstance(predictor, RealTimePredictor) assert predictor.endpoint.startswith(IMAGE_NAME) @@ -1032,27 +1323,30 @@ def test_generic_to_deploy(sagemaker_session): def test_generic_training_job_analytics(sagemaker_session): - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value={ - 'TuningJobArn': 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner', - 'TrainingStartTime': 1530562991.299, - "AlgorithmSpecification": { - "TrainingImage": "some-image-url", - "TrainingInputMode": "File", - "MetricDefinitions": [ - { - "Name": "train:loss", - "Regex": "train_loss=([0-9]+\\.[0-9]+)" - }, - { - "Name": "validation:loss", - "Regex": "valid_loss=([0-9]+\\.[0-9]+)" - } - ] + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", + return_value={ + "TuningJobArn": "arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner", + "TrainingStartTime": 1530562991.299, + "AlgorithmSpecification": { + "TrainingImage": "some-image-url", + "TrainingInputMode": "File", + "MetricDefinitions": [ + {"Name": "train:loss", "Regex": "train_loss=([0-9]+\\.[0-9]+)"}, + {"Name": "validation:loss", "Regex": "valid_loss=([0-9]+\\.[0-9]+)"}, + ], + }, }, - }) + ) - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session) + e = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) with pytest.raises(ValueError) as err: # noqa: F841 # No training job yet @@ -1060,78 +1354,89 @@ def test_generic_training_job_analytics(sagemaker_session): assert a is not None # This line is never reached e.set_hyperparameters(**HYPERPARAMS) - e.fit({'train': 's3://bucket/training-prefix'}) + e.fit({"train": "s3://bucket/training-prefix"}) a = e.training_job_analytics assert a is not None def test_generic_create_model_vpc_config_override(sagemaker_session): - vpc_config_a = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - vpc_config_b = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']} + vpc_config_a = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, - sagemaker_session=sagemaker_session) - e.fit({'train': 's3://bucket/training-prefix'}) + e = Estimator( + IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + ) + e.fit({"train": "s3://bucket/training-prefix"}) assert e.get_vpc_config() is None assert e.create_model().vpc_config is None assert e.create_model(vpc_config_override=vpc_config_a).vpc_config == vpc_config_a assert e.create_model(vpc_config_override=None).vpc_config is None - e.subnets = vpc_config_a['Subnets'] - e.security_group_ids = vpc_config_a['SecurityGroupIds'] + e.subnets = vpc_config_a["Subnets"] + e.security_group_ids = vpc_config_a["SecurityGroupIds"] assert e.get_vpc_config() == vpc_config_a assert e.create_model().vpc_config == vpc_config_a assert e.create_model(vpc_config_override=vpc_config_b).vpc_config == vpc_config_b assert e.create_model(vpc_config_override=None).vpc_config is None with pytest.raises(ValueError): - e.get_vpc_config(vpc_config_override={'invalid'}) + e.get_vpc_config(vpc_config_override={"invalid"}) with pytest.raises(ValueError): - e.create_model(vpc_config_override={'invalid'}) + e.create_model(vpc_config_override={"invalid"}) def test_generic_deploy_vpc_config_override(sagemaker_session): - vpc_config_a = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - vpc_config_b = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']} + vpc_config_a = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, - sagemaker_session=sagemaker_session) - e.fit({'train': 's3://bucket/training-prefix'}) + e = Estimator( + IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + ) + e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) - assert sagemaker_session.create_model.call_args_list[0][1]['vpc_config'] is None + assert sagemaker_session.create_model.call_args_list[0][1]["vpc_config"] is None - e.subnets = vpc_config_a['Subnets'] - e.security_group_ids = vpc_config_a['SecurityGroupIds'] + e.subnets = vpc_config_a["Subnets"] + e.security_group_ids = vpc_config_a["SecurityGroupIds"] e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) - assert sagemaker_session.create_model.call_args_list[1][1]['vpc_config'] == vpc_config_a + assert sagemaker_session.create_model.call_args_list[1][1]["vpc_config"] == vpc_config_a e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=vpc_config_b) - assert sagemaker_session.create_model.call_args_list[2][1]['vpc_config'] == vpc_config_b + assert sagemaker_session.create_model.call_args_list[2][1]["vpc_config"] == vpc_config_b e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, vpc_config_override=None) - assert sagemaker_session.create_model.call_args_list[3][1]['vpc_config'] is None + assert sagemaker_session.create_model.call_args_list[3][1]["vpc_config"] is None def test_generic_deploy_accelerator_type(sagemaker_session): - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, - sagemaker_session=sagemaker_session) - e.fit({'train': 's3://bucket/training-prefix'}) + e = Estimator( + IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + ) + e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, ACCELERATOR_TYPE) args = e.sagemaker_session.endpoint_from_production_variants.call_args[0] assert args[0].startswith(IMAGE_NAME) - assert args[1][0]['AcceleratorType'] == ACCELERATOR_TYPE - assert args[1][0]['InitialInstanceCount'] == INSTANCE_COUNT - assert args[1][0]['InstanceType'] == INSTANCE_TYPE + assert args[1][0]["AcceleratorType"] == ACCELERATOR_TYPE + assert args[1][0]["InitialInstanceCount"] == INSTANCE_COUNT + assert args[1][0]["InstanceType"] == INSTANCE_TYPE def test_deploy_with_update_endpoint(sagemaker_session): - estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH, - sagemaker_session=sagemaker_session) + estimator = Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) estimator.set_hyperparameters(**HYPERPARAMS) - estimator.fit({'train': 's3://bucket/training-prefix'}) - endpoint_name = 'endpoint-name' - estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name=endpoint_name, update_endpoint=True) + estimator.fit({"train": "s3://bucket/training-prefix"}) + endpoint_name = "endpoint-name" + estimator.deploy( + INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name=endpoint_name, update_endpoint=True + ) update_endpoint_args = sagemaker_session.update_endpoint.call_args[0] assert update_endpoint_args[0] == endpoint_name @@ -1140,8 +1445,8 @@ def test_deploy_with_update_endpoint(sagemaker_session): sagemaker_session.create_endpoint.assert_not_called() -@patch('sagemaker.estimator.LocalSession') -@patch('sagemaker.estimator.Session') +@patch("sagemaker.estimator.LocalSession") +@patch("sagemaker.estimator.Session") def test_local_mode(session_class, local_session_class): local_session = Mock() local_session.local_mode = True @@ -1152,128 +1457,158 @@ def test_local_mode(session_class, local_session_class): local_session_class.return_value = local_session session_class.return_value = session - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, 'local') + e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, "local") print(e.sagemaker_session.local_mode) assert e.sagemaker_session.local_mode is True - e2 = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, 'local_gpu') + e2 = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, "local_gpu") assert e2.sagemaker_session.local_mode is True e3 = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE) assert e3.sagemaker_session.local_mode is False -@patch('sagemaker.estimator.LocalSession') +@patch("sagemaker.estimator.LocalSession") def test_distributed_gpu_local_mode(LocalSession): with pytest.raises(RuntimeError): - Estimator(IMAGE_NAME, ROLE, 3, 'local_gpu', output_path=OUTPUT_PATH) + Estimator(IMAGE_NAME, ROLE, 3, "local_gpu", output_path=OUTPUT_PATH) -@patch('sagemaker.estimator.LocalSession') +@patch("sagemaker.estimator.LocalSession") def test_local_mode_file_output_path(local_session_class): local_session = Mock() local_session.local_mode = True local_session_class.return_value = local_session - e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, 'local', output_path='file:///tmp/model/') - assert e.output_path == 'file:///tmp/model/' + e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, "local", output_path="file:///tmp/model/") + assert e.output_path == "file:///tmp/model/" -@patch('sagemaker.estimator.Session') +@patch("sagemaker.estimator.Session") def test_file_output_path_not_supported_outside_local_mode(session_class): session = Mock() session.local_mode = False session_class.return_value = session with pytest.raises(RuntimeError): - Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='file:///tmp/model') + Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path="file:///tmp/model") def test_prepare_init_params_from_job_description_with_image_training_job(): init_params = EstimatorBase._prepare_init_params_from_job_description( - job_details=RETURNED_JOB_DESCRIPTION) + job_details=RETURNED_JOB_DESCRIPTION + ) - assert init_params['role'] == 'arn:aws:iam::366:role/SageMakerRole' - assert init_params['train_instance_count'] == 1 - assert init_params['image'] == '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other-py2-cpu:1.0.4' + assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole" + assert init_params["train_instance_count"] == 1 + assert init_params["image"] == "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other-py2-cpu:1.0.4" def test_prepare_init_params_from_job_description_with_algorithm_training_job(): algorithm_job_description = RETURNED_JOB_DESCRIPTION.copy() - algorithm_job_description['AlgorithmSpecification'] = { - 'TrainingInputMode': 'File', - 'AlgorithmName': 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees' + algorithm_job_description["AlgorithmSpecification"] = { + "TrainingInputMode": "File", + "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", } init_params = EstimatorBase._prepare_init_params_from_job_description( - job_details=algorithm_job_description) + job_details=algorithm_job_description + ) - assert init_params['role'] == 'arn:aws:iam::366:role/SageMakerRole' - assert init_params['train_instance_count'] == 1 - assert init_params['algorithm_arn'] == 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees' + assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole" + assert init_params["train_instance_count"] == 1 + assert ( + init_params["algorithm_arn"] + == "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees" + ) def test_prepare_init_params_from_job_description_with_invalid_training_job(): invalid_job_description = RETURNED_JOB_DESCRIPTION.copy() - invalid_job_description['AlgorithmSpecification'] = { - 'TrainingInputMode': 'File', - } + invalid_job_description["AlgorithmSpecification"] = {"TrainingInputMode": "File"} with pytest.raises(RuntimeError) as error: EstimatorBase._prepare_init_params_from_job_description(job_details=invalid_job_description) - assert 'Invalid AlgorithmSpecification' in str(error) + assert "Invalid AlgorithmSpecification" in str(error) def test_prepare_for_training_with_base_name(sagemaker_session): - estimator = Estimator(image_name='some-image', role='some_image', train_instance_count=1, - train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker_session, - base_job_name='base_job_name') + estimator = Estimator( + image_name="some-image", + role="some_image", + train_instance_count=1, + train_instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + base_job_name="base_job_name", + ) estimator._prepare_for_training() - assert 'base_job_name' in estimator._current_job_name + assert "base_job_name" in estimator._current_job_name def test_prepare_for_training_with_name_based_on_image(sagemaker_session): - estimator = Estimator(image_name='some-image', role='some_image', train_instance_count=1, - train_instance_type='ml.m4.xlarge', sagemaker_session=sagemaker_session) + estimator = Estimator( + image_name="some-image", + role="some_image", + train_instance_count=1, + train_instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + ) estimator._prepare_for_training() - assert 'some-image' in estimator._current_job_name + assert "some-image" in estimator._current_job_name -@patch('sagemaker.algorithm.AlgorithmEstimator.validate_train_spec', Mock()) -@patch('sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters', Mock(return_value={})) +@patch("sagemaker.algorithm.AlgorithmEstimator.validate_train_spec", Mock()) +@patch("sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", Mock(return_value={})) def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): estimator = AlgorithmEstimator( - algorithm_arn='arn:aws:sagemaker:us-west-2:1234:algorithm/scikit-decision-trees-1542410022', - role='some_image', train_instance_count=1, train_instance_type='ml.m4.xlarge', - sagemaker_session=sagemaker_session) + algorithm_arn="arn:aws:sagemaker:us-west-2:1234:algorithm/scikit-decision-trees-1542410022", + role="some_image", + train_instance_count=1, + train_instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + ) estimator._prepare_for_training() - assert 'scikit-decision-trees-1542410022' in estimator._current_job_name - - -@patch('sagemaker.estimator.Estimator.fit', - Mock(side_effect=ClientError(error_response={"Error": { - "Code": 403, - "Message": - '"EnableInterContainerTrafficEncryption" and ' - '"VpcConfig" must be provided together'}}, - operation_name='Unit Test'))) + assert "scikit-decision-trees-1542410022" in estimator._current_job_name + + +@patch( + "sagemaker.estimator.Estimator.fit", + Mock( + side_effect=ClientError( + error_response={ + "Error": { + "Code": 403, + "Message": '"EnableInterContainerTrafficEncryption" and ' + '"VpcConfig" must be provided together', + } + }, + operation_name="Unit Test", + ) + ), +) def test_encryption_flag_in_non_vpc_mode_invalid(sagemaker_session): image_name = registry("us-west-2") + "/factorization-machines:1" with pytest.raises(ClientError) as error: - estimator = Estimator(image_name=image_name, - role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session, - base_job_name='test-non-vpc-encryption', - encrypt_inter_container_traffic=True) + estimator = Estimator( + image_name=image_name, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + base_job_name="test-non-vpc-encryption", + encrypt_inter_container_traffic=True, + ) estimator.fit() - assert '"EnableInterContainerTrafficEncryption" and "VpcConfig" must be provided together' in str(error) + assert ( + '"EnableInterContainerTrafficEncryption" and "VpcConfig" must be provided together' + in str(error) + ) ################################################################################# diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index d420301378..1f892dc444 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -15,119 +15,141 @@ import pytest from mock import Mock, patch -from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesPredictor +from sagemaker.amazon.factorization_machines import ( + FactorizationMachines, + FactorizationMachinesPredictor, +) from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" NUM_FACTORS = 3 -PREDICTOR_TYPE = 'regressor' +PREDICTOR_TYPE = "regressor" -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'num_factors': NUM_FACTORS, 'predictor_type': PREDICTOR_TYPE}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict( + {"num_factors": NUM_FACTORS, "predictor_type": PREDICTOR_TYPE}, **COMMON_TRAIN_ARGS +) -REGION = 'us-west-2' -BUCKET_NAME = 'Some-Bucket' +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': 's3://bucket/model.tar.gz' - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return sms def test_init_required_positional(sagemaker_session): - fm = FactorizationMachines('myrole', 1, 'ml.c4.xlarge', 3, 'regressor', - sagemaker_session=sagemaker_session) - assert fm.role == 'myrole' + fm = FactorizationMachines( + "myrole", 1, "ml.c4.xlarge", 3, "regressor", sagemaker_session=sagemaker_session + ) + assert fm.role == "myrole" assert fm.train_instance_count == 1 - assert fm.train_instance_type == 'ml.c4.xlarge' + assert fm.train_instance_type == "ml.c4.xlarge" assert fm.num_factors == 3 - assert fm.predictor_type == 'regressor' + assert fm.predictor_type == "regressor" def test_init_required_named(sagemaker_session): fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert fm.role == COMMON_TRAIN_ARGS['role'] - assert fm.train_instance_count == COMMON_TRAIN_ARGS['train_instance_count'] - assert fm.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] - assert fm.num_factors == ALL_REQ_ARGS['num_factors'] - assert fm.predictor_type == ALL_REQ_ARGS['predictor_type'] + assert fm.role == COMMON_TRAIN_ARGS["role"] + assert fm.train_instance_count == COMMON_TRAIN_ARGS["train_instance_count"] + assert fm.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] + assert fm.num_factors == ALL_REQ_ARGS["num_factors"] + assert fm.predictor_type == ALL_REQ_ARGS["predictor_type"] def test_all_hyperparameters(sagemaker_session): - fm = FactorizationMachines(sagemaker_session=sagemaker_session, - epochs=2, clip_gradient=1e2, eps=0.001, rescale_grad=2.2, - bias_lr=0.01, linear_lr=0.002, factors_lr=0.0003, - bias_wd=0.0004, linear_wd=1.01, factors_wd=1.002, - bias_init_method='uniform', bias_init_scale=0.1, bias_init_sigma=0.05, - bias_init_value=2.002, linear_init_method='constant', linear_init_scale=0.02, - linear_init_sigma=0.003, linear_init_value=1.0, factors_init_method='normal', - factors_init_scale=1.101, factors_init_sigma=1.202, factors_init_value=1.303, - **ALL_REQ_ARGS) + fm = FactorizationMachines( + sagemaker_session=sagemaker_session, + epochs=2, + clip_gradient=1e2, + eps=0.001, + rescale_grad=2.2, + bias_lr=0.01, + linear_lr=0.002, + factors_lr=0.0003, + bias_wd=0.0004, + linear_wd=1.01, + factors_wd=1.002, + bias_init_method="uniform", + bias_init_scale=0.1, + bias_init_sigma=0.05, + bias_init_value=2.002, + linear_init_method="constant", + linear_init_scale=0.02, + linear_init_sigma=0.003, + linear_init_value=1.0, + factors_init_method="normal", + factors_init_scale=1.101, + factors_init_sigma=1.202, + factors_init_value=1.303, + **ALL_REQ_ARGS + ) assert fm.hyperparameters() == dict( - num_factors=str(ALL_REQ_ARGS['num_factors']), - predictor_type=ALL_REQ_ARGS['predictor_type'], - epochs='2', - clip_gradient='100.0', - eps='0.001', - rescale_grad='2.2', - bias_lr='0.01', - linear_lr='0.002', - factors_lr='0.0003', - bias_wd='0.0004', - linear_wd='1.01', - factors_wd='1.002', - bias_init_method='uniform', - bias_init_scale='0.1', - bias_init_sigma='0.05', - bias_init_value='2.002', - linear_init_method='constant', - linear_init_scale='0.02', - linear_init_sigma='0.003', - linear_init_value='1.0', - factors_init_method='normal', - factors_init_scale='1.101', - factors_init_sigma='1.202', - factors_init_value='1.303', + num_factors=str(ALL_REQ_ARGS["num_factors"]), + predictor_type=ALL_REQ_ARGS["predictor_type"], + epochs="2", + clip_gradient="100.0", + eps="0.001", + rescale_grad="2.2", + bias_lr="0.01", + linear_lr="0.002", + factors_lr="0.0003", + bias_wd="0.0004", + linear_wd="1.01", + factors_wd="1.002", + bias_init_method="uniform", + bias_init_scale="0.1", + bias_init_sigma="0.05", + bias_init_value="2.002", + linear_init_method="constant", + linear_init_scale="0.02", + linear_init_sigma="0.003", + linear_init_value="1.0", + factors_init_method="normal", + factors_init_scale="1.101", + factors_init_sigma="1.202", + factors_init_value="1.303", ) def test_image(sagemaker_session): fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert fm.train_image() == registry(REGION) + '/factorization-machines:1' + assert fm.train_image() == registry(REGION) + "/factorization-machines:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_factors', 'string'), - ('predictor_type', 0) -]) +@pytest.mark.parametrize( + "required_hyper_parameters, value", [("num_factors", "string"), ("predictor_type", 0)] +) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -135,10 +157,9 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_factors', 0), - ('predictor_type', 'string') -]) +@pytest.mark.parametrize( + "required_hyper_parameters, value", [("num_factors", 0), ("predictor_type", "string")] +) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -146,30 +167,33 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('epochs', 'string'), - ('clip_gradient', 'string'), - ('eps', 'string'), - ('rescale_grad', 'string'), - ('bias_lr', 'string'), - ('linear_lr', 'string'), - ('factors_lr', 'string'), - ('bias_wd', 'string'), - ('linear_wd', 'string'), - ('factors_wd', 'string'), - ('bias_init_method', 0), - ('bias_init_scale', 'string'), - ('bias_init_sigma', 'string'), - ('bias_init_value', 'string'), - ('linear_init_method', 0), - ('linear_init_scale', 'string'), - ('linear_init_sigma', 'string'), - ('linear_init_value', 'string'), - ('factors_init_method', 0), - ('factors_init_scale', 'string'), - ('factors_init_sigma', 'string'), - ('factors_init_value', 'string') -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("epochs", "string"), + ("clip_gradient", "string"), + ("eps", "string"), + ("rescale_grad", "string"), + ("bias_lr", "string"), + ("linear_lr", "string"), + ("factors_lr", "string"), + ("bias_wd", "string"), + ("linear_wd", "string"), + ("factors_wd", "string"), + ("bias_init_method", 0), + ("bias_init_scale", "string"), + ("bias_init_sigma", "string"), + ("bias_init_value", "string"), + ("linear_init_method", 0), + ("linear_init_scale", "string"), + ("linear_init_sigma", "string"), + ("linear_init_value", "string"), + ("factors_init_method", 0), + ("factors_init_scale", "string"), + ("factors_init_sigma", "string"), + ("factors_init_value", "string"), + ], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -177,24 +201,27 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('epochs', 0), - ('bias_lr', -1), - ('linear_lr', -1), - ('factors_lr', -1), - ('bias_wd', -1), - ('linear_wd', -1), - ('factors_wd', -1), - ('bias_init_method', 'string'), - ('bias_init_scale', -1), - ('bias_init_sigma', -1), - ('linear_init_method', 'string'), - ('linear_init_scale', -1), - ('linear_init_sigma', -1), - ('factors_init_method', 'string'), - ('factors_init_scale', -1), - ('factors_init_sigma', -1) -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("epochs", 0), + ("bias_lr", -1), + ("linear_lr", -1), + ("factors_lr", -1), + ("bias_wd", -1), + ("linear_wd", -1), + ("factors_wd", -1), + ("bias_init_method", "string"), + ("bias_init_scale", -1), + ("bias_init_sigma", -1), + ("linear_init_method", "string"), + ("linear_init_scale", -1), + ("linear_init_sigma", -1), + ("factors_init_method", "string"), + ("factors_init_scale", -1), + ("factors_init_sigma", -1), + ], +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -202,16 +229,23 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param FactorizationMachines(sagemaker_session=sagemaker_session, **test_params) -PREFIX = 'prefix' +PREFIX = "prefix" FEATURE_DIM = 10 MINI_BATCH_SIZE = 200 -@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit') +@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit(base_fit, sagemaker_session): - fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + fm = FactorizationMachines( + base_job_name="fm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) fm.fit(data, MINI_BATCH_SIZE) @@ -222,44 +256,72 @@ def test_call_fit(base_fit, sagemaker_session): def test_prepare_for_training_no_mini_batch_size(sagemaker_session): - fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + fm = FactorizationMachines( + base_job_name="fm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) fm._prepare_for_training(data) def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): - fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + fm = FactorizationMachines( + base_job_name="fm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): - fm._prepare_for_training(data, 'some') + fm._prepare_for_training(data, "some") def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session): - fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + fm = FactorizationMachines( + base_job_name="fm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): fm._prepare_for_training(data, 0) def test_model_image(sagemaker_session): fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) fm.fit(data, MINI_BATCH_SIZE) model = fm.create_model() - assert model.image == registry(REGION, 'factorization-machines') + '/factorization-machines:1' + assert model.image == registry(REGION, "factorization-machines") + "/factorization-machines:1" def test_predictor_type(sagemaker_session): fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) fm.fit(data, MINI_BATCH_SIZE) model = fm.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_fw_registry.py b/tests/unit/test_fw_registry.py index 1bdd042164..1b1b47884f 100644 --- a/tests/unit/test_fw_registry.py +++ b/tests/unit/test_fw_registry.py @@ -22,66 +22,148 @@ def test_registry_sparkml_serving(): - assert registry('us-west-1', 'sparkml-serving') == "746614075791.dkr.ecr.us-west-1.amazonaws.com" - assert registry('us-west-2', 'sparkml-serving') == "246618743249.dkr.ecr.us-west-2.amazonaws.com" - assert registry('us-east-1', 'sparkml-serving') == "683313688378.dkr.ecr.us-east-1.amazonaws.com" - assert registry('us-east-2', 'sparkml-serving') == "257758044811.dkr.ecr.us-east-2.amazonaws.com" - assert registry('ap-northeast-1', 'sparkml-serving') == "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com" - assert registry('ap-northeast-2', 'sparkml-serving') == "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com" - assert registry('ap-southeast-1', 'sparkml-serving') == "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com" - assert registry('ap-southeast-2', 'sparkml-serving') == "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com" - assert registry('ap-south-1', 'sparkml-serving') == "720646828776.dkr.ecr.ap-south-1.amazonaws.com" - assert registry('eu-west-1', 'sparkml-serving') == "141502667606.dkr.ecr.eu-west-1.amazonaws.com" - assert registry('eu-west-2', 'sparkml-serving') == "764974769150.dkr.ecr.eu-west-2.amazonaws.com" - assert registry('eu-central-1', 'sparkml-serving') == "492215442770.dkr.ecr.eu-central-1.amazonaws.com" - assert registry('ca-central-1', 'sparkml-serving') == "341280168497.dkr.ecr.ca-central-1.amazonaws.com" - assert registry('us-gov-west-1', 'sparkml-serving') == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com" - assert registry('us-iso-east-1', 'sparkml-serving') == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov" + assert ( + registry("us-west-1", "sparkml-serving") == "746614075791.dkr.ecr.us-west-1.amazonaws.com" + ) + assert ( + registry("us-west-2", "sparkml-serving") == "246618743249.dkr.ecr.us-west-2.amazonaws.com" + ) + assert ( + registry("us-east-1", "sparkml-serving") == "683313688378.dkr.ecr.us-east-1.amazonaws.com" + ) + assert ( + registry("us-east-2", "sparkml-serving") == "257758044811.dkr.ecr.us-east-2.amazonaws.com" + ) + assert ( + registry("ap-northeast-1", "sparkml-serving") + == "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com" + ) + assert ( + registry("ap-northeast-2", "sparkml-serving") + == "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com" + ) + assert ( + registry("ap-southeast-1", "sparkml-serving") + == "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com" + ) + assert ( + registry("ap-southeast-2", "sparkml-serving") + == "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com" + ) + assert ( + registry("ap-south-1", "sparkml-serving") == "720646828776.dkr.ecr.ap-south-1.amazonaws.com" + ) + assert ( + registry("eu-west-1", "sparkml-serving") == "141502667606.dkr.ecr.eu-west-1.amazonaws.com" + ) + assert ( + registry("eu-west-2", "sparkml-serving") == "764974769150.dkr.ecr.eu-west-2.amazonaws.com" + ) + assert ( + registry("eu-central-1", "sparkml-serving") + == "492215442770.dkr.ecr.eu-central-1.amazonaws.com" + ) + assert ( + registry("ca-central-1", "sparkml-serving") + == "341280168497.dkr.ecr.ca-central-1.amazonaws.com" + ) + assert ( + registry("us-gov-west-1", "sparkml-serving") + == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com" + ) + assert ( + registry("us-iso-east-1", "sparkml-serving") + == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov" + ) def test_registry_sklearn(): - assert registry('us-west-1', scikit_learn_framework_name) == "746614075791.dkr.ecr.us-west-1.amazonaws.com" - assert registry('us-west-2', scikit_learn_framework_name) == "246618743249.dkr.ecr.us-west-2.amazonaws.com" - assert registry('us-east-1', scikit_learn_framework_name) == "683313688378.dkr.ecr.us-east-1.amazonaws.com" - assert registry('us-east-2', scikit_learn_framework_name) == "257758044811.dkr.ecr.us-east-2.amazonaws.com" - assert registry('ap-northeast-1', scikit_learn_framework_name) == \ - "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com" - assert registry('ap-northeast-2', scikit_learn_framework_name) == \ - "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com" - assert registry('ap-southeast-1', scikit_learn_framework_name) == \ - "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com" - assert registry('ap-southeast-2', scikit_learn_framework_name) == \ - "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com" - assert registry('ap-south-1', scikit_learn_framework_name) == "720646828776.dkr.ecr.ap-south-1.amazonaws.com" - assert registry('eu-west-1', scikit_learn_framework_name) == "141502667606.dkr.ecr.eu-west-1.amazonaws.com" - assert registry('eu-west-2', scikit_learn_framework_name) == "764974769150.dkr.ecr.eu-west-2.amazonaws.com" - assert registry('eu-central-1', scikit_learn_framework_name) == "492215442770.dkr.ecr.eu-central-1.amazonaws.com" - assert registry('ca-central-1', scikit_learn_framework_name) == "341280168497.dkr.ecr.ca-central-1.amazonaws.com" - assert registry('us-gov-west-1', scikit_learn_framework_name) == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com" - assert registry('us-iso-east-1', scikit_learn_framework_name) == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov" + assert ( + registry("us-west-1", scikit_learn_framework_name) + == "746614075791.dkr.ecr.us-west-1.amazonaws.com" + ) + assert ( + registry("us-west-2", scikit_learn_framework_name) + == "246618743249.dkr.ecr.us-west-2.amazonaws.com" + ) + assert ( + registry("us-east-1", scikit_learn_framework_name) + == "683313688378.dkr.ecr.us-east-1.amazonaws.com" + ) + assert ( + registry("us-east-2", scikit_learn_framework_name) + == "257758044811.dkr.ecr.us-east-2.amazonaws.com" + ) + assert ( + registry("ap-northeast-1", scikit_learn_framework_name) + == "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com" + ) + assert ( + registry("ap-northeast-2", scikit_learn_framework_name) + == "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com" + ) + assert ( + registry("ap-southeast-1", scikit_learn_framework_name) + == "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com" + ) + assert ( + registry("ap-southeast-2", scikit_learn_framework_name) + == "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com" + ) + assert ( + registry("ap-south-1", scikit_learn_framework_name) + == "720646828776.dkr.ecr.ap-south-1.amazonaws.com" + ) + assert ( + registry("eu-west-1", scikit_learn_framework_name) + == "141502667606.dkr.ecr.eu-west-1.amazonaws.com" + ) + assert ( + registry("eu-west-2", scikit_learn_framework_name) + == "764974769150.dkr.ecr.eu-west-2.amazonaws.com" + ) + assert ( + registry("eu-central-1", scikit_learn_framework_name) + == "492215442770.dkr.ecr.eu-central-1.amazonaws.com" + ) + assert ( + registry("ca-central-1", scikit_learn_framework_name) + == "341280168497.dkr.ecr.ca-central-1.amazonaws.com" + ) + assert ( + registry("us-gov-west-1", scikit_learn_framework_name) + == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com" + ) + assert ( + registry("us-iso-east-1", scikit_learn_framework_name) + == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov" + ) def test_default_sklearn_image_uri(): - image_tag = '0.20.0-cpu-py3' - sklearn_image_uri = default_framework_uri(scikit_learn_framework_name, 'us-west-1', image_tag) - assert sklearn_image_uri == "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3" + image_tag = "0.20.0-cpu-py3" + sklearn_image_uri = default_framework_uri(scikit_learn_framework_name, "us-west-1", image_tag) + assert ( + sklearn_image_uri + == "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3" + ) def test_framework_invalid(): with pytest.raises(KeyError): - registry('us-west-2', 'dummy-value') + registry("us-west-2", "dummy-value") def test_framework_none(): with pytest.raises(KeyError): - registry('us-west-2', None) + registry("us-west-2", None) def test_region_invalid(): with pytest.raises(KeyError): - registry('us-west-5', 'scikit-learn') + registry("us-west-5", "scikit-learn") def test_region_none(): with pytest.raises(KeyError): - registry(None, 'scikit-learn') + registry(None, "scikit-learn") diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 0a7e82058a..05aca78e38 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -23,16 +23,16 @@ from sagemaker import fw_utils from sagemaker.utils import name_from_image -DATA_DIR = 'data_dir' -BUCKET_NAME = 'mybucket' -ROLE = 'Sagemaker' -REGION = 'us-west-2' -SCRIPT_PATH = 'script.py' -TIMESTAMP = '2017-10-10-14-14-15' +DATA_DIR = "data_dir" +BUCKET_NAME = "mybucket" +ROLE = "Sagemaker" +REGION = "us-west-2" +SCRIPT_PATH = "script.py" +TIMESTAMP = "2017-10-10-14-14-15" -MOCK_FRAMEWORK = 'mlfw' -MOCK_REGION = 'mars-south-3' -MOCK_ACCELERATOR = 'eia1.medium' +MOCK_FRAMEWORK = "mlfw" +MOCK_REGION = "mars-south-3" +MOCK_ACCELERATOR = "eia1.medium" @contextmanager @@ -45,107 +45,189 @@ def cd(path): @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session_mock = Mock(name='sagemaker_session', boto_session=boto_mock) - session_mock.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + boto_mock = Mock(name="boto_session", region_name=REGION) + session_mock = Mock(name="sagemaker_session", boto_session=boto_mock) + session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session_mock.expand_role = Mock(name="expand_role", return_value=ROLE) - session_mock.sagemaker_client.describe_training_job = \ - Mock(return_value={'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}}) + session_mock.sagemaker_client.describe_training_job = Mock( + return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + ) return session_mock def test_create_image_uri_cpu(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23') - assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2", "23" + ) + assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2" - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'local', '1.0rc', 'py2', '23') - assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "local", "1.0rc", "py2", "23" + ) + assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2" - image_uri = fw_utils.create_image_uri('us-gov-west-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23') - assert image_uri == '246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2' + image_uri = fw_utils.create_image_uri( + "us-gov-west-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2", "23" + ) + assert ( + image_uri == "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2" + ) - image_uri = fw_utils.create_image_uri('us-iso-east-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23') - assert image_uri == '744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2' + image_uri = fw_utils.create_image_uri( + "us-iso-east-1", MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py2", "23" + ) + assert image_uri == "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2" def test_create_image_uri_no_python(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', account='23') - assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", account="23" + ) + assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu" def test_create_image_uri_bad_python(): with pytest.raises(ValueError): - fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py0') + fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, "ml.c4.large", "1.0rc", "py0") def test_create_image_uri_gpu(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.p3.2xlarge', '1.0rc', 'py3', '23') - assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3", "23" + ) + assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3" - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'local_gpu', '1.0rc', 'py3', '23') - assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "local_gpu", "1.0rc", "py3", "23" + ) + assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3" def test_create_image_uri_accelerator_tfs(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, 'tensorflow-serving', 'ml.c4.large', '1.1.0', - accelerator_type='ml.eia1.large', account='23') - assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-serving-eia:1.1.0-cpu' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, + "tensorflow-serving", + "ml.c4.large", + "1.1.0", + accelerator_type="ml.eia1.large", + account="23", + ) + assert ( + image_uri + == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-serving-eia:1.1.0-cpu" + ) def test_create_image_uri_default_account(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.p3.2xlarge', '1.0rc', 'py3') - assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3" + ) + assert ( + image_uri == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3" + ) def test_create_image_uri_gov_cloud(): - image_uri = fw_utils.create_image_uri('us-gov-west-1', MOCK_FRAMEWORK, 'ml.p3.2xlarge', '1.0rc', 'py3') - assert image_uri == '246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3' + image_uri = fw_utils.create_image_uri( + "us-gov-west-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3" + ) + assert ( + image_uri == "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3" + ) def test_create_image_uri_accelerator_tf(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, 'tensorflow', 'ml.p3.2xlarge', '1.0rc', 'py3', - accelerator_type='ml.eia1.medium') - assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0rc-gpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, + "tensorflow", + "ml.p3.2xlarge", + "1.0rc", + "py3", + accelerator_type="ml.eia1.medium", + ) + assert ( + image_uri + == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0rc-gpu-py3" + ) def test_create_image_uri_accelerator_mxnet_serving(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, 'mxnet-serving', 'ml.p3.2xlarge', '1.0rc', 'py3', - accelerator_type='ml.eia1.medium') - assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0rc-gpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, + "mxnet-serving", + "ml.p3.2xlarge", + "1.0rc", + "py3", + accelerator_type="ml.eia1.medium", + ) + assert ( + image_uri + == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0rc-gpu-py3" + ) def test_create_image_uri_local_sagemaker_notebook_accelerator(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, 'mxnet', 'ml.p3.2xlarge', '1.0rc', 'py3', - accelerator_type='local_sagemaker_notebook') - assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0rc-gpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, + "mxnet", + "ml.p3.2xlarge", + "1.0rc", + "py3", + accelerator_type="local_sagemaker_notebook", + ) + assert ( + image_uri + == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0rc-gpu-py3" + ) def test_invalid_accelerator(): - error_message = '{} is not a valid SageMaker Elastic Inference accelerator type.'.format(MOCK_ACCELERATOR) + error_message = "{} is not a valid SageMaker Elastic Inference accelerator type.".format( + MOCK_ACCELERATOR + ) # accelerator type is missing 'ml.' prefix with pytest.raises(ValueError) as error: - fw_utils.create_image_uri(MOCK_REGION, 'tensorflow', 'ml.p3.2xlarge', '1.0.0', 'py3', - accelerator_type=MOCK_ACCELERATOR) + fw_utils.create_image_uri( + MOCK_REGION, + "tensorflow", + "ml.p3.2xlarge", + "1.0.0", + "py3", + accelerator_type=MOCK_ACCELERATOR, + ) assert error_message in str(error) def test_invalid_framework_accelerator(): - error_message = '{} is not supported with Amazon Elastic Inference.'.format(MOCK_FRAMEWORK) + error_message = "{} is not supported with Amazon Elastic Inference.".format(MOCK_FRAMEWORK) # accelerator was chosen for unsupported framework with pytest.raises(ValueError) as error: - fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.p3.2xlarge', '1.0.0', 'py3', - accelerator_type='ml.eia1.medium') + fw_utils.create_image_uri( + MOCK_REGION, + MOCK_FRAMEWORK, + "ml.p3.2xlarge", + "1.0.0", + "py3", + accelerator_type="ml.eia1.medium", + ) assert error_message in str(error) def test_invalid_framework_accelerator_with_neo(): - error_message = 'Neo does not support Amazon Elastic Inference.'.format(MOCK_FRAMEWORK) + error_message = "Neo does not support Amazon Elastic Inference.".format(MOCK_FRAMEWORK) # accelerator was chosen for unsupported framework with pytest.raises(ValueError) as error: - fw_utils.create_image_uri(MOCK_REGION, 'tensorflow', 'ml.p3.2xlarge', '1.0.0', 'py3', - accelerator_type='ml.eia1.medium', optimized_families=['c5', 'p3']) + fw_utils.create_image_uri( + MOCK_REGION, + "tensorflow", + "ml.p3.2xlarge", + "1.0.0", + "py3", + accelerator_type="ml.eia1.medium", + optimized_families=["c5", "p3"], + ) assert error_message in str(error) @@ -153,81 +235,104 @@ def test_invalid_framework_accelerator_with_neo(): def test_invalid_instance_type(): # instance type is missing 'ml.' prefix with pytest.raises(ValueError): - fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'p3.2xlarge', '1.0.0', 'py3') + fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, "p3.2xlarge", "1.0.0", "py3") def test_optimized_family(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.p3.2xlarge', '1.0.0', 'py3', - optimized_families=['c5', 'p3']) - assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-p3-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, + MOCK_FRAMEWORK, + "ml.p3.2xlarge", + "1.0.0", + "py3", + optimized_families=["c5", "p3"], + ) + assert ( + image_uri == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-p3-py3" + ) def test_unoptimized_cpu_family(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.m4.xlarge', '1.0.0', 'py3', - optimized_families=['c5', 'p3']) - assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-cpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "ml.m4.xlarge", "1.0.0", "py3", optimized_families=["c5", "p3"] + ) + assert ( + image_uri == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-cpu-py3" + ) def test_unoptimized_gpu_family(): - image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.p2.xlarge', '1.0.0', 'py3', - optimized_families=['c5', 'p3']) - assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-gpu-py3' + image_uri = fw_utils.create_image_uri( + MOCK_REGION, MOCK_FRAMEWORK, "ml.p2.xlarge", "1.0.0", "py3", optimized_families=["c5", "p3"] + ) + assert ( + image_uri == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-gpu-py3" + ) def test_tar_and_upload_dir_s3(sagemaker_session): - bucket = 'mybucket' - s3_key_prefix = 'something/source' - script = 'mnist.py' - directory = 's3://m' - result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory) + bucket = "mybucket" + s3_key_prefix = "something/source" + script = "mnist.py" + directory = "s3://m" + result = fw_utils.tar_and_upload_dir( + sagemaker_session, bucket, s3_key_prefix, script, directory + ) - assert result == fw_utils.UploadedCode('s3://m', 'mnist.py') + assert result == fw_utils.UploadedCode("s3://m", "mnist.py") -@patch('sagemaker.utils') +@patch("sagemaker.utils") def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session): - bucket = 'mybucket' - s3_key_prefix = 'something/source' - script = 'mnist.py' - kms_key = 'kms-key' - result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, kms_key=kms_key) - - assert result == fw_utils.UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script) - - extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key} - obj = sagemaker_session.resource('s3').Object('', '') + bucket = "mybucket" + s3_key_prefix = "something/source" + script = "mnist.py" + kms_key = "kms-key" + result = fw_utils.tar_and_upload_dir( + sagemaker_session, bucket, s3_key_prefix, script, kms_key=kms_key + ) + + assert result == fw_utils.UploadedCode( + "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script + ) + + extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key} + obj = sagemaker_session.resource("s3").Object("", "") obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args) def test_validate_source_dir_does_not_exits(sagemaker_session): - script = 'mnist.py' - directory = ' !@#$%^&*()path probably in not there.!@#$%^&*()' + script = "mnist.py" + directory = " !@#$%^&*()path probably in not there.!@#$%^&*()" with pytest.raises(ValueError): fw_utils.validate_source_dir(script, directory) def test_validate_source_dir_is_not_directory(sagemaker_session): - script = 'mnist.py' + script = "mnist.py" directory = inspect.getfile(inspect.currentframe()) with pytest.raises(ValueError): fw_utils.validate_source_dir(script, directory) def test_validate_source_dir_file_not_in_dir(): - script = ' !@#$%^&*() .myscript. !@#$%^&*() ' - directory = '.' + script = " !@#$%^&*() .myscript. !@#$%^&*() " + directory = "." with pytest.raises(ValueError): fw_utils.validate_source_dir(script, directory) def test_tar_and_upload_dir_not_s3(sagemaker_session): - bucket = 'mybucket' - s3_key_prefix = 'something/source' + bucket = "mybucket" + s3_key_prefix = "something/source" script = os.path.basename(__file__) directory = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) - result = fw_utils.tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory) - assert result == fw_utils.UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), - script) + result = fw_utils.tar_and_upload_dir( + sagemaker_session, bucket, s3_key_prefix, script, directory + ) + assert result == fw_utils.UploadedCode( + "s3://{}/{}/sourcedir.tar.gz".format(bucket, s3_key_prefix), script + ) def file_tree(tmpdir, files=None, folders=None): @@ -243,185 +348,233 @@ def file_tree(tmpdir, files=None, folders=None): def test_tar_and_upload_dir_no_directory(sagemaker_session, tmpdir): - source_dir = file_tree(tmpdir, ['train.py']) - entrypoint = os.path.join(source_dir, 'train.py') + source_dir = file_tree(tmpdir, ["train.py"]) + entrypoint = os.path.join(source_dir, "train.py") - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", entrypoint, None + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='train.py') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py" + ) - assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir) def test_tar_and_upload_dir_no_directory_only_entrypoint(sagemaker_session, tmpdir): - source_dir = file_tree(tmpdir, ['train.py', 'not_me.py']) - entrypoint = os.path.join(source_dir, 'train.py') + source_dir = file_tree(tmpdir, ["train.py", "not_me.py"]) + entrypoint = os.path.join(source_dir, "train.py") - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", entrypoint, None + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='train.py') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py" + ) - assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir) def test_tar_and_upload_dir_no_directory_bare_filename(sagemaker_session, tmpdir): - source_dir = file_tree(tmpdir, ['train.py']) - entrypoint = 'train.py' + source_dir = file_tree(tmpdir, ["train.py"]) + entrypoint = "train.py" - with patch('shutil.rmtree'): + with patch("shutil.rmtree"): with cd(source_dir): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', entrypoint, None) + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", entrypoint, None + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='train.py') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py" + ) - assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir) def test_tar_and_upload_dir_with_directory(sagemaker_session, tmpdir): - file_tree(tmpdir, ['src-dir/train.py']) - source_dir = os.path.join(str(tmpdir), 'src-dir') + file_tree(tmpdir, ["src-dir/train.py"]) + source_dir = os.path.join(str(tmpdir), "src-dir") - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', source_dir) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", "train.py", source_dir + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='train.py') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py" + ) - assert {'/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/train.py"} == list_source_dir_files(sagemaker_session, tmpdir) def test_tar_and_upload_dir_with_subdirectory(sagemaker_session, tmpdir): - file_tree(tmpdir, ['src-dir/sub/train.py']) - source_dir = os.path.join(str(tmpdir), 'src-dir') + file_tree(tmpdir, ["src-dir/sub/train.py"]) + source_dir = os.path.join(str(tmpdir), "src-dir") - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', source_dir) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", "train.py", source_dir + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='train.py') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py" + ) - assert {'/sub/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/sub/train.py"} == list_source_dir_files(sagemaker_session, tmpdir) def test_tar_and_upload_dir_with_directory_and_files(sagemaker_session, tmpdir): - file_tree(tmpdir, ['src-dir/train.py', 'src-dir/laucher', 'src-dir/module/__init__.py']) - source_dir = os.path.join(str(tmpdir), 'src-dir') + file_tree(tmpdir, ["src-dir/train.py", "src-dir/laucher", "src-dir/module/__init__.py"]) + source_dir = os.path.join(str(tmpdir), "src-dir") - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'train.py', source_dir) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", "train.py", source_dir + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='train.py') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="train.py" + ) - assert {'/laucher', '/module/__init__.py', '/train.py'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/laucher", "/module/__init__.py", "/train.py"} == list_source_dir_files( + sagemaker_session, tmpdir + ) def test_tar_and_upload_dir_with_directories_and_files(sagemaker_session, tmpdir): - file_tree(tmpdir, ['src-dir/a/b', 'src-dir/a/b2', 'src-dir/x/y', 'src-dir/x/y2', 'src-dir/z']) - source_dir = os.path.join(str(tmpdir), 'src-dir') + file_tree(tmpdir, ["src-dir/a/b", "src-dir/a/b2", "src-dir/x/y", "src-dir/x/y2", "src-dir/z"]) + source_dir = os.path.join(str(tmpdir), "src-dir") - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'a/b', source_dir) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", "a/b", source_dir + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='a/b') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="a/b" + ) - assert {'/a/b', '/a/b2', '/x/y', '/x/y2', '/z'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/a/b", "/a/b2", "/x/y", "/x/y2", "/z"} == list_source_dir_files( + sagemaker_session, tmpdir + ) def test_tar_and_upload_dir_with_many_folders(sagemaker_session, tmpdir): - file_tree(tmpdir, ['src-dir/a/b', 'src-dir/a/b2', 'common/x/y', 'common/x/y2', 't/y/z']) - source_dir = os.path.join(str(tmpdir), 'src-dir') - dependencies = [os.path.join(str(tmpdir), 'common'), os.path.join(str(tmpdir), 't', 'y', 'z')] + file_tree(tmpdir, ["src-dir/a/b", "src-dir/a/b2", "common/x/y", "common/x/y2", "t/y/z"]) + source_dir = os.path.join(str(tmpdir), "src-dir") + dependencies = [os.path.join(str(tmpdir), "common"), os.path.join(str(tmpdir), "t", "y", "z")] - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', - 'pipeline.py', source_dir, dependencies) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, "bucket", "prefix", "pipeline.py", source_dir, dependencies + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='pipeline.py') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="pipeline.py" + ) - assert {'/a/b', '/a/b2', '/common/x/y', '/common/x/y2', '/z'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/a/b", "/a/b2", "/common/x/y", "/common/x/y2", "/z"} == list_source_dir_files( + sagemaker_session, tmpdir + ) def test_test_tar_and_upload_dir_with_subfolders(sagemaker_session, tmpdir): - file_tree(tmpdir, ['a/b/c', 'a/b/c2']) - root = file_tree(tmpdir, ['x/y/z', 'x/y/z2']) + file_tree(tmpdir, ["a/b/c", "a/b/c2"]) + root = file_tree(tmpdir, ["x/y/z", "x/y/z2"]) - with patch('shutil.rmtree'): - result = fw_utils.tar_and_upload_dir(sagemaker_session, 'bucket', 'prefix', 'b/c', - os.path.join(root, 'a'), [os.path.join(root, 'x')]) + with patch("shutil.rmtree"): + result = fw_utils.tar_and_upload_dir( + sagemaker_session, + "bucket", + "prefix", + "b/c", + os.path.join(root, "a"), + [os.path.join(root, "x")], + ) - assert result == fw_utils.UploadedCode(s3_prefix='s3://bucket/prefix/sourcedir.tar.gz', - script_name='b/c') + assert result == fw_utils.UploadedCode( + s3_prefix="s3://bucket/prefix/sourcedir.tar.gz", script_name="b/c" + ) - assert {'/b/c', '/b/c2', '/x/y/z', '/x/y/z2'} == list_source_dir_files(sagemaker_session, tmpdir) + assert {"/b/c", "/b/c2", "/x/y/z", "/x/y/z2"} == list_source_dir_files( + sagemaker_session, tmpdir + ) def list_source_dir_files(sagemaker_session, tmpdir): - source_dir_tar = sagemaker_session.resource('s3').Object().upload_file.call_args[0][0] + source_dir_tar = sagemaker_session.resource("s3").Object().upload_file.call_args[0][0] - source_dir_files = list_tar_files('/opt/ml/code/', source_dir_tar, tmpdir) + source_dir_files = list_tar_files("/opt/ml/code/", source_dir_tar, tmpdir) return source_dir_files def list_tar_files(folder, tar_ball, tmpdir): startpath = str(tmpdir.ensure(folder, dir=True)) - with tarfile.open(name=tar_ball, mode='r:gz') as t: + with tarfile.open(name=tar_ball, mode="r:gz") as t: t.extractall(path=startpath) def walk(): for root, dirs, files in os.walk(startpath): - path = root.replace(startpath, '') + path = root.replace(startpath, "") for f in files: - yield '%s/%s' % (path, f) + yield "%s/%s" % (path, f) result = set(walk()) return result if result else {} def test_framework_name_from_image_mxnet(): - image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.1-gpu-py3' - assert ('mxnet', 'py3', '1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name) + image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.1-gpu-py3" + assert ("mxnet", "py3", "1.1-gpu-py3", None) == fw_utils.framework_name_from_image(image_name) def test_framework_name_from_image_mxnet_in_gov(): - image_name = '123.dkr.ecr.region-name.c2s.ic.gov/sagemaker-mxnet:1.1-gpu-py3' - assert ('mxnet', 'py3', '1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name) + image_name = "123.dkr.ecr.region-name.c2s.ic.gov/sagemaker-mxnet:1.1-gpu-py3" + assert ("mxnet", "py3", "1.1-gpu-py3", None) == fw_utils.framework_name_from_image(image_name) def test_framework_name_from_image_tf(): - image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.6-cpu-py2' - assert ('tensorflow', 'py2', '1.6-cpu-py2', None) == fw_utils.framework_name_from_image(image_name) + image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.6-cpu-py2" + assert ("tensorflow", "py2", "1.6-cpu-py2", None) == fw_utils.framework_name_from_image( + image_name + ) def test_framework_name_from_image_tf_scriptmode(): - image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.12-cpu-py3' - assert ('tensorflow', 'py3', '1.12-cpu-py3', 'scriptmode') == fw_utils.framework_name_from_image(image_name) + image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.12-cpu-py3" + assert ( + "tensorflow", + "py3", + "1.12-cpu-py3", + "scriptmode", + ) == fw_utils.framework_name_from_image(image_name) def test_framework_name_from_image_rl(): - image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3' - assert ('mxnet', 'py3', 'toolkit1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name) + image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3" + assert ("mxnet", "py3", "toolkit1.1-gpu-py3", None) == fw_utils.framework_name_from_image( + image_name + ) def test_legacy_name_from_framework_image(): - image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-gpu:2.5.6-gpu-py2' + image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-gpu:2.5.6-gpu-py2" framework, py_ver, tag, _ = fw_utils.framework_name_from_image(image_name) - assert framework == 'mxnet' - assert py_ver == 'py3' - assert tag == '2.5.6-gpu-py2' + assert framework == "mxnet" + assert py_ver == "py3" + assert tag == "2.5.6-gpu-py2" def test_legacy_name_from_wrong_framework(): framework, py_ver, tag, _ = fw_utils.framework_name_from_image( - '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1') + "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1" + ) assert framework is None assert py_ver is None assert tag is None @@ -429,7 +582,8 @@ def test_legacy_name_from_wrong_framework(): def test_legacy_name_from_wrong_python(): framework, py_ver, tag, _ = fw_utils.framework_name_from_image( - '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1') + "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1" + ) assert framework is None assert py_ver is None assert tag is None @@ -437,71 +591,72 @@ def test_legacy_name_from_wrong_python(): def test_legacy_name_from_wrong_device(): framework, py_ver, tag, _ = fw_utils.framework_name_from_image( - '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1') + "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1" + ) assert framework is None assert py_ver is None assert tag is None def test_legacy_name_from_image_any_tag(): - image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:any-tag' + image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:any-tag" framework, py_ver, tag, _ = fw_utils.framework_name_from_image(image_name) - assert framework == 'tensorflow' - assert py_ver == 'py2' - assert tag == 'any-tag' + assert framework == "tensorflow" + assert py_ver == "py2" + assert tag == "any-tag" def test_framework_version_from_tag(): - version = fw_utils.framework_version_from_tag('1.5rc-keras-gpu-py2') - assert version == '1.5rc-keras' + version = fw_utils.framework_version_from_tag("1.5rc-keras-gpu-py2") + assert version == "1.5rc-keras" def test_framework_version_from_tag_other(): - version = fw_utils.framework_version_from_tag('weird-tag-py2') + version = fw_utils.framework_version_from_tag("weird-tag-py2") assert version is None def test_parse_s3_url(): - bucket, key_prefix = fw_utils.parse_s3_url('s3://bucket/code_location') - assert 'bucket' == bucket - assert 'code_location' == key_prefix + bucket, key_prefix = fw_utils.parse_s3_url("s3://bucket/code_location") + assert "bucket" == bucket + assert "code_location" == key_prefix def test_parse_s3_url_fail(): with pytest.raises(ValueError) as error: - fw_utils.parse_s3_url('t3://code_location') - assert 'Expecting \'s3\' scheme' in str(error) + fw_utils.parse_s3_url("t3://code_location") + assert "Expecting 's3' scheme" in str(error) def test_model_code_key_prefix_with_all_values_present(): - key_prefix = fw_utils.model_code_key_prefix('prefix', 'model_name', 'image_name') - assert key_prefix == 'prefix/model_name' + key_prefix = fw_utils.model_code_key_prefix("prefix", "model_name", "image_name") + assert key_prefix == "prefix/model_name" def test_model_code_key_prefix_with_no_prefix_and_all_other_values_present(): - key_prefix = fw_utils.model_code_key_prefix(None, 'model_name', 'image_name') - assert key_prefix == 'model_name' + key_prefix = fw_utils.model_code_key_prefix(None, "model_name", "image_name") + assert key_prefix == "model_name" -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_model_code_key_prefix_with_only_image_present(time): - key_prefix = fw_utils.model_code_key_prefix(None, None, 'image_name') - assert key_prefix == name_from_image('image_name') + key_prefix = fw_utils.model_code_key_prefix(None, None, "image_name") + assert key_prefix == name_from_image("image_name") -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_model_code_key_prefix_and_image_present(time): - key_prefix = fw_utils.model_code_key_prefix('prefix', None, 'image_name') - assert key_prefix == 'prefix/' + name_from_image('image_name') + key_prefix = fw_utils.model_code_key_prefix("prefix", None, "image_name") + assert key_prefix == "prefix/" + name_from_image("image_name") def test_model_code_key_prefix_with_prefix_present_and_others_none_fail(): with pytest.raises(TypeError) as error: - fw_utils.model_code_key_prefix('prefix', None, None) - assert 'expected string' in str(error) + fw_utils.model_code_key_prefix("prefix", None, None) + assert "expected string" in str(error) def test_model_code_key_prefix_with_all_none_fail(): with pytest.raises(TypeError) as error: fw_utils.model_code_key_prefix(None, None, None) - assert 'expected string' in str(error) + assert "expected string" in str(error) diff --git a/tests/unit/test_hyperparameter.py b/tests/unit/test_hyperparameter.py index ecccee254a..3fee3e8025 100644 --- a/tests/unit/test_hyperparameter.py +++ b/tests/unit/test_hyperparameter.py @@ -18,9 +18,9 @@ class Test(object): - blank = Hyperparameter(name='some-name', data_type=int) - elizabeth = Hyperparameter(name='elizabeth') - validated = Hyperparameter(name='validated', validate=lambda value: value > 55, data_type=int) + blank = Hyperparameter(name="some-name", data_type=int) + elizabeth = Hyperparameter(name="elizabeth") + validated = Hyperparameter(name="validated", validate=lambda value: value > 55, data_type=int) def test_blank_access(): @@ -40,7 +40,7 @@ def test_delete(): x = Test() x.blank = 97 assert x.blank == 97 - del(x.blank) + del x.blank with pytest.raises(AttributeError): x.blank @@ -49,7 +49,7 @@ def test_name(): x = Test() with pytest.raises(AttributeError) as excinfo: x.elizabeth - assert 'elizabeth' in excinfo + assert "elizabeth" in excinfo def test_validated(): @@ -62,7 +62,7 @@ def test_validated(): def test_data_type(): x = Test() x.validated = 66 - assert type(x.validated) == Test.__dict__['validated'].data_type + assert type(x.validated) == Test.__dict__["validated"].data_type def test_from_string(): diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 710f85ceb2..d37be22a4b 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -31,56 +31,58 @@ import sagemaker from sagemaker.local.image import _SageMakerContainer, _aws_credentials -REGION = 'us-west-2' -BUCKET_NAME = 'mybucket' -EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole' -TRAINING_JOB_NAME = 'my-job' +REGION = "us-west-2" +BUCKET_NAME = "mybucket" +EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" +TRAINING_JOB_NAME = "my-job" INPUT_DATA_CONFIG = [ { - 'ChannelName': 'a', - 'DataUri': 'file:///tmp/source1', - 'DataSource': { - 'FileDataSource': { - 'FileDataDistributionType': 'FullyReplicated', - 'FileUri': 'file:///tmp/source1' + "ChannelName": "a", + "DataUri": "file:///tmp/source1", + "DataSource": { + "FileDataSource": { + "FileDataDistributionType": "FullyReplicated", + "FileUri": "file:///tmp/source1", } - } + }, }, { - 'ChannelName': 'b', - 'DataUri': 's3://my-own-bucket/prefix', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://my-own-bucket/prefix' + "ChannelName": "b", + "DataUri": "s3://my-own-bucket/prefix", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://my-own-bucket/prefix", } - } - } + }, + }, ] -OUTPUT_DATA_CONFIG = { - 'S3OutputPath': '' -} +OUTPUT_DATA_CONFIG = {"S3OutputPath": ""} -HYPERPARAMETERS = {'a': 1, - 'b': json.dumps('bee'), - 'sagemaker_submit_directory': json.dumps('s3://my_bucket/code')} +HYPERPARAMETERS = { + "a": 1, + "b": json.dumps("bee"), + "sagemaker_submit_directory": json.dumps("s3://my_bucket/code"), +} -LOCAL_CODE_HYPERPARAMETERS = {'a': 1, - 'b': 2, - 'sagemaker_submit_directory': json.dumps('file:///tmp/code')} +LOCAL_CODE_HYPERPARAMETERS = { + "a": 1, + "b": 2, + "sagemaker_submit_directory": json.dumps("file:///tmp/code"), +} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'} - boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = [] + boto_mock = Mock(name="boto_session", region_name=REGION) + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"} + boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.return_value = [] sms = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.expand_role = Mock(return_value=EXPANDED_ROLE) return sms @@ -93,30 +95,30 @@ def assert_all_lowercase(hosts): for host in hosts: assert host.lower() == host - sagemaker_container = _SageMakerContainer('local', 2, 'my-image', sagemaker_session=Mock()) + sagemaker_container = _SageMakerContainer("local", 2, "my-image", sagemaker_session=Mock()) assert_all_lowercase(sagemaker_container.hosts) - sagemaker_container = _SageMakerContainer('local', 10, 'my-image', sagemaker_session=Mock()) + sagemaker_container = _SageMakerContainer("local", 10, "my-image", sagemaker_session=Mock()) assert_all_lowercase(sagemaker_container.hosts) - sagemaker_container = _SageMakerContainer('local', 1, 'my-image', sagemaker_session=Mock()) + sagemaker_container = _SageMakerContainer("local", 1, "my-image", sagemaker_session=Mock()) assert_all_lowercase(sagemaker_container.hosts) -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_write_config_file(LocalSession, tmpdir): - sagemaker_container = _SageMakerContainer('local', 2, 'my-image') - sagemaker_container.container_root = str(tmpdir.mkdir('container-root')) + sagemaker_container = _SageMakerContainer("local", 2, "my-image") + sagemaker_container.container_root = str(tmpdir.mkdir("container-root")) host = "algo-1" sagemaker.local.image._create_config_file_directories(sagemaker_container.container_root, host) container_root = sagemaker_container.container_root - config_file_root = os.path.join(container_root, host, 'input', 'config') + config_file_root = os.path.join(container_root, host, "input", "config") - hyperparameters_file = os.path.join(config_file_root, 'hyperparameters.json') - resource_config_file = os.path.join(config_file_root, 'resourceconfig.json') - input_data_config_file = os.path.join(config_file_root, 'inputdataconfig.json') + hyperparameters_file = os.path.join(config_file_root, "hyperparameters.json") + resource_config_file = os.path.join(config_file_root, "resourceconfig.json") + input_data_config_file = os.path.join(config_file_root, "inputdataconfig.json") # write the config files, and then lets check they exist and have the right content. sagemaker_container.write_config_files(host, HYPERPARAMETERS, INPUT_DATA_CONFIG) @@ -135,51 +137,51 @@ def test_write_config_file(LocalSession, tmpdir): assert hyperparameters_data[k] == v # Validate Resource Config - assert resource_config_data['current_host'] == host - assert resource_config_data['hosts'] == sagemaker_container.hosts + assert resource_config_data["current_host"] == host + assert resource_config_data["hosts"] == sagemaker_container.hosts # Validate Input Data Config for channel in INPUT_DATA_CONFIG: - assert channel['ChannelName'] in input_data_config_data + assert channel["ChannelName"] in input_data_config_data -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_write_config_files_input_content_type(LocalSession, tmpdir): - sagemaker_container = _SageMakerContainer('local', 1, 'my-image') - sagemaker_container.container_root = str(tmpdir.mkdir('container-root')) - host = 'algo-1' + sagemaker_container = _SageMakerContainer("local", 1, "my-image") + sagemaker_container.container_root = str(tmpdir.mkdir("container-root")) + host = "algo-1" sagemaker.local.image._create_config_file_directories(sagemaker_container.container_root, host) container_root = sagemaker_container.container_root - config_file_root = os.path.join(container_root, host, 'input', 'config') + config_file_root = os.path.join(container_root, host, "input", "config") - input_data_config_file = os.path.join(config_file_root, 'inputdataconfig.json') + input_data_config_file = os.path.join(config_file_root, "inputdataconfig.json") # write the config files, and then lets check they exist and have the right content. input_data_config = [ { - 'ChannelName': 'channel_a', - 'DataUri': 'file:///tmp/source1', - 'ContentType': 'text/csv', - 'DataSource': { - 'FileDataSource': { - 'FileDataDistributionType': 'FullyReplicated', - 'FileUri': 'file:///tmp/source1' + "ChannelName": "channel_a", + "DataUri": "file:///tmp/source1", + "ContentType": "text/csv", + "DataSource": { + "FileDataSource": { + "FileDataDistributionType": "FullyReplicated", + "FileUri": "file:///tmp/source1", } - } + }, }, { - 'ChannelName': 'channel_b', - 'DataUri': 's3://my-own-bucket/prefix', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://my-own-bucket/prefix' + "ChannelName": "channel_b", + "DataUri": "s3://my-own-bucket/prefix", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://my-own-bucket/prefix", } - } - } + }, + }, ] sagemaker_container.write_config_files(host, HYPERPARAMETERS, input_data_config) @@ -187,87 +189,111 @@ def test_write_config_files_input_content_type(LocalSession, tmpdir): parsed_input_config = json.load(open(input_data_config_file)) # Validate Input Data Config for channel in input_data_config: - assert channel['ChannelName'] in parsed_input_config + assert channel["ChannelName"] in parsed_input_config # Channel A has a content type - assert 'ContentType' in parsed_input_config['channel_a'] - assert parsed_input_config['channel_a']['ContentType'] == 'text/csv' + assert "ContentType" in parsed_input_config["channel_a"] + assert parsed_input_config["channel_a"]["ContentType"] == "text/csv" # Channel B does not have content type - assert 'ContentType' not in parsed_input_config['channel_b'] + assert "ContentType" not in parsed_input_config["channel_b"] -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_retrieve_artifacts(LocalSession, tmpdir): - sagemaker_container = _SageMakerContainer('local', 2, 'my-image') - sagemaker_container.hosts = ['algo-1', 'algo-2'] # avoid any randomness - sagemaker_container.container_root = str(tmpdir.mkdir('container-root')) + sagemaker_container = _SageMakerContainer("local", 2, "my-image") + sagemaker_container.hosts = ["algo-1", "algo-2"] # avoid any randomness + sagemaker_container.container_root = str(tmpdir.mkdir("container-root")) - volume1 = os.path.join(sagemaker_container.container_root, 'algo-1') - volume2 = os.path.join(sagemaker_container.container_root, 'algo-2') + volume1 = os.path.join(sagemaker_container.container_root, "algo-1") + volume2 = os.path.join(sagemaker_container.container_root, "algo-2") os.mkdir(volume1) os.mkdir(volume2) compose_data = { - 'services': { - 'algo-1': { - 'volumes': ['%s:/opt/ml/model' % os.path.join(volume1, 'model'), - '%s:/opt/ml/output' % os.path.join(volume1, 'output')] + "services": { + "algo-1": { + "volumes": [ + "%s:/opt/ml/model" % os.path.join(volume1, "model"), + "%s:/opt/ml/output" % os.path.join(volume1, "output"), + ] + }, + "algo-2": { + "volumes": [ + "%s:/opt/ml/model" % os.path.join(volume2, "model"), + "%s:/opt/ml/output" % os.path.join(volume2, "output"), + ] }, - 'algo-2': { - 'volumes': ['%s:/opt/ml/model' % os.path.join(volume2, 'model'), - '%s:/opt/ml/output' % os.path.join(volume2, 'output')] - } } } dirs = [ - ('model', volume1), ('model/data', volume1), - ('model', volume2), ('model/data', volume2), ('model/tmp', volume2), - ('output', volume1), ('output/data', volume1), - ('output', volume2), ('output/data', volume2), ('output/log', volume2) + ("model", volume1), + ("model/data", volume1), + ("model", volume2), + ("model/data", volume2), + ("model/tmp", volume2), + ("output", volume1), + ("output/data", volume1), + ("output", volume2), + ("output/data", volume2), + ("output/log", volume2), ] files = [ - ('model/data/model.json', volume1), ('model/data/variables.csv', volume1), - ('model/data/model.json', volume2), ('model/data/variables2.csv', volume2), - ('model/tmp/something-else.json', volume2), - ('output/data/loss.json', volume1), ('output/data/accuracy.json', volume1), - ('output/data/loss.json', volume2), ('output/data/accuracy2.json', volume2), - ('output/log/warnings.txt', volume2) + ("model/data/model.json", volume1), + ("model/data/variables.csv", volume1), + ("model/data/model.json", volume2), + ("model/data/variables2.csv", volume2), + ("model/tmp/something-else.json", volume2), + ("output/data/loss.json", volume1), + ("output/data/accuracy.json", volume1), + ("output/data/loss.json", volume2), + ("output/data/accuracy2.json", volume2), + ("output/log/warnings.txt", volume2), ] - expected_model = ['data', 'data/model.json', 'data/variables.csv', - 'data/variables2.csv', 'tmp/something-else.json'] - expected_output = ['data', 'log', 'data/loss.json', 'data/accuracy.json', 'data/accuracy2.json', - 'log/warnings.txt'] + expected_model = [ + "data", + "data/model.json", + "data/variables.csv", + "data/variables2.csv", + "tmp/something-else.json", + ] + expected_output = [ + "data", + "log", + "data/loss.json", + "data/accuracy.json", + "data/accuracy2.json", + "log/warnings.txt", + ] for d, volume in dirs: os.mkdir(os.path.join(volume, d)) # create all the files for f, volume in files: - open(os.path.join(volume, f), 'a').close() + open(os.path.join(volume, f), "a").close() - output_path = str(tmpdir.mkdir('exported_files')) - output_data_config = { - 'S3OutputPath': 'file://%s' % output_path - } + output_path = str(tmpdir.mkdir("exported_files")) + output_data_config = {"S3OutputPath": "file://%s" % output_path} model_artifacts = sagemaker_container.retrieve_artifacts( - compose_data, output_data_config, sagemaker_session).replace('file://', '') + compose_data, output_data_config, sagemaker_session + ).replace("file://", "") artifacts = os.path.dirname(model_artifacts) # we have both the tar files - assert set(os.listdir(artifacts)) == {'model.tar.gz', 'output.tar.gz'} + assert set(os.listdir(artifacts)) == {"model.tar.gz", "output.tar.gz"} # check that the tar files contain what we expect - tar = tarfile.open(os.path.join(output_path, 'model.tar.gz')) + tar = tarfile.open(os.path.join(output_path, "model.tar.gz")) model_tar_files = [m.name for m in tar.getmembers()] for f in expected_model: assert f in model_tar_files - tar = tarfile.open(os.path.join(output_path, 'output.tar.gz')) + tar = tarfile.open(os.path.join(output_path, "output.tar.gz")) output_tar_files = [m.name for m in tar.getmembers()] for f in expected_output: assert f in output_tar_files @@ -276,364 +302,430 @@ def test_retrieve_artifacts(LocalSession, tmpdir): def test_stream_output(): # it should raise an exception if the command fails with pytest.raises(RuntimeError): - p = subprocess.Popen(['ls', '/some/unknown/path'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + p = subprocess.Popen( + ["ls", "/some/unknown/path"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) sagemaker.local.image._stream_output(p) - p = subprocess.Popen(['echo', 'hello'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + p = subprocess.Popen(["echo", "hello"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) exit_code = sagemaker.local.image._stream_output(p) assert exit_code == 0 def test_check_output(): with pytest.raises(Exception): - sagemaker.local.image._check_output(['ls', '/some/unknown/path']) + sagemaker.local.image._check_output(["ls", "/some/unknown/path"]) - msg = 'hello!' + msg = "hello!" - output = sagemaker.local.image._check_output(['echo', msg]).strip() + output = sagemaker.local.image._check_output(["echo", msg]).strip() assert output == msg output = sagemaker.local.image._check_output("echo %s" % msg).strip() assert output == msg -@patch('sagemaker.local.local_session.LocalSession', Mock()) -@patch('sagemaker.local.image._stream_output', Mock()) -@patch('sagemaker.local.image._SageMakerContainer._cleanup') -@patch('sagemaker.local.image._SageMakerContainer.retrieve_artifacts') -@patch('sagemaker.local.data.get_data_source_instance') -@patch('subprocess.Popen') -def test_train(popen, get_data_source_instance, retrieve_artifacts, cleanup, tmpdir, sagemaker_session): +@patch("sagemaker.local.local_session.LocalSession", Mock()) +@patch("sagemaker.local.image._stream_output", Mock()) +@patch("sagemaker.local.image._SageMakerContainer._cleanup") +@patch("sagemaker.local.image._SageMakerContainer.retrieve_artifacts") +@patch("sagemaker.local.data.get_data_source_instance") +@patch("subprocess.Popen") +def test_train( + popen, get_data_source_instance, retrieve_artifacts, cleanup, tmpdir, sagemaker_session +): data_source = Mock() - data_source.get_root_dir.return_value = 'foo' + data_source.get_root_dir.return_value = "foo" get_data_source_instance.return_value = data_source - directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))] - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - side_effect=directories): + directories = [str(tmpdir.mkdir("container-root")), str(tmpdir.mkdir("data"))] + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", side_effect=directories + ): instance_count = 2 - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session) + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", instance_count, image, sagemaker_session=sagemaker_session + ) sagemaker_container.train( - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME) + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME + ) - docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml') + docker_compose_file = os.path.join( + sagemaker_container.container_root, "docker-compose.yaml" + ) call_args = popen.call_args[0][0] assert call_args is not None - expected = ['docker-compose', '-f', docker_compose_file, 'up', '--build', '--abort-on-container-exit'] + expected = [ + "docker-compose", + "-f", + docker_compose_file, + "up", + "--build", + "--abort-on-container-exit", + ] for i, v in enumerate(expected): assert call_args[i] == v - with open(docker_compose_file, 'r') as f: + with open(docker_compose_file, "r") as f: config = yaml.load(f) - assert len(config['services']) == instance_count + assert len(config["services"]) == instance_count for h in sagemaker_container.hosts: - assert config['services'][h]['image'] == image - assert config['services'][h]['command'] == 'train' - assert 'AWS_REGION={}'.format(REGION) in config['services'][h]['environment'] - assert 'TRAINING_JOB_NAME={}'.format(TRAINING_JOB_NAME) in config['services'][h]['environment'] + assert config["services"][h]["image"] == image + assert config["services"][h]["command"] == "train" + assert "AWS_REGION={}".format(REGION) in config["services"][h]["environment"] + assert ( + "TRAINING_JOB_NAME={}".format(TRAINING_JOB_NAME) + in config["services"][h]["environment"] + ) # assert that expected by sagemaker container output directories exist - assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output')) - assert os.path.exists(os.path.join(sagemaker_container.container_root, 'output/data')) + assert os.path.exists(os.path.join(sagemaker_container.container_root, "output")) + assert os.path.exists(os.path.join(sagemaker_container.container_root, "output/data")) retrieve_artifacts.assert_called_once() cleanup.assert_called_once() -@patch('sagemaker.local.local_session.LocalSession', Mock()) -@patch('sagemaker.local.image._stream_output', Mock()) -@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock()) -@patch('sagemaker.local.data.get_data_source_instance') -def test_train_with_hyperparameters_without_job_name(get_data_source_instance, tmpdir, sagemaker_session): +@patch("sagemaker.local.local_session.LocalSession", Mock()) +@patch("sagemaker.local.image._stream_output", Mock()) +@patch("sagemaker.local.image._SageMakerContainer._cleanup", Mock()) +@patch("sagemaker.local.data.get_data_source_instance") +def test_train_with_hyperparameters_without_job_name( + get_data_source_instance, tmpdir, sagemaker_session +): data_source = Mock() - data_source.get_root_dir.return_value = 'foo' + data_source.get_root_dir.return_value = "foo" get_data_source_instance.return_value = data_source - directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))] - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - side_effect=directories): + directories = [str(tmpdir.mkdir("container-root")), str(tmpdir.mkdir("data"))] + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", side_effect=directories + ): instance_count = 2 - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session) + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", instance_count, image, sagemaker_session=sagemaker_session + ) sagemaker_container.train( - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME) + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME + ) - docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml') + docker_compose_file = os.path.join( + sagemaker_container.container_root, "docker-compose.yaml" + ) - with open(docker_compose_file, 'r') as f: + with open(docker_compose_file, "r") as f: config = yaml.load(f) for h in sagemaker_container.hosts: - assert 'TRAINING_JOB_NAME={}'.format(TRAINING_JOB_NAME) in config['services'][h]['environment'] - - -@patch('sagemaker.local.local_session.LocalSession', Mock()) -@patch('sagemaker.local.image._stream_output', side_effect=RuntimeError('this is expected')) -@patch('sagemaker.local.image._SageMakerContainer._cleanup') -@patch('sagemaker.local.image._SageMakerContainer.retrieve_artifacts') -@patch('sagemaker.local.data.get_data_source_instance') -@patch('subprocess.Popen', Mock()) -def test_train_error(get_data_source_instance, retrieve_artifacts, cleanup, _stream_output, tmpdir, sagemaker_session): + assert ( + "TRAINING_JOB_NAME={}".format(TRAINING_JOB_NAME) + in config["services"][h]["environment"] + ) + + +@patch("sagemaker.local.local_session.LocalSession", Mock()) +@patch("sagemaker.local.image._stream_output", side_effect=RuntimeError("this is expected")) +@patch("sagemaker.local.image._SageMakerContainer._cleanup") +@patch("sagemaker.local.image._SageMakerContainer.retrieve_artifacts") +@patch("sagemaker.local.data.get_data_source_instance") +@patch("subprocess.Popen", Mock()) +def test_train_error( + get_data_source_instance, retrieve_artifacts, cleanup, _stream_output, tmpdir, sagemaker_session +): data_source = Mock() - data_source.get_root_dir.return_value = 'foo' + data_source.get_root_dir.return_value = "foo" get_data_source_instance.return_value = data_source - directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))] - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', side_effect=directories): + directories = [str(tmpdir.mkdir("container-root")), str(tmpdir.mkdir("data"))] + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", side_effect=directories + ): instance_count = 2 - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', instance_count, image, sagemaker_session=sagemaker_session) + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", instance_count, image, sagemaker_session=sagemaker_session + ) with pytest.raises(RuntimeError) as e: sagemaker_container.train( - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME) + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, HYPERPARAMETERS, TRAINING_JOB_NAME + ) - assert 'this is expected' in str(e) + assert "this is expected" in str(e) retrieve_artifacts.assert_called_once() cleanup.assert_called_once() -@patch('sagemaker.local.local_session.LocalSession', Mock()) -@patch('sagemaker.local.image._stream_output', Mock()) -@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock()) -@patch('sagemaker.local.data.get_data_source_instance') -@patch('subprocess.Popen', Mock()) +@patch("sagemaker.local.local_session.LocalSession", Mock()) +@patch("sagemaker.local.image._stream_output", Mock()) +@patch("sagemaker.local.image._SageMakerContainer._cleanup", Mock()) +@patch("sagemaker.local.data.get_data_source_instance") +@patch("subprocess.Popen", Mock()) def test_train_local_code(get_data_source_instance, tmpdir, sagemaker_session): data_source = Mock() - data_source.get_root_dir.return_value = 'foo' + data_source.get_root_dir.return_value = "foo" get_data_source_instance.return_value = data_source - directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))] - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - side_effect=directories): + directories = [str(tmpdir.mkdir("container-root")), str(tmpdir.mkdir("data"))] + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", side_effect=directories + ): instance_count = 2 - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', instance_count, image, - sagemaker_session=sagemaker_session) + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", instance_count, image, sagemaker_session=sagemaker_session + ) sagemaker_container.train( - INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS, TRAINING_JOB_NAME) + INPUT_DATA_CONFIG, OUTPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS, TRAINING_JOB_NAME + ) - docker_compose_file = os.path.join(sagemaker_container.container_root, - 'docker-compose.yaml') - shared_folder_path = os.path.join(sagemaker_container.container_root, 'shared') + docker_compose_file = os.path.join( + sagemaker_container.container_root, "docker-compose.yaml" + ) + shared_folder_path = os.path.join(sagemaker_container.container_root, "shared") - with open(docker_compose_file, 'r') as f: + with open(docker_compose_file, "r") as f: config = yaml.load(f) - assert len(config['services']) == instance_count + assert len(config["services"]) == instance_count for h in sagemaker_container.hosts: - assert config['services'][h]['image'] == image - assert config['services'][h]['command'] == 'train' - volumes = config['services'][h]['volumes'] - assert '%s:/opt/ml/code' % '/tmp/code' in volumes - assert '%s:/opt/ml/shared' % shared_folder_path in volumes - - config_file_root = os.path.join(sagemaker_container.container_root, h, 'input', 'config') - hyperparameters_file = os.path.join(config_file_root, 'hyperparameters.json') + assert config["services"][h]["image"] == image + assert config["services"][h]["command"] == "train" + volumes = config["services"][h]["volumes"] + assert "%s:/opt/ml/code" % "/tmp/code" in volumes + assert "%s:/opt/ml/shared" % shared_folder_path in volumes + + config_file_root = os.path.join( + sagemaker_container.container_root, h, "input", "config" + ) + hyperparameters_file = os.path.join(config_file_root, "hyperparameters.json") hyperparameters_data = json.load(open(hyperparameters_file)) - assert hyperparameters_data['sagemaker_submit_directory'] == json.dumps('/opt/ml/code') + assert hyperparameters_data["sagemaker_submit_directory"] == json.dumps("/opt/ml/code") -@patch('sagemaker.local.local_session.LocalSession', Mock()) -@patch('sagemaker.local.image._stream_output', Mock()) -@patch('sagemaker.local.image._SageMakerContainer._cleanup', Mock()) -@patch('sagemaker.local.data.get_data_source_instance') -@patch('subprocess.Popen', Mock()) +@patch("sagemaker.local.local_session.LocalSession", Mock()) +@patch("sagemaker.local.image._stream_output", Mock()) +@patch("sagemaker.local.image._SageMakerContainer._cleanup", Mock()) +@patch("sagemaker.local.data.get_data_source_instance") +@patch("subprocess.Popen", Mock()) def test_train_local_intermediate_output(get_data_source_instance, tmpdir, sagemaker_session): data_source = Mock() - data_source.get_root_dir.return_value = 'foo' + data_source.get_root_dir.return_value = "foo" get_data_source_instance.return_value = data_source - directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))] - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - side_effect=directories): + directories = [str(tmpdir.mkdir("container-root")), str(tmpdir.mkdir("data"))] + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", side_effect=directories + ): instance_count = 2 - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', instance_count, image, - sagemaker_session=sagemaker_session) + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", instance_count, image, sagemaker_session=sagemaker_session + ) - output_path = str(tmpdir.mkdir('customer_intermediate_output')) - output_data_config = {'S3OutputPath': 'file://%s' % output_path} - hyperparameters = {'sagemaker_s3_output': output_path} + output_path = str(tmpdir.mkdir("customer_intermediate_output")) + output_data_config = {"S3OutputPath": "file://%s" % output_path} + hyperparameters = {"sagemaker_s3_output": output_path} sagemaker_container.train( - INPUT_DATA_CONFIG, output_data_config, hyperparameters, TRAINING_JOB_NAME) + INPUT_DATA_CONFIG, output_data_config, hyperparameters, TRAINING_JOB_NAME + ) - docker_compose_file = os.path.join(sagemaker_container.container_root, - 'docker-compose.yaml') - intermediate_folder_path = os.path.join(output_path, 'output/intermediate') + docker_compose_file = os.path.join( + sagemaker_container.container_root, "docker-compose.yaml" + ) + intermediate_folder_path = os.path.join(output_path, "output/intermediate") - with open(docker_compose_file, 'r') as f: + with open(docker_compose_file, "r") as f: config = yaml.load(f) - assert len(config['services']) == instance_count + assert len(config["services"]) == instance_count for h in sagemaker_container.hosts: - assert config['services'][h]['image'] == image - assert config['services'][h]['command'] == 'train' - volumes = config['services'][h]['volumes'] - assert '%s:/opt/ml/output/intermediate' % intermediate_folder_path in volumes + assert config["services"][h]["image"] == image + assert config["services"][h]["command"] == "train" + volumes = config["services"][h]["volumes"] + assert "%s:/opt/ml/output/intermediate" % intermediate_folder_path in volumes def test_container_has_gpu_support(tmpdir, sagemaker_session): instance_count = 1 - image = 'my-image' - sagemaker_container = _SageMakerContainer('local_gpu', instance_count, image, - sagemaker_session=sagemaker_session) + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local_gpu", instance_count, image, sagemaker_session=sagemaker_session + ) - docker_host = sagemaker_container._create_docker_host('host-1', {}, set(), 'train', []) - assert 'runtime' in docker_host - assert docker_host['runtime'] == 'nvidia' + docker_host = sagemaker_container._create_docker_host("host-1", {}, set(), "train", []) + assert "runtime" in docker_host + assert docker_host["runtime"] == "nvidia" def test_container_does_not_enable_nvidia_docker_for_cpu_containers(sagemaker_session): instance_count = 1 - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', instance_count, image, - sagemaker_session=sagemaker_session) + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", instance_count, image, sagemaker_session=sagemaker_session + ) - docker_host = sagemaker_container._create_docker_host('host-1', {}, set(), 'train', []) - assert 'runtime' not in docker_host + docker_host = sagemaker_container._create_docker_host("host-1", {}, set(), "train", []) + assert "runtime" not in docker_host -@patch('sagemaker.local.image._HostingContainer.run', Mock()) -@patch('sagemaker.local.image._SageMakerContainer._prepare_serving_volumes', Mock(return_value=[])) -@patch('shutil.copy', Mock()) -@patch('shutil.copytree', Mock()) +@patch("sagemaker.local.image._HostingContainer.run", Mock()) +@patch("sagemaker.local.image._SageMakerContainer._prepare_serving_volumes", Mock(return_value=[])) +@patch("shutil.copy", Mock()) +@patch("shutil.copytree", Mock()) def test_serve(tmpdir, sagemaker_session): - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - return_value=str(tmpdir.mkdir('container-root'))): - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', 1, image, sagemaker_session=sagemaker_session) - environment = { - 'env1': 1, - 'env2': 'b', - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://some/path' - } - - sagemaker_container.serve('/some/model/path', environment) - docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml') - - with open(docker_compose_file, 'r') as f: + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", + return_value=str(tmpdir.mkdir("container-root")), + ): + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", 1, image, sagemaker_session=sagemaker_session + ) + environment = {"env1": 1, "env2": "b", "SAGEMAKER_SUBMIT_DIRECTORY": "s3://some/path"} + + sagemaker_container.serve("/some/model/path", environment) + docker_compose_file = os.path.join( + sagemaker_container.container_root, "docker-compose.yaml" + ) + + with open(docker_compose_file, "r") as f: config = yaml.load(f) for h in sagemaker_container.hosts: - assert config['services'][h]['image'] == image - assert config['services'][h]['command'] == 'serve' + assert config["services"][h]["image"] == image + assert config["services"][h]["command"] == "serve" -@patch('sagemaker.local.image._HostingContainer.run', Mock()) -@patch('sagemaker.local.image._SageMakerContainer._prepare_serving_volumes', Mock(return_value=[])) -@patch('shutil.copy', Mock()) -@patch('shutil.copytree', Mock()) +@patch("sagemaker.local.image._HostingContainer.run", Mock()) +@patch("sagemaker.local.image._SageMakerContainer._prepare_serving_volumes", Mock(return_value=[])) +@patch("shutil.copy", Mock()) +@patch("shutil.copytree", Mock()) def test_serve_local_code(tmpdir, sagemaker_session): - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - return_value=str(tmpdir.mkdir('container-root'))): - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', 1, image, sagemaker_session=sagemaker_session) - environment = { - 'env1': 1, - 'env2': 'b', - 'SAGEMAKER_SUBMIT_DIRECTORY': 'file:///tmp/code' - } - - sagemaker_container.serve('/some/model/path', environment) - docker_compose_file = os.path.join(sagemaker_container.container_root, - 'docker-compose.yaml') - - with open(docker_compose_file, 'r') as f: + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", + return_value=str(tmpdir.mkdir("container-root")), + ): + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", 1, image, sagemaker_session=sagemaker_session + ) + environment = {"env1": 1, "env2": "b", "SAGEMAKER_SUBMIT_DIRECTORY": "file:///tmp/code"} + + sagemaker_container.serve("/some/model/path", environment) + docker_compose_file = os.path.join( + sagemaker_container.container_root, "docker-compose.yaml" + ) + + with open(docker_compose_file, "r") as f: config = yaml.load(f) for h in sagemaker_container.hosts: - assert config['services'][h]['image'] == image - assert config['services'][h]['command'] == 'serve' + assert config["services"][h]["image"] == image + assert config["services"][h]["command"] == "serve" - volumes = config['services'][h]['volumes'] - assert '%s:/opt/ml/code' % '/tmp/code' in volumes - assert 'SAGEMAKER_SUBMIT_DIRECTORY=/opt/ml/code' in config['services'][h]['environment'] + volumes = config["services"][h]["volumes"] + assert "%s:/opt/ml/code" % "/tmp/code" in volumes + assert ( + "SAGEMAKER_SUBMIT_DIRECTORY=/opt/ml/code" + in config["services"][h]["environment"] + ) -@patch('sagemaker.local.image._HostingContainer.run', Mock()) -@patch('sagemaker.local.image._SageMakerContainer._prepare_serving_volumes', Mock(return_value=[])) -@patch('shutil.copy', Mock()) -@patch('shutil.copytree', Mock()) +@patch("sagemaker.local.image._HostingContainer.run", Mock()) +@patch("sagemaker.local.image._SageMakerContainer._prepare_serving_volumes", Mock(return_value=[])) +@patch("shutil.copy", Mock()) +@patch("shutil.copytree", Mock()) def test_serve_local_code_no_env(tmpdir, sagemaker_session): - with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder', - return_value=str(tmpdir.mkdir('container-root'))): - image = 'my-image' - sagemaker_container = _SageMakerContainer('local', 1, image, sagemaker_session=sagemaker_session) - sagemaker_container.serve('/some/model/path', {}) - docker_compose_file = os.path.join(sagemaker_container.container_root, - 'docker-compose.yaml') - - with open(docker_compose_file, 'r') as f: + with patch( + "sagemaker.local.image._SageMakerContainer._create_tmp_folder", + return_value=str(tmpdir.mkdir("container-root")), + ): + image = "my-image" + sagemaker_container = _SageMakerContainer( + "local", 1, image, sagemaker_session=sagemaker_session + ) + sagemaker_container.serve("/some/model/path", {}) + docker_compose_file = os.path.join( + sagemaker_container.container_root, "docker-compose.yaml" + ) + + with open(docker_compose_file, "r") as f: config = yaml.load(f) for h in sagemaker_container.hosts: - assert config['services'][h]['image'] == image - assert config['services'][h]['command'] == 'serve' - - -@patch('sagemaker.local.data.get_data_source_instance') -@patch('tarfile.is_tarfile') -@patch('tarfile.open', MagicMock()) -@patch('os.makedirs', Mock()) -def test_prepare_serving_volumes_with_s3_model(is_tarfile, get_data_source_instance, sagemaker_session): - sagemaker_container = _SageMakerContainer('local', 1, 'some-image', sagemaker_session=sagemaker_session) - sagemaker_container.container_root = '/tmp/container_root' + assert config["services"][h]["image"] == image + assert config["services"][h]["command"] == "serve" + + +@patch("sagemaker.local.data.get_data_source_instance") +@patch("tarfile.is_tarfile") +@patch("tarfile.open", MagicMock()) +@patch("os.makedirs", Mock()) +def test_prepare_serving_volumes_with_s3_model( + is_tarfile, get_data_source_instance, sagemaker_session +): + sagemaker_container = _SageMakerContainer( + "local", 1, "some-image", sagemaker_session=sagemaker_session + ) + sagemaker_container.container_root = "/tmp/container_root" s3_data_source = Mock() - s3_data_source.get_root_dir.return_value = '/tmp/downloaded/data/' - s3_data_source.get_file_list.return_value = ['/tmp/downloaded/data/my_model.tar.gz'] + s3_data_source.get_root_dir.return_value = "/tmp/downloaded/data/" + s3_data_source.get_file_list.return_value = ["/tmp/downloaded/data/my_model.tar.gz"] get_data_source_instance.return_value = s3_data_source is_tarfile.return_value = True - volumes = sagemaker_container._prepare_serving_volumes('s3://bucket/my_model.tar.gz') - is_tarfile.assert_called_with('/tmp/downloaded/data/my_model.tar.gz') + volumes = sagemaker_container._prepare_serving_volumes("s3://bucket/my_model.tar.gz") + is_tarfile.assert_called_with("/tmp/downloaded/data/my_model.tar.gz") assert len(volumes) == 1 - assert volumes[0].container_dir == '/opt/ml/model' - assert volumes[0].host_dir == '/tmp/downloaded/data/' + assert volumes[0].container_dir == "/opt/ml/model" + assert volumes[0].host_dir == "/tmp/downloaded/data/" -@patch('sagemaker.local.data.get_data_source_instance') -@patch('tarfile.is_tarfile', Mock(return_value=False)) -@patch('os.makedirs', Mock()) +@patch("sagemaker.local.data.get_data_source_instance") +@patch("tarfile.is_tarfile", Mock(return_value=False)) +@patch("os.makedirs", Mock()) def test_prepare_serving_volumes_with_local_model(get_data_source_instance, sagemaker_session): - sagemaker_container = _SageMakerContainer('local', 1, 'some-image', sagemaker_session=sagemaker_session) - sagemaker_container.container_root = '/tmp/container_root' + sagemaker_container = _SageMakerContainer( + "local", 1, "some-image", sagemaker_session=sagemaker_session + ) + sagemaker_container.container_root = "/tmp/container_root" local_file_data_source = Mock() - local_file_data_source.get_root_dir.return_value = '/path/to/my_model' - local_file_data_source.get_file_list.return_value = ['/path/to/my_model/model'] + local_file_data_source.get_root_dir.return_value = "/path/to/my_model" + local_file_data_source.get_file_list.return_value = ["/path/to/my_model/model"] get_data_source_instance.return_value = local_file_data_source - volumes = sagemaker_container._prepare_serving_volumes('file:///path/to/my_model') + volumes = sagemaker_container._prepare_serving_volumes("file:///path/to/my_model") assert len(volumes) == 1 - assert volumes[0].container_dir == '/opt/ml/model' - assert volumes[0].host_dir == '/path/to/my_model' + assert volumes[0].container_dir == "/opt/ml/model" + assert volumes[0].host_dir == "/path/to/my_model" def test_ecr_login_non_ecr(): session_mock = Mock() - result = sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu') + result = sagemaker.local.image._ecr_login_if_needed(session_mock, "ubuntu") session_mock.assert_not_called() assert result is False -@patch('sagemaker.local.image._check_output', return_value='123451324') -@pytest.mark.parametrize('image', [ - '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0', - '520713654638.dkr.ecr.us-iso-east-1.c2s.ic.gov/image-i-have:1.0' -]) +@patch("sagemaker.local.image._check_output", return_value="123451324") +@pytest.mark.parametrize( + "image", + [ + "520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0", + "520713654638.dkr.ecr.us-iso-east-1.c2s.ic.gov/image-i-have:1.0", + ], +) def test_ecr_login_image_exists(_check_output, image): session_mock = Mock() @@ -644,47 +736,49 @@ def test_ecr_login_image_exists(_check_output, image): assert result is False -@patch('subprocess.check_output', return_value=''.encode('utf-8')) +@patch("subprocess.check_output", return_value="".encode("utf-8")) def test_ecr_login_needed(check_output): session_mock = Mock() - token = 'very-secure-token' - token_response = 'AWS:%s' % token - b64_token = base64.b64encode(token_response.encode('utf-8')) + token = "very-secure-token" + token_response = "AWS:%s" % token + b64_token = base64.b64encode(token_response.encode("utf-8")) response = { - u'authorizationData': - [ - { - u'authorizationToken': b64_token, - u'proxyEndpoint': u'https://520713654638.dkr.ecr.us-east-1.amazonaws.com' - } - ], - 'ResponseMetadata': + u"authorizationData": [ { - 'RetryAttempts': 0, - 'HTTPStatusCode': 200, - 'RequestId': '25b2ac63-36bf-11e8-ab6a-e5dc597d2ad9', + u"authorizationToken": b64_token, + u"proxyEndpoint": u"https://520713654638.dkr.ecr.us-east-1.amazonaws.com", } + ], + "ResponseMetadata": { + "RetryAttempts": 0, + "HTTPStatusCode": 200, + "RequestId": "25b2ac63-36bf-11e8-ab6a-e5dc597d2ad9", + }, } - session_mock.client('ecr').get_authorization_token.return_value = response - image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1' + session_mock.client("ecr").get_authorization_token.return_value = response + image = "520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1" result = sagemaker.local.image._ecr_login_if_needed(session_mock, image) - expected_command = 'docker login -u AWS -p %s https://520713654638.dkr.ecr.us-east-1.amazonaws.com' % token + expected_command = ( + "docker login -u AWS -p %s https://520713654638.dkr.ecr.us-east-1.amazonaws.com" % token + ) check_output.assert_called_with(expected_command, shell=True) - session_mock.client('ecr').get_authorization_token.assert_called_with(registryIds=['520713654638']) + session_mock.client("ecr").get_authorization_token.assert_called_with( + registryIds=["520713654638"] + ) assert result is True -@patch('subprocess.check_output', return_value=''.encode('utf-8')) +@patch("subprocess.check_output", return_value="".encode("utf-8")) def test_pull_image(check_output): - image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1' + image = "520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1" sagemaker.local.image._pull_image(image) - expected_command = 'docker pull %s' % image + expected_command = "docker pull %s" % image check_output.assert_called_once_with(expected_command, shell=True) @@ -697,14 +791,18 @@ def test__aws_credentials_with_long_lived_credentials(): aws_credentials = _aws_credentials(session) assert aws_credentials == [ - 'AWS_ACCESS_KEY_ID=%s' % credentials.access_key, - 'AWS_SECRET_ACCESS_KEY=%s' % credentials.secret_key + "AWS_ACCESS_KEY_ID=%s" % credentials.access_key, + "AWS_SECRET_ACCESS_KEY=%s" % credentials.secret_key, ] -@patch('sagemaker.local.image._aws_credentials_available_in_metadata_service') -def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_credentials(mock): - credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=_random_string()) +@patch("sagemaker.local.image._aws_credentials_available_in_metadata_service") +def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_credentials( + mock +): + credentials = Credentials( + access_key=_random_string(), secret_key=_random_string(), token=_random_string() + ) session = Mock() session.get_credentials.return_value = credentials mock.return_value = True @@ -713,20 +811,24 @@ def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_ assert aws_credentials is None -@patch('sagemaker.local.image._aws_credentials_available_in_metadata_service') -def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_no_credentials(mock): - credentials = Credentials(access_key=_random_string(), secret_key=_random_string(), token=_random_string()) +@patch("sagemaker.local.image._aws_credentials_available_in_metadata_service") +def test__aws_credentials_with_short_lived_credentials_and_ec2_metadata_service_having_no_credentials( + mock +): + credentials = Credentials( + access_key=_random_string(), secret_key=_random_string(), token=_random_string() + ) session = Mock() session.get_credentials.return_value = credentials mock.return_value = False aws_credentials = _aws_credentials(session) assert aws_credentials == [ - 'AWS_ACCESS_KEY_ID=%s' % credentials.access_key, - 'AWS_SECRET_ACCESS_KEY=%s' % credentials.secret_key, - 'AWS_SESSION_TOKEN=%s' % credentials.token + "AWS_ACCESS_KEY_ID=%s" % credentials.access_key, + "AWS_SECRET_ACCESS_KEY=%s" % credentials.secret_key, + "AWS_SESSION_TOKEN=%s" % credentials.token, ] def _random_string(size=6, chars=string.ascii_uppercase): - return ''.join(random.choice(chars) for x in range(size)) + return "".join(random.choice(chars) for x in range(size)) diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index 402383348b..2a99f45fab 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -19,45 +19,47 @@ from sagemaker.amazon.amazon_estimator import registry, RecordSet # Mocked training config -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" # Required algorithm hyperparameters NUM_ENTITY_VECTORS = 10000 VECTOR_DIM = 128 -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'num_entity_vectors': NUM_ENTITY_VECTORS, 'vector_dim': VECTOR_DIM}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict( + {"num_entity_vectors": NUM_ENTITY_VECTORS, "vector_dim": VECTOR_DIM}, **COMMON_TRAIN_ARGS +) REGION = "us-west-2" BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': "s3://bucket/model.tar.gz" - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -66,9 +68,13 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): ipinsights = IPInsights( - ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, - NUM_ENTITY_VECTORS, VECTOR_DIM, - sagemaker_session=sagemaker_session) + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_ENTITY_VECTORS, + VECTOR_DIM, + sagemaker_session=sagemaker_session, + ) assert ipinsights.role == ROLE assert ipinsights.train_instance_count == TRAIN_INSTANCE_COUNT assert ipinsights.train_instance_type == TRAIN_INSTANCE_TYPE @@ -79,9 +85,9 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert ipinsights.role == COMMON_TRAIN_ARGS['role'] + assert ipinsights.role == COMMON_TRAIN_ARGS["role"] assert ipinsights.train_instance_count == TRAIN_INSTANCE_COUNT - assert ipinsights.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] + assert ipinsights.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] assert ipinsights.num_entity_vectors == NUM_ENTITY_VECTORS assert ipinsights.vector_dim == VECTOR_DIM @@ -89,31 +95,36 @@ def test_init_required_named(sagemaker_session): def test_all_hyperparameters(sagemaker_session): ipinsights = IPInsights( sagemaker_session=sagemaker_session, - batch_metrics_publish_interval=100, epochs=10, learning_rate=0.001, num_ip_encoder_layers=3, - random_negative_sampling_rate=5, shuffled_negative_sampling_rate=5, weight_decay=5.0, - **ALL_REQ_ARGS) + batch_metrics_publish_interval=100, + epochs=10, + learning_rate=0.001, + num_ip_encoder_layers=3, + random_negative_sampling_rate=5, + shuffled_negative_sampling_rate=5, + weight_decay=5.0, + **ALL_REQ_ARGS + ) assert ipinsights.hyperparameters() == dict( - num_entity_vectors=str(ALL_REQ_ARGS['num_entity_vectors']), - vector_dim=str(ALL_REQ_ARGS['vector_dim']), - batch_metrics_publish_interval='100', - epochs='10', - learning_rate='0.001', - num_ip_encoder_layers='3', - random_negative_sampling_rate='5', - shuffled_negative_sampling_rate='5', - weight_decay='5.0' + num_entity_vectors=str(ALL_REQ_ARGS["num_entity_vectors"]), + vector_dim=str(ALL_REQ_ARGS["vector_dim"]), + batch_metrics_publish_interval="100", + epochs="10", + learning_rate="0.001", + num_ip_encoder_layers="3", + random_negative_sampling_rate="5", + shuffled_negative_sampling_rate="5", + weight_decay="5.0", ) def test_image(sagemaker_session): ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert ipinsights.train_image() == registry(REGION, "ipinsights") + '/ipinsights:1' + assert ipinsights.train_image() == registry(REGION, "ipinsights") + "/ipinsights:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_entity_vectors', 'string'), - ('vector_dim', 'string') -]) +@pytest.mark.parametrize( + "required_hyper_parameters, value", [("num_entity_vectors", "string"), ("vector_dim", "string")] +) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -121,12 +132,15 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame IPInsights(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_entity_vectors', 0), - ('num_entity_vectors', 500000001), - ('vector_dim', 3), - ('vector_dim', 4097) -]) +@pytest.mark.parametrize( + "required_hyper_parameters, value", + [ + ("num_entity_vectors", 0), + ("num_entity_vectors", 500000001), + ("vector_dim", 3), + ("vector_dim", 4097), + ], +) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -134,15 +148,18 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param IPInsights(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('batch_metrics_publish_interval', 'string'), - ('epochs', 'string'), - ('learning_rate', 'string'), - ('num_ip_encoder_layers', 'string'), - ('random_negative_sampling_rate', 'string'), - ('shuffled_negative_sampling_rate', 'string'), - ('weight_decay', 'string'), -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("batch_metrics_publish_interval", "string"), + ("epochs", "string"), + ("learning_rate", "string"), + ("num_ip_encoder_layers", "string"), + ("random_negative_sampling_rate", "string"), + ("shuffled_negative_sampling_rate", "string"), + ("weight_decay", "string"), + ], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -150,20 +167,23 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame IPInsights(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('batch_metrics_publish_interval', 0), - ('epochs', 0), - ('learning_rate', 0), - ('learning_rate', 11), - ('num_ip_encoder_layers', -1), - ('num_ip_encoder_layers', 101), - ('random_negative_sampling_rate', -1), - ('random_negative_sampling_rate', 501), - ('shuffled_negative_sampling_rate', -1), - ('shuffled_negative_sampling_rate', 501), - ('weight_decay', -1), - ('weight_decay', 11), -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("batch_metrics_publish_interval", 0), + ("epochs", 0), + ("learning_rate", 0), + ("learning_rate", 11), + ("num_ip_encoder_layers", -1), + ("num_ip_encoder_layers", 101), + ("random_negative_sampling_rate", -1), + ("random_negative_sampling_rate", 501), + ("shuffled_negative_sampling_rate", -1), + ("shuffled_negative_sampling_rate", 501), + ("weight_decay", -1), + ("weight_decay", 11), + ], +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -178,9 +198,16 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param @patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit(base_fit, sagemaker_session): - ipinsights = IPInsights(base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + ipinsights = IPInsights( + base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ipinsights.fit(data, MINI_BATCH_SIZE) @@ -191,53 +218,87 @@ def test_call_fit(base_fit, sagemaker_session): def test_call_fit_none_mini_batch_size(sagemaker_session): - ipinsights = IPInsights(base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + ipinsights = IPInsights( + base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ipinsights.fit(data) def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): - ipinsights = IPInsights(base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + ipinsights = IPInsights( + base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): ipinsights._prepare_for_training(data, "some") def test_prepare_for_training_wrong_value_lower_mini_batch_size(sagemaker_session): - ipinsights = IPInsights(base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + ipinsights = IPInsights( + base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): ipinsights._prepare_for_training(data, 0) def test_prepare_for_training_wrong_value_upper_mini_batch_size(sagemaker_session): - ipinsights = IPInsights(base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + ipinsights = IPInsights( + base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): ipinsights._prepare_for_training(data, 500001) def test_model_image(sagemaker_session): ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ipinsights.fit(data, MINI_BATCH_SIZE) model = ipinsights.create_model() - assert model.image == registry(REGION, "ipinsights") + '/ipinsights:1' + assert model.image == registry(REGION, "ipinsights") + "/ipinsights:1" def test_predictor_type(sagemaker_session): ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ipinsights.fit(data, MINI_BATCH_SIZE) model = ipinsights.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index a3669a2524..954c9e39f8 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -22,56 +22,64 @@ from sagemaker.model import FrameworkModel from sagemaker.session import s3_input -BUCKET_NAME = 's3://mybucket/train' -S3_OUTPUT_PATH = 's3://bucket/prefix' -LOCAL_FILE_NAME = 'file://local/file' +BUCKET_NAME = "s3://mybucket/train" +S3_OUTPUT_PATH = "s3://bucket/prefix" +LOCAL_FILE_NAME = "file://local/file" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'c4.4xlarge' +INSTANCE_TYPE = "c4.4xlarge" VOLUME_SIZE = 1 MAX_RUNTIME = 1 -ROLE = 'DummyRole' -REGION = 'us-west-2' -IMAGE_NAME = 'fakeimage' -SCRIPT_NAME = 'script.py' -JOB_NAME = 'fakejob' -VOLUME_KMS_KEY = 'volkmskey' -MODEL_CHANNEL_NAME = 'testModelChannel' -MODEL_URI = 's3://bucket/prefix/model.tar.gz' -LOCAL_MODEL_NAME = 'file://local/file.tar.gz' -CODE_CHANNEL_NAME = 'testCodeChannel' -CODE_URI = 's3://bucket/prefix/code.py' -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') +ROLE = "DummyRole" +REGION = "us-west-2" +IMAGE_NAME = "fakeimage" +SCRIPT_NAME = "script.py" +JOB_NAME = "fakejob" +VOLUME_KMS_KEY = "volkmskey" +MODEL_CHANNEL_NAME = "testModelChannel" +MODEL_URI = "s3://bucket/prefix/model.tar.gz" +LOCAL_MODEL_NAME = "file://local/file.tar.gz" +CODE_CHANNEL_NAME = "testCodeChannel" +CODE_URI = "s3://bucket/prefix/code.py" +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME) MODEL_CONTAINER_DEF = { - 'Environment': { - 'SAGEMAKER_PROGRAM': SCRIPT_NAME, - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': REGION, - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false' + "Environment": { + "SAGEMAKER_PROGRAM": SCRIPT_NAME, + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", }, - 'Image': IMAGE_NAME, - 'ModelDataUrl': MODEL_URI, + "Image": IMAGE_NAME, + "ModelDataUrl": MODEL_URI, } @pytest.fixture() def estimator(sagemaker_session): - return Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, train_volume_size=VOLUME_SIZE, - train_max_run=MAX_RUNTIME, output_path=S3_OUTPUT_PATH, sagemaker_session=sagemaker_session) + return Estimator( + IMAGE_NAME, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + train_volume_size=VOLUME_SIZE, + train_max_run=MAX_RUNTIME, + output_path=S3_OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session') - mock_session = Mock(name='sagemaker_session', boto_session=boto_mock) - mock_session.expand_role = Mock(name='expand_role', return_value=ROLE) + boto_mock = Mock(name="boto_session") + mock_session = Mock(name="sagemaker_session", boto_session=boto_mock) + mock_session.expand_role = Mock(name="expand_role", return_value=ROLE) return mock_session class DummyFramework(Framework): - __framework_name__ = 'dummy' + __framework_name__ = "dummy" def train_image(self): return IMAGE_NAME @@ -82,15 +90,23 @@ def create_model(self, role=None, model_server_workers=None): @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): init_params = super(DummyFramework, cls)._prepare_init_params_from_job_description( - job_details, model_channel_name) + job_details, model_channel_name + ) init_params.pop("image", None) return init_params class DummyFrameworkModel(FrameworkModel): def __init__(self, sagemaker_session, **kwargs): - super(DummyFrameworkModel, self).__init__(MODEL_URI, IMAGE_NAME, INSTANCE_TYPE, ROLE, SCRIPT_NAME, - sagemaker_session=sagemaker_session, **kwargs) + super(DummyFrameworkModel, self).__init__( + MODEL_URI, + IMAGE_NAME, + INSTANCE_TYPE, + ROLE, + SCRIPT_NAME, + sagemaker_session=sagemaker_session, + **kwargs + ) def prepare_container_def(self, instance_type, accelerator_type=None): return MODEL_CONTAINER_DEF @@ -98,9 +114,14 @@ def prepare_container_def(self, instance_type, accelerator_type=None): @pytest.fixture() def framework(sagemaker_session): - return DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - output_path=S3_OUTPUT_PATH, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE) + return DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + output_path=S3_OUTPUT_PATH, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) def test_load_config(estimator): @@ -108,14 +129,14 @@ def test_load_config(estimator): config = _Job._load_config(inputs, estimator) - assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME - assert config['role'] == ROLE - assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH - assert 'KmsKeyId' not in config['output_config'] - assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT - assert config['resource_config']['InstanceType'] == INSTANCE_TYPE - assert config['resource_config']['VolumeSizeInGB'] == VOLUME_SIZE - assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE + assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE + assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME def test_load_config_with_model_channel(estimator): @@ -126,16 +147,16 @@ def test_load_config_with_model_channel(estimator): config = _Job._load_config(inputs, estimator) - assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME - assert config['input_config'][1]['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI - assert config['input_config'][1]['ChannelName'] == MODEL_CHANNEL_NAME - assert config['role'] == ROLE - assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH - assert 'KmsKeyId' not in config['output_config'] - assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT - assert config['resource_config']['InstanceType'] == INSTANCE_TYPE - assert config['resource_config']['VolumeSizeInGB'] == VOLUME_SIZE - assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME + assert config["input_config"][1]["DataSource"]["S3DataSource"]["S3Uri"] == MODEL_URI + assert config["input_config"][1]["ChannelName"] == MODEL_CHANNEL_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE + assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE + assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME def test_load_config_with_model_channel_no_inputs(estimator): @@ -144,15 +165,15 @@ def test_load_config_with_model_channel_no_inputs(estimator): config = _Job._load_config(inputs=None, estimator=estimator) - assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI - assert config['input_config'][0]['ChannelName'] == MODEL_CHANNEL_NAME - assert config['role'] == ROLE - assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH - assert 'KmsKeyId' not in config['output_config'] - assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT - assert config['resource_config']['InstanceType'] == INSTANCE_TYPE - assert config['resource_config']['VolumeSizeInGB'] == VOLUME_SIZE - assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == MODEL_URI + assert config["input_config"][0]["ChannelName"] == MODEL_CHANNEL_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE + assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE + assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME def test_load_config_with_code_channel(framework): @@ -164,15 +185,15 @@ def test_load_config_with_code_channel(framework): framework._enable_network_isolation = True config = _Job._load_config(inputs, framework) - assert len(config['input_config']) == 3 - assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME - assert config['input_config'][2]['DataSource']['S3DataSource']['S3Uri'] == CODE_URI - assert config['input_config'][2]['ChannelName'] == framework.code_channel_name - assert config['role'] == ROLE - assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH - assert 'KmsKeyId' not in config['output_config'] - assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT - assert config['resource_config']['InstanceType'] == INSTANCE_TYPE + assert len(config["input_config"]) == 3 + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME + assert config["input_config"][2]["DataSource"]["S3DataSource"]["S3Uri"] == CODE_URI + assert config["input_config"][2]["ChannelName"] == framework.code_channel_name + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE def test_load_config_with_code_channel_no_code_uri(framework): @@ -183,13 +204,13 @@ def test_load_config_with_code_channel_no_code_uri(framework): framework._enable_network_isolation = True config = _Job._load_config(inputs, framework) - assert len(config['input_config']) == 2 - assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == BUCKET_NAME - assert config['role'] == ROLE - assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH - assert 'KmsKeyId' not in config['output_config'] - assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT - assert config['resource_config']['InstanceType'] == INSTANCE_TYPE + assert len(config["input_config"]) == 2 + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE def test_format_inputs_none(): @@ -203,7 +224,7 @@ def test_format_inputs_to_input_config_string(): channels = _Job._format_inputs_to_input_config(inputs) - assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs + assert channels[0]["DataSource"]["S3DataSource"]["S3Uri"] == inputs def test_format_inputs_to_input_config_s3_input(): @@ -211,16 +232,18 @@ def test_format_inputs_to_input_config_s3_input(): channels = _Job._format_inputs_to_input_config(inputs) - assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs.config['DataSource'][ - 'S3DataSource']['S3Uri'] + assert ( + channels[0]["DataSource"]["S3DataSource"]["S3Uri"] + == inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ) def test_format_inputs_to_input_config_dict(): - inputs = {'train': BUCKET_NAME} + inputs = {"train": BUCKET_NAME} channels = _Job._format_inputs_to_input_config(inputs) - assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs['train'] + assert channels[0]["DataSource"]["S3DataSource"]["S3Uri"] == inputs["train"] def test_format_inputs_to_input_config_record_set(): @@ -228,8 +251,8 @@ def test_format_inputs_to_input_config_record_set(): channels = _Job._format_inputs_to_input_config(inputs) - assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs.s3_data - assert channels[0]['DataSource']['S3DataSource']['S3DataType'] == inputs.s3_data_type + assert channels[0]["DataSource"]["S3DataSource"]["S3Uri"] == inputs.s3_data + assert channels[0]["DataSource"]["S3DataSource"]["S3DataType"] == inputs.s3_data_type def test_format_inputs_to_input_config_list(): @@ -238,52 +261,60 @@ def test_format_inputs_to_input_config_list(): channels = _Job._format_inputs_to_input_config(inputs) - assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == records.s3_data - assert channels[0]['DataSource']['S3DataSource']['S3DataType'] == records.s3_data_type + assert channels[0]["DataSource"]["S3DataSource"]["S3Uri"] == records.s3_data + assert channels[0]["DataSource"]["S3DataSource"]["S3DataType"] == records.s3_data_type -@pytest.mark.parametrize('channel_uri, channel_name, content_type, input_mode', - [[MODEL_URI, MODEL_CHANNEL_NAME, 'application/x-sagemaker-model', 'File'], - [CODE_URI, CODE_CHANNEL_NAME, None, None]]) +@pytest.mark.parametrize( + "channel_uri, channel_name, content_type, input_mode", + [ + [MODEL_URI, MODEL_CHANNEL_NAME, "application/x-sagemaker-model", "File"], + [CODE_URI, CODE_CHANNEL_NAME, None, None], + ], +) def test_prepare_channel(channel_uri, channel_name, content_type, input_mode): - channel = _Job._prepare_channel([], channel_uri, channel_name, content_type=content_type, input_mode=input_mode) + channel = _Job._prepare_channel( + [], channel_uri, channel_name, content_type=content_type, input_mode=input_mode + ) - assert channel['DataSource']['S3DataSource']['S3Uri'] == channel_uri - assert channel['DataSource']['S3DataSource']['S3DataDistributionType'] == 'FullyReplicated' - assert channel['DataSource']['S3DataSource']['S3DataType'] == 'S3Prefix' - assert channel['ChannelName'] == channel_name - assert 'CompressionType' not in channel - assert 'RecordWrapperType' not in channel + assert channel["DataSource"]["S3DataSource"]["S3Uri"] == channel_uri + assert channel["DataSource"]["S3DataSource"]["S3DataDistributionType"] == "FullyReplicated" + assert channel["DataSource"]["S3DataSource"]["S3DataType"] == "S3Prefix" + assert channel["ChannelName"] == channel_name + assert "CompressionType" not in channel + assert "RecordWrapperType" not in channel # The model channel should use all the defaults except InputMode and ContentType if channel_name == MODEL_CHANNEL_NAME: - assert channel['ContentType'] == 'application/x-sagemaker-model' - assert channel['InputMode'] == 'File' + assert channel["ContentType"] == "application/x-sagemaker-model" + assert channel["InputMode"] == "File" def test_prepare_channel_duplicate(): - channels = [{ - 'ChannelName': MODEL_CHANNEL_NAME, - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://blah/blah' - } + channels = [ + { + "ChannelName": MODEL_CHANNEL_NAME, + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://blah/blah", + } + }, } - }] + ] with pytest.raises(ValueError) as error: _Job._prepare_channel(channels, MODEL_URI, MODEL_CHANNEL_NAME) - assert 'Duplicate channel {} not allowed.'.format(MODEL_CHANNEL_NAME) in str(error) + assert "Duplicate channel {} not allowed.".format(MODEL_CHANNEL_NAME) in str(error) def test_prepare_channel_with_missing_name(): with pytest.raises(ValueError) as ex: _Job._prepare_channel([], channel_uri=MODEL_URI, channel_name=None) - assert 'Expected a channel name if a channel URI {} is specified'.format(MODEL_URI) in str(ex) + assert "Expected a channel name if a channel URI {} is specified".format(MODEL_URI) in str(ex) def test_prepare_channel_with_missing_uri(): @@ -292,12 +323,12 @@ def test_prepare_channel_with_missing_uri(): def test_format_inputs_to_input_config_list_not_all_records(): records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1) - inputs = [records, 'mock'] + inputs = [records, "mock"] with pytest.raises(ValueError) as ex: _Job._format_inputs_to_input_config(inputs) - assert 'List compatible only with RecordSets.' in str(ex) + assert "List compatible only with RecordSets." in str(ex) def test_format_inputs_to_input_config_list_duplicate_channel(): @@ -307,94 +338,111 @@ def test_format_inputs_to_input_config_list_duplicate_channel(): with pytest.raises(ValueError) as ex: _Job._format_inputs_to_input_config(inputs) - assert 'Duplicate channels not allowed.' in str(ex) + assert "Duplicate channels not allowed." in str(ex) def test_format_input_single_unamed_channel(): - input_dict = _Job._format_inputs_to_input_config('s3://blah/blah') - assert input_dict == [{ - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://blah/blah' - } + input_dict = _Job._format_inputs_to_input_config("s3://blah/blah") + assert input_dict == [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://blah/blah", + } + }, } - }] + ] def test_format_input_multiple_channels(): - input_list = _Job._format_inputs_to_input_config({'a': 's3://blah/blah', 'b': 's3://foo/bar'}) - expected = [{ - 'ChannelName': 'a', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://blah/blah' - } - } - }, + input_list = _Job._format_inputs_to_input_config({"a": "s3://blah/blah", "b": "s3://foo/bar"}) + expected = [ { - 'ChannelName': 'b', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://foo/bar' + "ChannelName": "a", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://blah/blah", } - } - }] + }, + }, + { + "ChannelName": "b", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://foo/bar", + } + }, + }, + ] # convert back into map for comparison so list order (which is arbitrary) is ignored - assert {c['ChannelName']: c for c in input_list} == {c['ChannelName']: c for c in expected} + assert {c["ChannelName"]: c for c in input_list} == {c["ChannelName"]: c for c in expected} def test_format_input_s3_input(): - input_dict = _Job._format_inputs_to_input_config(s3_input('s3://foo/bar', distribution='ShardedByS3Key', - compression='gzip', content_type='whizz', - record_wrapping='bang')) - assert input_dict == [{ - 'CompressionType': 'gzip', - 'ChannelName': 'training', - 'ContentType': 'whizz', - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3DataDistributionType': 'ShardedByS3Key', - 'S3Uri': 's3://foo/bar'}}, - 'RecordWrapperType': 'bang'}] + input_dict = _Job._format_inputs_to_input_config( + s3_input( + "s3://foo/bar", + distribution="ShardedByS3Key", + compression="gzip", + content_type="whizz", + record_wrapping="bang", + ) + ) + assert input_dict == [ + { + "CompressionType": "gzip", + "ChannelName": "training", + "ContentType": "whizz", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3DataDistributionType": "ShardedByS3Key", + "S3Uri": "s3://foo/bar", + } + }, + "RecordWrapperType": "bang", + } + ] def test_dict_of_mixed_input_types(): - input_list = _Job._format_inputs_to_input_config({ - 'a': 's3://foo/bar', - 'b': s3_input('s3://whizz/bang')}) + input_list = _Job._format_inputs_to_input_config( + {"a": "s3://foo/bar", "b": s3_input("s3://whizz/bang")} + ) expected = [ - {'ChannelName': 'a', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://foo/bar' - } - } - }, { - 'ChannelName': 'b', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://whizz/bang' + "ChannelName": "a", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://foo/bar", + } + }, + }, + { + "ChannelName": "b", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": "s3://whizz/bang", } - } - }] + }, + }, + ] # convert back into map for comparison so list order (which is arbitrary) is ignored - assert {c['ChannelName']: c for c in input_list} == {c['ChannelName']: c for c in expected} + assert {c["ChannelName"]: c for c in input_list} == {c["ChannelName"]: c for c in expected} def test_format_inputs_to_input_config_exception(): @@ -406,7 +454,7 @@ def test_format_inputs_to_input_config_exception(): def test_unsupported_type_in_dict(): with pytest.raises(ValueError): - _Job._format_inputs_to_input_config({'a': 66}) + _Job._format_inputs_to_input_config({"a": 66}) def test_format_string_uri_input_string(): @@ -414,11 +462,11 @@ def test_format_string_uri_input_string(): s3_uri_input = _Job._format_string_uri_input(inputs) - assert s3_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == inputs + assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs def test_format_string_uri_input_string_exception(): - inputs = 'mybucket/train' + inputs = "mybucket/train" with pytest.raises(ValueError): _Job._format_string_uri_input(inputs) @@ -427,7 +475,7 @@ def test_format_string_uri_input_string_exception(): def test_format_string_uri_input_local_file(): file_uri_input = _Job._format_string_uri_input(LOCAL_FILE_NAME) - assert file_uri_input.config['DataSource']['FileDataSource']['FileUri'] == LOCAL_FILE_NAME + assert file_uri_input.config["DataSource"]["FileDataSource"]["FileUri"] == LOCAL_FILE_NAME def test_format_string_uri_input(): @@ -435,8 +483,10 @@ def test_format_string_uri_input(): s3_uri_input = _Job._format_string_uri_input(inputs) - assert s3_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == inputs.config[ - 'DataSource']['S3DataSource']['S3Uri'] + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] + == inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ) def test_format_string_uri_input_exception(): @@ -451,13 +501,13 @@ def test_format_model_uri_input_string(): model_uri_input = _Job._format_model_uri_input(model_uri) - assert model_uri_input.config['DataSource']['S3DataSource']['S3Uri'] == model_uri + assert model_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == model_uri def test_format_model_uri_input_local_file(): model_uri_input = _Job._format_model_uri_input(LOCAL_MODEL_NAME) - assert model_uri_input.config['DataSource']['FileDataSource']['FileUri'] == LOCAL_MODEL_NAME + assert model_uri_input.config["DataSource"]["FileDataSource"]["FileUri"] == LOCAL_MODEL_NAME def test_format_model_uri_input_exception(): @@ -468,12 +518,12 @@ def test_format_model_uri_input_exception(): def test_prepare_output_config(): - kms_key_id = 'kms_key' + kms_key_id = "kms_key" config = _Job._prepare_output_config(BUCKET_NAME, kms_key_id) - assert config['S3OutputPath'] == BUCKET_NAME - assert config['KmsKeyId'] == kms_key_id + assert config["S3OutputPath"] == BUCKET_NAME + assert config["KmsKeyId"] == kms_key_id def test_prepare_output_config_kms_key_none(): @@ -482,28 +532,32 @@ def test_prepare_output_config_kms_key_none(): config = _Job._prepare_output_config(s3_path, kms_key_id) - assert config['S3OutputPath'] == s3_path - assert 'KmsKeyId' not in config + assert config["S3OutputPath"] == s3_path + assert "KmsKeyId" not in config def test_prepare_resource_config(): - resource_config = _Job._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, None) + resource_config = _Job._prepare_resource_config( + INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, None + ) assert resource_config == { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeSizeInGB': VOLUME_SIZE + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": VOLUME_SIZE, } def test_prepare_resource_config_with_volume_kms(): - resource_config = _Job._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, VOLUME_KMS_KEY) + resource_config = _Job._prepare_resource_config( + INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, VOLUME_KMS_KEY + ) assert resource_config == { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeSizeInGB': VOLUME_SIZE, - 'VolumeKmsKeyId': VOLUME_KMS_KEY + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": VOLUME_SIZE, + "VolumeKmsKeyId": VOLUME_KMS_KEY, } @@ -512,7 +566,7 @@ def test_prepare_stop_condition(): stop_condition = _Job._prepare_stop_condition(max_run) - assert stop_condition['MaxRuntimeInSeconds'] == max_run + assert stop_condition["MaxRuntimeInSeconds"] == max_run def test_name(sagemaker_session): diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index dc87a38ec1..156920e787 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -18,43 +18,43 @@ from sagemaker.amazon.kmeans import KMeans, KMeansPredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" K = 2 -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'k': K}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict({"k": K}, **COMMON_TRAIN_ARGS) -REGION = 'us-west-2' -BUCKET_NAME = 'Some-Bucket' +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': 's3://bucket/model.tar.gz' - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -62,7 +62,9 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): - kmeans = KMeans(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, K, sagemaker_session=sagemaker_session) + kmeans = KMeans( + ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, K, sagemaker_session=sagemaker_session + ) assert kmeans.role == ROLE assert kmeans.train_instance_count == TRAIN_INSTANCE_COUNT assert kmeans.train_instance_type == TRAIN_INSTANCE_TYPE @@ -72,39 +74,47 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert kmeans.role == COMMON_TRAIN_ARGS['role'] + assert kmeans.role == COMMON_TRAIN_ARGS["role"] assert kmeans.train_instance_count == TRAIN_INSTANCE_COUNT - assert kmeans.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] - assert kmeans.k == ALL_REQ_ARGS['k'] + assert kmeans.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] + assert kmeans.k == ALL_REQ_ARGS["k"] def test_all_hyperparameters(sagemaker_session): - kmeans = KMeans(sagemaker_session=sagemaker_session, init_method='random', max_iterations=3, tol=0.5, - num_trials=5, local_init_method='kmeans++', half_life_time_size=0, epochs=10, center_factor=2, - eval_metrics=['msd', 'ssd'], **ALL_REQ_ARGS) + kmeans = KMeans( + sagemaker_session=sagemaker_session, + init_method="random", + max_iterations=3, + tol=0.5, + num_trials=5, + local_init_method="kmeans++", + half_life_time_size=0, + epochs=10, + center_factor=2, + eval_metrics=["msd", "ssd"], + **ALL_REQ_ARGS + ) assert kmeans.hyperparameters() == dict( - k=str(ALL_REQ_ARGS['k']), - init_method='random', - local_lloyd_max_iter='3', - local_lloyd_tol='0.5', - local_lloyd_num_trials='5', - local_lloyd_init_method='kmeans++', - half_life_time_size='0', - epochs='10', - extra_center_factor='2', - eval_metrics='[\'msd\', \'ssd\']', - force_dense='True', + k=str(ALL_REQ_ARGS["k"]), + init_method="random", + local_lloyd_max_iter="3", + local_lloyd_tol="0.5", + local_lloyd_num_trials="5", + local_lloyd_init_method="kmeans++", + half_life_time_size="0", + epochs="10", + extra_center_factor="2", + eval_metrics="['msd', 'ssd']", + force_dense="True", ) def test_image(sagemaker_session): kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert kmeans.train_image() == registry(REGION, 'kmeans') + '/kmeans:1' + assert kmeans.train_image() == registry(REGION, "kmeans") + "/kmeans:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('k', 'string') -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("k", "string")]) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -112,9 +122,7 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame KMeans(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('k', 0) -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("k", 0)]) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -122,9 +130,7 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param KMeans(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('iterable_hyper_parameters, value', [ - ('eval_metrics', 0) -]) +@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)]) def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value): with pytest.raises(TypeError): test_params = ALL_REQ_ARGS.copy() @@ -132,16 +138,19 @@ def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parame KMeans(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('init_method', 0), - ('max_iterations', 'string'), - ('tol', 'string'), - ('num_trials', 'string'), - ('local_init_method', 0), - ('half_life_time_size', 'string'), - ('epochs', 'string'), - ('center_factor', 'string') -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("init_method", 0), + ("max_iterations", "string"), + ("tol", "string"), + ("num_trials", "string"), + ("local_init_method", 0), + ("half_life_time_size", "string"), + ("epochs", "string"), + ("center_factor", "string"), + ], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -149,17 +158,20 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame KMeans(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('init_method', 'string'), - ('max_iterations', 0), - ('tol', -0.1), - ('tol', 1.1), - ('num_trials', 0), - ('local_init_method', 'string'), - ('half_life_time_size', -1), - ('epochs', 0), - ('center_factor', 0) -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("init_method", "string"), + ("max_iterations", 0), + ("tol", -0.1), + ("tol", 1.1), + ("num_trials", 0), + ("local_init_method", "string"), + ("half_life_time_size", -1), + ("epochs", 0), + ("center_factor", 0), + ], +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -167,16 +179,21 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param KMeans(sagemaker_session=sagemaker_session, **test_params) -PREFIX = 'prefix' +PREFIX = "prefix" FEATURE_DIM = 10 MINI_BATCH_SIZE = 200 -@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit') +@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit(base_fit, sagemaker_session): - kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) kmeans.fit(data, MINI_BATCH_SIZE) @@ -187,46 +204,68 @@ def test_call_fit(base_fit, sagemaker_session): def test_prepare_for_training_no_mini_batch_size(sagemaker_session): - kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) kmeans._prepare_for_training(data) assert kmeans.mini_batch_size == 5000 def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): - kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): - kmeans._prepare_for_training(data, 'some') + kmeans._prepare_for_training(data, "some") def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session): - kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): kmeans._prepare_for_training(data, 0) def test_model_image(sagemaker_session): kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) kmeans.fit(data, MINI_BATCH_SIZE) model = kmeans.create_model() - assert model.image == registry(REGION, 'kmeans') + '/kmeans:1' + assert model.image == registry(REGION, "kmeans") + "/kmeans:1" def test_predictor_type(sagemaker_session): kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) kmeans.fit(data, MINI_BATCH_SIZE) model = kmeans.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index fad5ae64b1..547e47735f 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -18,47 +18,49 @@ from sagemaker.amazon.knn import KNN, KNNPredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" K = 5 SAMPLE_SIZE = 1000 -PREDICTOR_TYPE_REGRESSOR = 'regressor' -PREDICTOR_TYPE_CLASSIFIER = 'classifier' +PREDICTOR_TYPE_REGRESSOR = "regressor" +PREDICTOR_TYPE_CLASSIFIER = "classifier" -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'k': K, 'sample_size': SAMPLE_SIZE, - 'predictor_type': PREDICTOR_TYPE_REGRESSOR}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict( + {"k": K, "sample_size": SAMPLE_SIZE, "predictor_type": PREDICTOR_TYPE_REGRESSOR}, + **COMMON_TRAIN_ARGS +) REGION = "us-west-2" BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': "s3://bucket/model.tar.gz" - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -66,9 +68,15 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): - knn = KNN(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, - K, SAMPLE_SIZE, PREDICTOR_TYPE_REGRESSOR, - sagemaker_session=sagemaker_session) + knn = KNN( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + K, + SAMPLE_SIZE, + PREDICTOR_TYPE_REGRESSOR, + sagemaker_session=sagemaker_session, + ) assert knn.role == ROLE assert knn.train_instance_count == TRAIN_INSTANCE_COUNT assert knn.train_instance_type == TRAIN_INSTANCE_TYPE @@ -78,58 +86,70 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert knn.role == COMMON_TRAIN_ARGS['role'] + assert knn.role == COMMON_TRAIN_ARGS["role"] assert knn.train_instance_count == TRAIN_INSTANCE_COUNT - assert knn.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] - assert knn.k == ALL_REQ_ARGS['k'] + assert knn.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] + assert knn.k == ALL_REQ_ARGS["k"] def test_all_hyperparameters_regressor(sagemaker_session): - knn = KNN(sagemaker_session=sagemaker_session, - dimension_reduction_type='sign', dimension_reduction_target='2', index_type='faiss.Flat', - index_metric='COSINE', faiss_index_ivf_nlists='auto', faiss_index_pq_m=1, **ALL_REQ_ARGS) + knn = KNN( + sagemaker_session=sagemaker_session, + dimension_reduction_type="sign", + dimension_reduction_target="2", + index_type="faiss.Flat", + index_metric="COSINE", + faiss_index_ivf_nlists="auto", + faiss_index_pq_m=1, + **ALL_REQ_ARGS + ) assert knn.hyperparameters() == dict( - k=str(ALL_REQ_ARGS['k']), - sample_size=str(ALL_REQ_ARGS['sample_size']), - predictor_type=str(ALL_REQ_ARGS['predictor_type']), - dimension_reduction_type='sign', - dimension_reduction_target='2', - index_type='faiss.Flat', - index_metric='COSINE', - faiss_index_ivf_nlists='auto', - faiss_index_pq_m='1' + k=str(ALL_REQ_ARGS["k"]), + sample_size=str(ALL_REQ_ARGS["sample_size"]), + predictor_type=str(ALL_REQ_ARGS["predictor_type"]), + dimension_reduction_type="sign", + dimension_reduction_target="2", + index_type="faiss.Flat", + index_metric="COSINE", + faiss_index_ivf_nlists="auto", + faiss_index_pq_m="1", ) def test_all_hyperparameters_classifier(sagemaker_session): test_params = ALL_REQ_ARGS.copy() - test_params['predictor_type'] = PREDICTOR_TYPE_CLASSIFIER - - knn = KNN(sagemaker_session=sagemaker_session, - dimension_reduction_type='fjlt', dimension_reduction_target='2', index_type='faiss.IVFFlat', - index_metric='L2', faiss_index_ivf_nlists='20', **test_params) + test_params["predictor_type"] = PREDICTOR_TYPE_CLASSIFIER + + knn = KNN( + sagemaker_session=sagemaker_session, + dimension_reduction_type="fjlt", + dimension_reduction_target="2", + index_type="faiss.IVFFlat", + index_metric="L2", + faiss_index_ivf_nlists="20", + **test_params + ) assert knn.hyperparameters() == dict( - k=str(ALL_REQ_ARGS['k']), - sample_size=str(ALL_REQ_ARGS['sample_size']), + k=str(ALL_REQ_ARGS["k"]), + sample_size=str(ALL_REQ_ARGS["sample_size"]), predictor_type=str(PREDICTOR_TYPE_CLASSIFIER), - dimension_reduction_type='fjlt', - dimension_reduction_target='2', - index_type='faiss.IVFFlat', - index_metric='L2', - faiss_index_ivf_nlists='20' + dimension_reduction_type="fjlt", + dimension_reduction_target="2", + index_type="faiss.IVFFlat", + index_metric="L2", + faiss_index_ivf_nlists="20", ) def test_image(sagemaker_session): knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert knn.train_image() == registry(REGION, "knn") + '/knn:1' + assert knn.train_image() == registry(REGION, "knn") + "/knn:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('k', 'string'), - ('sample_size', 'string'), - ('predictor_type', 1) -]) +@pytest.mark.parametrize( + "required_hyper_parameters, value", + [("k", "string"), ("sample_size", "string"), ("predictor_type", 1)], +) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -137,9 +157,7 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame KNN(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('predictor_type', 'random_string') -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("predictor_type", "random_string")]) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -147,10 +165,9 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param KNN(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('iterable_hyper_parameters, value', [ - ('index_type', 1), - ('index_metric', 'string') -]) +@pytest.mark.parametrize( + "iterable_hyper_parameters, value", [("index_type", 1), ("index_metric", "string")] +) def test_error_optional_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -158,11 +175,10 @@ def test_error_optional_hyper_parameters_type(sagemaker_session, iterable_hyper_ KNN(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('index_type', 'faiss.random'), - ('index_metric', 'randomstring'), - ('faiss_index_pq_m', -1) -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [("index_type", "faiss.random"), ("index_metric", "randomstring"), ("faiss_index_pq_m", -1)], +) def test_error_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -170,13 +186,16 @@ def test_error_optional_hyper_parameters_value(sagemaker_session, optional_hyper KNN(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('conditional_hyper_parameters', [ - {'dimension_reduction_type': 'sign'}, # errors due to missing dimension_reduction_target - {'dimension_reduction_type': 'sign', 'dimension_reduction_target': -2}, - {'dimension_reduction_type': 'sign', 'dimension_reduction_target': 'string'}, - {'dimension_reduction_type': 2, 'dimension_reduction_target': 20}, - {'dimension_reduction_type': 'randomstring', 'dimension_reduction_target': 20} -]) +@pytest.mark.parametrize( + "conditional_hyper_parameters", + [ + {"dimension_reduction_type": "sign"}, # errors due to missing dimension_reduction_target + {"dimension_reduction_type": "sign", "dimension_reduction_target": -2}, + {"dimension_reduction_type": "sign", "dimension_reduction_target": "string"}, + {"dimension_reduction_type": 2, "dimension_reduction_target": 20}, + {"dimension_reduction_type": "randomstring", "dimension_reduction_target": 20}, + ], +) def test_error_conditional_hyper_parameters_value(sagemaker_session, conditional_hyper_parameters): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -193,7 +212,12 @@ def test_error_conditional_hyper_parameters_value(sagemaker_session, conditional def test_call_fit(base_fit, sagemaker_session): knn = KNN(base_job_name="knn", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) knn.fit(data, MINI_BATCH_SIZE) @@ -206,16 +230,24 @@ def test_call_fit(base_fit, sagemaker_session): def test_call_fit_none_mini_batch_size(sagemaker_session): knn = KNN(base_job_name="knn", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) knn.fit(data) def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): knn = KNN(base_job_name="knn", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): knn._prepare_for_training(data, "some") @@ -224,24 +256,38 @@ def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): def test_prepare_for_training_wrong_value_lower_mini_batch_size(sagemaker_session): knn = KNN(base_job_name="knn", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): knn._prepare_for_training(data, 0) def test_model_image(sagemaker_session): knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) knn.fit(data, MINI_BATCH_SIZE) model = knn.create_model() - assert model.image == registry(REGION, "knn") + '/knn:1' + assert model.image == registry(REGION, "knn") + "/knn:1" def test_predictor_type(sagemaker_session): knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) knn.fit(data, MINI_BATCH_SIZE) model = knn.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index eed9902292..dc975a7fa3 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -18,41 +18,33 @@ from sagemaker.amazon.lda import LDA, LDAPredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" NUM_TOPICS = 3 -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'num_topics': NUM_TOPICS}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = {"role": ROLE, "train_instance_type": TRAIN_INSTANCE_TYPE} +ALL_REQ_ARGS = dict({"num_topics": NUM_TOPICS}, **COMMON_TRAIN_ARGS) -REGION = 'us-west-2' -BUCKET_NAME = 'Some-Bucket' +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': 's3://bucket/model.tar.gz' - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock(name="sagemaker_session", boto_session=boto_mock, config=None, local_mode=False) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -70,33 +62,36 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert lda.role == COMMON_TRAIN_ARGS['role'] + assert lda.role == COMMON_TRAIN_ARGS["role"] assert lda.train_instance_count == TRAIN_INSTANCE_COUNT - assert lda.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] - assert lda.num_topics == ALL_REQ_ARGS['num_topics'] + assert lda.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] + assert lda.num_topics == ALL_REQ_ARGS["num_topics"] def test_all_hyperparameters(sagemaker_session): - lda = LDA(sagemaker_session=sagemaker_session, - alpha0=2.2, max_restarts=3, max_iterations=10, tol=3.3, - **ALL_REQ_ARGS) + lda = LDA( + sagemaker_session=sagemaker_session, + alpha0=2.2, + max_restarts=3, + max_iterations=10, + tol=3.3, + **ALL_REQ_ARGS + ) assert lda.hyperparameters() == dict( - num_topics=str(ALL_REQ_ARGS['num_topics']), - alpha0='2.2', - max_restarts='3', - max_iterations='10', - tol='3.3', + num_topics=str(ALL_REQ_ARGS["num_topics"]), + alpha0="2.2", + max_restarts="3", + max_iterations="10", + tol="3.3", ) def test_image(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert lda.train_image() == registry(REGION, 'lda') + '/lda:1' + assert lda.train_image() == registry(REGION, "lda") + "/lda:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_topics', 'string') -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("num_topics", "string")]) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -104,9 +99,7 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame LDA(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_topics', 0) -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("num_topics", 0)]) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -114,12 +107,15 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param LDA(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('alpha0', 'string'), - ('max_restarts', 'string'), - ('max_iterations', 'string'), - ('tol', 'string') -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("alpha0", "string"), + ("max_restarts", "string"), + ("max_iterations", "string"), + ("tol", "string"), + ], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -127,11 +123,9 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame LDA(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('max_restarts', 0), - ('max_iterations', 0), - ('tol', 0) -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", [("max_restarts", 0), ("max_iterations", 0), ("tol", 0)] +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -139,16 +133,21 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param LDA(sagemaker_session=sagemaker_session, **test_params) -PREFIX = 'prefix' +PREFIX = "prefix" FEATURE_DIM = 10 MINI_BATCH_SZIE = 200 -@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit') +@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit(base_fit, sagemaker_session): - lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) lda.fit(data, MINI_BATCH_SZIE) @@ -159,45 +158,67 @@ def test_call_fit(base_fit, sagemaker_session): def test_prepare_for_training_no_mini_batch_size(sagemaker_session): - lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): lda._prepare_for_training(data, None) def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): - lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): - lda._prepare_for_training(data, 'some') + lda._prepare_for_training(data, "some") def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session): - lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): lda._prepare_for_training(data, 0) def test_model_image(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) lda.fit(data, MINI_BATCH_SZIE) model = lda.create_model() - assert model.image == registry(REGION, 'lda') + '/lda:1' + assert model.image == registry(REGION, "lda") + "/lda:1" def test_predictor_type(sagemaker_session): lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) lda.fit(data, MINI_BATCH_SZIE) model = lda.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index 1e5fa64f0f..3767d861cc 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -18,44 +18,44 @@ from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerPredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" -PREDICTOR_TYPE = 'binary_classifier' +PREDICTOR_TYPE = "binary_classifier" -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'predictor_type': PREDICTOR_TYPE}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict({"predictor_type": PREDICTOR_TYPE}, **COMMON_TRAIN_ARGS) -REGION = 'us-west-2' -BUCKET_NAME = 'Some-Bucket' +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': 's3://bucket/model.tar.gz' - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -63,8 +63,13 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): - lr = LinearLearner(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, PREDICTOR_TYPE, - sagemaker_session=sagemaker_session) + lr = LinearLearner( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + PREDICTOR_TYPE, + sagemaker_session=sagemaker_session, + ) assert lr.role == ROLE assert lr.train_instance_count == TRAIN_INSTANCE_COUNT assert lr.train_instance_type == TRAIN_INSTANCE_TYPE @@ -74,50 +79,111 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert lr.role == ALL_REQ_ARGS['role'] - assert lr.train_instance_count == ALL_REQ_ARGS['train_instance_count'] - assert lr.train_instance_type == ALL_REQ_ARGS['train_instance_type'] - assert lr.predictor_type == ALL_REQ_ARGS['predictor_type'] + assert lr.role == ALL_REQ_ARGS["role"] + assert lr.train_instance_count == ALL_REQ_ARGS["train_instance_count"] + assert lr.train_instance_type == ALL_REQ_ARGS["train_instance_type"] + assert lr.predictor_type == ALL_REQ_ARGS["predictor_type"] def test_all_hyperparameters(sagemaker_session): - lr = LinearLearner(sagemaker_session=sagemaker_session, - binary_classifier_model_selection_criteria='accuracy', - target_recall=0.5, target_precision=0.6, - positive_example_weight_mult=0.1, epochs=1, use_bias=True, num_models=5, - num_calibration_samples=6, init_method='uniform', init_scale=0.1, init_sigma=0.001, - init_bias=0, optimizer='sgd', loss='logistic', wd=0.4, l1=0.04, momentum=0.1, - learning_rate=0.001, beta_1=0.2, beta_2=0.03, bias_lr_mult=5.5, bias_wd_mult=6.6, - use_lr_scheduler=False, lr_scheduler_step=2, lr_scheduler_factor=0.03, - lr_scheduler_minimum_lr=0.001, normalize_data=False, normalize_label=True, - unbias_data=True, unbias_label=False, num_point_for_scaler=3, margin=1.0, - quantile=0.5, loss_insensitivity=0.1, huber_delta=0.1, early_stopping_patience=3, - early_stopping_tolerance=0.001, num_classes=1, accuracy_top_k=3, f_beta=1.0, - balance_multiclass_weights=False, **ALL_REQ_ARGS) + lr = LinearLearner( + sagemaker_session=sagemaker_session, + binary_classifier_model_selection_criteria="accuracy", + target_recall=0.5, + target_precision=0.6, + positive_example_weight_mult=0.1, + epochs=1, + use_bias=True, + num_models=5, + num_calibration_samples=6, + init_method="uniform", + init_scale=0.1, + init_sigma=0.001, + init_bias=0, + optimizer="sgd", + loss="logistic", + wd=0.4, + l1=0.04, + momentum=0.1, + learning_rate=0.001, + beta_1=0.2, + beta_2=0.03, + bias_lr_mult=5.5, + bias_wd_mult=6.6, + use_lr_scheduler=False, + lr_scheduler_step=2, + lr_scheduler_factor=0.03, + lr_scheduler_minimum_lr=0.001, + normalize_data=False, + normalize_label=True, + unbias_data=True, + unbias_label=False, + num_point_for_scaler=3, + margin=1.0, + quantile=0.5, + loss_insensitivity=0.1, + huber_delta=0.1, + early_stopping_patience=3, + early_stopping_tolerance=0.001, + num_classes=1, + accuracy_top_k=3, + f_beta=1.0, + balance_multiclass_weights=False, + **ALL_REQ_ARGS + ) assert lr.hyperparameters() == dict( - predictor_type='binary_classifier', binary_classifier_model_selection_criteria='accuracy', - target_recall='0.5', target_precision='0.6', positive_example_weight_mult='0.1', epochs='1', - use_bias='True', num_models='5', num_calibration_samples='6', init_method='uniform', - init_scale='0.1', init_sigma='0.001', init_bias='0.0', optimizer='sgd', loss='logistic', - wd='0.4', l1='0.04', momentum='0.1', learning_rate='0.001', beta_1='0.2', beta_2='0.03', - bias_lr_mult='5.5', bias_wd_mult='6.6', use_lr_scheduler='False', lr_scheduler_step='2', - lr_scheduler_factor='0.03', lr_scheduler_minimum_lr='0.001', normalize_data='False', - normalize_label='True', unbias_data='True', unbias_label='False', num_point_for_scaler='3', margin='1.0', - quantile='0.5', loss_insensitivity='0.1', huber_delta='0.1', early_stopping_patience='3', - early_stopping_tolerance='0.001', num_classes='1', accuracy_top_k='3', f_beta='1.0', - balance_multiclass_weights='False', + predictor_type="binary_classifier", + binary_classifier_model_selection_criteria="accuracy", + target_recall="0.5", + target_precision="0.6", + positive_example_weight_mult="0.1", + epochs="1", + use_bias="True", + num_models="5", + num_calibration_samples="6", + init_method="uniform", + init_scale="0.1", + init_sigma="0.001", + init_bias="0.0", + optimizer="sgd", + loss="logistic", + wd="0.4", + l1="0.04", + momentum="0.1", + learning_rate="0.001", + beta_1="0.2", + beta_2="0.03", + bias_lr_mult="5.5", + bias_wd_mult="6.6", + use_lr_scheduler="False", + lr_scheduler_step="2", + lr_scheduler_factor="0.03", + lr_scheduler_minimum_lr="0.001", + normalize_data="False", + normalize_label="True", + unbias_data="True", + unbias_label="False", + num_point_for_scaler="3", + margin="1.0", + quantile="0.5", + loss_insensitivity="0.1", + huber_delta="0.1", + early_stopping_patience="3", + early_stopping_tolerance="0.001", + num_classes="1", + accuracy_top_k="3", + f_beta="1.0", + balance_multiclass_weights="False", ) def test_image(sagemaker_session): lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert lr.train_image() == registry(REGION, 'linear-learner') + '/linear-learner:1' + assert lr.train_image() == registry(REGION, "linear-learner") + "/linear-learner:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('predictor_type', 0) -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("predictor_type", 0)]) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -125,9 +191,7 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame LinearLearner(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('predictor_type', 'string') -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("predictor_type", "string")]) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -138,15 +202,15 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param def test_num_classes_is_required_for_multiclass_classifier(sagemaker_session): with pytest.raises(ValueError) as excinfo: test_params = ALL_REQ_ARGS.copy() - test_params["predictor_type"] = 'multiclass_classifier' + test_params["predictor_type"] = "multiclass_classifier" LinearLearner(sagemaker_session=sagemaker_session, **test_params) - assert "For predictor_type 'multiclass_classifier', 'num_classes' should be set to a value greater than 2." in str( - excinfo.value) + assert ( + "For predictor_type 'multiclass_classifier', 'num_classes' should be set to a value greater than 2." + in str(excinfo.value) + ) -@pytest.mark.parametrize('iterable_hyper_parameters, value', [ - ('eval_metrics', 0) -]) +@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)]) def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value): with pytest.raises(TypeError): test_params = ALL_REQ_ARGS.copy() @@ -154,41 +218,44 @@ def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parame LinearLearner(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('binary_classifier_model_selection_criteria', 0), - ('target_recall', 'string'), - ('target_precision', 'string'), - ('epochs', 'string'), - ('num_models', 'string'), - ('num_calibration_samples', 'string'), - ('init_method', 0), - ('init_scale', 'string'), - ('init_sigma', 'string'), - ('init_bias', 'string'), - ('optimizer', 0), - ('loss', 0), - ('wd', 'string'), - ('l1', 'string'), - ('momentum', 'string'), - ('learning_rate', 'string'), - ('beta_1', 'string'), - ('beta_2', 'string'), - ('bias_lr_mult', 'string'), - ('bias_wd_mult', 'string'), - ('lr_scheduler_step', 'string'), - ('lr_scheduler_factor', 'string'), - ('lr_scheduler_minimum_lr', 'string'), - ('num_point_for_scaler', 'string'), - ('margin', 'string'), - ('quantile', 'string'), - ('loss_insensitivity', 'string'), - ('huber_delta', 'string'), - ('early_stopping_patience', 'string'), - ('early_stopping_tolerance', 'string'), - ('num_classes', 'string'), - ('accuracy_top_k', 'string'), - ('f_beta', 'string'), -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("binary_classifier_model_selection_criteria", 0), + ("target_recall", "string"), + ("target_precision", "string"), + ("epochs", "string"), + ("num_models", "string"), + ("num_calibration_samples", "string"), + ("init_method", 0), + ("init_scale", "string"), + ("init_sigma", "string"), + ("init_bias", "string"), + ("optimizer", 0), + ("loss", 0), + ("wd", "string"), + ("l1", "string"), + ("momentum", "string"), + ("learning_rate", "string"), + ("beta_1", "string"), + ("beta_2", "string"), + ("bias_lr_mult", "string"), + ("bias_wd_mult", "string"), + ("lr_scheduler_step", "string"), + ("lr_scheduler_factor", "string"), + ("lr_scheduler_minimum_lr", "string"), + ("num_point_for_scaler", "string"), + ("margin", "string"), + ("quantile", "string"), + ("loss_insensitivity", "string"), + ("huber_delta", "string"), + ("early_stopping_patience", "string"), + ("early_stopping_tolerance", "string"), + ("num_classes", "string"), + ("accuracy_top_k", "string"), + ("f_beta", "string"), + ], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -196,45 +263,47 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame LinearLearner(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('binary_classifier_model_selection_criteria', 'string'), - ('target_recall', 0), - ('target_recall', 1), - ('target_precision', 0), - ('target_precision', 1), - ('epochs', 0), - ('num_models', 0), - ('num_calibration_samples', 0), - ('init_method', 'string'), - ('init_scale', 0), - ('init_sigma', 0), - ('optimizer', 'string'), - ('loss', 'string'), - ('wd', -1), - ('l1', -1), - ('momentum', 1), - ('learning_rate', 0), - ('beta_1', 1), - ('beta_2', 1), - ('bias_lr_mult', 0), - ('bias_wd_mult', -1), - ('lr_scheduler_step', 0), - ('lr_scheduler_factor', 0), - ('lr_scheduler_factor', 1), - ('lr_scheduler_minimum_lr', 0), - ('num_point_for_scaler', 0), - ('margin', -1), - ('quantile', 0), - ('quantile', 1), - ('loss_insensitivity', 0), - ('huber_delta', -1), - ('early_stopping_patience', 0), - ('early_stopping_tolerance', 0), - ('num_classes', 0), - ('accuracy_top_k', 0), - ('f_beta', -1.0), - -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("binary_classifier_model_selection_criteria", "string"), + ("target_recall", 0), + ("target_recall", 1), + ("target_precision", 0), + ("target_precision", 1), + ("epochs", 0), + ("num_models", 0), + ("num_calibration_samples", 0), + ("init_method", "string"), + ("init_scale", 0), + ("init_sigma", 0), + ("optimizer", "string"), + ("loss", "string"), + ("wd", -1), + ("l1", -1), + ("momentum", 1), + ("learning_rate", 0), + ("beta_1", 1), + ("beta_2", 1), + ("bias_lr_mult", 0), + ("bias_wd_mult", -1), + ("lr_scheduler_step", 0), + ("lr_scheduler_factor", 0), + ("lr_scheduler_factor", 1), + ("lr_scheduler_minimum_lr", 0), + ("num_point_for_scaler", 0), + ("margin", -1), + ("quantile", 0), + ("quantile", 1), + ("loss_insensitivity", 0), + ("huber_delta", -1), + ("early_stopping_patience", 0), + ("early_stopping_tolerance", 0), + ("num_classes", 0), + ("accuracy_top_k", 0), + ("f_beta", -1.0), + ], +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -242,15 +311,20 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param LinearLearner(sagemaker_session=sagemaker_session, **test_params) -PREFIX = 'prefix' +PREFIX = "prefix" FEATURE_DIM = 10 DEFAULT_MINI_BATCH_SIZE = 1000 def test_prepare_for_training_calculate_batch_size_1(sagemaker_session): - lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lr = LinearLearner(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) lr._prepare_for_training(data) @@ -258,12 +332,14 @@ def test_prepare_for_training_calculate_batch_size_1(sagemaker_session): def test_prepare_for_training_calculate_batch_size_2(sagemaker_session): - lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lr = LinearLearner(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), - num_records=10000, - feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=10000, + feature_dim=FEATURE_DIM, + channel="train", + ) lr._prepare_for_training(data) @@ -271,12 +347,14 @@ def test_prepare_for_training_calculate_batch_size_2(sagemaker_session): def test_prepare_for_training_multiple_channel(sagemaker_session): - lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lr = LinearLearner(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), - num_records=10000, - feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=10000, + feature_dim=FEATURE_DIM, + channel="train", + ) lr._prepare_for_training([data, data]) @@ -284,27 +362,31 @@ def test_prepare_for_training_multiple_channel(sagemaker_session): def test_prepare_for_training_multiple_channel_no_train(sagemaker_session): - lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lr = LinearLearner(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), - num_records=10000, - feature_dim=FEATURE_DIM, - channel='mock') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=10000, + feature_dim=FEATURE_DIM, + channel="mock", + ) with pytest.raises(ValueError) as ex: lr._prepare_for_training([data, data]) - assert 'Must provide train channel.' in str(ex) + assert "Must provide train channel." in str(ex) -@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit') +@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit_pass_batch_size(base_fit, sagemaker_session): - lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lr = LinearLearner(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), - num_records=10000, - feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=10000, + feature_dim=FEATURE_DIM, + channel="train", + ) lr.fit(data, 10) @@ -316,16 +398,26 @@ def test_call_fit_pass_batch_size(base_fit, sagemaker_session): def test_model_image(sagemaker_session): lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) lr.fit(data) model = lr.create_model() - assert model.image == registry(REGION, 'linear-learner') + '/linear-learner:1' + assert model.image == registry(REGION, "linear-learner") + "/linear-learner:1" def test_predictor_type(sagemaker_session): lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) lr.fit(data) model = lr.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_local_data.py b/tests/unit/test_local_data.py index 10f124ee6d..356ee70d33 100644 --- a/tests/unit/test_local_data.py +++ b/tests/unit/test_local_data.py @@ -21,72 +21,71 @@ import sagemaker.local.data -@patch('sagemaker.local.data.LocalFileDataSource') +@patch("sagemaker.local.data.LocalFileDataSource") def test_get_data_source_instance_with_file(LocalFileDataSource, sagemaker_local_session): # file - data_source = sagemaker.local.data.get_data_source_instance('file:///my/file', sagemaker_local_session) - LocalFileDataSource.assert_called_with('/my/file') + data_source = sagemaker.local.data.get_data_source_instance( + "file:///my/file", sagemaker_local_session + ) + LocalFileDataSource.assert_called_with("/my/file") assert data_source is not None - data_source = sagemaker.local.data.get_data_source_instance('file://relative/path', sagemaker_local_session) - LocalFileDataSource.assert_called_with('relative/path') + data_source = sagemaker.local.data.get_data_source_instance( + "file://relative/path", sagemaker_local_session + ) + LocalFileDataSource.assert_called_with("relative/path") assert data_source is not None -@patch('sagemaker.local.data.S3DataSource') +@patch("sagemaker.local.data.S3DataSource") def test_get_data_source_instance_with_s3(S3DataSource, sagemaker_local_session): - data_source = sagemaker.local.data.get_data_source_instance('s3://bucket/path', sagemaker_local_session) - S3DataSource.assert_called_with('bucket', '/path', sagemaker_local_session) + data_source = sagemaker.local.data.get_data_source_instance( + "s3://bucket/path", sagemaker_local_session + ) + S3DataSource.assert_called_with("bucket", "/path", sagemaker_local_session) assert data_source is not None -@patch('os.path.exists', Mock(return_value=True)) -@patch('os.path.abspath', lambda x: x) -@patch('os.path.isdir', lambda x: x[-1] == '/') -@patch('os.path.isfile', lambda x: x[-1] != '/') -@patch('os.listdir') +@patch("os.path.exists", Mock(return_value=True)) +@patch("os.path.abspath", lambda x: x) +@patch("os.path.isdir", lambda x: x[-1] == "/") +@patch("os.path.isfile", lambda x: x[-1] != "/") +@patch("os.listdir") def test_file_data_source_get_file_list_with_folder(listdir): - data_source = sagemaker.local.data.LocalFileDataSource('/some/path/') - listdir.return_value = [ - '/some/path/a', - '/some/path/b', - '/some/path/c/', - '/some/path/c/a' - ] - expected = [ - '/some/path/a', - '/some/path/b', - '/some/path/c/a' - ] + data_source = sagemaker.local.data.LocalFileDataSource("/some/path/") + listdir.return_value = ["/some/path/a", "/some/path/b", "/some/path/c/", "/some/path/c/a"] + expected = ["/some/path/a", "/some/path/b", "/some/path/c/a"] result = data_source.get_file_list() assert result == expected -@patch('os.path.exists', Mock(return_value=True)) -@patch('os.path.abspath', lambda x: x) -@patch('os.path.isdir', lambda x: x[-1] == '/') -@patch('os.path.isfile', lambda x: x[-1] != '/') +@patch("os.path.exists", Mock(return_value=True)) +@patch("os.path.abspath", lambda x: x) +@patch("os.path.isdir", lambda x: x[-1] == "/") +@patch("os.path.isfile", lambda x: x[-1] != "/") def test_file_data_source_get_file_list_with_single_file(): - data_source = sagemaker.local.data.LocalFileDataSource('/some/batch/file.csv') - assert data_source.get_file_list() == ['/some/batch/file.csv'] + data_source = sagemaker.local.data.LocalFileDataSource("/some/batch/file.csv") + assert data_source.get_file_list() == ["/some/batch/file.csv"] -@patch('os.path.exists', Mock(return_value=True)) -@patch('os.path.abspath', lambda x: x) -@patch('os.path.isdir', lambda x: x[-1] == '/') +@patch("os.path.exists", Mock(return_value=True)) +@patch("os.path.abspath", lambda x: x) +@patch("os.path.isdir", lambda x: x[-1] == "/") def test_file_data_source_get_root(): - data_source = sagemaker.local.data.LocalFileDataSource('/some/path/') - assert data_source.get_root_dir() == '/some/path/' + data_source = sagemaker.local.data.LocalFileDataSource("/some/path/") + assert data_source.get_root_dir() == "/some/path/" - data_source = sagemaker.local.data.LocalFileDataSource('/some/path/my_file.csv') - assert data_source.get_root_dir() == '/some/path' + data_source = sagemaker.local.data.LocalFileDataSource("/some/path/my_file.csv") + assert data_source.get_root_dir() == "/some/path" -@patch('sagemaker.local.data.LocalFileDataSource') -@patch('sagemaker.utils.download_folder') -@patch('tempfile.mkdtemp', lambda dir: '/tmp/working_dir') +@patch("sagemaker.local.data.LocalFileDataSource") +@patch("sagemaker.utils.download_folder") +@patch("tempfile.mkdtemp", lambda dir: "/tmp/working_dir") def test_s3_data_source(download_folder, LocalFileDataSource, sagemaker_local_session): - data_source = sagemaker.local.data.S3DataSource('my_bucket', '/transform/data', sagemaker_local_session) + data_source = sagemaker.local.data.S3DataSource( + "my_bucket", "/transform/data", sagemaker_local_session + ) download_folder.assert_called() data_source.get_file_list() LocalFileDataSource().get_file_list.assert_called() @@ -98,48 +97,48 @@ def test_get_splitter_instance_with_valid_types(): splitter = sagemaker.local.data.get_splitter_instance(None) assert isinstance(splitter, sagemaker.local.data.NoneSplitter) - splitter = sagemaker.local.data.get_splitter_instance('Line') + splitter = sagemaker.local.data.get_splitter_instance("Line") assert isinstance(splitter, sagemaker.local.data.LineSplitter) - splitter = sagemaker.local.data.get_splitter_instance('RecordIO') + splitter = sagemaker.local.data.get_splitter_instance("RecordIO") assert isinstance(splitter, sagemaker.local.data.RecordIOSplitter) def test_get_splitter_instance_with_invalid_types(): with pytest.raises(ValueError): - sagemaker.local.data.get_splitter_instance('SomethingInvalid') + sagemaker.local.data.get_splitter_instance("SomethingInvalid") def test_none_splitter(tmpdir): - test_file_path = tmpdir.join('none_test.txt') + test_file_path = tmpdir.join("none_test.txt") - with test_file_path.open('w') as f: - f.write('this\nis\na\ntest') + with test_file_path.open("w") as f: + f.write("this\nis\na\ntest") splitter = sagemaker.local.data.NoneSplitter() data = [x for x in splitter.split(str(test_file_path))] - assert data == ['this\nis\na\ntest'] + assert data == ["this\nis\na\ntest"] def test_line_splitter(tmpdir): - test_file_path = tmpdir.join('line_test.txt') + test_file_path = tmpdir.join("line_test.txt") - with test_file_path.open('w') as f: + with test_file_path.open("w") as f: for i in range(10): - f.write('%s\n' % i) + f.write("%s\n" % i) splitter = sagemaker.local.data.LineSplitter() data = [x for x in splitter.split(str(test_file_path))] assert len(data) == 10 for i in range(10): - assert data[i] == '%s\n' % str(i) + assert data[i] == "%s\n" % str(i) def test_recordio_splitter(tmpdir): - test_file_path = tmpdir.join('recordio_test.txt') - with test_file_path.open('wb') as f: + test_file_path = tmpdir.join("recordio_test.txt") + with test_file_path.open("wb") as f: for i in range(10): - data = str(i).encode('utf-8') + data = str(i).encode("utf-8") sagemaker.amazon.common._write_recordio(f, data) splitter = sagemaker.local.data.RecordIOSplitter() @@ -150,29 +149,29 @@ def test_recordio_splitter(tmpdir): def test_get_batch_strategy_instance_with_valid_type(): # Single Record - strategy = sagemaker.local.data.get_batch_strategy_instance('SingleRecord', None) + strategy = sagemaker.local.data.get_batch_strategy_instance("SingleRecord", None) assert isinstance(strategy, sagemaker.local.data.SingleRecordStrategy) # Multi Record - strategy = sagemaker.local.data.get_batch_strategy_instance('MultiRecord', None) + strategy = sagemaker.local.data.get_batch_strategy_instance("MultiRecord", None) assert isinstance(strategy, sagemaker.local.data.MultiRecordStrategy) def test_get_batch_strategy_instance_with_invalid_type(): with pytest.raises(ValueError): # something invalid - sagemaker.local.data.get_batch_strategy_instance('NiceRecord', None) + sagemaker.local.data.get_batch_strategy_instance("NiceRecord", None) def test_single_record_strategy_with_small_records(): splitter = Mock() single_record = sagemaker.local.data.SingleRecordStrategy(splitter) - data = ['123', '456', '789'] + data = ["123", "456", "789"] splitter.split.return_value = data # given 3 small records the output should be the same 3 records - batch_records = [r for r in single_record.pad('some_file', 6)] + batch_records = [r for r in single_record.pad("some_file", 6)] assert data == batch_records @@ -183,14 +182,14 @@ def test_single_record_strategy_with_large_records(): single_record = sagemaker.local.data.SingleRecordStrategy(splitter) # We will construct a huge record greater than 1MB and we expect an exception # since there is no way to fit this with the payload size. - buffer = '' + buffer = "" while sys.getsizeof(buffer) < 2 * mb: - buffer += '1' * 100 + buffer += "1" * 100 data = [buffer] with pytest.raises(RuntimeError): splitter.split.return_value = data - batch_records = [r for r in single_record.pad('some_file', 1)] + batch_records = [r for r in single_record.pad("some_file", 1)] print(batch_records) @@ -200,13 +199,13 @@ def test_single_record_strategy_with_no_payload_limit(): splitter = Mock() mb = 1024 * 1024 - buffer = '' + buffer = "" while sys.getsizeof(buffer) < 2 * mb: - buffer += '1' * 100 + buffer += "1" * 100 splitter.split.return_value = [buffer] single_record = sagemaker.local.data.SingleRecordStrategy(splitter) - batch_records = [r for r in single_record.pad('some_file', 0)] + batch_records = [r for r in single_record.pad("some_file", 0)] assert len(batch_records) == 1 @@ -214,13 +213,13 @@ def test_multi_record_strategy_with_small_records(): splitter = Mock() multi_record = sagemaker.local.data.MultiRecordStrategy(splitter) - data = ['123', '456', '789'] + data = ["123", "456", "789"] splitter.split.return_value = data # given 3 small records, the output should be 1 single record with the data from all 3 combined - batch_records = [r for r in multi_record.pad('some_file', 6)] + batch_records = [r for r in multi_record.pad("some_file", 6)] assert len(batch_records) == 1 - assert batch_records[0] == '123456789' + assert batch_records[0] == "123456789" def test_multi_record_strategy_with_large_records(): @@ -229,9 +228,9 @@ def test_multi_record_strategy_with_large_records(): multi_record = sagemaker.local.data.MultiRecordStrategy(splitter) # we will construct several large records and we expect them to be merged into <1MB ones - buffer = '' + buffer = "" while sys.getsizeof(buffer) < 0.5 * mb: - buffer += '1' * 100 + buffer += "1" * 100 # buffer should be aprox 0.5 MB. We will make the data total 10 MB made out of 0.5mb records # with a max_payload size of 1MB the expectation is to have ~10 output records. @@ -239,6 +238,6 @@ def test_multi_record_strategy_with_large_records(): data = [buffer for _ in range(10)] splitter.split.return_value = data - batch_records = [r for r in multi_record.pad('some_file', 1)] + batch_records = [r for r in multi_record.pad("some_file", 1)] # check with 11 because there may be a bit of leftover. assert len(batch_records) <= 11 diff --git a/tests/unit/test_local_entities.py b/tests/unit/test_local_entities.py index 6261031bdf..419a35aa8b 100644 --- a/tests/unit/test_local_entities.py +++ b/tests/unit/test_local_entities.py @@ -20,137 +20,137 @@ import sagemaker.local -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def local_transform_job(sagemaker_local_session): - with patch('sagemaker.local.local_session.LocalSagemakerClient.describe_model') as describe_model: - describe_model.return_value = {'PrimaryContainer': {'Environment': {}, 'Image': 'some-image:1.0'}} - job = sagemaker.local.entities._LocalTransformJob('my-transform-job', 'some-model', sagemaker_local_session) + with patch( + "sagemaker.local.local_session.LocalSagemakerClient.describe_model" + ) as describe_model: + describe_model.return_value = { + "PrimaryContainer": {"Environment": {}, "Image": "some-image:1.0"} + } + job = sagemaker.local.entities._LocalTransformJob( + "my-transform-job", "some-model", sagemaker_local_session + ) return job -@patch('sagemaker.local.local_session.LocalSagemakerClient.describe_model', Mock(return_value={'PrimaryContainer': {}})) +@patch( + "sagemaker.local.local_session.LocalSagemakerClient.describe_model", + Mock(return_value={"PrimaryContainer": {}}), +) def test_local_transform_job_init(sagemaker_local_session): - job = sagemaker.local.entities._LocalTransformJob('my-transform-job', 'some-model', sagemaker_local_session) - assert job.name == 'my-transform-job' + job = sagemaker.local.entities._LocalTransformJob( + "my-transform-job", "some-model", sagemaker_local_session + ) + assert job.name == "my-transform-job" assert job.state == sagemaker.local.entities._LocalTransformJob._CREATING def test_local_transform_job_container_environment(local_transform_job): - transform_kwargs = { - 'MaxPayloadInMB': 3, - 'BatchStrategy': 'MultiRecord', - } + transform_kwargs = {"MaxPayloadInMB": 3, "BatchStrategy": "MultiRecord"} container_env = local_transform_job._get_container_environment(**transform_kwargs) - assert 'SAGEMAKER_BATCH' in container_env - assert 'SAGEMAKER_MAX_PAYLOAD_IN_MB' in container_env - assert 'SAGEMAKER_BATCH_STRATEGY' in container_env - assert 'SAGEMAKER_MAX_CONCURRENT_TRANSFORMS' in container_env + assert "SAGEMAKER_BATCH" in container_env + assert "SAGEMAKER_MAX_PAYLOAD_IN_MB" in container_env + assert "SAGEMAKER_BATCH_STRATEGY" in container_env + assert "SAGEMAKER_MAX_CONCURRENT_TRANSFORMS" in container_env - transform_kwargs = { - 'BatchStrategy': 'SingleRecord', - } + transform_kwargs = {"BatchStrategy": "SingleRecord"} container_env = local_transform_job._get_container_environment(**transform_kwargs) - assert 'SAGEMAKER_BATCH' in container_env - assert 'SAGEMAKER_BATCH_STRATEGY' in container_env - assert 'SAGEMAKER_MAX_CONCURRENT_TRANSFORMS' in container_env + assert "SAGEMAKER_BATCH" in container_env + assert "SAGEMAKER_BATCH_STRATEGY" in container_env + assert "SAGEMAKER_MAX_CONCURRENT_TRANSFORMS" in container_env - transform_kwargs = { - 'Environment': {'MY_ENV': 3} - } + transform_kwargs = {"Environment": {"MY_ENV": 3}} container_env = local_transform_job._get_container_environment(**transform_kwargs) - assert 'SAGEMAKER_BATCH' in container_env - assert 'SAGEMAKER_MAX_PAYLOAD_IN_MB' not in container_env - assert 'SAGEMAKER_BATCH_STRATEGY' not in container_env - assert 'SAGEMAKER_MAX_CONCURRENT_TRANSFORMS' in container_env - assert 'MY_ENV' in container_env + assert "SAGEMAKER_BATCH" in container_env + assert "SAGEMAKER_MAX_PAYLOAD_IN_MB" not in container_env + assert "SAGEMAKER_BATCH_STRATEGY" not in container_env + assert "SAGEMAKER_MAX_CONCURRENT_TRANSFORMS" in container_env + assert "MY_ENV" in container_env def test_local_transform_job_defaults_with_empty_args(local_transform_job): transform_kwargs = {} defaults = local_transform_job._get_required_defaults(**transform_kwargs) - assert 'BatchStrategy' in defaults - assert 'MaxPayloadInMB' in defaults + assert "BatchStrategy" in defaults + assert "MaxPayloadInMB" in defaults def test_local_transform_job_defaults_with_batch_strategy(local_transform_job): - transform_kwargs = {'BatchStrategy': 'my-own'} + transform_kwargs = {"BatchStrategy": "my-own"} defaults = local_transform_job._get_required_defaults(**transform_kwargs) - assert 'BatchStrategy' not in defaults - assert 'MaxPayloadInMB' in defaults + assert "BatchStrategy" not in defaults + assert "MaxPayloadInMB" in defaults def test_local_transform_job_defaults_with_max_payload(local_transform_job): - transform_kwargs = {'MaxPayloadInMB': 322} + transform_kwargs = {"MaxPayloadInMB": 322} defaults = local_transform_job._get_required_defaults(**transform_kwargs) - assert 'BatchStrategy' in defaults - assert 'MaxPayloadInMB' not in defaults + assert "BatchStrategy" in defaults + assert "MaxPayloadInMB" not in defaults -@patch('sagemaker.local.entities._SageMakerContainer', Mock()) -@patch('sagemaker.local.entities._wait_for_serving_container', Mock()) -@patch('sagemaker.local.entities._perform_request') -@patch('sagemaker.local.entities._LocalTransformJob._perform_batch_inference') +@patch("sagemaker.local.entities._SageMakerContainer", Mock()) +@patch("sagemaker.local.entities._wait_for_serving_container", Mock()) +@patch("sagemaker.local.entities._perform_request") +@patch("sagemaker.local.entities._LocalTransformJob._perform_batch_inference") def test_start_local_transform_job(_perform_batch_inference, _perform_request, local_transform_job): input_data = {} output_data = {} - transform_resources = {'InstanceType': 'local'} + transform_resources = {"InstanceType": "local"} response = Mock() _perform_request.return_value = (response, 200) response.read.return_value = '{"BatchStrategy": "SingleRecord"}' - local_transform_job.primary_container['ModelDataUrl'] = 'file:///some/model' + local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model" local_transform_job.start(input_data, output_data, transform_resources, Environment={}) _perform_batch_inference.assert_called() response = local_transform_job.describe() - assert response['TransformJobStatus'] == 'Completed' - - -@patch('sagemaker.local.data.get_batch_strategy_instance') -@patch('sagemaker.local.data.get_data_source_instance') -@patch('sagemaker.local.entities.move_to_destination') -@patch('sagemaker.local.entities.get_config_value') -def test_local_transform_job_perform_batch_inference(get_config_value, move_to_destination, get_data_source_instance, - get_batch_strategy_instance, local_transform_job, tmpdir): + assert response["TransformJobStatus"] == "Completed" + + +@patch("sagemaker.local.data.get_batch_strategy_instance") +@patch("sagemaker.local.data.get_data_source_instance") +@patch("sagemaker.local.entities.move_to_destination") +@patch("sagemaker.local.entities.get_config_value") +def test_local_transform_job_perform_batch_inference( + get_config_value, + move_to_destination, + get_data_source_instance, + get_batch_strategy_instance, + local_transform_job, + tmpdir, +): input_data = { - 'DataSource': { - 'S3DataSource': { - 'S3Uri': 's3://some_bucket/nice/data' - } - }, - 'ContentType': 'text/csv' + "DataSource": {"S3DataSource": {"S3Uri": "s3://some_bucket/nice/data"}}, + "ContentType": "text/csv", } - output_data = { - 'S3OutputPath': 's3://bucket/output', - 'AssembleWith': 'Line' - } + output_data = {"S3OutputPath": "s3://bucket/output", "AssembleWith": "Line"} - transform_kwargs = { - 'MaxPayloadInMB': 3, - 'BatchStrategy': 'MultiRecord', - } + transform_kwargs = {"MaxPayloadInMB": 3, "BatchStrategy": "MultiRecord"} data_source = Mock() - data_source.get_file_list.return_value = ['/tmp/file1', '/tmp/file2'] - data_source.get_root_dir.return_value = '/tmp' + data_source.get_file_list.return_value = ["/tmp/file1", "/tmp/file2"] + data_source.get_root_dir.return_value = "/tmp" get_data_source_instance.return_value = data_source batch_strategy = Mock() - batch_strategy.pad.return_value = 'some data' + batch_strategy.pad.return_value = "some data" get_batch_strategy_instance.return_value = batch_strategy get_config_value.return_value = str(tmpdir) runtime_client = Mock() response_object = Mock() - response_object.read.return_value = b'data' - runtime_client.invoke_endpoint.return_value = {'Body': response_object} + response_object.read.return_value = b"data" + runtime_client.invoke_endpoint.return_value = {"Body": response_object} local_transform_job.local_session.sagemaker_runtime_client = runtime_client local_transform_job.container = Mock() @@ -158,8 +158,8 @@ def test_local_transform_job_perform_batch_inference(get_config_value, move_to_d local_transform_job._perform_batch_inference(input_data, output_data, **transform_kwargs) dir, output, job_name, session = move_to_destination.call_args[0] - assert output == 's3://bucket/output' + assert output == "s3://bucket/output" output_files = os.listdir(dir) assert len(output_files) == 2 - assert 'file1.out' in output_files - assert 'file2.out' in output_files + assert "file1.out" in output_files + assert "file2.out" in output_files diff --git a/tests/unit/test_local_session.py b/tests/unit/test_local_session.py index 0cd2f431bf..c733111206 100644 --- a/tests/unit/test_local_session.py +++ b/tests/unit/test_local_session.py @@ -27,131 +27,158 @@ BAD_RESPONSE = urllib3.HTTPResponse() BAD_RESPONSE.status = 502 -ENDPOINT_CONFIG_NAME = 'test-endpoint-config' -PRODUCTION_VARIANTS = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}] +ENDPOINT_CONFIG_NAME = "test-endpoint-config" +PRODUCTION_VARIANTS = [{"InstanceType": "ml.c4.99xlarge", "InitialInstanceCount": 10}] -MODEL_NAME = 'test-model' -PRIMARY_CONTAINER = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}} +MODEL_NAME = "test-model" +PRIMARY_CONTAINER = {"ModelDataUrl": "/some/model/path", "Environment": {"env1": 1, "env2": "b"}} -@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model") -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") +@patch("sagemaker.local.local_session.LocalSession") def test_create_training_job(train, LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 image = "my-docker-image:1.0" - algo_spec = {'TrainingImage': image} + algo_spec = {"TrainingImage": image} input_data_config = [ { - 'ChannelName': 'a', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3Uri': 's3://my_bucket/tmp/source1' + "ChannelName": "a", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3Uri": "s3://my_bucket/tmp/source1", } - } + }, }, { - 'ChannelName': 'b', - 'DataSource': { - 'FileDataSource': { - 'FileDataDistributionType': 'FullyReplicated', - 'FileUri': 'file:///tmp/source1' + "ChannelName": "b", + "DataSource": { + "FileDataSource": { + "FileDataDistributionType": "FullyReplicated", + "FileUri": "file:///tmp/source1", } - } - } + }, + }, ] output_data_config = {} - resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count} - hyperparameters = {'a': 1, 'b': 'bee'} - - local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config, - InputDataConfig=input_data_config, HyperParameters=hyperparameters) + resource_config = {"InstanceType": "local", "InstanceCount": instance_count} + hyperparameters = {"a": 1, "b": "bee"} + + local_sagemaker_client.create_training_job( + "my-training-job", + algo_spec, + output_data_config, + resource_config, + InputDataConfig=input_data_config, + HyperParameters=hyperparameters, + ) expected = { - 'ResourceConfig': {'InstanceCount': instance_count}, - 'TrainingJobStatus': 'Completed', - 'ModelArtifacts': {'S3ModelArtifacts': "/some/path/to/model"} + "ResourceConfig": {"InstanceCount": instance_count}, + "TrainingJobStatus": "Completed", + "ModelArtifacts": {"S3ModelArtifacts": "/some/path/to/model"}, } response = local_sagemaker_client.describe_training_job("my-training-job") - assert response['TrainingJobStatus'] == expected['TrainingJobStatus'] - assert response['ResourceConfig']['InstanceCount'] == expected['ResourceConfig']['InstanceCount'] - assert response['ModelArtifacts']['S3ModelArtifacts'] == expected['ModelArtifacts']['S3ModelArtifacts'] + assert response["TrainingJobStatus"] == expected["TrainingJobStatus"] + assert ( + response["ResourceConfig"]["InstanceCount"] == expected["ResourceConfig"]["InstanceCount"] + ) + assert ( + response["ModelArtifacts"]["S3ModelArtifacts"] + == expected["ModelArtifacts"]["S3ModelArtifacts"] + ) -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_describe_invalid_training_job(*args): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() with pytest.raises(ClientError): - local_sagemaker_client.describe_training_job('i-havent-created-this-job') + local_sagemaker_client.describe_training_job("i-havent-created-this-job") -@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model") -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") +@patch("sagemaker.local.local_session.LocalSession") def test_create_training_job_invalid_data_source(train, LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 image = "my-docker-image:1.0" - algo_spec = {'TrainingImage': image} + algo_spec = {"TrainingImage": image} # InvalidDataSource is not supported. S3DataSource and FileDataSource are currently the only # valid Data Sources. We expect a ValueError if we pass this input data config. - input_data_config = [{ - 'ChannelName': 'a', - 'DataSource': { - 'InvalidDataSource': { - 'FileDataDistributionType': 'FullyReplicated', - 'FileUri': 'ftp://myserver.com/tmp/source1' - } + input_data_config = [ + { + "ChannelName": "a", + "DataSource": { + "InvalidDataSource": { + "FileDataDistributionType": "FullyReplicated", + "FileUri": "ftp://myserver.com/tmp/source1", + } + }, } - }] + ] output_data_config = {} - resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count} - hyperparameters = {'a': 1, 'b': 'bee'} + resource_config = {"InstanceType": "local", "InstanceCount": instance_count} + hyperparameters = {"a": 1, "b": "bee"} with pytest.raises(ValueError): - local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config, - InputDataConfig=input_data_config, HyperParameters=hyperparameters) - - -@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model") -@patch('sagemaker.local.local_session.LocalSession') + local_sagemaker_client.create_training_job( + "my-training-job", + algo_spec, + output_data_config, + resource_config, + InputDataConfig=input_data_config, + HyperParameters=hyperparameters, + ) + + +@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") +@patch("sagemaker.local.local_session.LocalSession") def test_create_training_job_not_fully_replicated(train, LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 image = "my-docker-image:1.0" - algo_spec = {'TrainingImage': image} + algo_spec = {"TrainingImage": image} # Local Mode only supports FullyReplicated as Data Distribution type. - input_data_config = [{ - 'ChannelName': 'a', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'ShardedByS3Key', - 'S3Uri': 's3://my_bucket/tmp/source1' - } + input_data_config = [ + { + "ChannelName": "a", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "ShardedByS3Key", + "S3Uri": "s3://my_bucket/tmp/source1", + } + }, } - }] + ] output_data_config = {} - resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count} - hyperparameters = {'a': 1, 'b': 'bee'} + resource_config = {"InstanceType": "local", "InstanceCount": instance_count} + hyperparameters = {"a": 1, "b": "bee"} with pytest.raises(RuntimeError): - local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config, - InputDataConfig=input_data_config, HyperParameters=hyperparameters) + local_sagemaker_client.create_training_job( + "my-training-job", + algo_spec, + output_data_config, + resource_config, + InputDataConfig=input_data_config, + HyperParameters=hyperparameters, + ) -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_create_model(LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() @@ -160,7 +187,7 @@ def test_create_model(LocalSession): assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_delete_model(LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() @@ -171,281 +198,280 @@ def test_delete_model(LocalSession): assert MODEL_NAME not in sagemaker.local.local_session.LocalSagemakerClient._models -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_describe_model(LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() with pytest.raises(ClientError): - local_sagemaker_client.describe_model('model-does-not-exist') + local_sagemaker_client.describe_model("model-does-not-exist") local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) response = local_sagemaker_client.describe_model(MODEL_NAME) - assert response['ModelName'] == 'test-model' - assert response['PrimaryContainer']['ModelDataUrl'] == '/some/model/path' + assert response["ModelName"] == "test-model" + assert response["PrimaryContainer"]["ModelDataUrl"] == "/some/model/path" -@patch('sagemaker.local.local_session._LocalTransformJob') -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session._LocalTransformJob") +@patch("sagemaker.local.local_session.LocalSession") def test_create_transform_job(LocalSession, _LocalTransformJob): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() - local_sagemaker_client.create_transform_job('transform-job', 'some-model', None, None, None) + local_sagemaker_client.create_transform_job("transform-job", "some-model", None, None, None) _LocalTransformJob().start.assert_called_with(None, None, None) - local_sagemaker_client.describe_transform_job('transform-job') + local_sagemaker_client.describe_transform_job("transform-job") _LocalTransformJob().describe.assert_called() -@patch('sagemaker.local.local_session._LocalTransformJob') -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session._LocalTransformJob") +@patch("sagemaker.local.local_session.LocalSession") def test_describe_transform_job_does_not_exist(LocalSession, _LocalTransformJob): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() with pytest.raises(ClientError): - local_sagemaker_client.describe_transform_job('transform-job-does-not-exist') + local_sagemaker_client.describe_transform_job("transform-job-does-not-exist") -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_describe_endpoint_config(LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() # No Endpoint Config Created with pytest.raises(ClientError): - local_sagemaker_client.describe_endpoint_config('some-endpoint-config') + local_sagemaker_client.describe_endpoint_config("some-endpoint-config") - production_variants = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}] - local_sagemaker_client.create_endpoint_config('test-endpoint-config', production_variants) + production_variants = [{"InstanceType": "ml.c4.99xlarge", "InitialInstanceCount": 10}] + local_sagemaker_client.create_endpoint_config("test-endpoint-config", production_variants) - response = local_sagemaker_client.describe_endpoint_config('test-endpoint-config') - assert response['EndpointConfigName'] == 'test-endpoint-config' - assert response['ProductionVariants'][0]['InstanceType'] == 'ml.c4.99xlarge' + response = local_sagemaker_client.describe_endpoint_config("test-endpoint-config") + assert response["EndpointConfigName"] == "test-endpoint-config" + assert response["ProductionVariants"][0]["InstanceType"] == "ml.c4.99xlarge" -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_create_endpoint_config(LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) - assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs + assert ( + ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs + ) -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_delete_endpoint_config(LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) - assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs + assert ( + ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs + ) local_sagemaker_client.delete_endpoint_config(ENDPOINT_CONFIG_NAME) - assert ENDPOINT_CONFIG_NAME not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs + assert ( + ENDPOINT_CONFIG_NAME + not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs + ) -@patch('sagemaker.local.image._SageMakerContainer.serve') -@patch('sagemaker.local.local_session.LocalSession') -@patch('urllib3.PoolManager.request') -@patch('sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config') -@patch('sagemaker.local.local_session.LocalSagemakerClient.describe_model') +@patch("sagemaker.local.image._SageMakerContainer.serve") +@patch("sagemaker.local.local_session.LocalSession") +@patch("urllib3.PoolManager.request") +@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") +@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") def test_describe_endpoint(describe_model, describe_endpoint_config, request, *args): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE describe_endpoint_config.return_value = { - 'EndpointConfigName': 'name', - 'EndpointConfigArn': 'local:arn-does-not-matter', - 'CreationTime': '00:00:00', - 'ProductionVariants': [ + "EndpointConfigName": "name", + "EndpointConfigArn": "local:arn-does-not-matter", + "CreationTime": "00:00:00", + "ProductionVariants": [ { - 'InitialVariantWeight': 1.0, - 'ModelName': 'my-model', - 'VariantName': 'AllTraffic', - 'InitialInstanceCount': 1, - 'InstanceType': 'local' - + "InitialVariantWeight": 1.0, + "ModelName": "my-model", + "VariantName": "AllTraffic", + "InitialInstanceCount": 1, + "InstanceType": "local", } - ] + ], } describe_model.return_value = { - 'ModelName': 'my-model', - 'CreationTime': '00:00;00', - 'ExecutionRoleArn': 'local:arn-does-not-matter', - 'ModelArn': 'local:arn-does-not-matter', - 'PrimaryContainer': { - 'Environment': { - 'SAGEMAKER_REGION': 'us-west-2' - }, - 'Image': '123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0', - 'ModelDataUrl': 's3://sagemaker-us-west-2/some/model.tar.gz' - } + "ModelName": "my-model", + "CreationTime": "00:00;00", + "ExecutionRoleArn": "local:arn-does-not-matter", + "ModelArn": "local:arn-does-not-matter", + "PrimaryContainer": { + "Environment": {"SAGEMAKER_REGION": "us-west-2"}, + "Image": "123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0", + "ModelDataUrl": "s3://sagemaker-us-west-2/some/model.tar.gz", + }, } with pytest.raises(ClientError): - local_sagemaker_client.describe_endpoint('non-existing-endpoint') + local_sagemaker_client.describe_endpoint("non-existing-endpoint") - local_sagemaker_client.create_endpoint('test-endpoint', 'some-endpoint-config') - response = local_sagemaker_client.describe_endpoint('test-endpoint') + local_sagemaker_client.create_endpoint("test-endpoint", "some-endpoint-config") + response = local_sagemaker_client.describe_endpoint("test-endpoint") - assert response['EndpointName'] == 'test-endpoint' + assert response["EndpointName"] == "test-endpoint" -@patch('sagemaker.local.image._SageMakerContainer.serve') -@patch('sagemaker.local.local_session.LocalSession') -@patch('urllib3.PoolManager.request') -@patch('sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config') -@patch('sagemaker.local.local_session.LocalSagemakerClient.describe_model') +@patch("sagemaker.local.image._SageMakerContainer.serve") +@patch("sagemaker.local.local_session.LocalSession") +@patch("urllib3.PoolManager.request") +@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") +@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") def test_create_endpoint(describe_model, describe_endpoint_config, request, *args): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE describe_endpoint_config.return_value = { - 'EndpointConfigName': 'name', - 'EndpointConfigArn': 'local:arn-does-not-matter', - 'CreationTime': '00:00:00', - 'ProductionVariants': [ + "EndpointConfigName": "name", + "EndpointConfigArn": "local:arn-does-not-matter", + "CreationTime": "00:00:00", + "ProductionVariants": [ { - 'InitialVariantWeight': 1.0, - 'ModelName': 'my-model', - 'VariantName': 'AllTraffic', - 'InitialInstanceCount': 1, - 'InstanceType': 'local' - + "InitialVariantWeight": 1.0, + "ModelName": "my-model", + "VariantName": "AllTraffic", + "InitialInstanceCount": 1, + "InstanceType": "local", } - ] + ], } describe_model.return_value = { - 'ModelName': 'my-model', - 'CreationTime': '00:00;00', - 'ExecutionRoleArn': 'local:arn-does-not-matter', - 'ModelArn': 'local:arn-does-not-matter', - 'PrimaryContainer': { - 'Environment': { - 'SAGEMAKER_REGION': 'us-west-2' - }, - 'Image': '123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0', - 'ModelDataUrl': 's3://sagemaker-us-west-2/some/model.tar.gz' - } + "ModelName": "my-model", + "CreationTime": "00:00;00", + "ExecutionRoleArn": "local:arn-does-not-matter", + "ModelArn": "local:arn-does-not-matter", + "PrimaryContainer": { + "Environment": {"SAGEMAKER_REGION": "us-west-2"}, + "Image": "123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0", + "ModelDataUrl": "s3://sagemaker-us-west-2/some/model.tar.gz", + }, } - local_sagemaker_client.create_endpoint('my-endpoint', 'some-endpoint-config') + local_sagemaker_client.create_endpoint("my-endpoint", "some-endpoint-config") - assert 'my-endpoint' in sagemaker.local.local_session.LocalSagemakerClient._endpoints + assert "my-endpoint" in sagemaker.local.local_session.LocalSagemakerClient._endpoints -@patch('sagemaker.local.local_session.LocalSession') +@patch("sagemaker.local.local_session.LocalSession") def test_update_endpoint(LocalSession): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() - endpoint_name = 'my-endpoint' - endpoint_config = 'my-endpoint-config' - expected_error_message = 'Update endpoint name is not supported in local session.' + endpoint_name = "my-endpoint" + endpoint_config = "my-endpoint-config" + expected_error_message = "Update endpoint name is not supported in local session." with pytest.raises(NotImplementedError, match=expected_error_message): local_sagemaker_client.update_endpoint(endpoint_name, endpoint_config) -@patch('sagemaker.local.image._SageMakerContainer.serve') -@patch('urllib3.PoolManager.request') +@patch("sagemaker.local.image._SageMakerContainer.serve") +@patch("urllib3.PoolManager.request") def test_serve_endpoint_with_correct_accelerator(request, *args): - mock_session = Mock(name='sagemaker_session') - mock_session.return_value.sagemaker_client = Mock(name='sagemaker_client') + mock_session = Mock(name="sagemaker_session") + mock_session.return_value.sagemaker_client = Mock(name="sagemaker_client") mock_session.config = None request.return_value = OK_RESPONSE mock_session.sagemaker_client.describe_endpoint_config.return_value = { - 'ProductionVariants': [ + "ProductionVariants": [ { - 'ModelName': 'my-model', - 'InitialInstanceCount': 1, - 'InstanceType': 'local', - 'AcceleratorType': 'local_sagemaker_notebook' + "ModelName": "my-model", + "InitialInstanceCount": 1, + "InstanceType": "local", + "AcceleratorType": "local_sagemaker_notebook", } ] } mock_session.sagemaker_client.describe_model.return_value = { - 'PrimaryContainer': { - 'Environment': { - }, - 'Image': '123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0', - 'ModelDataUrl': 's3://sagemaker-us-west-2/some/model.tar.gz' + "PrimaryContainer": { + "Environment": {}, + "Image": "123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0", + "ModelDataUrl": "s3://sagemaker-us-west-2/some/model.tar.gz", } } - endpoint = sagemaker.local.local_session._LocalEndpoint('my-endpoint', 'some-endpoint-config', - local_session=mock_session) + endpoint = sagemaker.local.local_session._LocalEndpoint( + "my-endpoint", "some-endpoint-config", local_session=mock_session + ) endpoint.serve() - assert endpoint.primary_container['Environment']['SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT'] == 'true' + assert ( + endpoint.primary_container["Environment"]["SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"] + == "true" + ) -@patch('sagemaker.local.image._SageMakerContainer.serve') -@patch('urllib3.PoolManager.request') +@patch("sagemaker.local.image._SageMakerContainer.serve") +@patch("urllib3.PoolManager.request") def test_serve_endpoint_with_incorrect_accelerator(request, *args): - mock_session = Mock(name='sagemaker_session') - mock_session.return_value.sagemaker_client = Mock(name='sagemaker_client') + mock_session = Mock(name="sagemaker_session") + mock_session.return_value.sagemaker_client = Mock(name="sagemaker_client") mock_session.config = None request.return_value = OK_RESPONSE mock_session.sagemaker_client.describe_endpoint_config.return_value = { - 'ProductionVariants': [ + "ProductionVariants": [ { - 'ModelName': 'my-model', - 'InitialInstanceCount': 1, - 'InstanceType': 'local', - 'AcceleratorType': 'local' + "ModelName": "my-model", + "InitialInstanceCount": 1, + "InstanceType": "local", + "AcceleratorType": "local", } ] } mock_session.sagemaker_client.describe_model.return_value = { - 'PrimaryContainer': { - 'Environment': { - }, - 'Image': '123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0', - 'ModelDataUrl': 's3://sagemaker-us-west-2/some/model.tar.gz' + "PrimaryContainer": { + "Environment": {}, + "Image": "123.dkr.ecr-us-west-2.amazonaws.com/sagemaker-container:1.0", + "ModelDataUrl": "s3://sagemaker-us-west-2/some/model.tar.gz", } } - endpoint = sagemaker.local.local_session._LocalEndpoint('my-endpoint', 'some-endpoint-config', - local_session=mock_session) + endpoint = sagemaker.local.local_session._LocalEndpoint( + "my-endpoint", "some-endpoint-config", local_session=mock_session + ) endpoint.serve() with pytest.raises(KeyError): - assert endpoint.primary_container['Environment']['SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT'] == 'true' + assert ( + endpoint.primary_container["Environment"]["SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"] + == "true" + ) def test_file_input_all_defaults(): - prefix = 'pre' + prefix = "pre" actual = sagemaker.local.local_session.file_input(fileUri=prefix) - expected = \ - { - 'DataSource': { - 'FileDataSource': { - 'FileDataDistributionType': 'FullyReplicated', - 'FileUri': prefix - } - } + expected = { + "DataSource": { + "FileDataSource": {"FileDataDistributionType": "FullyReplicated", "FileUri": prefix} } + } assert actual.config == expected def test_file_input_content_type(): - prefix = 'pre' - actual = sagemaker.local.local_session.file_input(fileUri=prefix, content_type='text/csv') - expected = \ - { - 'DataSource': { - 'FileDataSource': { - 'FileDataDistributionType': 'FullyReplicated', - 'FileUri': prefix - } - }, - 'ContentType': 'text/csv' - } + prefix = "pre" + actual = sagemaker.local.local_session.file_input(fileUri=prefix, content_type="text/csv") + expected = { + "DataSource": { + "FileDataSource": {"FileDataDistributionType": "FullyReplicated", "FileUri": prefix} + }, + "ContentType": "text/csv", + } assert actual.config == expected def test_local_session_is_set_to_local_mode(): - boto_session = Mock(region_name='us-west-2') + boto_session = Mock(region_name="us-west-2") local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session) assert local_session.local_mode diff --git a/tests/unit/test_local_utils.py b/tests/unit/test_local_utils.py index 3d4ca02250..bcacd8f0ff 100644 --- a/tests/unit/test_local_utils.py +++ b/tests/unit/test_local_utils.py @@ -18,19 +18,19 @@ import sagemaker.local.utils -@patch('shutil.rmtree', Mock()) -@patch('sagemaker.local.utils.recursive_copy') +@patch("shutil.rmtree", Mock()) +@patch("sagemaker.local.utils.recursive_copy") def test_move_to_destination(recursive_copy): # local files will just be recursively copied - sagemaker.local.utils.move_to_destination('/tmp/data', 'file:///target/dir/', 'job', None) - recursive_copy.assert_called_with('/tmp/data', '/target/dir/') + sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None) + recursive_copy.assert_called_with("/tmp/data", "/target/dir/") # s3 destination will upload to S3 sms = Mock() - sagemaker.local.utils.move_to_destination('/tmp/data', 's3://bucket/path', 'job', sms) + sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms) sms.upload_data.assert_called() def test_move_to_destination_illegal_destination(): with pytest.raises(ValueError): - sagemaker.local.utils.move_to_destination('/tmp/data', 'ftp://ftp/in/2018', 'job', None) + sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 6811e6f3db..7838266516 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -22,85 +22,73 @@ import pytest from mock import MagicMock, Mock, patch -MODEL_DATA = 's3://bucket/model.tar.gz' -MODEL_IMAGE = 'mi' -ENTRY_POINT = 'blah.py' -INSTANCE_TYPE = 'p2.xlarge' -ROLE = 'some-role' - -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_NAME = 'dummy_script.py' +MODEL_DATA = "s3://bucket/model.tar.gz" +MODEL_IMAGE = "mi" +ENTRY_POINT = "blah.py" +INSTANCE_TYPE = "p2.xlarge" +ROLE = "some-role" + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_NAME = "dummy_script.py" SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME) -TIMESTAMP = '2017-10-10-14-14-15' -BUCKET_NAME = 'mybucket' +TIMESTAMP = "2017-10-10-14-14-15" +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'c4.4xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -IMAGE_NAME = 'fakeimage' -REGION = 'us-west-2' -MODEL_NAME = '{}-{}'.format(MODEL_IMAGE, TIMESTAMP) +INSTANCE_TYPE = "c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +IMAGE_NAME = "fakeimage" +REGION = "us-west-2" +MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) DESCRIBE_MODEL_PACKAGE_RESPONSE = { - 'InferenceSpecification': { - 'SupportedResponseMIMETypes': [ - 'text' - ], - 'SupportedContentTypes': [ - 'text/csv' - ], - 'SupportedTransformInstanceTypes': [ - 'ml.m4.xlarge', - 'ml.m4.2xlarge' - ], - 'Containers': [ + "InferenceSpecification": { + "SupportedResponseMIMETypes": ["text"], + "SupportedContentTypes": ["text/csv"], + "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], + "Containers": [ { - 'Image': '1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest', - 'ImageDigest': 'sha256:1234556789', - 'ModelDataUrl': 's3://bucket/output/model.tar.gz' + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", + "ImageDigest": "sha256:1234556789", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", } ], - 'SupportedRealtimeInferenceInstanceTypes': [ - 'ml.m4.xlarge', - 'ml.m4.2xlarge', - - ] - }, - 'ModelPackageDescription': 'Model Package created from training with ' - 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees', - 'CreationTime': 1542752036.687, - 'ModelPackageArn': 'arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees', - 'ModelPackageStatusDetails': { - 'ValidationStatuses': [], - 'ImageScanStatuses': [] + "SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], }, - 'SourceAlgorithmSpecification': { - 'SourceAlgorithms': [ + "ModelPackageDescription": "Model Package created from training with " + "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + "CreationTime": 1542752036.687, + "ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees", + "ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []}, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ { - 'ModelDataUrl': 's3://bucket/output/model.tar.gz', - 'AlgorithmName': 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees' + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", } ] }, - - 'ModelPackageStatus': 'Completed', - 'ModelPackageName': 'mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502', - 'CertifyForMarketplace': False + "ModelPackageStatus": "Completed", + "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", + "CertifyForMarketplace": False, } DESCRIBE_COMPILATION_JOB_RESPONSE = { - 'CompilationJobStatus': "Completed", - 'ModelArtifacts': { - 'S3ModelArtifacts': 's3://output-path/model.tar.gz' - } + "CompilationJobStatus": "Completed", + "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, } class DummyFrameworkModel(FrameworkModel): - def __init__(self, sagemaker_session, **kwargs): - super(DummyFrameworkModel, self).__init__(MODEL_DATA, MODEL_IMAGE, ROLE, ENTRY_POINT, - sagemaker_session=sagemaker_session, **kwargs) + super(DummyFrameworkModel, self).__init__( + MODEL_DATA, + MODEL_IMAGE, + ROLE, + ENTRY_POINT, + sagemaker_session=sagemaker_session, + **kwargs + ) def create_predictor(self, endpoint_name): return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session) @@ -108,165 +96,213 @@ def create_predictor(self, endpoint_name): @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return sms -@patch('shutil.rmtree', MagicMock()) -@patch('tarfile.open', MagicMock()) -@patch('os.listdir', MagicMock(return_value=['blah.py'])) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("shutil.rmtree", MagicMock()) +@patch("tarfile.open", MagicMock()) +@patch("os.listdir", MagicMock(return_value=["blah.py"])) +@patch("time.strftime", return_value=TIMESTAMP) def test_prepare_container_def(time, sagemaker_session): model = DummyFrameworkModel(sagemaker_session) assert model.prepare_container_def(INSTANCE_TYPE) == { - 'Environment': {'SAGEMAKER_PROGRAM': ENTRY_POINT, - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': REGION, - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false'}, - 'Image': MODEL_IMAGE, - 'ModelDataUrl': MODEL_DATA} - - -@patch('shutil.rmtree', MagicMock()) -@patch('tarfile.open', MagicMock()) -@patch('os.path.exists', MagicMock(return_value=True)) -@patch('os.path.isdir', MagicMock(return_value=True)) -@patch('os.listdir', MagicMock(return_value=['blah.py'])) -@patch('time.strftime', MagicMock(return_value=TIMESTAMP)) + "Environment": { + "SAGEMAKER_PROGRAM": ENTRY_POINT, + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + }, + "Image": MODEL_IMAGE, + "ModelDataUrl": MODEL_DATA, + } + + +@patch("shutil.rmtree", MagicMock()) +@patch("tarfile.open", MagicMock()) +@patch("os.path.exists", MagicMock(return_value=True)) +@patch("os.path.isdir", MagicMock(return_value=True)) +@patch("os.listdir", MagicMock(return_value=["blah.py"])) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) def test_create_no_defaults(sagemaker_session, tmpdir): - model = DummyFrameworkModel(sagemaker_session, source_dir='sd', env={"a": "a"}, name="name", - enable_cloudwatch_metrics=True, container_log_level=55, - code_location="s3://cb/cp") + model = DummyFrameworkModel( + sagemaker_session, + source_dir="sd", + env={"a": "a"}, + name="name", + enable_cloudwatch_metrics=True, + container_log_level=55, + code_location="s3://cb/cp", + ) assert model.prepare_container_def(INSTANCE_TYPE) == { - 'Environment': {'SAGEMAKER_PROGRAM': ENTRY_POINT, - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://cb/cp/name/sourcedir.tar.gz', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '55', - 'SAGEMAKER_REGION': REGION, - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'true', - 'a': 'a'}, - 'Image': MODEL_IMAGE, - 'ModelDataUrl': MODEL_DATA} - - -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) -@patch('time.strftime', MagicMock(return_value=TIMESTAMP)) + "Environment": { + "SAGEMAKER_PROGRAM": ENTRY_POINT, + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://cb/cp/name/sourcedir.tar.gz", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "55", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "true", + "a": "a", + }, + "Image": MODEL_IMAGE, + "ModelDataUrl": MODEL_DATA, + } + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) def test_deploy(sagemaker_session, tmpdir): model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) sagemaker_session.endpoint_from_production_variants.assert_called_with( MODEL_NAME, - [{'InitialVariantWeight': 1, - 'ModelName': MODEL_NAME, - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 1, - 'VariantName': 'AllTraffic'}], + [ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], None, None, - True) + True, + ) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) -@patch('time.strftime', MagicMock(return_value=TIMESTAMP)) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) def test_deploy_endpoint_name(sagemaker_session, tmpdir): model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.deploy(endpoint_name='blah', instance_type=INSTANCE_TYPE, initial_instance_count=55) + model.deploy(endpoint_name="blah", instance_type=INSTANCE_TYPE, initial_instance_count=55) sagemaker_session.endpoint_from_production_variants.assert_called_with( - 'blah', - [{'InitialVariantWeight': 1, - 'ModelName': MODEL_NAME, - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 55, - 'VariantName': 'AllTraffic'}], + "blah", + [ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 55, + "VariantName": "AllTraffic", + } + ], None, None, - True) + True, + ) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) -@patch('time.strftime', MagicMock(return_value=TIMESTAMP)) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) def test_deploy_tags(sagemaker_session, tmpdir): model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - tags = [{'ModelName': 'TestModel'}] + tags = [{"ModelName": "TestModel"}] model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags) sagemaker_session.endpoint_from_production_variants.assert_called_with( MODEL_NAME, - [{'InitialVariantWeight': 1, - 'ModelName': MODEL_NAME, - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 1, - 'VariantName': 'AllTraffic'}], + [ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], tags, None, - True) + True, + ) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_deploy_accelerator_type(tfo, time, sagemaker_session): model = DummyFrameworkModel(sagemaker_session) - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE) + model.deploy( + instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE + ) sagemaker_session.endpoint_from_production_variants.assert_called_with( MODEL_NAME, - [{'InitialVariantWeight': 1, - 'ModelName': MODEL_NAME, - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 1, - 'VariantName': 'AllTraffic', - 'AcceleratorType': ACCELERATOR_TYPE}], + [ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + "AcceleratorType": ACCELERATOR_TYPE, + } + ], None, None, - True) + True, + ) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_deploy_kms_key(tfo, time, sagemaker_session): - key = 'some-key-arn' + key = "some-key-arn" model = DummyFrameworkModel(sagemaker_session) model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key) sagemaker_session.endpoint_from_production_variants.assert_called_with( MODEL_NAME, - [{'InitialVariantWeight': 1, - 'ModelName': MODEL_NAME, - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 1, - 'VariantName': 'AllTraffic'}], + [ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], None, key, - True) + True, + ) -@patch('sagemaker.session.Session') -@patch('sagemaker.local.LocalSession') -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) +@patch("sagemaker.session.Session") +@patch("sagemaker.local.LocalSession") +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_deploy_creates_correct_session(local_session, session, tmpdir): # We expect a LocalSession when deploying to instance_type = 'local' model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) - model.deploy(endpoint_name='blah', instance_type='local', initial_instance_count=1) + model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) assert model.sagemaker_session == local_session.return_value # We expect a real Session when deploying to instance_type != local/local_gpu model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) - model.deploy(endpoint_name='remote_endpoint', instance_type='ml.m4.4xlarge', initial_instance_count=2) + model.deploy( + endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2 + ) assert model.sagemaker_session == session.return_value -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_deploy_update_endpoint(sagemaker_session, tmpdir): model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) - endpoint_name = 'endpoint-name' - model.deploy(instance_type=INSTANCE_TYPE, - initial_instance_count=1, - endpoint_name=endpoint_name, - update_endpoint=True, - accelerator_type=ACCELERATOR_TYPE) + endpoint_name = "endpoint-name" + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=1, + endpoint_name=endpoint_name, + update_endpoint=True, + accelerator_type=ACCELERATOR_TYPE, + ) sagemaker_session.create_endpoint_config.assert_called_with( name=model.name, model_name=model.name, @@ -281,7 +317,7 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir): model_name=model.name, initial_instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE + accelerator_type=ACCELERATOR_TYPE, ) sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name) sagemaker_session.create_endpoint.assert_not_called() @@ -292,89 +328,101 @@ def test_model_enable_network_isolation(sagemaker_session): assert model.enable_network_isolation() is False -@patch('sagemaker.model.Model._create_sagemaker_model', Mock()) +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) def test_model_create_transformer(sagemaker_session): sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE) + return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE + ) model = DummyFrameworkModel(sagemaker_session=sagemaker_session) - model.name = 'auto-generated-model' - transformer = model.transformer(instance_count=1, instance_type='ml.m4.xlarge', - env={'test': True}) + model.name = "auto-generated-model" + transformer = model.transformer( + instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} + ) assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == 'auto-generated-model' - assert transformer.instance_type == 'ml.m4.xlarge' - assert transformer.env == {'test': True} + assert transformer.model_name == "auto-generated-model" + assert transformer.instance_type == "ml.m4.xlarge" + assert transformer.env == {"test": True} def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE) + return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE + ) - model_package = ModelPackage(role='role', model_package_arn='my-model-package', - sagemaker_session=sagemaker_session) + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) assert model_package.enable_network_isolation() is False def test_model_package_enable_network_isolation_with_product_id(sagemaker_session): model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) - model_package_response['InferenceSpecification']['Containers'].append( + model_package_response["InferenceSpecification"]["Containers"].append( { - 'Image': '1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest', - 'ModelDataUrl': 's3://bucket/output/model.tar.gz', - 'ProductId': 'some-product-id' + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "ProductId": "some-product-id", } ) sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=model_package_response) + return_value=model_package_response + ) - model_package = ModelPackage(role='role', model_package_arn='my-model-package', - sagemaker_session=sagemaker_session) + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) assert model_package.enable_network_isolation() is True -@patch('sagemaker.model.ModelPackage._create_sagemaker_model', Mock()) +@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) def test_model_package_create_transformer(sagemaker_session): sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE) + return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE + ) - model_package = ModelPackage(role='role', model_package_arn='my-model-package', - sagemaker_session=sagemaker_session) - model_package.name = 'auto-generated-model' - transformer = model_package.transformer(instance_count=1, instance_type='ml.m4.xlarge', - env={'test': True}) + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + model_package.name = "auto-generated-model" + transformer = model_package.transformer( + instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} + ) assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == 'auto-generated-model' - assert transformer.instance_type == 'ml.m4.xlarge' - assert transformer.env == {'test': True} + assert transformer.model_name == "auto-generated-model" + assert transformer.instance_type == "ml.m4.xlarge" + assert transformer.env == {"test": True} -@patch('sagemaker.model.ModelPackage._create_sagemaker_model', Mock()) +@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) def test_model_package_create_transformer_with_product_id(sagemaker_session): model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) - model_package_response['InferenceSpecification']['Containers'].append( + model_package_response["InferenceSpecification"]["Containers"].append( { - 'Image': '1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest', - 'ModelDataUrl': 's3://bucket/output/model.tar.gz', - 'ProductId': 'some-product-id' + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "ProductId": "some-product-id", } ) sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=model_package_response) + return_value=model_package_response + ) - model_package = ModelPackage(role='role', model_package_arn='my-model-package', - sagemaker_session=sagemaker_session) - model_package.name = 'auto-generated-model' - transformer = model_package.transformer(instance_count=1, instance_type='ml.m4.xlarge', - env={'test': True}) + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + model_package.name = "auto-generated-model" + transformer = model_package.transformer( + instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} + ) assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == 'auto-generated-model' - assert transformer.instance_type == 'ml.m4.xlarge' + assert transformer.model_name == "auto-generated-model" + assert transformer.instance_type == "ml.m4.xlarge" assert transformer.env is None -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) -@patch('time.strftime', MagicMock(return_value=TIMESTAMP)) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) def test_model_delete_model(sagemaker_session, tmpdir): model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) @@ -385,37 +433,74 @@ def test_model_delete_model(sagemaker_session, tmpdir): def test_delete_non_deployed_model(sagemaker_session): model = DummyFrameworkModel(sagemaker_session) - with pytest.raises(ValueError, match='The SageMaker model must be created first before attempting to delete.'): + with pytest.raises( + ValueError, match="The SageMaker model must be created first before attempting to delete." + ): model.delete_model() def test_compile_model_for_edge_device(sagemaker_session, tmpdir): sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE) + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.compile(target_instance_family='deeplens', input_shape={'data': [1, 3, 1024, 1024]}, - output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model") + model.compile( + target_instance_family="deeplens", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + job_name="compile-model", + ) assert model._is_compiled_model is False def test_compile_model_for_cloud(sagemaker_session, tmpdir): sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE) + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.compile(target_instance_family='ml_c4', input_shape={'data': [1, 3, 1024, 1024]}, - output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model") + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + job_name="compile-model", + ) assert model._is_compiled_model is True def test_check_neo_region(sagemaker_session, tmpdir): sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE) + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - ec2_region_list = ['us-east-2', 'us-east-1', 'us-west-1', 'us-west-2', 'ap-east-1', 'ap-south-1', - 'ap-northeast-3', 'ap-northeast-2', 'ap-southeast-1', 'ap-southeast-2', 'ap-northeast-1', - 'ca-central-1', 'cn-north-1', 'cn-northwest-1', 'eu-central-1', ' eu-west-1', 'eu-west-2', - 'eu-west-3', 'eu-north-1', 'sa-east-1', 'us-gov-east-1', 'us-gov-west-1'] - neo_support_region = ['us-west-2', 'eu-west-1', 'us-east-1', 'us-east-2', 'ap-northeast-1'] + ec2_region_list = [ + "us-east-2", + "us-east-1", + "us-west-1", + "us-west-2", + "ap-east-1", + "ap-south-1", + "ap-northeast-3", + "ap-northeast-2", + "ap-southeast-1", + "ap-southeast-2", + "ap-northeast-1", + "ca-central-1", + "cn-north-1", + "cn-northwest-1", + "eu-central-1", + " eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-north-1", + "sa-east-1", + "us-gov-east-1", + "us-gov-west-1", + ] + neo_support_region = ["us-west-2", "eu-west-1", "us-east-1", "us-east-2", "ap-northeast-1"] for region_name in ec2_region_list: if region_name in neo_support_region: assert model.check_neo_region(region_name) is True diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 0f45afe51f..3a3a82b494 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -26,56 +26,54 @@ from sagemaker.mxnet import MXNet from sagemaker.mxnet import MXNetPredictor, MXNetModel -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py') -MODEL_DATA = 's3://mybucket/model' -TIMESTAMP = '2017-11-06-14:14:15.672' +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +MODEL_DATA = "s3://mybucket/model" +TIMESTAMP = "2017-11-06-14:14:15.672" TIME = 1507167947 -BUCKET_NAME = 'mybucket' +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.4xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -IMAGE_REPO_NAME = 'sagemaker-mxnet' -IMAGE_REPO_SERVING_NAME = 'sagemaker-mxnet-serving' -JOB_NAME = '{}-{}'.format(IMAGE_REPO_NAME, TIMESTAMP) -COMPILATION_JOB_NAME = '{}-{}'.format('compilation-sagemaker-mxnet', TIMESTAMP) -FRAMEWORK = 'mxnet' -FULL_IMAGE_URI = '520713654638.dkr.ecr.us-west-2.amazonaws.com/{}:{}-{}-{}' -ROLE = 'Dummy' -REGION = 'us-west-2' -GPU = 'ml.p2.xlarge' -CPU = 'ml.c4.xlarge' -CPU_C5 = 'ml.c5.xlarge' -LAUNCH_PS_DISTRIBUTIONS_DICT = {'parameter_server': {'enabled': True}} - -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} - -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} - -LIST_TAGS_RESULT = { - 'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] -} +INSTANCE_TYPE = "ml.c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +IMAGE_REPO_NAME = "sagemaker-mxnet" +IMAGE_REPO_SERVING_NAME = "sagemaker-mxnet-serving" +JOB_NAME = "{}-{}".format(IMAGE_REPO_NAME, TIMESTAMP) +COMPILATION_JOB_NAME = "{}-{}".format("compilation-sagemaker-mxnet", TIMESTAMP) +FRAMEWORK = "mxnet" +FULL_IMAGE_URI = "520713654638.dkr.ecr.us-west-2.amazonaws.com/{}:{}-{}-{}" +ROLE = "Dummy" +REGION = "us-west-2" +GPU = "ml.p2.xlarge" +CPU = "ml.c4.xlarge" +CPU_C5 = "ml.c5.xlarge" +LAUNCH_PS_DISTRIBUTIONS_DICT = {"parameter_server": {"enabled": True}} + +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} + +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} + +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) - describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} - describe_compilation = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/model_c5.tar.gz'}} + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + describe_compilation = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/model_c5.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) session.wait_for_compilation_job = Mock(return_value=describe_compilation) - session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session @@ -83,104 +81,102 @@ def sagemaker_session(): @pytest.fixture() def skip_if_mms_version(mxnet_version): if parse_version(MXNetModel._LOWEST_MMS_VERSION) <= parse_version(mxnet_version): - pytest.skip('Skipping because this version uses MMS') + pytest.skip("Skipping because this version uses MMS") @pytest.fixture() def skip_if_not_mms_version(mxnet_version): if parse_version(MXNetModel._LOWEST_MMS_VERSION) > parse_version(mxnet_version): - pytest.skip('Skipping because this version does not use MMS') + pytest.skip("Skipping because this version does not use MMS") -def _get_full_image_uri(version, repo=IMAGE_REPO_NAME, processor='cpu', py_version='py2'): +def _get_full_image_uri(version, repo=IMAGE_REPO_NAME, processor="cpu", py_version="py2"): return FULL_IMAGE_URI.format(repo, version, processor, py_version) -def _get_full_image_uri_with_ei(version, repo=IMAGE_REPO_NAME, processor='cpu', py_version='py2'): - return FULL_IMAGE_URI.format('{}-eia'.format(repo), version, processor, py_version) +def _get_full_image_uri_with_ei(version, repo=IMAGE_REPO_NAME, processor="cpu", py_version="py2"): + return FULL_IMAGE_URI.format("{}-eia".format(repo), version, processor, py_version) def _create_train_job(version): return { - 'image': _get_full_image_uri(version), - 'input_mode': 'File', - 'input_config': [ + "image": _get_full_image_uri(version), + "input_mode": "File", + "input_config": [ { - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix' + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", } - } + }, } ], - 'role': ROLE, - 'job_name': JOB_NAME, - 'output_config': { - 'S3OutputPath': 's3://{}/'.format(BUCKET_NAME), + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": "ml.c4.4xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, }, - 'resource_config': { - 'InstanceType': 'ml.c4.4xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 30, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-west-2"', }, - 'hyperparameters': { - 'sagemaker_program': json.dumps('dummy_script.py'), - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': str(logging.INFO), - 'sagemaker_job_name': json.dumps(JOB_NAME), - 'sagemaker_submit_directory': - json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)), - 'sagemaker_region': '"us-west-2"' - }, - 'stop_condition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'tags': None, - 'vpc_config': None, - 'metric_definitions': None + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, } def _create_compilation_job(input_shape, output_location): return { - 'input_model_config': { - 'DataInputConfig': input_shape, - 'Framework': FRAMEWORK.upper(), - 'S3Uri': 's3://m/m.tar.gz' - }, - 'job_name': COMPILATION_JOB_NAME, - 'output_model_config': { - 'S3OutputLocation': output_location, - 'TargetDevice': 'ml_c4' + "input_model_config": { + "DataInputConfig": input_shape, + "Framework": FRAMEWORK.upper(), + "S3Uri": "s3://m/m.tar.gz", }, - 'role': ROLE, - 'stop_condition': { - 'MaxRuntimeInSeconds': 300 - }, - 'tags': None + "job_name": COMPILATION_JOB_NAME, + "output_model_config": {"S3OutputLocation": output_location, "TargetDevice": "ml_c4"}, + "role": ROLE, + "stop_condition": {"MaxRuntimeInSeconds": 300}, + "tags": None, } def _neo_inference_image(mxnet_version): return "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-neo-{}:{}-cpu-py3".format( - FRAMEWORK.lower(), - mxnet_version + FRAMEWORK.lower(), mxnet_version ) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_create_model(sagemaker_session, mxnet_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version, container_log_level=container_log_level, - base_job_name='job', source_dir=source_dir) - - job_name = 'new_name' - mx.fit(inputs='s3://mybucket/train', job_name=job_name) + source_dir = "s3://mybucket/source" + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=mxnet_version, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "new_name" + mx.fit(inputs="s3://mybucket/train", job_name=job_name) model = mx.create_model() assert model.sagemaker_session == sagemaker_session @@ -197,20 +193,28 @@ def test_create_model(sagemaker_session, mxnet_version): def test_create_model_with_optional_params(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - enable_cloudwatch_metrics = 'true' - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - container_log_level=container_log_level, base_job_name='job', source_dir=source_dir, - enable_cloudwatch_metrics=enable_cloudwatch_metrics) + source_dir = "s3://mybucket/source" + enable_cloudwatch_metrics = "true" + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + ) - mx.fit(inputs='s3://mybucket/train', job_name='new_name') + mx.fit(inputs="s3://mybucket/train", job_name="new_name") - new_role = 'role' + new_role = "role" model_server_workers = 2 - vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - model = mx.create_model(role=new_role, model_server_workers=model_server_workers, - vpc_config_override=vpc_config) + vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + model = mx.create_model( + role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config + ) assert model.role == new_role assert model.model_server_workers == model_server_workers @@ -219,15 +223,22 @@ def test_create_model_with_optional_params(sagemaker_session): def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - custom_image = 'mxnet:2.0' - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - image_name=custom_image, container_log_level=container_log_level, - base_job_name='job', source_dir=source_dir) - - job_name = 'new_name' - mx.fit(inputs='s3://mybucket/train', job_name='new_name') + source_dir = "s3://mybucket/source" + custom_image = "mxnet:2.0" + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + image_name=custom_image, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "new_name" + mx.fit(inputs="s3://mybucket/train", job_name="new_name") model = mx.create_model() assert model.sagemaker_session == sagemaker_session @@ -239,107 +250,137 @@ def test_create_model_with_custom_image(sagemaker_session): assert model.source_dir == source_dir -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) def test_mxnet(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version) + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=mxnet_version, + ) - inputs = 's3://mybucket/train' + inputs = "s3://mybucket/train" mx.fit(inputs=inputs) sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job'] + assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ['resource'] + assert boto_call_names == ["resource"] expected_train_args = _create_train_job(mxnet_version) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args model = mx.create_model() - expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2' + expected_image_base = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2" environment = { - 'Environment': { - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz'.format(TIMESTAMP), - 'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20' + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-mxnet-{}/source/sourcedir.tar.gz".format( + TIMESTAMP + ), + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", }, - 'Image': expected_image_base.format(mxnet_version), 'ModelDataUrl': 's3://m/m.tar.gz' + "Image": expected_image_base.format(mxnet_version), + "ModelDataUrl": "s3://m/m.tar.gz", } assert environment == model.prepare_container_def(GPU) - assert 'cpu' in model.prepare_container_def(CPU)['Image'] + assert "cpu" in model.prepare_container_def(CPU)["Image"] predictor = mx.deploy(1, GPU) assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.repack_model') -@patch('time.strftime', return_value=TIMESTAMP) -def test_mxnet_mms_version(strftime, repack_model, sagemaker_session, mxnet_version, skip_if_not_mms_version): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version) +@patch("sagemaker.utils.repack_model") +@patch("time.strftime", return_value=TIMESTAMP) +def test_mxnet_mms_version( + strftime, repack_model, sagemaker_session, mxnet_version, skip_if_not_mms_version +): + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=mxnet_version, + ) - inputs = 's3://mybucket/train' + inputs = "s3://mybucket/train" mx.fit(inputs=inputs) sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job'] + assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ['resource'] + assert boto_call_names == ["resource"] expected_train_args = _create_train_job(mxnet_version) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args model = mx.create_model() - expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, 'gpu') + expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, "gpu") environment = { - 'Environment': { - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz', - 'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_REGION': 'us-west-2', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20' + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz", + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", }, - 'Image': expected_image_base.format(mxnet_version), - 'ModelDataUrl': 's3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz' + "Image": expected_image_base.format(mxnet_version), + "ModelDataUrl": "s3://mybucket/sagemaker-mxnet-2017-11-06-14:14:15.672/model.tar.gz", } assert environment == model.prepare_container_def(GPU) - assert 'cpu' in model.prepare_container_def(CPU)['Image'] + assert "cpu" in model.prepare_container_def(CPU)["Image"] predictor = mx.deploy(1, GPU) assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) def test_mxnet_neo(strftime, sagemaker_session, mxnet_version, skip_if_mms_version): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=mxnet_version) + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=mxnet_version, + ) - inputs = 's3://mybucket/train' + inputs = "s3://mybucket/train" mx.fit(inputs=inputs) - input_shape = {'data': [100, 1, 28, 28]} - output_location = 's3://neo-sdk-test' + input_shape = {"data": [100, 1, 28, 28]} + output_location = "s3://neo-sdk-test" - compiled_model = mx.compile_model(target_instance_family='ml_c4', input_shape=input_shape, - output_path=output_location) + compiled_model = mx.compile_model( + target_instance_family="ml_c4", input_shape=input_shape, output_path=output_location + ) sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job', 'sagemaker_client.describe_training_job', - 'compile_model', 'wait_for_compilation_job'] + assert sagemaker_call_names == [ + "train", + "logs_for_job", + "sagemaker_client.describe_training_job", + "compile_model", + "wait_for_compilation_job", + ] expected_compile_model_args = _create_compilation_job(json.dumps(input_shape), output_location) actual_compile_model_args = sagemaker_session.method_calls[3][2] @@ -352,253 +393,309 @@ def test_mxnet_neo(strftime, sagemaker_session, mxnet_version, skip_if_mms_versi with pytest.raises(Exception) as wrong_target: mx.deploy(1, CPU_C5, use_compiled_model=True) - assert str(wrong_target.value).startswith('No compiled model for') + assert str(wrong_target.value).startswith("No compiled model for") # deploy without sagemaker Neo should continue to work mx.deploy(1, CPU) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_model(sagemaker_session): - model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = MXNetModel( + MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + ) predictor = model.deploy(1, GPU) assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.repack_model') +@patch("sagemaker.utils.repack_model") def test_model_mms_version(repack_model, sagemaker_session): - model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - framework_version=MXNetModel._LOWEST_MMS_VERSION, - sagemaker_session=sagemaker_session, name='test-mxnet-model') + model = MXNetModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=MXNetModel._LOWEST_MMS_VERSION, + sagemaker_session=sagemaker_session, + name="test-mxnet-model", + ) predictor = model.deploy(1, GPU) - repack_model.assert_called_once_with(inference_script=SCRIPT_PATH, - source_directory=None, - dependencies=[], - model_uri=MODEL_DATA, - repacked_model_uri='s3://mybucket/test-mxnet-model/model.tar.gz', - sagemaker_session=sagemaker_session) + repack_model.assert_called_once_with( + inference_script=SCRIPT_PATH, + source_directory=None, + dependencies=[], + model_uri=MODEL_DATA, + repacked_model_uri="s3://mybucket/test-mxnet-model/model.tar.gz", + sagemaker_session=sagemaker_session, + ) assert model.model_data == MODEL_DATA - assert model.repacked_model_data == 's3://mybucket/test-mxnet-model/model.tar.gz' - assert model.uploaded_code == UploadedCode(s3_prefix='s3://mybucket/test-mxnet-model/model.tar.gz', - script_name=os.path.basename(SCRIPT_PATH)) + assert model.repacked_model_data == "s3://mybucket/test-mxnet-model/model.tar.gz" + assert model.uploaded_code == UploadedCode( + s3_prefix="s3://mybucket/test-mxnet-model/model.tar.gz", + script_name=os.path.basename(SCRIPT_PATH), + ) assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_model_image_accelerator(sagemaker_session): - model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = MXNetModel( + MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + ) container_def = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) - assert container_def['Image'] == _get_full_image_uri_with_ei(defaults.MXNET_VERSION) + assert container_def["Image"] == _get_full_image_uri_with_ei(defaults.MXNET_VERSION) -@patch('sagemaker.utils.repack_model', MagicMock()) +@patch("sagemaker.utils.repack_model", MagicMock()) def test_model_image_accelerator_mms_version(sagemaker_session): - model = MXNetModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - framework_version=MXNetModel._LOWEST_MMS_VERSION, - sagemaker_session=sagemaker_session) + model = MXNetModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=MXNetModel._LOWEST_MMS_VERSION, + sagemaker_session=sagemaker_session, + ) container_def = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) - assert container_def['Image'] == _get_full_image_uri_with_ei(MXNetModel._LOWEST_MMS_VERSION, - IMAGE_REPO_SERVING_NAME) + assert container_def["Image"] == _get_full_image_uri_with_ei( + MXNetModel._LOWEST_MMS_VERSION, IMAGE_REPO_SERVING_NAME + ) def test_train_image_default(sagemaker_session): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) assert _get_full_image_uri(defaults.MXNET_VERSION) in mx.train_image() def test_attach(sagemaker_session, mxnet_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:{}-cpu-py2'.format(mxnet_version) + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:{}-cpu-py2".format( + mxnet_version + ) returned_job_description = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': training_image - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"' + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', }, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", }, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': { - 'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo' - }, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'} + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, } - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) - estimator = MXNet.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' - assert estimator.py_version == 'py2' + estimator = MXNet.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" + assert estimator.py_version == "py2" assert estimator.framework_version == mxnet_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' - assert estimator.tags == LIST_TAGS_RESULT['Tags'] + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + assert estimator.tags == LIST_TAGS_RESULT["Tags"] def test_attach_old_container(sagemaker_session): - returned_job_description = {'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0'}, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = MXNet.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' - assert estimator.py_version == 'py2' - assert estimator.framework_version == '0.12' - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + returned_job_description = { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0", + }, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = MXNet.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" + assert estimator.py_version == "py2" + assert estimator.framework_version == "0.12" + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" def test_attach_wrong_framework(sagemaker_session): rjd = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0.4'}, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0.4", + }, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) with pytest.raises(ValueError) as error: - MXNet.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + MXNet.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "didn't use image for requested framework" in str(error) def test_attach_custom_image(sagemaker_session): - training_image = 'ubuntu:latest' - returned_job_description = {'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = MXNet.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + training_image = "ubuntu:latest" + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = MXNet.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.image_name == training_image assert estimator.train_image() == training_image def test_estimator_script_mode_launch_parameter_server(sagemaker_session): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, framework_version='1.3.0') - assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == 'true' + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, + framework_version="1.3.0", + ) + assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == "true" def test_estimator_script_mode_dont_launch_parameter_server(sagemaker_session): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - distributions={'parameter_server': {'enabled': False}}, framework_version='1.3.0') - assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == 'false' + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + distributions={"parameter_server": {"enabled": False}}, + framework_version="1.3.0", + ) + assert mx.hyperparameters().get(MXNet.LAUNCH_PS_ENV_NAME) == "false" def test_estimator_wrong_version_launch_parameter_server(sagemaker_session): with pytest.raises(ValueError) as e: - MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, framework_version='1.2.1') - assert 'The distributions option is valid for only versions 1.3 and higher' in str(e) - - -@patch('sagemaker.mxnet.estimator.empty_framework_version_warning') + MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + distributions=LAUNCH_PS_DISTRIBUTIONS_DICT, + framework_version="1.2.1", + ) + assert "The distributions option is valid for only versions 1.3 and higher" in str(e) + + +@patch("sagemaker.mxnet.estimator.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): - mx = MXNet(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=None) + mx = MXNet( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=None, + ) assert mx.framework_version == defaults.MXNET_VERSION warning.assert_called_with(defaults.MXNET_VERSION, mx.LATEST_VERSION) diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index c72e8cea70..963acc73d2 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -18,43 +18,43 @@ from sagemaker.amazon.ntm import NTM, NTMPredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" NUM_TOPICS = 5 -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'num_topics': NUM_TOPICS}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict({"num_topics": NUM_TOPICS}, **COMMON_TRAIN_ARGS) REGION = "us-west-2" BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': "s3://bucket/model.tar.gz" - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -62,7 +62,13 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): - ntm = NTM(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_TOPICS, sagemaker_session=sagemaker_session) + ntm = NTM( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_TOPICS, + sagemaker_session=sagemaker_session, + ) assert ntm.role == ROLE assert ntm.train_instance_count == TRAIN_INSTANCE_COUNT assert ntm.train_instance_type == TRAIN_INSTANCE_TYPE @@ -72,41 +78,50 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert ntm.role == COMMON_TRAIN_ARGS['role'] + assert ntm.role == COMMON_TRAIN_ARGS["role"] assert ntm.train_instance_count == TRAIN_INSTANCE_COUNT - assert ntm.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] - assert ntm.num_topics == ALL_REQ_ARGS['num_topics'] + assert ntm.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] + assert ntm.num_topics == ALL_REQ_ARGS["num_topics"] def test_all_hyperparameters(sagemaker_session): - ntm = NTM(sagemaker_session=sagemaker_session, - encoder_layers=[1, 2, 3], epochs=3, encoder_layers_activation='tanh', optimizer='sgd', - tolerance=0.05, num_patience_epochs=2, batch_norm=False, rescale_gradient=0.5, clip_gradient=0.5, - weight_decay=0.5, learning_rate=0.5, **ALL_REQ_ARGS) + ntm = NTM( + sagemaker_session=sagemaker_session, + encoder_layers=[1, 2, 3], + epochs=3, + encoder_layers_activation="tanh", + optimizer="sgd", + tolerance=0.05, + num_patience_epochs=2, + batch_norm=False, + rescale_gradient=0.5, + clip_gradient=0.5, + weight_decay=0.5, + learning_rate=0.5, + **ALL_REQ_ARGS + ) assert ntm.hyperparameters() == dict( - num_topics=str(ALL_REQ_ARGS['num_topics']), - encoder_layers='[1, 2, 3]', - epochs='3', - encoder_layers_activation='tanh', - optimizer='sgd', - tolerance='0.05', - num_patience_epochs='2', - batch_norm='False', - rescale_gradient='0.5', - clip_gradient='0.5', - weight_decay='0.5', - learning_rate='0.5' + num_topics=str(ALL_REQ_ARGS["num_topics"]), + encoder_layers="[1, 2, 3]", + epochs="3", + encoder_layers_activation="tanh", + optimizer="sgd", + tolerance="0.05", + num_patience_epochs="2", + batch_norm="False", + rescale_gradient="0.5", + clip_gradient="0.5", + weight_decay="0.5", + learning_rate="0.5", ) def test_image(sagemaker_session): ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert ntm.train_image() == registry(REGION, "ntm") + '/ntm:1' + assert ntm.train_image() == registry(REGION, "ntm") + "/ntm:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_topics', 'string') -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("num_topics", "string")]) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -114,10 +129,9 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame NTM(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_topics', 0), - ('num_topics', 10000) -]) +@pytest.mark.parametrize( + "required_hyper_parameters, value", [("num_topics", 0), ("num_topics", 10000)] +) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -125,9 +139,7 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param NTM(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('iterable_hyper_parameters, value', [ - ('encoder_layers', 0) -]) +@pytest.mark.parametrize("iterable_hyper_parameters, value", [("encoder_layers", 0)]) def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value): with pytest.raises(TypeError): test_params = ALL_REQ_ARGS.copy() @@ -135,17 +147,20 @@ def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parame NTM(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('epochs', 'string'), - ('encoder_layers_activation', 0), - ('optimizer', 0), - ('tolerance', 'string'), - ('num_patience_epochs', 'string'), - ('rescale_gradient', 'string'), - ('clip_gradient', 'string'), - ('weight_decay', 'string'), - ('learning_rate', 'string') -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("epochs", "string"), + ("encoder_layers_activation", 0), + ("optimizer", 0), + ("tolerance", "string"), + ("num_patience_epochs", "string"), + ("rescale_gradient", "string"), + ("clip_gradient", "string"), + ("weight_decay", "string"), + ("learning_rate", "string"), + ], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -153,23 +168,26 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame NTM(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('epochs', 0), - ('epochs', 1000), - ('encoder_layers_activation', 'string'), - ('optimizer', 'string'), - ('tolerance', 0), - ('tolerance', 0.5), - ('num_patience_epochs', 0), - ('num_patience_epochs', 100), - ('rescale_gradient', 0), - ('rescale_gradient', 10), - ('clip_gradient', 0), - ('weight_decay', -1), - ('weight_decay', 2), - ('learning_rate', 0), - ('learning_rate', 2) -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("epochs", 0), + ("epochs", 1000), + ("encoder_layers_activation", "string"), + ("optimizer", "string"), + ("tolerance", 0), + ("tolerance", 0.5), + ("num_patience_epochs", 0), + ("num_patience_epochs", 100), + ("rescale_gradient", 0), + ("rescale_gradient", 10), + ("clip_gradient", 0), + ("weight_decay", -1), + ("weight_decay", 2), + ("learning_rate", 0), + ("learning_rate", 2), + ], +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -186,7 +204,12 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param def test_call_fit(base_fit, sagemaker_session): ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ntm.fit(data, MINI_BATCH_SIZE) @@ -199,16 +222,24 @@ def test_call_fit(base_fit, sagemaker_session): def test_call_fit_none_mini_batch_size(sagemaker_session): ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ntm.fit(data) def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): ntm._prepare_for_training(data, "some") @@ -217,8 +248,12 @@ def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): def test_prepare_for_training_wrong_value_lower_mini_batch_size(sagemaker_session): ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): ntm._prepare_for_training(data, 0) @@ -226,24 +261,38 @@ def test_prepare_for_training_wrong_value_lower_mini_batch_size(sagemaker_sessio def test_prepare_for_training_wrong_value_upper_mini_batch_size(sagemaker_session): ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): ntm._prepare_for_training(data, 10001) def test_model_image(sagemaker_session): ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ntm.fit(data, MINI_BATCH_SIZE) model = ntm.create_model() - assert model.image == registry(REGION, "ntm") + '/ntm:1' + assert model.image == registry(REGION, "ntm") + "/ntm:1" def test_predictor_type(sagemaker_session): ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) ntm.fit(data, MINI_BATCH_SIZE) model = ntm.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index 0d5936ecef..9338098761 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -19,51 +19,50 @@ from sagemaker.predictor import RealTimePredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" EPOCHS = 5 ENC0_MAX_SEQ_LEN = 100 ENC0_VOCAB_SIZE = 500 MINI_BATCH_SIZE = 32 -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({ - 'epochs': EPOCHS, - 'enc0_max_seq_len': ENC0_MAX_SEQ_LEN, - 'enc0_vocab_size': ENC0_VOCAB_SIZE, -}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict( + {"epochs": EPOCHS, "enc0_max_seq_len": ENC0_MAX_SEQ_LEN, "enc0_vocab_size": ENC0_VOCAB_SIZE}, + **COMMON_TRAIN_ARGS +) REGION = "us-west-2" BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': "s3://bucket/model.tar.gz" - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -72,9 +71,14 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): object2vec = Object2Vec( - ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, - EPOCHS, ENC0_MAX_SEQ_LEN, ENC0_VOCAB_SIZE, - sagemaker_session=sagemaker_session) + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + EPOCHS, + ENC0_MAX_SEQ_LEN, + ENC0_VOCAB_SIZE, + sagemaker_session=sagemaker_session, + ) assert object2vec.role == ROLE assert object2vec.train_instance_count == TRAIN_INSTANCE_COUNT assert object2vec.train_instance_type == TRAIN_INSTANCE_TYPE @@ -86,12 +90,12 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): object2vec = Object2Vec(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert object2vec.role == COMMON_TRAIN_ARGS['role'] + assert object2vec.role == COMMON_TRAIN_ARGS["role"] assert object2vec.train_instance_count == TRAIN_INSTANCE_COUNT - assert object2vec.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] - assert object2vec.epochs == ALL_REQ_ARGS['epochs'] - assert object2vec.enc0_max_seq_len == ALL_REQ_ARGS['enc0_max_seq_len'] - assert object2vec.enc0_vocab_size == ALL_REQ_ARGS['enc0_vocab_size'] + assert object2vec.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] + assert object2vec.epochs == ALL_REQ_ARGS["epochs"] + assert object2vec.enc0_max_seq_len == ALL_REQ_ARGS["enc0_max_seq_len"] + assert object2vec.enc0_vocab_size == ALL_REQ_ARGS["enc0_vocab_size"] def test_all_hyperparameters(sagemaker_session): @@ -107,16 +111,16 @@ def test_all_hyperparameters(sagemaker_session): num_classes=5, mlp_layers=3, mlp_dim=1024, - mlp_activation='tanh', - output_layer='softmax', - optimizer='adam', + mlp_activation="tanh", + output_layer="softmax", + optimizer="adam", learning_rate=0.0001, negative_sampling_rate=1, - comparator_list='hadamard, abs_diff', + comparator_list="hadamard, abs_diff", tied_token_embedding_weight=True, - token_embedding_storage_type='row_sparse', - enc0_network='bilstm', - enc1_network='hcnn', + token_embedding_storage_type="row_sparse", + enc0_network="bilstm", + enc1_network="hcnn", enc0_cnn_filter_width=3, enc1_cnn_filter_width=3, enc1_max_seq_len=300, @@ -127,21 +131,20 @@ def test_all_hyperparameters(sagemaker_session): enc1_layers=3, enc0_freeze_pretrained_embedding=True, enc1_freeze_pretrained_embedding=False, - **ALL_REQ_ARGS) + **ALL_REQ_ARGS + ) hp = object2vec.hyperparameters() - assert hp['epochs'] == str(EPOCHS) - assert hp['mlp_activation'] == 'tanh' + assert hp["epochs"] == str(EPOCHS) + assert hp["mlp_activation"] == "tanh" def test_image(sagemaker_session): object2vec = Object2Vec(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert object2vec.train_image() == registry(REGION, "object2vec") + '/object2vec:1' + assert object2vec.train_image() == registry(REGION, "object2vec") + "/object2vec:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('epochs', 'string') -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("epochs", "string")]) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -149,10 +152,9 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame Object2Vec(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('enc0_vocab_size', 0), - ('enc0_vocab_size', 1000000000) -]) +@pytest.mark.parametrize( + "required_hyper_parameters, value", [("enc0_vocab_size", 0), ("enc0_vocab_size", 1000000000)] +) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -160,17 +162,20 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param Object2Vec(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('epochs', 'string'), - ('optimizer', 0), - ('enc0_cnn_filter_width', 'string'), - ('weight_decay', 'string'), - ('learning_rate', 'string'), - ('negative_sampling_rate', 'some_string'), - ('comparator_list', 0), - ('comparator_list', ['foobar']), - ('token_embedding_storage_type', 123), -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("epochs", "string"), + ("optimizer", 0), + ("enc0_cnn_filter_width", "string"), + ("weight_decay", "string"), + ("learning_rate", "string"), + ("negative_sampling_rate", "some_string"), + ("comparator_list", 0), + ("comparator_list", ["foobar"]), + ("token_embedding_storage_type", 123), + ], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -178,23 +183,26 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame Object2Vec(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('epochs', 0), - ('epochs', 1000), - ('optimizer', 'string'), - ('early_stopping_tolerance', 0), - ('early_stopping_tolerance', 0.5), - ('early_stopping_patience', 0), - ('early_stopping_patience', 100), - ('weight_decay', -1), - ('weight_decay', 200000), - ('enc0_cnn_filter_width', 2000), - ('learning_rate', 0), - ('learning_rate', 2), - ('negative_sampling_rate', -1), - ('comparator_list', 'hadamard,foobar'), - ('token_embedding_storage_type', 'foobar'), -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("epochs", 0), + ("epochs", 1000), + ("optimizer", "string"), + ("early_stopping_tolerance", 0), + ("early_stopping_tolerance", 0.5), + ("early_stopping_patience", 0), + ("early_stopping_patience", 100), + ("weight_decay", -1), + ("weight_decay", 200000), + ("enc0_cnn_filter_width", 2000), + ("learning_rate", 0), + ("learning_rate", 2), + ("negative_sampling_rate", -1), + ("comparator_list", "hadamard,foobar"), + ("token_embedding_storage_type", "foobar"), + ], +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -208,9 +216,16 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param @patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit(base_fit, sagemaker_session): - object2vec = Object2Vec(base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + object2vec = Object2Vec( + base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) object2vec.fit(data, MINI_BATCH_SIZE) @@ -221,53 +236,87 @@ def test_call_fit(base_fit, sagemaker_session): def test_call_fit_none_mini_batch_size(sagemaker_session): - object2vec = Object2Vec(base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + object2vec = Object2Vec( + base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) + + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) object2vec.fit(data) def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): - object2vec = Object2Vec(base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + object2vec = Object2Vec( + base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): object2vec._prepare_for_training(data, "some") def test_prepare_for_training_wrong_value_lower_mini_batch_size(sagemaker_session): - object2vec = Object2Vec(base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + object2vec = Object2Vec( + base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) + + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): object2vec._prepare_for_training(data, 0) def test_prepare_for_training_wrong_value_upper_mini_batch_size(sagemaker_session): - object2vec = Object2Vec(base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + object2vec = Object2Vec( + base_job_name="object2vec", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) + + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises(ValueError): object2vec._prepare_for_training(data, 10001) def test_model_image(sagemaker_session): object2vec = Object2Vec(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) object2vec.fit(data, MINI_BATCH_SIZE) model = object2vec.create_model() - assert model.image == registry(REGION, "object2vec") + '/object2vec:1' + assert model.image == registry(REGION, "object2vec") + "/object2vec:1" def test_predictor_type(sagemaker_session): object2vec = Object2Vec(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) object2vec.fit(data, MINI_BATCH_SIZE) model = object2vec.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 0748d90ca6..cb83423444 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -18,43 +18,43 @@ from sagemaker.amazon.pca import PCA, PCAPredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" NUM_COMPONENTS = 5 -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} -ALL_REQ_ARGS = dict({'num_components': NUM_COMPONENTS}, **COMMON_TRAIN_ARGS) +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} +ALL_REQ_ARGS = dict({"num_components": NUM_COMPONENTS}, **COMMON_TRAIN_ARGS) -REGION = 'us-west-2' -BUCKET_NAME = 'Some-Bucket' +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': 's3://bucket/model.tar.gz' - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -62,7 +62,13 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): - pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, sagemaker_session=sagemaker_session) + pca = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + sagemaker_session=sagemaker_session, + ) assert pca.role == ROLE assert pca.train_instance_count == TRAIN_INSTANCE_COUNT assert pca.train_instance_type == TRAIN_INSTANCE_TYPE @@ -72,31 +78,34 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert pca.role == COMMON_TRAIN_ARGS['role'] + assert pca.role == COMMON_TRAIN_ARGS["role"] assert pca.train_instance_count == TRAIN_INSTANCE_COUNT - assert pca.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] - assert pca.num_components == ALL_REQ_ARGS['num_components'] + assert pca.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] + assert pca.num_components == ALL_REQ_ARGS["num_components"] def test_all_hyperparameters(sagemaker_session): - pca = PCA(sagemaker_session=sagemaker_session, - algorithm_mode='regular', subtract_mean='True', extra_components=1, **ALL_REQ_ARGS) + pca = PCA( + sagemaker_session=sagemaker_session, + algorithm_mode="regular", + subtract_mean="True", + extra_components=1, + **ALL_REQ_ARGS + ) assert pca.hyperparameters() == dict( - num_components=str(ALL_REQ_ARGS['num_components']), - algorithm_mode='regular', - subtract_mean='True', - extra_components='1' + num_components=str(ALL_REQ_ARGS["num_components"]), + algorithm_mode="regular", + subtract_mean="True", + extra_components="1", ) def test_image(sagemaker_session): pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert pca.train_image() == registry(REGION, 'pca') + '/pca:1' + assert pca.train_image() == registry(REGION, "pca") + "/pca:1" -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_components', 'string') -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("num_components", "string")]) def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -104,9 +113,7 @@ def test_required_hyper_parameters_type(sagemaker_session, required_hyper_parame PCA(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('required_hyper_parameters, value', [ - ('num_components', 0) -]) +@pytest.mark.parametrize("required_hyper_parameters, value", [("num_components", 0)]) def test_required_hyper_parameters_value(sagemaker_session, required_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -114,10 +121,9 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param PCA(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('algorithm_mode', 0), - ('extra_components', 'string') -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", [("algorithm_mode", 0), ("extra_components", "string")] +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -125,9 +131,7 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame PCA(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('algorithm_mode', 'string') -]) +@pytest.mark.parametrize("optional_hyper_parameters, value", [("algorithm_mode", "string")]) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -135,16 +139,21 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param PCA(sagemaker_session=sagemaker_session, **test_params) -PREFIX = 'prefix' +PREFIX = "prefix" FEATURE_DIM = 10 MINI_BATCH_SIZE = 200 -@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit') +@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit(base_fit, sagemaker_session): - pca = PCA(base_job_name='pca', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + pca = PCA(base_job_name="pca", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) pca.fit(data, MINI_BATCH_SIZE) @@ -155,30 +164,42 @@ def test_call_fit(base_fit, sagemaker_session): def test_prepare_for_training_no_mini_batch_size(sagemaker_session): - pca = PCA(base_job_name='pca', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + pca = PCA(base_job_name="pca", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) pca._prepare_for_training(data) assert pca.mini_batch_size == 1 def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): - pca = PCA(base_job_name='pca', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + pca = PCA(base_job_name="pca", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): - pca.fit(data, 'some') + pca.fit(data, "some") def test_prepare_for_training_multiple_channel(sagemaker_session): - lr = PCA(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lr = PCA(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) lr._prepare_for_training([data, data]) @@ -186,29 +207,43 @@ def test_prepare_for_training_multiple_channel(sagemaker_session): def test_prepare_for_training_multiple_channel_no_train(sagemaker_session): - lr = PCA(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + lr = PCA(base_job_name="lr", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='mock') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="mock", + ) with pytest.raises(ValueError) as ex: lr._prepare_for_training([data, data]) - assert 'Must provide train channel.' in str(ex) + assert "Must provide train channel." in str(ex) def test_model_image(sagemaker_session): pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) pca.fit(data, MINI_BATCH_SIZE) model = pca.create_model() - assert model.image == registry(REGION, 'pca') + '/pca:1' + assert model.image == registry(REGION, "pca") + "/pca:1" def test_predictor_type(sagemaker_session): pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) pca.fit(data, MINI_BATCH_SIZE) model = pca.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 3c33ae2396..fef72271e6 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -35,18 +35,23 @@ ENDPOINT = "some-ep" -TIMESTAMP = '2017-10-10-14-14-15' -BUCKET_NAME = 'mybucket' +TIMESTAMP = "2017-10-10-14-14-15" +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -IMAGE_NAME = 'fakeimage' -REGION = 'us-west-2' +IMAGE_NAME = "fakeimage" +REGION = "us-west-2" class DummyFrameworkModel(FrameworkModel): - def __init__(self, sagemaker_session, **kwargs): - super(DummyFrameworkModel, self).__init__(MODEL_DATA_1, MODEL_IMAGE_1, ROLE, ENTRY_POINT, - sagemaker_session=sagemaker_session, **kwargs) + super(DummyFrameworkModel, self).__init__( + MODEL_DATA_1, + MODEL_IMAGE_1, + ROLE, + ENTRY_POINT, + sagemaker_session=sagemaker_session, + **kwargs + ) def create_predictor(self, endpoint_name): return RealTimePredictor(endpoint_name, self.sagemaker_session) @@ -54,88 +59,124 @@ def create_predictor(self, endpoint_name): @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return sms -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_prepare_container_def(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) - sparkml_model = SparkMLModel(model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session, - env={'SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT': 'text/csv'}) - model = PipelineModel(models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, + role=ROLE, + sagemaker_session=sagemaker_session, + env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, + ) + model = PipelineModel( + models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session + ) assert model.pipeline_container_def(INSTANCE_TYPE) == [ { - 'Environment': { - 'SAGEMAKER_PROGRAM': 'blah.py', - 'SAGEMAKER_SUBMIT_DIRECTORY': 's3://mybucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', - 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false' + "Environment": { + "SAGEMAKER_PROGRAM": "blah.py", + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", }, - 'Image': 'mi-1', - 'ModelDataUrl': 's3://bucket/model_1.tar.gz' + "Image": "mi-1", + "ModelDataUrl": "s3://bucket/model_1.tar.gz", }, { - 'Environment': {'SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT': 'text/csv'}, - 'Image': '246618743249.dkr.ecr.us-west-2.amazonaws.com' + '/sagemaker-sparkml-serving:2.2', - 'ModelDataUrl': 's3://bucket/model_2.tar.gz' - } + "Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, + "Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com" + + "/sagemaker-sparkml-serving:2.2", + "ModelDataUrl": "s3://bucket/model_2.tar.gz", + }, ] -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_deploy(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) - sparkml_model = SparkMLModel(model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session) - model = PipelineModel(models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session + ) + model = PipelineModel( + models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session + ) model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) sagemaker_session.endpoint_from_production_variants.assert_called_with( - 'mi-1-2017-10-10-14-14-15', - [{'InitialVariantWeight': 1, - 'ModelName': 'mi-1-2017-10-10-14-14-15', - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 1, - 'VariantName': 'AllTraffic'}], + "mi-1-2017-10-10-14-14-15", + [ + { + "InitialVariantWeight": 1, + "ModelName": "mi-1-2017-10-10-14-14-15", + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], None, - wait=True) + wait=True, + ) -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_deploy_endpoint_name(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) - sparkml_model = SparkMLModel(model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session) - model = PipelineModel(models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session + ) + model = PipelineModel( + models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session + ) model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) sagemaker_session.endpoint_from_production_variants.assert_called_with( - 'mi-1-2017-10-10-14-14-15', - [{'InitialVariantWeight': 1, - 'ModelName': 'mi-1-2017-10-10-14-14-15', - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 1, - 'VariantName': 'AllTraffic'}], + "mi-1-2017-10-10-14-14-15", + [ + { + "InitialVariantWeight": 1, + "ModelName": "mi-1-2017-10-10-14-14-15", + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], None, - wait=True) + wait=True, + ) -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_transformer(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) - sparkml_model = SparkMLModel(model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session) - model_name = 'ModelName' - model = PipelineModel(models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session, - name=model_name) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session + ) + model_name = "ModelName" + model = PipelineModel( + models=[framework_model, sparkml_model], + role=ROLE, + sagemaker_session=sagemaker_session, + name=model_name, + ) instance_count = 55 - strategy = 'MultiRecord' - assemble_with = 'Line' + strategy = "MultiRecord" + assemble_with = "Line" output_path = "s3://output/path" output_kms_key = "output:kms:key" accept = "application/jsonlines" @@ -144,12 +185,20 @@ def test_transformer(tfo, time, sagemaker_session): max_payload = 5 tags = [{"my_tag": "my_value"}] volume_kms_key = "volume:kms:key" - transformer = model.transformer(instance_type=INSTANCE_TYPE, instance_count=instance_count, - strategy=strategy, assemble_with=assemble_with, output_path=output_path, - output_kms_key=output_kms_key, accept=accept, env=env, - max_concurrent_transforms=max_concurrent_transforms, - max_payload=max_payload, tags=tags, volume_kms_key=volume_kms_key - ) + transformer = model.transformer( + instance_type=INSTANCE_TYPE, + instance_count=instance_count, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=output_kms_key, + accept=accept, + env=env, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + tags=tags, + volume_kms_key=volume_kms_key, + ) assert transformer.instance_type == INSTANCE_TYPE assert transformer.instance_count == instance_count assert transformer.strategy == strategy @@ -165,38 +214,49 @@ def test_transformer(tfo, time, sagemaker_session): assert transformer.model_name == model_name -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_deploy_tags(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) - sparkml_model = SparkMLModel(model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session) - model = PipelineModel(models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session) - tags = [{'ModelName': 'TestModel'}] + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session + ) + model = PipelineModel( + models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session + ) + tags = [{"ModelName": "TestModel"}] model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags) sagemaker_session.endpoint_from_production_variants.assert_called_with( - 'mi-1-2017-10-10-14-14-15', - [{'InitialVariantWeight': 1, - 'ModelName': 'mi-1-2017-10-10-14-14-15', - 'InstanceType': INSTANCE_TYPE, - 'InitialInstanceCount': 1, - 'VariantName': 'AllTraffic'}], + "mi-1-2017-10-10-14-14-15", + [ + { + "InitialVariantWeight": 1, + "ModelName": "mi-1-2017-10-10-14-14-15", + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], tags, - wait=True) + wait=True, + ) def test_delete_model_without_deploy(sagemaker_session): pipeline_model = PipelineModel([], role=ROLE, sagemaker_session=sagemaker_session) - expected_error_message = 'The SageMaker model must be created before attempting to delete.' + expected_error_message = "The SageMaker model must be created before attempting to delete." with pytest.raises(ValueError, match=expected_error_message): pipeline_model.delete_model() -@patch('tarfile.open') -@patch('time.strftime', return_value=TIMESTAMP) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) def test_delete_model(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) - pipeline_model = PipelineModel([framework_model], role=ROLE, sagemaker_session=sagemaker_session) + pipeline_model = PipelineModel( + [framework_model], role=ROLE, sagemaker_session=sagemaker_session + ) pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) pipeline_model.delete_model() diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index dd393ee1c1..821b069a47 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -21,9 +21,18 @@ import numpy as np from sagemaker.predictor import RealTimePredictor -from sagemaker.predictor import json_serializer, json_deserializer, csv_serializer, \ - csv_deserializer, BytesDeserializer, StringDeserializer, StreamDeserializer, \ - numpy_deserializer, npy_serializer, _NumpyDeserializer +from sagemaker.predictor import ( + json_serializer, + json_deserializer, + csv_serializer, + csv_deserializer, + BytesDeserializer, + StringDeserializer, + StreamDeserializer, + numpy_deserializer, + npy_serializer, + _NumpyDeserializer, +) from tests.unit import DATA_DIR # testing serialization functions @@ -32,23 +41,23 @@ def test_json_serializer_numpy_valid(): result = json_serializer(np.array([1, 2, 3])) - assert result == '[1, 2, 3]' + assert result == "[1, 2, 3]" def test_json_serializer_numpy_valid_2dimensional(): result = json_serializer(np.array([[1, 2, 3], [3, 4, 5]])) - assert result == '[[1, 2, 3], [3, 4, 5]]' + assert result == "[[1, 2, 3], [3, 4, 5]]" def test_json_serializer_empty(): - assert json_serializer(np.array([])) == '[]' + assert json_serializer(np.array([])) == "[]" def test_json_serializer_python_array(): result = json_serializer([1, 2, 3]) - assert result == '[1, 2, 3]' + assert result == "[1, 2, 3]" def test_json_serializer_python_dictionary(): @@ -60,11 +69,11 @@ def test_json_serializer_python_dictionary(): def test_json_serializer_python_invalid_empty(): - assert json_serializer([]) == '[]' + assert json_serializer([]) == "[]" def test_json_serializer_python_dictionary_invalid_empty(): - assert json_serializer({}) == '{}' + assert json_serializer({}) == "{}" def test_json_serializer_csv_buffer(): @@ -77,8 +86,8 @@ def test_json_serializer_csv_buffer(): def test_csv_serializer_str(): - original = '1,2,3' - result = csv_serializer('1,2,3') + original = "1,2,3" + result = csv_serializer("1,2,3") assert result == original @@ -86,31 +95,31 @@ def test_csv_serializer_str(): def test_csv_serializer_python_array(): result = csv_serializer([1, 2, 3]) - assert result == '1,2,3' + assert result == "1,2,3" def test_csv_serializer_numpy_valid(): result = csv_serializer(np.array([1, 2, 3])) - assert result == '1,2,3' + assert result == "1,2,3" def test_csv_serializer_numpy_valid_2dimensional(): result = csv_serializer(np.array([[1, 2, 3], [3, 4, 5]])) - assert result == '1,2,3\n3,4,5' + assert result == "1,2,3\n3,4,5" def test_csv_serializer_list_of_str(): - result = csv_serializer(['1,2,3', '4,5,6']) + result = csv_serializer(["1,2,3", "4,5,6"]) - assert result == '1,2,3\n4,5,6' + assert result == "1,2,3\n4,5,6" def test_csv_serializer_list_of_list(): result = csv_serializer([[1, 2, 3], [3, 4, 5]]) - assert result == '1,2,3\n3,4,5' + assert result == "1,2,3\n3,4,5" def test_csv_serializer_list_of_empty(): @@ -143,55 +152,55 @@ def test_csv_serializer_csv_reader(): def test_csv_deserializer_single_element(): - result = csv_deserializer(io.BytesIO(b'1'), 'text/csv') - assert result == [['1']] + result = csv_deserializer(io.BytesIO(b"1"), "text/csv") + assert result == [["1"]] def test_csv_deserializer_array(): - result = csv_deserializer(io.BytesIO(b'1,2,3'), 'text/csv') - assert result == [['1', '2', '3']] + result = csv_deserializer(io.BytesIO(b"1,2,3"), "text/csv") + assert result == [["1", "2", "3"]] def test_csv_deserializer_2dimensional(): - result = csv_deserializer(io.BytesIO(b'1,2,3\n3,4,5'), 'text/csv') - assert result == [['1', '2', '3'], ['3', '4', '5']] + result = csv_deserializer(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv") + assert result == [["1", "2", "3"], ["3", "4", "5"]] def test_json_deserializer_array(): - result = json_deserializer(io.BytesIO(b'[1, 2, 3]'), 'application/json') + result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json") assert result == [1, 2, 3] def test_json_deserializer_2dimensional(): - result = json_deserializer(io.BytesIO(b'[[1, 2, 3], [3, 4, 5]]'), 'application/json') + result = json_deserializer(io.BytesIO(b"[[1, 2, 3], [3, 4, 5]]"), "application/json") assert result == [[1, 2, 3], [3, 4, 5]] def test_json_deserializer_invalid_data(): with pytest.raises(ValueError) as error: - json_deserializer(io.BytesIO(b'[[1]'), 'application/json') + json_deserializer(io.BytesIO(b"[[1]"), "application/json") assert "column" in str(error) def test_bytes_deserializer(): - result = BytesDeserializer()(io.BytesIO(b'[1, 2, 3]'), 'application/json') + result = BytesDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json") - assert result == b'[1, 2, 3]' + assert result == b"[1, 2, 3]" def test_string_deserializer(): - result = StringDeserializer()(io.BytesIO(b'[1, 2, 3]'), 'application/json') + result = StringDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json") - assert result == '[1, 2, 3]' + assert result == "[1, 2, 3]" def test_stream_deserializer(): - stream, content_type = StreamDeserializer()(io.BytesIO(b'[1, 2, 3]'), 'application/json') + stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json") result = stream.read() - assert result == b'[1, 2, 3]' - assert content_type == 'application/json' + assert result == b"[1, 2, 3]" + assert content_type == "application/json" def test_npy_serializer_python_array(): @@ -203,7 +212,7 @@ def test_npy_serializer_python_array(): def test_npy_serializer_python_array_with_dtype(): array = [1, 2, 3] - dtype = 'float16' + dtype = "float16" result = npy_serializer(array, dtype) @@ -227,7 +236,7 @@ def test_npy_serializer_numpy_valid_multidimensional(): def test_npy_serializer_numpy_valid_list_of_strings(): - array = np.array(['one', 'two', 'three']) + array = np.array(["one", "two", "three"]) result = npy_serializer(array) assert np.array_equal(array, np.load(io.BytesIO(result))) @@ -273,35 +282,37 @@ def test_npy_serializer_python_invalid_empty(): def test_numpy_deser_from_csv(): - arr = numpy_deserializer(io.BytesIO(b'1,2,3\n4,5,6'), 'text/csv') + arr = numpy_deserializer(io.BytesIO(b"1,2,3\n4,5,6"), "text/csv") assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6]])) def test_numpy_deser_from_csv_ragged(): with pytest.raises(ValueError) as error: - numpy_deserializer(io.BytesIO(b'1,2,3\n4,5,6,7'), 'text/csv') + numpy_deserializer(io.BytesIO(b"1,2,3\n4,5,6,7"), "text/csv") assert "errors were detected" in str(error) def test_numpy_deser_from_csv_alpha(): - arr = _NumpyDeserializer(dtype='U5')(io.BytesIO(b'hello,2,3\n4,5,6'), 'text/csv') - assert np.array_equal(arr, np.array([['hello', 2, 3], [4, 5, 6]])) + arr = _NumpyDeserializer(dtype="U5")(io.BytesIO(b"hello,2,3\n4,5,6"), "text/csv") + assert np.array_equal(arr, np.array([["hello", 2, 3], [4, 5, 6]])) def test_numpy_deser_from_json(): - arr = numpy_deserializer(io.BytesIO(b'[[1,2,3],\n[4,5,6]]'), 'application/json') + arr = numpy_deserializer(io.BytesIO(b"[[1,2,3],\n[4,5,6]]"), "application/json") assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6]])) # Sadly, ragged arrays work fine in JSON (giving us a 1D array of Python lists def test_numpy_deser_from_json_ragged(): - arr = numpy_deserializer(io.BytesIO(b'[[1,2,3],\n[4,5,6,7]]'), 'application/json') + arr = numpy_deserializer(io.BytesIO(b"[[1,2,3],\n[4,5,6,7]]"), "application/json") assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6, 7]])) def test_numpy_deser_from_json_alpha(): - arr = _NumpyDeserializer(dtype='U5')(io.BytesIO(b'[["hello",2,3],\n[4,5,6]]'), 'application/json') - assert np.array_equal(arr, np.array([['hello', 2, 3], [4, 5, 6]])) + arr = _NumpyDeserializer(dtype="U5")( + io.BytesIO(b'[["hello",2,3],\n[4,5,6]]'), "application/json" + ) + assert np.array_equal(arr, np.array([["hello", 2, 3], [4, 5, 6]])) def test_numpy_deser_from_npy(): @@ -316,7 +327,7 @@ def test_numpy_deser_from_npy(): def test_numpy_deser_from_npy_object_array(): - array = np.array(['one', 'two']) + array = np.array(["one", "two"]) stream = io.BytesIO() np.save(stream, array) stream.seek(0) @@ -325,37 +336,35 @@ def test_numpy_deser_from_npy_object_array(): assert np.array_equal(array, result) + # testing 'predict' invocations -ENDPOINT = 'mxnet_endpoint' -BUCKET_NAME = 'mxnet_endpoint' -DEFAULT_CONTENT_TYPE = 'application/json' -CSV_CONTENT_TYPE = 'text/csv' +ENDPOINT = "mxnet_endpoint" +BUCKET_NAME = "mxnet_endpoint" +DEFAULT_CONTENT_TYPE = "application/json" +CSV_CONTENT_TYPE = "text/csv" RETURN_VALUE = 0 CSV_RETURN_VALUE = "1,2,3\r\n" -ENDPOINT_DESC = { - 'EndpointConfigName': ENDPOINT -} +ENDPOINT_DESC = {"EndpointConfigName": ENDPOINT} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} def empty_sagemaker_session(): - ims = Mock(name='sagemaker_session') - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - ims.sagemaker_runtime_client = Mock(name='sagemaker_runtime') + ims = Mock(name="sagemaker_session") + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - response_body = Mock('body') - response_body.read = Mock('read', return_value=RETURN_VALUE) - response_body.close = Mock('close', return_value=None) - ims.sagemaker_runtime_client.invoke_endpoint = Mock(name='invoke_endpoint', return_value={'Body': response_body}) + response_body = Mock("body") + response_body.read = Mock("read", return_value=RETURN_VALUE) + response_body.close = Mock("close", return_value=None) + ims.sagemaker_runtime_client.invoke_endpoint = Mock( + name="invoke_endpoint", return_value={"Body": response_body} + ) return ims @@ -368,10 +377,7 @@ def test_predict_call_pass_through(): assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called - expected_request_args = { - 'Body': data, - 'EndpointName': ENDPOINT - } + expected_request_args = {"Body": data, "EndpointName": ENDPOINT} call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args assert kwargs == expected_request_args @@ -380,9 +386,9 @@ def test_predict_call_pass_through(): def test_predict_call_with_headers(): sagemaker_session = empty_sagemaker_session() - predictor = RealTimePredictor(ENDPOINT, sagemaker_session, - content_type=DEFAULT_CONTENT_TYPE, - accept=DEFAULT_CONTENT_TYPE) + predictor = RealTimePredictor( + ENDPOINT, sagemaker_session, content_type=DEFAULT_CONTENT_TYPE, accept=DEFAULT_CONTENT_TYPE + ) data = "untouched" result = predictor.predict(data) @@ -390,10 +396,10 @@ def test_predict_call_with_headers(): assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called expected_request_args = { - 'Accept': DEFAULT_CONTENT_TYPE, - 'Body': data, - 'ContentType': DEFAULT_CONTENT_TYPE, - 'EndpointName': ENDPOINT + "Accept": DEFAULT_CONTENT_TYPE, + "Body": data, + "ContentType": DEFAULT_CONTENT_TYPE, + "EndpointName": ENDPOINT, } call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args assert kwargs == expected_request_args @@ -402,30 +408,34 @@ def test_predict_call_with_headers(): def json_sagemaker_session(): - ims = Mock(name='sagemaker_session') - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - ims.sagemaker_runtime_client = Mock(name='sagemaker_runtime') + ims = Mock(name="sagemaker_session") + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - response_body = Mock('body') - response_body.read = Mock('read', return_value=json.dumps([RETURN_VALUE])) - response_body.close = Mock('close', return_value=None) - ims.sagemaker_runtime_client.invoke_endpoint = Mock(name='invoke_endpoint', - return_value={'Body': response_body, - 'ContentType': DEFAULT_CONTENT_TYPE}) + response_body = Mock("body") + response_body.read = Mock("read", return_value=json.dumps([RETURN_VALUE])) + response_body.close = Mock("close", return_value=None) + ims.sagemaker_runtime_client.invoke_endpoint = Mock( + name="invoke_endpoint", + return_value={"Body": response_body, "ContentType": DEFAULT_CONTENT_TYPE}, + ) return ims def test_predict_call_with_headers_and_json(): sagemaker_session = json_sagemaker_session() - predictor = RealTimePredictor(ENDPOINT, sagemaker_session, - content_type='not/json', - accept='also/not-json', - serializer=json_serializer) + predictor = RealTimePredictor( + ENDPOINT, + sagemaker_session, + content_type="not/json", + accept="also/not-json", + serializer=json_serializer, + ) data = [1, 2] result = predictor.predict(data) @@ -433,10 +443,10 @@ def test_predict_call_with_headers_and_json(): assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called expected_request_args = { - 'Accept': 'also/not-json', - 'Body': json.dumps(data), - 'ContentType': 'not/json', - 'EndpointName': ENDPOINT + "Accept": "also/not-json", + "Body": json.dumps(data), + "ContentType": "not/json", + "EndpointName": ENDPOINT, } call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args assert kwargs == expected_request_args @@ -445,29 +455,30 @@ def test_predict_call_with_headers_and_json(): def ret_csv_sagemaker_session(): - ims = Mock(name='sagemaker_session') - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - ims.sagemaker_runtime_client = Mock(name='sagemaker_runtime') + ims = Mock(name="sagemaker_session") + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - response_body = Mock('body') - response_body.read = Mock('read', return_value=CSV_RETURN_VALUE) - response_body.close = Mock('close', return_value=None) - ims.sagemaker_runtime_client.invoke_endpoint = Mock(name='invoke_endpoint', - return_value={'Body': response_body, - 'ContentType': CSV_CONTENT_TYPE}) + response_body = Mock("body") + response_body.read = Mock("read", return_value=CSV_RETURN_VALUE) + response_body.close = Mock("close", return_value=None) + ims.sagemaker_runtime_client.invoke_endpoint = Mock( + name="invoke_endpoint", + return_value={"Body": response_body, "ContentType": CSV_CONTENT_TYPE}, + ) return ims def test_predict_call_with_headers_and_csv(): sagemaker_session = ret_csv_sagemaker_session() - predictor = RealTimePredictor(ENDPOINT, sagemaker_session, - accept=CSV_CONTENT_TYPE, - serializer=csv_serializer) + predictor = RealTimePredictor( + ENDPOINT, sagemaker_session, accept=CSV_CONTENT_TYPE, serializer=csv_serializer + ) data = [1, 2] result = predictor.predict(data) @@ -475,10 +486,10 @@ def test_predict_call_with_headers_and_csv(): assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called expected_request_args = { - 'Accept': CSV_CONTENT_TYPE, - 'Body': '1,2', - 'ContentType': CSV_CONTENT_TYPE, - 'EndpointName': ENDPOINT + "Accept": CSV_CONTENT_TYPE, + "Body": "1,2", + "ContentType": CSV_CONTENT_TYPE, + "EndpointName": ENDPOINT, } call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args assert kwargs == expected_request_args @@ -488,12 +499,14 @@ def test_predict_call_with_headers_and_csv(): def test_delete_endpoint_with_config(): sagemaker_session = empty_sagemaker_session() - sagemaker_session.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointConfigName': 'endpoint-config'}) + sagemaker_session.sagemaker_client.describe_endpoint = Mock( + return_value={"EndpointConfigName": "endpoint-config"} + ) predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session) predictor.delete_endpoint() sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT) - sagemaker_session.delete_endpoint_config.assert_called_with('endpoint-config') + sagemaker_session.delete_endpoint_config.assert_called_with("endpoint-config") def test_delete_endpoint_only(): @@ -512,15 +525,17 @@ def test_delete_model(): predictor.delete_model() expected_call_count = 2 - expected_call_args_list = [call('model-1'), call('model-2')] + expected_call_args_list = [call("model-1"), call("model-2")] assert sagemaker_session.delete_model.call_count == expected_call_count assert sagemaker_session.delete_model.call_args_list == expected_call_args_list def test_delete_model_fail(): sagemaker_session = empty_sagemaker_session() - sagemaker_session.sagemaker_client.delete_model = Mock(side_effect=Exception('Could not find model.')) - expected_error_message = 'One or more models cannot be deleted, please retry.' + sagemaker_session.sagemaker_client.delete_model = Mock( + side_effect=Exception("Could not find model.") + ) + expected_error_message = "One or more models cannot be deleted, please retry." predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 11df279517..7eeadaea27 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -25,130 +25,141 @@ from sagemaker.pytorch import PyTorchPredictor, PyTorchModel -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py') +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") MODEL_DATA = "s3://some/data.tar.gz" -TIMESTAMP = '2017-11-06-14:14:15.672' +TIMESTAMP = "2017-11-06-14:14:15.672" TIME = 1507167947 -BUCKET_NAME = 'mybucket' +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.4xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -PYTHON_VERSION = 'py' + str(sys.version_info.major) -IMAGE_NAME = 'sagemaker-pytorch' -JOB_NAME = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) +INSTANCE_TYPE = "ml.c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +PYTHON_VERSION = "py" + str(sys.version_info.major) +IMAGE_NAME = "sagemaker-pytorch" +JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" -ROLE = 'Dummy' -REGION = 'us-west-2' -GPU = 'ml.p2.xlarge' -CPU = 'ml.c4.xlarge' +ROLE = "Dummy" +REGION = "us-west-2" +GPU = "ml.p2.xlarge" +CPU = "ml.c4.xlarge" -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} -LIST_TAGS_RESULT = { - 'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] -} +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} -@pytest.fixture(name='sagemaker_session') +@pytest.fixture(name="sagemaker_session") def fixture_sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - - describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) - session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session def _get_full_cpu_image_uri(version, py_version=PYTHON_VERSION): - return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'cpu', py_version) + return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "cpu", py_version) def _get_full_gpu_image_uri(version, py_version=PYTHON_VERSION): - return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'gpu', py_version) + return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "gpu", py_version) def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION): - return _get_full_cpu_image_uri(version, py_version=py_version) + '-eia' - - -def _pytorch_estimator(sagemaker_session, framework_version=defaults.PYTORCH_VERSION, train_instance_type=None, - base_job_name=None, **kwargs): - return PyTorch(entry_point=SCRIPT_PATH, - framework_version=framework_version, - py_version=PYTHON_VERSION, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, - base_job_name=base_job_name, - **kwargs) + return _get_full_cpu_image_uri(version, py_version=py_version) + "-eia" + + +def _pytorch_estimator( + sagemaker_session, + framework_version=defaults.PYTORCH_VERSION, + train_instance_type=None, + base_job_name=None, + **kwargs +): + return PyTorch( + entry_point=SCRIPT_PATH, + framework_version=framework_version, + py_version=PYTHON_VERSION, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, + base_job_name=base_job_name, + **kwargs + ) def _create_train_job(version): return { - 'image': _get_full_cpu_image_uri(version), - 'input_mode': 'File', - 'input_config': [{ - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix' - } + "image": _get_full_cpu_image_uri(version), + "input_mode": "File", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + } + }, } - }], - 'role': ROLE, - 'job_name': JOB_NAME, - 'output_config': { - 'S3OutputPath': 's3://{}/'.format(BUCKET_NAME), + ], + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": "ml.c4.4xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, }, - 'resource_config': { - 'InstanceType': 'ml.c4.4xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 30, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-west-2"', }, - 'hyperparameters': { - 'sagemaker_program': json.dumps('dummy_script.py'), - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': str(logging.INFO), - 'sagemaker_job_name': json.dumps(JOB_NAME), - 'sagemaker_submit_directory': - json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)), - 'sagemaker_region': '"us-west-2"' - }, - 'stop_condition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'tags': None, - 'vpc_config': None, - 'metric_definitions': None + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, } def test_create_model(sagemaker_session, pytorch_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=pytorch_version, container_log_level=container_log_level, - base_job_name='job', source_dir=source_dir) - - job_name = 'new_name' - pytorch.fit(inputs='s3://mybucket/train', job_name='new_name') + source_dir = "s3://mybucket/source" + pytorch = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=pytorch_version, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "new_name" + pytorch.fit(inputs="s3://mybucket/train", job_name="new_name") model = pytorch.create_model() assert model.sagemaker_session == sagemaker_session @@ -164,20 +175,28 @@ def test_create_model(sagemaker_session, pytorch_version): def test_create_model_with_optional_params(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - enable_cloudwatch_metrics = 'true' - pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - container_log_level=container_log_level, base_job_name='job', source_dir=source_dir, - enable_cloudwatch_metrics=enable_cloudwatch_metrics) - - pytorch.fit(inputs='s3://mybucket/train', job_name='new_name') - - new_role = 'role' + source_dir = "s3://mybucket/source" + enable_cloudwatch_metrics = "true" + pytorch = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + ) + + pytorch.fit(inputs="s3://mybucket/train", job_name="new_name") + + new_role = "role" model_server_workers = 2 - vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - model = pytorch.create_model(role=new_role, model_server_workers=model_server_workers, - vpc_config_override=vpc_config) + vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + model = pytorch.create_model( + role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config + ) assert model.role == new_role assert model.model_server_workers == model_server_workers @@ -186,15 +205,22 @@ def test_create_model_with_optional_params(sagemaker_session): def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - image = 'pytorch:9000' - pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - container_log_level=container_log_level, image_name=image, - base_job_name='job', source_dir=source_dir) - - job_name = 'new_name' - pytorch.fit(inputs='s3://mybucket/train', job_name='new_name') + source_dir = "s3://mybucket/source" + image = "pytorch:9000" + pytorch = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, + image_name=image, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "new_name" + pytorch.fit(inputs="s3://mybucket/train", job_name="new_name") model = pytorch.create_model() assert model.sagemaker_session == sagemaker_session @@ -206,205 +232,249 @@ def test_create_model_with_custom_image(sagemaker_session): assert model.source_dir == source_dir -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) def test_pytorch(strftime, sagemaker_session, pytorch_version): - pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=pytorch_version, py_version=PYTHON_VERSION) - - inputs = 's3://mybucket/train' + pytorch = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=pytorch_version, + py_version=PYTHON_VERSION, + ) + + inputs = "s3://mybucket/train" pytorch.fit(inputs=inputs) sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job'] + assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ['resource'] + assert boto_call_names == ["resource"] expected_train_args = _create_train_job(pytorch_version) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args model = pytorch.create_model() - expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-gpu-{}' - assert {'Environment': - {'SAGEMAKER_SUBMIT_DIRECTORY': - 's3://mybucket/sagemaker-pytorch-{}/source/sourcedir.tar.gz'.format(TIMESTAMP), - 'SAGEMAKER_PROGRAM': 'dummy_script.py', - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'}, - 'Image': expected_image_base.format(pytorch_version, PYTHON_VERSION), - 'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(GPU) - - assert 'cpu' in model.prepare_container_def(CPU)['Image'] + expected_image_base = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-gpu-{}" + assert { + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-pytorch-{}/source/sourcedir.tar.gz".format( + TIMESTAMP + ), + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + }, + "Image": expected_image_base.format(pytorch_version, PYTHON_VERSION), + "ModelDataUrl": "s3://m/m.tar.gz", + } == model.prepare_container_def(GPU) + + assert "cpu" in model.prepare_container_def(CPU)["Image"] predictor = pytorch.deploy(1, GPU) assert isinstance(predictor, PyTorchPredictor) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_model(sagemaker_session): - model = PyTorchModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = PyTorchModel( + MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + ) predictor = model.deploy(1, GPU) assert isinstance(predictor, PyTorchPredictor) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_model_image_accelerator(sagemaker_session): - model = PyTorchModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = PyTorchModel( + MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + ) with pytest.raises(ValueError): model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) def test_train_image_default(sagemaker_session): - pytorch = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) + pytorch = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) - assert _get_full_cpu_image_uri(defaults.PYTORCH_VERSION, defaults.PYTHON_VERSION) in pytorch.train_image() + assert ( + _get_full_cpu_image_uri(defaults.PYTORCH_VERSION, defaults.PYTHON_VERSION) + in pytorch.train_image() + ) def test_train_image_cpu_instances(sagemaker_session, pytorch_version): - pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, train_instance_type='ml.c2.2xlarge') + pytorch = _pytorch_estimator( + sagemaker_session, pytorch_version, train_instance_type="ml.c2.2xlarge" + ) assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version) - pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, train_instance_type='ml.c4.2xlarge') + pytorch = _pytorch_estimator( + sagemaker_session, pytorch_version, train_instance_type="ml.c4.2xlarge" + ) assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version) - pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, train_instance_type='ml.m16') + pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, train_instance_type="ml.m16") assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version) def test_train_image_gpu_instances(sagemaker_session, pytorch_version): - pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, train_instance_type='ml.g2.2xlarge') + pytorch = _pytorch_estimator( + sagemaker_session, pytorch_version, train_instance_type="ml.g2.2xlarge" + ) assert pytorch.train_image() == _get_full_gpu_image_uri(pytorch_version) - pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, train_instance_type='ml.p2.2xlarge') + pytorch = _pytorch_estimator( + sagemaker_session, pytorch_version, train_instance_type="ml.p2.2xlarge" + ) assert pytorch.train_image() == _get_full_gpu_image_uri(pytorch_version) def test_attach(sagemaker_session, pytorch_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-cpu-{}'.format(pytorch_version, - PYTHON_VERSION) - returned_job_description = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = PyTorch.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-cpu-{}".format( + pytorch_version, PYTHON_VERSION + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = PyTorch.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" assert estimator.py_version == PYTHON_VERSION assert estimator.framework_version == pytorch_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" def test_attach_wrong_framework(sagemaker_session): - rjd = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0.4'}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) + rjd = { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0.4", + }, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) with pytest.raises(ValueError) as error: - PyTorch.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + PyTorch.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "didn't use image for requested framework" in str(error) def test_attach_custom_image(sagemaker_session): - training_image = 'pytorch:latest' - returned_job_description = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = PyTorch.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' + training_image = "pytorch:latest" + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = PyTorch.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" assert estimator.image_name == training_image assert estimator.train_image() == training_image -@patch('sagemaker.pytorch.estimator.empty_framework_version_warning') +@patch("sagemaker.pytorch.estimator.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): - estimator = PyTorch(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=None) + estimator = PyTorch( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=None, + ) assert estimator.framework_version == defaults.PYTORCH_VERSION warning.assert_called_with(defaults.PYTORCH_VERSION, defaults.PYTORCH_VERSION) diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index 561f7c29ac..f182969c54 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -18,45 +18,45 @@ from sagemaker.amazon.randomcutforest import RandomCutForest, RandomCutForestPredictor from sagemaker.amazon.amazon_estimator import registry, RecordSet -ROLE = 'myrole' +ROLE = "myrole" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" NUM_SAMPLES_PER_TREE = 20 NUM_TREES = 50 -EVAL_METRICS = ['accuracy', 'precision_recall_fscore'] +EVAL_METRICS = ["accuracy", "precision_recall_fscore"] -COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, - 'train_instance_type': TRAIN_INSTANCE_TYPE} +COMMON_TRAIN_ARGS = { + "role": ROLE, + "train_instance_count": TRAIN_INSTANCE_COUNT, + "train_instance_type": TRAIN_INSTANCE_TYPE, +} ALL_REQ_ARGS = dict(**COMMON_TRAIN_ARGS) REGION = "us-west-2" BUCKET_NAME = "Some-Bucket" -DESCRIBE_TRAINING_JOB_RESULT = { - 'ModelArtifacts': { - 'S3ModelArtifacts': "s3://bucket/model.tar.gz" - } -} +DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": "s3://bucket/model.tar.gz"}} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) - sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=DESCRIBE_TRAINING_JOB_RESULT) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=DESCRIBE_TRAINING_JOB_RESULT + ) sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -64,9 +64,15 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): - randomcutforest = RandomCutForest(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, - NUM_SAMPLES_PER_TREE, NUM_TREES, EVAL_METRICS, - sagemaker_session=sagemaker_session) + randomcutforest = RandomCutForest( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_SAMPLES_PER_TREE, + NUM_TREES, + EVAL_METRICS, + sagemaker_session=sagemaker_session, + ) assert randomcutforest.role == ROLE assert randomcutforest.train_instance_count == TRAIN_INSTANCE_COUNT assert randomcutforest.train_instance_type == TRAIN_INSTANCE_TYPE @@ -78,30 +84,34 @@ def test_init_required_positional(sagemaker_session): def test_init_required_named(sagemaker_session): randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert randomcutforest.role == COMMON_TRAIN_ARGS['role'] + assert randomcutforest.role == COMMON_TRAIN_ARGS["role"] assert randomcutforest.train_instance_count == TRAIN_INSTANCE_COUNT - assert randomcutforest.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] + assert randomcutforest.train_instance_type == COMMON_TRAIN_ARGS["train_instance_type"] def test_all_hyperparameters(sagemaker_session): - randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, num_trees=NUM_TREES, - num_samples_per_tree=NUM_SAMPLES_PER_TREE, - eval_metrics=EVAL_METRICS, **ALL_REQ_ARGS) + randomcutforest = RandomCutForest( + sagemaker_session=sagemaker_session, + num_trees=NUM_TREES, + num_samples_per_tree=NUM_SAMPLES_PER_TREE, + eval_metrics=EVAL_METRICS, + **ALL_REQ_ARGS + ) assert randomcutforest.hyperparameters() == dict( num_samples_per_tree=str(NUM_SAMPLES_PER_TREE), num_trees=str(NUM_TREES), - eval_metrics="{}".format(EVAL_METRICS) + eval_metrics="{}".format(EVAL_METRICS), ) def test_image(sagemaker_session): randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - assert randomcutforest.train_image() == registry(REGION, "randomcutforest") + '/randomcutforest:1' + assert ( + randomcutforest.train_image() == registry(REGION, "randomcutforest") + "/randomcutforest:1" + ) -@pytest.mark.parametrize('iterable_hyper_parameters, value', [ - ('eval_metrics', 0) -]) +@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)]) def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value): with pytest.raises(TypeError): test_params = ALL_REQ_ARGS.copy() @@ -109,10 +119,10 @@ def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parame RandomCutForest(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('num_trees', 'string'), - ('num_samples_per_tree', 'string'), -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [("num_trees", "string"), ("num_samples_per_tree", "string")], +) def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -120,10 +130,15 @@ def test_optional_hyper_parameters_type(sagemaker_session, optional_hyper_parame RandomCutForest(sagemaker_session=sagemaker_session, **test_params) -@pytest.mark.parametrize('optional_hyper_parameters, value', [ - ('num_trees', 49), ('num_trees', 1001), - ('num_samples_per_tree', 0), ('num_samples_per_tree', 2049), -]) +@pytest.mark.parametrize( + "optional_hyper_parameters, value", + [ + ("num_trees", 49), + ("num_trees", 1001), + ("num_samples_per_tree", 0), + ("num_samples_per_tree", 2049), + ], +) def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_parameters, value): with pytest.raises(ValueError): test_params = ALL_REQ_ARGS.copy() @@ -139,10 +154,16 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param @patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") def test_call_fit(base_fit, sagemaker_session): - randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session, - **ALL_REQ_ARGS) + randomcutforest = RandomCutForest( + base_job_name="randomcutforest", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) randomcutforest.fit(data, MINI_BATCH_SIZE) @@ -153,33 +174,48 @@ def test_call_fit(base_fit, sagemaker_session): def test_prepare_for_training_no_mini_batch_size(sagemaker_session): - randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session, - **ALL_REQ_ARGS) + randomcutforest = RandomCutForest( + base_job_name="randomcutforest", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) randomcutforest._prepare_for_training(data) assert randomcutforest.mini_batch_size == MINI_BATCH_SIZE def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session): - randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session, - **ALL_REQ_ARGS) + randomcutforest = RandomCutForest( + base_job_name="randomcutforest", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) with pytest.raises((TypeError, ValueError)): randomcutforest._prepare_for_training(data, 1234) def test_prepare_for_training_feature_dim_greater_than_max_allowed(sagemaker_session): - randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session, - **ALL_REQ_ARGS) + randomcutforest = RandomCutForest( + base_job_name="randomcutforest", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS + ) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=MAX_FEATURE_DIM + 1, - channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=MAX_FEATURE_DIM + 1, + channel="train", + ) with pytest.raises((TypeError, ValueError)): randomcutforest._prepare_for_training(data) @@ -187,16 +223,26 @@ def test_prepare_for_training_feature_dim_greater_than_max_allowed(sagemaker_ses def test_model_image(sagemaker_session): randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) randomcutforest.fit(data, MINI_BATCH_SIZE) model = randomcutforest.create_model() - assert model.image == registry(REGION, "randomcutforest") + '/randomcutforest:1' + assert model.image == registry(REGION, "randomcutforest") + "/randomcutforest:1" def test_predictor_type(sagemaker_session): randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) - data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) randomcutforest.fit(data, MINI_BATCH_SIZE) model = randomcutforest.create_model() predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 35c8fa706f..234e19cf4f 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -25,136 +25,148 @@ import sagemaker.tensorflow.serving as tfs -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py') -TIMESTAMP = '2017-11-06-14:14:15.672' +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +TIMESTAMP = "2017-11-06-14:14:15.672" TIME = 1507167947 -BUCKET_NAME = 'notmybucket' +BUCKET_NAME = "notmybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.4xlarge' -IMAGE_NAME = 'sagemaker-rl' +INSTANCE_TYPE = "ml.c4.4xlarge" +IMAGE_NAME = "sagemaker-rl" IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}-{}:{}{}-{}-py3" -PYTHON_VERSION = 'py3' -ROLE = 'Dummy' -REGION = 'us-west-2' -GPU = 'ml.p2.xlarge' -CPU = 'ml.c4.xlarge' +PYTHON_VERSION = "py3" +ROLE = "Dummy" +REGION = "us-west-2" +GPU = "ml.p2.xlarge" +CPU = "ml.c4.xlarge" -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} -LIST_TAGS_RESULT = { - 'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] -} +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} -@pytest.fixture(name='sagemaker_session') +@pytest.fixture(name="sagemaker_session") def fixture_sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - - describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) - session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session def _get_full_cpu_image_uri(toolkit, toolkit_version, framework): - return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, framework, - toolkit, toolkit_version, 'cpu') + return IMAGE_URI_FORMAT_STRING.format( + REGION, IMAGE_NAME, framework, toolkit, toolkit_version, "cpu" + ) def _get_full_gpu_image_uri(toolkit, toolkit_version, framework): - return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, framework, - toolkit, toolkit_version, 'gpu') - - -def _rl_estimator(sagemaker_session, toolkit=RLToolkit.COACH, - toolkit_version=RLEstimator.COACH_LATEST_VERSION_MXNET, framework=RLFramework.MXNET, - train_instance_type=None, base_job_name=None, **kwargs): - return RLEstimator(entry_point=SCRIPT_PATH, - toolkit=toolkit, - toolkit_version=toolkit_version, - framework=framework, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=train_instance_type or INSTANCE_TYPE, - base_job_name=base_job_name, - **kwargs) + return IMAGE_URI_FORMAT_STRING.format( + REGION, IMAGE_NAME, framework, toolkit, toolkit_version, "gpu" + ) + + +def _rl_estimator( + sagemaker_session, + toolkit=RLToolkit.COACH, + toolkit_version=RLEstimator.COACH_LATEST_VERSION_MXNET, + framework=RLFramework.MXNET, + train_instance_type=None, + base_job_name=None, + **kwargs +): + return RLEstimator( + entry_point=SCRIPT_PATH, + toolkit=toolkit, + toolkit_version=toolkit_version, + framework=framework, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=train_instance_type or INSTANCE_TYPE, + base_job_name=base_job_name, + **kwargs + ) def _create_train_job(toolkit, toolkit_version, framework): - job_name = '{}-{}-{}'.format(IMAGE_NAME, framework, TIMESTAMP) + job_name = "{}-{}-{}".format(IMAGE_NAME, framework, TIMESTAMP) return { - 'image': _get_full_cpu_image_uri(toolkit, toolkit_version, framework), - 'input_mode': 'File', - 'input_config': [{ - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix' - } + "image": _get_full_cpu_image_uri(toolkit, toolkit_version, framework), + "input_mode": "File", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + } + }, } - }], - 'role': ROLE, - 'job_name': job_name, - 'output_config': { - 'S3OutputPath': 's3://{}/'.format(BUCKET_NAME), + ], + "role": ROLE, + "job_name": job_name, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": "ml.c4.4xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, }, - 'resource_config': { - 'InstanceType': 'ml.c4.4xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 30, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_estimator": '"RLEstimator"', + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(job_name), + "sagemaker_s3_output": '"s3://{}/"'.format(BUCKET_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, job_name) + ), + "sagemaker_region": '"us-west-2"', }, - 'hyperparameters': { - 'sagemaker_program': json.dumps('dummy_script.py'), - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_estimator': '"RLEstimator"', - 'sagemaker_container_log_level': str(logging.INFO), - 'sagemaker_job_name': json.dumps(job_name), - 'sagemaker_s3_output': '"s3://{}/"'.format(BUCKET_NAME), - 'sagemaker_submit_directory': - json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, job_name)), - 'sagemaker_region': '"us-west-2"' - }, - 'stop_condition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'tags': None, - 'vpc_config': None, - 'metric_definitions': [ - {'Name': 'reward-training', - 'Regex': '^Training>.*Total reward=(.*?),'}, - {'Name': 'reward-testing', - 'Regex': '^Testing>.*Total reward=(.*?),'} - ] + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": [ + {"Name": "reward-training", "Regex": "^Training>.*Total reward=(.*?),"}, + {"Name": "reward-testing", "Regex": "^Testing>.*Total reward=(.*?),"}, + ], } def test_create_tf_model(sagemaker_session, rl_coach_tf_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - toolkit=RLToolkit.COACH, toolkit_version=rl_coach_tf_version, - framework=RLFramework.TENSORFLOW, container_log_level=container_log_level, - source_dir=source_dir) - - job_name = 'new_name' - rl.fit(inputs='s3://mybucket/train', job_name='new_name') + source_dir = "s3://mybucket/source" + rl = RLEstimator( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + toolkit=RLToolkit.COACH, + toolkit_version=rl_coach_tf_version, + framework=RLFramework.TENSORFLOW, + container_log_level=container_log_level, + source_dir=source_dir, + ) + + job_name = "new_name" + rl.fit(inputs="s3://mybucket/train", job_name="new_name") model = rl.create_model() supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value] framework_version = supported_versions[rl_coach_tf_version][RLFramework.TENSORFLOW.value] @@ -170,15 +182,22 @@ def test_create_tf_model(sagemaker_session, rl_coach_tf_version): def test_create_mxnet_model(sagemaker_session, rl_coach_mxnet_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - toolkit=RLToolkit.COACH, toolkit_version=rl_coach_mxnet_version, - framework=RLFramework.MXNET, container_log_level=container_log_level, - source_dir=source_dir) - - job_name = 'new_name' - rl.fit(inputs='s3://mybucket/train', job_name='new_name') + source_dir = "s3://mybucket/source" + rl = RLEstimator( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + toolkit=RLToolkit.COACH, + toolkit_version=rl_coach_mxnet_version, + framework=RLFramework.MXNET, + container_log_level=container_log_level, + source_dir=source_dir, + ) + + job_name = "new_name" + rl.fit(inputs="s3://mybucket/train", job_name="new_name") model = rl.create_model() supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value] framework_version = supported_versions[rl_coach_mxnet_version][RLFramework.MXNET.value] @@ -197,20 +216,28 @@ def test_create_mxnet_model(sagemaker_session, rl_coach_mxnet_version): def test_create_model_with_optional_params(sagemaker_session, rl_coach_mxnet_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - toolkit=RLToolkit.COACH, toolkit_version=rl_coach_mxnet_version, - framework=RLFramework.MXNET, container_log_level=container_log_level, - source_dir=source_dir) - - rl.fit(job_name='new_name') - - new_role = 'role' - new_entry_point = 'deploy_script.py' - vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - model = rl.create_model(role=new_role, entry_point=new_entry_point, - vpc_config_override=vpc_config) + source_dir = "s3://mybucket/source" + rl = RLEstimator( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + toolkit=RLToolkit.COACH, + toolkit_version=rl_coach_mxnet_version, + framework=RLFramework.MXNET, + container_log_level=container_log_level, + source_dir=source_dir, + ) + + rl.fit(job_name="new_name") + + new_role = "role" + new_entry_point = "deploy_script.py" + vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + model = rl.create_model( + role=new_role, entry_point=new_entry_point, vpc_config_override=vpc_config + ) assert model.role == new_role assert model.vpc_config == vpc_config @@ -219,16 +246,22 @@ def test_create_model_with_optional_params(sagemaker_session, rl_coach_mxnet_ver def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - image = 'selfdrivingcars:9000' - rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - image_name=image, container_log_level=container_log_level, - source_dir=source_dir) - - job_name = 'new_name' + source_dir = "s3://mybucket/source" + image = "selfdrivingcars:9000" + rl = RLEstimator( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + image_name=image, + container_log_level=container_log_level, + source_dir=source_dir, + ) + + job_name = "new_name" rl.fit(job_name=job_name) - new_entry_point = 'deploy_script.py' + new_entry_point = "deploy_script.py" model = rl.create_model(entry_point=new_entry_point) assert model.sagemaker_session == sagemaker_session @@ -240,26 +273,33 @@ def test_create_model_with_custom_image(sagemaker_session): assert model.source_dir == source_dir -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('time.strftime', return_value=TIMESTAMP) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) def test_rl(strftime, sagemaker_session, rl_coach_mxnet_version): - rl = RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - toolkit=RLToolkit.COACH, toolkit_version=rl_coach_mxnet_version, - framework=RLFramework.MXNET) - - inputs = 's3://mybucket/train' + rl = RLEstimator( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + toolkit=RLToolkit.COACH, + toolkit_version=rl_coach_mxnet_version, + framework=RLFramework.MXNET, + ) + + inputs = "s3://mybucket/train" rl.fit(inputs=inputs) sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job'] + assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ['resource'] + assert boto_call_names == ["resource"] - expected_train_args = _create_train_job(RLToolkit.COACH.value, rl_coach_mxnet_version, - RLFramework.MXNET.value) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args = _create_train_job( + RLToolkit.COACH.value, rl_coach_mxnet_version, RLFramework.MXNET.value + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args @@ -268,232 +308,301 @@ def test_rl(strftime, sagemaker_session, rl_coach_mxnet_version): supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value] framework_version = supported_versions[rl_coach_mxnet_version][RLFramework.MXNET.value] - expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py3' - submit_dir = 's3://notmybucket/sagemaker-rl-mxnet-{}/source/sourcedir.tar.gz'.format(TIMESTAMP) - assert {'Environment': {'SAGEMAKER_SUBMIT_DIRECTORY': submit_dir, - 'SAGEMAKER_PROGRAM': 'dummy_script.py', - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'}, - 'Image': expected_image_base.format(framework_version), - 'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(GPU) + expected_image_base = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py3" + submit_dir = "s3://notmybucket/sagemaker-rl-mxnet-{}/source/sourcedir.tar.gz".format(TIMESTAMP) + assert { + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": submit_dir, + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + }, + "Image": expected_image_base.format(framework_version), + "ModelDataUrl": "s3://m/m.tar.gz", + } == model.prepare_container_def(GPU) - assert 'cpu' in model.prepare_container_def(CPU)['Image'] + assert "cpu" in model.prepare_container_def(CPU)["Image"] -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_deploy_mxnet(sagemaker_session, rl_coach_mxnet_version): - rl = _rl_estimator(sagemaker_session, RLToolkit.COACH, rl_coach_mxnet_version, RLFramework.MXNET, - train_instance_type='ml.g2.2xlarge') + rl = _rl_estimator( + sagemaker_session, + RLToolkit.COACH, + rl_coach_mxnet_version, + RLFramework.MXNET, + train_instance_type="ml.g2.2xlarge", + ) rl.fit() predictor = rl.deploy(1, CPU) assert isinstance(predictor, MXNetPredictor) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_deploy_tfs(sagemaker_session, rl_coach_tf_version): - rl = _rl_estimator(sagemaker_session, RLToolkit.COACH, rl_coach_tf_version, RLFramework.TENSORFLOW, - train_instance_type='ml.g2.2xlarge') + rl = _rl_estimator( + sagemaker_session, + RLToolkit.COACH, + rl_coach_tf_version, + RLFramework.TENSORFLOW, + train_instance_type="ml.g2.2xlarge", + ) rl.fit() predictor = rl.deploy(1, GPU) assert isinstance(predictor, tfs.Predictor) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_deploy_ray(sagemaker_session, rl_ray_version): - rl = _rl_estimator(sagemaker_session, RLToolkit.RAY, rl_ray_version, RLFramework.TENSORFLOW, - train_instance_type='ml.g2.2xlarge') + rl = _rl_estimator( + sagemaker_session, + RLToolkit.RAY, + rl_ray_version, + RLFramework.TENSORFLOW, + train_instance_type="ml.g2.2xlarge", + ) rl.fit() with pytest.raises(NotImplementedError) as e: rl.deploy(1, GPU) - assert 'deployment of Ray models is not currently available' in str(e.value) + assert "deployment of Ray models is not currently available" in str(e.value) def test_train_image_cpu_instances(sagemaker_session, rl_ray_version): toolkit = RLToolkit.RAY framework = RLFramework.TENSORFLOW - rl = _rl_estimator(sagemaker_session, toolkit, rl_ray_version, framework, - train_instance_type='ml.c2.2xlarge') - assert rl.train_image() == _get_full_cpu_image_uri(toolkit.value, rl_ray_version, - framework.value) - - rl = _rl_estimator(sagemaker_session, toolkit, rl_ray_version, framework, - train_instance_type='ml.c4.2xlarge') - assert rl.train_image() == _get_full_cpu_image_uri(toolkit.value, rl_ray_version, - framework.value) - - rl = _rl_estimator(sagemaker_session, toolkit, rl_ray_version, framework, - train_instance_type='ml.m16') - assert rl.train_image() == _get_full_cpu_image_uri(toolkit.value, rl_ray_version, - framework.value) + rl = _rl_estimator( + sagemaker_session, toolkit, rl_ray_version, framework, train_instance_type="ml.c2.2xlarge" + ) + assert rl.train_image() == _get_full_cpu_image_uri( + toolkit.value, rl_ray_version, framework.value + ) + + rl = _rl_estimator( + sagemaker_session, toolkit, rl_ray_version, framework, train_instance_type="ml.c4.2xlarge" + ) + assert rl.train_image() == _get_full_cpu_image_uri( + toolkit.value, rl_ray_version, framework.value + ) + + rl = _rl_estimator( + sagemaker_session, toolkit, rl_ray_version, framework, train_instance_type="ml.m16" + ) + assert rl.train_image() == _get_full_cpu_image_uri( + toolkit.value, rl_ray_version, framework.value + ) def test_train_image_gpu_instances(sagemaker_session, rl_coach_mxnet_version): toolkit = RLToolkit.COACH framework = RLFramework.MXNET - rl = _rl_estimator(sagemaker_session, toolkit, rl_coach_mxnet_version, framework, - train_instance_type='ml.g2.2xlarge') - assert rl.train_image() == _get_full_gpu_image_uri(toolkit.value, rl_coach_mxnet_version, - framework.value) - - rl = _rl_estimator(sagemaker_session, toolkit, rl_coach_mxnet_version, framework, - train_instance_type='ml.p2.2xlarge') - assert rl.train_image() == _get_full_gpu_image_uri(toolkit.value, rl_coach_mxnet_version, - framework.value) + rl = _rl_estimator( + sagemaker_session, + toolkit, + rl_coach_mxnet_version, + framework, + train_instance_type="ml.g2.2xlarge", + ) + assert rl.train_image() == _get_full_gpu_image_uri( + toolkit.value, rl_coach_mxnet_version, framework.value + ) + + rl = _rl_estimator( + sagemaker_session, + toolkit, + rl_coach_mxnet_version, + framework, + train_instance_type="ml.p2.2xlarge", + ) + assert rl.train_image() == _get_full_gpu_image_uri( + toolkit.value, rl_coach_mxnet_version, framework.value + ) def test_attach(sagemaker_session, rl_coach_mxnet_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-{}:{}{}-cpu-py3'\ - .format(RLFramework.MXNET.value, RLToolkit.COACH.value, rl_coach_mxnet_version) + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-{}:{}{}-cpu-py3".format( + RLFramework.MXNET.value, RLToolkit.COACH.value, rl_coach_mxnet_version + ) supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP[RLToolkit.COACH.value] framework_version = supported_versions[rl_coach_mxnet_version][RLFramework.MXNET.value] - returned_job_description = {'AlgorithmSpecification': {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"train_coach.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': { - 'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = \ - Mock(name='describe_training_job', return_value=returned_job_description) - - estimator = RLEstimator.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"train_coach.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = RLEstimator.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" assert estimator.framework == RLFramework.MXNET.value assert estimator.toolkit == RLToolkit.COACH.value assert estimator.framework_version == framework_version assert estimator.toolkit_version == rl_coach_mxnet_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'train_coach.py' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "train_coach.py" assert estimator.metric_definitions == RLEstimator.default_metric_definitions(RLToolkit.COACH) def test_attach_wrong_framework(sagemaker_session): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0.4' - rjd = {'AlgorithmSpecification': {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=rjd) + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0.4" + rjd = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) with pytest.raises(ValueError) as error: - RLEstimator.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + RLEstimator.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "didn't use image for requested framework" in str(error) def test_attach_custom_image(sagemaker_session): - training_image = 'rl:latest' - returned_job_description = {'AlgorithmSpecification': {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': - {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': - {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = \ - Mock(name='describe_training_job', return_value=returned_job_description) - - estimator = RLEstimator.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' + training_image = "rl:latest" + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = RLEstimator.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" assert estimator.image_name == training_image assert estimator.train_image() == training_image def test_wrong_framework_format(sagemaker_session): with pytest.raises(ValueError) as e: - RLEstimator(toolkit=RLToolkit.RAY, framework='TF', - toolkit_version=RLEstimator.RAY_LATEST_VERSION, - entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=None) + RLEstimator( + toolkit=RLToolkit.RAY, + framework="TF", + toolkit_version=RLEstimator.RAY_LATEST_VERSION, + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=None, + ) - assert 'Invalid type' in str(e.value) + assert "Invalid type" in str(e.value) def test_wrong_toolkit_format(sagemaker_session): with pytest.raises(ValueError) as e: - RLEstimator(toolkit='coach', framework=RLFramework.TENSORFLOW, - toolkit_version=RLEstimator.COACH_LATEST_VERSION_TF, - entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=None) + RLEstimator( + toolkit="coach", + framework=RLFramework.TENSORFLOW, + toolkit_version=RLEstimator.COACH_LATEST_VERSION_TF, + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=None, + ) - assert 'Invalid type' in str(e.value) + assert "Invalid type" in str(e.value) def test_missing_required_parameters(sagemaker_session): with pytest.raises(AttributeError) as e: - RLEstimator(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) - assert 'Please provide `toolkit`, `toolkit_version`, `framework`' + \ - ' or `image_name` parameter.' in str(e.value) + RLEstimator( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) + assert ( + "Please provide `toolkit`, `toolkit_version`, `framework`" + " or `image_name` parameter." + in str(e.value) + ) def test_wrong_type_parameters(sagemaker_session): with pytest.raises(AttributeError) as e: - RLEstimator(toolkit=RLToolkit.COACH, framework=RLFramework.TENSORFLOW, - toolkit_version=RLEstimator.RAY_LATEST_VERSION, - entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) - assert 'combination is not supported.' in str(e.value) + RLEstimator( + toolkit=RLToolkit.COACH, + framework=RLFramework.TENSORFLOW, + toolkit_version=RLEstimator.RAY_LATEST_VERSION, + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) + assert "combination is not supported." in str(e.value) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index cc428692d4..e401622353 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -26,11 +26,11 @@ from sagemaker.session import _tuning_job_status, _transform_job_status, _train_done from sagemaker.tuner import WarmStartConfig, WarmStartTypes -STATIC_HPs = {"feature_dim": "784", } +STATIC_HPs = {"feature_dim": "784"} SAMPLE_PARAM_RANGES = [{"Name": "mini_batch_size", "MinValue": "10", "MaxValue": "100"}] -REGION = 'us-west-2' +REGION = "us-west-2" @pytest.fixture() @@ -38,8 +38,9 @@ def boto_session(): boto_session = Mock(region_name=REGION) mock_client = Mock() - mock_client._client_config.user_agent = \ - 'Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource' + mock_client._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) boto_session.client.return_value = mock_client return boto_session @@ -47,68 +48,74 @@ def boto_session(): def test_get_execution_role(): session = Mock() - session.get_caller_identity_arn.return_value = 'arn:aws:iam::369233609183:role/SageMakerRole' + session.get_caller_identity_arn.return_value = "arn:aws:iam::369233609183:role/SageMakerRole" actual = get_execution_role(session) - assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole' + assert actual == "arn:aws:iam::369233609183:role/SageMakerRole" def test_get_execution_role_works_with_service_role(): session = Mock() - session.get_caller_identity_arn.return_value = \ - 'arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388' + session.get_caller_identity_arn.return_value = ( + "arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388" + ) actual = get_execution_role(session) - assert actual == 'arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388' + assert ( + actual + == "arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388" + ) def test_get_execution_role_throws_exception_if_arn_is_not_role(): session = Mock() - session.get_caller_identity_arn.return_value = 'arn:aws:iam::369233609183:user/marcos' + session.get_caller_identity_arn.return_value = "arn:aws:iam::369233609183:user/marcos" with pytest.raises(ValueError) as error: get_execution_role(session) - assert 'ValueError: The current AWS identity is not a role' in str(error) + assert "ValueError: The current AWS identity is not a role" in str(error) def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_name(): session = Mock() - session.get_caller_identity_arn.return_value = 'arn:aws:iam::369233609183:user/marcos-role' + session.get_caller_identity_arn.return_value = "arn:aws:iam::369233609183:user/marcos-role" with pytest.raises(ValueError) as error: get_execution_role(session) - assert 'ValueError: The current AWS identity is not a role' in str(error) + assert "ValueError: The current AWS identity is not a role" in str(error) def test_get_caller_identity_arn_from_an_user(boto_session): sess = Session(boto_session) - arn = 'arn:aws:iam::369233609183:user/mia' - sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': arn} - sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': arn}} + arn = "arn:aws:iam::369233609183:user/mia" + sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn} + sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}} actual = sess.get_caller_identity_arn() - assert actual == 'arn:aws:iam::369233609183:user/mia' + assert actual == "arn:aws:iam::369233609183:user/mia" def test_get_caller_identity_arn_from_an_user_without_permissions(boto_session): sess = Session(boto_session) - arn = 'arn:aws:iam::369233609183:user/mia' - sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': arn} - sess.boto_session.client('iam').get_role.side_effect = ClientError({}, {}) + arn = "arn:aws:iam::369233609183:user/mia" + sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn} + sess.boto_session.client("iam").get_role.side_effect = ClientError({}, {}) - with patch('logging.Logger.warning') as mock_logger: + with patch("logging.Logger.warning") as mock_logger: actual = sess.get_caller_identity_arn() - assert actual == 'arn:aws:iam::369233609183:user/mia' + assert actual == "arn:aws:iam::369233609183:user/mia" mock_logger.assert_called_once() def test_get_caller_identity_arn_from_a_role(boto_session): sess = Session(boto_session) - arn = 'arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122' - sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': arn} + arn = ( + "arn:aws:sts::369233609183:assumed-role/SageMakerRole/6d009ef3-5306-49d5-8efc-78db644d8122" + ) + sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn} - expected_role = 'arn:aws:iam::369233609183:role/SageMakerRole' - sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': expected_role}} + expected_role = "arn:aws:iam::369233609183:role/SageMakerRole" + sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": expected_role}} actual = sess.get_caller_identity_arn() assert actual == expected_role @@ -116,23 +123,28 @@ def test_get_caller_identity_arn_from_a_role(boto_session): def test_get_caller_identity_arn_from_a_execution_role(boto_session): sess = Session(boto_session) - arn = 'arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker' - sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': arn} - sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': arn}} + arn = "arn:aws:sts::369233609183:assumed-role/AmazonSageMaker-ExecutionRole-20171129T072388/SageMaker" + sess.boto_session.client("sts").get_caller_identity.return_value = {"Arn": arn} + sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}} actual = sess.get_caller_identity_arn() - assert actual == 'arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388' + assert ( + actual + == "arn:aws:iam::369233609183:role/service-role/AmazonSageMaker-ExecutionRole-20171129T072388" + ) def test_get_caller_identity_arn_from_role_with_path(boto_session): sess = Session(boto_session) - arn_prefix = 'arn:aws:iam::369233609183:role' - role_name = 'name' - sess.boto_session.client('sts').get_caller_identity.return_value = {'Arn': '/'.join([arn_prefix, role_name])} + arn_prefix = "arn:aws:iam::369233609183:role" + role_name = "name" + sess.boto_session.client("sts").get_caller_identity.return_value = { + "Arn": "/".join([arn_prefix, role_name]) + } - role_path = 'path' - role_with_path = '/'.join([arn_prefix, role_path, role_name]) - sess.boto_session.client('iam').get_role.return_value = {'Role': {'Arn': role_with_path}} + role_path = "path" + role_with_path = "/".join([arn_prefix, role_path, role_name]) + sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": role_with_path}} actual = sess.get_caller_identity_arn() assert actual == role_with_path @@ -140,75 +152,91 @@ def test_get_caller_identity_arn_from_role_with_path(boto_session): def test_delete_endpoint(boto_session): sess = Session(boto_session) - sess.delete_endpoint('my_endpoint') + sess.delete_endpoint("my_endpoint") - boto_session.client().delete_endpoint.assert_called_with(EndpointName='my_endpoint') + boto_session.client().delete_endpoint.assert_called_with(EndpointName="my_endpoint") def test_delete_endpoint_config(boto_session): sess = Session(boto_session) - sess.delete_endpoint_config('my_endpoint_config') + sess.delete_endpoint_config("my_endpoint_config") - boto_session.client().delete_endpoint_config.assert_called_with(EndpointConfigName='my_endpoint_config') + boto_session.client().delete_endpoint_config.assert_called_with( + EndpointConfigName="my_endpoint_config" + ) def test_delete_model(boto_session): sess = Session(boto_session) - model_name = 'my_model' + model_name = "my_model" sess.delete_model(model_name) boto_session.client().delete_model.assert_called_with(ModelName=model_name) def test_user_agent_injected(boto_session): - assert 'AWS-SageMaker-Python-SDK' not in boto_session.client('sagemaker')._client_config.user_agent + assert ( + "AWS-SageMaker-Python-SDK" not in boto_session.client("sagemaker")._client_config.user_agent + ) sess = Session(boto_session) - assert 'AWS-SageMaker-Python-SDK' in sess.sagemaker_client._client_config.user_agent - assert 'AWS-SageMaker-Python-SDK' in sess.sagemaker_runtime_client._client_config.user_agent - assert 'AWS-SageMaker-Notebook-Instance' not in sess.sagemaker_client._client_config.user_agent - assert 'AWS-SageMaker-Notebook-Instance' not in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_runtime_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi(boto_session): - assert 'AWS-SageMaker-Python-SDK' not in boto_session.client('sagemaker')._client_config.user_agent + assert ( + "AWS-SageMaker-Python-SDK" not in boto_session.client("sagemaker")._client_config.user_agent + ) - with patch('six.moves.builtins.open', mock_open(read_data='120.0-0')) as mo: + with patch("six.moves.builtins.open", mock_open(read_data="120.0-0")) as mo: sess = Session(boto_session) - mo.assert_called_with('/etc/opt/ml/sagemaker-notebook-instance-version.txt') + mo.assert_called_with("/etc/opt/ml/sagemaker-notebook-instance-version.txt") - assert 'AWS-SageMaker-Python-SDK' in sess.sagemaker_client._client_config.user_agent - assert 'AWS-SageMaker-Python-SDK' in sess.sagemaker_runtime_client._client_config.user_agent - assert 'AWS-SageMaker-Notebook-Instance' in sess.sagemaker_client._client_config.user_agent - assert 'AWS-SageMaker-Notebook-Instance' in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_client._client_config.user_agent + assert ( + "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_runtime_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi_ioerror(boto_session): - assert 'AWS-SageMaker-Python-SDK' not in boto_session.client('sagemaker')._client_config.user_agent + assert ( + "AWS-SageMaker-Python-SDK" not in boto_session.client("sagemaker")._client_config.user_agent + ) - with patch('six.moves.builtins.open', MagicMock(side_effect=IOError('File not found'))) as mo: + with patch("six.moves.builtins.open", MagicMock(side_effect=IOError("File not found"))) as mo: sess = Session(boto_session) - mo.assert_called_with('/etc/opt/ml/sagemaker-notebook-instance-version.txt') + mo.assert_called_with("/etc/opt/ml/sagemaker-notebook-instance-version.txt") - assert 'AWS-SageMaker-Python-SDK' in sess.sagemaker_client._client_config.user_agent - assert 'AWS-SageMaker-Python-SDK' in sess.sagemaker_runtime_client._client_config.user_agent - assert 'AWS-SageMaker-Notebook-Instance' not in sess.sagemaker_client._client_config.user_agent - assert 'AWS-SageMaker-Notebook-Instance' not in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_runtime_client._client_config.user_agent + ) def test_s3_input_all_defaults(): - prefix = 'pre' + prefix = "pre" actual = s3_input(s3_data=prefix) expected = { - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': prefix + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": prefix, } } } @@ -216,442 +244,547 @@ def test_s3_input_all_defaults(): def test_s3_input_all_arguments(): - prefix = 'pre' - distribution = 'FullyReplicated' - compression = 'Gzip' - content_type = 'text/csv' - record_wrapping = 'RecordIO' - s3_data_type = 'Manifestfile' - input_mode = 'Pipe' - result = s3_input(s3_data=prefix, distribution=distribution, compression=compression, input_mode=input_mode, - content_type=content_type, record_wrapping=record_wrapping, s3_data_type=s3_data_type) - expected = \ - {'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': distribution, - 'S3DataType': s3_data_type, - 'S3Uri': prefix, + prefix = "pre" + distribution = "FullyReplicated" + compression = "Gzip" + content_type = "text/csv" + record_wrapping = "RecordIO" + s3_data_type = "Manifestfile" + input_mode = "Pipe" + result = s3_input( + s3_data=prefix, + distribution=distribution, + compression=compression, + input_mode=input_mode, + content_type=content_type, + record_wrapping=record_wrapping, + s3_data_type=s3_data_type, + ) + expected = { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": distribution, + "S3DataType": s3_data_type, + "S3Uri": prefix, } }, - 'CompressionType': compression, - 'ContentType': content_type, - 'RecordWrapperType': record_wrapping, - 'InputMode': input_mode - } + "CompressionType": compression, + "ContentType": content_type, + "RecordWrapperType": record_wrapping, + "InputMode": input_mode, + } assert result.config == expected -IMAGE = 'myimage' -S3_INPUT_URI = 's3://mybucket/data' -S3_OUTPUT = 's3://sagemaker-123/output/jobname' -ROLE = 'SageMakerRole' -EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole' +IMAGE = "myimage" +S3_INPUT_URI = "s3://mybucket/data" +S3_OUTPUT = "s3://sagemaker-123/output/jobname" +ROLE = "SageMakerRole" +EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' +INSTANCE_TYPE = "ml.c4.xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" MAX_SIZE = 30 MAX_TIME = 3 * 60 * 60 -JOB_NAME = 'jobname' -TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] -VPC_CONFIG = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} -METRIC_DEFINITONS = [{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}] +JOB_NAME = "jobname" +TAGS = [{"Name": "some-tag", "Value": "value-for-tag"}] +VPC_CONFIG = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} +METRIC_DEFINITONS = [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}] DEFAULT_EXPECTED_TRAIN_JOB_ARGS = { - 'OutputDataConfig': { - 'S3OutputPath': S3_OUTPUT + "OutputDataConfig": {"S3OutputPath": S3_OUTPUT}, + "RoleArn": EXPANDED_ROLE, + "ResourceConfig": { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": MAX_SIZE, }, - 'RoleArn': EXPANDED_ROLE, - 'ResourceConfig': { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeSizeInGB': MAX_SIZE - }, - 'InputDataConfig': [ + "InputDataConfig": [ { - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': S3_INPUT_URI + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": S3_INPUT_URI, } }, - 'ChannelName': 'training' + "ChannelName": "training", } ], - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': IMAGE - }, - 'TrainingJobName': JOB_NAME, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': MAX_TIME - }, - 'VpcConfig': VPC_CONFIG, + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": IMAGE}, + "TrainingJobName": JOB_NAME, + "StoppingCondition": {"MaxRuntimeInSeconds": MAX_TIME}, + "VpcConfig": VPC_CONFIG, } COMPLETED_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS) -COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/' + JOB_NAME}) -COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Completed'}) COMPLETED_DESCRIBE_JOB_RESULT.update( - {'ModelArtifacts': { - 'S3ModelArtifacts': S3_OUTPUT + '/model/model.tar.gz' - }}) + {"TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/" + JOB_NAME} +) +COMPLETED_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "Completed"}) +COMPLETED_DESCRIBE_JOB_RESULT.update( + {"ModelArtifacts": {"S3ModelArtifacts": S3_OUTPUT + "/model/model.tar.gz"}} +) # TrainingStartTime and TrainingEndTime are for billable seconds calculation COMPLETED_DESCRIBE_JOB_RESULT.update( - {'TrainingStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000)}) + {"TrainingStartTime": datetime.datetime(2018, 2, 17, 7, 15, 0, 103000)} +) COMPLETED_DESCRIBE_JOB_RESULT.update( - {'TrainingEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000)}) + {"TrainingEndTime": datetime.datetime(2018, 2, 17, 7, 19, 34, 953000)} +) STOPPED_DESCRIBE_JOB_RESULT = dict(COMPLETED_DESCRIBE_JOB_RESULT) -STOPPED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Stopped'}) +STOPPED_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "Stopped"}) IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS) -IN_PROGRESS_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'InProgress'}) +IN_PROGRESS_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "InProgress"}) @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session') - boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'} + boto_mock = Mock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"} ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.expand_role = Mock(return_value=EXPANDED_ROLE) return ims def test_train_pack_to_request(sagemaker_session): - in_config = [{ - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': S3_INPUT_URI - } + in_config = [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": S3_INPUT_URI, + } + }, } - }] - - out_config = {'S3OutputPath': S3_OUTPUT} + ] - resource_config = {'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeSizeInGB': MAX_SIZE} + out_config = {"S3OutputPath": S3_OUTPUT} - stop_cond = {'MaxRuntimeInSeconds': MAX_TIME} + resource_config = { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": MAX_SIZE, + } - sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, - job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=VPC_CONFIG, - metric_definitions=None) + stop_cond = {"MaxRuntimeInSeconds": MAX_TIME} + + sagemaker_session.train( + image=IMAGE, + input_mode="File", + input_config=in_config, + role=EXPANDED_ROLE, + job_name=JOB_NAME, + output_config=out_config, + resource_config=resource_config, + hyperparameters=None, + stop_condition=stop_cond, + tags=None, + vpc_config=VPC_CONFIG, + metric_definitions=None, + ) assert sagemaker_session.sagemaker_client.method_calls[0] == ( - 'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS) + "create_training_job", + (), + DEFAULT_EXPECTED_TRAIN_JOB_ARGS, + ) -SAMPLE_STOPPING_CONDITION = {'MaxRuntimeInSeconds': MAX_TIME} +SAMPLE_STOPPING_CONDITION = {"MaxRuntimeInSeconds": MAX_TIME} -RESOURCE_CONFIG = {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeSizeInGB': MAX_SIZE} +RESOURCE_CONFIG = { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": MAX_SIZE, +} -SAMPLE_INPUT = [{'DataSource': { - 'S3DataSource': {'S3DataDistributionType': 'FullyReplicated', 'S3DataType': 'S3Prefix', 'S3Uri': S3_INPUT_URI}}, - 'ChannelName': 'training'}] +SAMPLE_INPUT = [ + { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": S3_INPUT_URI, + } + }, + "ChannelName": "training", + } +] -SAMPLE_OUTPUT = {'S3OutputPath': S3_OUTPUT} +SAMPLE_OUTPUT = {"S3OutputPath": S3_OUTPUT} -SAMPLE_OBJECTIVE = {'Type': "Maximize", 'MetricName': "val-score", } +SAMPLE_OBJECTIVE = {"Type": "Maximize", "MetricName": "val-score"} SAMPLE_METRIC_DEF = [{"Name": "train:progress", "Regex": "regex-1"}] SAMPLE_TUNING_JOB_REQUEST = { - 'HyperParameterTuningJobName': 'dummy-tuning-1', - 'HyperParameterTuningJobConfig': { - 'Strategy': "Bayesian", - 'HyperParameterTuningJobObjective': SAMPLE_OBJECTIVE, - 'ResourceLimits': { - 'MaxNumberOfTrainingJobs': 100, - 'MaxParallelTrainingJobs': 5, - }, - 'ParameterRanges': SAMPLE_PARAM_RANGES, - 'TrainingJobEarlyStoppingType': 'Off' + "HyperParameterTuningJobName": "dummy-tuning-1", + "HyperParameterTuningJobConfig": { + "Strategy": "Bayesian", + "HyperParameterTuningJobObjective": SAMPLE_OBJECTIVE, + "ResourceLimits": {"MaxNumberOfTrainingJobs": 100, "MaxParallelTrainingJobs": 5}, + "ParameterRanges": SAMPLE_PARAM_RANGES, + "TrainingJobEarlyStoppingType": "Off", }, - 'TrainingJobDefinition': { - 'StaticHyperParameters': STATIC_HPs, - 'AlgorithmSpecification': { - 'TrainingImage': "dummy-image-1", - 'TrainingInputMode': "File", - 'MetricDefinitions': SAMPLE_METRIC_DEF + "TrainingJobDefinition": { + "StaticHyperParameters": STATIC_HPs, + "AlgorithmSpecification": { + "TrainingImage": "dummy-image-1", + "TrainingInputMode": "File", + "MetricDefinitions": SAMPLE_METRIC_DEF, }, - 'RoleArn': EXPANDED_ROLE, - 'InputDataConfig': SAMPLE_INPUT, - 'OutputDataConfig': SAMPLE_OUTPUT, - - 'ResourceConfig': RESOURCE_CONFIG, - 'StoppingCondition': SAMPLE_STOPPING_CONDITION - } + "RoleArn": EXPANDED_ROLE, + "InputDataConfig": SAMPLE_INPUT, + "OutputDataConfig": SAMPLE_OUTPUT, + "ResourceConfig": RESOURCE_CONFIG, + "StoppingCondition": SAMPLE_STOPPING_CONDITION, + }, } -@pytest.mark.parametrize('warm_start_type, parents', [ - ("IdenticalDataAndAlgorithm", {"p1", "p2", "p3"}), - ("TransferLearning", {"p1", "p2", "p3"}), -]) +@pytest.mark.parametrize( + "warm_start_type, parents", + [("IdenticalDataAndAlgorithm", {"p1", "p2", "p3"}), ("TransferLearning", {"p1", "p2", "p3"})], +) def test_tune_warm_start(sagemaker_session, warm_start_type, parents): - def assert_create_tuning_job_request(**kwrags): - assert kwrags["HyperParameterTuningJobConfig"] == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"] + assert ( + kwrags["HyperParameterTuningJobConfig"] + == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"] + ) assert kwrags["HyperParameterTuningJobName"] == "dummy-tuning-1" assert kwrags["TrainingJobDefinition"] == SAMPLE_TUNING_JOB_REQUEST["TrainingJobDefinition"] assert kwrags["WarmStartConfig"] == { - 'WarmStartType': warm_start_type, - 'ParentHyperParameterTuningJobs': [{'HyperParameterTuningJobName': parent} for parent in parents] + "WarmStartType": warm_start_type, + "ParentHyperParameterTuningJobs": [ + {"HyperParameterTuningJobName": parent} for parent in parents + ], } - sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = assert_create_tuning_job_request - sagemaker_session.tune(job_name="dummy-tuning-1", - strategy="Bayesian", - objective_type="Maximize", - objective_metric_name="val-score", - max_jobs=100, - max_parallel_jobs=5, - parameter_ranges=SAMPLE_PARAM_RANGES, - static_hyperparameters=STATIC_HPs, - image="dummy-image-1", - input_mode="File", - metric_definitions=SAMPLE_METRIC_DEF, - role=EXPANDED_ROLE, - input_config=SAMPLE_INPUT, - output_config=SAMPLE_OUTPUT, - resource_config=RESOURCE_CONFIG, - stop_condition=SAMPLE_STOPPING_CONDITION, - tags=None, - warm_start_config=WarmStartConfig(warm_start_type=WarmStartTypes(warm_start_type), - parents=parents).to_input_req()) + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=WarmStartConfig( + warm_start_type=WarmStartTypes(warm_start_type), parents=parents + ).to_input_req(), + ) def test_tune(sagemaker_session): - def assert_create_tuning_job_request(**kwrags): - assert kwrags["HyperParameterTuningJobConfig"] == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"] + assert ( + kwrags["HyperParameterTuningJobConfig"] + == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"] + ) assert kwrags["HyperParameterTuningJobName"] == "dummy-tuning-1" assert kwrags["TrainingJobDefinition"] == SAMPLE_TUNING_JOB_REQUEST["TrainingJobDefinition"] assert kwrags.get("WarmStartConfig", None) is None - sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = assert_create_tuning_job_request - sagemaker_session.tune(job_name="dummy-tuning-1", - strategy="Bayesian", - objective_type="Maximize", - objective_metric_name="val-score", - max_jobs=100, - max_parallel_jobs=5, - parameter_ranges=SAMPLE_PARAM_RANGES, - static_hyperparameters=STATIC_HPs, - image="dummy-image-1", - input_mode="File", - metric_definitions=SAMPLE_METRIC_DEF, - role=EXPANDED_ROLE, - input_config=SAMPLE_INPUT, - output_config=SAMPLE_OUTPUT, - resource_config=RESOURCE_CONFIG, - stop_condition=SAMPLE_STOPPING_CONDITION, - tags=None, - warm_start_config=None) + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=None, + ) def test_tune_with_encryption_flag(sagemaker_session): - def assert_create_tuning_job_request(**kwrags): - assert kwrags["HyperParameterTuningJobConfig"] == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"] + assert ( + kwrags["HyperParameterTuningJobConfig"] + == SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"] + ) assert kwrags["HyperParameterTuningJobName"] == "dummy-tuning-1" assert kwrags["TrainingJobDefinition"]["EnableInterContainerTrafficEncryption"] is True assert kwrags.get("WarmStartConfig", None) is None - sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = assert_create_tuning_job_request - sagemaker_session.tune(job_name="dummy-tuning-1", - strategy="Bayesian", - objective_type="Maximize", - objective_metric_name="val-score", - max_jobs=100, - max_parallel_jobs=5, - parameter_ranges=SAMPLE_PARAM_RANGES, - static_hyperparameters=STATIC_HPs, - image="dummy-image-1", - input_mode="File", - metric_definitions=SAMPLE_METRIC_DEF, - role=EXPANDED_ROLE, - input_config=SAMPLE_INPUT, - output_config=SAMPLE_OUTPUT, - resource_config=RESOURCE_CONFIG, - stop_condition=SAMPLE_STOPPING_CONDITION, - tags=None, - warm_start_config=None, - encrypt_inter_container_traffic=True) + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=None, + encrypt_inter_container_traffic=True, + ) def test_stop_tuning_job(sagemaker_session): sms = sagemaker_session - sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(name='stop_hyper_parameter_tuning_job') + sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock( + name="stop_hyper_parameter_tuning_job" + ) sagemaker_session.stop_tuning_job(JOB_NAME) - sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME) + sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with( + HyperParameterTuningJobName=JOB_NAME + ) def test_stop_tuning_job_client_error_already_stopped(sagemaker_session): sms = sagemaker_session - exception = ClientError({'Error': {'Code': 'ValidationException'}}, 'Operation') - sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(name='stop_hyper_parameter_tuning_job', - side_effect=exception) + exception = ClientError({"Error": {"Code": "ValidationException"}}, "Operation") + sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock( + name="stop_hyper_parameter_tuning_job", side_effect=exception + ) sagemaker_session.stop_tuning_job(JOB_NAME) - sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME) + sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with( + HyperParameterTuningJobName=JOB_NAME + ) def test_stop_tuning_job_client_error(sagemaker_session): - error_response = {'Error': {'Code': 'MockException', 'Message': 'MockMessage'}} - operation = 'Operation' + error_response = {"Error": {"Code": "MockException", "Message": "MockMessage"}} + operation = "Operation" exception = ClientError(error_response, operation) sms = sagemaker_session - sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(name='stop_hyper_parameter_tuning_job', - side_effect=exception) + sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock( + name="stop_hyper_parameter_tuning_job", side_effect=exception + ) with pytest.raises(ClientError) as e: sagemaker_session.stop_tuning_job(JOB_NAME) - sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME) - assert 'An error occurred (MockException) when calling the Operation operation: MockMessage' in str(e) + sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with( + HyperParameterTuningJobName=JOB_NAME + ) + assert ( + "An error occurred (MockException) when calling the Operation operation: MockMessage" + in str(e) + ) def test_train_pack_to_request_with_optional_params(sagemaker_session): - in_config = [{ - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri': S3_INPUT_URI - } + in_config = [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + "S3Uri": S3_INPUT_URI, + } + }, } - }] - - out_config = {'S3OutputPath': S3_OUTPUT} + ] - resource_config = {'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeSizeInGB': MAX_SIZE} + out_config = {"S3OutputPath": S3_OUTPUT} - stop_cond = {'MaxRuntimeInSeconds': MAX_TIME} - hyperparameters = {'foo': 'bar'} + resource_config = { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": MAX_SIZE, + } - sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, - job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - vpc_config=VPC_CONFIG, hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS, - metric_definitions=METRIC_DEFINITONS, encrypt_inter_container_traffic=True) + stop_cond = {"MaxRuntimeInSeconds": MAX_TIME} + hyperparameters = {"foo": "bar"} + + sagemaker_session.train( + image=IMAGE, + input_mode="File", + input_config=in_config, + role=EXPANDED_ROLE, + job_name=JOB_NAME, + output_config=out_config, + resource_config=resource_config, + vpc_config=VPC_CONFIG, + hyperparameters=hyperparameters, + stop_condition=stop_cond, + tags=TAGS, + metric_definitions=METRIC_DEFINITONS, + encrypt_inter_container_traffic=True, + ) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] - assert actual_train_args['VpcConfig'] == VPC_CONFIG - assert actual_train_args['HyperParameters'] == hyperparameters - assert actual_train_args['Tags'] == TAGS - assert actual_train_args['AlgorithmSpecification']['MetricDefinitions'] == METRIC_DEFINITONS - assert actual_train_args['EnableInterContainerTrafficEncryption'] is True + assert actual_train_args["VpcConfig"] == VPC_CONFIG + assert actual_train_args["HyperParameters"] == hyperparameters + assert actual_train_args["Tags"] == TAGS + assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS + assert actual_train_args["EnableInterContainerTrafficEncryption"] is True def test_transform_pack_to_request(sagemaker_session): - model_name = 'my-model' + model_name = "my-model" in_config = { - 'CompressionType': 'None', - 'ContentType': 'text/csv', - 'SplitType': 'None', - 'DataSource': { - 'S3DataSource': { - 'S3DataType': 'S3Prefix', - 'S3Uri': S3_INPUT_URI, - }, - }, + "CompressionType": "None", + "ContentType": "text/csv", + "SplitType": "None", + "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}}, } - out_config = {'S3OutputPath': S3_OUTPUT} + out_config = {"S3OutputPath": S3_OUTPUT} - resource_config = { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - } + resource_config = {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE} expected_args = { - 'TransformJobName': JOB_NAME, - 'ModelName': model_name, - 'TransformInput': in_config, - 'TransformOutput': out_config, - 'TransformResources': resource_config, + "TransformJobName": JOB_NAME, + "ModelName": model_name, + "TransformInput": in_config, + "TransformOutput": out_config, + "TransformResources": resource_config, } - sagemaker_session.transform(job_name=JOB_NAME, model_name=model_name, strategy=None, max_concurrent_transforms=None, - max_payload=None, env=None, input_config=in_config, output_config=out_config, - resource_config=resource_config, tags=None, data_processing=None) + sagemaker_session.transform( + job_name=JOB_NAME, + model_name=model_name, + strategy=None, + max_concurrent_transforms=None, + max_payload=None, + env=None, + input_config=in_config, + output_config=out_config, + resource_config=resource_config, + tags=None, + data_processing=None, + ) _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_args == expected_args def test_transform_pack_to_request_with_optional_params(sagemaker_session): - strategy = 'strategy' + strategy = "strategy" max_concurrent_transforms = 1 max_payload = 0 - env = {'FOO': 'BAR'} - - sagemaker_session.transform(job_name=JOB_NAME, model_name='my-model', strategy=strategy, - max_concurrent_transforms=max_concurrent_transforms, - env=env, max_payload=max_payload, input_config={}, output_config={}, - resource_config={}, tags=TAGS, data_processing=None) + env = {"FOO": "BAR"} + + sagemaker_session.transform( + job_name=JOB_NAME, + model_name="my-model", + strategy=strategy, + max_concurrent_transforms=max_concurrent_transforms, + env=env, + max_payload=max_payload, + input_config={}, + output_config={}, + resource_config={}, + tags=TAGS, + data_processing=None, + ) _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] - assert actual_args['BatchStrategy'] == strategy - assert actual_args['MaxConcurrentTransforms'] == max_concurrent_transforms - assert actual_args['MaxPayloadInMB'] == max_payload - assert actual_args['Environment'] == env - assert actual_args['Tags'] == TAGS + assert actual_args["BatchStrategy"] == strategy + assert actual_args["MaxConcurrentTransforms"] == max_concurrent_transforms + assert actual_args["MaxPayloadInMB"] == max_payload + assert actual_args["Environment"] == env + assert actual_args["Tags"] == TAGS -@patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO) +@patch("sys.stdout", new_callable=io.BytesIO if six.PY2 else io.StringIO) def test_color_wrap(bio): color_wrap = sagemaker.logs.ColorWrap() - color_wrap(0, 'hi there') - assert bio.getvalue() == 'hi there\n' + color_wrap(0, "hi there") + assert bio.getvalue() == "hi there\n" class MockBotoException(ClientError): def __init__(self, code): - self.response = {'Error': {'Code': code}} - - -DEFAULT_LOG_STREAMS = {'logStreams': [{'logStreamName': JOB_NAME + '/xxxxxxxxx'}]} -LIFECYCLE_LOG_STREAMS = [MockBotoException('ResourceNotFoundException'), - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS, - DEFAULT_LOG_STREAMS] - -DEFAULT_LOG_EVENTS = [{'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, - {'nextForwardToken': None, 'events': []}] -STREAM_LOG_EVENTS = [{'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}]}, - {'nextForwardToken': None, 'events': []}, - {'nextForwardToken': None, 'events': [{'timestamp': 1, 'message': 'hi there #1'}, - {'timestamp': 2, 'message': 'hi there #2'}]}, - {'nextForwardToken': None, 'events': []}, - {'nextForwardToken': None, 'events': [{'timestamp': 2, 'message': 'hi there #2'}, - {'timestamp': 2, 'message': 'hi there #2a'}, - {'timestamp': 3, 'message': 'hi there #3'}]}, - {'nextForwardToken': None, 'events': []}] + self.response = {"Error": {"Code": code}} + + +DEFAULT_LOG_STREAMS = {"logStreams": [{"logStreamName": JOB_NAME + "/xxxxxxxxx"}]} +LIFECYCLE_LOG_STREAMS = [ + MockBotoException("ResourceNotFoundException"), + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, + DEFAULT_LOG_STREAMS, +] + +DEFAULT_LOG_EVENTS = [ + {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]}, + {"nextForwardToken": None, "events": []}, +] +STREAM_LOG_EVENTS = [ + {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]}, + {"nextForwardToken": None, "events": []}, + { + "nextForwardToken": None, + "events": [ + {"timestamp": 1, "message": "hi there #1"}, + {"timestamp": 2, "message": "hi there #2"}, + ], + }, + {"nextForwardToken": None, "events": []}, + { + "nextForwardToken": None, + "events": [ + {"timestamp": 2, "message": "hi there #2"}, + {"timestamp": 2, "message": "hi there #2a"}, + {"timestamp": 3, "message": "hi there #3"}, + ], + }, + {"nextForwardToken": None, "events": []}, +] @pytest.fixture() def sagemaker_session_complete(): - boto_mock = Mock(name='boto_session') - boto_mock.client('logs').describe_log_streams.return_value = DEFAULT_LOG_STREAMS - boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS + boto_mock = Mock(name="boto_session") + boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS + boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT return ims @@ -659,9 +792,9 @@ def sagemaker_session_complete(): @pytest.fixture() def sagemaker_session_stopped(): - boto_mock = Mock(name='boto_session') - boto_mock.client('logs').describe_log_streams.return_value = DEFAULT_LOG_STREAMS - boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS + boto_mock = Mock(name="boto_session") + boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS + boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT return ims @@ -669,188 +802,214 @@ def sagemaker_session_stopped(): @pytest.fixture() def sagemaker_session_ready_lifecycle(): - boto_mock = Mock(name='boto_session') - boto_mock.client('logs').describe_log_streams.return_value = DEFAULT_LOG_STREAMS - boto_mock.client('logs').get_log_events.side_effect = STREAM_LOG_EVENTS + boto_mock = Mock(name="boto_session") + boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS + boto_mock.client("logs").get_log_events.side_effect = STREAM_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) - ims.sagemaker_client.describe_training_job.side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT, - IN_PROGRESS_DESCRIBE_JOB_RESULT, - COMPLETED_DESCRIBE_JOB_RESULT] + ims.sagemaker_client.describe_training_job.side_effect = [ + IN_PROGRESS_DESCRIBE_JOB_RESULT, + IN_PROGRESS_DESCRIBE_JOB_RESULT, + COMPLETED_DESCRIBE_JOB_RESULT, + ] return ims @pytest.fixture() def sagemaker_session_full_lifecycle(): - boto_mock = Mock(name='boto_session') - boto_mock.client('logs').describe_log_streams.side_effect = LIFECYCLE_LOG_STREAMS - boto_mock.client('logs').get_log_events.side_effect = STREAM_LOG_EVENTS + boto_mock = Mock(name="boto_session") + boto_mock.client("logs").describe_log_streams.side_effect = LIFECYCLE_LOG_STREAMS + boto_mock.client("logs").get_log_events.side_effect = STREAM_LOG_EVENTS ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) - ims.sagemaker_client.describe_training_job.side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT, - IN_PROGRESS_DESCRIBE_JOB_RESULT, - COMPLETED_DESCRIBE_JOB_RESULT] + ims.sagemaker_client.describe_training_job.side_effect = [ + IN_PROGRESS_DESCRIBE_JOB_RESULT, + IN_PROGRESS_DESCRIBE_JOB_RESULT, + COMPLETED_DESCRIBE_JOB_RESULT, + ] return ims -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_job_no_wait(cw, sagemaker_session_complete): ims = sagemaker_session_complete ims.logs_for_job(JOB_NAME) ims.sagemaker_client.describe_training_job.assert_called_once_with(TrainingJobName=JOB_NAME) - cw().assert_called_with(0, 'hi there #1') + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_job_no_wait_stopped_job(cw, sagemaker_session_stopped): ims = sagemaker_session_stopped ims.logs_for_job(JOB_NAME) ims.sagemaker_client.describe_training_job.assert_called_once_with(TrainingJobName=JOB_NAME) - cw().assert_called_with(0, 'hi there #1') + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete): ims = sagemaker_session_complete ims.logs_for_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME,)] - cw().assert_called_with(0, 'hi there #1') + assert ims.sagemaker_client.describe_training_job.call_args_list == [ + call(TrainingJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_job_wait_on_stopped(cw, sagemaker_session_stopped): ims = sagemaker_session_stopped ims.logs_for_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME,)] - cw().assert_called_with(0, 'hi there #1') + assert ims.sagemaker_client.describe_training_job.call_args_list == [ + call(TrainingJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') +@patch("sagemaker.logs.ColorWrap") def test_logs_for_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle): ims = sagemaker_session_ready_lifecycle ims.logs_for_job(JOB_NAME) - assert ims.sagemaker_client.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME,)] - cw().assert_called_with(0, 'hi there #1') + assert ims.sagemaker_client.describe_training_job.call_args_list == [ + call(TrainingJobName=JOB_NAME) + ] + cw().assert_called_with(0, "hi there #1") -@patch('sagemaker.logs.ColorWrap') -@patch('time.time', side_effect=[0, 30, 60, 90, 120, 150, 180]) +@patch("sagemaker.logs.ColorWrap") +@patch("time.time", side_effect=[0, 30, 60, 90, 120, 150, 180]) def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle): ims = sagemaker_session_full_lifecycle ims.logs_for_job(JOB_NAME, wait=True, poll=0) - assert ims.sagemaker_client.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME,)] * 3 - assert cw().call_args_list == [call(0, 'hi there #1'), call(0, 'hi there #2'), - call(0, 'hi there #2a'), call(0, 'hi there #3')] - - -MODEL_NAME = 'some-model' + assert ( + ims.sagemaker_client.describe_training_job.call_args_list + == [call(TrainingJobName=JOB_NAME)] * 3 + ) + assert cw().call_args_list == [ + call(0, "hi there #1"), + call(0, "hi there #2"), + call(0, "hi there #2a"), + call(0, "hi there #3"), + ] + + +MODEL_NAME = "some-model" PRIMARY_CONTAINER = { - 'Environment': {}, - 'Image': IMAGE, - 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz', + "Environment": {}, + "Image": IMAGE, + "ModelDataUrl": "s3://sagemaker-123/output/jobname/model/model.tar.gz", } -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model(expand_container_def, sagemaker_session): model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) assert model == MODEL_NAME - sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=MODEL_NAME, - PrimaryContainer=PRIMARY_CONTAINER) + sagemaker_session.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, ModelName=MODEL_NAME, PrimaryContainer=PRIMARY_CONTAINER + ) -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model_with_tags(expand_container_def, sagemaker_session): - tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] + tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER, tags=tags) assert model == MODEL_NAME - tags = [{'Value': 'TagtestValue', 'Key': 'TagtestKey'}] - sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=MODEL_NAME, - PrimaryContainer=PRIMARY_CONTAINER, - Tags=tags) + tags = [{"Value": "TagtestValue", "Key": "TagtestKey"}] + sagemaker_session.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=MODEL_NAME, + PrimaryContainer=PRIMARY_CONTAINER, + Tags=tags, + ) -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model_with_primary_container(expand_container_def, sagemaker_session): model = sagemaker_session.create_model(MODEL_NAME, ROLE, container_defs=PRIMARY_CONTAINER) assert model == MODEL_NAME - sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=MODEL_NAME, - PrimaryContainer=PRIMARY_CONTAINER) + sagemaker_session.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, ModelName=MODEL_NAME, PrimaryContainer=PRIMARY_CONTAINER + ) -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model_with_both(expand_container_def, sagemaker_session): with pytest.raises(ValueError): - sagemaker_session.create_model(MODEL_NAME, ROLE, container_defs=PRIMARY_CONTAINER, - primary_container=PRIMARY_CONTAINER) + sagemaker_session.create_model( + MODEL_NAME, ROLE, container_defs=PRIMARY_CONTAINER, primary_container=PRIMARY_CONTAINER + ) CONTAINERS = [ { - 'Environment': {'SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT': 'application/json'}, - 'Image': 'mi-1', - 'ModelDataUrl': 's3://bucket/model_1.tar.gz' + "Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json"}, + "Image": "mi-1", + "ModelDataUrl": "s3://bucket/model_1.tar.gz", }, - { - 'Environment': {}, - 'Image': 'mi-2', - 'ModelDataUrl': 's3://bucket/model_2.tar.gz' - } + {"Environment": {}, "Image": "mi-2", "ModelDataUrl": "s3://bucket/model_2.tar.gz"}, ] -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_pipeline_model(expand_container_def, sagemaker_session): model = sagemaker_session.create_model(MODEL_NAME, ROLE, container_defs=CONTAINERS) assert model == MODEL_NAME - sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=MODEL_NAME, - Containers=CONTAINERS) + sagemaker_session.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, ModelName=MODEL_NAME, Containers=CONTAINERS + ) -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model_vpc_config(expand_container_def, sagemaker_session): model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER, VPC_CONFIG) assert model == MODEL_NAME - sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=MODEL_NAME, - PrimaryContainer=PRIMARY_CONTAINER, - VpcConfig=VPC_CONFIG) + sagemaker_session.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=MODEL_NAME, + PrimaryContainer=PRIMARY_CONTAINER, + VpcConfig=VPC_CONFIG, + ) -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_pipeline_model_vpc_config(expand_container_def, sagemaker_session): model = sagemaker_session.create_model(MODEL_NAME, ROLE, CONTAINERS, VPC_CONFIG) assert model == MODEL_NAME - sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=MODEL_NAME, - Containers=CONTAINERS, - VpcConfig=VPC_CONFIG) + sagemaker_session.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=MODEL_NAME, + Containers=CONTAINERS, + VpcConfig=VPC_CONFIG, + ) -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model_already_exists(expand_container_def, sagemaker_session, caplog): - error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing model'}} - exception = ClientError(error_response, 'Operation') + error_response = { + "Error": {"Code": "ValidationException", "Message": "Cannot create already existing model"} + } + exception = ClientError(error_response, "Operation") sagemaker_session.sagemaker_client.create_model.side_effect = exception model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) assert model == MODEL_NAME - expected_warning = ('sagemaker', logging.WARNING, 'Using already existing model: {}'.format(MODEL_NAME)) + expected_warning = ( + "sagemaker", + logging.WARNING, + "Using already existing model: {}".format(MODEL_NAME), + ) assert expected_warning in caplog.record_tuples -@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +@patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model_failure(expand_container_def, sagemaker_session): - error_message = 'this is expected' + error_message = "this is expected" sagemaker_session.sagemaker_client.create_model.side_effect = RuntimeError(error_message) with pytest.raises(RuntimeError) as e: @@ -864,11 +1023,15 @@ def test_create_model_from_job(sagemaker_session): ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT ims.create_model_from_job(JOB_NAME) - assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list - ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=JOB_NAME, - PrimaryContainer=PRIMARY_CONTAINER, - VpcConfig=VPC_CONFIG) + assert ( + call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list + ) + ims.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=JOB_NAME, + PrimaryContainer=PRIMARY_CONTAINER, + VpcConfig=VPC_CONFIG, + ) def test_create_model_from_job_with_tags(sagemaker_session): @@ -876,109 +1039,127 @@ def test_create_model_from_job_with_tags(sagemaker_session): ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT ims.create_model_from_job(JOB_NAME, tags=TAGS) - assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list - ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, - ModelName=JOB_NAME, - PrimaryContainer=PRIMARY_CONTAINER, - VpcConfig=VPC_CONFIG, - Tags=TAGS) + assert ( + call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list + ) + ims.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=JOB_NAME, + PrimaryContainer=PRIMARY_CONTAINER, + VpcConfig=VPC_CONFIG, + Tags=TAGS, + ) def test_create_model_from_job_with_image(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT - ims.create_model_from_job(JOB_NAME, primary_container_image='some-image') + ims.create_model_from_job(JOB_NAME, primary_container_image="some-image") [create_model_call] = ims.sagemaker_client.create_model.call_args_list - assert dict(create_model_call[1]['PrimaryContainer'])['Image'] == 'some-image' + assert dict(create_model_call[1]["PrimaryContainer"])["Image"] == "some-image" def test_create_model_from_job_with_container_def(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT - ims.create_model_from_job(JOB_NAME, primary_container_image='some-image', model_data_url='some-data', - env={'a': 'b'}) + ims.create_model_from_job( + JOB_NAME, primary_container_image="some-image", model_data_url="some-data", env={"a": "b"} + ) [create_model_call] = ims.sagemaker_client.create_model.call_args_list - c_def = create_model_call[1]['PrimaryContainer'] - assert c_def['Image'] == 'some-image' - assert c_def['ModelDataUrl'] == 'some-data' - assert c_def['Environment'] == {'a': 'b'} + c_def = create_model_call[1]["PrimaryContainer"] + assert c_def["Image"] == "some-image" + assert c_def["ModelDataUrl"] == "some-data" + assert c_def["Environment"] == {"a": "b"} def test_create_model_from_job_with_vpc_config_override(sagemaker_session): - vpc_config_override = {'Subnets': ['foo', 'bar'], 'SecurityGroupIds': ['baz']} + vpc_config_override = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT ims.create_model_from_job(JOB_NAME, vpc_config_override=vpc_config_override) - assert ims.sagemaker_client.create_model.call_args[1]['VpcConfig'] == vpc_config_override + assert ims.sagemaker_client.create_model.call_args[1]["VpcConfig"] == vpc_config_override ims.create_model_from_job(JOB_NAME, vpc_config_override=None) - assert 'VpcConfig' not in ims.sagemaker_client.create_model.call_args[1] + assert "VpcConfig" not in ims.sagemaker_client.create_model.call_args[1] def test_endpoint_from_production_variants(sagemaker_session): ims = sagemaker_session - ims.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'}) - pvs = [sagemaker.production_variant('A', 'ml.p2.xlarge'), sagemaker.production_variant('B', 'p299.4096xlarge')] - ex = ClientError({'Error': {'Code': 'ValidationException', 'Message': 'Could not find your thing'}}, 'b') + ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) + pvs = [ + sagemaker.production_variant("A", "ml.p2.xlarge"), + sagemaker.production_variant("B", "p299.4096xlarge"), + ] + ex = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) - sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs) - sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint', - EndpointName='some-endpoint', - Tags=[]) + sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs) + sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( + EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=[] + ) sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( - EndpointConfigName='some-endpoint', - ProductionVariants=pvs) + EndpointConfigName="some-endpoint", ProductionVariants=pvs + ) def test_create_endpoint_config_with_tags(sagemaker_session): - tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] + tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}] - sagemaker_session.create_endpoint_config('endpoint-test', 'simple-model', 1, 'local', tags=tags) + sagemaker_session.create_endpoint_config("endpoint-test", "simple-model", 1, "local", tags=tags) sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( - EndpointConfigName='endpoint-test', - ProductionVariants=ANY, - Tags=tags) + EndpointConfigName="endpoint-test", ProductionVariants=ANY, Tags=tags + ) def test_endpoint_from_production_variants_with_tags(sagemaker_session): ims = sagemaker_session - ims.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'}) - pvs = [sagemaker.production_variant('A', 'ml.p2.xlarge'), sagemaker.production_variant('B', 'p299.4096xlarge')] - ex = ClientError({'Error': {'Code': 'ValidationException', 'Message': 'Could not find your thing'}}, 'b') + ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) + pvs = [ + sagemaker.production_variant("A", "ml.p2.xlarge"), + sagemaker.production_variant("B", "p299.4096xlarge"), + ] + ex = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) - tags = [{'ModelName': 'TestModel'}] - sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags) - sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint', - EndpointName='some-endpoint', - Tags=tags) + tags = [{"ModelName": "TestModel"}] + sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs, tags) + sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( + EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=tags + ) sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( - EndpointConfigName='some-endpoint', - ProductionVariants=pvs, - Tags=tags) + EndpointConfigName="some-endpoint", ProductionVariants=pvs, Tags=tags + ) def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_session): ims = sagemaker_session - ims.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'}) - pvs = [sagemaker.production_variant('A', 'ml.p2.xlarge', accelerator_type=ACCELERATOR_TYPE), - sagemaker.production_variant('B', 'p299.4096xlarge', accelerator_type=ACCELERATOR_TYPE)] - ex = ClientError({'Error': {'Code': 'ValidationException', 'Message': 'Could not find your thing'}}, 'b') + ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) + pvs = [ + sagemaker.production_variant("A", "ml.p2.xlarge", accelerator_type=ACCELERATOR_TYPE), + sagemaker.production_variant("B", "p299.4096xlarge", accelerator_type=ACCELERATOR_TYPE), + ] + ex = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) - tags = [{'ModelName': 'TestModel'}] - sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags) - sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint', - EndpointName='some-endpoint', - Tags=tags) + tags = [{"ModelName": "TestModel"}] + sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs, tags) + sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( + EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=tags + ) sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( - EndpointConfigName='some-endpoint', - ProductionVariants=pvs, - Tags=tags) + EndpointConfigName="some-endpoint", ProductionVariants=pvs, Tags=tags + ) def test_update_endpoint_succeed(sagemaker_session): - sagemaker_session.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'}) + sagemaker_session.sagemaker_client.describe_endpoint = Mock( + return_value={"EndpointStatus": "InService"} + ) endpoint_name = "some-endpoint" endpoint_config = "some-endpoint-config" returned_endpoint_name = sagemaker_session.update_endpoint(endpoint_name, endpoint_config) @@ -986,100 +1167,119 @@ def test_update_endpoint_succeed(sagemaker_session): def test_update_endpoint_non_existing_endpoint(sagemaker_session): - error = ClientError({'Error': {'Code': 'ValidationException', 'Message': 'Could not find entity'}}, 'foo') - expected_error_message = 'Endpoint with name "non-existing-endpoint" does not exist; ' \ - 'please use an existing endpoint name' + error = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Could not find entity"}}, "foo" + ) + expected_error_message = ( + 'Endpoint with name "non-existing-endpoint" does not exist; ' + "please use an existing endpoint name" + ) sagemaker_session.sagemaker_client.describe_endpoint = Mock(side_effect=error) with pytest.raises(ValueError, match=expected_error_message): sagemaker_session.update_endpoint("non-existing-endpoint", "non-existing-config") -@patch('time.sleep') +@patch("time.sleep") def test_wait_for_tuning_job(sleep, sagemaker_session): - hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'} + hyperparameter_tuning_job_desc = {"HyperParameterTuningJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc) + name="describe_hyper_parameter_tuning_job", return_value=hyperparameter_tuning_job_desc + ) result = sagemaker_session.wait_for_tuning_job(JOB_NAME) - assert result['HyperParameterTuningJobStatus'] == 'Completed' + assert result["HyperParameterTuningJobStatus"] == "Completed" def test_tune_job_status(sagemaker_session): - hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'} + hyperparameter_tuning_job_desc = {"HyperParameterTuningJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc) + name="describe_hyper_parameter_tuning_job", return_value=hyperparameter_tuning_job_desc + ) result = _tuning_job_status(sagemaker_session.sagemaker_client, JOB_NAME) - assert result['HyperParameterTuningJobStatus'] == 'Completed' + assert result["HyperParameterTuningJobStatus"] == "Completed" def test_tune_job_status_none(sagemaker_session): - hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'InProgress'} + hyperparameter_tuning_job_desc = {"HyperParameterTuningJobStatus": "InProgress"} sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc) + name="describe_hyper_parameter_tuning_job", return_value=hyperparameter_tuning_job_desc + ) result = _tuning_job_status(sagemaker_session.sagemaker_client, JOB_NAME) assert result is None -@patch('time.sleep') +@patch("time.sleep") def test_wait_for_transform_job_completed(sleep, sagemaker_session): - transform_job_desc = {'TransformJobStatus': 'Completed'} + transform_job_desc = {"TransformJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_transform_job = Mock( - name='describe_transform_job', return_value=transform_job_desc) + name="describe_transform_job", return_value=transform_job_desc + ) - assert sagemaker_session.wait_for_transform_job(JOB_NAME)['TransformJobStatus'] == 'Completed' + assert sagemaker_session.wait_for_transform_job(JOB_NAME)["TransformJobStatus"] == "Completed" -@patch('time.sleep') +@patch("time.sleep") def test_wait_for_transform_job_in_progress(sleep, sagemaker_session): - transform_job_desc_in_progress = {'TransformJobStatus': 'InProgress'} - transform_job_desc_in_completed = {'TransformJobStatus': 'Completed'} + transform_job_desc_in_progress = {"TransformJobStatus": "InProgress"} + transform_job_desc_in_completed = {"TransformJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_transform_job = Mock( - name='describe_transform_job', side_effect=[transform_job_desc_in_progress, - transform_job_desc_in_completed]) + name="describe_transform_job", + side_effect=[transform_job_desc_in_progress, transform_job_desc_in_completed], + ) - assert sagemaker_session.wait_for_transform_job(JOB_NAME, 1)['TransformJobStatus'] == 'Completed' + assert ( + sagemaker_session.wait_for_transform_job(JOB_NAME, 1)["TransformJobStatus"] == "Completed" + ) assert 2 == sagemaker_session.sagemaker_client.describe_transform_job.call_count def test_transform_job_status(sagemaker_session): - transform_job_desc = {'TransformJobStatus': 'Completed'} + transform_job_desc = {"TransformJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_transform_job = Mock( - name='describe_transform_job', return_value=transform_job_desc) + name="describe_transform_job", return_value=transform_job_desc + ) result = _transform_job_status(sagemaker_session.sagemaker_client, JOB_NAME) - assert result['TransformJobStatus'] == 'Completed' + assert result["TransformJobStatus"] == "Completed" def test_transform_job_status_none(sagemaker_session): - transform_job_desc = {'TransformJobStatus': 'InProgress'} + transform_job_desc = {"TransformJobStatus": "InProgress"} sagemaker_session.sagemaker_client.describe_transform_job = Mock( - name='describe_transform_job', return_value=transform_job_desc) + name="describe_transform_job", return_value=transform_job_desc + ) result = _transform_job_status(sagemaker_session.sagemaker_client, JOB_NAME) assert result is None def test_train_done_completed(sagemaker_session): - training_job_desc = {'TrainingJobStatus': 'Completed'} + training_job_desc = {"TrainingJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_training_job = Mock( - name='describe_training_job', return_value=training_job_desc) + name="describe_training_job", return_value=training_job_desc + ) - actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None) + actual_job_desc, training_finished = _train_done( + sagemaker_session.sagemaker_client, JOB_NAME, None + ) - assert actual_job_desc['TrainingJobStatus'] == 'Completed' + assert actual_job_desc["TrainingJobStatus"] == "Completed" assert training_finished is True def test_train_done_in_progress(sagemaker_session): - training_job_desc = {'TrainingJobStatus': 'InProgress'} + training_job_desc = {"TrainingJobStatus": "InProgress"} sagemaker_session.sagemaker_client.describe_training_job = Mock( - name='describe_training_job', return_value=training_job_desc) + name="describe_training_job", return_value=training_job_desc + ) - actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None) + actual_job_desc, training_finished = _train_done( + sagemaker_session.sagemaker_client, JOB_NAME, None + ) - assert actual_job_desc['TrainingJobStatus'] == 'InProgress' + assert actual_job_desc["TrainingJobStatus"] == "InProgress" assert training_finished is False diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index da5ac483e6..88f3a524e7 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -26,144 +26,167 @@ from sagemaker.sklearn import SKLearnPredictor, SKLearnModel -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py') -TIMESTAMP = '2017-11-06-14:14:15.672' +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +TIMESTAMP = "2017-11-06-14:14:15.672" TIME = 1507167947 -BUCKET_NAME = 'mybucket' +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 DIST_INSTANCE_COUNT = 2 -INSTANCE_TYPE = 'ml.c4.4xlarge' +INSTANCE_TYPE = "ml.c4.4xlarge" GPU_INSTANCE_TYPE = "ml.p2.xlarge" -PYTHON_VERSION = 'py3' -IMAGE_NAME = 'sagemaker-scikit-learn' -JOB_NAME = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) +PYTHON_VERSION = "py3" +IMAGE_NAME = "sagemaker-scikit-learn" +JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) IMAGE_URI_FORMAT_STRING = "246618743249.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" -ROLE = 'Dummy' -REGION = 'us-west-2' -CPU = 'ml.c4.xlarge' +ROLE = "Dummy" +REGION = "us-west-2" +CPU = "ml.c4.xlarge" -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} -LIST_TAGS_RESULT = { - 'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] -} +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - - describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) - session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) return session def _get_full_cpu_image_uri(version): - return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, 'cpu', PYTHON_VERSION) - - -def _sklearn_estimator(sagemaker_session, framework_version=defaults.SKLEARN_VERSION, train_instance_type=None, - base_job_name=None, **kwargs): - return SKLearn(entry_point=SCRIPT_PATH, - framework_version=framework_version, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, - base_job_name=base_job_name, - py_version=PYTHON_VERSION, - **kwargs) + return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "cpu", PYTHON_VERSION) + + +def _sklearn_estimator( + sagemaker_session, + framework_version=defaults.SKLEARN_VERSION, + train_instance_type=None, + base_job_name=None, + **kwargs +): + return SKLearn( + entry_point=SCRIPT_PATH, + framework_version=framework_version, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, + base_job_name=base_job_name, + py_version=PYTHON_VERSION, + **kwargs + ) def _create_train_job(version): return { - 'image': _get_full_cpu_image_uri(version), - 'input_mode': 'File', - 'input_config': [ + "image": _get_full_cpu_image_uri(version), + "input_mode": "File", + "input_config": [ { - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix' + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", } - } + }, } ], - 'role': ROLE, - 'job_name': JOB_NAME, - 'output_config': { - 'S3OutputPath': 's3://{}/'.format(BUCKET_NAME), + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": "ml.c4.4xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, }, - 'resource_config': { - 'InstanceType': 'ml.c4.4xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 30, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-west-2"', }, - 'hyperparameters': { - 'sagemaker_program': json.dumps('dummy_script.py'), - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': str(logging.INFO), - 'sagemaker_job_name': json.dumps(JOB_NAME), - 'sagemaker_submit_directory': - json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME)), - 'sagemaker_region': '"us-west-2"' - }, - 'stop_condition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'metric_definitions': None, - 'tags': None, - 'vpc_config': None + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "metric_definitions": None, + "tags": None, + "vpc_config": None, } def test_train_image(sagemaker_session, sklearn_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_type=INSTANCE_TYPE, framework_version=sklearn_version, - container_log_level=container_log_level, py_version=PYTHON_VERSION, - base_job_name='job', source_dir=source_dir) + source_dir = "s3://mybucket/source" + sklearn = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=INSTANCE_TYPE, + framework_version=sklearn_version, + container_log_level=container_log_level, + py_version=PYTHON_VERSION, + base_job_name="job", + source_dir=source_dir, + ) train_image = sklearn.train_image() - assert train_image == '246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3' + assert ( + train_image + == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3" + ) def test_create_model(sagemaker_session): - source_dir = 's3://mybucket/source' - - sklearn_model = SKLearnModel(model_data=source_dir, role=ROLE, sagemaker_session=sagemaker_session, - entry_point=SCRIPT_PATH) - default_image_uri = _get_full_cpu_image_uri('0.20.0') + source_dir = "s3://mybucket/source" + + sklearn_model = SKLearnModel( + model_data=source_dir, + role=ROLE, + sagemaker_session=sagemaker_session, + entry_point=SCRIPT_PATH, + ) + default_image_uri = _get_full_cpu_image_uri("0.20.0") model_values = sklearn_model.prepare_container_def(CPU) - assert model_values['Image'] == default_image_uri + assert model_values["Image"] == default_image_uri def test_create_model_from_estimator(sagemaker_session, sklearn_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_type=INSTANCE_TYPE, framework_version=sklearn_version, - container_log_level=container_log_level, py_version=PYTHON_VERSION, - base_job_name='job', source_dir=source_dir) - - job_name = 'new_name' - sklearn.fit(inputs='s3://mybucket/train', job_name=job_name) + source_dir = "s3://mybucket/source" + sklearn = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=INSTANCE_TYPE, + framework_version=sklearn_version, + container_log_level=container_log_level, + py_version=PYTHON_VERSION, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "new_name" + sklearn.fit(inputs="s3://mybucket/train", job_name=job_name) model = sklearn.create_model() assert model.sagemaker_session == sagemaker_session @@ -179,20 +202,28 @@ def test_create_model_from_estimator(sagemaker_session, sklearn_version): def test_create_model_with_optional_params(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - enable_cloudwatch_metrics = 'true' - sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, - py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir, - enable_cloudwatch_metrics=enable_cloudwatch_metrics) - - sklearn.fit(inputs='s3://mybucket/train', job_name='new_name') - - new_role = 'role' + source_dir = "s3://mybucket/source" + enable_cloudwatch_metrics = "true" + sklearn = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, + py_version=PYTHON_VERSION, + base_job_name="job", + source_dir=source_dir, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + ) + + sklearn.fit(inputs="s3://mybucket/train", job_name="new_name") + + new_role = "role" model_server_workers = 2 - vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - model = sklearn.create_model(role=new_role, model_server_workers=model_server_workers, - vpc_config_override=vpc_config) + vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + model = sklearn.create_model( + role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config + ) assert model.role == new_role assert model.model_server_workers == model_server_workers @@ -201,205 +232,259 @@ def test_create_model_with_optional_params(sagemaker_session): def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - custom_image = 'ubuntu:latest' - sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_type=INSTANCE_TYPE, image_name=custom_image, - container_log_level=container_log_level, py_version=PYTHON_VERSION, - base_job_name='job', source_dir=source_dir) - - sklearn.fit(inputs='s3://mybucket/train', job_name='new_name') + source_dir = "s3://mybucket/source" + custom_image = "ubuntu:latest" + sklearn = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=INSTANCE_TYPE, + image_name=custom_image, + container_log_level=container_log_level, + py_version=PYTHON_VERSION, + base_job_name="job", + source_dir=source_dir, + ) + + sklearn.fit(inputs="s3://mybucket/train", job_name="new_name") model = sklearn.create_model() assert model.image == custom_image -@patch('time.strftime', return_value=TIMESTAMP) +@patch("time.strftime", return_value=TIMESTAMP) def test_sklearn(strftime, sagemaker_session, sklearn_version): - sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_type=INSTANCE_TYPE, py_version=PYTHON_VERSION, framework_version=sklearn_version) + sklearn = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=INSTANCE_TYPE, + py_version=PYTHON_VERSION, + framework_version=sklearn_version, + ) - inputs = 's3://mybucket/train' + inputs = "s3://mybucket/train" sklearn.fit(inputs=inputs) sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] - assert sagemaker_call_names == ['train', 'logs_for_job'] + assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ['resource'] + assert boto_call_names == ["resource"] expected_train_args = _create_train_job(sklearn_version) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args model = sklearn.create_model() - expected_image_base = '246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-{}' - assert {'Environment': - {'SAGEMAKER_SUBMIT_DIRECTORY': - 's3://mybucket/sagemaker-scikit-learn-{}/source/sourcedir.tar.gz'.format(TIMESTAMP), - 'SAGEMAKER_PROGRAM': 'dummy_script.py', - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', - 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'}, - 'Image': expected_image_base.format(sklearn_version, PYTHON_VERSION), - 'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(CPU) - - assert 'cpu' in model.prepare_container_def(CPU)['Image'] + expected_image_base = ( + "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-{}" + ) + assert { + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-scikit-learn-{}/source/sourcedir.tar.gz".format( + TIMESTAMP + ), + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + }, + "Image": expected_image_base.format(sklearn_version, PYTHON_VERSION), + "ModelDataUrl": "s3://m/m.tar.gz", + } == model.prepare_container_def(CPU) + + assert "cpu" in model.prepare_container_def(CPU)["Image"] predictor = sklearn.deploy(1, CPU) assert isinstance(predictor, SKLearnPredictor) def test_fail_distributed_training(sagemaker_session, sklearn_version): with pytest.raises(AttributeError) as error: - SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=DIST_INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - py_version=PYTHON_VERSION, framework_version=sklearn_version) + SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=DIST_INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + py_version=PYTHON_VERSION, + framework_version=sklearn_version, + ) assert "Scikit-Learn does not support distributed training." in str(error) def test_fail_GPU_training(sagemaker_session, sklearn_version): with pytest.raises(ValueError) as error: - SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_type=GPU_INSTANCE_TYPE, py_version=PYTHON_VERSION, - framework_version=sklearn_version) + SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=GPU_INSTANCE_TYPE, + py_version=PYTHON_VERSION, + framework_version=sklearn_version, + ) assert "GPU training in not supported for Scikit-Learn." in str(error) def test_model(sagemaker_session): - model = SKLearnModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = SKLearnModel( + "s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + ) predictor = model.deploy(1, CPU) assert isinstance(predictor, SKLearnPredictor) def test_train_image_default(sagemaker_session): - sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_type=INSTANCE_TYPE, py_version=PYTHON_VERSION) + sklearn = SKLearn( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_type=INSTANCE_TYPE, + py_version=PYTHON_VERSION, + ) assert _get_full_cpu_image_uri(defaults.SKLEARN_VERSION) in sklearn.train_image() def test_train_image_cpu_instances(sagemaker_session, sklearn_version): - sklearn = _sklearn_estimator(sagemaker_session, sklearn_version, train_instance_type='ml.c2.2xlarge') + sklearn = _sklearn_estimator( + sagemaker_session, sklearn_version, train_instance_type="ml.c2.2xlarge" + ) assert sklearn.train_image() == _get_full_cpu_image_uri(sklearn_version) - sklearn = _sklearn_estimator(sagemaker_session, sklearn_version, train_instance_type='ml.c4.2xlarge') + sklearn = _sklearn_estimator( + sagemaker_session, sklearn_version, train_instance_type="ml.c4.2xlarge" + ) assert sklearn.train_image() == _get_full_cpu_image_uri(sklearn_version) - sklearn = _sklearn_estimator(sagemaker_session, sklearn_version, train_instance_type='ml.m16') + sklearn = _sklearn_estimator(sagemaker_session, sklearn_version, train_instance_type="ml.m16") assert sklearn.train_image() == _get_full_cpu_image_uri(sklearn_version) def test_attach(sagemaker_session, sklearn_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-{}'.format(sklearn_version, - PYTHON_VERSION) - returned_job_description = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = SKLearn.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator._current_job_name == 'neo' - assert estimator.latest_training_job.job_name == 'neo' + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-{}".format( + sklearn_version, PYTHON_VERSION + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = SKLearn.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator._current_job_name == "neo" + assert estimator.latest_training_job.job_name == "neo" assert estimator.py_version == PYTHON_VERSION assert estimator.framework_version == sklearn_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" def test_attach_wrong_framework(sagemaker_session): - rjd = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-cpu:1.0.4'}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) + rjd = { + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-cpu:1.0.4", + }, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) with pytest.raises(ValueError) as error: - SKLearn.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + SKLearn.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "didn't use image for requested framework" in str(error) def test_attach_custom_image(sagemaker_session): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/my_custom_sklearn_image:latest' - returned_job_description = {'AlgorithmSpecification': - {'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': - {'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_s3_uri_training': '"sagemaker-3/integ-test-data/tf_iris"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'sagemaker_region': '"us-west-2"'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) - - estimator = SKLearn.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/my_custom_sklearn_image:latest" + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "sagemaker_region": '"us-west-2"', + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = SKLearn.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.image_name == training_image assert estimator.train_image() == training_image diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index bacf64a81f..c7c409fa48 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -18,29 +18,29 @@ from sagemaker.fw_registry import registry from sagemaker.sparkml import SparkMLModel, SparkMLPredictor -MODEL_DATA = 's3://bucket/model.tar.gz' -ROLE = 'myrole' -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +MODEL_DATA = "s3://bucket/model.tar.gz" +ROLE = "myrole" +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" -REGION = 'us-west-2' -BUCKET_NAME = 'Some-Bucket' -ENDPOINT = 'some-endpoint' +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" +ENDPOINT = "some-endpoint" -ENDPOINT_DESC = { - 'EndpointConfigName': ENDPOINT -} +ENDPOINT_DESC = {"EndpointConfigName": ENDPOINT} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock, - region_name=REGION, config=None, local_mode=False) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + region_name=REGION, + config=None, + local_mode=False, + ) sms.boto_region_name = REGION sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -49,7 +49,7 @@ def sagemaker_session(): def test_sparkml_model(sagemaker_session): sparkml = SparkMLModel(sagemaker_session=sagemaker_session, model_data=MODEL_DATA, role=ROLE) - assert sparkml.image == registry(REGION, 'sparkml-serving') + '/sagemaker-sparkml-serving:2.2' + assert sparkml.image == registry(REGION, "sparkml-serving") + "/sagemaker-sparkml-serving:2.2" def test_predictor_type(sagemaker_session): diff --git a/tests/unit/test_sync_directories.py b/tests/unit/test_sync_directories.py index ef7caeac7b..74f93995d9 100644 --- a/tests/unit/test_sync_directories.py +++ b/tests/unit/test_sync_directories.py @@ -34,13 +34,13 @@ def create_test_directory(directory, variable_content="hello world"): directory (str): The path to a directory to create with fake files variable_content (str): Content to put in one of the files """ - child_dir = os.path.join(directory, 'child_directory') + child_dir = os.path.join(directory, "child_directory") os.mkdir(child_dir) - with open(os.path.join(directory, 'foo1.txt'), 'w') as f: - f.write('bar1') - with open(os.path.join(directory, 'foo2.txt'), 'w') as f: - f.write('bar2') - with open(os.path.join(child_dir, 'hello.txt'), 'w') as f: + with open(os.path.join(directory, "foo1.txt"), "w") as f: + f.write("bar1") + with open(os.path.join(directory, "foo2.txt"), "w") as f: + f.write("bar2") + with open(os.path.join(child_dir, "hello.txt"), "w") as f: f.write(variable_content) @@ -69,7 +69,7 @@ def same_dirs(a, b): def test_to_directory_doesnt_exist(): with Tensorboard._temporary_directory() as from_dir: create_test_directory(from_dir) - to_dir = './not_a_real_place_{}'.format(random.getrandbits(64)) + to_dir = "./not_a_real_place_{}".format(random.getrandbits(64)) Tensorboard._sync_directories(from_dir, to_dir) assert same_dirs(from_dir, to_dir) shutil.rmtree(to_dir) diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 7827cf5968..fd97f08d63 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -25,51 +25,51 @@ from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowModel, TensorFlowPredictor import sagemaker.tensorflow.estimator as tfe -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') -SCRIPT_FILE = 'dummy_script.py' +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_FILE = "dummy_script.py" SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE) MODEL_DATA = "s3://some/data.tar.gz" -REQUIREMENTS_FILE = 'dummy_requirements.txt' -TIMESTAMP = '2017-11-06-14:14:15.673' +REQUIREMENTS_FILE = "dummy_requirements.txt" +TIMESTAMP = "2017-11-06-14:14:15.673" TIME = 1510006209.073025 -BUCKET_NAME = 'mybucket' +BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.4xlarge' -ACCELERATOR_TYPE = 'ml.eia.medium' -IMAGE_REPO_NAME = 'sagemaker-tensorflow' -SM_IMAGE_REPO_NAME = 'sagemaker-tensorflow-scriptmode' -JOB_NAME = '{}-{}'.format(IMAGE_REPO_NAME, TIMESTAMP) -SM_JOB_NAME = '{}-{}'.format(SM_IMAGE_REPO_NAME, TIMESTAMP) -ROLE = 'Dummy' -REGION = 'us-west-2' -DOCKER_TAG = '1.0' +INSTANCE_TYPE = "ml.c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +IMAGE_REPO_NAME = "sagemaker-tensorflow" +SM_IMAGE_REPO_NAME = "sagemaker-tensorflow-scriptmode" +JOB_NAME = "{}-{}".format(IMAGE_REPO_NAME, TIMESTAMP) +SM_JOB_NAME = "{}-{}".format(SM_IMAGE_REPO_NAME, TIMESTAMP) +ROLE = "Dummy" +REGION = "us-west-2" +DOCKER_TAG = "1.0" IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" -SCRIPT_MODE_REPO_NAME = 'sagemaker-tensorflow-scriptmode' -DISTRIBUTION_ENABLED = {'parameter_server': {'enabled': True}} -DISTRIBUTION_MPI_ENABLED = {'mpi': {'enabled': True, 'custom_mpi_options': 'options', 'processes_per_host': 2}} - -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' +SCRIPT_MODE_REPO_NAME = "sagemaker-tensorflow-scriptmode" +DISTRIBUTION_ENABLED = {"parameter_server": {"enabled": True}} +DISTRIBUTION_MPI_ENABLED = { + "mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2} } -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -LIST_TAGS_RESULT = { - 'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} + +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) - describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -77,163 +77,188 @@ def sagemaker_session(): return session -def _get_full_cpu_image_uri(version, repo=IMAGE_REPO_NAME, py_version='py2'): - return IMAGE_URI_FORMAT_STRING.format(REGION, repo, version, 'cpu', py_version) +def _get_full_cpu_image_uri(version, repo=IMAGE_REPO_NAME, py_version="py2"): + return IMAGE_URI_FORMAT_STRING.format(REGION, repo, version, "cpu", py_version) -def _get_full_gpu_image_uri(version, repo=IMAGE_REPO_NAME, py_version='py2'): - return IMAGE_URI_FORMAT_STRING.format(REGION, repo, version, 'gpu', py_version) +def _get_full_gpu_image_uri(version, repo=IMAGE_REPO_NAME, py_version="py2"): + return IMAGE_URI_FORMAT_STRING.format(REGION, repo, version, "gpu", py_version) def _get_full_cpu_image_uri_with_ei(version): - return _get_full_cpu_image_uri(version, repo='{}-eia'.format(IMAGE_REPO_NAME)) + return _get_full_cpu_image_uri(version, repo="{}-eia".format(IMAGE_REPO_NAME)) def _hyperparameters(script_mode=False, horovod=False): job_name = SM_JOB_NAME if script_mode else JOB_NAME hps = { - 'sagemaker_program': json.dumps('dummy_script.py'), - 'sagemaker_submit_directory': json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format( - BUCKET_NAME, job_name)), - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': str(logging.INFO), - 'sagemaker_job_name': json.dumps(job_name), - 'sagemaker_region': json.dumps('us-west-2') + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, job_name) + ), + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(job_name), + "sagemaker_region": json.dumps("us-west-2"), } if script_mode: if horovod: - hps['model_dir'] = json.dumps('/opt/ml/model') + hps["model_dir"] = json.dumps("/opt/ml/model") else: - hps['model_dir'] = json.dumps('s3://{}/{}/model'.format(BUCKET_NAME, job_name)) + hps["model_dir"] = json.dumps("s3://{}/{}/model".format(BUCKET_NAME, job_name)) else: - hps['checkpoint_path'] = json.dumps('s3://{}/{}/checkpoints'.format(BUCKET_NAME, job_name)) - hps['training_steps'] = '1000' - hps['evaluation_steps'] = '10' - hps['sagemaker_requirements'] = '"{}"'.format(REQUIREMENTS_FILE) + hps["checkpoint_path"] = json.dumps("s3://{}/{}/checkpoints".format(BUCKET_NAME, job_name)) + hps["training_steps"] = "1000" + hps["evaluation_steps"] = "10" + hps["sagemaker_requirements"] = '"{}"'.format(REQUIREMENTS_FILE) return hps -def _create_train_job(tf_version, script_mode=False, horovod=False, repo_name=IMAGE_REPO_NAME, py_version='py2'): +def _create_train_job( + tf_version, script_mode=False, horovod=False, repo_name=IMAGE_REPO_NAME, py_version="py2" +): return { - 'image': _get_full_cpu_image_uri(tf_version, repo=repo_name, py_version=py_version), - 'input_mode': 'File', - 'input_config': [ + "image": _get_full_cpu_image_uri(tf_version, repo=repo_name, py_version=py_version), + "input_mode": "File", + "input_config": [ { - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix' + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", } - } + }, } ], - 'role': ROLE, - 'job_name': '{}-{}'.format(repo_name, TIMESTAMP), - 'output_config': { - 'S3OutputPath': 's3://{}/'.format(BUCKET_NAME), + "role": ROLE, + "job_name": "{}-{}".format(repo_name, TIMESTAMP), + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": "ml.c4.4xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, }, - 'resource_config': { - 'InstanceType': 'ml.c4.4xlarge', - 'InstanceCount': 1, - 'VolumeSizeInGB': 30, - }, - 'hyperparameters': _hyperparameters(script_mode, horovod), - 'stop_condition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'tags': None, - 'vpc_config': None, - 'metric_definitions': None + "hyperparameters": _hyperparameters(script_mode, horovod), + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, } -def _build_tf(sagemaker_session, framework_version=defaults.TF_VERSION, train_instance_type=None, - checkpoint_path=None, base_job_name=None, - training_steps=None, evaluation_steps=None, **kwargs): - return TensorFlow(entry_point=SCRIPT_PATH, - training_steps=training_steps, - evaluation_steps=evaluation_steps, - framework_version=framework_version, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, - checkpoint_path=checkpoint_path, - base_job_name=base_job_name, - **kwargs) +def _build_tf( + sagemaker_session, + framework_version=defaults.TF_VERSION, + train_instance_type=None, + checkpoint_path=None, + base_job_name=None, + training_steps=None, + evaluation_steps=None, + **kwargs +): + return TensorFlow( + entry_point=SCRIPT_PATH, + training_steps=training_steps, + evaluation_steps=evaluation_steps, + framework_version=framework_version, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE, + checkpoint_path=checkpoint_path, + base_job_name=base_job_name, + **kwargs + ) def test_tf_support_cpu_instances(sagemaker_session, tf_version): - tf = _build_tf(sagemaker_session, tf_version, train_instance_type='ml.c2.2xlarge') + tf = _build_tf(sagemaker_session, tf_version, train_instance_type="ml.c2.2xlarge") assert tf.train_image() == _get_full_cpu_image_uri(tf_version) - tf = _build_tf(sagemaker_session, tf_version, train_instance_type='ml.c4.2xlarge') + tf = _build_tf(sagemaker_session, tf_version, train_instance_type="ml.c4.2xlarge") assert tf.train_image() == _get_full_cpu_image_uri(tf_version) - tf = _build_tf(sagemaker_session, tf_version, train_instance_type='ml.m16') + tf = _build_tf(sagemaker_session, tf_version, train_instance_type="ml.m16") assert tf.train_image() == _get_full_cpu_image_uri(tf_version) def test_tf_support_gpu_instances(sagemaker_session, tf_version): - tf = _build_tf(sagemaker_session, tf_version, train_instance_type='ml.g2.2xlarge') + tf = _build_tf(sagemaker_session, tf_version, train_instance_type="ml.g2.2xlarge") assert tf.train_image() == _get_full_gpu_image_uri(tf_version) - tf = _build_tf(sagemaker_session, tf_version, train_instance_type='ml.p2.2xlarge') + tf = _build_tf(sagemaker_session, tf_version, train_instance_type="ml.p2.2xlarge") assert tf.train_image() == _get_full_gpu_image_uri(tf_version) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_deploy_model_server_workers(sagemaker_session): tf = _build_tf(sagemaker_session) - tf.fit(inputs=s3_input('s3://mybucket/train')) + tf.fit(inputs=s3_input("s3://mybucket/train")) - tf.deploy(initial_instance_count=1, instance_type='ml.c2.2xlarge', model_server_workers=2) + tf.deploy(initial_instance_count=1, instance_type="ml.c2.2xlarge", model_server_workers=2) - assert "2" == sagemaker_session.method_calls[3][1][2]['Environment'][ - MODEL_SERVER_WORKERS_PARAM_NAME.upper()] + assert ( + "2" + == sagemaker_session.method_calls[3][1][2]["Environment"][ + MODEL_SERVER_WORKERS_PARAM_NAME.upper() + ] + ) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_deploy_model_server_workers_unset(sagemaker_session): tf = _build_tf(sagemaker_session) - tf.fit(inputs=s3_input('s3://mybucket/train')) + tf.fit(inputs=s3_input("s3://mybucket/train")) - tf.deploy(initial_instance_count=1, instance_type='ml.c2.2xlarge') + tf.deploy(initial_instance_count=1, instance_type="ml.c2.2xlarge") - assert MODEL_SERVER_WORKERS_PARAM_NAME.upper() not in sagemaker_session.method_calls[3][1][2]['Environment'] + assert ( + MODEL_SERVER_WORKERS_PARAM_NAME.upper() + not in sagemaker_session.method_calls[3][1][2]["Environment"] + ) def test_tf_invalid_requirements_path(sagemaker_session): - requirements_file = '/foo/bar/requirements.txt' + requirements_file = "/foo/bar/requirements.txt" with pytest.raises(ValueError) as e: _build_tf(sagemaker_session, requirements_file=requirements_file, source_dir=DATA_DIR) - assert 'Requirements file {} is not a path relative to source_dir.'.format(requirements_file) in str(e.value) + assert "Requirements file {} is not a path relative to source_dir.".format( + requirements_file + ) in str(e.value) def test_tf_nonexistent_requirements_path(sagemaker_session): - requirements_file = 'nonexistent_requirements.txt' + requirements_file = "nonexistent_requirements.txt" with pytest.raises(ValueError) as e: _build_tf(sagemaker_session, requirements_file=requirements_file, source_dir=DATA_DIR) - assert 'Requirements file {} does not exist.'.format(requirements_file) in str(e.value) + assert "Requirements file {} does not exist.".format(requirements_file) in str(e.value) def test_create_model(sagemaker_session, tf_version): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, framework_version=tf_version, - container_log_level=container_log_level, base_job_name='job', - source_dir=source_dir) - - job_name = 'doing something' - tf.fit(inputs='s3://mybucket/train', job_name=job_name) + source_dir = "s3://mybucket/source" + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + training_steps=1000, + evaluation_steps=10, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=tf_version, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "doing something" + tf.fit(inputs="s3://mybucket/train", job_name=job_name) model = tf.create_model() assert model.sagemaker_session == sagemaker_session @@ -249,21 +274,31 @@ def test_create_model(sagemaker_session, tf_version): def test_create_model_with_optional_params(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - enable_cloudwatch_metrics = 'true' - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, base_job_name='job', - source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics) - - job_name = 'doing something' - tf.fit(inputs='s3://mybucket/train', job_name=job_name) - - new_role = 'role' + source_dir = "s3://mybucket/source" + enable_cloudwatch_metrics = "true" + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + training_steps=1000, + evaluation_steps=10, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + ) + + job_name = "doing something" + tf.fit(inputs="s3://mybucket/train", job_name=job_name) + + new_role = "role" model_server_workers = 2 - vpc_config = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']} - model = tf.create_model(role=new_role, model_server_workers=model_server_workers, - vpc_config_override=vpc_config) + vpc_config = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + model = tf.create_model( + role=new_role, model_server_workers=model_server_workers, vpc_config_override=vpc_config + ) assert model.role == new_role assert model.model_server_workers == model_server_workers @@ -272,38 +307,55 @@ def test_create_model_with_optional_params(sagemaker_session): def test_create_model_with_custom_image(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - custom_image = 'tensorflow:1.0' - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, image_name=custom_image, - container_log_level=container_log_level, base_job_name='job', - source_dir=source_dir) - - job_name = 'doing something' - tf.fit(inputs='s3://mybucket/train', job_name=job_name) + source_dir = "s3://mybucket/source" + custom_image = "tensorflow:1.0" + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + training_steps=1000, + evaluation_steps=10, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + image_name=custom_image, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "doing something" + tf.fit(inputs="s3://mybucket/train", job_name=job_name) model = tf.create_model() assert model.image == custom_image -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('time.strftime', MagicMock(return_value=TIMESTAMP)) -@patch('time.time', MagicMock(return_value=TIME)) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +@patch("time.time", MagicMock(return_value=TIME)) def test_tf(sagemaker_session, tf_version): - tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, training_steps=1000, - evaluation_steps=10, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=tf_version, requirements_file=REQUIREMENTS_FILE, source_dir=DATA_DIR) - - inputs = 's3://mybucket/train' + tf = TensorFlow( + entry_point=SCRIPT_FILE, + role=ROLE, + sagemaker_session=sagemaker_session, + training_steps=1000, + evaluation_steps=10, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=tf_version, + requirements_file=REQUIREMENTS_FILE, + source_dir=DATA_DIR, + ) + + inputs = "s3://mybucket/train" tf.fit(inputs=inputs) call_names = [c[0] for c in sagemaker_session.method_calls] - assert call_names == ['train', 'logs_for_job'] + assert call_names == ["train", "logs_for_job"] expected_train_args = _create_train_job(tf_version) - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args @@ -311,435 +363,515 @@ def test_tf(sagemaker_session, tf_version): model = tf.create_model() environment = { - 'Environment': { - 'SAGEMAKER_SUBMIT_DIRECTORY': - 's3://mybucket/sagemaker-tensorflow-2017-11-06-14:14:15.673/source/sourcedir.tar.gz', - 'SAGEMAKER_PROGRAM': 'dummy_script.py', 'SAGEMAKER_REQUIREMENTS': 'dummy_requirements.txt', - 'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false', 'SAGEMAKER_REGION': 'us-west-2', - 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20' + "Environment": { + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/sagemaker-tensorflow-2017-11-06-14:14:15.673/source/sourcedir.tar.gz", # noqa: E501 + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_REQUIREMENTS": "dummy_requirements.txt", + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + "SAGEMAKER_REGION": "us-west-2", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", }, - 'Image': create_image_uri('us-west-2', "tensorflow", INSTANCE_TYPE, tf_version, "py2"), - 'ModelDataUrl': 's3://m/m.tar.gz' + "Image": create_image_uri("us-west-2", "tensorflow", INSTANCE_TYPE, tf_version, "py2"), + "ModelDataUrl": "s3://m/m.tar.gz", } assert environment == model.prepare_container_def(INSTANCE_TYPE) - assert 'cpu' in model.prepare_container_def(INSTANCE_TYPE)['Image'] + assert "cpu" in model.prepare_container_def(INSTANCE_TYPE)["Image"] predictor = tf.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, TensorFlowPredictor) -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('subprocess.Popen') -@patch('subprocess.call') -@patch('os.access', return_value=False) -def test_run_tensorboard_locally_without_tensorboard_binary(time, strftime, popen, call, access, sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("subprocess.Popen") +@patch("subprocess.call") +@patch("os.access", return_value=False) +def test_run_tensorboard_locally_without_tensorboard_binary( + time, strftime, popen, call, access, sagemaker_session +): + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) with pytest.raises(EnvironmentError) as error: - tf.fit(inputs='s3://mybucket/train', run_tensorboard_locally=True) - assert str(error.value) == 'TensorBoard is not installed in the system. Please install TensorBoard using the ' \ - 'following command: \n pip install tensorboard' + tf.fit(inputs="s3://mybucket/train", run_tensorboard_locally=True) + assert ( + str(error.value) + == "TensorBoard is not installed in the system. Please install TensorBoard using the " + "following command: \n pip install tensorboard" + ) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_model(sagemaker_session, tf_version): - model = TensorFlowModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = TensorFlowModel( + MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + ) predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, TensorFlowPredictor) -@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock()) +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_model_image_accelerator(sagemaker_session): - model = TensorFlowModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, - sagemaker_session=sagemaker_session) + model = TensorFlowModel( + MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session + ) container_def = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) - assert container_def['Image'] == _get_full_cpu_image_uri_with_ei(defaults.TF_VERSION) - - -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('subprocess.Popen') -@patch('subprocess.call') -@patch('os.access', side_effect=[False, True]) -def test_run_tensorboard_locally_without_awscli_binary(time, strftime, popen, call, access, sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) + assert container_def["Image"] == _get_full_cpu_image_uri_with_ei(defaults.TF_VERSION) + + +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("subprocess.Popen") +@patch("subprocess.call") +@patch("os.access", side_effect=[False, True]) +def test_run_tensorboard_locally_without_awscli_binary( + time, strftime, popen, call, access, sagemaker_session +): + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) with pytest.raises(EnvironmentError) as error: - tf.fit(inputs='s3://mybucket/train', run_tensorboard_locally=True) - assert str(error.value) == 'The AWS CLI is not installed in the system. Please install the AWS CLI using the ' \ - 'following command: \n pip install awscli' - - -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories') -@patch('tempfile.mkdtemp', return_value='/my/temp/folder') -@patch('shutil.rmtree') -@patch('os.access', return_value=True) -@patch('subprocess.call') -@patch('subprocess.Popen') -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('time.sleep') -def test_run_tensorboard_locally(sleep, time, strftime, popen, call, access, rmtree, mkdtemp, sync, sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) + tf.fit(inputs="s3://mybucket/train", run_tensorboard_locally=True) + assert ( + str(error.value) + == "The AWS CLI is not installed in the system. Please install the AWS CLI using the " + "following command: \n pip install awscli" + ) + + +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.tensorflow.estimator.Tensorboard._sync_directories") +@patch("tempfile.mkdtemp", return_value="/my/temp/folder") +@patch("shutil.rmtree") +@patch("os.access", return_value=True) +@patch("subprocess.call") +@patch("subprocess.Popen") +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("time.sleep") +def test_run_tensorboard_locally( + sleep, time, strftime, popen, call, access, rmtree, mkdtemp, sync, sagemaker_session +): + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) popen().poll.return_value = None - tf.fit(inputs='s3://mybucket/train', run_tensorboard_locally=True) - - popen.assert_called_with(['tensorboard', '--logdir', '/my/temp/folder', '--host', 'localhost', '--port', '6006'], - stderr=-1, - stdout=-1) - - -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories') -@patch('tempfile.mkdtemp', return_value='/my/temp/folder') -@patch('shutil.rmtree') -@patch('socket.socket') -@patch('os.access', return_value=True) -@patch('subprocess.call') -@patch('subprocess.Popen') -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('time.sleep') -def test_run_tensorboard_locally_port_in_use(sleep, time, strftime, popen, call, access, socket, rmtree, mkdtemp, sync, - sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE) + tf.fit(inputs="s3://mybucket/train", run_tensorboard_locally=True) + + popen.assert_called_with( + ["tensorboard", "--logdir", "/my/temp/folder", "--host", "localhost", "--port", "6006"], + stderr=-1, + stdout=-1, + ) + + +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.tensorflow.estimator.Tensorboard._sync_directories") +@patch("tempfile.mkdtemp", return_value="/my/temp/folder") +@patch("shutil.rmtree") +@patch("socket.socket") +@patch("os.access", return_value=True) +@patch("subprocess.call") +@patch("subprocess.Popen") +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("time.sleep") +def test_run_tensorboard_locally_port_in_use( + sleep, time, strftime, popen, call, access, socket, rmtree, mkdtemp, sync, sagemaker_session +): + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) popen().poll.side_effect = [-1, None] - tf.fit(inputs='s3://mybucket/train', run_tensorboard_locally=True) + tf.fit(inputs="s3://mybucket/train", run_tensorboard_locally=True) - popen.assert_any_call(['tensorboard', '--logdir', '/my/temp/folder', '--host', 'localhost', '--port', '6006'], - stderr=-1, stdout=-1) + popen.assert_any_call( + ["tensorboard", "--logdir", "/my/temp/folder", "--host", "localhost", "--port", "6006"], + stderr=-1, + stdout=-1, + ) - popen.assert_any_call(['tensorboard', '--logdir', '/my/temp/folder', '--host', 'localhost', '--port', '6007'], - stderr=-1, stdout=-1) + popen.assert_any_call( + ["tensorboard", "--logdir", "/my/temp/folder", "--host", "localhost", "--port", "6007"], + stderr=-1, + stdout=-1, + ) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_checkpoint_not_set(sagemaker_session): job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09" - tf = _build_tf(sagemaker_session, checkpoint_path=None, base_job_name=job_name, - output_path="s3://{}/".format(sagemaker_session.default_bucket())) - tf.fit(inputs=s3_input('s3://mybucket/train'), job_name=job_name) + tf = _build_tf( + sagemaker_session, + checkpoint_path=None, + base_job_name=job_name, + output_path="s3://{}/".format(sagemaker_session.default_bucket()), + ) + tf.fit(inputs=s3_input("s3://mybucket/train"), job_name=job_name) - expected_result = '"s3://{}/{}/checkpoints"'.format(sagemaker_session.default_bucket(), job_name) - assert tf.hyperparameters()['checkpoint_path'] == expected_result + expected_result = '"s3://{}/{}/checkpoints"'.format( + sagemaker_session.default_bucket(), job_name + ) + assert tf.hyperparameters()["checkpoint_path"] == expected_result -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_training_and_evaluation_steps_not_set(sagemaker_session): job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09" output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name) - tf = _build_tf(sagemaker_session, training_steps=None, evaluation_steps=None, output_path=output_path) - tf.fit(inputs=s3_input('s3://mybucket/train')) - assert tf.hyperparameters()['training_steps'] == 'null' - assert tf.hyperparameters()['evaluation_steps'] == 'null' + tf = _build_tf( + sagemaker_session, training_steps=None, evaluation_steps=None, output_path=output_path + ) + tf.fit(inputs=s3_input("s3://mybucket/train")) + assert tf.hyperparameters()["training_steps"] == "null" + assert tf.hyperparameters()["evaluation_steps"] == "null" -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_training_and_evaluation_steps(sagemaker_session): job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09" output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name) - tf = _build_tf(sagemaker_session, training_steps=123, evaluation_steps=456, output_path=output_path) - tf.fit(inputs=s3_input('s3://mybucket/train')) - assert tf.hyperparameters()['training_steps'] == '123' - assert tf.hyperparameters()['evaluation_steps'] == '456' + tf = _build_tf( + sagemaker_session, training_steps=123, evaluation_steps=456, output_path=output_path + ) + tf.fit(inputs=s3_input("s3://mybucket/train")) + assert tf.hyperparameters()["training_steps"] == "123" + assert tf.hyperparameters()["evaluation_steps"] == "456" -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_checkpoint_set(sagemaker_session): - tf = _build_tf(sagemaker_session, checkpoint_path='s3://my_checkpoint_bucket') - assert tf.hyperparameters()['checkpoint_path'] == json.dumps("s3://my_checkpoint_bucket") + tf = _build_tf(sagemaker_session, checkpoint_path="s3://my_checkpoint_bucket") + assert tf.hyperparameters()["checkpoint_path"] == json.dumps("s3://my_checkpoint_bucket") -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_train_image_default(sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_PATH, - role=ROLE, - sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE) + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + ) assert _get_full_cpu_image_uri(defaults.TF_VERSION) in tf.train_image() -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_attach(sagemaker_session, tf_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:{}-cpu-py2'.format(tf_version) + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:{}-cpu-py2".format( + tf_version + ) rjd = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': training_image - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'evaluation_steps': '10' + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "evaluation_steps": "10", }, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", }, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) - - estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' - assert estimator.py_version == 'py2' + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) + + estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" + assert estimator.py_version == "py2" assert estimator.framework_version == tf_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' + assert estimator.input_mode == "File" assert estimator.training_steps == 100 assert estimator.evaluation_steps == 10 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' - assert estimator.checkpoint_path == 's3://other/1508872349' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + assert estimator.checkpoint_path == "s3://other/1508872349" -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_attach_new_repo_name(sagemaker_session, tf_version): - training_image = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:{}-cpu-py2'.format(tf_version) + training_image = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:{}-cpu-py2".format( + tf_version + ) rjd = { - 'AlgorithmSpecification': {'TrainingInputMode': 'File', 'TrainingImage': training_image}, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'evaluation_steps': '10' + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "evaluation_steps": "10", }, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", }, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) - - estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' - assert estimator.py_version == 'py2' + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) + + estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" + assert estimator.py_version == "py2" assert estimator.framework_version == tf_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' + assert estimator.input_mode == "File" assert estimator.training_steps == 100 assert estimator.evaluation_steps == 10 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' - assert estimator.checkpoint_path == 's3://other/1508872349' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + assert estimator.checkpoint_path == "s3://other/1508872349" assert estimator.train_image() == training_image -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_attach_old_container(sagemaker_session): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0' + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0" rjd = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'evaluation_steps': '10'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) - - estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' - assert estimator.py_version == 'py2' - assert estimator.framework_version == '1.4' - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "evaluation_steps": "10", + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) + + estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" + assert estimator.py_version == "py2" + assert estimator.framework_version == "1.4" + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' + assert estimator.input_mode == "File" assert estimator.training_steps == 100 assert estimator.evaluation_steps == 10 - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' - assert estimator.hyperparameters()['training_steps'] == '100' - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' - assert estimator.checkpoint_path == 's3://other/1508872349' + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + assert estimator.checkpoint_path == "s3://other/1508872349" def test_attach_wrong_framework(sagemaker_session): returned_job_description = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0' - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'training_steps': '100' - + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-cpu:1.0", }, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': - {'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' - }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "training_steps": "100", }, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': { - 'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo' + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", }, - 'TrainingJobOutput': { - 'S3TrainingJobOutput': 's3://here/output.tar.gz' - } + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, } - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', - return_value=returned_job_description) + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) with pytest.raises(ValueError) as error: - TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert "didn't use image for requested framework" in str(error) def test_attach_custom_image(sagemaker_session): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/tensorflow_with_custom_binary:1.0' + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/tensorflow_with_custom_binary:1.0" rjd = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': training_image}, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - 'evaluation_steps': '10'}, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge'}, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) - - estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session) + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "evaluation_steps": "10", + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) + + estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.image_name == training_image assert estimator.train_image() == training_image -@patch('sagemaker.fw_utils.empty_framework_version_warning') +@patch("sagemaker.fw_utils.empty_framework_version_warning") def test_empty_framework_version(warning, sagemaker_session): - estimator = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version=None) + estimator = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version=None, + ) assert estimator.framework_version == defaults.TF_VERSION warning.assert_called_with(defaults.TF_VERSION, estimator.LATEST_VERSION) def _deprecated_args_msg(args): - return '{} are deprecated in script mode. Please do not set {}.'.format( - ', '.join(tfe._FRAMEWORK_MODE_ARGS), args) + return "{} are deprecated in script mode. Please do not set {}.".format( + ", ".join(tfe._FRAMEWORK_MODE_ARGS), args + ) def test_script_mode_deprecated_args(sagemaker_session): with pytest.raises(AttributeError) as e: - _build_tf(sagemaker_session=sagemaker_session, py_version='py3', checkpoint_path='some_path') - assert _deprecated_args_msg('checkpoint_path') in str(e.value) + _build_tf( + sagemaker_session=sagemaker_session, py_version="py3", checkpoint_path="some_path" + ) + assert _deprecated_args_msg("checkpoint_path") in str(e.value) with pytest.raises(AttributeError) as e: - _build_tf(sagemaker_session=sagemaker_session, py_version='py3', training_steps=1) - assert _deprecated_args_msg('training_steps') in str(e.value) + _build_tf(sagemaker_session=sagemaker_session, py_version="py3", training_steps=1) + assert _deprecated_args_msg("training_steps") in str(e.value) with pytest.raises(AttributeError) as e: _build_tf(sagemaker_session=sagemaker_session, script_mode=True, evaluation_steps=1) - assert _deprecated_args_msg('evaluation_steps') in str(e.value) + assert _deprecated_args_msg("evaluation_steps") in str(e.value) with pytest.raises(AttributeError) as e: - _build_tf(sagemaker_session=sagemaker_session, script_mode=True, requirements_file='some_file') - assert _deprecated_args_msg('requirements_file') in str(e.value) + _build_tf( + sagemaker_session=sagemaker_session, script_mode=True, requirements_file="some_file" + ) + assert _deprecated_args_msg("requirements_file") in str(e.value) with pytest.raises(AttributeError) as e: - _build_tf(sagemaker_session=sagemaker_session, script_mode=True, checkpoint_path='some_path', - requirements_file='some_file', training_steps=1, evaluation_steps=1) - assert _deprecated_args_msg('training_steps, evaluation_steps, requirements_file, checkpoint_path') in str(e.value) + _build_tf( + sagemaker_session=sagemaker_session, + script_mode=True, + checkpoint_path="some_path", + requirements_file="some_file", + training_steps=1, + evaluation_steps=1, + ) + assert _deprecated_args_msg( + "training_steps, evaluation_steps, requirements_file, checkpoint_path" + ) in str(e.value) def test_script_mode_enabled(sagemaker_session): - tf = _build_tf(sagemaker_session=sagemaker_session, py_version='py3') + tf = _build_tf(sagemaker_session=sagemaker_session, py_version="py3") assert tf._script_mode_enabled() is True tf = _build_tf(sagemaker_session=sagemaker_session, script_mode=True) @@ -749,140 +881,180 @@ def test_script_mode_enabled(sagemaker_session): assert tf._script_mode_enabled() is False -@patch('sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model') +@patch("sagemaker.tensorflow.estimator.TensorFlow._create_tfs_model") def test_script_mode_create_model(create_tfs_model, sagemaker_session): - tf = _build_tf(sagemaker_session=sagemaker_session, py_version='py3') + tf = _build_tf(sagemaker_session=sagemaker_session, py_version="py3") tf.create_model() create_tfs_model.assert_called_once() -@patch('sagemaker.utils.create_tar_file', MagicMock()) -@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories') -@patch('sagemaker.tensorflow.estimator.Tensorboard.start') -@patch('os.access', return_value=True) -@patch('subprocess.call') -@patch('subprocess.Popen') -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('time.sleep') -def test_script_mode_tensorboard(sleep, time, strftime, popen, call, access, start, sync, sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session, - train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version='some_version', script_mode=True) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.tensorflow.estimator.Tensorboard._sync_directories") +@patch("sagemaker.tensorflow.estimator.Tensorboard.start") +@patch("os.access", return_value=True) +@patch("subprocess.call") +@patch("subprocess.Popen") +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("time.sleep") +def test_script_mode_tensorboard( + sleep, time, strftime, popen, call, access, start, sync, sagemaker_session +): + tf = TensorFlow( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + framework_version="some_version", + script_mode=True, + ) popen().poll.return_value = None - tf.fit(inputs='s3://mybucket/train', run_tensorboard_locally=True) + tf.fit(inputs="s3://mybucket/train", run_tensorboard_locally=True) start.assert_not_called() -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_script_mode(time, strftime, sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, py_version='py3', - train_instance_type=INSTANCE_TYPE, train_instance_count=1, framework_version='1.11', - source_dir=DATA_DIR) - - inputs = 's3://mybucket/train' + tf = TensorFlow( + entry_point=SCRIPT_FILE, + role=ROLE, + sagemaker_session=sagemaker_session, + py_version="py3", + train_instance_type=INSTANCE_TYPE, + train_instance_count=1, + framework_version="1.11", + source_dir=DATA_DIR, + ) + + inputs = "s3://mybucket/train" tf.fit(inputs=inputs) call_names = [c[0] for c in sagemaker_session.method_calls] - assert call_names == ['train', 'logs_for_job'] + assert call_names == ["train", "logs_for_job"] - expected_train_args = _create_train_job('1.11', script_mode=True, repo_name=SM_IMAGE_REPO_NAME, py_version='py3') - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs + expected_train_args = _create_train_job( + "1.11", script_mode=True, repo_name=SM_IMAGE_REPO_NAME, py_version="py3" + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_script_mode_ps(time, strftime, sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, py_version='py3', - train_instance_type=INSTANCE_TYPE, train_instance_count=1, framework_version='1.11', - source_dir=DATA_DIR, distributions=DISTRIBUTION_ENABLED) - - inputs = 's3://mybucket/train' + tf = TensorFlow( + entry_point=SCRIPT_FILE, + role=ROLE, + sagemaker_session=sagemaker_session, + py_version="py3", + train_instance_type=INSTANCE_TYPE, + train_instance_count=1, + framework_version="1.11", + source_dir=DATA_DIR, + distributions=DISTRIBUTION_ENABLED, + ) + + inputs = "s3://mybucket/train" tf.fit(inputs=inputs) call_names = [c[0] for c in sagemaker_session.method_calls] - assert call_names == ['train', 'logs_for_job'] + assert call_names == ["train", "logs_for_job"] - expected_train_args = _create_train_job('1.11', script_mode=True, repo_name=SM_IMAGE_REPO_NAME, py_version='py3') - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs - expected_train_args['hyperparameters'][TensorFlow.LAUNCH_PS_ENV_NAME] = json.dumps(True) + expected_train_args = _create_train_job( + "1.11", script_mode=True, repo_name=SM_IMAGE_REPO_NAME, py_version="py3" + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["hyperparameters"][TensorFlow.LAUNCH_PS_ENV_NAME] = json.dumps(True) actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args -@patch('time.strftime', return_value=TIMESTAMP) -@patch('time.time', return_value=TIME) -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_script_mode_mpi(time, strftime, sagemaker_session): - tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, py_version='py3', - train_instance_type=INSTANCE_TYPE, train_instance_count=1, framework_version='1.11', - source_dir=DATA_DIR, distributions=DISTRIBUTION_MPI_ENABLED) - - inputs = 's3://mybucket/train' + tf = TensorFlow( + entry_point=SCRIPT_FILE, + role=ROLE, + sagemaker_session=sagemaker_session, + py_version="py3", + train_instance_type=INSTANCE_TYPE, + train_instance_count=1, + framework_version="1.11", + source_dir=DATA_DIR, + distributions=DISTRIBUTION_MPI_ENABLED, + ) + + inputs = "s3://mybucket/train" tf.fit(inputs=inputs) call_names = [c[0] for c in sagemaker_session.method_calls] - assert call_names == ['train', 'logs_for_job'] - - expected_train_args = _create_train_job('1.11', script_mode=True, horovod=True, - repo_name=SM_IMAGE_REPO_NAME, py_version='py3') - expected_train_args['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] = inputs - expected_train_args['hyperparameters'][TensorFlow.LAUNCH_MPI_ENV_NAME] = json.dumps(True) - expected_train_args['hyperparameters'][TensorFlow.MPI_NUM_PROCESSES_PER_HOST] = json.dumps(2) - expected_train_args['hyperparameters'][TensorFlow.MPI_CUSTOM_MPI_OPTIONS] = json.dumps('options') + assert call_names == ["train", "logs_for_job"] + + expected_train_args = _create_train_job( + "1.11", script_mode=True, horovod=True, repo_name=SM_IMAGE_REPO_NAME, py_version="py3" + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["hyperparameters"][TensorFlow.LAUNCH_MPI_ENV_NAME] = json.dumps(True) + expected_train_args["hyperparameters"][TensorFlow.MPI_NUM_PROCESSES_PER_HOST] = json.dumps(2) + expected_train_args["hyperparameters"][TensorFlow.MPI_CUSTOM_MPI_OPTIONS] = json.dumps( + "options" + ) actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args -@patch('sagemaker.utils.create_tar_file', MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) def test_tf_script_mode_attach(sagemaker_session, tf_version): - training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py3-cpu:{}-cpu-py3'.format(tf_version) + training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py3-cpu:{}-cpu-py3".format( + tf_version + ) rjd = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': training_image - }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"' + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', }, - 'RoleArn': 'arn:aws:iam::366:role/SageMakerRole', - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", }, - 'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60}, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'}, - 'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}} - sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd) - - estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session) - assert estimator.latest_training_job.job_name == 'neo' - assert estimator.py_version == 'py3' + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=rjd + ) + + estimator = TensorFlow.attach(training_job_name="neo", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "neo" + assert estimator.py_version == "py3" assert estimator.framework_version == tf_version - assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole' + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 assert estimator.train_max_run == 24 * 60 * 60 - assert estimator.input_mode == 'File' - assert estimator.input_mode == 'File' - assert estimator.base_job_name == 'neo' - assert estimator.output_path == 's3://place/output/neo' - assert estimator.output_kms_key == '' + assert estimator.input_mode == "File" + assert estimator.input_mode == "File" + assert estimator.base_job_name == "neo" + assert estimator.output_path == "s3://place/output/neo" + assert estimator.output_kms_key == "" assert estimator.hyperparameters() is not None - assert estimator.source_dir == 's3://some/sourcedir.tar.gz' - assert estimator.entry_point == 'iris-dnn-classifier.py' + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" diff --git a/tests/unit/test_tf_predictor.py b/tests/unit/test_tf_predictor.py index b58971adbf..7ba93a5c56 100644 --- a/tests/unit/test_tf_predictor.py +++ b/tests/unit/test_tf_predictor.py @@ -23,41 +23,53 @@ import tensorflow as tf import six from six import BytesIO -from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY, PREDICT_INPUTS +from tensorflow.python.saved_model.signature_constants import ( + DEFAULT_SERVING_SIGNATURE_DEF_KEY, + PREDICT_INPUTS, +) from sagemaker.predictor import RealTimePredictor -from sagemaker.tensorflow.predictor import tf_csv_serializer, tf_deserializer, tf_json_deserializer, \ - tf_json_serializer, tf_serializer +from sagemaker.tensorflow.predictor import ( + tf_csv_serializer, + tf_deserializer, + tf_json_deserializer, + tf_json_serializer, + tf_serializer, +) from sagemaker.tensorflow.tensorflow_serving.apis import classification_pb2 -BUCKET_NAME = 'mybucket' -ENDPOINT = 'myendpoint' -REGION = 'us-west-2' - -CLASSIFICATION_RESPONSE = {'result': {'classifications': [{'classes': [{'label': '0', 'score': 0.0012890376383438706}, - {'label': '1', 'score': 0.9814321994781494}, - {'label': '2', - 'score': 0.017278732731938362}]}]}} +BUCKET_NAME = "mybucket" +ENDPOINT = "myendpoint" +REGION = "us-west-2" + +CLASSIFICATION_RESPONSE = { + "result": { + "classifications": [ + { + "classes": [ + {"label": "0", "score": 0.0012890376383438706}, + {"label": "1", "score": 0.9814321994781494}, + {"label": "2", "score": 0.017278732731938362}, + ] + } + ] + } +} CSV_CONTENT_TYPE = "text/csv" JSON_CONTENT_TYPE = "application/json" PROTO_CONTENT_TYPE = "application/octet-stream" -ENDPOINT_DESC = { - 'EndpointConfigName': ENDPOINT -} +ENDPOINT_DESC = {"EndpointConfigName": ENDPOINT} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - ims = Mock(name='sagemaker_session', boto_session=boto_mock) - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + boto_mock = Mock(name="boto_session", region_name=REGION) + ims = Mock(name="sagemaker_session", boto_session=boto_mock) + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) return ims @@ -71,20 +83,24 @@ def test_endpoint_initialization(sagemaker_session): def test_classification_request_json(sagemaker_session): data = [1, 2, 3] - predictor = RealTimePredictor(endpoint=ENDPOINT, - sagemaker_session=sagemaker_session, - deserializer=tf_json_deserializer, - serializer=tf_json_serializer) + predictor = RealTimePredictor( + endpoint=ENDPOINT, + sagemaker_session=sagemaker_session, + deserializer=tf_json_deserializer, + serializer=tf_json_serializer, + ) - mock_response(json.dumps(CLASSIFICATION_RESPONSE).encode('utf-8'), sagemaker_session, JSON_CONTENT_TYPE) + mock_response( + json.dumps(CLASSIFICATION_RESPONSE).encode("utf-8"), sagemaker_session, JSON_CONTENT_TYPE + ) result = predictor.predict(data) sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with( Accept=JSON_CONTENT_TYPE, - Body='[1, 2, 3]', + Body="[1, 2, 3]", ContentType=JSON_CONTENT_TYPE, - EndpointName='myendpoint' + EndpointName="myendpoint", ) assert result == CLASSIFICATION_RESPONSE @@ -92,10 +108,12 @@ def test_classification_request_json(sagemaker_session): def test_classification_request_csv(sagemaker_session): data = [1, 2, 3] - predictor = RealTimePredictor(serializer=tf_csv_serializer, - deserializer=tf_deserializer, - sagemaker_session=sagemaker_session, - endpoint=ENDPOINT) + predictor = RealTimePredictor( + serializer=tf_csv_serializer, + deserializer=tf_deserializer, + sagemaker_session=sagemaker_session, + endpoint=ENDPOINT, + ) expected_response = json_format.Parse( json.dumps(CLASSIFICATION_RESPONSE), classification_pb2.ClassificationResponse() @@ -107,15 +125,17 @@ def test_classification_request_csv(sagemaker_session): sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with( Accept=PROTO_CONTENT_TYPE, - Body='1,2,3', + Body="1,2,3", ContentType=CSV_CONTENT_TYPE, - EndpointName='myendpoint' + EndpointName="myendpoint", ) # python 2 and 3 protobuf serialization has different precision so I'm checking # the version here if sys.version_info < (3, 0): - assert str(result) == """result { + assert ( + str(result) + == """result { classifications { classes { label: "0" @@ -132,8 +152,11 @@ def test_classification_request_csv(sagemaker_session): } } """ + ) else: - assert str(result) == """result { + assert ( + str(result) + == """result { classifications { classes { label: "0" @@ -150,6 +173,7 @@ def test_classification_request_csv(sagemaker_session): } } """ + ) def test_json_deserializer_should_work_with_predict_response(): @@ -193,31 +217,31 @@ def test_json_deserializer_should_work_with_predict_response(): stream = BytesIO(data) - response = tf_json_deserializer(stream, 'application/json') + response = tf_json_deserializer(stream, "application/json") if six.PY2: - string_vals = ['apple', 'banana', 'orange'] + string_vals = ["apple", "banana", "orange"] else: - string_vals = [b'apple', b'banana', b'orange'] + string_vals = [b"apple", b"banana", b"orange"] assert response == { - 'model_spec': { - 'name': u'generic_model', - 'signature_name': u'serving_default', - 'version': {'value': 1531758457. if six.PY2 else 1531758457} + "model_spec": { + "name": u"generic_model", + "signature_name": u"serving_default", + "version": {"value": 1531758457.0 if six.PY2 else 1531758457}, }, - 'outputs': { - u'ages': { - 'dtype': 1, - 'float_val': [4.954165935516357], - 'tensor_shape': {'dim': [{'size': 1. if six.PY2 else 1}]} + "outputs": { + u"ages": { + "dtype": 1, + "float_val": [4.954165935516357], + "tensor_shape": {"dim": [{"size": 1.0 if six.PY2 else 1}]}, }, - u'example_strings': { - 'dtype': 7, - 'string_val': string_vals, - 'tensor_shape': {'dim': [{'size': 3. if six.PY2 else 3}]} - } - } + u"example_strings": { + "dtype": 7, + "string_val": string_vals, + "tensor_shape": {"dim": [{"size": 3.0 if six.PY2 else 3}]}, + }, + }, } @@ -228,10 +252,12 @@ def test_classification_request_pb(sagemaker_session): example = request.input.example_list.examples.add() example.features.feature[PREDICT_INPUTS].float_list.value.extend([6.4, 3.2, 4.5, 1.5]) - predictor = RealTimePredictor(sagemaker_session=sagemaker_session, - endpoint=ENDPOINT, - deserializer=tf_deserializer, - serializer=tf_serializer) + predictor = RealTimePredictor( + sagemaker_session=sagemaker_session, + endpoint=ENDPOINT, + deserializer=tf_deserializer, + serializer=tf_serializer, + ) expected_response = classification_pb2.ClassificationResponse() classes = expected_response.result.classifications.add().classes @@ -256,13 +282,15 @@ def test_classification_request_pb(sagemaker_session): Accept=PROTO_CONTENT_TYPE, Body=request.SerializeToString(), ContentType=PROTO_CONTENT_TYPE, - EndpointName='myendpoint' + EndpointName="myendpoint", ) # python 2 and 3 protobuf serialization has different precision so I'm checking # the version here if sys.version_info < (3, 0): - assert str(result) == """result { + assert ( + str(result) + == """result { classifications { classes { label: "0" @@ -279,8 +307,11 @@ def test_classification_request_pb(sagemaker_session): } } """ + ) else: - assert str(result) == """result { + assert ( + str(result) + == """result { classifications { classes { label: "0" @@ -297,17 +328,24 @@ def test_classification_request_pb(sagemaker_session): } } """ + ) def test_predict_request_json(sagemaker_session): - data = [6.4, 3.2, .5, 1.5] - tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32) - predictor = RealTimePredictor(sagemaker_session=sagemaker_session, - endpoint=ENDPOINT, - deserializer=tf_json_deserializer, - serializer=tf_json_serializer) + data = [6.4, 3.2, 0.5, 1.5] + tensor_proto = tf.make_tensor_proto( + values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32 + ) + predictor = RealTimePredictor( + sagemaker_session=sagemaker_session, + endpoint=ENDPOINT, + deserializer=tf_json_deserializer, + serializer=tf_json_serializer, + ) - mock_response(json.dumps(CLASSIFICATION_RESPONSE).encode('utf-8'), sagemaker_session, JSON_CONTENT_TYPE) + mock_response( + json.dumps(CLASSIFICATION_RESPONSE).encode("utf-8"), sagemaker_session, JSON_CONTENT_TYPE + ) result = predictor.predict(tensor_proto) @@ -315,29 +353,35 @@ def test_predict_request_json(sagemaker_session): Accept=JSON_CONTENT_TYPE, Body=json_format.MessageToJson(tensor_proto), ContentType=JSON_CONTENT_TYPE, - EndpointName='myendpoint' + EndpointName="myendpoint", ) assert result == CLASSIFICATION_RESPONSE def test_predict_tensor_request_csv(sagemaker_session): - data = [6.4, 3.2, .5, 1.5] - tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32) - predictor = RealTimePredictor(serializer=tf_csv_serializer, - deserializer=tf_json_deserializer, - sagemaker_session=sagemaker_session, - endpoint=ENDPOINT) + data = [6.4, 3.2, 0.5, 1.5] + tensor_proto = tf.make_tensor_proto( + values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32 + ) + predictor = RealTimePredictor( + serializer=tf_csv_serializer, + deserializer=tf_json_deserializer, + sagemaker_session=sagemaker_session, + endpoint=ENDPOINT, + ) - mock_response(json.dumps(CLASSIFICATION_RESPONSE).encode('utf-8'), sagemaker_session, JSON_CONTENT_TYPE) + mock_response( + json.dumps(CLASSIFICATION_RESPONSE).encode("utf-8"), sagemaker_session, JSON_CONTENT_TYPE + ) result = predictor.predict(tensor_proto) sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with( Accept=JSON_CONTENT_TYPE, - Body='6.4,3.2,0.5,1.5', + Body="6.4,3.2,0.5,1.5", ContentType=CSV_CONTENT_TYPE, - EndpointName='myendpoint' + EndpointName="myendpoint", ) assert result == CLASSIFICATION_RESPONSE @@ -345,12 +389,13 @@ def test_predict_tensor_request_csv(sagemaker_session): def mock_response(expected_response, sagemaker_session, content_type): sagemaker_session.sagemaker_runtime_client.invoke_endpoint.return_value = { - 'ContentType': content_type, - 'Body': io.BytesIO(expected_response)} + "ContentType": content_type, + "Body": io.BytesIO(expected_response), + } def test_json_serialize_dict(): - data = {'tensor1': [1, 2, 3], 'tensor2': [4, 5, 6]} + data = {"tensor1": [1, 2, 3], "tensor2": [4, 5, 6]} serialized = tf_json_serializer(data) # deserialize again for assertion, since dict order is not guaranteed deserialized = json.loads(serialized) @@ -358,13 +403,13 @@ def test_json_serialize_dict(): def test_json_serialize_dict_with_numpy(): - data = {'tensor1': np.asarray([1, 2, 3]), 'tensor2': np.asarray([4, 5, 6])} + data = {"tensor1": np.asarray([1, 2, 3]), "tensor2": np.asarray([4, 5, 6])} serialized = tf_json_serializer(data) # deserialize again for assertion, since dict order is not guaranteed deserialized = json.loads(serialized) - assert deserialized == {'tensor1': [1, 2, 3], 'tensor2': [4, 5, 6]} + assert deserialized == {"tensor1": [1, 2, 3], "tensor2": [4, 5, 6]} def test_json_serialize_numpy(): data = np.asarray([[1, 2, 3], [4, 5, 6]]) - assert tf_json_serializer(data) == '[[1, 2, 3], [4, 5, 6]]' + assert tf_json_serializer(data) == "[[1, 2, 3], [4, 5, 6]]" diff --git a/tests/unit/test_tfs.py b/tests/unit/test_tfs.py index d2d59e0c2d..9ab59e6b7d 100644 --- a/tests/unit/test_tfs.py +++ b/tests/unit/test_tfs.py @@ -23,44 +23,44 @@ from sagemaker.tensorflow.predictor import csv_serializer from sagemaker.tensorflow.serving import Model, Predictor -JSON_CONTENT_TYPE = 'application/json' -CSV_CONTENT_TYPE = 'text/csv' +JSON_CONTENT_TYPE = "application/json" +CSV_CONTENT_TYPE = "text/csv" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.c4.4xlarge' -ACCELERATOR_TYPE = 'ml.eia1.medium' -ROLE = 'Dummy' -REGION = 'us-west-2' -PREDICT_INPUT = {'instances': [1.0, 2.0, 5.0]} -PREDICT_RESPONSE = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} +INSTANCE_TYPE = "ml.c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia1.medium" +ROLE = "Dummy" +REGION = "us-west-2" +PREDICT_INPUT = {"instances": [1.0, 2.0, 5.0]} +PREDICT_RESPONSE = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} CLASSIFY_INPUT = { - 'signature_name': 'tensorflow/serving/classify', - 'examples': [{'x': 1.0}, {'x': 2.0}] + "signature_name": "tensorflow/serving/classify", + "examples": [{"x": 1.0}, {"x": 2.0}], } -CLASSIFY_RESPONSE = {'result': [[0.4, 0.6], [0.2, 0.8]]} +CLASSIFY_RESPONSE = {"result": [[0.4, 0.6], [0.2, 0.8]]} REGRESS_INPUT = { - 'signature_name': 'tensorflow/serving/regress', - 'examples': [{'x': 1.0}, {'x': 2.0}] + "signature_name": "tensorflow/serving/regress", + "examples": [{"x": 1.0}, {"x": 2.0}], } -REGRESS_RESPONSE = {'results': [3.5, 4.0]} +REGRESS_RESPONSE = {"results": [3.5, 4.0]} -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - session = Mock(name='sagemaker_session', boto_session=boto_mock, - boto_region_name=REGION, config=None, local_mode=False) - session.default_bucket = Mock(name='default_bucket', return_value='my_bucket') + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + ) + session.default_bucket = Mock(name="default_bucket", return_value="my_bucket") session.expand_role = Mock(name="expand_role", return_value=ROLE) - describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}} + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} session.sagemaker_client.describe_training_job = Mock(return_value=describe) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) @@ -68,216 +68,284 @@ def sagemaker_session(): def test_tfs_model(sagemaker_session, tf_version): - model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version, - sagemaker_session=sagemaker_session) + model = Model( + "s3://some/data.tar.gz", + role=ROLE, + framework_version=tf_version, + sagemaker_session=sagemaker_session, + ) cdef = model.prepare_container_def(INSTANCE_TYPE) - assert cdef['Image'].endswith('sagemaker-tensorflow-serving:{}-cpu'.format(tf_version)) - assert cdef['Environment'] == {} + assert cdef["Image"].endswith("sagemaker-tensorflow-serving:{}-cpu".format(tf_version)) + assert cdef["Environment"] == {} predictor = model.deploy(INSTANCE_COUNT, INSTANCE_TYPE) assert isinstance(predictor, Predictor) def test_tfs_model_image_accelerator(sagemaker_session, tf_version): - model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version, - sagemaker_session=sagemaker_session) + model = Model( + "s3://some/data.tar.gz", + role=ROLE, + framework_version=tf_version, + sagemaker_session=sagemaker_session, + ) cdef = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) - assert cdef['Image'].endswith('sagemaker-tensorflow-serving-eia:{}-cpu'.format(tf_version)) + assert cdef["Image"].endswith("sagemaker-tensorflow-serving-eia:{}-cpu".format(tf_version)) predictor = model.deploy(INSTANCE_COUNT, INSTANCE_TYPE) assert isinstance(predictor, Predictor) def test_tfs_model_with_log_level(sagemaker_session, tf_version): - model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version, - container_log_level=logging.INFO, - sagemaker_session=sagemaker_session) + model = Model( + "s3://some/data.tar.gz", + role=ROLE, + framework_version=tf_version, + container_log_level=logging.INFO, + sagemaker_session=sagemaker_session, + ) cdef = model.prepare_container_def(INSTANCE_TYPE) - assert cdef['Environment'] == {Model.LOG_LEVEL_PARAM_NAME: 'info'} + assert cdef["Environment"] == {Model.LOG_LEVEL_PARAM_NAME: "info"} def test_tfs_model_with_custom_image(sagemaker_session, tf_version): - model = Model("s3://some/data.tar.gz", role=ROLE, framework_version=tf_version, - image='my-image', - sagemaker_session=sagemaker_session) + model = Model( + "s3://some/data.tar.gz", + role=ROLE, + framework_version=tf_version, + image="my-image", + sagemaker_session=sagemaker_session, + ) cdef = model.prepare_container_def(INSTANCE_TYPE) - assert cdef['Image'] == 'my-image' - - -@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') -@mock.patch('sagemaker.utils.repack_model') -def test_tfs_model_with_entry_point(repack_model, model_code_key_prefix, sagemaker_session, - tf_version): - model = Model("s3://some/data.tar.gz", - entry_point='train.py', - role=ROLE, framework_version=tf_version, - image='my-image', sagemaker_session=sagemaker_session) + assert cdef["Image"] == "my-image" + + +@mock.patch("sagemaker.fw_utils.model_code_key_prefix", return_value="key-prefix") +@mock.patch("sagemaker.utils.repack_model") +def test_tfs_model_with_entry_point( + repack_model, model_code_key_prefix, sagemaker_session, tf_version +): + model = Model( + "s3://some/data.tar.gz", + entry_point="train.py", + role=ROLE, + framework_version=tf_version, + image="my-image", + sagemaker_session=sagemaker_session, + ) model.prepare_container_def(INSTANCE_TYPE) model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) - repack_model.assert_called_with('train.py', None, [], 's3://some/data.tar.gz', - 's3://my_bucket/key-prefix/model.tar.gz', - sagemaker_session) + repack_model.assert_called_with( + "train.py", + None, + [], + "s3://some/data.tar.gz", + "s3://my_bucket/key-prefix/model.tar.gz", + sagemaker_session, + ) -@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') -@mock.patch('sagemaker.utils.repack_model') +@mock.patch("sagemaker.fw_utils.model_code_key_prefix", return_value="key-prefix") +@mock.patch("sagemaker.utils.repack_model") def test_tfs_model_with_source(repack_model, model_code_key_prefix, sagemaker_session, tf_version): - model = Model("s3://some/data.tar.gz", - entry_point='train.py', - source_dir='src', - role=ROLE, framework_version=tf_version, - image='my-image', sagemaker_session=sagemaker_session) + model = Model( + "s3://some/data.tar.gz", + entry_point="train.py", + source_dir="src", + role=ROLE, + framework_version=tf_version, + image="my-image", + sagemaker_session=sagemaker_session, + ) model.prepare_container_def(INSTANCE_TYPE) model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) - repack_model.assert_called_with('train.py', 'src', [], 's3://some/data.tar.gz', - 's3://my_bucket/key-prefix/model.tar.gz', - sagemaker_session) - - -@mock.patch('sagemaker.fw_utils.model_code_key_prefix', return_value='key-prefix') -@mock.patch('sagemaker.utils.repack_model') -def test_tfs_model_with_dependencies(repack_model, model_code_key_prefix, sagemaker_session, tf_version): - model = Model("s3://some/data.tar.gz", - entry_point='train.py', - dependencies=['src', 'lib'], - role=ROLE, framework_version=tf_version, - image='my-image', sagemaker_session=sagemaker_session) + repack_model.assert_called_with( + "train.py", + "src", + [], + "s3://some/data.tar.gz", + "s3://my_bucket/key-prefix/model.tar.gz", + sagemaker_session, + ) + + +@mock.patch("sagemaker.fw_utils.model_code_key_prefix", return_value="key-prefix") +@mock.patch("sagemaker.utils.repack_model") +def test_tfs_model_with_dependencies( + repack_model, model_code_key_prefix, sagemaker_session, tf_version +): + model = Model( + "s3://some/data.tar.gz", + entry_point="train.py", + dependencies=["src", "lib"], + role=ROLE, + framework_version=tf_version, + image="my-image", + sagemaker_session=sagemaker_session, + ) model.prepare_container_def(INSTANCE_TYPE) model_code_key_prefix.assert_called_with(model.key_prefix, model.name, model.image) - repack_model.assert_called_with('train.py', None, ['src', 'lib'], 's3://some/data.tar.gz', - 's3://my_bucket/key-prefix/model.tar.gz', - sagemaker_session) + repack_model.assert_called_with( + "train.py", + None, + ["src", "lib"], + "s3://some/data.tar.gz", + "s3://my_bucket/key-prefix/model.tar.gz", + sagemaker_session, + ) def test_estimator_deploy(sagemaker_session): container_log_level = '"logging.INFO"' - source_dir = 's3://mybucket/source' - custom_image = 'custom:1.0' - tf = TensorFlow(entry_point='script.py', role=ROLE, sagemaker_session=sagemaker_session, - training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT, - train_instance_type=INSTANCE_TYPE, image_name=custom_image, - container_log_level=container_log_level, base_job_name='job', - source_dir=source_dir) - - job_name = 'doing something' - tf.fit(inputs='s3://mybucket/train', job_name=job_name) - predictor = tf.deploy(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name='endpoint', - endpoint_type='tensorflow-serving') + source_dir = "s3://mybucket/source" + custom_image = "custom:1.0" + tf = TensorFlow( + entry_point="script.py", + role=ROLE, + sagemaker_session=sagemaker_session, + training_steps=1000, + evaluation_steps=10, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + image_name=custom_image, + container_log_level=container_log_level, + base_job_name="job", + source_dir=source_dir, + ) + + job_name = "doing something" + tf.fit(inputs="s3://mybucket/train", job_name=job_name) + predictor = tf.deploy( + INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name="endpoint", endpoint_type="tensorflow-serving" + ) assert isinstance(predictor, Predictor) def test_predictor(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session) + predictor = Predictor("endpoint", sagemaker_session) - mock_response(json.dumps(PREDICT_RESPONSE).encode('utf-8'), sagemaker_session) + mock_response(json.dumps(PREDICT_RESPONSE).encode("utf-8"), sagemaker_session) result = predictor.predict(PREDICT_INPUT) - assert_invoked(sagemaker_session, - EndpointName='endpoint', - ContentType=JSON_CONTENT_TYPE, - Accept=JSON_CONTENT_TYPE, - Body=json.dumps(PREDICT_INPUT)) + assert_invoked( + sagemaker_session, + EndpointName="endpoint", + ContentType=JSON_CONTENT_TYPE, + Accept=JSON_CONTENT_TYPE, + Body=json.dumps(PREDICT_INPUT), + ) assert PREDICT_RESPONSE == result def test_predictor_jsons(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session, serializer=None, - content_type='application/jsons') + predictor = Predictor( + "endpoint", sagemaker_session, serializer=None, content_type="application/jsons" + ) - mock_response(json.dumps(PREDICT_RESPONSE).encode('utf-8'), sagemaker_session) - result = predictor.predict('[1.0, 2.0, 3.0]\n[4.0, 5.0, 6.0]') + mock_response(json.dumps(PREDICT_RESPONSE).encode("utf-8"), sagemaker_session) + result = predictor.predict("[1.0, 2.0, 3.0]\n[4.0, 5.0, 6.0]") - assert_invoked(sagemaker_session, - EndpointName='endpoint', - ContentType='application/jsons', - Accept=JSON_CONTENT_TYPE, - Body='[1.0, 2.0, 3.0]\n[4.0, 5.0, 6.0]') + assert_invoked( + sagemaker_session, + EndpointName="endpoint", + ContentType="application/jsons", + Accept=JSON_CONTENT_TYPE, + Body="[1.0, 2.0, 3.0]\n[4.0, 5.0, 6.0]", + ) assert PREDICT_RESPONSE == result def test_predictor_csv(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session, serializer=csv_serializer) + predictor = Predictor("endpoint", sagemaker_session, serializer=csv_serializer) - mock_response(json.dumps(PREDICT_RESPONSE).encode('utf-8'), sagemaker_session) + mock_response(json.dumps(PREDICT_RESPONSE).encode("utf-8"), sagemaker_session) result = predictor.predict([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - assert_invoked(sagemaker_session, - EndpointName='endpoint', - ContentType=CSV_CONTENT_TYPE, - Accept=JSON_CONTENT_TYPE, - Body='1.0,2.0,3.0\n4.0,5.0,6.0') + assert_invoked( + sagemaker_session, + EndpointName="endpoint", + ContentType=CSV_CONTENT_TYPE, + Accept=JSON_CONTENT_TYPE, + Body="1.0,2.0,3.0\n4.0,5.0,6.0", + ) assert PREDICT_RESPONSE == result def test_predictor_model_attributes(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session, model_name='model', model_version='123') + predictor = Predictor("endpoint", sagemaker_session, model_name="model", model_version="123") - mock_response(json.dumps(PREDICT_RESPONSE).encode('utf-8'), sagemaker_session) + mock_response(json.dumps(PREDICT_RESPONSE).encode("utf-8"), sagemaker_session) result = predictor.predict(PREDICT_INPUT) - assert_invoked(sagemaker_session, - EndpointName='endpoint', - ContentType=JSON_CONTENT_TYPE, - Accept=JSON_CONTENT_TYPE, - CustomAttributes='tfs-model-name=model,tfs-model-version=123', - Body=json.dumps(PREDICT_INPUT)) + assert_invoked( + sagemaker_session, + EndpointName="endpoint", + ContentType=JSON_CONTENT_TYPE, + Accept=JSON_CONTENT_TYPE, + CustomAttributes="tfs-model-name=model,tfs-model-version=123", + Body=json.dumps(PREDICT_INPUT), + ) assert PREDICT_RESPONSE == result def test_predictor_classify(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session) + predictor = Predictor("endpoint", sagemaker_session) - mock_response(json.dumps(CLASSIFY_RESPONSE).encode('utf-8'), sagemaker_session) + mock_response(json.dumps(CLASSIFY_RESPONSE).encode("utf-8"), sagemaker_session) result = predictor.classify(CLASSIFY_INPUT) - assert_invoked_with_body_dict(sagemaker_session, - EndpointName='endpoint', - ContentType=JSON_CONTENT_TYPE, - Accept=JSON_CONTENT_TYPE, - CustomAttributes='tfs-method=classify', - Body=json.dumps(CLASSIFY_INPUT)) + assert_invoked_with_body_dict( + sagemaker_session, + EndpointName="endpoint", + ContentType=JSON_CONTENT_TYPE, + Accept=JSON_CONTENT_TYPE, + CustomAttributes="tfs-method=classify", + Body=json.dumps(CLASSIFY_INPUT), + ) assert CLASSIFY_RESPONSE == result def test_predictor_regress(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session, model_name='model', model_version='123') + predictor = Predictor("endpoint", sagemaker_session, model_name="model", model_version="123") - mock_response(json.dumps(REGRESS_RESPONSE).encode('utf-8'), sagemaker_session) + mock_response(json.dumps(REGRESS_RESPONSE).encode("utf-8"), sagemaker_session) result = predictor.regress(REGRESS_INPUT) - assert_invoked_with_body_dict(sagemaker_session, - EndpointName='endpoint', - ContentType=JSON_CONTENT_TYPE, - Accept=JSON_CONTENT_TYPE, - CustomAttributes='tfs-method=regress,tfs-model-name=model,tfs-model-version=123', - Body=json.dumps(REGRESS_INPUT)) + assert_invoked_with_body_dict( + sagemaker_session, + EndpointName="endpoint", + ContentType=JSON_CONTENT_TYPE, + Accept=JSON_CONTENT_TYPE, + CustomAttributes="tfs-method=regress,tfs-model-name=model,tfs-model-version=123", + Body=json.dumps(REGRESS_INPUT), + ) assert REGRESS_RESPONSE == result def test_predictor_regress_bad_content_type(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session, csv_serializer) + predictor = Predictor("endpoint", sagemaker_session, csv_serializer) with pytest.raises(ValueError): predictor.regress(REGRESS_INPUT) def test_predictor_classify_bad_content_type(sagemaker_session): - predictor = Predictor('endpoint', sagemaker_session, csv_serializer) + predictor = Predictor("endpoint", sagemaker_session, csv_serializer) with pytest.raises(ValueError): predictor.classify(CLASSIFY_INPUT) @@ -293,7 +361,7 @@ def assert_invoked_with_body_dict(sagemaker_session, **kwargs): assert not cargs assert len(kwargs) == len(ckwargs) for k in ckwargs: - if k != 'Body': + if k != "Body": assert kwargs[k] == ckwargs[k] else: actual_body = json.loads(ckwargs[k]) @@ -305,6 +373,6 @@ def assert_invoked_with_body_dict(sagemaker_session, **kwargs): def mock_response(expected_response, sagemaker_session, content_type=JSON_CONTENT_TYPE): sagemaker_session.sagemaker_runtime_client.invoke_endpoint.return_value = { - 'ContentType': content_type, - 'Body': io.BytesIO(expected_response) + "ContentType": content_type, + "Body": io.BytesIO(expected_response), } diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index d549a6ad37..88c4495e37 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -18,105 +18,125 @@ from sagemaker.transformer import _TransformJob, Transformer from tests.integ import test_local_mode -MODEL_NAME = 'model' -IMAGE_NAME = 'image-for-model' -JOB_NAME = 'job' +MODEL_NAME = "model" +IMAGE_NAME = "image-for-model" +JOB_NAME = "job" INSTANCE_COUNT = 1 -INSTANCE_TYPE = 'ml.m4.xlarge' -KMS_KEY_ID = 'kms-key-id' +INSTANCE_TYPE = "ml.m4.xlarge" +KMS_KEY_ID = "kms-key-id" -S3_DATA_TYPE = 'S3Prefix' -S3_BUCKET = 'bucket' -DATA = 's3://{}/input-data'.format(S3_BUCKET) -OUTPUT_PATH = 's3://{}/output'.format(S3_BUCKET) +S3_DATA_TYPE = "S3Prefix" +S3_BUCKET = "bucket" +DATA = "s3://{}/input-data".format(S3_BUCKET) +OUTPUT_PATH = "s3://{}/output".format(S3_BUCKET) -TIMESTAMP = '2018-07-12' +TIMESTAMP = "2018-07-12" INIT_PARAMS = { - 'model_name': MODEL_NAME, - 'instance_count': INSTANCE_COUNT, - 'instance_type': INSTANCE_TYPE, - 'base_transform_job_name': JOB_NAME + "model_name": MODEL_NAME, + "instance_count": INSTANCE_COUNT, + "instance_type": INSTANCE_TYPE, + "base_transform_job_name": JOB_NAME, } -MODEL_DESC_PRIMARY_CONTAINER = { - 'PrimaryContainer': { - 'Image': IMAGE_NAME - } -} +MODEL_DESC_PRIMARY_CONTAINER = {"PrimaryContainer": {"Image": IMAGE_NAME}} -MODEL_DESC_CONTAINERS_ONLY = { - 'Containers': [ - {'Image': IMAGE_NAME} - ] -} +MODEL_DESC_CONTAINERS_ONLY = {"Containers": [{"Image": IMAGE_NAME}]} @pytest.fixture(autouse=True) def mock_create_tar_file(): - with patch('sagemaker.utils.create_tar_file', MagicMock()) as create_tar_file: + with patch("sagemaker.utils.create_tar_file", MagicMock()) as create_tar_file: yield create_tar_file @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session') - return Mock(name='sagemaker_session', boto_session=boto_mock, local_mode=False) + boto_mock = Mock(name="boto_session") + return Mock(name="sagemaker_session", boto_session=boto_mock, local_mode=False) @pytest.fixture() def transformer(sagemaker_session): - return Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, - output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session, - volume_kms_key=KMS_KEY_ID) + return Transformer( + MODEL_NAME, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + volume_kms_key=KMS_KEY_ID, + ) def test_delete_model(sagemaker_session): - transformer = Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session) + transformer = Transformer( + MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + ) transformer.delete_model() sagemaker_session.delete_model.assert_called_with(MODEL_NAME) def test_transformer_fails_without_model(): - transformer = Transformer(model_name='remote-model', - sagemaker_session=test_local_mode.LocalNoS3Session(), - instance_type='local', - instance_count=1) + transformer = Transformer( + model_name="remote-model", + sagemaker_session=test_local_mode.LocalNoS3Session(), + instance_type="local", + instance_count=1, + ) with pytest.raises(ValueError) as error: - transformer.transform('empty-data') + transformer.transform("empty-data") - assert str(error.value) == 'Failed to fetch model information for remote-model. ' \ - 'Please ensure that the model exists. ' \ - 'Local instance types require locally created models.' + assert ( + str(error.value) == "Failed to fetch model information for remote-model. " + "Please ensure that the model exists. " + "Local instance types require locally created models." + ) -@patch('sagemaker.transformer._TransformJob.start_new') +@patch("sagemaker.transformer._TransformJob.start_new") def test_transform_with_all_params(start_new_job, transformer): - content_type = 'text/csv' - compression = 'Gzip' - split = 'Line' + content_type = "text/csv" + compression = "Gzip" + split = "Line" input_filter = "$.feature" output_filter = "$['sagemaker_output', 'id']" join_source = "Input" - transformer.transform(DATA, S3_DATA_TYPE, content_type=content_type, compression_type=compression, split_type=split, - job_name=JOB_NAME, input_filter=input_filter, output_filter=output_filter, - join_source=join_source) + transformer.transform( + DATA, + S3_DATA_TYPE, + content_type=content_type, + compression_type=compression, + split_type=split, + job_name=JOB_NAME, + input_filter=input_filter, + output_filter=output_filter, + join_source=join_source, + ) assert transformer._current_job_name == JOB_NAME assert transformer.output_path == OUTPUT_PATH - start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression, - split, input_filter, output_filter, join_source) - - -@patch('sagemaker.transformer.name_from_base') -@patch('sagemaker.transformer._TransformJob.start_new') + start_new_job.assert_called_once_with( + transformer, + DATA, + S3_DATA_TYPE, + content_type, + compression, + split, + input_filter, + output_filter, + join_source, + ) + + +@patch("sagemaker.transformer.name_from_base") +@patch("sagemaker.transformer._TransformJob.start_new") def test_transform_with_base_job_name_provided(start_new_job, name_from_base, transformer): - base_name = 'base-job-name' - full_name = '{}-{}'.format(base_name, TIMESTAMP) + base_name = "base-job-name" + full_name = "{}-{}".format(base_name, TIMESTAMP) transformer.base_transform_job_name = base_name name_from_base.return_value = full_name @@ -127,11 +147,11 @@ def test_transform_with_base_job_name_provided(start_new_job, name_from_base, tr assert transformer._current_job_name == full_name -@patch('sagemaker.transformer.Transformer._retrieve_base_name', return_value=IMAGE_NAME) -@patch('sagemaker.transformer.name_from_base') -@patch('sagemaker.transformer._TransformJob.start_new') +@patch("sagemaker.transformer.Transformer._retrieve_base_name", return_value=IMAGE_NAME) +@patch("sagemaker.transformer.name_from_base") +@patch("sagemaker.transformer._TransformJob.start_new") def test_transform_with_base_name(start_new_job, name_from_base, retrieve_base_name, transformer): - full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) + full_name = "{}-{}".format(IMAGE_NAME, TIMESTAMP) name_from_base.return_value = full_name transformer.transform(DATA) @@ -141,11 +161,13 @@ def test_transform_with_base_name(start_new_job, name_from_base, retrieve_base_n assert transformer._current_job_name == full_name -@patch('sagemaker.transformer.Transformer._retrieve_image_name', return_value=IMAGE_NAME) -@patch('sagemaker.transformer.name_from_base') -@patch('sagemaker.transformer._TransformJob.start_new') -def test_transform_with_job_name_based_on_image(start_new_job, name_from_base, retrieve_image_name, transformer): - full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) +@patch("sagemaker.transformer.Transformer._retrieve_image_name", return_value=IMAGE_NAME) +@patch("sagemaker.transformer.name_from_base") +@patch("sagemaker.transformer._TransformJob.start_new") +def test_transform_with_job_name_based_on_image( + start_new_job, name_from_base, retrieve_image_name, transformer +): + full_name = "{}-{}".format(IMAGE_NAME, TIMESTAMP) name_from_base.return_value = full_name transformer.transform(DATA) @@ -155,70 +177,76 @@ def test_transform_with_job_name_based_on_image(start_new_job, name_from_base, r assert transformer._current_job_name == full_name -@pytest.mark.parametrize('model_desc', [MODEL_DESC_PRIMARY_CONTAINER, - MODEL_DESC_CONTAINERS_ONLY]) -@patch('sagemaker.transformer.name_from_base') -@patch('sagemaker.transformer._TransformJob.start_new') -def test_transform_with_job_name_based_on_containers(start_new_job, name_from_base, model_desc, transformer): +@pytest.mark.parametrize("model_desc", [MODEL_DESC_PRIMARY_CONTAINER, MODEL_DESC_CONTAINERS_ONLY]) +@patch("sagemaker.transformer.name_from_base") +@patch("sagemaker.transformer._TransformJob.start_new") +def test_transform_with_job_name_based_on_containers( + start_new_job, name_from_base, model_desc, transformer +): transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc - full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) + full_name = "{}-{}".format(IMAGE_NAME, TIMESTAMP) name_from_base.return_value = full_name transformer.transform(DATA) - transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME) + transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with( + ModelName=MODEL_NAME + ) name_from_base.assert_called_once_with(IMAGE_NAME) assert transformer._current_job_name == full_name -@pytest.mark.parametrize('model_desc', [{'PrimaryContainer': dict()}, - {'Containers': [dict()]}, - dict(), - ]) -@patch('sagemaker.transformer.name_from_base') -@patch('sagemaker.transformer._TransformJob.start_new') -def test_transform_with_job_name_based_on_model_name(start_new_job, name_from_base, model_desc, transformer): +@pytest.mark.parametrize( + "model_desc", [{"PrimaryContainer": dict()}, {"Containers": [dict()]}, dict()] +) +@patch("sagemaker.transformer.name_from_base") +@patch("sagemaker.transformer._TransformJob.start_new") +def test_transform_with_job_name_based_on_model_name( + start_new_job, name_from_base, model_desc, transformer +): transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc - full_name = '{}-{}'.format(MODEL_NAME, TIMESTAMP) + full_name = "{}-{}".format(MODEL_NAME, TIMESTAMP) name_from_base.return_value = full_name transformer.transform(DATA) - transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME) + transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with( + ModelName=MODEL_NAME + ) name_from_base.assert_called_once_with(MODEL_NAME) assert transformer._current_job_name == full_name -@patch('sagemaker.transformer._TransformJob.start_new') +@patch("sagemaker.transformer._TransformJob.start_new") def test_transform_with_generated_output_path(start_new_job, transformer, sagemaker_session): transformer.output_path = None sagemaker_session.default_bucket.return_value = S3_BUCKET transformer.transform(DATA, job_name=JOB_NAME) - assert transformer.output_path == 's3://{}/{}'.format(S3_BUCKET, JOB_NAME) + assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, JOB_NAME) def test_transform_with_invalid_s3_uri(transformer): with pytest.raises(ValueError) as e: - transformer.transform('not-an-s3-uri') + transformer.transform("not-an-s3-uri") - assert 'Invalid S3 URI' in str(e) + assert "Invalid S3 URI" in str(e) def test_retrieve_image_name(sagemaker_session, transformer): - sage_mock = Mock(name='sagemaker_client') - sage_mock.describe_model.return_value = {'PrimaryContainer': {'Image': IMAGE_NAME}} + sage_mock = Mock(name="sagemaker_client") + sage_mock.describe_model.return_value = {"PrimaryContainer": {"Image": IMAGE_NAME}} sagemaker_session.sagemaker_client = sage_mock assert transformer._retrieve_image_name() == IMAGE_NAME -@patch('sagemaker.transformer.Transformer._ensure_last_transform_job') +@patch("sagemaker.transformer.Transformer._ensure_last_transform_job") def test_wait(ensure_last_transform_job, transformer): - transformer.latest_transform_job = Mock(name='latest_transform_job') + transformer.latest_transform_job = Mock(name="latest_transform_job") transformer.wait() @@ -227,7 +255,7 @@ def test_wait(ensure_last_transform_job, transformer): def test_ensure_last_transform_job_exists(transformer, sagemaker_session): - transformer.latest_transform_job = _TransformJob(sagemaker_session, 'some-transform-job') + transformer.latest_transform_job = _TransformJob(sagemaker_session, "some-transform-job") transformer._ensure_last_transform_job() @@ -236,12 +264,15 @@ def test_ensure_last_transform_job_none(transformer): with pytest.raises(ValueError) as e: transformer._ensure_last_transform_job() - assert 'No transform job available' in str(e) + assert "No transform job available" in str(e) -@patch('sagemaker.transformer.Transformer._prepare_init_params_from_job_description', return_value=INIT_PARAMS) +@patch( + "sagemaker.transformer.Transformer._prepare_init_params_from_job_description", + return_value=INIT_PARAMS, +) def test_attach(prepare_init_params, transformer, sagemaker_session): - sagemaker_session.sagemaker_client.describe_transform_job = Mock(name='describe_transform_job') + sagemaker_session.sagemaker_client.describe_transform_job = Mock(name="describe_transform_job") attached = Transformer.attach(JOB_NAME, sagemaker_session) assert prepare_init_params.called_once @@ -253,60 +284,55 @@ def test_attach(prepare_init_params, transformer, sagemaker_session): def test_prepare_init_params_from_job_description_missing_keys(transformer): job_details = { - 'ModelName': MODEL_NAME, - 'TransformResources': { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE - }, - 'TransformOutput': { - 'S3OutputPath': None - }, - 'TransformJobName': JOB_NAME + "ModelName": MODEL_NAME, + "TransformResources": {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE}, + "TransformOutput": {"S3OutputPath": None}, + "TransformJobName": JOB_NAME, } init_params = transformer._prepare_init_params_from_job_description(job_details) - assert init_params['model_name'] == MODEL_NAME - assert init_params['instance_count'] == INSTANCE_COUNT - assert init_params['instance_type'] == INSTANCE_TYPE + assert init_params["model_name"] == MODEL_NAME + assert init_params["instance_count"] == INSTANCE_COUNT + assert init_params["instance_type"] == INSTANCE_TYPE def test_prepare_init_params_from_job_description_all_keys(transformer): job_details = { - 'ModelName': MODEL_NAME, - 'TransformResources': { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeKmsKeyId': KMS_KEY_ID + "ModelName": MODEL_NAME, + "TransformResources": { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeKmsKeyId": KMS_KEY_ID, }, - 'BatchStrategy': None, - 'TransformOutput': { - 'AssembleWith': None, - 'S3OutputPath': None, - 'KmsKeyId': None, - 'Accept': None + "BatchStrategy": None, + "TransformOutput": { + "AssembleWith": None, + "S3OutputPath": None, + "KmsKeyId": None, + "Accept": None, }, - 'MaxConcurrentTransforms': None, - 'MaxPayloadInMB': None, - 'TransformJobName': JOB_NAME + "MaxConcurrentTransforms": None, + "MaxPayloadInMB": None, + "TransformJobName": JOB_NAME, } init_params = transformer._prepare_init_params_from_job_description(job_details) - assert init_params['model_name'] == MODEL_NAME - assert init_params['instance_count'] == INSTANCE_COUNT - assert init_params['instance_type'] == INSTANCE_TYPE - assert init_params['volume_kms_key'] == KMS_KEY_ID + assert init_params["model_name"] == MODEL_NAME + assert init_params["instance_count"] == INSTANCE_COUNT + assert init_params["instance_type"] == INSTANCE_TYPE + assert init_params["volume_kms_key"] == KMS_KEY_ID # _TransformJob tests + def test_start_new(transformer, sagemaker_session): transformer._current_job_name = JOB_NAME job = _TransformJob(sagemaker_session, JOB_NAME) - started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None, - None, None, None) + started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None, None, None, None) assert started_job.sagemaker_session == sagemaker_session sagemaker_session.transform.assert_called_once() @@ -314,21 +340,14 @@ def test_start_new(transformer, sagemaker_session): def test_load_config(transformer): expected_config = { - 'input_config': { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': S3_DATA_TYPE, - 'S3Uri': DATA, - }, - }, - }, - 'output_config': { - 'S3OutputPath': OUTPUT_PATH, + "input_config": { + "DataSource": {"S3DataSource": {"S3DataType": S3_DATA_TYPE, "S3Uri": DATA}} }, - 'resource_config': { - 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE, - 'VolumeKmsKeyId': KMS_KEY_ID, + "output_config": {"S3OutputPath": OUTPUT_PATH}, + "resource_config": { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeKmsKeyId": KMS_KEY_ID, }, } @@ -337,79 +356,77 @@ def test_load_config(transformer): def test_format_inputs_to_input_config(): - expected_config = { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': S3_DATA_TYPE, - 'S3Uri': DATA, - }, - }, - } + expected_config = {"DataSource": {"S3DataSource": {"S3DataType": S3_DATA_TYPE, "S3Uri": DATA}}} - actual_config = _TransformJob._format_inputs_to_input_config(DATA, S3_DATA_TYPE, None, None, None) + actual_config = _TransformJob._format_inputs_to_input_config( + DATA, S3_DATA_TYPE, None, None, None + ) assert actual_config == expected_config def test_format_inputs_to_input_config_with_optional_params(): - compression = 'Gzip' - content_type = 'text/csv' - split = 'Line' + compression = "Gzip" + content_type = "text/csv" + split = "Line" expected_config = { - 'DataSource': { - 'S3DataSource': { - 'S3DataType': S3_DATA_TYPE, - 'S3Uri': DATA, - }, - }, - 'CompressionType': compression, - 'ContentType': content_type, - 'SplitType': split, + "DataSource": {"S3DataSource": {"S3DataType": S3_DATA_TYPE, "S3Uri": DATA}}, + "CompressionType": compression, + "ContentType": content_type, + "SplitType": split, } - actual_config = _TransformJob._format_inputs_to_input_config(DATA, S3_DATA_TYPE, content_type, compression, split) + actual_config = _TransformJob._format_inputs_to_input_config( + DATA, S3_DATA_TYPE, content_type, compression, split + ) assert actual_config == expected_config def test_prepare_output_config(): config = _TransformJob._prepare_output_config(OUTPUT_PATH, None, None, None) - assert config == {'S3OutputPath': OUTPUT_PATH} + assert config == {"S3OutputPath": OUTPUT_PATH} def test_prepare_output_config_with_optional_params(): - kms_key = 'key' - assemble_with = 'Line' - accept = 'text/csv' + kms_key = "key" + assemble_with = "Line" + accept = "text/csv" expected_config = { - 'S3OutputPath': OUTPUT_PATH, - 'KmsKeyId': kms_key, - 'AssembleWith': assemble_with, - 'Accept': accept, + "S3OutputPath": OUTPUT_PATH, + "KmsKeyId": kms_key, + "AssembleWith": assemble_with, + "Accept": accept, } - actual_config = _TransformJob._prepare_output_config(OUTPUT_PATH, kms_key, assemble_with, accept) + actual_config = _TransformJob._prepare_output_config( + OUTPUT_PATH, kms_key, assemble_with, accept + ) assert actual_config == expected_config def test_prepare_resource_config(): config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, KMS_KEY_ID) - assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID} + assert config == { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeKmsKeyId": KMS_KEY_ID, + } def test_data_processing_config(): actual_config = _TransformJob._prepare_data_processing("$", None, None) - assert actual_config == {'InputFilter': "$"} + assert actual_config == {"InputFilter": "$"} actual_config = _TransformJob._prepare_data_processing(None, "$", None) - assert actual_config == {'OutputFilter': "$"} + assert actual_config == {"OutputFilter": "$"} actual_config = _TransformJob._prepare_data_processing(None, None, "Input") - assert actual_config == {'JoinSource': "Input"} + assert actual_config == {"JoinSource": "Input"} actual_config = _TransformJob._prepare_data_processing("$[0]", "$[1]", "Input") - assert actual_config == {'InputFilter': "$[0]", 'OutputFilter': "$[1]", 'JoinSource': "Input"} + assert actual_config == {"InputFilter": "$[0]", "OutputFilter": "$[1]", "JoinSource": "Input"} actual_config = _TransformJob._prepare_data_processing(None, None, None) assert actual_config is None diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index df931451d0..199ccc4624 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -23,128 +23,127 @@ from sagemaker.amazon.pca import PCA from sagemaker.estimator import Estimator from sagemaker.mxnet import MXNet -from sagemaker.parameter import (CategoricalParameter, ContinuousParameter, - IntegerParameter, ParameterRange) -from sagemaker.tuner import (_TuningJob, create_identical_dataset_and_algorithm_tuner, - create_transfer_learning_tuner, HyperparameterTuner, WarmStartConfig, - WarmStartTypes) +from sagemaker.parameter import ( + CategoricalParameter, + ContinuousParameter, + IntegerParameter, + ParameterRange, +) +from sagemaker.tuner import ( + _TuningJob, + create_identical_dataset_and_algorithm_tuner, + create_transfer_learning_tuner, + HyperparameterTuner, + WarmStartConfig, + WarmStartTypes, +) from sagemaker.session import s3_input -DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") MODEL_DATA = "s3://bucket/model.tar.gz" -JOB_NAME = 'tuning_job' -REGION = 'us-west-2' -BUCKET_NAME = 'Some-Bucket' -ROLE = 'myrole' -IMAGE_NAME = 'image' +JOB_NAME = "tuning_job" +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" +ROLE = "myrole" +IMAGE_NAME = "image" TRAIN_INSTANCE_COUNT = 1 -TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +TRAIN_INSTANCE_TYPE = "ml.c4.xlarge" NUM_COMPONENTS = 5 -SCRIPT_NAME = 'my_script.py' -FRAMEWORK_VERSION = '1.0.0' +SCRIPT_NAME = "my_script.py" +FRAMEWORK_VERSION = "1.0.0" -INPUTS = 's3://mybucket/train' -OBJECTIVE_METRIC_NAME = 'mock_metric' -HYPERPARAMETER_RANGES = {'validated': ContinuousParameter(0, 5), - 'elizabeth': IntegerParameter(0, 5), - 'blank': CategoricalParameter([0, 5])} -METRIC_DEFINITIONS = 'mock_metric_definitions' +INPUTS = "s3://mybucket/train" +OBJECTIVE_METRIC_NAME = "mock_metric" +HYPERPARAMETER_RANGES = { + "validated": ContinuousParameter(0, 5), + "elizabeth": IntegerParameter(0, 5), + "blank": CategoricalParameter([0, 5]), +} +METRIC_DEFINITIONS = "mock_metric_definitions" TUNING_JOB_DETAILS = { - 'HyperParameterTuningJobConfig': { - 'ResourceLimits': { - 'MaxParallelTrainingJobs': 1, - 'MaxNumberOfTrainingJobs': 1 - }, - 'HyperParameterTuningJobObjective': { - 'MetricName': OBJECTIVE_METRIC_NAME, - 'Type': 'Minimize' + "HyperParameterTuningJobConfig": { + "ResourceLimits": {"MaxParallelTrainingJobs": 1, "MaxNumberOfTrainingJobs": 1}, + "HyperParameterTuningJobObjective": { + "MetricName": OBJECTIVE_METRIC_NAME, + "Type": "Minimize", }, - 'Strategy': 'Bayesian', - 'ParameterRanges': { - 'CategoricalParameterRanges': [], - 'ContinuousParameterRanges': [], - 'IntegerParameterRanges': [ + "Strategy": "Bayesian", + "ParameterRanges": { + "CategoricalParameterRanges": [], + "ContinuousParameterRanges": [], + "IntegerParameterRanges": [ { - 'MaxValue': '100', - 'Name': 'mini_batch_size', - 'MinValue': '10', - 'ScalingType': 'Auto' - }, - ] + "MaxValue": "100", + "Name": "mini_batch_size", + "MinValue": "10", + "ScalingType": "Auto", + } + ], }, - 'TrainingJobEarlyStoppingType': 'Off' + "TrainingJobEarlyStoppingType": "Off", }, - 'HyperParameterTuningJobName': JOB_NAME, - 'TrainingJobDefinition': { - 'RoleArn': ROLE, - 'StaticHyperParameters': { - 'num_components': '1', - '_tuning_objective_metric': 'train:throughput', - 'feature_dim': '784', - 'sagemaker_estimator_module': '"sagemaker.amazon.pca"', - 'sagemaker_estimator_class_name': '"PCA"', + "HyperParameterTuningJobName": JOB_NAME, + "TrainingJobDefinition": { + "RoleArn": ROLE, + "StaticHyperParameters": { + "num_components": "1", + "_tuning_objective_metric": "train:throughput", + "feature_dim": "784", + "sagemaker_estimator_module": '"sagemaker.amazon.pca"', + "sagemaker_estimator_class_name": '"PCA"', }, - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceType': 'ml.c4.xlarge', - 'InstanceCount': 1 + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceType": "ml.c4.xlarge", + "InstanceCount": 1, }, - 'AlgorithmSpecification': { - 'TrainingImage': IMAGE_NAME, - 'TrainingInputMode': 'File', - 'MetricDefinitions': METRIC_DEFINITIONS, + "AlgorithmSpecification": { + "TrainingImage": IMAGE_NAME, + "TrainingInputMode": "File", + "MetricDefinitions": METRIC_DEFINITIONS, }, - 'InputDataConfig': [ + "InputDataConfig": [ { - 'ChannelName': 'train', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'ShardedByS3Key', - 'S3Uri': INPUTS, - 'S3DataType': 'ManifestFile' + "ChannelName": "train", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "ShardedByS3Key", + "S3Uri": INPUTS, + "S3DataType": "ManifestFile", } - } + }, } ], - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 86400 - }, - 'OutputDataConfig': { - 'S3OutputPath': BUCKET_NAME, - } + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "OutputDataConfig": {"S3OutputPath": BUCKET_NAME}, }, - 'TrainingJobCounters': { - 'ClientError': 0, - 'Completed': 1, - 'InProgress': 0, - 'Fault': 0, - 'Stopped': 0 + "TrainingJobCounters": { + "ClientError": 0, + "Completed": 1, + "InProgress": 0, + "Fault": 0, + "Stopped": 0, }, - 'HyperParameterTuningEndTime': 1526605831.0, - 'CreationTime': 1526605605.0, - 'HyperParameterTuningJobArn': 'arn:tuning_job', + "HyperParameterTuningEndTime": 1526605831.0, + "CreationTime": 1526605605.0, + "HyperParameterTuningJobArn": "arn:tuning_job", } -ENDPOINT_DESC = { - 'EndpointConfigName': 'test-endpoint' -} +ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"} -ENDPOINT_CONFIG_DESC = { - 'ProductionVariants': [{'ModelName': 'model-1'}, - {'ModelName': 'model-2'}] -} +ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session', region_name=REGION) - sms = Mock(name='sagemaker_session', boto_session=boto_mock) + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock(name="sagemaker_session", boto_session=boto_mock) sms.boto_region_name = REGION - sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.config = None sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) @@ -155,90 +154,131 @@ def sagemaker_session(): @pytest.fixture() def estimator(sagemaker_session): - return Estimator(IMAGE_NAME, ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, output_path='s3://bucket/prefix', - sagemaker_session=sagemaker_session) + return Estimator( + IMAGE_NAME, + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + output_path="s3://bucket/prefix", + sagemaker_session=sagemaker_session, + ) @pytest.fixture() def tuner(estimator): - return HyperparameterTuner(estimator, OBJECTIVE_METRIC_NAME, HYPERPARAMETER_RANGES, METRIC_DEFINITIONS) + return HyperparameterTuner( + estimator, OBJECTIVE_METRIC_NAME, HYPERPARAMETER_RANGES, METRIC_DEFINITIONS + ) def test_prepare_for_training(tuner): - static_hyperparameters = {'validated': 1, 'another_one': 0} + static_hyperparameters = {"validated": 1, "another_one": 0} tuner.estimator.set_hyperparameters(**static_hyperparameters) tuner._prepare_for_training() assert tuner._current_job_name.startswith(IMAGE_NAME) assert len(tuner.static_hyperparameters) == 1 - assert tuner.static_hyperparameters['another_one'] == '0' + assert tuner.static_hyperparameters["another_one"] == "0" def test_prepare_for_training_with_amazon_estimator(tuner, sagemaker_session): - tuner.estimator = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, - sagemaker_session=sagemaker_session) + tuner.estimator = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + sagemaker_session=sagemaker_session, + ) tuner._prepare_for_training() - assert 'sagemaker_estimator_class_name' not in tuner.static_hyperparameters - assert 'sagemaker_estimator_module' not in tuner.static_hyperparameters + assert "sagemaker_estimator_class_name" not in tuner.static_hyperparameters + assert "sagemaker_estimator_module" not in tuner.static_hyperparameters def test_prepare_for_training_include_estimator_cls(tuner): tuner._prepare_for_training(include_cls_metadata=True) - assert 'sagemaker_estimator_class_name' in tuner.static_hyperparameters - assert 'sagemaker_estimator_module' in tuner.static_hyperparameters + assert "sagemaker_estimator_class_name" in tuner.static_hyperparameters + assert "sagemaker_estimator_module" in tuner.static_hyperparameters def test_prepare_for_training_with_job_name(tuner): - static_hyperparameters = {'validated': 1, 'another_one': 0} + static_hyperparameters = {"validated": 1, "another_one": 0} tuner.estimator.set_hyperparameters(**static_hyperparameters) - tuner._prepare_for_training(job_name='some-other-job-name') - assert tuner._current_job_name == 'some-other-job-name' + tuner._prepare_for_training(job_name="some-other-job-name") + assert tuner._current_job_name == "some-other-job-name" def test_validate_parameter_ranges_number_validation_error(sagemaker_session): - pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, - base_job_name='pca', sagemaker_session=sagemaker_session) + pca = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + base_job_name="pca", + sagemaker_session=sagemaker_session, + ) - invalid_hyperparameter_ranges = {'num_components': IntegerParameter(-1, 2)} + invalid_hyperparameter_ranges = {"num_components": IntegerParameter(-1, 2)} with pytest.raises(ValueError) as e: - HyperparameterTuner(estimator=pca, objective_metric_name=OBJECTIVE_METRIC_NAME, - hyperparameter_ranges=invalid_hyperparameter_ranges, metric_definitions=METRIC_DEFINITIONS) + HyperparameterTuner( + estimator=pca, + objective_metric_name=OBJECTIVE_METRIC_NAME, + hyperparameter_ranges=invalid_hyperparameter_ranges, + metric_definitions=METRIC_DEFINITIONS, + ) - assert 'Value must be an integer greater than zero' in str(e) + assert "Value must be an integer greater than zero" in str(e) def test_validate_parameter_ranges_string_value_validation_error(sagemaker_session): - pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, - base_job_name='pca', sagemaker_session=sagemaker_session) + pca = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + base_job_name="pca", + sagemaker_session=sagemaker_session, + ) - invalid_hyperparameter_ranges = {'algorithm_mode': CategoricalParameter([0, 5])} + invalid_hyperparameter_ranges = {"algorithm_mode": CategoricalParameter([0, 5])} with pytest.raises(ValueError) as e: - HyperparameterTuner(estimator=pca, objective_metric_name=OBJECTIVE_METRIC_NAME, - hyperparameter_ranges=invalid_hyperparameter_ranges, metric_definitions=METRIC_DEFINITIONS) + HyperparameterTuner( + estimator=pca, + objective_metric_name=OBJECTIVE_METRIC_NAME, + hyperparameter_ranges=invalid_hyperparameter_ranges, + metric_definitions=METRIC_DEFINITIONS, + ) assert 'Value must be one of "regular" and "randomized"' in str(e) def test_fit_pca(sagemaker_session, tuner): - pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, - base_job_name='pca', sagemaker_session=sagemaker_session) - - pca.algorithm_mode = 'randomized' + pca = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + base_job_name="pca", + sagemaker_session=sagemaker_session, + ) + + pca.algorithm_mode = "randomized" pca.subtract_mean = True pca.extra_components = 5 tuner.estimator = pca - tags = [{'Name': 'some-tag-without-a-value'}] + tags = [{"Name": "some-tag-without-a-value"}] tuner.tags = tags - hyperparameter_ranges = {'num_components': IntegerParameter(2, 4), - 'algorithm_mode': CategoricalParameter(['regular', 'randomized'])} + hyperparameter_ranges = { + "num_components": IntegerParameter(2, 4), + "algorithm_mode": CategoricalParameter(["regular", "randomized"]), + } tuner._hyperparameter_ranges = hyperparameter_ranges records = RecordSet(s3_data=INPUTS, num_records=1, feature_dim=1) @@ -246,76 +286,99 @@ def test_fit_pca(sagemaker_session, tuner): _, _, tune_kwargs = sagemaker_session.tune.mock_calls[0] - assert len(tune_kwargs['static_hyperparameters']) == 4 - assert tune_kwargs['static_hyperparameters']['extra_components'] == '5' - assert len(tune_kwargs['parameter_ranges']['IntegerParameterRanges']) == 1 - assert tune_kwargs['job_name'].startswith('pca') - assert tune_kwargs['tags'] == tags - assert tune_kwargs['early_stopping_type'] == 'Off' + assert len(tune_kwargs["static_hyperparameters"]) == 4 + assert tune_kwargs["static_hyperparameters"]["extra_components"] == "5" + assert len(tune_kwargs["parameter_ranges"]["IntegerParameterRanges"]) == 1 + assert tune_kwargs["job_name"].startswith("pca") + assert tune_kwargs["tags"] == tags + assert tune_kwargs["early_stopping_type"] == "Off" assert tuner.estimator.mini_batch_size == 9999 def test_fit_pca_with_early_stopping(sagemaker_session, tuner): - pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, - base_job_name='pca', sagemaker_session=sagemaker_session) + pca = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + base_job_name="pca", + sagemaker_session=sagemaker_session, + ) tuner.estimator = pca - tuner.early_stopping_type = 'Auto' + tuner.early_stopping_type = "Auto" records = RecordSet(s3_data=INPUTS, num_records=1, feature_dim=1) tuner.fit(records, mini_batch_size=9999) _, _, tune_kwargs = sagemaker_session.tune.mock_calls[0] - assert tune_kwargs['job_name'].startswith('pca') - assert tune_kwargs['early_stopping_type'] == 'Auto' + assert tune_kwargs["job_name"].startswith("pca") + assert tune_kwargs["early_stopping_type"] == "Auto" def test_fit_mxnet_with_vpc_config(sagemaker_session, tuner): - subnets = ['foo'] - security_group_ids = ['bar'] - - pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, - base_job_name='pca', sagemaker_session=sagemaker_session, - subnets=subnets, security_group_ids=security_group_ids) + subnets = ["foo"] + security_group_ids = ["bar"] + + pca = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + base_job_name="pca", + sagemaker_session=sagemaker_session, + subnets=subnets, + security_group_ids=security_group_ids, + ) tuner.estimator = pca records = RecordSet(s3_data=INPUTS, num_records=1, feature_dim=1) tuner.fit(records, mini_batch_size=9999) _, _, tune_kwargs = sagemaker_session.tune.mock_calls[0] - assert tune_kwargs['vpc_config'] == {'Subnets': subnets, 'SecurityGroupIds': security_group_ids} + assert tune_kwargs["vpc_config"] == {"Subnets": subnets, "SecurityGroupIds": security_group_ids} def test_s3_input_mode(sagemaker_session, tuner): - expected_input_mode = 'Pipe' - - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py') - mxnet = MXNet(entry_point=script_path, - role=ROLE, - framework_version=FRAMEWORK_VERSION, - train_instance_count=TRAIN_INSTANCE_COUNT, - train_instance_type=TRAIN_INSTANCE_TYPE, - sagemaker_session=sagemaker_session) + expected_input_mode = "Pipe" + + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py") + mxnet = MXNet( + entry_point=script_path, + role=ROLE, + framework_version=FRAMEWORK_VERSION, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) tuner.estimator = mxnet - tags = [{'Name': 'some-tag-without-a-value'}] + tags = [{"Name": "some-tag-without-a-value"}] tuner.tags = tags - hyperparameter_ranges = {'num_components': IntegerParameter(2, 4), - 'algorithm_mode': CategoricalParameter(['regular', 'randomized'])} + hyperparameter_ranges = { + "num_components": IntegerParameter(2, 4), + "algorithm_mode": CategoricalParameter(["regular", "randomized"]), + } tuner._hyperparameter_ranges = hyperparameter_ranges - tuner.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode)) + tuner.fit(inputs=s3_input("s3://mybucket/train_manifest", input_mode=expected_input_mode)) - actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode'] + actual_input_mode = sagemaker_session.method_calls[1][2]["input_mode"] assert actual_input_mode == expected_input_mode def test_fit_pca_with_inter_container_traffic_encryption_flag(sagemaker_session, tuner): - pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, - base_job_name='pca', sagemaker_session=sagemaker_session, - encrypt_inter_container_traffic=True) + pca = PCA( + ROLE, + TRAIN_INSTANCE_COUNT, + TRAIN_INSTANCE_TYPE, + NUM_COMPONENTS, + base_job_name="pca", + sagemaker_session=sagemaker_session, + encrypt_inter_container_traffic=True, + ) tuner.estimator = pca @@ -324,14 +387,15 @@ def test_fit_pca_with_inter_container_traffic_encryption_flag(sagemaker_session, _, _, tune_kwargs = sagemaker_session.tune.mock_calls[0] - assert tune_kwargs['job_name'].startswith('pca') - assert tune_kwargs['encrypt_inter_container_traffic'] is True + assert tune_kwargs["job_name"].startswith("pca") + assert tune_kwargs["encrypt_inter_container_traffic"] is True def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) assert tuner.latest_tuning_job.name == JOB_NAME @@ -339,46 +403,54 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session assert tuner.max_jobs == 1 assert tuner.max_parallel_jobs == 1 assert tuner.metric_definitions == METRIC_DEFINITIONS - assert tuner.strategy == 'Bayesian' - assert tuner.objective_type == 'Minimize' - assert tuner.early_stopping_type == 'Off' + assert tuner.strategy == "Bayesian" + assert tuner.objective_type == "Minimize" + assert tuner.early_stopping_type == "Off" assert isinstance(tuner.estimator, PCA) assert tuner.estimator.role == ROLE assert tuner.estimator.train_instance_count == 1 assert tuner.estimator.train_max_run == 24 * 60 * 60 - assert tuner.estimator.input_mode == 'File' + assert tuner.estimator.input_mode == "File" assert tuner.estimator.output_path == BUCKET_NAME - assert tuner.estimator.output_kms_key == '' + assert tuner.estimator.output_kms_key == "" - assert '_tuning_objective_metric' not in tuner.estimator.hyperparameters() - assert tuner.estimator.hyperparameters()['num_components'] == '1' + assert "_tuning_objective_metric" not in tuner.estimator.hyperparameters() + assert tuner.estimator.hyperparameters()["num_components"] == "1" -def test_attach_tuning_job_with_estimator_from_hyperparameters_with_early_stopping(sagemaker_session): +def test_attach_tuning_job_with_estimator_from_hyperparameters_with_early_stopping( + sagemaker_session +): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - job_details['HyperParameterTuningJobConfig']['TrainingJobEarlyStoppingType'] = 'Auto' - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + job_details["HyperParameterTuningJobConfig"]["TrainingJobEarlyStoppingType"] = "Auto" + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) assert tuner.latest_tuning_job.name == JOB_NAME - assert tuner.early_stopping_type == 'Auto' + assert tuner.early_stopping_type == "Auto" assert isinstance(tuner.estimator, PCA) def test_attach_tuning_job_with_job_details(sagemaker_session): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session, job_details=job_details) + HyperparameterTuner.attach( + JOB_NAME, sagemaker_session=sagemaker_session, job_details=job_details + ) sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_not_called def test_attach_tuning_job_with_estimator_from_image(sagemaker_session): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - job_details['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = '1111.amazonaws.com/pca:1' - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + job_details["TrainingJobDefinition"]["AlgorithmSpecification"][ + "TrainingImage" + ] = "1111.amazonaws.com/pca:1" + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) assert isinstance(tuner.estimator, PCA) @@ -386,31 +458,39 @@ def test_attach_tuning_job_with_estimator_from_image(sagemaker_session): def test_attach_tuning_job_with_estimator_from_kwarg(sagemaker_session): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) - tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session, - estimator_cls='sagemaker.estimator.Estimator') + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) + tuner = HyperparameterTuner.attach( + JOB_NAME, sagemaker_session=sagemaker_session, estimator_cls="sagemaker.estimator.Estimator" + ) assert isinstance(tuner.estimator, Estimator) def test_attach_with_no_specified_estimator(sagemaker_session): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - del job_details['TrainingJobDefinition']['StaticHyperParameters']['sagemaker_estimator_module'] - del job_details['TrainingJobDefinition']['StaticHyperParameters']['sagemaker_estimator_class_name'] - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + del job_details["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_estimator_module"] + del job_details["TrainingJobDefinition"]["StaticHyperParameters"][ + "sagemaker_estimator_class_name" + ] + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) assert isinstance(tuner.estimator, Estimator) def test_attach_with_warm_start_config(sagemaker_session): - warm_start_config = WarmStartConfig(warm_start_type=WarmStartTypes.TRANSFER_LEARNING, parents={"p1", "p2"}) + warm_start_config = WarmStartConfig( + warm_start_type=WarmStartTypes.TRANSFER_LEARNING, parents={"p1", "p2"} + ) job_details = copy.deepcopy(TUNING_JOB_DETAILS) job_details["WarmStartConfig"] = warm_start_config.to_input_req() - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) assert tuner.warm_start_config.type == warm_start_config.type @@ -421,44 +501,46 @@ def test_serialize_parameter_ranges(tuner): hyperparameter_ranges = tuner.hyperparameter_ranges() for key, value in HYPERPARAMETER_RANGES.items(): - assert hyperparameter_ranges[value.__name__ + 'ParameterRanges'][0]['Name'] == key + assert hyperparameter_ranges[value.__name__ + "ParameterRanges"][0]["Name"] == key def test_analytics(tuner): - tuner.latest_tuning_job = _TuningJob(tuner.sagemaker_session, 'testjob') + tuner.latest_tuning_job = _TuningJob(tuner.sagemaker_session, "testjob") tuner_analytics = tuner.analytics() assert tuner_analytics is not None - assert tuner_analytics.name.find('testjob') > -1 + assert tuner_analytics.name.find("testjob") > -1 def test_serialize_categorical_ranges_for_frameworks(sagemaker_session, tuner): - tuner.estimator = MXNet(entry_point=SCRIPT_NAME, - role=ROLE, - framework_version=FRAMEWORK_VERSION, - train_instance_count=TRAIN_INSTANCE_COUNT, - train_instance_type=TRAIN_INSTANCE_TYPE, - sagemaker_session=sagemaker_session) + tuner.estimator = MXNet( + entry_point=SCRIPT_NAME, + role=ROLE, + framework_version=FRAMEWORK_VERSION, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) hyperparameter_ranges = tuner.hyperparameter_ranges() - assert hyperparameter_ranges['CategoricalParameterRanges'][0]['Name'] == 'blank' - assert hyperparameter_ranges['CategoricalParameterRanges'][0]['Values'] == ['"0"', '"5"'] + assert hyperparameter_ranges["CategoricalParameterRanges"][0]["Name"] == "blank" + assert hyperparameter_ranges["CategoricalParameterRanges"][0]["Values"] == ['"0"', '"5"'] def test_serialize_nonexistent_parameter_ranges(tuner): temp_hyperparameter_ranges = HYPERPARAMETER_RANGES.copy() - parameter_type = temp_hyperparameter_ranges['validated'].__name__ + parameter_type = temp_hyperparameter_ranges["validated"].__name__ - temp_hyperparameter_ranges['validated'] = None + temp_hyperparameter_ranges["validated"] = None tuner._hyperparameter_ranges = temp_hyperparameter_ranges ranges = tuner.hyperparameter_ranges() assert len(ranges.keys()) == 3 - assert not ranges[parameter_type + 'ParameterRanges'] + assert not ranges[parameter_type + "ParameterRanges"] def test_stop_tuning_job(sagemaker_session, tuner): - sagemaker_session.stop_tuning_job = Mock(name='stop_hyper_parameter_tuning_job') + sagemaker_session.stop_tuning_job = Mock(name="stop_hyper_parameter_tuning_job") tuner.latest_tuning_job = _TuningJob(sagemaker_session, JOB_NAME) tuner.stop_tuning_job() @@ -469,35 +551,38 @@ def test_stop_tuning_job(sagemaker_session, tuner): def test_stop_tuning_job_no_tuning_job(tuner): with pytest.raises(ValueError) as e: tuner.stop_tuning_job() - assert 'No tuning job available' in str(e) + assert "No tuning job available" in str(e) def test_best_tuning_job(tuner): - tuning_job_description = {'BestTrainingJob': {'TrainingJobName': JOB_NAME}} + tuning_job_description = {"BestTrainingJob": {"TrainingJobName": JOB_NAME}} tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description) + name="describe_hyper_parameter_tuning_job", return_value=tuning_job_description + ) tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME) best_training_job = tuner.best_training_job() assert best_training_job == JOB_NAME tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_called_once_with( - HyperParameterTuningJobName=JOB_NAME) + HyperParameterTuningJobName=JOB_NAME + ) def test_best_tuning_job_no_latest_job(tuner): with pytest.raises(Exception) as e: tuner.best_training_job() - assert 'No tuning job available' in str(e) + assert "No tuning job available" in str(e) def test_best_tuning_job_no_best_job(tuner): - tuning_job_description = {'BestTrainingJob': {'Mock': None}} + tuning_job_description = {"BestTrainingJob": {"Mock": None}} tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description) + name="describe_hyper_parameter_tuning_job", return_value=tuning_job_description + ) tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME) @@ -505,61 +590,55 @@ def test_best_tuning_job_no_best_job(tuner): tuner.best_training_job() tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_called_once_with( - HyperParameterTuningJobName=JOB_NAME) - assert 'Best training job not available for tuning job:' in str(e) + HyperParameterTuningJobName=JOB_NAME + ) + assert "Best training job not available for tuning job:" in str(e) def test_deploy_default(tuner): returned_training_job_description = { - 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'TrainingImage': IMAGE_NAME, - 'MetricDefinitions': METRIC_DEFINITIONS, + "AlgorithmSpecification": { + "TrainingInputMode": "File", + "TrainingImage": IMAGE_NAME, + "MetricDefinitions": METRIC_DEFINITIONS, }, - 'HyperParameters': { - 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', - 'checkpoint_path': '"s3://other/1508872349"', - 'sagemaker_program': '"iris-dnn-classifier.py"', - 'sagemaker_enable_cloudwatch_metrics': 'false', - 'sagemaker_container_log_level': '"logging.INFO"', - 'sagemaker_job_name': '"neo"', - 'training_steps': '100', - '_tuning_objective_metric': 'Validation-accuracy', + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "checkpoint_path": '"s3://other/1508872349"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_enable_cloudwatch_metrics": "false", + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"neo"', + "training_steps": "100", + "_tuning_objective_metric": "Validation-accuracy", }, - - 'RoleArn': ROLE, - 'ResourceConfig': { - 'VolumeSizeInGB': 30, - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge' + "RoleArn": ROLE, + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", }, - 'StoppingCondition': { - 'MaxRuntimeInSeconds': 24 * 60 * 60 - }, - 'TrainingJobName': 'neo', - 'TrainingJobStatus': 'Completed', - 'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo', - 'OutputDataConfig': { - 'KmsKeyId': '', - 'S3OutputPath': 's3://place/output/neo' - }, - 'TrainingJobOutput': { - 'S3TrainingJobOutput': 's3://here/output.tar.gz' - }, - 'ModelArtifacts': { - 'S3ModelArtifacts': MODEL_DATA - } + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "neo", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + "ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}, } - tuning_job_description = {'BestTrainingJob': {'TrainingJobName': JOB_NAME}} - returned_list_tags = {'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]} - - tuner.estimator.sagemaker_session.sagemaker_client.describe_training_job = \ - Mock(name='describe_training_job', return_value=returned_training_job_description) - tuner.estimator.sagemaker_session.sagemaker_client.list_tags = \ - Mock(name='list_tags', return_value=returned_list_tags) + tuning_job_description = {"BestTrainingJob": {"TrainingJobName": JOB_NAME}} + returned_list_tags = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} + + tuner.estimator.sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_training_job_description + ) + tuner.estimator.sagemaker_session.sagemaker_client.list_tags = Mock( + name="list_tags", return_value=returned_list_tags + ) tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description) - tuner.estimator.sagemaker_session.log_for_jobs = Mock(name='log_for_jobs') + name="describe_hyper_parameter_tuning_job", return_value=tuning_job_description + ) + tuner.estimator.sagemaker_session.log_for_jobs = Mock(name="log_for_jobs") tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME) predictor = tuner.deploy(TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE) @@ -568,8 +647,8 @@ def test_deploy_default(tuner): args = tuner.estimator.sagemaker_session.create_model.call_args[0] assert args[0].startswith(IMAGE_NAME) assert args[1] == ROLE - assert args[2]['Image'] == IMAGE_NAME - assert args[2]['ModelDataUrl'] == MODEL_DATA + assert args[2]["Image"] == IMAGE_NAME + assert args[2]["ModelDataUrl"] == MODEL_DATA assert isinstance(predictor, RealTimePredictor) assert predictor.endpoint.startswith(JOB_NAME) @@ -578,7 +657,7 @@ def test_deploy_default(tuner): def test_wait(tuner): tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME) - tuner.estimator.sagemaker_session.wait_for_tuning_job = Mock(name='wait_for_tuning_job') + tuner.estimator.sagemaker_session.wait_for_tuning_job = Mock(name="wait_for_tuning_job") tuner.wait() @@ -588,34 +667,38 @@ def test_wait(tuner): def test_delete_endpoint(tuner): tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME) - tuning_job_description = {'BestTrainingJob': {'TrainingJobName': JOB_NAME}} + tuning_job_description = {"BestTrainingJob": {"TrainingJobName": JOB_NAME}} tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description) + name="describe_hyper_parameter_tuning_job", return_value=tuning_job_description + ) tuner.delete_endpoint() tuner.sagemaker_session.delete_endpoint.assert_called_with(JOB_NAME) def test_fit_no_inputs(tuner, sagemaker_session): - script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py') - tuner.estimator = MXNet(entry_point=script_path, - role=ROLE, - framework_version=FRAMEWORK_VERSION, - train_instance_count=TRAIN_INSTANCE_COUNT, - train_instance_type=TRAIN_INSTANCE_TYPE, - sagemaker_session=sagemaker_session) + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py") + tuner.estimator = MXNet( + entry_point=script_path, + role=ROLE, + framework_version=FRAMEWORK_VERSION, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) tuner.fit() _, _, tune_kwargs = sagemaker_session.tune.mock_calls[0] - assert tune_kwargs['input_config'] is None + assert tune_kwargs["input_config"] is None def test_identical_dataset_and_algorithm_tuner(sagemaker_session): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) parent_tuner = tuner.identical_dataset_and_algorithm_tuner(additional_parents={"p1", "p2"}) @@ -625,11 +708,14 @@ def test_identical_dataset_and_algorithm_tuner(sagemaker_session): def test_transfer_learning_tuner_with_estimator(sagemaker_session, estimator): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) - parent_tuner = tuner.transfer_learning_tuner(additional_parents={"p1", "p2"}, estimator=estimator) + parent_tuner = tuner.transfer_learning_tuner( + additional_parents={"p1", "p2"}, estimator=estimator + ) assert parent_tuner.warm_start_config.type == WarmStartTypes.TRANSFER_LEARNING assert parent_tuner.warm_start_config.parents == {tuner.latest_tuning_job.name, "p1", "p2"} @@ -638,8 +724,9 @@ def test_transfer_learning_tuner_with_estimator(sagemaker_session, estimator): def test_transfer_learning_tuner(sagemaker_session): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session) parent_tuner = tuner.transfer_learning_tuner(additional_parents={"p1", "p2"}) @@ -656,77 +743,78 @@ def test_transfer_learning_tuner(sagemaker_session): def test_continuous_parameter(): cont_param = ContinuousParameter(0.1, 1e-2) assert isinstance(cont_param, ParameterRange) - assert cont_param.__name__ == 'Continuous' + assert cont_param.__name__ == "Continuous" def test_continuous_parameter_ranges(): cont_param = ContinuousParameter(0.1, 1e-2) - ranges = cont_param.as_tuning_range('some') + ranges = cont_param.as_tuning_range("some") assert len(ranges.keys()) == 4 - assert ranges['Name'] == 'some' - assert ranges['MinValue'] == '0.1' - assert ranges['MaxValue'] == '0.01' - assert ranges['ScalingType'] == 'Auto' + assert ranges["Name"] == "some" + assert ranges["MinValue"] == "0.1" + assert ranges["MaxValue"] == "0.01" + assert ranges["ScalingType"] == "Auto" def test_continuous_parameter_scaling_type(): - cont_param = ContinuousParameter(0.1, 2, scaling_type='ReverseLogarithmic') - cont_range = cont_param.as_tuning_range('range') - assert cont_range['ScalingType'] == 'ReverseLogarithmic' + cont_param = ContinuousParameter(0.1, 2, scaling_type="ReverseLogarithmic") + cont_range = cont_param.as_tuning_range("range") + assert cont_range["ScalingType"] == "ReverseLogarithmic" def test_integer_parameter(): int_param = IntegerParameter(1, 2) assert isinstance(int_param, ParameterRange) - assert int_param.__name__ == 'Integer' + assert int_param.__name__ == "Integer" def test_integer_parameter_ranges(): int_param = IntegerParameter(1, 2) - ranges = int_param.as_tuning_range('some') + ranges = int_param.as_tuning_range("some") assert len(ranges.keys()) == 4 - assert ranges['Name'] == 'some' - assert ranges['MinValue'] == '1' - assert ranges['MaxValue'] == '2' - assert ranges['ScalingType'] == 'Auto' + assert ranges["Name"] == "some" + assert ranges["MinValue"] == "1" + assert ranges["MaxValue"] == "2" + assert ranges["ScalingType"] == "Auto" def test_integer_parameter_scaling_type(): - int_param = IntegerParameter(2, 3, scaling_type='Linear') - int_range = int_param.as_tuning_range('range') - assert int_range['ScalingType'] == 'Linear' + int_param = IntegerParameter(2, 3, scaling_type="Linear") + int_range = int_param.as_tuning_range("range") + assert int_range["ScalingType"] == "Linear" def test_categorical_parameter_list(): - cat_param = CategoricalParameter(['a', 'z']) + cat_param = CategoricalParameter(["a", "z"]) assert isinstance(cat_param, ParameterRange) - assert cat_param.__name__ == 'Categorical' + assert cat_param.__name__ == "Categorical" def test_categorical_parameter_list_ranges(): cat_param = CategoricalParameter([1, 10]) - ranges = cat_param.as_tuning_range('some') + ranges = cat_param.as_tuning_range("some") assert len(ranges.keys()) == 2 - assert ranges['Name'] == 'some' - assert ranges['Values'] == ['1', '10'] + assert ranges["Name"] == "some" + assert ranges["Values"] == ["1", "10"] def test_categorical_parameter_value(): - cat_param = CategoricalParameter('a') + cat_param = CategoricalParameter("a") assert isinstance(cat_param, ParameterRange) def test_categorical_parameter_value_ranges(): - cat_param = CategoricalParameter('a') - ranges = cat_param.as_tuning_range('some') + cat_param = CategoricalParameter("a") + ranges = cat_param.as_tuning_range("some") assert len(ranges.keys()) == 2 - assert ranges['Name'] == 'some' - assert ranges['Values'] == ['a'] + assert ranges["Name"] == "some" + assert ranges["Values"] == ["a"] ################################################################################# # _TuningJob Tests + def test_start_new(tuner, sagemaker_session): tuning_job = _TuningJob(sagemaker_session, JOB_NAME) @@ -745,7 +833,7 @@ def test_stop(sagemaker_session): def test_tuning_job_wait(sagemaker_session): - sagemaker_session.wait_for_tuning_job = Mock(name='wait_for_tuning_job') + sagemaker_session.wait_for_tuning_job = Mock(name="wait_for_tuning_job") tuning_job = _TuningJob(sagemaker_session, JOB_NAME) tuning_job.wait() @@ -756,77 +844,104 @@ def test_tuning_job_wait(sagemaker_session): ################################################################################# # WarmStartConfig Tests -@pytest.mark.parametrize('type, parents', [ - (WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, {"p1", "p2", "p3"}), - (WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, {"p1", "p3", "p3"}), - (WarmStartTypes.TRANSFER_LEARNING, {"p3"}), -]) + +@pytest.mark.parametrize( + "type, parents", + [ + (WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, {"p1", "p2", "p3"}), + (WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, {"p1", "p3", "p3"}), + (WarmStartTypes.TRANSFER_LEARNING, {"p3"}), + ], +) def test_warm_start_config_init(type, parents): warm_start_config = WarmStartConfig(warm_start_type=type, parents=parents) assert warm_start_config.type == type, "Warm start type initialization failed." - assert warm_start_config.parents == set(parents), "Warm start parents config initialization failed." + assert warm_start_config.parents == set( + parents + ), "Warm start parents config initialization failed." warm_start_config_req = warm_start_config.to_input_req() assert warm_start_config.type == WarmStartTypes(warm_start_config_req["WarmStartType"]) for parent in warm_start_config_req["ParentHyperParameterTuningJobs"]: - assert parent['HyperParameterTuningJobName'] in parents - - -@pytest.mark.parametrize('type, parents', [ - ("InvalidType", {"p1", "p2", "p3"}), - (None, {"p1", "p2", "p3"}), - ("", {"p1", "p2", "p3"}), - (WarmStartTypes.TRANSFER_LEARNING, None), - (WarmStartTypes.TRANSFER_LEARNING, {}), -]) + assert parent["HyperParameterTuningJobName"] in parents + + +@pytest.mark.parametrize( + "type, parents", + [ + ("InvalidType", {"p1", "p2", "p3"}), + (None, {"p1", "p2", "p3"}), + ("", {"p1", "p2", "p3"}), + (WarmStartTypes.TRANSFER_LEARNING, None), + (WarmStartTypes.TRANSFER_LEARNING, {}), + ], +) def test_warm_start_config_init_negative(type, parents): with pytest.raises(ValueError): WarmStartConfig(warm_start_type=type, parents=parents) -@pytest.mark.parametrize('warm_start_config_req', [ - ({}), - (None), - ({'WarmStartType': 'TransferLearning'}), - ({'ParentHyperParameterTuningJobs': []}), -]) +@pytest.mark.parametrize( + "warm_start_config_req", + [ + ({}), + (None), + ({"WarmStartType": "TransferLearning"}), + ({"ParentHyperParameterTuningJobs": []}), + ], +) def test_prepare_warm_start_config_cls_negative(warm_start_config_req): warm_start_config = WarmStartConfig.from_job_desc(warm_start_config_req) assert warm_start_config is None, "Warm start config should be None for invalid type/parents" -@pytest.mark.parametrize('warm_start_config_req', [ - ({'WarmStartType': 'TransferLearning', 'ParentHyperParameterTuningJobs': [{'HyperParameterTuningJobName': 'p1'}, - {'HyperParameterTuningJobName': 'p2'}]}), - ({'WarmStartType': 'IdenticalDataAndAlgorithm', - 'ParentHyperParameterTuningJobs': [{'HyperParameterTuningJobName': 'p1'}, - {'HyperParameterTuningJobName': 'p1'}]}), -]) +@pytest.mark.parametrize( + "warm_start_config_req", + [ + ( + { + "WarmStartType": "TransferLearning", + "ParentHyperParameterTuningJobs": [ + {"HyperParameterTuningJobName": "p1"}, + {"HyperParameterTuningJobName": "p2"}, + ], + } + ), + ( + { + "WarmStartType": "IdenticalDataAndAlgorithm", + "ParentHyperParameterTuningJobs": [ + {"HyperParameterTuningJobName": "p1"}, + {"HyperParameterTuningJobName": "p1"}, + ], + } + ), + ], +) def test_prepare_warm_start_config_cls(warm_start_config_req): warm_start_config = WarmStartConfig.from_job_desc(warm_start_config_req) assert warm_start_config.type == WarmStartTypes( - warm_start_config_req["WarmStartType"]), "Warm start type initialization failed." + warm_start_config_req["WarmStartType"] + ), "Warm start type initialization failed." for p in warm_start_config_req["ParentHyperParameterTuningJobs"]: - assert p['HyperParameterTuningJobName'] in warm_start_config.parents, \ - "Warm start parents config initialization failed." + assert ( + p["HyperParameterTuningJobName"] in warm_start_config.parents + ), "Warm start parents config initialization failed." -@pytest.mark.parametrize('additional_parents', [ - {"p1", "p2"}, - {}, - None, -]) +@pytest.mark.parametrize("additional_parents", [{"p1", "p2"}, {}, None]) def test_create_identical_dataset_and_algorithm_tuner(sagemaker_session, additional_parents): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) - tuner = create_identical_dataset_and_algorithm_tuner(parent=JOB_NAME, - additional_parents=additional_parents, - sagemaker_session=sagemaker_session) + tuner = create_identical_dataset_and_algorithm_tuner( + parent=JOB_NAME, additional_parents=additional_parents, sagemaker_session=sagemaker_session + ) assert tuner.warm_start_config.type == WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM if additional_parents: @@ -836,20 +951,19 @@ def test_create_identical_dataset_and_algorithm_tuner(sagemaker_session, additio assert tuner.warm_start_config.parents == {JOB_NAME} -@pytest.mark.parametrize('additional_parents', [ - {"p1", "p2"}, - {}, - None, -]) +@pytest.mark.parametrize("additional_parents", [{"p1", "p2"}, {}, None]) def test_create_transfer_learning_tuner(sagemaker_session, estimator, additional_parents): job_details = copy.deepcopy(TUNING_JOB_DETAILS) - sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job', - return_value=job_details) - - tuner = create_transfer_learning_tuner(parent=JOB_NAME, - additional_parents=additional_parents, - sagemaker_session=sagemaker_session, - estimator=estimator) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name="describe_tuning_job", return_value=job_details + ) + + tuner = create_transfer_learning_tuner( + parent=JOB_NAME, + additional_parents=additional_parents, + sagemaker_session=sagemaker_session, + estimator=estimator, + ) assert tuner.warm_start_config.type == WarmStartTypes.TRANSFER_LEARNING assert tuner.estimator == estimator diff --git a/tests/unit/test_upload_data.py b/tests/unit/test_upload_data.py index 9a4d9e5aeb..6b731cc25b 100644 --- a/tests/unit/test_upload_data.py +++ b/tests/unit/test_upload_data.py @@ -20,64 +20,80 @@ import sagemaker from tests.unit import DATA_DIR -UPLOAD_DATA_TESTS_FILES_DIR = os.path.join(DATA_DIR, 'upload_data_tests') -SINGLE_FILE_NAME = 'file1.py' +UPLOAD_DATA_TESTS_FILES_DIR = os.path.join(DATA_DIR, "upload_data_tests") +SINGLE_FILE_NAME = "file1.py" UPLOAD_DATA_TESTS_SINGLE_FILE = os.path.join(UPLOAD_DATA_TESTS_FILES_DIR, SINGLE_FILE_NAME) -BUCKET_NAME = 'mybucket' -AES_ENCRYPTION_ENABLED = {'ServerSideEncryption': 'AES256'} +BUCKET_NAME = "mybucket" +AES_ENCRYPTION_ENABLED = {"ServerSideEncryption": "AES256"} @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name='boto_session') + boto_mock = Mock(name="boto_session") ims = sagemaker.Session(boto_session=boto_mock) - ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) return ims def test_upload_data_absolute_dir(sagemaker_session): result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_FILES_DIR) - uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls - if name == 'resource().Object().upload_file'] - assert result_s3_uri == 's3://{}/data'.format(BUCKET_NAME) + uploaded_files_with_args = [ + (args[0], kwargs) + for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == "resource().Object().upload_file" + ] + assert result_s3_uri == "s3://{}/data".format(BUCKET_NAME) assert len(uploaded_files_with_args) == 4 for file, kwargs in uploaded_files_with_args: assert os.path.exists(file) - assert kwargs['ExtraArgs'] is None + assert kwargs["ExtraArgs"] is None def test_upload_data_absolute_file(sagemaker_session): result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_SINGLE_FILE) - uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls - if name == 'resource().Object().upload_file'] - assert result_s3_uri == 's3://{}/data/{}'.format(BUCKET_NAME, SINGLE_FILE_NAME) + uploaded_files_with_args = [ + (args[0], kwargs) + for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == "resource().Object().upload_file" + ] + assert result_s3_uri == "s3://{}/data/{}".format(BUCKET_NAME, SINGLE_FILE_NAME) assert len(uploaded_files_with_args) == 1 (file, kwargs) = uploaded_files_with_args[0] assert os.path.exists(file) - assert kwargs['ExtraArgs'] is None + assert kwargs["ExtraArgs"] is None def test_upload_data_aes_encrypted_absolute_dir(sagemaker_session): - result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_FILES_DIR, extra_args=AES_ENCRYPTION_ENABLED) - - uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls - if name == 'resource().Object().upload_file'] - assert result_s3_uri == 's3://{}/data'.format(BUCKET_NAME) + result_s3_uri = sagemaker_session.upload_data( + UPLOAD_DATA_TESTS_FILES_DIR, extra_args=AES_ENCRYPTION_ENABLED + ) + + uploaded_files_with_args = [ + (args[0], kwargs) + for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == "resource().Object().upload_file" + ] + assert result_s3_uri == "s3://{}/data".format(BUCKET_NAME) assert len(uploaded_files_with_args) == 4 for file, kwargs in uploaded_files_with_args: assert os.path.exists(file) - assert kwargs['ExtraArgs'] == AES_ENCRYPTION_ENABLED + assert kwargs["ExtraArgs"] == AES_ENCRYPTION_ENABLED def test_upload_data_aes_encrypted_absolute_file(sagemaker_session): - result_s3_uri = sagemaker_session.upload_data(UPLOAD_DATA_TESTS_SINGLE_FILE, extra_args=AES_ENCRYPTION_ENABLED) - - uploaded_files_with_args = [(args[0], kwargs) for name, args, kwargs in sagemaker_session.boto_session.mock_calls - if name == 'resource().Object().upload_file'] - assert result_s3_uri == 's3://{}/data/{}'.format(BUCKET_NAME, SINGLE_FILE_NAME) + result_s3_uri = sagemaker_session.upload_data( + UPLOAD_DATA_TESTS_SINGLE_FILE, extra_args=AES_ENCRYPTION_ENABLED + ) + + uploaded_files_with_args = [ + (args[0], kwargs) + for name, args, kwargs in sagemaker_session.boto_session.mock_calls + if name == "resource().Object().upload_file" + ] + assert result_s3_uri == "s3://{}/data/{}".format(BUCKET_NAME, SINGLE_FILE_NAME) assert len(uploaded_files_with_args) == 1 (file, kwargs) = uploaded_files_with_args[0] assert os.path.exists(file) - assert kwargs['ExtraArgs'] == AES_ENCRYPTION_ENABLED + assert kwargs["ExtraArgs"] == AES_ENCRYPTION_ENABLED diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2f2f706ff0..e59c05f865 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -27,29 +27,24 @@ import sagemaker -BUCKET_WITHOUT_WRITING_PERMISSION = 's3://bucket-without-writing-permission' +BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission" -NAME = 'base_name' -BUCKET_NAME = 'some_bucket' +NAME = "base_name" +BUCKET_NAME = "some_bucket" def test_get_config_value(): - config = { - 'local': { - 'region_name': 'us-west-2', - 'port': '123' - }, - 'other': { - 'key': 1 - } - } + config = {"local": {"region_name": "us-west-2", "port": "123"}, "other": {"key": 1}} - assert sagemaker.utils.get_config_value('local.region_name', config) == 'us-west-2' - assert sagemaker.utils.get_config_value('local', config) == {'region_name': 'us-west-2', 'port': '123'} + assert sagemaker.utils.get_config_value("local.region_name", config) == "us-west-2" + assert sagemaker.utils.get_config_value("local", config) == { + "region_name": "us-west-2", + "port": "123", + } - assert sagemaker.utils.get_config_value('does_not.exist', config) is None - assert sagemaker.utils.get_config_value('other.key', None) is None + assert sagemaker.utils.get_config_value("does_not.exist", config) is None + assert sagemaker.utils.get_config_value("other.key", None) is None def test_deferred_error(): @@ -68,70 +63,74 @@ def test_bad_import(): pd.DataFrame() -@patch('sagemaker.utils.sagemaker_timestamp') +@patch("sagemaker.utils.sagemaker_timestamp") def test_name_from_base(sagemaker_timestamp): sagemaker.utils.name_from_base(NAME, short=False) assert sagemaker_timestamp.called_once -@patch('sagemaker.utils.sagemaker_short_timestamp') +@patch("sagemaker.utils.sagemaker_short_timestamp") def test_name_from_base_short(sagemaker_short_timestamp): sagemaker.utils.name_from_base(NAME, short=True) assert sagemaker_short_timestamp.called_once def test_unique_name_from_base(): - assert re.match(r'base-\d{10}-[a-f0-9]{4}', sagemaker.utils.unique_name_from_base('base')) + assert re.match(r"base-\d{10}-[a-f0-9]{4}", sagemaker.utils.unique_name_from_base("base")) def test_unique_name_from_base_truncated(): - assert re.match(r'real-\d{10}-[a-f0-9]{4}', - sagemaker.utils.unique_name_from_base('really-long-name', max_length=20)) + assert re.match( + r"real-\d{10}-[a-f0-9]{4}", + sagemaker.utils.unique_name_from_base("really-long-name", max_length=20), + ) def test_to_str_with_native_string(): - value = 'some string' + value = "some string" assert sagemaker.utils.to_str(value) == value def test_to_str_with_unicode_string(): - value = u'åñøthér strîng' + value = u"åñøthér strîng" assert sagemaker.utils.to_str(value) == value def test_name_from_tuning_arn(): - arn = 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/resnet-sgd-tuningjob-11-07-34-11' + arn = "arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/resnet-sgd-tuningjob-11-07-34-11" name = sagemaker.utils.extract_name_from_job_arn(arn) - assert name == 'resnet-sgd-tuningjob-11-07-34-11' + assert name == "resnet-sgd-tuningjob-11-07-34-11" def test_name_from_training_arn(): - arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b' + arn = "arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b" name = sagemaker.utils.extract_name_from_job_arn(arn) - assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b' + assert name == "resnet-sgd-tuningjob-11-22-38-46-002-2927640b" -MESSAGE = 'message' -STATUS = 'status' +MESSAGE = "message" +STATUS = "status" TRAINING_JOB_DESCRIPTION_1 = { - 'SecondaryStatusTransitions': [{'StatusMessage': MESSAGE, 'Status': STATUS}] + "SecondaryStatusTransitions": [{"StatusMessage": MESSAGE, "Status": STATUS}] } TRAINING_JOB_DESCRIPTION_2 = { - 'SecondaryStatusTransitions': [{'StatusMessage': 'different message', 'Status': STATUS}] + "SecondaryStatusTransitions": [{"StatusMessage": "different message", "Status": STATUS}] } -TRAINING_JOB_DESCRIPTION_EMPTY = { - 'SecondaryStatusTransitions': [] -} +TRAINING_JOB_DESCRIPTION_EMPTY = {"SecondaryStatusTransitions": []} def test_secondary_training_status_changed_true(): - changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) + changed = sagemaker.utils.secondary_training_status_changed( + TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2 + ) assert changed is True def test_secondary_training_status_changed_false(): - changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed( + TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1 + ) assert changed is False @@ -151,50 +150,62 @@ def test_secondary_training_status_changed_current_missing(): def test_secondary_training_status_changed_empty(): - changed = sagemaker.utils.secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, - TRAINING_JOB_DESCRIPTION_1) + changed = sagemaker.utils.secondary_training_status_changed( + TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1 + ) assert changed is False def test_secondary_training_status_message_status_changed(): now = datetime.now() - TRAINING_JOB_DESCRIPTION_1['LastModifiedTime'] = now - expected = '{} {} - {}'.format( - datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'), + TRAINING_JOB_DESCRIPTION_1["LastModifiedTime"] = now + expected = "{} {} - {}".format( + datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime("%Y-%m-%d %H:%M:%S"), STATUS, - MESSAGE + MESSAGE, + ) + assert ( + sagemaker.utils.secondary_training_status_message( + TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY + ) + == expected ) - assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, - TRAINING_JOB_DESCRIPTION_EMPTY) == expected def test_secondary_training_status_message_status_not_changed(): now = datetime.now() - TRAINING_JOB_DESCRIPTION_1['LastModifiedTime'] = now - expected = '{} {} - {}'.format( - datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'), + TRAINING_JOB_DESCRIPTION_1["LastModifiedTime"] = now + expected = "{} {} - {}".format( + datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime("%Y-%m-%d %H:%M:%S"), STATUS, - MESSAGE + MESSAGE, + ) + assert ( + sagemaker.utils.secondary_training_status_message( + TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2 + ) + == expected ) - assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, - TRAINING_JOB_DESCRIPTION_2) == expected def test_secondary_training_status_message_prev_missing(): now = datetime.now() - TRAINING_JOB_DESCRIPTION_1['LastModifiedTime'] = now - expected = '{} {} - {}'.format( - datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'), + TRAINING_JOB_DESCRIPTION_1["LastModifiedTime"] = now + expected = "{} {} - {}".format( + datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime("%Y-%m-%d %H:%M:%S"), STATUS, - MESSAGE + MESSAGE, + ) + assert ( + sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, {}) + == expected ) - assert sagemaker.utils.secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, {}) == expected -@patch('os.makedirs') +@patch("os.makedirs") def test_download_folder(makedirs): - boto_mock = Mock(name='boto_session') - boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'} + boto_mock = Mock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"} session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) @@ -202,81 +213,84 @@ def test_download_folder(makedirs): validation_data = Mock() train_data.bucket_name.return_value = BUCKET_NAME - train_data.key = 'prefix/train/train_data.csv' + train_data.key = "prefix/train/train_data.csv" validation_data.bucket_name.return_value = BUCKET_NAME - validation_data.key = 'prefix/train/validation_data.csv' + validation_data.key = "prefix/train/validation_data.csv" s3_files = [train_data, validation_data] - boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = s3_files + boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.return_value = s3_files obj_mock = Mock() - boto_mock.resource('s3').Object.return_value = obj_mock + boto_mock.resource("s3").Object.return_value = obj_mock # all the S3 mocks are set, the test itself begins now. - sagemaker.utils.download_folder(BUCKET_NAME, '/prefix', '/tmp', session) + sagemaker.utils.download_folder(BUCKET_NAME, "/prefix", "/tmp", session) obj_mock.download_file.assert_called() - calls = [call(os.path.join('/tmp', 'train/train_data.csv')), - call(os.path.join('/tmp', 'train/validation_data.csv'))] + calls = [ + call(os.path.join("/tmp", "train/train_data.csv")), + call(os.path.join("/tmp", "train/validation_data.csv")), + ] obj_mock.download_file.assert_has_calls(calls) obj_mock.reset_mock() # Testing with a trailing slash for the prefix. - sagemaker.utils.download_folder(BUCKET_NAME, '/prefix/', '/tmp', session) + sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/", "/tmp", session) obj_mock.download_file.assert_called() obj_mock.download_file.assert_has_calls(calls) -@patch('os.makedirs') +@patch("os.makedirs") def test_download_folder_points_to_single_file(makedirs): - boto_mock = Mock(name='boto_session') - boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'} + boto_mock = Mock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"} session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) train_data = Mock() train_data.bucket_name.return_value = BUCKET_NAME - train_data.key = 'prefix/train/train_data.csv' + train_data.key = "prefix/train/train_data.csv" s3_files = [train_data] - boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.return_value = s3_files + boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.return_value = s3_files obj_mock = Mock() - boto_mock.resource('s3').Object.return_value = obj_mock + boto_mock.resource("s3").Object.return_value = obj_mock # all the S3 mocks are set, the test itself begins now. - sagemaker.utils.download_folder(BUCKET_NAME, '/prefix/train/train_data.csv', '/tmp', session) + sagemaker.utils.download_folder(BUCKET_NAME, "/prefix/train/train_data.csv", "/tmp", session) obj_mock.download_file.assert_called() - calls = [call(os.path.join('/tmp', 'train_data.csv'))] + calls = [call(os.path.join("/tmp", "train_data.csv"))] obj_mock.download_file.assert_has_calls(calls) - assert boto_mock.resource('s3').Bucket(BUCKET_NAME).objects.filter.call_count == 1 + assert boto_mock.resource("s3").Bucket(BUCKET_NAME).objects.filter.call_count == 1 obj_mock.reset_mock() def test_download_file(): - boto_mock = Mock(name='boto_session') - boto_mock.client('sts').get_caller_identity.return_value = {'Account': '123'} + boto_mock = Mock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"} bucket_mock = Mock() - boto_mock.resource('s3').Bucket.return_value = bucket_mock + boto_mock.resource("s3").Bucket.return_value = bucket_mock session = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock()) - sagemaker.utils.download_file(BUCKET_NAME, '/prefix/path/file.tar.gz', - '/tmp/file.tar.gz', session) + sagemaker.utils.download_file( + BUCKET_NAME, "/prefix/path/file.tar.gz", "/tmp/file.tar.gz", session + ) - bucket_mock.download_file.assert_called_with('prefix/path/file.tar.gz', '/tmp/file.tar.gz') + bucket_mock.download_file.assert_called_with("prefix/path/file.tar.gz", "/tmp/file.tar.gz") -@patch('tarfile.open') +@patch("tarfile.open") def test_create_tar_file_with_provided_path(open): files = mock_tarfile(open) - file_list = ['/tmp/a', '/tmp/b'] + file_list = ["/tmp/a", "/tmp/b"] - path = sagemaker.utils.create_tar_file(file_list, target='/my/custom/path.tar.gz') - assert path == '/my/custom/path.tar.gz' - assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] + path = sagemaker.utils.create_tar_file(file_list, target="/my/custom/path.tar.gz") + assert path == "/my/custom/path.tar.gz" + assert files == [["/tmp/a", "a"], ["/tmp/b", "b"]] def mock_tarfile(open): @@ -292,14 +306,14 @@ def add_files(filename, arcname): return files -@patch('tarfile.open') -@patch('tempfile.mkstemp', Mock(return_value=(None, '/auto/generated/path'))) +@patch("tarfile.open") +@patch("tempfile.mkstemp", Mock(return_value=(None, "/auto/generated/path"))) def test_create_tar_file_with_auto_generated_path(open): files = mock_tarfile(open) - path = sagemaker.utils.create_tar_file(['/tmp/a', '/tmp/b']) - assert path == '/auto/generated/path' - assert files == [['/tmp/a', 'a'], ['/tmp/b', 'b']] + path = sagemaker.utils.create_tar_file(["/tmp/a", "/tmp/b"]) + assert path == "/auto/generated/path" + assert files == [["/tmp/a", "a"], ["/tmp/b", "b"]] def create_file_tree(root, tree): @@ -308,7 +322,7 @@ def create_file_tree(root, tree): os.makedirs(os.path.join(root, os.path.dirname(file))) except: # noqa: E722 Using bare except because p2/3 incompatibility issues. pass - with open(os.path.join(root, file), 'a') as f: + with open(os.path.join(root, file), "a") as f: f.write(file) @@ -319,138 +333,171 @@ def tmp(tmpdir): def test_repack_model_without_source_dir(tmp, fake_s3): - create_file_tree(tmp, ['model-dir/model', - 'dependencies/a', - 'dependencies/b', - 'source-dir/inference.py', - 'source-dir/this-file-should-not-be-included.py']) + create_file_tree( + tmp, + [ + "model-dir/model", + "dependencies/a", + "dependencies/b", + "source-dir/inference.py", + "source-dir/this-file-should-not-be-included.py", + ], + ) - fake_s3.tar_and_upload('model-dir', 's3://fake/location') + fake_s3.tar_and_upload("model-dir", "s3://fake/location") - sagemaker.utils.repack_model(inference_script=os.path.join(tmp, 'source-dir/inference.py'), - source_directory=None, - dependencies=[os.path.join(tmp, 'dependencies/a'), - os.path.join(tmp, 'dependencies/b')], - model_uri='s3://fake/location', - repacked_model_uri='s3://destination-bucket/model.tar.gz', - sagemaker_session=fake_s3.sagemaker_session) + sagemaker.utils.repack_model( + inference_script=os.path.join(tmp, "source-dir/inference.py"), + source_directory=None, + dependencies=[os.path.join(tmp, "dependencies/a"), os.path.join(tmp, "dependencies/b")], + model_uri="s3://fake/location", + repacked_model_uri="s3://destination-bucket/model.tar.gz", + sagemaker_session=fake_s3.sagemaker_session, + ) - assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/model', '/code/a', - '/code/b', '/code/inference.py'} + assert list_tar_files(fake_s3.fake_upload_path, tmp) == { + "/model", + "/code/a", + "/code/b", + "/code/inference.py", + } def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake_s3): - create_file_tree(tmp, ['model-dir/model', - 'source-dir/inference.py', - 'source-dir/this-file-should-not-be-included.py']) + create_file_tree( + tmp, + [ + "model-dir/model", + "source-dir/inference.py", + "source-dir/this-file-should-not-be-included.py", + ], + ) - fake_s3.tar_and_upload('model-dir', 's3://fake/location') + fake_s3.tar_and_upload("model-dir", "s3://fake/location") cwd = os.getcwd() try: - os.chdir(os.path.join(tmp, 'source-dir')) - - sagemaker.utils.repack_model('inference.py', - None, - None, - 's3://fake/location', - 's3://destination-bucket/model.tar.gz', - fake_s3.sagemaker_session) + os.chdir(os.path.join(tmp, "source-dir")) + + sagemaker.utils.repack_model( + "inference.py", + None, + None, + "s3://fake/location", + "s3://destination-bucket/model.tar.gz", + fake_s3.sagemaker_session, + ) finally: os.chdir(cwd) - assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/inference.py', '/model'} + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {"/code/inference.py", "/model"} def test_repack_model_from_s3_to_s3(tmp, fake_s3): - create_file_tree(tmp, ['model-dir/model', - 'source-dir/inference.py', - 'source-dir/this-file-should-be-included.py']) + create_file_tree( + tmp, + [ + "model-dir/model", + "source-dir/inference.py", + "source-dir/this-file-should-be-included.py", + ], + ) - fake_s3.tar_and_upload('model-dir', 's3://fake/location') + fake_s3.tar_and_upload("model-dir", "s3://fake/location") - sagemaker.utils.repack_model('inference.py', - os.path.join(tmp, 'source-dir'), - None, - 's3://fake/location', - 's3://destination-bucket/model.tar.gz', - fake_s3.sagemaker_session) + sagemaker.utils.repack_model( + "inference.py", + os.path.join(tmp, "source-dir"), + None, + "s3://fake/location", + "s3://destination-bucket/model.tar.gz", + fake_s3.sagemaker_session, + ) - assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/this-file-should-be-included.py', - '/code/inference.py', - '/model'} + assert list_tar_files(fake_s3.fake_upload_path, tmp) == { + "/code/this-file-should-be-included.py", + "/code/inference.py", + "/model", + } def test_repack_model_from_file_to_file(tmp): - create_file_tree(tmp, ['model', - 'dependencies/a', - 'source-dir/inference.py']) + create_file_tree(tmp, ["model", "dependencies/a", "source-dir/inference.py"]) - model_tar_path = os.path.join(tmp, 'model.tar.gz') - sagemaker.utils.create_tar_file([os.path.join(tmp, 'model')], model_tar_path) + model_tar_path = os.path.join(tmp, "model.tar.gz") + sagemaker.utils.create_tar_file([os.path.join(tmp, "model")], model_tar_path) sagemaker_session = MagicMock() - file_mode_path = 'file://%s' % model_tar_path - destination_path = 'file://%s' % os.path.join(tmp, 'repacked-model.tar.gz') + file_mode_path = "file://%s" % model_tar_path + destination_path = "file://%s" % os.path.join(tmp, "repacked-model.tar.gz") - sagemaker.utils.repack_model('inference.py', - os.path.join(tmp, 'source-dir'), - [os.path.join(tmp, 'dependencies/a')], - file_mode_path, - destination_path, - sagemaker_session) + sagemaker.utils.repack_model( + "inference.py", + os.path.join(tmp, "source-dir"), + [os.path.join(tmp, "dependencies/a")], + file_mode_path, + destination_path, + sagemaker_session, + ) - assert list_tar_files(destination_path, tmp) == {'/code/a', '/code/inference.py', '/model'} + assert list_tar_files(destination_path, tmp) == {"/code/a", "/code/inference.py", "/model"} def test_repack_model_with_inference_code_should_replace_the_code(tmp, fake_s3): - create_file_tree(tmp, ['model-dir/model', - 'source-dir/new-inference.py', - 'model-dir/code/old-inference.py']) + create_file_tree( + tmp, ["model-dir/model", "source-dir/new-inference.py", "model-dir/code/old-inference.py"] + ) - fake_s3.tar_and_upload('model-dir', 's3://fake/location') + fake_s3.tar_and_upload("model-dir", "s3://fake/location") - sagemaker.utils.repack_model('inference.py', - os.path.join(tmp, 'source-dir'), - None, - 's3://fake/location', - 's3://destination-bucket/repacked-model', - fake_s3.sagemaker_session) + sagemaker.utils.repack_model( + "inference.py", + os.path.join(tmp, "source-dir"), + None, + "s3://fake/location", + "s3://destination-bucket/repacked-model", + fake_s3.sagemaker_session, + ) - assert list_tar_files(fake_s3.fake_upload_path, tmp) == {'/code/new-inference.py', '/model'} + assert list_tar_files(fake_s3.fake_upload_path, tmp) == {"/code/new-inference.py", "/model"} def test_repack_model_from_file_to_folder(tmp): - create_file_tree(tmp, ['model', - 'source-dir/inference.py']) + create_file_tree(tmp, ["model", "source-dir/inference.py"]) - model_tar_path = os.path.join(tmp, 'model.tar.gz') - sagemaker.utils.create_tar_file([os.path.join(tmp, 'model')], model_tar_path) + model_tar_path = os.path.join(tmp, "model.tar.gz") + sagemaker.utils.create_tar_file([os.path.join(tmp, "model")], model_tar_path) - file_mode_path = 'file://%s' % model_tar_path + file_mode_path = "file://%s" % model_tar_path - sagemaker.utils.repack_model('inference.py', - os.path.join(tmp, 'source-dir'), - [], - file_mode_path, - 'file://%s/repacked-model.tar.gz' % tmp, - MagicMock()) + sagemaker.utils.repack_model( + "inference.py", + os.path.join(tmp, "source-dir"), + [], + file_mode_path, + "file://%s/repacked-model.tar.gz" % tmp, + MagicMock(), + ) - assert list_tar_files('file://%s/repacked-model.tar.gz' % tmp, tmp) == {'/code/inference.py', '/model'} + assert list_tar_files("file://%s/repacked-model.tar.gz" % tmp, tmp) == { + "/code/inference.py", + "/model", + } class FakeS3(object): - def __init__(self, tmp): self.tmp = tmp self.sagemaker_session = MagicMock() self.location_map = {} self.current_bucket = None - self.sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = self.download_file + self.sagemaker_session.boto_session.resource().Bucket().download_file.side_effect = ( + self.download_file + ) self.sagemaker_session.boto_session.resource().Bucket.side_effect = self.bucket self.fake_upload_path = self.mock_s3_upload() @@ -459,22 +506,21 @@ def bucket(self, name): return self def download_file(self, path, target): - key = '%s/%s' % (self.current_bucket, path) + key = "%s/%s" % (self.current_bucket, path) shutil.copy2(self.location_map[key], target) def tar_and_upload(self, path, fake_location): - tar_location = os.path.join(self.tmp, 'model-%s.tar.gz' % time.time()) - with tarfile.open(tar_location, mode='w:gz') as t: + tar_location = os.path.join(self.tmp, "model-%s.tar.gz" % time.time()) + with tarfile.open(tar_location, mode="w:gz") as t: t.add(os.path.join(self.tmp, path), arcname=os.path.sep) - self.location_map[fake_location.replace('s3://', '')] = tar_location + self.location_map[fake_location.replace("s3://", "")] = tar_location return tar_location def mock_s3_upload(self): - dst = os.path.join(self.tmp, 'dst') + dst = os.path.join(self.tmp, "dst") class MockS3Object(object): - def __init__(self, bucket, key): self.bucket = bucket self.key = key @@ -494,18 +540,18 @@ def fake_s3(tmp): def list_tar_files(tar_ball, tmp): - tar_ball = tar_ball.replace('file://', '') - startpath = os.path.join(tmp, 'startpath') + tar_ball = tar_ball.replace("file://", "") + startpath = os.path.join(tmp, "startpath") os.mkdir(startpath) - with tarfile.open(name=tar_ball, mode='r:gz') as t: + with tarfile.open(name=tar_ball, mode="r:gz") as t: t.extractall(path=startpath) def walk(): for root, dirs, files in os.walk(startpath): - path = root.replace(startpath, '') + path = root.replace(startpath, "") for f in files: - yield '%s/%s' % (path, f) + yield "%s/%s" % (path, f) result = set(walk()) return result if result else {} diff --git a/tests/unit/test_vpc_utils.py b/tests/unit/test_vpc_utils.py index 119989537d..dcefd64190 100644 --- a/tests/unit/test_vpc_utils.py +++ b/tests/unit/test_vpc_utils.py @@ -18,13 +18,10 @@ from sagemaker.vpc_utils import SUBNETS_KEY, SECURITY_GROUP_IDS_KEY, to_dict, from_dict, sanitize -subnets = ['subnet'] -security_groups = ['sg'] -good_vpc_config = {SUBNETS_KEY: subnets, - SECURITY_GROUP_IDS_KEY: security_groups} -foo_vpc_config = {SUBNETS_KEY: subnets, - SECURITY_GROUP_IDS_KEY: security_groups, - 'foo': 1} +subnets = ["subnet"] +security_groups = ["sg"] +good_vpc_config = {SUBNETS_KEY: subnets, SECURITY_GROUP_IDS_KEY: security_groups} +foo_vpc_config = {SUBNETS_KEY: subnets, SECURITY_GROUP_IDS_KEY: security_groups, "foo": 1} def test_to_dict(): @@ -32,8 +29,10 @@ def test_to_dict(): assert to_dict(subnets, None) is None assert to_dict(None, security_groups) is None - assert to_dict(subnets, security_groups) == {SUBNETS_KEY: subnets, - SECURITY_GROUP_IDS_KEY: security_groups} + assert to_dict(subnets, security_groups) == { + SUBNETS_KEY: subnets, + SECURITY_GROUP_IDS_KEY: security_groups, + } def test_from_dict(): @@ -71,8 +70,6 @@ def test_sanitize(): sanitize({SUBNETS_KEY: []}) with pytest.raises(ValueError): - sanitize({SECURITY_GROUP_IDS_KEY: 1, - SUBNETS_KEY: subnets}) + sanitize({SECURITY_GROUP_IDS_KEY: 1, SUBNETS_KEY: subnets}) with pytest.raises(ValueError): - sanitize({SECURITY_GROUP_IDS_KEY: [], - SUBNETS_KEY: subnets}) + sanitize({SECURITY_GROUP_IDS_KEY: [], SUBNETS_KEY: subnets}) diff --git a/tox.ini b/tox.ini index b9baae836f..c1d8d5196f 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = flake8,pylint,twine,sphinx,py27,py36 +envlist = black-format,flake8,pylint,twine,sphinx,py27,py36 skip_missing_interpreters = False @@ -25,6 +25,7 @@ exclude = max-complexity = 10 ignore = + E203, # whitespace before ':': Black disagrees with and explicitly violates this. FI10, FI12, FI13, @@ -112,3 +113,17 @@ deps = commands = pip install --exists-action=w -r requirements.txt sphinx-build -T -W -b html -d _build/doctrees-readthedocs -D language=en . _build/html + +[testenv:black-format] +# Used during development (before committing) to format .py files. +basepython = python3 +deps = black==19.3b0 +commands = + black -l 100 ./ + +[testenv:black-check] +# Used by automated build steps to check that all files are properly formatted. +basepython = python3 +deps = black==19.3b0 +commands = + black -l 100 --check ./