diff --git a/scalafix/input/src/main/scala/fix/SetMapSrc.scala b/scalafix/input/src/main/scala/fix/SetMapSrc.scala index 0597c513..7477a4c3 100644 --- a/scalafix/input/src/main/scala/fix/SetMapSrc.scala +++ b/scalafix/input/src/main/scala/fix/SetMapSrc.scala @@ -3,10 +3,22 @@ rule = "scala:fix.Scalacollectioncompat_newcollections" */ package fix -class SetMapSrc(set: Set[Int], map: Map[Int, Int]) { - set + (2, 3) - map + (2 -> 3, 3 -> 4) - (set + (2, 3)).map(x => x) - set + (2, 3) - 4 - map.mapValues(_ + 1) -} \ No newline at end of file +import scala.collection +import scala.collection.immutable +import scala.collection.mutable.{Map, Set} // Challenge to make sure the scoping is correct + +class SetMapSrc(iset: immutable.Set[Int], + cset: collection.Set[Int], + imap: immutable.Map[Int, Int], + cmap: collection.Map[Int, Int]) { + iset + (2, 3) + imap + (2 -> 3, 3 -> 4) + (iset + (2, 3)).toString + iset + (2, 3) - 4 + imap.mapValues(_ + 1) + iset + 1 + iset - 2 + cset + 1 + cset - 2 + cmap + (2 -> 3) + ((4, 5)) +} diff --git a/scalafix/output/src/main/scala/fix/SetMapSrc.scala b/scalafix/output/src/main/scala/fix/SetMapSrc.scala index 978b0bfa..06fea105 100644 --- a/scalafix/output/src/main/scala/fix/SetMapSrc.scala +++ b/scalafix/output/src/main/scala/fix/SetMapSrc.scala @@ -3,10 +3,22 @@ package fix -class SetMapSrc(set: Set[Int], map: Map[Int, Int]) { - set + 2 + 3 - map + (2 -> 3) + (3 -> 4) - (set + 2 + 3).map(x => x) - set + 2 + 3 - 4 - map.mapValues(_ + 1).toMap +import scala.collection +import scala.collection.immutable +import scala.collection.mutable.{Map, Set} // Challenge to make sure the scoping is correct + +class SetMapSrc(iset: immutable.Set[Int], + cset: collection.Set[Int], + imap: immutable.Map[Int, Int], + cmap: collection.Map[Int, Int]) { + iset + 2 + 3 + imap + (2 -> 3) + (3 -> 4) + (iset + 2 + 3).toString + iset + 2 + 3 - 4 + imap.mapValues(_ + 1).toMap + iset + 1 + iset - 2 + cset ++ _root_.scala.collection.Set(1) + cset -- _root_.scala.collection.Set(2) + cmap ++ _root_.scala.collection.Map(2 -> 3) ++ _root_.scala.collection.Map((4, 5)) } \ No newline at end of file diff --git a/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala b/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala index f58a0b11..77ca199d 100644 --- a/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala +++ b/scalafix/rules/src/main/scala/fix/Scalacollectioncompat_newcollections.scala @@ -8,6 +8,24 @@ import scala.meta._ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) extends SemanticRule(index, "Scalacollectioncompat_newcollections") { + // WARNING: TOTAL HACK + // this is only to unblock us until Term.tpe is available: https://github.com/scalameta/scalameta/issues/1212 + // if we have a simple identifier, we can look at his definition at query it's type + // this should be improved in future version of scalameta + object TypeMatcher { + def apply(symbols: Symbol*)(implicit index: SemanticdbIndex): TypeMatcher = + new TypeMatcher(symbols: _*)(index) + } + + final class TypeMatcher(symbols: Symbol*)(implicit index: SemanticdbIndex) { + def unapply(tree: Tree): Boolean = { + index.denotation(tree) + .exists(_.names.headOption.exists(n => symbols.exists(_ == n.symbol))) + } + } + + val CollectionSet: TypeMatcher = TypeMatcher(Symbol("_root_.scala.collection.Set#")) + def replaceSymbols(ctx: RuleCtx): Patch = { ctx.replaceSymbols( "scala.collection.LinearSeq" -> "scala.collection.immutable.List", @@ -30,6 +48,18 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) Symbol("_root_.scala.runtime.Tuple2Zipped.Ops.zipped."), Symbol("_root_.scala.runtime.Tuple3Zipped.Ops.zipped.") ) + val setPlus = + SymbolMatcher.exact( + Symbol("_root_.scala.collection.SetLike#`+`(Ljava/lang/Object;)Lscala/collection/Set;.") + ) + val setMinus = + SymbolMatcher.exact( + Symbol("_root_.scala.collection.SetLike#`-`(Ljava/lang/Object;)Lscala/collection/Set;.") + ) + val mapPlus = + SymbolMatcher.exact( + Symbol("_root_.scala.collection.MapLike#`+`(Lscala/Tuple2;)Lscala/collection/Map;.") + ) val setPlus2 = SymbolMatcher.exact( Symbol("_root_.scala.collection.SetLike#`+`(Ljava/lang/Object;Ljava/lang/Object;Lscala/collection/Seq;)Lscala/collection/Set;.") ) @@ -82,6 +112,9 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) Symbol("_root_.scala.collection.mutable.ArrayBuilder.make(Lscala/reflect/ClassTag;)Lscala/collection/mutable/ArrayBuilder;.") ) + def startsWithParens(tree: Tree): Boolean = + tree.tokens.headOption.map(_.is[Token.LeftParen]).getOrElse(false) + def replaceMutableSet(ctx: RuleCtx) = ctx.tree.collect { case retainSet(n: Name) => @@ -192,7 +225,7 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) def replaceSetMapPlus2(ctx: RuleCtx): Patch = { def rewritePlus(ap: Term.ApplyInfix, lhs: Term, op: Term.Name, rhs1: Term, rhs2: Term): Patch = { val tokensToReplace = - if(ap.tokens.headOption.map(_.is[Token.LeftParen]).getOrElse(false)) { + if(startsWithParens(ap)) { // don't drop surrounding parens ap.tokens.slice(1, ap.tokens.size - 1) } else ap.tokens @@ -217,6 +250,34 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) }.asPatch } + def replaceSetMapPlusMinus(ctx: RuleCtx): Patch = { + def rewriteOp(op: Tree, rhs: Tree, doubleOp: String, col0: String): Patch = { + val col = "_root_.scala.collection." + col0 + val callSite = + if (startsWithParens(rhs)) { + ctx.addLeft(rhs, col) + } + else { + ctx.addLeft(rhs, col + "(") + + ctx.addRight(rhs, ")") + } + + ctx.addRight(op, doubleOp) + callSite + } + + ctx.tree.collect { + case Term.ApplyInfix(CollectionSet(), op @ setPlus(_), Nil, List(rhs)) => + rewriteOp(op, rhs, "+", "Set") + + case Term.ApplyInfix(CollectionSet(), op @ setMinus(_), Nil, List(rhs)) => + rewriteOp(op, rhs, "-", "Set") + + case Term.ApplyInfix(_, op @ mapPlus(_), Nil, List(rhs)) => + rewriteOp(op, rhs, "+", "Map") + }.asPatch + } + + def replaceMutSetMapPlus(ctx: RuleCtx): Patch = { def rewriteMutPlus(lhs: Term, op: Term.Name): Patch = { ctx.addRight(lhs, ".clone()") + @@ -265,7 +326,7 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) ctx.addRight(ap, ".toMap") }.asPatch } - + override def fix(ctx: RuleCtx): Patch = { replaceToList(ctx) + replaceSymbols(ctx) + @@ -276,6 +337,7 @@ case class Scalacollectioncompat_newcollections(index: SemanticdbIndex) replaceMutableSet(ctx) + replaceSymbolicFold(ctx) + replaceSetMapPlus2(ctx) + + replaceSetMapPlusMinus(ctx) + replaceMutSetMapPlus(ctx) + replaceMutMapUpdated(ctx) + replaceIterableSameElements(ctx) +