diff --git a/scalafix/input/src/main/scala/fix/BreakoutSrc.scala b/scalafix/input/src/main/scala/fix/BreakoutSrc.scala new file mode 100644 index 00000000..09b7e16a --- /dev/null +++ b/scalafix/input/src/main/scala/fix/BreakoutSrc.scala @@ -0,0 +1,25 @@ +/* +rule = "scala:fix.Scalacollectioncompat_newcollections" + */ +package fix + +import scala.collection.breakOut + +object BreakoutSrc { + val xs = List(1, 2, 3) + + xs.collect{ case x => x }(breakOut): Set[Int] + xs.flatMap(x => List(x))(breakOut): collection.SortedSet[Int] + xs.map(_ + 1)(breakOut): Set[Int] + xs.reverseMap(_ + 1)(breakOut): Set[Int] + xs.scanLeft(0)((a, b) => a + b)(breakOut): Set[Int] + xs.union(xs)(breakOut): Set[Int] + xs.updated(0, 1)(breakOut): Set[Int] + xs.zip(xs)(breakOut): Map[Int, Int] + xs.zipAll(xs, 0, 0)(breakOut): Array[(Int, Int)] + + (xs ++ xs)(breakOut): Set[Int] + (1 +: xs)(breakOut): Set[Int] + (xs :+ 1)(breakOut): Set[Int] + (xs ++: xs)(breakOut): Set[Int] +} \ No newline at end of file diff --git a/scalafix/output/src/main/scala/fix/BreakoutSrc.scala b/scalafix/output/src/main/scala/fix/BreakoutSrc.scala new file mode 100644 index 00000000..e9ef5a05 --- /dev/null +++ b/scalafix/output/src/main/scala/fix/BreakoutSrc.scala @@ -0,0 +1,24 @@ + + + +package fix + + +object BreakoutSrc { + val xs = List(1, 2, 3) + + xs.iterator.collect{ case x => x }.to(implicitly): Set[Int] + xs.iterator.flatMap(x => List(x)).to(implicitly): collection.SortedSet[Int] + xs.iterator.map(_ + 1).to(implicitly): Set[Int] + xs.reverseIterator.map(_ + 1).to(implicitly): Set[Int] + xs.iterator.scanLeft(0)((a, b) => a + b).to(implicitly): Set[Int] + xs.iterator.concat(xs).to(implicitly): Set[Int] + xs.view.updated(0, 1).to(implicitly): Set[Int] + xs.iterator.zip(xs).to(implicitly): Map[Int, Int] + xs.iterator.zipAll(xs, 0, 0).to(implicitly): Array[(Int, Int)] + + (xs.iterator ++ xs).to(implicitly): Set[Int] + (1 +: xs.view).to(implicitly): Set[Int] + (xs.view :+ 1).to(implicitly): Set[Int] + (xs ++: xs.view).to(implicitly): Set[Int] +} \ No newline at end of file diff --git a/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala b/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala index 34ef7049..3c3662a0 100644 --- a/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala +++ b/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala @@ -26,31 +26,6 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) close <- ctx.matchingParens.close(open) } yield (open, close) - // terms dont give us terms https://github.com/scalameta/scalameta/issues/1212 - // WARNING: TOTAL HACK - // this is only to unblock us until Term.tpe is available: https://github.com/scalameta/scalameta/issues/1212 - // if we have a simple identifier, we can look at his definition at query it's type - // this should be improved in future version of scalameta - object TypeMatcher { - def apply(symbols: Symbol*)(implicit index: SemanticdbIndex): TypeMatcher = - new TypeMatcher(symbols: _*)(index) - } - - final class TypeMatcher(symbols: Symbol*)(implicit index: SemanticdbIndex) { - def unapply(tree: Tree): Boolean = { - index.denotation(tree) - .exists(_.names.headOption.exists(n => symbols.exists(_ == n.symbol))) - } - } - - val CollectionMap: TypeMatcher = TypeMatcher( - Symbol("_root_.scala.collection.immutable.Map#"), - Symbol("_root_.scala.collection.mutable.Map#"), - Symbol("_root_.scala.Predef.Map#") - ) - - val CollectionSet: TypeMatcher = TypeMatcher(Symbol("_root_.scala.collection.Set#")) - def replaceSymbols(ctx: RuleCtx): Patch = { ctx.replaceSymbols( "scala.collection.LinearSeq" -> "scala.collection.immutable.List", @@ -141,6 +116,60 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) Symbol("_root_.scala.collection.mutable.ArrayBuilder.make(Lscala/reflect/ClassTag;)Lscala/collection/mutable/ArrayBuilder;.") ) + object Breakout { + implicit class RichSymbol(val symbol: Symbol) { + def exact(tree: Tree)(implicit index: SemanticdbIndex): Boolean = + index.symbol(tree).fold(false)(_ == symbol) + } + + val breakOut = SymbolMatcher.exact(Symbol("_root_.scala.collection.package.breakOut(Lscala/collection/generic/CanBuildFrom;)Lscala/collection/generic/CanBuildFrom;.")) + + // infix operators + val `List ++` = Symbol("_root_.scala.collection.immutable.List#`++`(Lscala/collection/GenTraversableOnce;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `List +:` = Symbol("_root_.scala.collection.immutable.List#`+:`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike :+` = Symbol("_root_.scala.collection.SeqLike#`:+`(Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `TraversableLike ++:` = Symbol("_root_.scala.collection.TraversableLike#`++:`(Lscala/collection/Traversable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + + val operatorsIteratorSymbols = List(`List ++`) + val operatorsViewSymbols = List(`List +:`, `SeqLike :+`, `TraversableLike ++:`) + val operatorsSymbols = operatorsViewSymbols ++ operatorsIteratorSymbols + + val operatorsIterator = SymbolMatcher.exact(operatorsIteratorSymbols: _*) + val operatorsView = SymbolMatcher.exact(operatorsViewSymbols: _*) + val operators = SymbolMatcher.exact(operatorsSymbols: _*) + + // select + val `List.collect` = Symbol("_root_.scala.collection.immutable.List#collect(Lscala/PartialFunction;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `List.flatMap` = Symbol("_root_.scala.collection.immutable.List#flatMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `List.map` = Symbol("_root_.scala.collection.immutable.List#map(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `IterableLike.zip` = Symbol("_root_.scala.collection.IterableLike#zip(Lscala/collection/GenIterable;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `IterableLike.zipAll` = Symbol("_root_.scala.collection.IterableLike#zipAll(Lscala/collection/GenIterable;Ljava/lang/Object;Ljava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike.union` = Symbol("_root_.scala.collection.SeqLike#union(Lscala/collection/GenSeq;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike.updated` = Symbol("_root_.scala.collection.SeqLike#updated(ILjava/lang/Object;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + val `SeqLike.reverseMap` = Symbol("_root_.scala.collection.SeqLike#reverseMap(Lscala/Function1;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.") + + val functionsIteratorSymbols = List(`List.collect`, `List.flatMap`, `List.map`, `IterableLike.zip`, `IterableLike.zipAll`, `SeqLike.union`) + val functionsViewSymbols = List(`SeqLike.updated`) + val functionsReverseIteratorSymbols = List(`SeqLike.reverseMap`) + val functionsSymbols = functionsIteratorSymbols ++ functionsViewSymbols ++ functionsReverseIteratorSymbols + + val functionsIterator = SymbolMatcher.exact(functionsIteratorSymbols: _*) + val functionsReverseIterator = SymbolMatcher.exact(functionsReverseIteratorSymbols: _*) + val functionsView = SymbolMatcher.exact(functionsViewSymbols: _*) + val functions = SymbolMatcher.exact(functionsSymbols: _*) + + // special select + + // iterator + val `TraversableLike.scanLeft` = SymbolMatcher.exact(Symbol("_root_.scala.collection.TraversableLike#scanLeft(Ljava/lang/Object;Lscala/Function2;Lscala/collection/generic/CanBuildFrom;)Ljava/lang/Object;.")) + + def isLeftAssociative(tree: Tree): Boolean = + tree match { + case Term.Name(value) => value.last != ':' + case _ => false + } + } + def startsWithParens(tree: Tree): Boolean = tree.tokens.headOption.map(_.is[Token.LeftParen]).getOrElse(false) @@ -500,6 +529,60 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) if (useSites.nonEmpty) useSites + imports else Patch.empty } + + def replaceBreakout(ctx: RuleCtx): Patch = { + import Breakout._ + + def fixIt(intermediate: String, lhs: Term, ap: Term, breakout: Tree): Patch = { + ctx.addRight(lhs, "." + intermediate) + + ctx.addRight(ap, ".to") + + ctx.replaceTree(breakout, "implicitly") + } + + ctx.tree.collect { + case i: Importee if breakOut.matches(i) => + ctx.removeImportee(i) + + case Term.Apply(ap @ Term.ApplyInfix(lhs, operators(op), _, List(rhs)), List(breakOut(bo))) => + val subject = + if(isLeftAssociative(op)) lhs + else rhs + + val intermediate = + op match { + case operatorsIterator(_) => "iterator" + case operatorsView(_) => "view" + // since operators(op) matches iterator and view + case _ => throw new Exception("impossible") + } + + fixIt(intermediate, subject, ap, bo) + + case Term.Apply(ap @ Term.Apply(Term.Select(lhs, functions(op)), _), List(breakOut(bo))) => + val intermediate = + op match { + case functionsIterator(_) => "iterator" + case functionsView(_) => "view" + case functionsReverseIterator(_) => "reverseIterator" + // since functions(op) matches iterator, view and reverseIterator + case _ => throw new Exception("impossible") + } + + val replaceUnion = + if (`SeqLike.union`.exact(op)) ctx.replaceTree(op, "concat") + else Patch.empty + + val isReversed = `SeqLike.reverseMap`.exact(op) + val replaceReverseMap = + if (isReversed) ctx.replaceTree(op, "map") + else Patch.empty + + fixIt(intermediate, lhs, ap, bo) + replaceUnion + replaceReverseMap + + case Term.Apply(ap @ Term.Apply(Term.Apply(Term.Select(lhs, `TraversableLike.scanLeft`(op)), _), _), List(breakOut(bo))) => + fixIt("iterator", lhs, ap, bo) + }.asPatch + } override def fix(ctx: RuleCtx): Patch = { replaceCanBuildFrom(ctx) + @@ -516,6 +599,7 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) replaceMutMapUpdated(ctx) + replaceArrayBuilderMake(ctx) + replaceIterableSameElements(ctx) + - replaceMapMapValues(ctx) + replaceMapMapValues(ctx) + + replaceBreakout(ctx) } }