Skip to content

Commit 67511b2

Browse files
Make the breakOut rewrite rule cross compatible (fix #80)
1 parent d18fe86 commit 67511b2

File tree

11 files changed

+252
-140
lines changed

11 files changed

+252
-140
lines changed

build.sbt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ lazy val compat = crossProject(JSPlatform, JVMPlatform)
2222
.jvmSettings(scalaModuleSettingsJVM)
2323
.settings(
2424
name := "scala-collection-compat",
25-
version := "0.1-SNAPSHOT",
25+
version := "0.2.0-SNAPSHOT",
2626
scalacOptions ++= Seq("-feature", "-language:higherKinds", "-language:implicitConversions"),
2727
unmanagedSourceDirectories in Compile += {
2828
val sharedSourceDir = baseDirectory.value.getParentFile / "src/main"
@@ -52,6 +52,7 @@ lazy val scalafixRules = project
5252
.in(file("scalafix/rules"))
5353
.settings(
5454
organization := (organization in compatJVM).value,
55+
version := (version in compatJVM).value,
5556
name := "scala-collection-migrations",
5657
scalaVersion := scalafixScala212,
5758
libraryDependencies += "ch.epfl.scala" %% "scalafix-core" % scalafixVersion

compat/src/main/scala-2.11_2.12/scala/collection/compat/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ package object compat {
5858
def sameElements[B >: A](that: IterableOnce[B]): Boolean = {
5959
self.sameElements(that.iterator)
6060
}
61+
def concat[B >: A](that: IterableOnce[B]): IterableOnce[B] = self ++ that
6162
}
6263

6364
implicit class TraversableOnceExtensionMethods[A](private val self: TraversableOnce[A]) extends AnyVal {
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package test.scala.collection
2+
3+
import org.junit.Test
4+
import org.junit.Assert._
5+
import scala.collection.compat._
6+
7+
class Playground {
8+
@Test
9+
def t(): Unit = {
10+
11+
}
12+
}

scalafix/input/src/main/scala/fix/BreakoutSrc.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
rule = "scala:fix.NewCollections"
2+
rule = "scala:fix.CrossCompat"
33
*/
44
package fix
55

@@ -13,13 +13,14 @@ object BreakoutSrc {
1313
xs.map(_ + 1)(breakOut): Set[Int]
1414
xs.reverseMap(_ + 1)(breakOut): Set[Int]
1515
xs.scanLeft(0)((a, b) => a + b)(breakOut): Set[Int]
16-
xs.union(xs)(breakOut): Set[Int]
1716
xs.updated(0, 1)(breakOut): Set[Int]
18-
xs.zip(xs)(breakOut): Map[Int, Int]
19-
xs.zipAll(xs, 0, 0)(breakOut): Array[(Int, Int)]
2017

2118
(xs ++ xs)(breakOut): Set[Int]
2219
(1 +: xs)(breakOut): Set[Int]
2320
(xs :+ 1)(breakOut): Set[Int]
2421
(xs ++: xs)(breakOut): Set[Int]
22+
23+
xs.union(xs)(breakOut): Set[Int]
24+
xs.zip(xs)(breakOut): Map[Int, Int]
25+
xs.zipAll(xs, 0, 0)(breakOut): Array[(Int, Int)]
2526
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/*
2+
rule = "scala:fix.CrossCompat"
3+
*/
4+
package fix
5+
6+
object Playground {
7+
8+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
3+
4+
package fix
5+
6+
import scala.collection.compat._
7+
8+
object BreakoutSrc {
9+
val xs = List(1, 2, 3)
10+
11+
xs.iterator.collect{ case x => x }.to(scala.collection.immutable.Set): Set[Int]
12+
xs.iterator.flatMap(x => List(x)).to(scala.collection.SortedSet): collection.SortedSet[Int]
13+
xs.iterator.map(_ + 1).to(scala.collection.immutable.Set): Set[Int]
14+
xs.reverseIterator.map(_ + 1).to(scala.collection.immutable.Set): Set[Int]
15+
xs.iterator.scanLeft(0)((a, b) => a + b).to(scala.collection.immutable.Set): Set[Int]
16+
xs.view.updated(0, 1).to(scala.collection.immutable.Set): Set[Int]
17+
18+
(xs.iterator ++ xs).to(scala.collection.immutable.Set): Set[Int]
19+
(1 +: xs.view).to(scala.collection.immutable.Set): Set[Int]
20+
(xs.view :+ 1).to(scala.collection.immutable.Set): Set[Int]
21+
(xs ++: xs.view).to(scala.collection.immutable.Set): Set[Int]
22+
23+
xs.iterator.concat(xs).to(scala.collection.immutable.Set): Set[Int]
24+
xs.iterator.zip(xs.iterator).toMap: Map[Int, Int]
25+
xs.iterator.zipAll(xs.iterator, 0, 0).to(scala.Array): Array[(Int, Int)]
26+
}
27+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
3+
4+
package fix
5+
6+
object Playground {
7+
8+
}

scalafix/output213/src/main/scala/fix/BreakoutSrc.scala

Lines changed: 0 additions & 24 deletions
This file was deleted.

scalafix/rules/src/main/scala/fix/NewCollections.scala

Lines changed: 2 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -21,60 +21,6 @@ case class NewCollections(index: SemanticdbIndex)
2121
val retainMap = normalized("_root_.scala.collection.mutable.MapLike.retain.")
2222
val retainSet = normalized("_root_.scala.collection.mutable.SetLike.retain.")
2323

24-
object Breakout {
25-
implicit class RichSymbol(val symbol: Symbol) {
26-
def exact(tree: Tree)(implicit index: SemanticdbIndex): Boolean =
27-
index.symbol(tree).fold(false)(_ == symbol)
28-
}
29-
30-
val breakOut = SymbolMatcher.exact(Symbol("_root_.scala.collection.package.breakOut(Lscala/collection/generic/CanBuildFrom;)Lscala/collection/generic/CanBuildFrom;."))
31-
32-
// infix operators
33-
val `List ++` = Symbol("_root_.scala.collection.immutable.List#`++`(Lscala/collection/GenTraversableOnce;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
34-
val `List +:` = Symbol("_root_.scala.collection.immutable.List#`+:`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
35-
val `SeqLike :+` = Symbol("_root_.scala.collection.SeqLike#`:+`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
36-
val `TraversableLike ++:` = Symbol("_root_.scala.collection.TraversableLike#`++:`(Lscala/collection/Traversable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
37-
38-
val operatorsIteratorSymbols = List(`List ++`)
39-
val operatorsViewSymbols = List(`List +:`, `SeqLike :+`, `TraversableLike ++:`)
40-
val operatorsSymbols = operatorsViewSymbols ++ operatorsIteratorSymbols
41-
42-
val operatorsIterator = SymbolMatcher.exact(operatorsIteratorSymbols: _*)
43-
val operatorsView = SymbolMatcher.exact(operatorsViewSymbols: _*)
44-
val operators = SymbolMatcher.exact(operatorsSymbols: _*)
45-
46-
// select
47-
val `List.collect` = Symbol("_root_.scala.collection.immutable.List#collect(Lscala/PartialFunction;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
48-
val `List.flatMap` = Symbol("_root_.scala.collection.immutable.List#flatMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
49-
val `List.map` = Symbol("_root_.scala.collection.immutable.List#map(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
50-
val `IterableLike.zip` = Symbol("_root_.scala.collection.IterableLike#zip(Lscala/collection/GenIterable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
51-
val `IterableLike.zipAll` = Symbol("_root_.scala.collection.IterableLike#zipAll(Lscala/collection/GenIterable;Ljava/lang/Object;Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
52-
val `SeqLike.union` = Symbol("_root_.scala.collection.SeqLike#union(Lscala/collection/GenSeq;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
53-
val `SeqLike.updated` = Symbol("_root_.scala.collection.SeqLike#updated(ILjava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
54-
val `SeqLike.reverseMap` = Symbol("_root_.scala.collection.SeqLike#reverseMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")
55-
56-
val functionsIteratorSymbols = List(`List.collect`, `List.flatMap`, `List.map`, `IterableLike.zip`, `IterableLike.zipAll`, `SeqLike.union`)
57-
val functionsViewSymbols = List(`SeqLike.updated`)
58-
val functionsReverseIteratorSymbols = List(`SeqLike.reverseMap`)
59-
val functionsSymbols = functionsIteratorSymbols ++ functionsViewSymbols ++ functionsReverseIteratorSymbols
60-
61-
val functionsIterator = SymbolMatcher.exact(functionsIteratorSymbols: _*)
62-
val functionsReverseIterator = SymbolMatcher.exact(functionsReverseIteratorSymbols: _*)
63-
val functionsView = SymbolMatcher.exact(functionsViewSymbols: _*)
64-
val functions = SymbolMatcher.exact(functionsSymbols: _*)
65-
66-
// special select
67-
68-
// iterator
69-
val `TraversableLike.scanLeft` = SymbolMatcher.exact(Symbol("_root_.scala.collection.TraversableLike#scanLeft(Ljava/lang/Object;Lscala/Function2;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;."))
70-
71-
def isLeftAssociative(tree: Tree): Boolean =
72-
tree match {
73-
case Term.Name(value) => value.last != ':'
74-
case _ => false
75-
}
76-
}
77-
7824
// == Rules ==
7925

8026
def replaceSymbols(ctx: RuleCtx): Patch = {
@@ -146,66 +92,13 @@ case class NewCollections(index: SemanticdbIndex)
14692
}.asPatch
14793
}
14894

149-
def replaceBreakout(ctx: RuleCtx): Patch = {
150-
import Breakout._
151-
152-
def fixIt(intermediate: String, lhs: Term, ap: Term, breakout: Tree): Patch = {
153-
ctx.addRight(lhs, "." + intermediate) +
154-
ctx.addRight(ap, ".to") +
155-
ctx.replaceTree(breakout, "implicitly")
156-
}
157-
158-
ctx.tree.collect {
159-
case i: Importee if breakOut.matches(i) =>
160-
ctx.removeImportee(i)
161-
162-
case Term.Apply(ap @ Term.ApplyInfix(lhs, operators(op), _, List(rhs)), List(breakOut(bo))) =>
163-
val subject =
164-
if(isLeftAssociative(op)) lhs
165-
else rhs
166-
167-
val intermediate =
168-
op match {
169-
case operatorsIterator(_) => "iterator"
170-
case operatorsView(_) => "view"
171-
// since operators(op) matches iterator and view
172-
case _ => throw new Exception("impossible")
173-
}
174-
175-
fixIt(intermediate, subject, ap, bo)
176-
177-
case Term.Apply(ap @ Term.Apply(Term.Select(lhs, functions(op)), _), List(breakOut(bo))) =>
178-
val intermediate =
179-
op match {
180-
case functionsIterator(_) => "iterator"
181-
case functionsView(_) => "view"
182-
case functionsReverseIterator(_) => "reverseIterator"
183-
// since functions(op) matches iterator, view and reverseIterator
184-
case _ => throw new Exception("impossible")
185-
}
186-
187-
val replaceUnion =
188-
if (`SeqLike.union`.exact(op)) ctx.replaceTree(op, "concat")
189-
else Patch.empty
190-
191-
val isReversed = `SeqLike.reverseMap`.exact(op)
192-
val replaceReverseMap =
193-
if (isReversed) ctx.replaceTree(op, "map")
194-
else Patch.empty
195-
196-
fixIt(intermediate, lhs, ap, bo) + replaceUnion + replaceReverseMap
197-
198-
case Term.Apply(ap @ Term.Apply(Term.Apply(Term.Select(lhs, `TraversableLike.scanLeft`(op)), _), _), List(breakOut(bo))) =>
199-
fixIt("iterator", lhs, ap, bo)
200-
}.asPatch
201-
}
95+
20296

20397
override def fix(ctx: RuleCtx): Patch = {
20498
super.fix(ctx) +
20599
replaceSymbols(ctx) +
206100
replaceTupleZipped(ctx) +
207101
replaceMutableMap(ctx) +
208-
replaceMutableSet(ctx) +
209-
replaceBreakout(ctx)
102+
replaceMutableSet(ctx)
210103
}
211104
}

0 commit comments

Comments
 (0)