diff --git a/build.sbt b/build.sbt index b5b3229f..f0728d3f 100644 --- a/build.sbt +++ b/build.sbt @@ -22,7 +22,7 @@ lazy val compat = crossProject(JSPlatform, JVMPlatform) .jvmSettings(scalaModuleSettingsJVM) .settings( name := "scala-collection-compat", - version := "0.1-SNAPSHOT", + version := "0.2.0-SNAPSHOT", scalacOptions ++= Seq("-feature", "-language:higherKinds", "-language:implicitConversions"), unmanagedSourceDirectories in Compile += { val sharedSourceDir = baseDirectory.value.getParentFile / "src/main" @@ -69,6 +69,7 @@ lazy val scalafixRules = project .in(file("scalafix/rules")) .settings( organization := (organization in compatJVM).value, + version := (version in compatJVM).value, name := "scala-collection-migrations", scalaVersion := scalafixScala212, libraryDependencies += "ch.epfl.scala" %% "scalafix-core" % scalafixVersion diff --git a/compat/src/main/scala-2.11_2.12/scala/collection/compat/package.scala b/compat/src/main/scala-2.11_2.12/scala/collection/compat/package.scala index dc99f4b0..3587e63e 100644 --- a/compat/src/main/scala-2.11_2.12/scala/collection/compat/package.scala +++ b/compat/src/main/scala-2.11_2.12/scala/collection/compat/package.scala @@ -44,6 +44,46 @@ package object compat { def fromSpecific(source: TraversableOnce[Int]): C = fact.apply(source.toSeq: _*) } + private def build[T, CC](builder: m.Builder[T, CC], source: TraversableOnce[T]): CC = { + builder ++= source + builder.result() + } + + implicit class ImmutableSortedMapExtensions(private val fact: i.SortedMap.type) extends AnyVal { + def from[K: Ordering, V](source: TraversableOnce[(K, V)]): i.SortedMap[K, V] = + build(i.SortedMap.newBuilder[K, V], source) + } + + implicit class ImmutableTreeMapExtensions(private val fact: i.TreeMap.type) extends AnyVal { + def from[K: Ordering, V](source: TraversableOnce[(K, V)]): i.TreeMap[K, V] = + build(i.TreeMap.newBuilder[K, V], source) + } + + implicit class ImmutableIntMapExtensions(private val fact: i.IntMap.type) extends AnyVal { + def from[V](source: TraversableOnce[(Int, V)]): i.IntMap[V] = + build(i.IntMap.canBuildFrom[Int, V](), source) + } + + implicit class ImmutableLongMapExtensions(private val fact: i.LongMap.type) extends AnyVal { + def from[V](source: TraversableOnce[(Long, V)]): i.LongMap[V] = + build(i.LongMap.canBuildFrom[Long, V](), source) + } + + implicit class MutableSortedMapExtensions(private val fact: m.SortedMap.type) extends AnyVal { + def from[K: Ordering, V](source: TraversableOnce[(K, V)]): m.SortedMap[K, V] = + build(m.SortedMap.newBuilder[K, V], source) + } + + implicit class MutableTreeMapExtensions(private val fact: m.TreeMap.type) extends AnyVal { + def from[K: Ordering, V](source: TraversableOnce[(K, V)]): m.TreeMap[K, V] = + build(m.TreeMap.newBuilder[K, V], source) + } + + implicit class MutableLongMapExtensions(private val fact: m.LongMap.type) extends AnyVal { + def from[V](source: TraversableOnce[(Long, V)]): m.LongMap[V] = + build(m.LongMap.canBuildFrom[Long, V](), source) + } + implicit class StreamExtensionMethods[A](private val stream: Stream[A]) extends AnyVal { def lazyAppendedAll(as: => TraversableOnce[A]): Stream[A] = stream.append(as) } @@ -58,6 +98,7 @@ package object compat { def sameElements[B >: A](that: IterableOnce[B]): Boolean = { self.sameElements(that.iterator) } + def concat[B >: A](that: IterableOnce[B]): IterableOnce[B] = self ++ that } implicit class TraversableOnceExtensionMethods[A](private val self: TraversableOnce[A]) extends AnyVal { diff --git a/scalafix/input/src/main/scala/fix/BreakoutSrc.scala b/scalafix/input/src/main/scala/fix/BreakoutSrc.scala index 50382e33..d58c703f 100644 --- a/scalafix/input/src/main/scala/fix/BreakoutSrc.scala +++ b/scalafix/input/src/main/scala/fix/BreakoutSrc.scala @@ -1,25 +1,108 @@ /* -rule = "scala:fix.NewCollections" +rule = "scala:fix.CrossCompat" */ package fix import scala.collection.breakOut +import scala.collection.{immutable, mutable} +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits.global -object BreakoutSrc { - val xs = List(1, 2, 3) - - xs.collect{ case x => x }(breakOut): Set[Int] - xs.flatMap(x => List(x))(breakOut): collection.SortedSet[Int] - xs.map(_ + 1)(breakOut): Set[Int] - xs.reverseMap(_ + 1)(breakOut): Set[Int] - xs.scanLeft(0)((a, b) => a + b)(breakOut): Set[Int] - xs.union(xs)(breakOut): Set[Int] - xs.updated(0, 1)(breakOut): Set[Int] - xs.zip(xs)(breakOut): Map[Int, Int] - xs.zipAll(xs, 0, 0)(breakOut): Array[(Int, Int)] - - (xs ++ xs)(breakOut): Set[Int] - (1 +: xs)(breakOut): Set[Int] - (xs :+ 1)(breakOut): Set[Int] - (xs ++: xs)(breakOut): Set[Int] +class BreakoutSrc(ts: Traversable[Int], vec: Vector[Int], list: List[Int], seq: Seq[Int]) { + + // `IndexedSeqOptimized.zip` + vec.zip(vec)(breakOut): Map[Int, Int] + + // `IterableLike.zip` + seq.zip(seq)(breakOut): Map[Int, Int] + + // `IterableLike.zipAll` + seq.zipAll(seq, 0, 0)(breakOut): Array[(Int, Int)] + + // `List ++` + (list ++ list)(breakOut): Set[Int] + + // `List +:` + (1 +: list)(breakOut): Set[Int] + + // `List.collect` + list.collect{ case x => x}(breakOut): Set[Int] + + // `List.flatMap` + list.flatMap(x => List(x))(breakOut): Set[Int] + + // `List.map` + list.map(x => x)(breakOut): Set[Int] + + // `SeqLike.reverseMap` + seq.reverseMap(_ + 1)(breakOut): Set[Int] + + // `SeqLike +:` + (1 +: seq)(breakOut): List[Int] + + // `SeqLike :+` + (seq :+ 1)(breakOut): List[Int] + + // `SeqLike.updated` + (seq.updated(0, 0))(breakOut): List[Int] + + // `SeqLike.union` + seq.union(seq)(breakOut): List[Int] + + + //`SetLike.map` + Set(1).map(x => x)(breakOut): List[Int] + + + // `TraversableLike ++` + (ts ++ ts )(breakOut): Set[Int] + + // `TraversableLike ++:` + (ts ++: ts)(breakOut): Set[Int] + + // `TraversableLike.collect` + ts.collect{ case x => x }(breakOut): Set[Int] + + // `TraversableLike.flatMap` + ts.flatMap(x => List(x))(breakOut): collection.SortedSet[Int] + + // `TraversableLike.map` + ts.map(_ + 1)(breakOut): Set[Int] + + // `TraversableLike.scanLeft` + ts.scanLeft(0)((a, b) => a + b)(breakOut): Set[Int] + + + // `Vector ++` + (vec ++ List(1))(breakOut): List[Int] + + // `Vector +:` + (1 +: vec)(breakOut): List[Int] + + // `Vector :+` + (vec :+ 1)(breakOut): List[Int] + + // `Vector.updated` + (vec.updated(0, 0))(breakOut): List[Int] + + // Future + Future.sequence(List(Future(1)))(breakOut, global): Future[Seq[Int]] + Future.traverse(List(1))(x => Future(x))(breakOut, global): Future[Seq[Int]] + + // Iterable + List(1).map(x => x)(breakOut): Iterator[Int] + + // Specific collections + List(1 -> "1").map(x => x)(breakOut): immutable.SortedMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): immutable.HashMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): immutable.ListMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): immutable.TreeMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): mutable.SortedMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): mutable.HashMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): mutable.ListMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): mutable.TreeMap[Int, String] + List(1 -> "1").map(x => x)(breakOut): mutable.Map[Int, String] + List(1 -> "1").map(x => x)(breakOut): immutable.IntMap[String] + List(1L -> "1").map(x => x)(breakOut): immutable.LongMap[String] + List(1L -> "1").map(x => x)(breakOut): mutable.LongMap[String] } diff --git a/scalafix/output212/src/main/scala/fix/BreakoutSrc.scala b/scalafix/output212/src/main/scala/fix/BreakoutSrc.scala new file mode 100644 index 00000000..d4d54055 --- /dev/null +++ b/scalafix/output212/src/main/scala/fix/BreakoutSrc.scala @@ -0,0 +1,108 @@ + + + +package fix + +import scala.collection.{immutable, mutable} +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits.global +import scala.collection.compat._ + +class BreakoutSrc(ts: Iterable[Int], vec: Vector[Int], list: List[Int], seq: Seq[Int]) { + + // `IndexedSeqOptimized.zip` + vec.iterator.zip(vec.iterator).toMap: Map[Int, Int] + + // `IterableLike.zip` + seq.iterator.zip(seq.iterator).toMap: Map[Int, Int] + + // `IterableLike.zipAll` + seq.iterator.zipAll(seq.iterator, 0, 0).to(scala.Array): Array[(Int, Int)] + + // `List ++` + (list.iterator ++ list).to(scala.collection.immutable.Set): Set[Int] + + // `List +:` + (1 +: list.view).to(scala.collection.immutable.Set): Set[Int] + + // `List.collect` + list.iterator.collect{ case x => x}.to(scala.collection.immutable.Set): Set[Int] + + // `List.flatMap` + list.iterator.flatMap(x => List(x)).to(scala.collection.immutable.Set): Set[Int] + + // `List.map` + list.iterator.map(x => x).to(scala.collection.immutable.Set): Set[Int] + + // `SeqLike.reverseMap` + seq.reverseIterator.map(_ + 1).to(scala.collection.immutable.Set): Set[Int] + + // `SeqLike +:` + (1 +: seq.view).to(scala.collection.immutable.List): List[Int] + + // `SeqLike :+` + (seq.view :+ 1).to(scala.collection.immutable.List): List[Int] + + // `SeqLike.updated` + (seq.view.updated(0, 0)).to(scala.collection.immutable.List): List[Int] + + // `SeqLike.union` + seq.iterator.concat(seq).to(scala.collection.immutable.List): List[Int] + + + //`SetLike.map` + Set(1).iterator.map(x => x).to(scala.collection.immutable.List): List[Int] + + + // `TraversableLike ++` + (ts.iterator ++ ts ).to(scala.collection.immutable.Set): Set[Int] + + // `TraversableLike ++:` + (ts ++: ts.view).to(scala.collection.immutable.Set): Set[Int] + + // `TraversableLike.collect` + ts.iterator.collect{ case x => x }.to(scala.collection.immutable.Set): Set[Int] + + // `TraversableLike.flatMap` + ts.iterator.flatMap(x => List(x)).to(scala.collection.SortedSet): collection.SortedSet[Int] + + // `TraversableLike.map` + ts.iterator.map(_ + 1).to(scala.collection.immutable.Set): Set[Int] + + // `TraversableLike.scanLeft` + ts.iterator.scanLeft(0)((a, b) => a + b).to(scala.collection.immutable.Set): Set[Int] + + + // `Vector ++` + (vec.iterator ++ List(1)).to(scala.collection.immutable.List): List[Int] + + // `Vector +:` + (1 +: vec.view).to(scala.collection.immutable.List): List[Int] + + // `Vector :+` + (vec.view :+ 1).to(scala.collection.immutable.List): List[Int] + + // `Vector.updated` + (vec.view.updated(0, 0)).to(scala.collection.immutable.List): List[Int] + + // Future + Future.sequence(List(Future(1)))(scala.collection.immutable.List, global): Future[Seq[Int]] + Future.traverse(List(1))(x => Future(x))(scala.collection.immutable.List, global): Future[Seq[Int]] + + // Iterable + List(1).iterator.map(x => x): Iterator[Int] + + // Specific collections + scala.collection.immutable.SortedMap.from(List(1 -> "1").iterator.map(x => x)): immutable.SortedMap[Int, String] + scala.collection.immutable.HashMap.from(List(1 -> "1").iterator.map(x => x)): immutable.HashMap[Int, String] + scala.collection.immutable.ListMap.from(List(1 -> "1").iterator.map(x => x)): immutable.ListMap[Int, String] + scala.collection.immutable.TreeMap.from(List(1 -> "1").iterator.map(x => x)): immutable.TreeMap[Int, String] + scala.collection.mutable.SortedMap.from(List(1 -> "1").iterator.map(x => x)): mutable.SortedMap[Int, String] + scala.collection.mutable.HashMap.from(List(1 -> "1").iterator.map(x => x)): mutable.HashMap[Int, String] + scala.collection.mutable.ListMap.from(List(1 -> "1").iterator.map(x => x)): mutable.ListMap[Int, String] + scala.collection.mutable.TreeMap.from(List(1 -> "1").iterator.map(x => x)): mutable.TreeMap[Int, String] + scala.collection.mutable.Map.from(List(1 -> "1").iterator.map(x => x)): mutable.Map[Int, String] + scala.collection.immutable.IntMap.from(List(1 -> "1").iterator.map(x => x)): immutable.IntMap[String] + scala.collection.immutable.LongMap.from(List(1L -> "1").iterator.map(x => x)): immutable.LongMap[String] + scala.collection.mutable.LongMap.from(List(1L -> "1").iterator.map(x => x)): mutable.LongMap[String] +} diff --git a/scalafix/output213/src/main/scala/fix/BreakoutSrc.scala b/scalafix/output213/src/main/scala/fix/BreakoutSrc.scala deleted file mode 100644 index 7a086ec5..00000000 --- a/scalafix/output213/src/main/scala/fix/BreakoutSrc.scala +++ /dev/null @@ -1,24 +0,0 @@ - - - -package fix - - -object BreakoutSrc { - val xs = List(1, 2, 3) - - xs.iterator.collect{ case x => x }.to(implicitly): Set[Int] - xs.iterator.flatMap(x => List(x)).to(implicitly): collection.SortedSet[Int] - xs.iterator.map(_ + 1).to(implicitly): Set[Int] - xs.reverseIterator.map(_ + 1).to(implicitly): Set[Int] - xs.iterator.scanLeft(0)((a, b) => a + b).to(implicitly): Set[Int] - xs.iterator.concat(xs).to(implicitly): Set[Int] - xs.view.updated(0, 1).to(implicitly): Set[Int] - xs.iterator.zip(xs).to(implicitly): Map[Int, Int] - xs.iterator.zipAll(xs, 0, 0).to(implicitly): Array[(Int, Int)] - - (xs.iterator ++ xs).to(implicitly): Set[Int] - (1 +: xs.view).to(implicitly): Set[Int] - (xs.view :+ 1).to(implicitly): Set[Int] - (xs ++: xs.view).to(implicitly): Set[Int] -} diff --git a/scalafix/rules/src/main/scala/fix/Breakout.scala b/scalafix/rules/src/main/scala/fix/Breakout.scala new file mode 100644 index 00000000..5dcab796 --- /dev/null +++ b/scalafix/rules/src/main/scala/fix/Breakout.scala @@ -0,0 +1,279 @@ +package fix + +import scalafix._ +import scalafix.util._ +import scala.meta._ + +class BreakoutRewrite(addCompatImport: RuleCtx => Patch)(implicit val index: SemanticdbIndex) { + implicit class RichSymbol(val symbol: Symbol) { + def exact(tree: Tree): Boolean = + index.symbol(tree).fold(false)(_ == symbol) + } + + def isLeftAssociative(tree: Tree): Boolean = + tree match { + case Term.Name(value) => value.last != ':' + case _ => false + } + + val breakOut = SymbolMatcher.exact(Symbol("_root_.scala.collection.package.breakOut(Lscala/collection/generic/CanBuildFrom;)Lscala/collection/generic/CanBuildFrom;.")) + + // == infix operators == + + val `TraversableLike ++` = Symbol("_root_.scala.collection.TraversableLike#`++`(Lscala/collection/GenTraversableOnce;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `Vector ++` = Symbol("_root_.scala.collection.immutable.Vector#`++`(Lscala/collection/GenTraversableOnce;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `List ++` = Symbol("_root_.scala.collection.immutable.List#`++`(Lscala/collection/GenTraversableOnce;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike +:` = Symbol("_root_.scala.collection.SeqLike#`+:`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `Vector +:` = Symbol("_root_.scala.collection.immutable.Vector#`+:`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `List +:` = Symbol("_root_.scala.collection.immutable.List#`+:`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike :+` = Symbol("_root_.scala.collection.SeqLike#`:+`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `Vector :+` = Symbol("_root_.scala.collection.immutable.Vector#`:+`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `TraversableLike ++:` = Symbol("_root_.scala.collection.TraversableLike#`++:`(Lscala/collection/Traversable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + + val operatorsIteratorSymbols = List(`TraversableLike ++`, `List ++`, `Vector ++`) + val operatorsViewSymbols = List( + `SeqLike +:`, `Vector +:`, `List +:`, + `SeqLike :+`, `Vector :+`, + `TraversableLike ++:` + ) + + val operatorsSymbols = operatorsViewSymbols ++ operatorsIteratorSymbols + + val operatorsIterator = SymbolMatcher.exact(operatorsIteratorSymbols: _*) + val operatorsView = SymbolMatcher.exact(operatorsViewSymbols: _*) + val operators = SymbolMatcher.exact(operatorsSymbols: _*) + + // == select == + + val `List.collect` = Symbol("_root_.scala.collection.immutable.List#collect(Lscala/PartialFunction;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `TraversableLike.collect` = Symbol("_root_.scala.collection.TraversableLike#collect(Lscala/PartialFunction;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `List.flatMap` = Symbol("_root_.scala.collection.immutable.List#flatMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `TraversableLike.flatMap` = Symbol("_root_.scala.collection.TraversableLike#flatMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `List.map` = Symbol("_root_.scala.collection.immutable.List#map(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SetLike.map` = Symbol("_root_.scala.collection.SetLike#map(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `TraversableLike.map` = Symbol("_root_.scala.collection.TraversableLike#map(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `IterableLike.zip` = Symbol("_root_.scala.collection.IterableLike#zip(Lscala/collection/GenIterable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `IndexedSeqOptimized.zip` = Symbol("_root_.scala.collection.IndexedSeqOptimized#zip(Lscala/collection/GenIterable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `IterableLike.zipAll` = Symbol("_root_.scala.collection.IterableLike#zipAll(Lscala/collection/GenIterable;Ljava/lang/Object;Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike.union` = Symbol("_root_.scala.collection.SeqLike#union(Lscala/collection/GenSeq;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike.updated` = Symbol("_root_.scala.collection.SeqLike#updated(ILjava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `Vector.updated` = Symbol("_root_.scala.collection.immutable.Vector#updated(ILjava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike.reverseMap` = Symbol("_root_.scala.collection.SeqLike#reverseMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + + val functionsZipSymbols = List( + `IterableLike.zip`, + `IndexedSeqOptimized.zip`, + `IterableLike.zipAll` + ) + val functionsIteratorSymbols = List( + `List.collect`, `TraversableLike.collect`, + `List.flatMap`, `TraversableLike.flatMap`, + `List.map`, `SetLike.map`, `TraversableLike.map`, + `SeqLike.union` + ) ++ functionsZipSymbols + val functionsViewSymbols = List(`SeqLike.updated`, `Vector.updated`) + val functionsReverseIteratorSymbols = List(`SeqLike.reverseMap`) + val functionsSymbols = functionsIteratorSymbols ++ functionsViewSymbols ++ functionsReverseIteratorSymbols + + val functionsIterator = SymbolMatcher.exact(functionsIteratorSymbols: _*) + val functionsReverseIterator = SymbolMatcher.exact(functionsReverseIteratorSymbols: _*) + val functionsView = SymbolMatcher.exact(functionsViewSymbols: _*) + val functions = SymbolMatcher.exact(functionsSymbols: _*) + + val functionsZip = SymbolMatcher.exact(functionsZipSymbols: _*) + + // == special select == + + val `TraversableLike.scanLeft` = SymbolMatcher.exact(Symbol("_root_.scala.collection.TraversableLike#scanLeft(Ljava/lang/Object;Lscala/Function2;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")) + val `Future.sequence` = SymbolMatcher.exact(Symbol("_root_.scala.concurrent.Future.sequence(Lscala/collection/TraversableOnce;Lscala/collection/generic/CanBuildFrom;Lscala/concurrent/ExecutionContext;)Lscala/concurrent/Future;.")) + val `Future.traverse` = SymbolMatcher.exact(Symbol("_root_.scala.concurrent.Future.traverse(Lscala/collection/TraversableOnce;Lscala/Function1;Lscala/collection/generic/CanBuildFrom;Lscala/concurrent/ExecutionContext;)Lscala/concurrent/Future;.")) + + val toSpecificCollectionBuiltIn = Map( + "scala.collection.immutable.Map" -> "toMap" + ) + + val toSpecificCollectionFrom = Set( + "scala.collection.Map", + "scala.collection.immutable.SortedMap", + "scala.collection.immutable.HashMap", + "scala.collection.immutable.ListMap", + "scala.collection.immutable.TreeMap", + "scala.collection.mutable.SortedMap", + "scala.collection.mutable.HashMap", + "scala.collection.mutable.ListMap", + "scala.collection.mutable.TreeMap", + "scala.collection.mutable.Map", + "scala.collection.immutable.IntMap", + "scala.collection.immutable.LongMap", + "scala.collection.mutable.LongMap" + ) + + // == rule == + def apply(ctx: RuleCtx): Patch = { + + var requiresCompatImport = false + + def covertToCollection(intermediateLhs: String, + lhs: Term, + ap: Term, + breakout: Tree, + ap0: Term, + intermediateRhs: Option[String] = None, + rhs: Option[Term] = None): Patch = { + + val toCollection = extractCollectionFromBreakout(breakout) + + val patchRhs = + (intermediateRhs, rhs) match { + case (Some(i), Some(r)) => ctx.addRight(r, "." + i) + case _ => Patch.empty + } + + val patchSpecificCollection = + toSpecificCollectionBuiltIn.get(toCollection) match { + case Some(toX) => ctx.addRight(ap0, '.' + toX) + case None => + if (toSpecificCollectionFrom.contains(toCollection)) { + requiresCompatImport = true + ctx.addLeft(ap0, toCollection + ".from(") + + ctx.addRight(ap0, ")") + } else { + Patch.empty + } + } + + val isIterator = toCollection == "scala.collection.Iterator" + + val sharedPatch = + ctx.addRight(lhs, "." + intermediateLhs) + + patchRhs + + def removeBreakout: Patch = { + val breakoutWithParens = ap0.tokens.slice(ap.tokens.size, ap0.tokens.size) + ctx.removeTokens(breakoutWithParens) + } + + val toColl = + if (patchSpecificCollection.isEmpty && !isIterator) { + requiresCompatImport = true + ctx.addRight(ap, ".to") + + ctx.replaceTree(breakout, toCollection) + } else { + patchSpecificCollection + + removeBreakout + } + + sharedPatch + toColl + } + + def replaceBreakoutWithCollection(breakout: Tree): Patch = { + requiresCompatImport = true + + val toCollection = extractCollectionFromBreakout(breakout) + ctx.replaceTree(breakout, toCollection) + } + + def extractCollectionFromBreakout(breakout: Tree): String = { + val synth = ctx.index.synthetics.find(_.position.end == breakout.pos.end).get + val Term.Apply(_, List(implicitCbf)) = synth.text.parse[Term].get + + implicitCbf match { + case Term.ApplyType(q"scala.Predef.fallbackStringCanBuildFrom", _) => + "scala.collection.immutable.IndexedSeq" + + case Term.ApplyType(Term.Select(coll,_), _) => + coll.syntax + + case Term.Apply(Term.ApplyType(Term.Select(coll, _), _), _) => + coll.syntax + + case Term.Select(coll,_) => + coll.syntax + + case _ => { + throw new Exception( + s"""|cannot extract breakout collection: + | + |--------------------------------------------- + |syntax: + |${implicitCbf.syntax} + | + |--------------------------------------------- + |structure: + |${implicitCbf.structure}""".stripMargin + ) + } + } + } + + val rewriteBreakout = + ctx.tree.collect { + // (xs ++ ys)(breakOut) + case ap0 @ Term.Apply(ap @ Term.ApplyInfix(lhs, operators(op), _, List(rhs)), List(breakOut(bo))) => + val subject = + if(isLeftAssociative(op)) lhs + else rhs + + val intermediate = + op match { + case operatorsIterator(_) => "iterator" + case operatorsView(_) => "view" + // since operators(op) matches iterator and view + case _ => throw new Exception("impossible") + } + + covertToCollection(intermediate, subject, ap, bo, ap0) + + // xs.map(f)(breakOut) + case ap0 @ Term.Apply(ap @ Term.Apply(Term.Select(lhs, functions(op)), rhs :: _), List(breakOut(bo))) => + val intermediateLhs = + op match { + case functionsIterator(_) => "iterator" + case functionsView(_) => "view" + case functionsReverseIterator(_) => "reverseIterator" + // since functions(op) matches iterator, view and reverseIterator + case _ => throw new Exception("impossible") + } + + val intermediateRhs = + op match { + case functionsZip(_) => Some("iterator") + case _ => None + } + + val replaceUnion = + if (`SeqLike.union`.exact(op)) ctx.replaceTree(op, "concat") + else Patch.empty + + val isReversed = `SeqLike.reverseMap`.exact(op) + val replaceReverseMap = + if (isReversed) ctx.replaceTree(op, "map") + else Patch.empty + + covertToCollection(intermediateLhs, lhs, ap, bo, ap0, intermediateRhs, Some(rhs)) + replaceUnion + replaceReverseMap + + // ts.scanLeft(d)(f)(breakOut) + case ap0 @ Term.Apply(ap @ Term.Apply(Term.Apply(Term.Select(lhs, `TraversableLike.scanLeft`(op)), _), _), List(breakOut(bo))) => + covertToCollection("iterator", lhs, ap, bo, ap0) + + // sequence(xs)(breakOut, ec) + case Term.Apply(Term.Apply(`Future.sequence`(_), _), List(breakOut(bo), _)) => + replaceBreakoutWithCollection(bo) + + // traverse(xs)(f)(breakOut, ec) + case Term.Apply(Term.Apply(Term.Apply(`Future.traverse`(_),_), _), List(breakOut(bo), _)) => + replaceBreakoutWithCollection(bo) + + // import scala.collection.breakOut + case i: Importee if breakOut.matches(i) => + ctx.removeImportee(i) + + }.asPatch + + val compatImport = + if (requiresCompatImport) addCompatImport(ctx) + else Patch.empty + + rewriteBreakout + compatImport + } +} diff --git a/scalafix/rules/src/main/scala/fix/NewCollections.scala b/scalafix/rules/src/main/scala/fix/NewCollections.scala index a5033c9f..b283ffdb 100644 --- a/scalafix/rules/src/main/scala/fix/NewCollections.scala +++ b/scalafix/rules/src/main/scala/fix/NewCollections.scala @@ -21,60 +21,6 @@ case class NewCollections(index: SemanticdbIndex) val retainMap = normalized("_root_.scala.collection.mutable.MapLike.retain.") val retainSet = normalized("_root_.scala.collection.mutable.SetLike.retain.") - object Breakout { - implicit class RichSymbol(val symbol: Symbol) { - def exact(tree: Tree)(implicit index: SemanticdbIndex): Boolean = - index.symbol(tree).fold(false)(_ == symbol) - } - - val breakOut = SymbolMatcher.exact(Symbol("_root_.scala.collection.package.breakOut(Lscala/collection/generic/CanBuildFrom;)Lscala/collection/generic/CanBuildFrom;.")) - - // infix operators - val `List ++` = Symbol("_root_.scala.collection.immutable.List#`++`(Lscala/collection/GenTraversableOnce;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `List +:` = Symbol("_root_.scala.collection.immutable.List#`+:`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `SeqLike :+` = Symbol("_root_.scala.collection.SeqLike#`:+`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `TraversableLike ++:` = Symbol("_root_.scala.collection.TraversableLike#`++:`(Lscala/collection/Traversable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - - val operatorsIteratorSymbols = List(`List ++`) - val operatorsViewSymbols = List(`List +:`, `SeqLike :+`, `TraversableLike ++:`) - val operatorsSymbols = operatorsViewSymbols ++ operatorsIteratorSymbols - - val operatorsIterator = SymbolMatcher.exact(operatorsIteratorSymbols: _*) - val operatorsView = SymbolMatcher.exact(operatorsViewSymbols: _*) - val operators = SymbolMatcher.exact(operatorsSymbols: _*) - - // select - val `List.collect` = Symbol("_root_.scala.collection.immutable.List#collect(Lscala/PartialFunction;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `List.flatMap` = Symbol("_root_.scala.collection.immutable.List#flatMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `List.map` = Symbol("_root_.scala.collection.immutable.List#map(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `IterableLike.zip` = Symbol("_root_.scala.collection.IterableLike#zip(Lscala/collection/GenIterable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `IterableLike.zipAll` = Symbol("_root_.scala.collection.IterableLike#zipAll(Lscala/collection/GenIterable;Ljava/lang/Object;Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `SeqLike.union` = Symbol("_root_.scala.collection.SeqLike#union(Lscala/collection/GenSeq;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `SeqLike.updated` = Symbol("_root_.scala.collection.SeqLike#updated(ILjava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - val `SeqLike.reverseMap` = Symbol("_root_.scala.collection.SeqLike#reverseMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") - - val functionsIteratorSymbols = List(`List.collect`, `List.flatMap`, `List.map`, `IterableLike.zip`, `IterableLike.zipAll`, `SeqLike.union`) - val functionsViewSymbols = List(`SeqLike.updated`) - val functionsReverseIteratorSymbols = List(`SeqLike.reverseMap`) - val functionsSymbols = functionsIteratorSymbols ++ functionsViewSymbols ++ functionsReverseIteratorSymbols - - val functionsIterator = SymbolMatcher.exact(functionsIteratorSymbols: _*) - val functionsReverseIterator = SymbolMatcher.exact(functionsReverseIteratorSymbols: _*) - val functionsView = SymbolMatcher.exact(functionsViewSymbols: _*) - val functions = SymbolMatcher.exact(functionsSymbols: _*) - - // special select - - // iterator - val `TraversableLike.scanLeft` = SymbolMatcher.exact(Symbol("_root_.scala.collection.TraversableLike#scanLeft(Ljava/lang/Object;Lscala/Function2;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")) - - def isLeftAssociative(tree: Tree): Boolean = - tree match { - case Term.Name(value) => value.last != ':' - case _ => false - } - } - // == Rules == def replaceSymbols(ctx: RuleCtx): Patch = { @@ -146,66 +92,13 @@ case class NewCollections(index: SemanticdbIndex) }.asPatch } - def replaceBreakout(ctx: RuleCtx): Patch = { - import Breakout._ - - def fixIt(intermediate: String, lhs: Term, ap: Term, breakout: Tree): Patch = { - ctx.addRight(lhs, "." + intermediate) + - ctx.addRight(ap, ".to") + - ctx.replaceTree(breakout, "implicitly") - } - - ctx.tree.collect { - case i: Importee if breakOut.matches(i) => - ctx.removeImportee(i) - - case Term.Apply(ap @ Term.ApplyInfix(lhs, operators(op), _, List(rhs)), List(breakOut(bo))) => - val subject = - if(isLeftAssociative(op)) lhs - else rhs - - val intermediate = - op match { - case operatorsIterator(_) => "iterator" - case operatorsView(_) => "view" - // since operators(op) matches iterator and view - case _ => throw new Exception("impossible") - } - - fixIt(intermediate, subject, ap, bo) - - case Term.Apply(ap @ Term.Apply(Term.Select(lhs, functions(op)), _), List(breakOut(bo))) => - val intermediate = - op match { - case functionsIterator(_) => "iterator" - case functionsView(_) => "view" - case functionsReverseIterator(_) => "reverseIterator" - // since functions(op) matches iterator, view and reverseIterator - case _ => throw new Exception("impossible") - } - - val replaceUnion = - if (`SeqLike.union`.exact(op)) ctx.replaceTree(op, "concat") - else Patch.empty - - val isReversed = `SeqLike.reverseMap`.exact(op) - val replaceReverseMap = - if (isReversed) ctx.replaceTree(op, "map") - else Patch.empty - - fixIt(intermediate, lhs, ap, bo) + replaceUnion + replaceReverseMap - - case Term.Apply(ap @ Term.Apply(Term.Apply(Term.Select(lhs, `TraversableLike.scanLeft`(op)), _), _), List(breakOut(bo))) => - fixIt("iterator", lhs, ap, bo) - }.asPatch - } + override def fix(ctx: RuleCtx): Patch = { super.fix(ctx) + replaceSymbols(ctx) + replaceTupleZipped(ctx) + replaceMutableMap(ctx) + - replaceMutableSet(ctx) + - replaceBreakout(ctx) + replaceMutableSet(ctx) } } diff --git a/scalafix/rules/src/main/scala/fix/Stable212Base.scala b/scalafix/rules/src/main/scala/fix/Stable212Base.scala index b0199e4f..9cc7c110 100644 --- a/scalafix/rules/src/main/scala/fix/Stable212Base.scala +++ b/scalafix/rules/src/main/scala/fix/Stable212Base.scala @@ -44,20 +44,27 @@ trait Stable212Base extends CrossCompatibility { self: SemanticRule => val traversable = exact( "_root_.scala.package.Traversable#", - "_root_.scala.collection.Traversable#", - "_root_.scala.package.Iterable#", - "_root_.scala.collection.Iterable#" + "_root_.scala.collection.Traversable#" ) // == Rules == + val breakoutRewrite = new BreakoutRewrite(addCompatImport) + def replaceBreakout(ctx: RuleCtx): Patch = breakoutRewrite(ctx) + def replaceIterableSameElements(ctx: RuleCtx): Patch = { - ctx.tree.collect { - case Term.Apply(Term.Select(lhs, iterableSameElement(_)), List(_)) => - ctx.addRight(lhs, ".iterator") - }.asPatch - } + val sameElements = + ctx.tree.collect { + case Term.Apply(Term.Select(lhs, iterableSameElement(_)), List(_)) => + ctx.addRight(lhs, ".iterator") + }.asPatch + val compatImport = + if(sameElements.nonEmpty) addCompatImport(ctx) + else Patch.empty + + sameElements + compatImport + } def replaceSymbols0(ctx: RuleCtx): Patch = { val traversableToIterable = @@ -179,9 +186,10 @@ trait Stable212Base extends CrossCompatibility { self: SemanticRule => ctx.removeImportee(i) }.asPatch - val compatImport = addCompatImport(ctx) - - if (useSites.nonEmpty) useSites + imports + compatImport + if (useSites.nonEmpty) { + val compatImport = addCompatImport(ctx) + useSites + imports + compatImport + } else Patch.empty } @@ -214,40 +222,55 @@ trait Stable212Base extends CrossCompatibility { self: SemanticRule => } def replaceToList(ctx: RuleCtx): Patch = { - ctx.tree.collect { - case iterator(t: Name) => - ctx.replaceTree(t, "iterator") - - case Term.ApplyType(Term.Select(_, t @ toTpe(n: Name)), _) if !handledTo.contains(n) => - trailingBrackets(n, ctx).map { case (open, close) => - ctx.replaceToken(open, "(") + ctx.replaceToken(close, ")") - }.asPatch - - case t @ Term.Select(_, to @ toTpe(n: Name)) if !handledTo.contains(n) => - // we only want f.to, not f.to(X) - val applied = - t.parent match { - case Some(_:Term.Apply) => true - case _ => false - } - - if (!applied) { - val synth = ctx.index.synthetics.find(_.position.end == to.pos.end) - synth.map{ s => - val res = s.text.parse[Term].get - val Term.Apply(_, List(toCol)) = res - val col = extractCollection(toCol) - ctx.addRight(to, "(" + col + ")") - }.getOrElse(Patch.empty) - } else Patch.empty + val replaceToIterator = + ctx.tree.collect { + case iterator(t: Name) => + ctx.replaceTree(t, "iterator") + }.asPatch - }.asPatch + val replaceTo = + ctx.tree.collect { + case Term.ApplyType(Term.Select(_, t @ toTpe(n: Name)), _) if !handledTo.contains(n) => + trailingBrackets(n, ctx).map { case (open, close) => + ctx.replaceToken(open, "(") + ctx.replaceToken(close, ")") + }.asPatch + + case t @ Term.Select(_, to @ toTpe(n: Name)) if !handledTo.contains(n) => + // we only want f.to, not f.to(X) + val applied = + t.parent match { + case Some(_:Term.Apply) => true + case _ => false + } + + if (!applied) { + val synth = ctx.index.synthetics.find(_.position.end == to.pos.end) + synth.map{ s => + val res = s.text.parse[Term].get + val Term.Apply(_, List(toCol)) = res + val col = extractCollection(toCol) + ctx.addRight(to, "(" + col + ")") + }.getOrElse(Patch.empty) + } else Patch.empty + + }.asPatch + + val compatImport = + if (replaceTo.nonEmpty) addCompatImport(ctx) + else Patch.empty + + compatImport + replaceToIterator + replaceTo } + private val compatImportAdded = mutable.Set[Input]() def addCompatImport(ctx: RuleCtx): Patch = { - if (isCrossCompatible) ctx.addGlobalImport(importer"scala.collection.compat._") - else Patch.empty + if (isCrossCompatible && !compatImportAdded.contains(ctx.input)) { + compatImportAdded += ctx.input + ctx.addGlobalImport(importer"scala.collection.compat._") + } else { + Patch.empty + } } override def fix(ctx: RuleCtx): Patch = { @@ -260,7 +283,8 @@ trait Stable212Base extends CrossCompatibility { self: SemanticRule => replaceMutSetMapPlus(ctx) + replaceMutMapUpdated(ctx) + replaceArrayBuilderMake(ctx) + - replaceIterableSameElements(ctx) + replaceIterableSameElements(ctx) + + replaceBreakout(ctx) } }