7
7
"go/token"
8
8
"go/types"
9
9
"strconv"
10
- "strings"
11
10
12
11
"golang.org/x/tools/go/analysis"
13
12
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -98,20 +97,14 @@ func checkForStmt(pass *analysis.Pass, forStmt *ast.ForStmt) {
98
97
return
99
98
}
100
99
101
- var (
102
- nExpr ast.Expr
103
- nStr string
104
- )
100
+ var operand ast.Expr
105
101
106
102
switch cond .Op {
107
103
case token .LSS : // ;i < n;
108
104
if isBenchmark (cond .Y ) {
109
105
return
110
106
}
111
107
112
- nExpr = findNExpr (cond .Y )
113
- nStr = findNStr (cond .Y )
114
-
115
108
x , ok := cond .X .(* ast.Ident )
116
109
if ! ok {
117
110
return
@@ -120,14 +113,13 @@ func checkForStmt(pass *analysis.Pass, forStmt *ast.ForStmt) {
120
113
if x .Name != initIdent .Name {
121
114
return
122
115
}
116
+
117
+ operand = cond .Y
123
118
case token .GTR : // ;n > i;
124
119
if isBenchmark (cond .X ) {
125
120
return
126
121
}
127
122
128
- nExpr = findNExpr (cond .X )
129
- nStr = findNStr (cond .X )
130
-
131
123
y , ok := cond .Y .(* ast.Ident )
132
124
if ! ok {
133
125
return
@@ -136,6 +128,8 @@ func checkForStmt(pass *analysis.Pass, forStmt *ast.ForStmt) {
136
128
if y .Name != initIdent .Name {
137
129
return
138
130
}
131
+
132
+ operand = cond .X
139
133
default :
140
134
return
141
135
}
@@ -234,7 +228,7 @@ func checkForStmt(pass *analysis.Pass, forStmt *ast.ForStmt) {
234
228
235
229
bc := & bodyChecker {
236
230
initIdent : initIdent ,
237
- nExpr : nExpr ,
231
+ nExpr : findNExpr ( operand ) ,
238
232
}
239
233
240
234
ast .Inspect (forStmt .Body , bc .check )
@@ -243,19 +237,19 @@ func checkForStmt(pass *analysis.Pass, forStmt *ast.ForStmt) {
243
237
return
244
238
}
245
239
246
- nStr = castNStr (pass , initIdent , nStr )
240
+ rangeX := operandToString (pass , initIdent , operand )
247
241
248
242
pass .Report (analysis.Diagnostic {
249
243
Pos : forStmt .Pos (),
250
244
Message : msg ,
251
245
SuggestedFixes : []analysis.SuggestedFix {
252
246
{
253
- Message : fmt .Sprintf ("Replace loop with `%s := range %s`" , initIdent .Name , nStr ),
247
+ Message : fmt .Sprintf ("Replace loop with `%s := range %s`" , initIdent .Name , rangeX ),
254
248
TextEdits : []analysis.TextEdit {
255
249
{
256
250
Pos : forStmt .Init .Pos (),
257
251
End : forStmt .Post .End (),
258
- NewText : []byte (fmt .Sprintf ("%s := range %s" , initIdent .Name , nStr )),
252
+ NewText : []byte (fmt .Sprintf ("%s := range %s" , initIdent .Name , rangeX )),
259
253
},
260
254
},
261
255
},
@@ -383,22 +377,30 @@ func findNExpr(expr ast.Expr) ast.Expr {
383
377
}
384
378
}
385
379
386
- func findNStr (expr ast.Expr ) string {
380
+ func recursiveOperandToString (expr ast.Expr ) string {
387
381
switch e := expr .(type ) {
388
382
case * ast.CallExpr :
389
- args := make ([]string , len (e .Args ))
383
+ args := ""
384
+
390
385
for i , v := range e .Args {
391
- args [i ] = findNStr (v )
386
+ if i > 0 {
387
+ args += ", "
388
+ }
389
+
390
+ args += recursiveOperandToString (v )
392
391
}
393
- return findNStr (e .Fun ) + "(" + strings .Join (args , ", " ) + ")"
392
+
393
+ return recursiveOperandToString (e .Fun ) + "(" + args + ")"
394
394
case * ast.BasicLit :
395
395
return e .Value
396
396
case * ast.Ident :
397
- return e .String ()
397
+ return e .Name
398
398
case * ast.SelectorExpr :
399
- return findNStr (e .X ) + "." + findNStr (e .Sel )
399
+ return recursiveOperandToString (e .X ) + "." + recursiveOperandToString (e .Sel )
400
400
case * ast.IndexExpr :
401
- return findNStr (e .X ) + "[" + findNStr (e .Index ) + "]"
401
+ return recursiveOperandToString (e .X ) + "[" + recursiveOperandToString (e .Index ) + "]"
402
+ case * ast.BinaryExpr :
403
+ return recursiveOperandToString (e .X ) + " " + e .Op .String () + " " + recursiveOperandToString (e .Y )
402
404
default :
403
405
return ""
404
406
}
@@ -539,13 +541,21 @@ func compareNumberLit(exp ast.Expr, val int) bool {
539
541
}
540
542
}
541
543
542
- func castNStr (pass * analysis.Pass , i * ast.Ident , n string ) string {
543
- initType := pass .TypesInfo .TypeOf (i ).String ()
544
- if initType == "int" {
545
- return n
544
+ func operandToString (pass * analysis.Pass , i * ast.Ident , operand ast.Expr ) string {
545
+ s := recursiveOperandToString (operand )
546
+ t := pass .TypesInfo .TypeOf (i )
547
+
548
+ if t == types .Typ [types .Int ] {
549
+ if len (s ) > 5 && s [:4 ] == "int(" && s [len (s )- 1 ] == ')' {
550
+ s = s [4 : len (s )- 1 ]
551
+ }
552
+
553
+ return s
546
554
}
547
- if _ , err := strconv .Atoi (n ); err != nil {
548
- return n
555
+
556
+ if len (s ) > 2 && s [len (s )- 1 :] == ")" {
557
+ return s
549
558
}
550
- return initType + "(" + n + ")"
559
+
560
+ return t .String () + "(" + s + ")"
551
561
}
0 commit comments