Skip to content

Commit 13a1a83

Browse files
authored
dev: refactor some checks to use forEachKey() (#43)
1 parent 6dd777b commit 13a1a83

File tree

1 file changed

+58
-125
lines changed

1 file changed

+58
-125
lines changed

sloglint.go

+58-125
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ var slogFuncs = map[string]struct {
118118
argsPos int
119119
skipContextCheck bool
120120
}{
121-
// funcName: {argsPos, skipContextCheck}
122121
"log/slog.With": {argsPos: 0, skipContextCheck: true},
123122
"log/slog.Log": {argsPos: 3},
124123
"log/slog.LogAttrs": {argsPos: 3},
@@ -198,7 +197,7 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)
198197

199198
switch opts.NoGlobal {
200199
case "all":
201-
if strings.HasPrefix(name, "log/slog.") || globalLoggerUsed(pass.TypesInfo, call.Fun) {
200+
if strings.HasPrefix(name, "log/slog.") || isGlobalLoggerUsed(pass.TypesInfo, call.Fun) {
202201
pass.Reportf(call.Pos(), "global logger should not be used")
203202
}
204203
case "default":
@@ -217,13 +216,14 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)
217216
}
218217
case "scope":
219218
typ := pass.TypesInfo.TypeOf(call.Args[0])
220-
if typ != nil && typ.String() != "context.Context" && hasContextInScope(pass.TypesInfo, stack) {
219+
if typ != nil && typ.String() != "context.Context" && isContextInScope(pass.TypesInfo, stack) {
221220
pass.Reportf(call.Pos(), "%sContext should be used instead", fn.Name())
222221
}
223222
}
224223
}
225224

226-
if opts.StaticMsg && !staticMsg(call.Args[funcInfo.argsPos-1]) {
225+
msgPos := funcInfo.argsPos - 1
226+
if opts.StaticMsg && !isStaticMsg(call.Args[msgPos]) {
227227
pass.Reportf(call.Pos(), "message should be a string literal or a constant")
228228
}
229229

@@ -259,54 +259,48 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)
259259
pass.Reportf(call.Pos(), "key-value pairs and attributes should not be mixed")
260260
}
261261

262-
if opts.NoRawKeys && rawKeysUsed(pass.TypesInfo, keys, attrs) {
263-
pass.Reportf(call.Pos(), "raw keys should not be used")
262+
if opts.NoRawKeys {
263+
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
264+
if ident, ok := key.(*ast.Ident); !ok || ident.Obj == nil || ident.Obj.Kind != ast.Con {
265+
pass.Reportf(call.Pos(), "raw keys should not be used")
266+
}
267+
})
264268
}
265269

266-
if opts.ArgsOnSepLines && argsOnSameLine(pass.Fset, call, keys, attrs) {
267-
pass.Reportf(call.Pos(), "arguments should be put on separate lines")
270+
checkKeyNamingCase := func(caseFn func(string) string, caseName string) {
271+
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
272+
if name, ok := getKeyName(key); ok && name != caseFn(name) {
273+
pass.Reportf(call.Pos(), "keys should be written in %s", caseName)
274+
}
275+
})
268276
}
269277

270-
if len(opts.ForbiddenKeys) > 0 {
271-
if name, found := badKeyNames(pass.TypesInfo, isForbiddenKey(opts.ForbiddenKeys), keys, attrs); found {
272-
pass.Reportf(call.Pos(), "%q key is forbidden and should not be used", name)
273-
}
278+
switch opts.KeyNamingCase {
279+
case snakeCase:
280+
checkKeyNamingCase(strcase.ToSnake, "snake_case")
281+
case kebabCase:
282+
checkKeyNamingCase(strcase.ToKebab, "kebab-case")
283+
case camelCase:
284+
checkKeyNamingCase(strcase.ToCamel, "camelCase")
285+
case pascalCase:
286+
checkKeyNamingCase(strcase.ToPascal, "PascalCase")
274287
}
275288

276-
switch {
277-
case opts.KeyNamingCase == snakeCase:
278-
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToSnake), keys, attrs); found {
279-
pass.Reportf(call.Pos(), "keys should be written in snake_case")
280-
}
281-
case opts.KeyNamingCase == kebabCase:
282-
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToKebab), keys, attrs); found {
283-
pass.Reportf(call.Pos(), "keys should be written in kebab-case")
284-
}
285-
case opts.KeyNamingCase == camelCase:
286-
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToCamel), keys, attrs); found {
287-
pass.Reportf(call.Pos(), "keys should be written in camelCase")
288-
}
289-
case opts.KeyNamingCase == pascalCase:
290-
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToPascal), keys, attrs); found {
291-
pass.Reportf(call.Pos(), "keys should be written in PascalCase")
292-
}
293-
}
294-
}
295-
296-
func isForbiddenKey(forbiddenKeys []string) func(string) bool {
297-
return func(name string) bool {
298-
return slices.Contains(forbiddenKeys, name)
289+
if len(opts.ForbiddenKeys) > 0 {
290+
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
291+
if name, ok := getKeyName(key); ok && slices.Contains(opts.ForbiddenKeys, name) {
292+
pass.Reportf(call.Pos(), "%q key is forbidden and should not be used", name)
293+
}
294+
})
299295
}
300-
}
301296

302-
func valueChanged(handler func(string) string) func(string) bool {
303-
return func(name string) bool {
304-
return handler(name) != name
297+
if opts.ArgsOnSepLines && areArgsOnSameLine(pass.Fset, call, keys, attrs) {
298+
pass.Reportf(call.Pos(), "arguments should be put on separate lines")
305299
}
306300
}
307301

308-
func globalLoggerUsed(info *types.Info, expr ast.Expr) bool {
309-
selector, ok := expr.(*ast.SelectorExpr)
302+
func isGlobalLoggerUsed(info *types.Info, call ast.Expr) bool {
303+
selector, ok := call.(*ast.SelectorExpr)
310304
if !ok {
311305
return false
312306
}
@@ -318,7 +312,7 @@ func globalLoggerUsed(info *types.Info, expr ast.Expr) bool {
318312
return obj.Parent() == obj.Pkg().Scope()
319313
}
320314

321-
func hasContextInScope(info *types.Info, stack []ast.Node) bool {
315+
func isContextInScope(info *types.Info, stack []ast.Node) bool {
322316
for i := len(stack) - 1; i >= 0; i-- {
323317
decl, ok := stack[i].(*ast.FuncDecl)
324318
if !ok {
@@ -336,8 +330,8 @@ func hasContextInScope(info *types.Info, stack []ast.Node) bool {
336330
return false
337331
}
338332

339-
func staticMsg(expr ast.Expr) bool {
340-
switch msg := expr.(type) {
333+
func isStaticMsg(msg ast.Expr) bool {
334+
switch msg := msg.(type) {
341335
case *ast.BasicLit: // e.g. slog.Info("msg")
342336
return msg.Kind == token.STRING
343337
case *ast.Ident: // e.g. const msg = "msg"; slog.Info(msg)
@@ -347,114 +341,53 @@ func staticMsg(expr ast.Expr) bool {
347341
}
348342
}
349343

350-
func rawKeysUsed(info *types.Info, keys, attrs []ast.Expr) bool {
351-
isConst := func(expr ast.Expr) bool {
352-
ident, ok := expr.(*ast.Ident)
353-
return ok && ident.Obj != nil && ident.Obj.Kind == ast.Con
354-
}
355-
356-
for _, key := range keys {
357-
if !isConst(key) {
358-
return true
359-
}
360-
}
361-
362-
for _, attr := range attrs {
363-
switch attr := attr.(type) {
364-
case *ast.CallExpr: // e.g. slog.Int()
365-
fn := typeutil.StaticCallee(info, attr)
366-
if _, ok := attrFuncs[fn.FullName()]; ok && !isConst(attr.Args[0]) {
367-
return true
368-
}
369-
370-
case *ast.CompositeLit: // slog.Attr{}
371-
isRawKey := func(kv *ast.KeyValueExpr) bool {
372-
return kv.Key.(*ast.Ident).Name == "Key" && !isConst(kv.Value)
373-
}
374-
375-
switch len(attr.Elts) {
376-
case 1: // slog.Attr{Key: ...} | slog.Attr{Value: ...}
377-
kv := attr.Elts[0].(*ast.KeyValueExpr)
378-
if isRawKey(kv) {
379-
return true
380-
}
381-
case 2: // slog.Attr{..., ...} | slog.Attr{Key: ..., Value: ...}
382-
kv1, ok := attr.Elts[0].(*ast.KeyValueExpr)
383-
if ok {
384-
kv2 := attr.Elts[1].(*ast.KeyValueExpr)
385-
if isRawKey(kv1) || isRawKey(kv2) {
386-
return true
387-
}
388-
} else if !isConst(attr.Elts[0]) {
389-
return true
390-
}
391-
}
392-
}
393-
}
394-
395-
return false
396-
}
397-
398-
func badKeyNames(info *types.Info, validationFn func(string) bool, keys, attrs []ast.Expr) (string, bool) {
344+
func forEachKey(info *types.Info, keys, attrs []ast.Expr, fn func(key ast.Expr)) {
399345
for _, key := range keys {
400-
if name, ok := getKeyName(key); ok && validationFn(name) {
401-
return name, true
402-
}
346+
fn(key)
403347
}
404348

405349
for _, attr := range attrs {
406-
var expr ast.Expr
407-
408350
switch attr := attr.(type) {
409351
case *ast.CallExpr: // e.g. slog.Int()
410-
fn := typeutil.StaticCallee(info, attr)
411-
if fn == nil {
352+
callee := typeutil.StaticCallee(info, attr)
353+
if callee == nil {
412354
continue
413355
}
414-
if _, ok := attrFuncs[fn.FullName()]; !ok {
356+
if _, ok := attrFuncs[callee.FullName()]; !ok {
415357
continue
416358
}
417-
expr = attr.Args[0]
359+
fn(attr.Args[0])
418360

419361
case *ast.CompositeLit: // slog.Attr{}
420362
switch len(attr.Elts) {
421363
case 1: // slog.Attr{Key: ...} | slog.Attr{Value: ...}
422364
if kv := attr.Elts[0].(*ast.KeyValueExpr); kv.Key.(*ast.Ident).Name == "Key" {
423-
expr = kv.Value
365+
fn(kv.Value)
424366
}
425-
case 2: // slog.Attr{..., ...} | slog.Attr{Key: ..., Value: ...}
426-
expr = attr.Elts[0]
427-
if kv1, ok := attr.Elts[0].(*ast.KeyValueExpr); ok && kv1.Key.(*ast.Ident).Name == "Key" {
428-
expr = kv1.Value
429-
}
430-
if kv2, ok := attr.Elts[1].(*ast.KeyValueExpr); ok && kv2.Key.(*ast.Ident).Name == "Key" {
431-
expr = kv2.Value
367+
case 2: // slog.Attr{Key: ..., Value: ...} | slog.Attr{Value: ..., Key: ...} | slog.Attr{..., ...}
368+
if kv, ok := attr.Elts[0].(*ast.KeyValueExpr); ok && kv.Key.(*ast.Ident).Name == "Key" {
369+
fn(kv.Value)
370+
} else if kv, ok := attr.Elts[1].(*ast.KeyValueExpr); ok && kv.Key.(*ast.Ident).Name == "Key" {
371+
fn(kv.Value)
372+
} else {
373+
fn(attr.Elts[0])
432374
}
433375
}
434376
}
435-
436-
if name, ok := getKeyName(expr); ok && validationFn(name) {
437-
return name, true
438-
}
439377
}
440-
441-
return "", false
442378
}
443379

444-
func getKeyName(expr ast.Expr) (string, bool) {
445-
if expr == nil {
446-
return "", false
447-
}
448-
if ident, ok := expr.(*ast.Ident); ok {
380+
func getKeyName(key ast.Expr) (string, bool) {
381+
if ident, ok := key.(*ast.Ident); ok {
449382
if ident.Obj == nil || ident.Obj.Decl == nil || ident.Obj.Kind != ast.Con {
450383
return "", false
451384
}
452385
if spec, ok := ident.Obj.Decl.(*ast.ValueSpec); ok && len(spec.Values) > 0 {
453-
// TODO: support len(spec.Values) > 1; e.g. "const foo, bar = 1, 2"
454-
expr = spec.Values[0]
386+
// TODO: support len(spec.Values) > 1; e.g. const foo, bar = 1, 2
387+
key = spec.Values[0]
455388
}
456389
}
457-
if lit, ok := expr.(*ast.BasicLit); ok && lit.Kind == token.STRING {
390+
if lit, ok := key.(*ast.BasicLit); ok && lit.Kind == token.STRING {
458391
// string literals are always quoted.
459392
value, err := strconv.Unquote(lit.Value)
460393
if err != nil {
@@ -465,7 +398,7 @@ func getKeyName(expr ast.Expr) (string, bool) {
465398
return "", false
466399
}
467400

468-
func argsOnSameLine(fset *token.FileSet, call ast.Expr, keys, attrs []ast.Expr) bool {
401+
func areArgsOnSameLine(fset *token.FileSet, call ast.Expr, keys, attrs []ast.Expr) bool {
469402
if len(keys)+len(attrs) <= 1 {
470403
return false // special case: slog.Info("msg", "key", "value") is ok.
471404
}

0 commit comments

Comments
 (0)