File tree Expand file tree Collapse file tree 7 files changed +60
-2
lines changed Expand file tree Collapse file tree 7 files changed +60
-2
lines changed Original file line number Diff line number Diff line change
1
+ TensorFlow(entry_point="foo.py")
2
+ sagemaker.tensorflow.TensorFlow()
3
+ m = MXNet()
4
+ sagemaker.mxnet.MXNet()
Original file line number Diff line number Diff line change
1
+ import unittest
2
+
3
+
4
+ class FrameworkVersion (unittest .TestCase ):
5
+ def setUp (self ) -> None :
6
+ pass
7
+
8
+ def test_something (self ):
9
+ self .assertEqual (True , False )
10
+
11
+
12
+ if __name__ == '__main__' :
13
+ unittest .main ()
Original file line number Diff line number Diff line change
1
+ import ast
2
+ import unittest
3
+
4
+ from tests .unit .v2 .utils import get_sample_file
5
+ from tools .compatibility .v2 .ast_transformer import ASTTransformer
6
+ import pasta
7
+
8
+
9
+ class TransformerTest (unittest .TestCase ):
10
+ def setUp (self ) -> None :
11
+ self .transformer_class = ASTTransformer ()
12
+
13
+ def test_simple_transform (self ):
14
+ sample = get_sample_file ('simple.txt' )
15
+ rewrite = self .transformer_class .visit (
16
+ ast .parse (
17
+ sample
18
+ )
19
+ )
20
+
21
+ expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0')
22
+ sagemaker.tensorflow.TensorFlow(framework_version='1.11.0')
23
+ m = MXNet(framework_version='1.2.0')
24
+ sagemaker.mxnet.MXNet(framework_version='1.2.0')\n """
25
+
26
+ self .assertEqual (pasta .dump (rewrite ), expected )
27
+
28
+
29
+ if __name__ == '__main__' :
30
+ unittest .main ()
31
+
32
+
Original file line number Diff line number Diff line change
1
+ from os .path import join
2
+
3
+ SAMPLES_DIRECTORY = "/Users/owahab/Desktop/personal/sagemaker-python-sdk/tests/unit/v2/samples/"
4
+
5
+
6
+ def get_sample_file (filename ):
7
+ file_path = join (SAMPLES_DIRECTORY , filename )
8
+ with open (file_path ) as file_content :
9
+ return file_content .read ()
Original file line number Diff line number Diff line change 15
15
16
16
import ast
17
17
18
- from modifiers import framework_version
18
+ from tools . compatibility . v2 . modifiers import framework_version
19
19
20
20
FUNCTION_CALL_MODIFIERS = [framework_version .FrameworkVersionEnforcer ()]
21
21
Original file line number Diff line number Diff line change 15
15
16
16
import ast
17
17
18
- from modifiers .modifier import Modifier
18
+ from tools . compatibility . v2 . modifiers .modifier import Modifier
19
19
20
20
FRAMEWORK_DEFAULTS = {
21
21
"Chainer" : "4.1.0" ,
You can’t perform that action at this time.
0 commit comments