Skip to content

Commit 50b3b31

Browse files
sjrdnicolasstucki
authored andcommitted
SIP-56: Better foundations for match types
See #18262
1 parent 16f1680 commit 50b3b31

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1508
-285
lines changed

compiler/src/dotty/tools/dotc/core/MatchTypeTrace.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ object MatchTypeTrace:
1212

1313
private enum TraceEntry:
1414
case TryReduce(scrut: Type)
15-
case Stuck(scrut: Type, stuckCase: Type, otherCases: List[Type])
16-
case NoInstance(scrut: Type, stuckCase: Type, fails: List[(Name, TypeBounds)])
15+
case Stuck(scrut: Type, stuckCase: MatchTypeCaseSpec, otherCases: List[MatchTypeCaseSpec])
16+
case NoInstance(scrut: Type, stuckCase: MatchTypeCaseSpec, fails: List[(Name, TypeBounds)])
1717
case EmptyScrutinee(scrut: Type)
1818
import TraceEntry.*
1919

@@ -54,10 +54,10 @@ object MatchTypeTrace:
5454
* not disjoint from it either, which means that the remaining cases `otherCases`
5555
* cannot be visited. Only the first failure is recorded.
5656
*/
57-
def stuck(scrut: Type, stuckCase: Type, otherCases: List[Type])(using Context) =
57+
def stuck(scrut: Type, stuckCase: MatchTypeCaseSpec, otherCases: List[MatchTypeCaseSpec])(using Context) =
5858
matchTypeFail(Stuck(scrut, stuckCase, otherCases))
5959

60-
def noInstance(scrut: Type, stuckCase: Type, fails: List[(Name, TypeBounds)])(using Context) =
60+
def noInstance(scrut: Type, stuckCase: MatchTypeCaseSpec, fails: List[(Name, TypeBounds)])(using Context) =
6161
matchTypeFail(NoInstance(scrut, stuckCase, fails))
6262

6363
/** Record a failure that scrutinee `scrut` is provably empty.
@@ -80,13 +80,16 @@ object MatchTypeTrace:
8080
case _ =>
8181
op
8282

83+
def caseText(spec: MatchTypeCaseSpec)(using Context): String =
84+
caseText(spec.origMatchCase)
85+
8386
def caseText(tp: Type)(using Context): String = tp match
8487
case tp: HKTypeLambda => caseText(tp.resultType)
8588
case defn.MatchCase(any, body) if any eq defn.AnyType => i"case _ => $body"
8689
case defn.MatchCase(pat, body) => i"case $pat => $body"
8790
case _ => i"case $tp"
8891

89-
private def casesText(cases: List[Type])(using Context) =
92+
private def casesText(cases: List[MatchTypeCaseSpec])(using Context) =
9093
i"${cases.map(caseText)}%\n %"
9194

9295
private def explainEntry(entry: TraceEntry)(using Context): String = entry match
@@ -116,10 +119,15 @@ object MatchTypeTrace:
116119
| ${fails.map((name, bounds) => i"$name$bounds")}%\n %"""
117120

118121
/** The failure message when the scrutinee `scrut` does not match any case in `cases`. */
119-
def noMatchesText(scrut: Type, cases: List[Type])(using Context): String =
122+
def noMatchesText(scrut: Type, cases: List[MatchTypeCaseSpec])(using Context): String =
120123
i"""failed since selector $scrut
121124
|matches none of the cases
122125
|
123126
| ${casesText(cases)}"""
124127

128+
def illegalPatternText(scrut: Type, cas: MatchTypeCaseSpec.LegacyPatMat)(using Context): String =
129+
i"""The match type contains an illegal case:
130+
| ${caseText(cas)}
131+
|(this error can be ignored for now with `-source:3.3`)"""
132+
125133
end MatchTypeTrace

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 460 additions & 211 deletions
Large diffs are not rendered by default.

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ package dotty.tools
22
package dotc
33
package core
44

5+
import java.util as ju
6+
57
import Symbols.*
68
import Flags.*
79
import Names.*
810
import StdNames.*, NameOps.*
911
import NullOpsDecorator.*
10-
import NameKinds.SkolemName
12+
import NameKinds.{SkolemName, WildcardParamName}
1113
import Scopes.*
1214
import Constants.*
1315
import Contexts.*
@@ -30,6 +32,8 @@ import Hashable.*
3032
import Uniques.*
3133
import collection.mutable
3234
import config.Config
35+
import config.Feature.sourceVersion
36+
import config.SourceVersion
3337
import annotation.{tailrec, constructorOnly}
3438
import scala.util.hashing.{ MurmurHash3 => hashing }
3539
import config.Printers.{core, typr, matchTypes}
@@ -5036,7 +5040,7 @@ object Types extends TypeUtils {
50365040
trace(i"reduce match type $this $hashCode", matchTypes, show = true)(inMode(Mode.Type) {
50375041
def matchCases(cmp: TrackingTypeComparer): Type =
50385042
val saved = ctx.typerState.snapshot()
5039-
try cmp.matchCases(scrutinee.normalized, cases)
5043+
try cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
50405044
catch case ex: Throwable =>
50415045
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
50425046
finally
@@ -5088,6 +5092,181 @@ object Types extends TypeUtils {
50885092
case _ => None
50895093
}
50905094

5095+
enum MatchTypeCasePattern:
5096+
case Capture(num: Int, isWildcard: Boolean)
5097+
case TypeTest(tpe: Type)
5098+
case BaseTypeTest(classType: TypeRef, argPatterns: List[MatchTypeCasePattern], needsConcreteScrut: Boolean)
5099+
case CompileTimeS(argPattern: MatchTypeCasePattern)
5100+
case AbstractTypeConstructor(tycon: Type, argPatterns: List[MatchTypeCasePattern])
5101+
case TypeMemberExtractor(typeMemberName: TypeName, capture: Capture)
5102+
5103+
def isTypeTest: Boolean =
5104+
this.isInstanceOf[TypeTest]
5105+
5106+
def needsConcreteScrutInVariantPos: Boolean = this match
5107+
case Capture(_, isWildcard) => !isWildcard
5108+
case TypeTest(_) => false
5109+
case _ => true
5110+
end MatchTypeCasePattern
5111+
5112+
enum MatchTypeCaseSpec:
5113+
case SubTypeTest(origMatchCase: Type, pattern: Type, body: Type)
5114+
case SpeccedPatMat(origMatchCase: HKTypeLambda, captureCount: Int, pattern: MatchTypeCasePattern, body: Type)
5115+
case LegacyPatMat(origMatchCase: HKTypeLambda)
5116+
case MissingCaptures(origMatchCase: HKTypeLambda, missing: ju.BitSet)
5117+
5118+
def origMatchCase: Type
5119+
end MatchTypeCaseSpec
5120+
5121+
object MatchTypeCaseSpec:
5122+
def analyze(cas: Type)(using Context): MatchTypeCaseSpec =
5123+
cas match
5124+
case cas: HKTypeLambda if !sourceVersion.isAtLeast(SourceVersion.`3.4`) =>
5125+
// Always apply the legacy algorithm under -source:3.3 and below
5126+
LegacyPatMat(cas)
5127+
case cas: HKTypeLambda =>
5128+
val defn.MatchCase(pat, body) = cas.resultType: @unchecked
5129+
val missing = checkCapturesPresent(cas, pat)
5130+
if !missing.isEmpty then
5131+
MissingCaptures(cas, missing)
5132+
else
5133+
val specPattern = tryConvertToSpecPattern(cas, pat)
5134+
if specPattern != null then
5135+
SpeccedPatMat(cas, cas.paramNames.size, specPattern, body)
5136+
else
5137+
LegacyPatMat(cas)
5138+
case _ =>
5139+
val defn.MatchCase(pat, body) = cas: @unchecked
5140+
SubTypeTest(cas, pat, body)
5141+
end analyze
5142+
5143+
/** Checks that all the captures of the case are present in the case.
5144+
*
5145+
* Sometimes, because of earlier substitutions of an abstract type constructor,
5146+
* we can end up with patterns that do not mention all their captures anymore.
5147+
* This can happen even when the body still refers to these missing captures.
5148+
* In that case, we must always consider the case to be unmatchable, i.e., to
5149+
* become `Stuck`.
5150+
*
5151+
* See pos/i12127.scala for an example.
5152+
*/
5153+
def checkCapturesPresent(cas: HKTypeLambda, pat: Type)(using Context): ju.BitSet =
5154+
val captureCount = cas.paramNames.size
5155+
val missing = new java.util.BitSet(captureCount)
5156+
missing.set(0, captureCount)
5157+
new CheckCapturesPresent(cas).apply(missing, pat)
5158+
5159+
private class CheckCapturesPresent(cas: HKTypeLambda)(using Context) extends TypeAccumulator[ju.BitSet]:
5160+
def apply(missing: ju.BitSet, tp: Type): ju.BitSet = tp match
5161+
case TypeParamRef(binder, num) if binder eq cas =>
5162+
missing.clear(num)
5163+
missing
5164+
case _ =>
5165+
foldOver(missing, tp)
5166+
end CheckCapturesPresent
5167+
5168+
private def tryConvertToSpecPattern(caseLambda: HKTypeLambda, pat: Type)(using Context): MatchTypeCasePattern | Null =
5169+
var typeParamRefsAccountedFor: Int = 0
5170+
5171+
def rec(pat: Type, variance: Int): MatchTypeCasePattern | Null =
5172+
pat match
5173+
case pat @ TypeParamRef(binder, num) if binder eq caseLambda =>
5174+
typeParamRefsAccountedFor += 1
5175+
MatchTypeCasePattern.Capture(num, isWildcard = pat.paramName.is(WildcardParamName))
5176+
5177+
case pat @ AppliedType(tycon: TypeRef, args) if variance == 1 =>
5178+
val tyconSym = tycon.symbol
5179+
if tyconSym.isClass then
5180+
if tyconSym.name.startsWith("Tuple") && defn.isTupleNType(pat) then
5181+
rec(pat.toNestedPairs, variance)
5182+
else
5183+
recArgPatterns(pat) { argPatterns =>
5184+
val needsConcreteScrut = argPatterns.zip(tycon.typeParams).exists {
5185+
(argPattern, tparam) => tparam.paramVarianceSign != 0 && argPattern.needsConcreteScrutInVariantPos
5186+
}
5187+
MatchTypeCasePattern.BaseTypeTest(tycon, argPatterns, needsConcreteScrut)
5188+
}
5189+
else if defn.isCompiletime_S(tyconSym) && args.sizeIs == 1 then
5190+
val argPattern = rec(args.head, variance)
5191+
if argPattern == null then
5192+
null
5193+
else if argPattern.isTypeTest then
5194+
MatchTypeCasePattern.TypeTest(pat)
5195+
else
5196+
MatchTypeCasePattern.CompileTimeS(argPattern)
5197+
else
5198+
tycon.info match
5199+
case _: RealTypeBounds =>
5200+
recAbstractTypeConstructor(pat)
5201+
case TypeAlias(tl @ HKTypeLambda(onlyParam :: Nil, resType: RefinedType)) =>
5202+
/* Unlike for eta-expanded classes, the typer does not automatically
5203+
* dealias poly type aliases to refined types. So we have to give them
5204+
* a chance here.
5205+
* We are quite specific about the shape of type aliases that we are willing
5206+
* to dealias this way, because we must not dealias arbitrary type constructors
5207+
* that could refine the bounds of the captures; those would amount of
5208+
* type-test + capture combos, which are out of the specced match types.
5209+
*/
5210+
rec(pat.superType, variance)
5211+
case _ =>
5212+
null
5213+
5214+
case pat @ AppliedType(tycon: TypeParamRef, _) if variance == 1 =>
5215+
recAbstractTypeConstructor(pat)
5216+
5217+
case pat @ RefinedType(parent, refinedName: TypeName, TypeAlias(alias @ TypeParamRef(binder, num)))
5218+
if variance == 1 && (binder eq caseLambda) =>
5219+
parent.member(refinedName) match
5220+
case refinedMember: SingleDenotation if refinedMember.exists =>
5221+
// Check that the bounds of the capture contain the bounds of the inherited member
5222+
val refinedMemberBounds = refinedMember.info
5223+
val captureBounds = caseLambda.paramInfos(num)
5224+
if captureBounds.contains(refinedMemberBounds) then
5225+
/* In this case, we know that any member we eventually find during reduction
5226+
* will have bounds that fit in the bounds of the capture. Therefore, no
5227+
* type-test + capture combo is necessary, and we can apply the specced match types.
5228+
*/
5229+
val capture = rec(alias, variance = 0).asInstanceOf[MatchTypeCasePattern.Capture]
5230+
MatchTypeCasePattern.TypeMemberExtractor(refinedName, capture)
5231+
else
5232+
// Otherwise, a type-test + capture combo might be necessary, and we are out of spec
5233+
null
5234+
case _ =>
5235+
// If the member does not refine a member of the `parent`, we are out of spec
5236+
null
5237+
5238+
case _ =>
5239+
MatchTypeCasePattern.TypeTest(pat)
5240+
end rec
5241+
5242+
def recAbstractTypeConstructor(pat: AppliedType): MatchTypeCasePattern | Null =
5243+
recArgPatterns(pat) { argPatterns =>
5244+
MatchTypeCasePattern.AbstractTypeConstructor(pat.tycon, argPatterns)
5245+
}
5246+
end recAbstractTypeConstructor
5247+
5248+
def recArgPatterns(pat: AppliedType)(whenNotTypeTest: List[MatchTypeCasePattern] => MatchTypeCasePattern | Null): MatchTypeCasePattern | Null =
5249+
val AppliedType(tycon, args) = pat
5250+
val tparams = tycon.typeParams
5251+
val argPatterns = args.zip(tparams).map { (arg, tparam) =>
5252+
rec(arg, tparam.paramVarianceSign)
5253+
}
5254+
if argPatterns.exists(_ == null) then
5255+
null
5256+
else
5257+
val argPatterns1 = argPatterns.asInstanceOf[List[MatchTypeCasePattern]] // they are not null
5258+
if argPatterns1.forall(_.isTypeTest) then
5259+
MatchTypeCasePattern.TypeTest(pat)
5260+
else
5261+
whenNotTypeTest(argPatterns1)
5262+
end recArgPatterns
5263+
5264+
val result = rec(pat, variance = 1)
5265+
if typeParamRefsAccountedFor == caseLambda.paramNames.size then result
5266+
else null
5267+
end tryConvertToSpecPattern
5268+
end MatchTypeCaseSpec
5269+
50915270
// ------ ClassInfo, Type Bounds --------------------------------------------------
50925271

50935272
type TypeOrSymbol = Type | Symbol

compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
204204
case VarArgsParamCannotBeGivenID // errorNumber: 188
205205
case ExtractorNotFoundID // errorNumber: 189
206206
case PureUnitExpressionID // errorNumber: 190
207+
case MatchTypeLegacyPatternID // errorNumber: 191
207208

208209
def errorNumber = ordinal - 1
209210

compiler/src/dotty/tools/dotc/reporting/messages.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,6 +3073,10 @@ class MatchTypeScrutineeCannotBeHigherKinded(tp: Type)(using Context)
30733073
def msg(using Context) = i"the scrutinee of a match type cannot be higher-kinded"
30743074
def explain(using Context) = ""
30753075

3076+
class MatchTypeLegacyPattern(errorText: String)(using Context) extends TypeMsg(MatchTypeLegacyPatternID):
3077+
def msg(using Context) = errorText
3078+
def explain(using Context) = ""
3079+
30763080
class ClosureCannotHaveInternalParameterDependencies(mt: Type)(using Context)
30773081
extends TypeMsg(ClosureCannotHaveInternalParameterDependenciesID):
30783082
def msg(using Context) =

compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ object TypeTestsCasts {
153153

154154
case x =>
155155
// always false test warnings are emitted elsewhere
156-
TypeComparer.provablyDisjoint(x, tpe.derivedAppliedType(tycon, targs.map(_ => WildcardType)))
156+
// provablyDisjoint wants fully applied types as input; because we're in the middle of erasure, we sometimes get raw types here
157+
val xApplied =
158+
val tparams = x.typeParams
159+
if tparams.isEmpty then x else x.appliedTo(tparams.map(_ => WildcardType))
160+
TypeComparer.provablyDisjoint(xApplied, tpe.derivedAppliedType(tycon, targs.map(_ => WildcardType)))
157161
|| typeArgsDeterminable(X, tpe)
158162
||| i"its type arguments can't be determined from $X"
159163
}

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ trait TypeAssigner {
297297
else fntpe.resultType // fast path optimization
298298
else
299299
errorType(em"wrong number of arguments at ${ctx.phase.prev} for $fntpe: ${fn.tpe}, expected: ${fntpe.paramInfos.length}, found: ${args.length}", tree.srcPos)
300+
case err: ErrorType =>
301+
err
300302
case t =>
301303
if (ctx.settings.Ydebug.value) new FatalError("").printStackTrace()
302304
errorType(err.takesNoParamsMsg(fn, ""), tree.srcPos)
@@ -563,5 +565,3 @@ object TypeAssigner extends TypeAssigner:
563565
def seqLitType(tree: untpd.SeqLiteral, elemType: Type)(using Context) = tree match
564566
case tree: untpd.JavaSeqLiteral => defn.ArrayOf(elemType)
565567
case _ => if ctx.erasedTypes then defn.SeqType else defn.SeqType.appliedTo(elemType)
566-
567-

tests/neg/12800.scala

Lines changed: 0 additions & 21 deletions
This file was deleted.

tests/neg/6314-1.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
object G {
2-
final class X
3-
final class Y
2+
trait X
3+
class Y
4+
class Z
45

56
trait FooSig {
67
type Type
@@ -13,14 +14,14 @@ object G {
1314
type Foo = Foo.Type
1415

1516
type Bar[A] = A match {
16-
case X & Y => String
17+
case X & Z => String
1718
case Y => Int
1819
}
1920

2021
def main(args: Array[String]): Unit = {
2122
val a: Bar[X & Y] = "hello" // error
2223
val i: Bar[Y & Foo] = Foo.apply[Bar](a)
23-
val b: Int = i // error
24+
val b: Int = i
2425
println(b + 1)
2526
}
2627
}

tests/neg/6314-6.check

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
-- Error: tests/neg/6314-6.scala:26:3 ----------------------------------------------------------------------------------
2+
26 | (new YY {}).boom // error: object creation impossible
3+
| ^
4+
|object creation impossible, since def apply(fa: String): Int in trait XX in object Test3 is not defined
5+
|(Note that
6+
| parameter String in def apply(fa: String): Int in trait XX in object Test3 does not match
7+
| parameter Test3.Bar[X & Object with Test3.YY {...}#Foo] in def apply(fa: Test3.Bar[X & YY.this.Foo]): Test3.Bar[Y & YY.this.Foo] in trait YY in object Test3
8+
| )
9+
-- Error: tests/neg/6314-6.scala:52:3 ----------------------------------------------------------------------------------
10+
52 | (new YY {}).boom // error: object creation impossible
11+
| ^
12+
|object creation impossible, since def apply(fa: String): Int in trait XX in object Test4 is not defined
13+
|(Note that
14+
| parameter String in def apply(fa: String): Int in trait XX in object Test4 does not match
15+
| parameter Test4.Bar[X & Object with Test4.YY {...}#FooAlias] in def apply(fa: Test4.Bar[X & YY.this.FooAlias]): Test4.Bar[Y & YY.this.FooAlias] in trait YY in object Test4
16+
| )

0 commit comments

Comments
 (0)