Skip to content

change: convert TF legacy mode parameters to hyperparameters in v2 migration script #1534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions buildspec-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,12 @@ phases:
- start_time=`date +%s`
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
tox -e py27,py36,py37 --parallel all -- tests/unit
- ./ci-scripts/displaytime.sh 'py27,py36,py37 unit' $start_time
tox -e py36,py37 --parallel all -- tests/unit
- ./ci-scripts/displaytime.sh 'py36,py37 unit' $start_time

# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
- start_time=`date +%s`
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
IGNORE_COVERAGE=- tox -e py27 --parallel all -- tests/unit
- ./ci-scripts/displaytime.sh 'py27 unit' $start_time
9 changes: 7 additions & 2 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

import ast

from sagemaker.cli.compatibility.v2.modifiers import framework_version
from sagemaker.cli.compatibility.v2 import modifiers

FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()]
FUNCTION_CALL_MODIFIERS = [
modifiers.framework_version.FrameworkVersionEnforcer(),
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
]


class ASTTransformer(ast.NodeTransformer):
Expand All @@ -38,4 +41,6 @@ def visit_Call(self, node):
"""
for function_checker in FUNCTION_CALL_MODIFIERS:
function_checker.check_and_modify_node(node)

ast.fix_missing_locations(node)
return node
5 changes: 5 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@
# language governing permissions and limitations under the License.
"""Classes for modifying AST nodes"""
from __future__ import absolute_import

from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
framework_version,
tf_legacy_mode,
)
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def _is_framework_constructor(self, node):
"""
# Check for <Framework> call
if isinstance(node.func, ast.Name):
if node.func.id in FRAMEWORK_CLASSES:
return True
return node.func.id in FRAMEWORK_CLASSES

# Check for sagemaker.<framework>.<Framework> call
ends_with_framework_constructor = (
Expand Down
150 changes: 150 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Classes to modify TensorFlow legacy mode code to be compatible with SageMaker Python SDK v2."""
# TODO: handle fit(run_tensorboard_locally=True)
from __future__ import absolute_import

import ast

import six

from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier


class TensorFlowLegacyModeConstructorUpgrader(Modifier):
"""A class to turn legacy mode parameters into hyperparameters when
instantiating a TensorFlow estimator.
"""

LEGACY_MODE_PARAMETERS = (
"checkpoint_path",
"evaluation_steps",
"requirements_file",
"training_steps",
)

def node_should_be_modified(self, node):
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with legacy mode.

This looks for the following formats:

- ``TensorFlow``
- ``sagemaker.tensorflow.TensorFlow``

Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
and (2) if ``py_version`` is ``py2`` or not specified.

Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with legacy mode.
"""
return self._is_tf_constructor(node) and self._is_legacy_mode(node)

def _is_tf_constructor(self, node):
"""Checks if the ``ast.Call`` node represents a call of the form
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
"""
# Check for TensorFlow()
if isinstance(node.func, ast.Name):
return node.func.id == "TensorFlow"

# Check for sagemaker.tensorflow.TensorFlow()
ends_with_tensorflow_constructor = (
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
)

is_in_tensorflow_module = (
isinstance(node.func.value, ast.Attribute)
and node.func.value.attr == "tensorflow"
and isinstance(node.func.value.value, ast.Name)
and node.func.value.value.id == "sagemaker"
)

return ends_with_tensorflow_constructor and is_in_tensorflow_module

def _is_legacy_mode(self, node):
"""Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
script_mode = False
py_version = "py2"

for kw in node.keywords:
if kw.arg == "script_mode":
script_mode = bool(kw.value.value)
if kw.arg == "py_version":
py_version = kw.value.s

return not (py_version.startswith("py3") or script_mode)

def modify_node(self, node):
"""Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters
into hyperparameters and set ``script_mode=False``.

The parameters that are converted into hyperparameters:

- ``training_steps``
- ``evaluation_steps``
- ``checkpoint_path``
- ``requirements_file``

Args:
node (ast.Call): a node that represents a TensorFlow constructor.
"""
base_hps = {}
additional_hps = {}
kw_to_remove = [] # remove keyword args after so that none are skipped during iteration

for kw in node.keywords:
if kw.arg == "script_mode":
# remove here because is set to False later regardless of current value
kw_to_remove.append(kw)
if kw.arg == "hyperparameters" and kw.value:
base_hps = dict(zip(kw.value.keys, kw.value.values))
kw_to_remove.append(kw)
if kw.arg in self.LEGACY_MODE_PARAMETERS and kw.value:
hp_key = self._hyperparameter_key_for_param(kw.arg)
additional_hps[hp_key] = kw.value
kw_to_remove.append(kw)

self._remove_keywords(node, kw_to_remove)
self._add_updated_hyperparameters(node, base_hps, additional_hps)

node.keywords.append(ast.keyword(arg="script_mode", value=ast.NameConstant(value=False)))

def _hyperparameter_key_for_param(self, arg):
"""Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter."""
name = "sagemaker_requirements" if arg == "requirements_file" else arg
return ast.Str(s=name)

def _remove_keywords(self, node, keywords):
"""Removes the keywords from the ``ast.Call`` node."""
for kw in keywords:
node.keywords.remove(kw)

def _add_updated_hyperparameters(self, node, base_hps, additional_hps):
"""Combines and adds the hyperparameters to the ``ast.Call`` node's keywords."""
base_hps.update(additional_hps)
updated_hp_keyword = self._to_ast_keyword(base_hps)

if updated_hp_keyword:
node.keywords.append(updated_hp_keyword)

def _to_ast_keyword(self, hps):
"""Returns an ``ast.keyword`` for the ``hyperparameters`` kwarg if there are any."""
if hps:
keys, values = zip(*six.iteritems(hps))
return ast.keyword(arg="hyperparameters", value=ast.Dict(keys=keys, values=values))

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import sys

import pasta
import pytest

from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode


@pytest.fixture(autouse=True)
def skip_if_py2():
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
if sys.version_info.major < 3:
pytest.skip("v2 migration script doesn't support Python 2.")


def test_node_should_be_modified_tf_constructor_legacy_mode():
tf_legacy_mode_constructors = (
"TensorFlow(script_mode=False)",
"TensorFlow(script_mode=None)",
"TensorFlow(py_version='py2')",
"TensorFlow()",
"sagemaker.tensorflow.TensorFlow(script_mode=False)",
"sagemaker.tensorflow.TensorFlow(script_mode=None)",
"sagemaker.tensorflow.TensorFlow(py_version='py2')",
"sagemaker.tensorflow.TensorFlow()",
)

modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

for constructor in tf_legacy_mode_constructors:
node = _ast_call(constructor)
assert modifier.node_should_be_modified(node) is True


def test_node_should_be_modified_tf_constructor_script_mode():
tf_script_mode_constructors = (
"TensorFlow(script_mode=True)",
"TensorFlow(py_version='py3')",
"TensorFlow(py_version='py37')",
"TensorFlow(py_version='py3', script_mode=False)",
"sagemaker.tensorflow.TensorFlow(script_mode=True)",
"sagemaker.tensorflow.TensorFlow(py_version='py3')",
"sagemaker.tensorflow.TensorFlow(py_version='py37')",
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
)

modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

for constructor in tf_script_mode_constructors:
node = _ast_call(constructor)
assert modifier.node_should_be_modified(node) is False


def test_node_should_be_modified_random_function_call():
node = _ast_call("MXNet(py_version='py3')")
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
assert modifier.node_should_be_modified(node) is False


def test_modify_node_set_script_mode_false():
tf_constructors = (
"TensorFlow()",
"TensorFlow(script_mode=False)",
"TensorFlow(script_mode=None)",
)
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()

for constructor in tf_constructors:
node = _ast_call(constructor)
modifier.modify_node(node)
assert "TensorFlow(script_mode=False)" == pasta.dump(node)


def test_modify_node_set_hyperparameters():
tf_constructor = """TensorFlow(
checkpoint_path='s3://foo/bar',
training_steps=100,
evaluation_steps=10,
requirements_file='source/requirements.txt',
)"""

node = _ast_call(tf_constructor)
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
modifier.modify_node(node)

expected_hyperparameters = {
"checkpoint_path": "s3://foo/bar",
"evaluation_steps": 10,
"sagemaker_requirements": "source/requirements.txt",
"training_steps": 100,
}

assert expected_hyperparameters == _hyperparameters_from_node(node)


def test_modify_node_preserve_other_hyperparameters():
tf_constructor = """sagemaker.tensorflow.TensorFlow(
training_steps=100,
evaluation_steps=10,
requirements_file='source/requirements.txt',
hyperparameters={'optimizer': 'sgd', 'lr': 0.1, 'checkpoint_path': 's3://foo/bar'},
)"""

node = _ast_call(tf_constructor)
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
modifier.modify_node(node)

expected_hyperparameters = {
"optimizer": "sgd",
"lr": 0.1,
"checkpoint_path": "s3://foo/bar",
"evaluation_steps": 10,
"sagemaker_requirements": "source/requirements.txt",
"training_steps": 100,
}

assert expected_hyperparameters == _hyperparameters_from_node(node)


def test_modify_node_prefer_param_over_hyperparameter():
tf_constructor = """sagemaker.tensorflow.TensorFlow(
training_steps=100,
requirements_file='source/requirements.txt',
hyperparameters={'training_steps': 10, 'sagemaker_requirements': 'foo.txt'},
)"""

node = _ast_call(tf_constructor)
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
modifier.modify_node(node)

expected_hyperparameters = {
"sagemaker_requirements": "source/requirements.txt",
"training_steps": 100,
}

assert expected_hyperparameters == _hyperparameters_from_node(node)


def _hyperparameters_from_node(node):
for kw in node.keywords:
if kw.arg == "hyperparameters":
keys = [k.s for k in kw.value.keys]
values = [getattr(v, v._fields[0]) for v in kw.value.values]
return dict(zip(keys, values))


def _ast_call(code):
return pasta.parse(code).body[0].value