46
46
47
47
if sys .version_info >= (3 , 8 ):
48
48
namedExpr = ast .NamedExpr
49
+ astNameConstant = ast .Constant
50
+ astStr = ast .Constant
51
+ astNum = ast .Constant
49
52
else :
50
53
namedExpr = ast .Expr
54
+ astNameConstant = ast .NameConstant
55
+ astStr = ast .Str
56
+ astNum = ast .Num
51
57
52
58
53
59
assertstate_key = StashKey ["AssertionState" ]()
@@ -680,9 +686,12 @@ def run(self, mod: ast.Module) -> None:
680
686
if (
681
687
expect_docstring
682
688
and isinstance (item , ast .Expr )
683
- and isinstance (item .value , ast . Str )
689
+ and isinstance (item .value , astStr )
684
690
):
685
- doc = item .value .s
691
+ if sys .version_info >= (3 , 8 ):
692
+ doc = item .value .value
693
+ else :
694
+ doc = item .value .s
686
695
if self .is_rewrite_disabled (doc ):
687
696
return
688
697
expect_docstring = False
@@ -814,7 +823,7 @@ def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
814
823
current = self .stack .pop ()
815
824
if self .stack :
816
825
self .explanation_specifiers = self .stack [- 1 ]
817
- keys = [ast . Str (key ) for key in current .keys ()]
826
+ keys = [astStr (key ) for key in current .keys ()]
818
827
format_dict = ast .Dict (keys , list (current .values ()))
819
828
form = ast .BinOp (expl_expr , ast .Mod (), format_dict )
820
829
name = "@py_format" + str (next (self .variable_counter ))
@@ -868,16 +877,16 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
868
877
negation = ast .UnaryOp (ast .Not (), top_condition )
869
878
870
879
if self .enable_assertion_pass_hook : # Experimental pytest_assertion_pass hook
871
- msg = self .pop_format_context (ast . Str (explanation ))
880
+ msg = self .pop_format_context (astStr (explanation ))
872
881
873
882
# Failed
874
883
if assert_ .msg :
875
884
assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
876
885
gluestr = "\n >assert "
877
886
else :
878
- assertmsg = ast . Str ("" )
887
+ assertmsg = astStr ("" )
879
888
gluestr = "assert "
880
- err_explanation = ast .BinOp (ast . Str (gluestr ), ast .Add (), msg )
889
+ err_explanation = ast .BinOp (astStr (gluestr ), ast .Add (), msg )
881
890
err_msg = ast .BinOp (assertmsg , ast .Add (), err_explanation )
882
891
err_name = ast .Name ("AssertionError" , ast .Load ())
883
892
fmt = self .helper ("_format_explanation" , err_msg )
@@ -893,8 +902,8 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
893
902
hook_call_pass = ast .Expr (
894
903
self .helper (
895
904
"_call_assertion_pass" ,
896
- ast . Num (assert_ .lineno ),
897
- ast . Str (orig ),
905
+ astNum (assert_ .lineno ),
906
+ astStr (orig ),
898
907
fmt_pass ,
899
908
)
900
909
)
@@ -913,7 +922,7 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
913
922
variables = [
914
923
ast .Name (name , ast .Store ()) for name in self .format_variables
915
924
]
916
- clear_format = ast .Assign (variables , ast . NameConstant (None ))
925
+ clear_format = ast .Assign (variables , astNameConstant (None ))
917
926
self .statements .append (clear_format )
918
927
919
928
else : # Original assertion rewriting
@@ -924,9 +933,9 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
924
933
assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
925
934
explanation = "\n >assert " + explanation
926
935
else :
927
- assertmsg = ast . Str ("" )
936
+ assertmsg = astStr ("" )
928
937
explanation = "assert " + explanation
929
- template = ast .BinOp (assertmsg , ast .Add (), ast . Str (explanation ))
938
+ template = ast .BinOp (assertmsg , ast .Add (), astStr (explanation ))
930
939
msg = self .pop_format_context (template )
931
940
fmt = self .helper ("_format_explanation" , msg )
932
941
err_name = ast .Name ("AssertionError" , ast .Load ())
@@ -938,7 +947,7 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
938
947
# Clear temporary variables by setting them to None.
939
948
if self .variables :
940
949
variables = [ast .Name (name , ast .Store ()) for name in self .variables ]
941
- clear = ast .Assign (variables , ast . NameConstant (None ))
950
+ clear = ast .Assign (variables , astNameConstant (None ))
942
951
self .statements .append (clear )
943
952
# Fix locations (line numbers/column offsets).
944
953
for stmt in self .statements :
@@ -952,20 +961,20 @@ def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
952
961
# thinks it's acceptable.
953
962
locs = ast .Call (self .builtin ("locals" ), [], [])
954
963
target_id = name .target .id # type: ignore[attr-defined]
955
- inlocs = ast .Compare (ast . Str (target_id ), [ast .In ()], [locs ])
964
+ inlocs = ast .Compare (astStr (target_id ), [ast .In ()], [locs ])
956
965
dorepr = self .helper ("_should_repr_global_name" , name )
957
966
test = ast .BoolOp (ast .Or (), [inlocs , dorepr ])
958
- expr = ast .IfExp (test , self .display (name ), ast . Str (target_id ))
967
+ expr = ast .IfExp (test , self .display (name ), astStr (target_id ))
959
968
return name , self .explanation_param (expr )
960
969
961
970
def visit_Name (self , name : ast .Name ) -> Tuple [ast .Name , str ]:
962
971
# Display the repr of the name if it's a local variable or
963
972
# _should_repr_global_name() thinks it's acceptable.
964
973
locs = ast .Call (self .builtin ("locals" ), [], [])
965
- inlocs = ast .Compare (ast . Str (name .id ), [ast .In ()], [locs ])
974
+ inlocs = ast .Compare (astStr (name .id ), [ast .In ()], [locs ])
966
975
dorepr = self .helper ("_should_repr_global_name" , name )
967
976
test = ast .BoolOp (ast .Or (), [inlocs , dorepr ])
968
- expr = ast .IfExp (test , self .display (name ), ast . Str (name .id ))
977
+ expr = ast .IfExp (test , self .display (name ), astStr (name .id ))
969
978
return name , self .explanation_param (expr )
970
979
971
980
def visit_BoolOp (self , boolop : ast .BoolOp ) -> Tuple [ast .Name , str ]:
@@ -1003,7 +1012,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
1003
1012
self .push_format_context ()
1004
1013
res , expl = self .visit (v )
1005
1014
body .append (ast .Assign ([ast .Name (res_var , ast .Store ())], res ))
1006
- expl_format = self .pop_format_context (ast . Str (expl ))
1015
+ expl_format = self .pop_format_context (astStr (expl ))
1007
1016
call = ast .Call (app , [expl_format ], [])
1008
1017
self .expl_stmts .append (ast .Expr (call ))
1009
1018
if i < levels :
@@ -1015,7 +1024,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
1015
1024
self .statements = body = inner
1016
1025
self .statements = save
1017
1026
self .expl_stmts = fail_save
1018
- expl_template = self .helper ("_format_boolop" , expl_list , ast . Num (is_or ))
1027
+ expl_template = self .helper ("_format_boolop" , expl_list , astNum (is_or ))
1019
1028
expl = self .pop_format_context (expl_template )
1020
1029
return ast .Name (res_var , ast .Load ()), self .explanation_param (expl )
1021
1030
@@ -1118,9 +1127,9 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
1118
1127
next_expl = f"({ next_expl } )"
1119
1128
results .append (next_res )
1120
1129
sym = BINOP_MAP [op .__class__ ]
1121
- syms .append (ast . Str (sym ))
1130
+ syms .append (astStr (sym ))
1122
1131
expl = f"{ left_expl } { sym } { next_expl } "
1123
- expls .append (ast . Str (expl ))
1132
+ expls .append (astStr (expl ))
1124
1133
res_expr = ast .Compare (left_res , [op ], [next_res ])
1125
1134
self .statements .append (ast .Assign ([store_names [i ]], res_expr ))
1126
1135
left_res , left_expl = next_res , next_expl
0 commit comments