Skip to content

Commit 872ba39

Browse files
author
owahab
committed
Major WIP.
1 parent 57b2a22 commit 872ba39

File tree

7 files changed

+60
-2
lines changed

7 files changed

+60
-2
lines changed

tests/unit/v2/__init__.py

Whitespace-only changes.

tests/unit/v2/samples/simple.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
TensorFlow(entry_point="foo.py")
2+
sagemaker.tensorflow.TensorFlow()
3+
m = MXNet()
4+
sagemaker.mxnet.MXNet()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
3+
4+
class FrameworkVersion(unittest.TestCase):
5+
def setUp(self) -> None:
6+
pass
7+
8+
def test_something(self):
9+
self.assertEqual(True, False)
10+
11+
12+
if __name__ == '__main__':
13+
unittest.main()

tests/unit/v2/test_transformer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import ast
2+
import unittest
3+
4+
from tests.unit.v2.utils import get_sample_file
5+
from tools.compatibility.v2.ast_transformer import ASTTransformer
6+
import pasta
7+
8+
9+
class TransformerTest(unittest.TestCase):
10+
def setUp(self) -> None:
11+
self.transformer_class = ASTTransformer()
12+
13+
def test_simple_transform(self):
14+
sample = get_sample_file('simple.txt')
15+
rewrite = self.transformer_class.visit(
16+
ast.parse(
17+
sample
18+
)
19+
)
20+
21+
expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
22+
sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
23+
m = MXNet(framework_version='1.2.0')
24+
sagemaker.mxnet.MXNet(framework_version='1.2.0')\n"""
25+
26+
self.assertEqual(pasta.dump(rewrite), expected)
27+
28+
29+
if __name__ == '__main__':
30+
unittest.main()
31+
32+

tests/unit/v2/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from os.path import join
2+
3+
SAMPLES_DIRECTORY = "/Users/owahab/Desktop/personal/sagemaker-python-sdk/tests/unit/v2/samples/"
4+
5+
6+
def get_sample_file(filename):
7+
file_path = join(SAMPLES_DIRECTORY, filename)
8+
with open(file_path) as file_content:
9+
return file_content.read()

tools/compatibility/v2/ast_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import ast
1717

18-
from modifiers import framework_version
18+
from tools.compatibility.v2.modifiers import framework_version
1919

2020
FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()]
2121

tools/compatibility/v2/modifiers/framework_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import ast
1717

18-
from modifiers.modifier import Modifier
18+
from tools.compatibility.v2.modifiers.modifier import Modifier
1919

2020
FRAMEWORK_DEFAULTS = {
2121
"Chainer": "4.1.0",

0 commit comments

Comments
 (0)