17
17
from sagemaker .cli .compatibility .v2 .modifiers import airflow
18
18
from tests .unit .sagemaker .cli .compatibility .v2 .modifiers .ast_converter import ast_call
19
19
20
-
21
- def test_node_should_be_modified_model_config_with_args ():
22
- model_config_calls = (
23
- " model_config(instance_type, model )" ,
24
- " airflow.model_config(instance_type, model )" ,
25
- "workflow.airflow.model_config(instance_type, model )" ,
26
- "sagemaker.workflow. airflow.model_config(instance_type, model )" ,
27
- " model_config_from_estimator(instance_type, model )" ,
28
- " airflow.model_config_from_estimator(instance_type, model )" ,
29
- "workflow.airflow.model_config_from_estimator(instance_type, model)" ,
30
- "sagemaker.workflow.airflow.model_config_from_estimator(instance_type, model)" ,
31
- )
32
-
20
+ MODEL_CONFIG_CALL_TEMPLATES = (
21
+ "model_config({})" ,
22
+ "airflow.model_config({})" ,
23
+ "workflow.airflow. model_config({} )" ,
24
+ "sagemaker.workflow. airflow.model_config({} )" ,
25
+ "model_config_from_estimator({} )" ,
26
+ " airflow.model_config_from_estimator({} )" ,
27
+ "workflow.airflow. model_config_from_estimator({} )" ,
28
+ "sagemaker.workflow. airflow.model_config_from_estimator({} )" ,
29
+ )
30
+
31
+
32
+ def test_arg_order_node_should_be_modified_model_config_with_args ():
33
33
modifier = airflow .ModelConfigArgModifier ()
34
34
35
- for call in model_config_calls :
36
- node = ast_call (call )
35
+ for template in MODEL_CONFIG_CALL_TEMPLATES :
36
+ node = ast_call (template . format ( "instance_type, model" ) )
37
37
assert modifier .node_should_be_modified (node ) is True
38
38
39
39
40
- def test_node_should_be_modified_model_config_without_args ():
41
- model_config_calls = (
42
- "model_config()" ,
43
- "airflow.model_config()" ,
44
- "workflow.airflow.model_config()" ,
45
- "sagemaker.workflow.airflow.model_config()" ,
46
- "model_config_from_estimator()" ,
47
- "airflow.model_config_from_estimator()" ,
48
- "workflow.airflow.model_config_from_estimator()" ,
49
- "sagemaker.workflow.airflow.model_config_from_estimator()" ,
50
- )
51
-
40
+ def test_arg_order_node_should_be_modified_model_config_without_args ():
52
41
modifier = airflow .ModelConfigArgModifier ()
53
42
54
- for call in model_config_calls :
55
- node = ast_call (call )
43
+ for template in MODEL_CONFIG_CALL_TEMPLATES :
44
+ node = ast_call (template . format ( "" ) )
56
45
assert modifier .node_should_be_modified (node ) is False
57
46
58
47
59
- def test_node_should_be_modified_random_function_call ():
48
+ def test_arg_order_node_should_be_modified_random_function_call ():
60
49
node = ast_call ("sagemaker.workflow.airflow.prepare_framework_container_def()" )
61
50
modifier = airflow .ModelConfigArgModifier ()
62
51
assert modifier .node_should_be_modified (node ) is False
63
52
64
53
65
- def test_modify_node ():
54
+ def test_arg_order_modify_node ():
66
55
model_config_calls = (
67
56
("model_config(instance_type, model)" , "model_config(model, instance_type=instance_type)" ),
68
57
(
@@ -89,3 +78,42 @@ def test_modify_node():
89
78
node = ast_call (call )
90
79
modifier .modify_node (node )
91
80
assert expected == pasta .dump (node )
81
+
82
+
83
+ def test_image_arg_node_should_be_modified_model_config_with_arg ():
84
+ modifier = airflow .ModelConfigImageURIRenamer ()
85
+
86
+ for template in MODEL_CONFIG_CALL_TEMPLATES :
87
+ node = ast_call (template .format ("image=my_image" ))
88
+ assert modifier .node_should_be_modified (node ) is True
89
+
90
+
91
+ def test_image_arg_node_should_be_modified_model_config_without_arg ():
92
+ modifier = airflow .ModelConfigImageURIRenamer ()
93
+
94
+ for template in MODEL_CONFIG_CALL_TEMPLATES :
95
+ node = ast_call (template .format ("" ))
96
+ assert modifier .node_should_be_modified (node ) is False
97
+
98
+
99
+ def test_image_arg_node_should_be_modified_random_function_call ():
100
+ node = ast_call ("sagemaker.workflow.airflow.prepare_framework_container_def()" )
101
+ modifier = airflow .ModelConfigImageURIRenamer ()
102
+ assert modifier .node_should_be_modified (node ) is False
103
+
104
+
105
+ def test_image_arg_modify_node ():
106
+ model_config_calls = (
107
+ ("model_config(image='image:latest')" , "model_config(image_uri='image:latest')" ),
108
+ (
109
+ "model_config_from_estimator(image=my_image)" ,
110
+ "model_config_from_estimator(image_uri=my_image)" ,
111
+ ),
112
+ )
113
+
114
+ modifier = airflow .ModelConfigImageURIRenamer ()
115
+
116
+ for call , expected in model_config_calls :
117
+ node = ast_call (call )
118
+ modifier .modify_node (node )
119
+ assert expected == pasta .dump (node )
0 commit comments