Skip to content

Commit 3e5a2fc

Browse files
committed
add UnnecessarySortRewriteConfig
fix #19
1 parent 5c9ace4 commit 3e5a2fc

9 files changed

+175
-16
lines changed

build.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ lazy val rules = projectMatrix
101101
.jvmPlatform(rulesCrossVersions)
102102

103103
lazy val inputOutputCommon = Def.settings(
104+
libraryDependencies += "org.scala-lang.modules" %% "scala-collection-compat" % "2.6.0",
104105
libraryDependencies ++= {
105106
if (scalaBinaryVersion.value == "2.13") {
106107
Seq("io.circe" %% "circe-generic-extras" % "0.14.1")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package fix
2+
3+
/*
4+
rule = UnnecessarySortRewrite
5+
UnnecessarySortRewrite.rewriteConfig = addCompatImport
6+
*/
7+
abstract class UnnecessarySortRewriteTestAddImport1 {
8+
def seq: Seq[(Int, Int)]
9+
10+
def foo: Unit = {
11+
seq.sortBy(_._1).head
12+
seq.sortBy(_._2).last
13+
seq.sortBy(_._1).headOption
14+
seq.sortBy(_._2).lastOption
15+
}
16+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package fix
2+
3+
/*
4+
rule = UnnecessarySortRewrite
5+
UnnecessarySortRewrite.rewriteConfig = addCompatImport
6+
*/
7+
abstract class UnnecessarySortRewriteTestAddImport2 {
8+
def seq: Seq[(Int, Int)]
9+
10+
def foo: Unit = {
11+
seq.sortBy(_._1).head
12+
seq.sortBy(_._2).last
13+
}
14+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package fix
2+
3+
/*
4+
rule = UnnecessarySortRewrite
5+
UnnecessarySortRewrite.rewriteConfig = only212methods
6+
*/
7+
abstract class UnnecessarySortRewriteTestOnly212 {
8+
def seq: Seq[(Int, Int)]
9+
10+
def foo: Unit = {
11+
seq.sortBy(_._1).head
12+
seq.sortBy(_._2).last
13+
seq.sortBy(_._1).headOption
14+
seq.sortBy(_._2).lastOption
15+
}
16+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package fix
2+
3+
4+
5+
import scala.collection.compat._
6+
abstract class UnnecessarySortRewriteTestAddImport1 {
7+
def seq: Seq[(Int, Int)]
8+
9+
def foo: Unit = {
10+
seq.minBy(_._1)
11+
seq.maxBy(_._2)
12+
seq.minByOption(_._1)
13+
seq.maxByOption(_._2)
14+
}
15+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package fix
2+
3+
4+
abstract class UnnecessarySortRewriteTestAddImport2 {
5+
def seq: Seq[(Int, Int)]
6+
7+
def foo: Unit = {
8+
seq.minBy(_._1)
9+
seq.maxBy(_._2)
10+
}
11+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package fix
2+
3+
4+
abstract class UnnecessarySortRewriteTestOnly212 {
5+
def seq: Seq[(Int, Int)]
6+
7+
def foo: Unit = {
8+
seq.minBy(_._1)
9+
seq.maxBy(_._2)
10+
seq.sortBy(_._1).headOption
11+
seq.sortBy(_._2).lastOption
12+
}
13+
}

rules/src/main/scala/fix/UnnecessarySort.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ import scala.meta.Term
99
import scala.meta.Position
1010

1111
object UnnecessarySort {
12-
val map: Map[String, String] = Map(
13-
"head" -> "minBy",
12+
val scala213Methods: Map[String, String] = Map(
1413
"headOption" -> "minByOption",
15-
"last" -> "maxBy",
1614
"lastOption" -> "maxByOption"
1715
)
16+
val map: Map[String, String] = Map(
17+
"head" -> "minBy",
18+
"last" -> "maxBy",
19+
) ++ scala213Methods
1820
}
1921

2022
class UnnecessarySort extends SyntacticRule("UnnecessarySort") {
Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,95 @@
11
package fix
22

3+
import metaconfig.ConfDecoder
4+
import metaconfig.ConfError
5+
import metaconfig.Configured
6+
import metaconfig.generic.Surface
37
import scalafix.Patch
8+
import scalafix.v1.Configuration
9+
import scalafix.v1.Rule
410
import scalafix.v1.SyntacticDocument
511
import scalafix.v1.SyntacticRule
12+
import scala.meta.Pkg
13+
import scala.meta.Source
614
import scala.meta.Term
715

8-
class UnnecessarySortRewrite extends SyntacticRule("UnnecessarySortRewrite") {
16+
case class UnnecessarySortRewriteConfig(rewriteConfig: UnnecessarySortRewriteConfig.RewriteConfig)
17+
18+
object UnnecessarySortRewriteConfig {
19+
sealed abstract class RewriteConfig(val value: String) extends Product with Serializable
20+
object RewriteConfig {
21+
val map: Map[String, RewriteConfig] = Seq(Only212Methods, AddCompatImport, Default).map(a => a.value -> a).toMap
22+
23+
implicit val decoder: ConfDecoder[RewriteConfig] =
24+
implicitly[ConfDecoder[String]].flatMap { str =>
25+
Configured.opt(map.get(str))(ConfError.message(s"invalid type ${str}"))
26+
}
27+
}
28+
case object Only212Methods extends RewriteConfig("only212methods")
29+
case object AddCompatImport extends RewriteConfig("addCompatImport")
30+
case object Default extends RewriteConfig("default")
31+
32+
val default: UnnecessarySortRewriteConfig = UnnecessarySortRewriteConfig(rewriteConfig = Default)
33+
34+
implicit val surface: Surface[UnnecessarySortRewriteConfig] =
35+
metaconfig.generic.deriveSurface[UnnecessarySortRewriteConfig]
36+
37+
implicit val decoder: ConfDecoder[UnnecessarySortRewriteConfig] =
38+
metaconfig.generic.deriveDecoder(default)
39+
}
40+
41+
class UnnecessarySortRewrite(config: UnnecessarySortRewriteConfig) extends SyntacticRule("UnnecessarySortRewrite") {
42+
43+
def this() = this(UnnecessarySortRewriteConfig.default)
44+
45+
override def withConfiguration(config: Configuration): Configured[Rule] = {
46+
config.conf.getOrElse("UnnecessarySortRewrite")(this.config).map(newConfig => new UnnecessarySortRewrite(newConfig))
47+
}
48+
949
override def fix(implicit doc: SyntacticDocument): Patch = {
10-
doc.tree.collect {
11-
case t @ Term.Select(
12-
Term.Apply(Term.Select(x1, Term.Name("sortBy")), List(x2)),
13-
Term.Name(methodName)
14-
) if UnnecessarySort.map.contains(methodName) =>
15-
// TODO add `import scala.collection.compat._`
16-
// if Scala 2.12 and minByOption, minByOption
17-
18-
Patch.replaceTree(
19-
t,
20-
Term.Apply(Term.Select(x1, Term.Name(UnnecessarySort.map(methodName))), List(x2)).toString,
21-
)
50+
doc.tree.collect { case src: Source =>
51+
val result = src.collect {
52+
case t @ Term.Select(
53+
Term.Apply(Term.Select(x1, Term.Name("sortBy")), List(x2)),
54+
Term.Name(methodName)
55+
) if UnnecessarySort.map.contains(methodName) =>
56+
val patch1 = Patch.replaceTree(
57+
t,
58+
Term.Apply(Term.Select(x1, Term.Name(UnnecessarySort.map(methodName))), List(x2)).toString,
59+
)
60+
61+
config.rewriteConfig match {
62+
case UnnecessarySortRewriteConfig.Only212Methods =>
63+
if (!UnnecessarySort.scala213Methods.contains(methodName)) {
64+
Option((patch1, false))
65+
} else {
66+
None
67+
}
68+
case UnnecessarySortRewriteConfig.AddCompatImport =>
69+
if (UnnecessarySort.scala213Methods.contains(methodName)) {
70+
Option((patch1, true))
71+
} else {
72+
Option((patch1, false))
73+
}
74+
case UnnecessarySortRewriteConfig.Default =>
75+
Option((patch1, false))
76+
}
77+
}.flatten
78+
79+
if (result.nonEmpty) {
80+
val patch = result.map(_._1).asPatch
81+
val pkg = src.collect { case p: Pkg => p.stats.head }.head
82+
if (result.exists(_._2)) {
83+
Seq(
84+
Patch.addLeft(pkg, "\nimport scala.collection.compat._\n"),
85+
patch,
86+
).asPatch
87+
} else {
88+
patch
89+
}
90+
} else {
91+
Patch.empty
92+
}
2293
}.asPatch
2394
}
2495
}

0 commit comments

Comments
 (0)