-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support TensorFlow-1.5.0 and MXNet-1.0.0 #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tests/unit/test_estimator.py
Outdated
assert framework == 'mxnet' | ||
assert py_ver == 'py2' | ||
assert tag == '2.5.6-gpu-py2' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 🥇 I love these tags.
src/sagemaker/fw_utils.py
Outdated
@@ -28,26 +28,28 @@ | |||
""" | |||
|
|||
|
|||
def create_image_uri(region, framework, instance_type, py_version='py2', tag='1.0', account='520713654638'): | |||
def create_image_uri(region, framework, instance_type, framework_version, py_version='py2', account='520713654638'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've notice that create_image_uri
does not have unit tests. Can you write unit tests for this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's doesn't, it should 👍
tests/unit/test_mxnet.py
Outdated
'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(GPU) | ||
|
||
assert 'cpu' in model.prepare_container_def(CPU)['Image'] | ||
predictor = mx.deploy(1, GPU) | ||
assert isinstance(predictor, MXNetPredictor) | ||
|
||
|
||
def test_model(sagemaker_session): | ||
model = MXNetModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session) | ||
def test_model(sagemaker_session, mxnet_version): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
|
||
|
||
def test_tf_deploy_model_server_workers(sagemaker_session): | ||
tf = _build_tf(sagemaker_session) | ||
def test_tf_deploy_model_server_workers(sagemaker_session, tf_version): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
@@ -139,8 +149,8 @@ def test_tf_deploy_model_server_workers(sagemaker_session): | |||
MODEL_SERVER_WORKERS_PARAM_NAME.upper()] | |||
|
|||
|
|||
def test_tf_deploy_model_server_workers_unset(sagemaker_session): | |||
tf = _build_tf(sagemaker_session) | |||
def test_tf_deploy_model_server_workers_unset(sagemaker_session, tf_version): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
@@ -192,19 +202,21 @@ def test_tf(time, strftime, sagemaker_session): | |||
@patch('subprocess.Popen') | |||
@patch('subprocess.call') | |||
@patch('os.access', return_value=False) | |||
def test_run_tensorboard_locally_without_tensorboard_binary(time, strftime, popen, call, access, sagemaker_session): | |||
def test_run_tensorboard_locally_without_tensorboard_binary(time, strftime, popen, call, access, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
@@ -251,9 +266,11 @@ def test_run_tensorboard_locally(time, strftime, popen, call, access, sagemaker_ | |||
@patch('time.strftime', return_value=TIMESTAMP) | |||
@patch('time.time', return_value=TIME) | |||
@pytest.mark.skip(reason="this test fails sometimes and it needs further investigation") | |||
def test_run_tensorboard_locally_port_in_use(time, strftime, popen, call, access, socket, sagemaker_session): | |||
def test_run_tensorboard_locally_port_in_use(time, strftime, popen, call, access, socket, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
@@ -266,51 +283,103 @@ def test_run_tensorboard_locally_port_in_use(time, strftime, popen, call, access | |||
stderr=-1, stdout=-1) | |||
|
|||
|
|||
def test_tf_checkpoint_not_set(sagemaker_session): | |||
def test_tf_checkpoint_not_set(sagemaker_session, tf_version): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
output_path="s3://{}/".format(sagemaker_session.default_bucket())) | ||
tf.fit(inputs=s3_input('s3://mybucket/train'), job_name=job_name) | ||
|
||
expected_result = '"s3://{}/{}/checkpoints"'.format(sagemaker_session.default_bucket(), job_name) | ||
assert tf.hyperparameters()['checkpoint_path'] == expected_result | ||
|
||
|
||
def test_tf_training_and_evaluation_steps_not_set(sagemaker_session): | ||
def test_tf_training_and_evaluation_steps_not_set(sagemaker_session, tf_version): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
tf.fit(inputs=s3_input('s3://mybucket/train')) | ||
assert tf.hyperparameters()['training_steps'] == 'null' | ||
assert tf.hyperparameters()['evaluation_steps'] == 'null' | ||
|
||
|
||
def test_tf_training_and_evaluation_steps(sagemaker_session): | ||
def test_tf_training_and_evaluation_steps(sagemaker_session, tf_version): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
tests/unit/test_tf_estimator.py
Outdated
tf.fit(inputs=s3_input('s3://mybucket/train')) | ||
assert tf.hyperparameters()['training_steps'] == '123' | ||
assert tf.hyperparameters()['evaluation_steps'] == '456' | ||
|
||
|
||
def test_tf_checkpoint_set(sagemaker_session): | ||
tf = _build_tf(sagemaker_session, checkpoint_path='s3://my_checkpoint_bucket') | ||
def test_tf_checkpoint_set(sagemaker_session, tf_version): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test does not need to run for both versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something I'm not set on, but worth thinking about: right now, if the user passes in an invalid framework_version, it won't fail until EASE tries to pull the image, right? If we have some kind of validation, that could save frustration.
On the other hand, we don't want to over-validate - when we release new images, people should be able to use them without updating the python SDK.
Maybe a good compromise would be to validate just the format of the tag?
README.rst
Outdated
|
||
Each Docker container has the following dependencies installed: | ||
The Docker containers have the following dependencies installed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's change this to "Docker images".
@@ -682,19 +682,24 @@ When training and deploying training scripts, SageMaker runs your Python script | |||
|
|||
SageMaker runs MXNet Estimator scripts in either Python 2.7 or Python 3.5. You can select the Python version by passing a ``py_version`` keyword arg to the MXNet Estimator constructor. Setting this to ``py2`` (the default) will cause your training script to be run on Python 2.7. Setting this to ``py3`` will cause your training script to be run on Python 3.5. This Python version applies to both the Training Job, created by fit, and the Endpoint, created by deploy. | |||
|
|||
Your MXNet training script will be run on version 0.12 of MXNet, built for either GPU or CPU use. The decision to use the GPU or CPU version of MXNet is made by the train_instance_type, set on the MXNet constructor. If you choose a GPU instance type, your training job will be run on a GPU version of MXNet. If you choose a CPU instance type, your training job will be run on a CPU version of MXNet. Similarly, when you call deploy, specifying a GPU or CPU deploy_instance_type, will control which MXNet build your Endpoint runs. | |||
Your MXNet training script will be run on version 1.0.0 (by default) or 0.12 of MXNet, built for either GPU or CPU use. The decision to use the GPU or CPU version of MXNet is made by the ``train_instance_type``, set on the MXNet constructor. If you choose a GPU instance type, your training job will be run on a GPU version of MXNet. If you choose a CPU instance type, your training job will be run on a CPU version of MXNet. Similarly, when you call deploy, specifying a GPU or CPU deploy_instance_type, will control which MXNet build your Endpoint runs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we write this to be more future-proof? Maybe something like, "will run using the latest supported version of MXNet by default. See (something) for our currently supported versions. You can choose older versions by..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we still would need to make an sdk change to change the default.
I believe it adds value to say what version we support by default.
@@ -735,7 +740,7 @@ Preparing the TensorFlow training script | |||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |||
|
|||
Your TensorFlow training script must be a **Python 2.7** source file. The current supported TensorFlow | |||
version is **1.4.0**. This training script **must contain** the following functions: | |||
versions are **1.5.0 (default)** and **1.4.1**. This training script **must contain** the following functions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here about future-proofing.
src/sagemaker/fw_utils.py
Outdated
@@ -28,26 +28,28 @@ | |||
""" | |||
|
|||
|
|||
def create_image_uri(region, framework, instance_type, py_version='py2', tag='1.0', account='520713654638'): | |||
def create_image_uri(region, framework, instance_type, framework_version, py_version='py2', account='520713654638'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might make sense to take out the default argument for py_version as well - what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have an explanation on why it was defaulted at all.
src/sagemaker/fw_utils.py
Outdated
account (str): AWS account that contains the image. (default: '520713654638') | ||
|
||
Returns: | ||
str: The appropriate image URI based on the given parameters. | ||
""" | ||
device_version = 'cpu' | ||
device_type = 'cpu' | ||
# Instance types that start with G, P are GPU powered: https://aws.amazon.com/ec2/instance-types/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should link to https://aws.amazon.com/sagemaker/pricing/instance-types/ instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point :)
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
MXNET_VERSION = '1.0' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe "DEFAULT_MXNET_VERSION"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defaults.DEFAULT_MXNET_VERSION ? I think it's a little bit redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, whoops.
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
TF_VERSION = '1.5' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"DEFAULT_TF_VERSION"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defaults.DEFAULT_TF_VERSION ? I think it's a little bit redundant.
init_params['py_version'] = py_version | ||
|
||
# For backward compatibility map deprecated container tag to a framework version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
container tag -> image tag
@@ -78,8 +82,8 @@ def create_model(self, model_server_workers=None): | |||
return MXNetModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir, | |||
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name, | |||
container_log_level=self.container_log_level, code_location=self.code_location, | |||
py_version=self.py_version, model_server_workers=model_server_workers, | |||
sagemaker_session=self.sagemaker_session) | |||
py_version=self.py_version, framework_version=self.framework_version, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably want an extra assertion in the appropriate unit test that we pass along the framework version correctly to the created model (for both TF and MXNet).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already do this indirectly.
I guess it would be good to have a unit test that tests just model creation.
@@ -47,10 +48,12 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_versio | |||
to convert them before training. | |||
py_version (str): Python version you want to use for executing your model training code (default: 'py2'). | |||
One of 'py2' or 'py3'. | |||
framework_version (str): MXNet version you want to use for executing your model training code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add info to this docstring to tell users how they can figure out what valid values are? (Same for TF as well.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! 👍
1.1.0 | ||
===== | ||
|
||
* feature: Estimators: add support for TensorFlow-1.5.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update after master branch merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do.
init_params['py_version'] = py_version | ||
|
||
# For backward compatibility map deprecated container tag to a framework version | ||
# otherwise extract framework version from the tag itself. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have an utility function that would return version from the tag? Otherwise we hardcode this logic (element [0] of the tag) in 2 different places (for both frameworks).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i really thought about it, but the version is specific to mxnet and tensorflow and it's too little to generalize.
Also we should be able to remove it at some point in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I had in mind is that the logic to deal with tag parsing should be externalized because we have same pattern for all the frameworks. It' same way as we do with 'framework_name_from_image'.
Version is specific to each but getting that version is common. If the pattern chanes in the future we will only update it in 1 place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, yeah probably makes sense even for 1 line function.
import pytest | ||
|
||
|
||
@pytest.fixture(scope='module', params=["1.4", "1.4.1", "1.5", "1.5.0"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to cover the default case, e.g. pass special word like 'default' and in the test do not pass version at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it would run one of the versions again for every test.
We have a specific test for testing default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true about the double run but I didn't see the special (integ) test for default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/aws/sagemaker-python-sdk/pull/82/files/bc748929fc3088d99d1e5c00787e3ddad1a3a51b#diff-8e5bb247678c2ede89b070ed8c1926f2R321
and we have another similar one for mxnet
tests/unit/test_estimator.py
Outdated
assert framework is None | ||
assert py_ver is None | ||
assert tag is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we get 'None' here? The tag will be '1', right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the way framework_name_from_image works. In this case it:
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
name_match = name_pattern.match(sagemaker_match.group(8))
name_match is None and everything is None as a result.
I don't see the value in changing this behavior since we can't really use the tag if we don't have the framework, similar to python version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, sorry - I missed the 'myown' framework name.
@@ -21,7 +21,8 @@ class MXNet(Framework): | |||
|
|||
__framework_name__ = "mxnet" | |||
|
|||
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2', **kwargs): | |||
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version='py2', | |||
framework_version=MXNET_VERSION, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something to consider: shall the framework_version be part of the Framework class and not be repeated in each of the child classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe in the future, we decided not to do it as part of this change and keep it similar to py_version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥇 🚢 🎉
* Add data_type to hyperparameters (aws#54) When we describe a training job the data type of the hyper parameters is lost because we use a dict[str, str]. This adds a new field to Hyperparameter so that we can convert the datatypes at runtime. instead of validating with isinstance(), we cast the hp value to the type it is meant to be. This enforces a "strongly typed" value. When we deserialize from the API string responses it becomes easier to deal with too. * Add wrapper for LDA. (aws#56) Update CHANGELOG and bump the version number. * Add support for async fit() (aws#59) when calling fit(wait=False) it will return immediately. The training job will carry on even if the process exits. by using attach() the estimator can be retrieved by providing the training job name. _prepare_init_params_from_job_description() is now a classmethod instead of being a static method. Each class is responsible to implement their specific logic to convert a training job description into arguments that can be passed to its own __init__() * Fix Estimator role expansion (aws#68) Instead of manually constructing the role ARN, use the IAM boto client to do it. This properly expands service-roles and regular roles. * Add FM and LDA to the documentation. (aws#66) * Fix description of an argument of sagemaker.session.train (aws#69) * Fix description of an argument of sagemaker.session.train 'input_config' should be an array which has channel objects. * Add a link to the botocore docs * Use 'list' instead of 'array' in the description * Add ntm algorithm with doc, unit tests, integ tests (aws#73) * JSON serializer: predictor.predict accepts dictionaries (aws#62) Add support for serializing python dictionaries to json Add prediction with dictionary in tf iris integ test * Fixing timeouts for PCA async integration test. (aws#78) Execute tf_cifar test without logs to eliminate delay to detect that job has finished. * Fixes in LinearLearner and unit tests addition. (aws#77) * Print out billable seconds after training completes (aws#30) * Added: print out billable seconds after training completes * Fixed: test_session.py to pass unit tests * Fixed: removed offending tzlocal() * Use sagemaker_timestamp when creating endpoint names in integration tests. (aws#81) * Support TensorFlow-1.5.0 and MXNet-1.0.0 (aws#82) * Update .gitignore to ignore pytest_cache. * Support TensorFlow-1.5.0 and MXNet-1.0.0 * Update and refactor tests. Add tests for fw_utils. * Fix typo. * Update changelog for 1.1.0 (aws#85)
Arpin ensemble move
Add support for TensorFlow-1.5.0 and MXNet-1.0.0