Skip to content

Commit b9ecfc0

Browse files
committed
Filter only for generators starting with case.
But wait with this for now, since we can't cross-compile easily otherwise. So currently this is enabled only under -strict.
1 parent ca29d33 commit b9ecfc0

File tree

5 files changed

+131
-56
lines changed

5 files changed

+131
-56
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,19 @@ object desugar {
3333
*/
3434
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key
3535

36-
/** An attachment for match expressions generated from a PatDef */
37-
val PatDefMatch: Property.Key[Unit] = new Property.Key
36+
/** An attachment for match expressions generated from a PatDef or GenFrom.
37+
* Value of key == one of IrrefutablePatDef, IrrefutableGenFrom
38+
*/
39+
val CheckIrrefutable: Property.Key[MatchCheck] = new Property.StickyKey
40+
41+
/** What static check should be applied to a Match (none, irrefutable, exhaustive) */
42+
class MatchCheck(val n: Int) extends AnyVal
43+
object MatchCheck {
44+
val None = new MatchCheck(0)
45+
val Exhaustive = new MatchCheck(1)
46+
val IrrefutablePatDef = new MatchCheck(2)
47+
val IrrefutableGenFrom = new MatchCheck(3)
48+
}
3849

3950
/** Info of a variable in a pattern: The named tree and its type */
4051
private type VarInfo = (NameTree, Tree)
@@ -926,6 +937,22 @@ object desugar {
926937
}
927938
}
928939

940+
/** The selector of a match, which depends of the given `checkMode`.
941+
* @param sel the original selector
942+
* @return if `checkMode` is
943+
* - None : sel @unchecked
944+
* - Exhaustive : sel
945+
* - IrrefutablePatDef,
946+
* IrrefutableGenFrom: sel @unchecked with attachment `CheckIrrefutable -> checkMode`
947+
*/
948+
def makeSelector(sel: Tree, checkMode: MatchCheck)(implicit ctx: Context): Tree =
949+
if (checkMode == MatchCheck.Exhaustive) sel
950+
else {
951+
val sel1 = Annotated(sel, New(ref(defn.UncheckedAnnotType)))
952+
if (checkMode != MatchCheck.None) sel1.pushAttachment(CheckIrrefutable, checkMode)
953+
sel1
954+
}
955+
929956
/** If `pat` is a variable pattern,
930957
*
931958
* val/var/lazy val p = e
@@ -960,11 +987,6 @@ object desugar {
960987
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
961988
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
962989

963-
def rhsUnchecked = {
964-
val rhs1 = makeAnnotated("scala.unchecked", rhs)
965-
rhs1.pushAttachment(PatDefMatch, ())
966-
rhs1
967-
}
968990
val vars =
969991
if (tupleOptimizable) // include `_`
970992
pat match {
@@ -977,7 +999,7 @@ object desugar {
977999
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids))
9781000
val matchExpr =
9791001
if (tupleOptimizable) rhs
980-
else Match(rhsUnchecked, caseDef :: Nil)
1002+
else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)
9811003
vars match {
9821004
case Nil =>
9831005
matchExpr
@@ -1126,14 +1148,10 @@ object desugar {
11261148
*
11271149
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked?) match { cases }
11281150
*/
1129-
def makeCaseLambda(cases: List[CaseDef], nparams: Int = 1, unchecked: Boolean = true)(implicit ctx: Context): Function = {
1151+
def makeCaseLambda(cases: List[CaseDef], checkMode: MatchCheck, nparams: Int = 1)(implicit ctx: Context): Function = {
11301152
val params = (1 to nparams).toList.map(makeSyntheticParameter(_))
11311153
val selector = makeTuple(params.map(p => Ident(p.name)))
1132-
1133-
if (unchecked)
1134-
Function(params, Match(Annotated(selector, New(ref(defn.UncheckedAnnotType))), cases))
1135-
else
1136-
Function(params, Match(selector, cases))
1154+
Function(params, Match(makeSelector(selector, checkMode), cases))
11371155
}
11381156

11391157
/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
@@ -1264,13 +1282,14 @@ object desugar {
12641282

12651283
/** Make a function value pat => body.
12661284
* If pat is a var pattern id: T then this gives (id: T) => body
1267-
* Otherwise this gives { case pat => body }
1285+
* Otherwise this gives { case pat => body }, where `pat` is allowed to be
1286+
* refutable only if `checkMode` is MatchCheck.None.
12681287
*/
1269-
def makeLambda(pat: Tree, body: Tree): Tree = pat match {
1288+
def makeLambda(pat: Tree, body: Tree, checkMode: MatchCheck): Tree = pat match {
12701289
case IdPattern(named, tpt) =>
12711290
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
12721291
case _ =>
1273-
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil)
1292+
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil, checkMode)
12741293
}
12751294

12761295
/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
@@ -1316,7 +1335,7 @@ object desugar {
13161335
val cases = List(
13171336
CaseDef(pat, EmptyTree, Literal(Constant(true))),
13181337
CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))))
1319-
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases))
1338+
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases, MatchCheck.None))
13201339
}
13211340

13221341
/** Is pattern `pat` irrefutable when matched against `rhs`?
@@ -1355,26 +1374,30 @@ object desugar {
13551374
Select(rhs, name)
13561375
}
13571376

1377+
def checkMode(gen: GenFrom) =
1378+
if (gen.filtering) MatchCheck.None // refutable paterns were already eliminated in filter step
1379+
else MatchCheck.IrrefutableGenFrom
1380+
13581381
enums match {
13591382
case (gen: GenFrom) :: Nil =>
1360-
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body))
1383+
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body, checkMode(gen)))
13611384
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
13621385
val cont = makeFor(mapName, flatMapName, rest, body)
1363-
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
1364-
case (gen @ GenFrom(pat, rhs, _)) :: (rest @ GenAlias(_, _) :: _) =>
1386+
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont, checkMode(gen)))
1387+
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
13651388
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
13661389
val pats = valeqs map { case GenAlias(pat, _) => pat }
13671390
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
1368-
val (defpat0, id0) = makeIdPat(pat)
1391+
val (defpat0, id0) = makeIdPat(gen.pat)
13691392
val (defpats, ids) = (pats map makeIdPat).unzip
13701393
val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _))
1371-
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, rhs, gen.filtering) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1372-
val allpats = pat :: pats
1394+
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.filtering) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1395+
val allpats = gen.pat :: pats
13731396
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, filtering = false)
13741397
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13751398
case (gen: GenFrom) :: test :: rest =>
1376-
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test))
1377-
val genFrom = new GenFrom(gen.pat, filtered, filtering = false)
1399+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test, MatchCheck.None))
1400+
val genFrom = GenFrom(gen.pat, filtered, filtering = false)
13781401
makeFor(mapName, flatMapName, genFrom :: rest, body)
13791402
case _ =>
13801403
EmptyTree //may happen for erroneous input

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -602,41 +602,47 @@ trait Checking {
602602
* This means `pat` is either marked @unchecked or `pt` conforms to the
603603
* pattern's type. If pattern is an UnApply, do the check recursively.
604604
*/
605-
def checkIrrefutable(pat: Tree, pt: Type)(implicit ctx: Context): Boolean = {
606-
patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt")
605+
def checkIrrefutable(pat: Tree, pt: Type, isPatDef: Boolean)(implicit ctx: Context): Boolean = {
607606

608607
def fail(pat: Tree, pt: Type): Boolean = {
608+
var reportedPt = pt.dropAnnot(defn.UncheckedAnnot)
609+
if (!pat.tpe.isSingleton) reportedPt = reportedPt.widen
610+
val fix = if (isPatDef) "`: @unchecked` after" else "`case ` before"
609611
ctx.errorOrMigrationWarning(
610-
ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type ${pt.dropAnnot(defn.UncheckedAnnot)}
612+
ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type $reportedPt
611613
|
612-
|If the narrowing is intentional, this can be communicated by writing `: @unchecked` after the full pattern.${err.rewriteNotice}""",
614+
|If the narrowing is intentional, this can be communicated by writing $fix the full pattern.${err.rewriteNotice}""",
613615
pat.sourcePos)
614616
false
615617
}
616618

617619
def check(pat: Tree, pt: Type): Boolean = (pt <:< pat.tpe) || fail(pat, pt)
618620

619-
!ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR
620-
pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || {
621-
pat match {
622-
case Bind(_, pat1) =>
623-
checkIrrefutable(pat1, pt)
624-
case UnApply(fn, _, pats) =>
625-
check(pat, pt) &&
626-
(isIrrefutableUnapply(fn) || fail(pat, pt)) && {
627-
val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.sourcePos)
628-
pats.corresponds(argPts)(checkIrrefutable)
629-
}
630-
case Alternative(pats) =>
631-
pats.forall(checkIrrefutable(_, pt))
632-
case Typed(arg, tpt) =>
633-
check(pat, pt) && checkIrrefutable(arg, pt)
634-
case Ident(nme.WILDCARD) =>
635-
true
636-
case _ =>
637-
check(pat, pt)
621+
def recur(pat: Tree, pt: Type): Boolean =
622+
!ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR
623+
pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || {
624+
patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt")
625+
pat match {
626+
case Bind(_, pat1) =>
627+
recur(pat1, pt)
628+
case UnApply(fn, _, pats) =>
629+
check(pat, pt) &&
630+
(isIrrefutableUnapply(fn) || fail(pat, pt)) && {
631+
val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.sourcePos)
632+
pats.corresponds(argPts)(recur)
633+
}
634+
case Alternative(pats) =>
635+
pats.forall(recur(_, pt))
636+
case Typed(arg, tpt) =>
637+
check(pat, pt) && recur(arg, pt)
638+
case Ident(nme.WILDCARD) =>
639+
true
640+
case _ =>
641+
check(pat, pt)
642+
}
638643
}
639-
}
644+
645+
recur(pat, pt)
640646
}
641647

642648
/** Check that `path` is a legal prefix for an import or export clause */

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,19 +1029,26 @@ class Typer extends Namer
10291029
}
10301030
else {
10311031
val (protoFormals, _) = decomposeProtoFunction(pt, 1)
1032-
val unchecked = pt.isRef(defn.PartialFunctionClass)
1033-
typed(desugar.makeCaseLambda(tree.cases, protoFormals.length, unchecked).withSpan(tree.span), pt)
1032+
val checkMode =
1033+
if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None
1034+
else desugar.MatchCheck.Exhaustive
1035+
typed(desugar.makeCaseLambda(tree.cases, checkMode, protoFormals.length).withSpan(tree.span), pt)
10341036
}
10351037
case _ =>
10361038
if (tree.isInline) checkInInlineContext("inline match", tree.posd)
10371039
val sel1 = typedExpr(tree.selector)
10381040
val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen
10391041
val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
10401042
result match {
1041-
case Match(sel, CaseDef(pat, _, _) :: _)
1042-
if (tree.selector.removeAttachment(desugar.PatDefMatch).isDefined) =>
1043-
if (!checkIrrefutable(pat, sel.tpe) && ctx.scala2Mode)
1044-
patch(Span(pat.span.end), ": @unchecked")
1043+
case Match(sel, CaseDef(pat, _, _) :: _) =>
1044+
tree.selector.removeAttachment(desugar.CheckIrrefutable) match {
1045+
case Some(checkMode) =>
1046+
val isPatDef = checkMode == desugar.MatchCheck.IrrefutablePatDef
1047+
if (!checkIrrefutable(pat, sel.tpe, isPatDef) && ctx.settings.migration.value)
1048+
if (isPatDef) patch(Span(pat.span.end), ": @unchecked")
1049+
else patch(Span(pat.span.start), "case ")
1050+
case _ =>
1051+
}
10451052
case _ =>
10461053
}
10471054
result

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class CompilationTests extends ParallelTesting {
194194
compileFilesInDir("tests/run-custom-args/Yretain-trees", defaultOptions and "-Yretain-trees"),
195195
compileFile("tests/run-custom-args/tuple-cons.scala", allowDeepSubtypes),
196196
compileFile("tests/run-custom-args/i5256.scala", allowDeepSubtypes),
197+
compileFile("tests/run-custom-args/fors.scala", defaultOptions and "-strict"),
197198
compileFile("tests/run-custom-args/no-useless-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"),
198199
compileFilesInDir("tests/run", defaultOptions)
199200
).checkRuns()

tests/neg/zipped.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// This test shows some un-intuitive behavior of the `zipped` method.
2+
object Test {
3+
val xs: List[Int] = ???
4+
5+
// 1. This works, since withFilter is not defined on Tuple3zipped. Instead,
6+
// an implicit conversion from Tuple3zipped to Traversable[(Int, Int, Int)] is inserted.
7+
// The subsequent map operation has the right type for this Traversable.
8+
(xs, xs, xs).zipped
9+
.withFilter( (x: (Int, Int, Int)) => x match { case (x, y, z) => true } ) // OK
10+
.map( (x: (Int, Int, Int)) => x match { case (x, y, z) => x + y + z }) // OK
11+
12+
13+
// 2. This works as well, because of auto untupling i.e. `case` is inserted.
14+
// But it does not work in Scala2.
15+
(xs, xs, xs).zipped
16+
.withFilter( (x: (Int, Int, Int)) => x match { case (x, y, z) => true } ) // OK
17+
.map( (x: Int, y: Int, z: Int) => x + y + z ) // OK
18+
// works, because of auto untupling i.e. `case` is inserted
19+
// does not work in Scala2
20+
21+
// 3. Now, without withFilter, it's the opposite, we need the 3 parameter map.
22+
(xs, xs, xs).zipped
23+
.map( (x: Int, y: Int, z: Int) => x + y + z ) // OK
24+
25+
// 4. The single parameter map does not work.
26+
(xs, xs, xs).zipped
27+
.map( (x: (Int, Int, Int)) => x match { case (x, y, z) => x + y + z }) // error
28+
29+
// 5. If we leave out the parameter type, we get a "Wrong number of parameters" error instead
30+
(xs, xs, xs).zipped
31+
.map( x => x match { case (x, y, z) => x + y + z }) // error
32+
33+
// This means that the following works in Dotty in normal mode, since a `withFilter`
34+
// is inserted. But it does no work under -strict. And it will not work in Scala 3.1.
35+
// The reason is that without -strict, the code below is mapped to (1), but with -strict
36+
// it is mapped to (5).
37+
for ((x, y, z) <- (xs, xs, xs).zipped) yield x + y + z
38+
}

0 commit comments

Comments
 (0)