diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index c1dd78451bae..af973db89aac 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -346,6 +346,8 @@ object desugar { // Propagate down the expected type to the leafs of the expression case Block(stats, expr) => cpy.Block(tree)(stats, adaptToExpectedTpt(expr)) + case AssumeInfo(sym, info, body) => + cpy.AssumeInfo(tree)(sym, info, adaptToExpectedTpt(body)) case If(cond, thenp, elsep) => cpy.If(tree)(cond, adaptToExpectedTpt(thenp), adaptToExpectedTpt(elsep)) case untpd.Parens(expr) => @@ -1645,6 +1647,7 @@ object desugar { case Tuple(trees) => (pats corresponds trees)(isIrrefutable) case Parens(rhs1) => matchesTuple(pats, rhs1) case Block(_, rhs1) => matchesTuple(pats, rhs1) + case AssumeInfo(_, _, rhs1) => matchesTuple(pats, rhs1) case If(_, thenp, elsep) => matchesTuple(pats, thenp) && matchesTuple(pats, elsep) case Match(_, cases) => cases forall (matchesTuple(pats, _)) case CaseDef(_, _, rhs1) => matchesTuple(pats, rhs1) diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index c2147b6af2d3..b38b4c67a1a5 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -330,6 +330,7 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] => case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p) case Match(_, cases) => cases forall (c => forallResults(c.body, p)) case Block(_, expr) => forallResults(expr, p) + case AssumeInfo(_, _, body) => forallResults(body, p) case _ => p(tree) } @@ -1088,6 +1089,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => case Typed(expr, _) => unapply(expr) case Inlined(_, Nil, expr) => unapply(expr) case Block(Nil, expr) => unapply(expr) + case AssumeInfo(_, _, body) => unapply(body) case _ => tree.tpe.widenTermRefExpr.dealias.normalized match case ConstantType(Constant(x)) => Some(x) diff --git a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala index faeafae97f5e..b7a7903dbf0e 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala @@ -135,6 +135,20 @@ class TreeTypeMap( cpy.LambdaTypeTree(tdef)(tparams1, tmap1.transform(body)) case inlined: Inlined => transformInlined(inlined) + case tree: AssumeInfo => + def mapBody(body: Tree) = body match + case tree @ AssumeInfo(_, _, _) => + val tree1 = treeMap(tree) + tree1.withType(mapType(tree1.tpe)) + case _ => body + tree.fold(transform, mapBody) { case (assumeInfo @ AssumeInfo(sym, info, _), body) => + mapType(sym.typeRef) match + case tp: TypeRef if tp eq sym.typeRef => + val sym1 = sym.subst(substFrom, substTo) + val info1 = mapType(info) + cpy.AssumeInfo(assumeInfo)(sym = sym1, info = info1, body = body) + case _ => body // if the AssumeInfo symbol maps (as a type) to another type, we lose the associated info + } case cdef @ CaseDef(pat, guard, rhs) => val tmap = withMappedSyms(patVars(pat)) val pat1 = tmap.transform(pat) diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index c0b5987c3875..750d2ed54c28 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -567,6 +567,20 @@ object Trees { override def isTerm: Boolean = !isType // this will classify empty trees as terms, which is necessary } + case class AssumeInfo[+T <: Untyped] private[ast] (sym: Symbol, info: Type, body: Tree[T])(implicit @constructorOnly src: SourceFile) + extends ProxyTree[T] { + type ThisTree[+T <: Untyped] <: AssumeInfo[T] + def forwardTo: Tree[T] = body + + def fold[U >: T <: Untyped, A]( + start: Context ?=> Tree[U] => A, mapBody: Tree[U] => Tree[U] = (body: Tree[U]) => body, + )(combine: Context ?=> (AssumeInfo[U], A) => A)(using Context): A = + val body1 = mapBody(body) + inContext(ctx.withAssumeInfo(ctx.assumeInfo.add(sym, info))) { + combine(this, start(body1)) + } + } + /** if cond then thenp else elsep */ case class If[+T <: Untyped] private[ast] (cond: Tree[T], thenp: Tree[T], elsep: Tree[T])(implicit @constructorOnly src: SourceFile) extends TermTree[T] { @@ -1074,6 +1088,7 @@ object Trees { type NamedArg = Trees.NamedArg[T] type Assign = Trees.Assign[T] type Block = Trees.Block[T] + type AssumeInfo = Trees.AssumeInfo[T] type If = Trees.If[T] type InlineIf = Trees.InlineIf[T] type Closure = Trees.Closure[T] @@ -1212,6 +1227,9 @@ object Trees { case tree: Block if (stats eq tree.stats) && (expr eq tree.expr) => tree case _ => finalize(tree, untpd.Block(stats, expr)(sourceFile(tree))) } + def AssumeInfo(tree: Tree)(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo = tree match + case tree: AssumeInfo if (sym eq tree.sym) && (info eq tree.info) && (body eq tree.body) => tree + case _ => finalize(tree, untpd.AssumeInfo(sym, info, body)(sourceFile(tree))) def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = tree match { case tree: If if (cond eq tree.cond) && (thenp eq tree.thenp) && (elsep eq tree.elsep) => tree case tree: InlineIf => finalize(tree, untpd.InlineIf(cond, thenp, elsep)(sourceFile(tree))) @@ -1344,6 +1362,8 @@ object Trees { // Copier methods with default arguments; these demand that the original tree // is of the same class as the copy. We only include trees with more than 2 elements here. + def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo = + AssumeInfo(tree: Tree)(sym, info, body) def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If = If(tree: Tree)(cond, thenp, elsep) def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure = @@ -1433,6 +1453,10 @@ object Trees { cpy.Closure(tree)(transform(env), transform(meth), transform(tpt)) case Match(selector, cases) => cpy.Match(tree)(transform(selector), transformSub(cases)) + case tree @ AssumeInfo(sym, info, body) => + tree.fold(transform) { (assumeInfo, body) => + cpy.AssumeInfo(assumeInfo)(body = body) + } case CaseDef(pat, guard, body) => cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body)) case Labeled(bind, expr) => @@ -1569,6 +1593,8 @@ object Trees { this(this(this(x, env), meth), tpt) case Match(selector, cases) => this(this(x, selector), cases) + case tree @ AssumeInfo(sym, info, body) => + tree.fold(this(x, _))((_, x) => x) case CaseDef(pat, guard, body) => this(this(this(x, pat), guard), body) case Labeled(bind, expr) => diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index d1b1cdf607b5..5fedf740fc57 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -98,6 +98,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { Block(stats, expr) } + def AssumeInfo(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo = + ta.assignType(untpd.AssumeInfo(sym, info, body), body) + def If(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = ta.assignType(untpd.If(cond, thenp, elsep), thenp, elsep) @@ -683,6 +686,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { } } + override def AssumeInfo(tree: Tree)(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo = + val tree1 = untpdCpy.AssumeInfo(tree)(sym, info, body) + tree match + case tree: AssumeInfo if body.tpe eq tree.body.tpe => tree1.withTypeUnchecked(tree.tpe) + case _ => ta.assignType(tree1, body) + override def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = { val tree1 = untpdCpy.If(tree)(cond, thenp, elsep) tree match { @@ -767,6 +776,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { } } + override def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo = + AssumeInfo(tree: Tree)(sym, info, body) override def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If = If(tree: Tree)(cond, thenp, elsep) override def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure = diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index a262c3658399..dd3521fb2dbe 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -388,6 +388,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { def NamedArg(name: Name, arg: Tree)(implicit src: SourceFile): NamedArg = new NamedArg(name, arg) def Assign(lhs: Tree, rhs: Tree)(implicit src: SourceFile): Assign = new Assign(lhs, rhs) def Block(stats: List[Tree], expr: Tree)(implicit src: SourceFile): Block = new Block(stats, expr) + def AssumeInfo(sym: Symbol, info: Type, body: Tree)(implicit src: SourceFile): AssumeInfo = new AssumeInfo(sym, info, body) def If(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new If(cond, thenp, elsep) def InlineIf(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new InlineIf(cond, thenp, elsep) def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt) diff --git a/compiler/src/dotty/tools/dotc/core/AssumeInfoMap.scala b/compiler/src/dotty/tools/dotc/core/AssumeInfoMap.scala new file mode 100644 index 000000000000..5f4364a6be12 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/core/AssumeInfoMap.scala @@ -0,0 +1,28 @@ +package dotty.tools +package dotc +package core + +import Contexts.*, Decorators.*, NameKinds.*, Symbols.*, Types.* +import ast.*, Trees.* +import printing.*, Texts.* + +import scala.annotation.internal.sharable +import util.{SimpleIdentitySet, SimpleIdentityMap} + +object AssumeInfoMap: + @sharable val empty: AssumeInfoMap = AssumeInfoMap(SimpleIdentityMap.empty) + +class AssumeInfoMap private ( + private val map: SimpleIdentityMap[Symbol, Type], +) extends Showable: + def info(sym: Symbol)(using Context): Type | Null = map(sym) + + def add(sym: Symbol, info: Type) = new AssumeInfoMap(map.updated(sym, info)) + + override def toText(p: Printer): Text = + given Context = p match + case p: PlainPrinter => p.printerContext + case _ => Contexts.NoContext + val deps = for (sym, info) <- map.toList yield + (p.toText(sym.typeRef) ~ p.toText(info)).close + ("AssumeInfo(" ~ Text(deps, ", ") ~ ")").close diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 2f28975dd066..f28200dd1c8e 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -143,6 +143,7 @@ object Contexts { def typerState: TyperState def gadt: GadtConstraint = gadtState.gadt def gadtState: GadtState + def assumeInfo: AssumeInfoMap def searchHistory: SearchHistory def source: SourceFile @@ -470,6 +471,15 @@ object Contexts { case None => fresh.dropProperty(key) } + final def withGadt(gadt: GadtConstraint): Context = + if this.gadt eq gadt then this else fresh.setGadtState(GadtState(gadt)) + + final def withGadtState(gadt: GadtState): Context = + if this.gadtState eq gadt then this else fresh.setGadtState(gadt) + + final def withAssumeInfo(assumeInfo: AssumeInfoMap): Context = + if this.assumeInfo eq assumeInfo then this else fresh.setAssumeInfo(assumeInfo) + def typer: Typer = this.typeAssigner match { case typer: Typer => typer case _ => new Typer @@ -545,6 +555,9 @@ object Contexts { private var _gadtState: GadtState = uninitialized final def gadtState: GadtState = _gadtState + private var _assumeInfo: AssumeInfoMap = uninitialized + final def assumeInfo: AssumeInfoMap = _assumeInfo + private var _searchHistory: SearchHistory = uninitialized final def searchHistory: SearchHistory = _searchHistory @@ -569,6 +582,7 @@ object Contexts { _tree = origin.tree _scope = origin.scope _gadtState = origin.gadtState + _assumeInfo = origin.assumeInfo _searchHistory = origin.searchHistory _source = origin.source _moreProperties = origin.moreProperties @@ -632,6 +646,10 @@ object Contexts { def setFreshGADTBounds: this.type = setGadtState(gadtState.fresh) + def setAssumeInfo(assumeInfo: AssumeInfoMap): this.type = + this._assumeInfo= assumeInfo + this + def setSearchHistory(searchHistory: SearchHistory): this.type = util.Stats.record("Context.setSearchHistory") this._searchHistory = searchHistory @@ -723,6 +741,7 @@ object Contexts { .updated(compilationUnitLoc, NoCompilationUnit) c._searchHistory = new SearchRoot c._gadtState = GadtState(GadtConstraint.empty) + c._assumeInfo = AssumeInfoMap.empty c end FreshContext diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index bb65cce84042..2eff7d902fa7 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -4,6 +4,7 @@ package core import Contexts.*, Decorators.*, Symbols.*, Types.* import NameKinds.UniqueName +import ast.*, Trees.* import config.Printers.{gadts, gadtsConstr} import util.{SimpleIdentitySet, SimpleIdentityMap} import printing._ @@ -27,6 +28,7 @@ class GadtConstraint private ( def symbols: List[Symbol] = mapping.keys def withConstraint(c: Constraint) = copy(myConstraint = c) def withWasConstrained = copy(wasConstrained = true) + def isEmpty: Boolean = mapping.isEmpty def add(sym: Symbol, tv: TypeVar): GadtConstraint = copy( mapping = mapping.updated(sym, tv), @@ -136,6 +138,13 @@ class GadtConstraint private ( override def toText(printer: Printer): Texts.Text = printer.toText(this) + def eql(that: GadtConstraint): Boolean = (this eq that) || { + myConstraint == that.myConstraint + && mapping == that.mapping + && reverseMapping == that.reverseMapping + && wasConstrained == that.wasConstrained + } + /** Provides more information than toText, by showing the underlying Constraint details. */ def debugBoundsDescription(using Context): String = i"$this\n$constraint" @@ -201,7 +210,7 @@ sealed trait GadtState { ) val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) => - val tv = TypeVar(paramRef, creatorState = null) + val tv = TypeVar(paramRef, creatorState = null, ctx.nestingLevel) gadt = gadt.add(sym, tv) tv } @@ -277,6 +286,8 @@ sealed trait GadtState { override def fullLowerBound(param: TypeParamRef)(using Context): Type = gadt.fullLowerBound(param) override def fullUpperBound(param: TypeParamRef)(using Context): Type = gadt.fullUpperBound(param) + def symbols: List[Symbol] = gadt.symbols + // ---- Debug ------------------------------------------------------------ override def constr = gadtsConstr diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 465978d329e6..3ac028566123 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -44,7 +44,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling myContext = c state = c.typerState monitored = false - GADTused = false recCount = 0 needsGc = false if Config.checkTypeComparerReset then checkReset() @@ -57,9 +56,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private var canCompareAtoms: Boolean = true // used for internal consistency checking - /** Indicates whether the subtype check used GADT bounds */ - private var GADTused: Boolean = false - private var myInstance: TypeComparer = this def currentInstance: TypeComparer = myInstance @@ -100,22 +96,15 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling assert(leftRoot == null) assert(frozenGadt == false) - /** Record that GADT bounds of `sym` were used in a subtype check. - * But exclude constructor type parameters, as these are aliased - * to the corresponding class parameters, which does not constitute - * a true usage of a GADT symbol. - */ - private def GADTusage(sym: Symbol): true = recordGadtUsageIf(!sym.owner.isConstructor) - - private def recordGadtUsageIf(cond: Boolean): true = { - if cond then - GADTused = true - true - } - private def isBottom(tp: Type) = tp.widen.isRef(NothingClass) - protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym) + protected def gadtBounds(sym: Symbol)(using Context): TypeBounds | Null = + val bounds = ctx.gadt.bounds(sym) + if bounds == null then + val info = ctx.assumeInfo.info(sym) + if info == null then null else info.bounds + else bounds + protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadtState.addBound(sym, b, isUpper) protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying @@ -139,12 +128,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling try topLevelSubType(tp1, tp2) finally myNecessaryConstraintsOnly = saved - def testSubType(tp1: Type, tp2: Type): CompareResult = - GADTused = false - if !topLevelSubType(tp1, tp2) then CompareResult.Fail - else if GADTused then CompareResult.OKwithGADTUsed - else CompareResult.OK - /** The current approximation state. See `ApproxState`. */ private var approx: ApproxState = ApproxState.Fresh protected def approxState: ApproxState = approx @@ -542,7 +525,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp1: MatchType => val reduced = tp1.reduced if reduced.exists then - recur(reduced, tp2) && recordGadtUsageIf { MatchType.thatReducesUsingGadt(tp1) } + recur(reduced, tp2) else thirdTry case _: FlexType => true @@ -562,10 +545,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling // one of which is constrained to be a subtype of another. // We do not need similar code in fourthTry, since we only need to care about // comparing two constrained types, and that case will be handled here first. - ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + ctx.gadt.isLess(tp1.symbol, tp2.symbol) case _ => false || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) - && (isBottom(tp1) || GADTusage(tp2.symbol)) isSubApproxHi(tp1, info2.lo.boxedIfTypeParam(tp2.symbol)) && (trustBounds || isSubApproxHi(tp1, info2.hi)) || compareGADT @@ -784,7 +766,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp2: MatchType => val reduced = tp2.reduced if reduced.exists then - recur(tp1, reduced) && recordGadtUsageIf { MatchType.thatReducesUsingGadt(tp2) } + recur(tp1, reduced) else fourthTry case tp2: MethodType => @@ -871,7 +853,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling && (!caseLambda.exists || canWidenAbstract || tp1.widen.underlyingClassRef(refinementOK = true).exists) then isSubType(base, tp2, if (tp1.isRef(cls2)) approx else approx.addLow) - && recordGadtUsageIf { MatchType.thatReducesUsingGadt(tp1) } || base.isInstanceOf[OrType] && fourthTry // if base is a disjunction, this might have come from a tp1 type that // expands to a match type. In this case, we should try to reduce the type @@ -887,7 +868,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1.symbol.onGadtBounds(gbounds1 => isSubTypeWhenFrozen(gbounds1.hi, tp2) || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) - && (tp2.isAny || GADTusage(tp1.symbol)) (!caseLambda.exists || canWidenAbstract) && isSubType(hi1.boxedIfTypeParam(tp1.symbol), tp2, approx.addLow) && (trustBounds || isSubType(lo1, tp2, approx.addLow)) @@ -1152,6 +1132,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling val tparams = tycon2.typeParams if (tparams.isEmpty) return false // can happen for ill-typed programs, e.g. neg/tcpoly_overloaded.scala + def compareParamRef2(param2: TypeParamRef): Boolean = + isMatchingApply(tp1) + || canConstrain(param2) && canInstantiate(param2) + || compareLower(bounds(param2), tyconIsTypeRef = false) + /** True if `tp1` and `tp2` have compatible type constructors and their * corresponding arguments are subtypes relative to their variance (see `isSubArgs`). */ @@ -1227,7 +1212,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling && ctx.gadt.contains(tycon2sym) && ctx.gadt.isLess(tycon1sym, tycon2sym) - val res = ( + ( tycon1sym == tycon2sym && isSubPrefix(tycon1.prefix, tycon2.prefix) || tycon1sym.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2)) || tycon2sym.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo)) @@ -1254,7 +1239,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling isSubArgs(args1, args2, tp1, tparams) } } - res && recordGadtUsageIf(touchedGADTs) case _ => false } @@ -1313,27 +1297,23 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling */ def compareLower(tycon2bounds: TypeBounds, tyconIsTypeRef: Boolean): Boolean = if ((tycon2bounds.lo `eq` tycon2bounds.hi) && !tycon2bounds.isInstanceOf[MatchAlias]) - if (tyconIsTypeRef) recur(tp1, tp2.superTypeNormalized) && recordGadtUsageIf(MatchType.thatReducesUsingGadt(tp2)) + if (tyconIsTypeRef) recur(tp1, tp2.superTypeNormalized) else isSubApproxHi(tp1, tycon2bounds.lo.applyIfParameterized(args2)) else fallback(tycon2bounds.lo) def byGadtBounds: Boolean = - { - tycon2 match - case tycon2: TypeRef => - val tycon2sym = tycon2.symbol - tycon2sym.onGadtBounds { bounds2 => - inFrozenGadt { compareLower(bounds2, tyconIsTypeRef = false) } - } - case _ => false - } && recordGadtUsageIf(true) + tycon2 match + case tycon2: TypeRef => + val tycon2sym = tycon2.symbol + tycon2sym.onGadtBounds { bounds2 => + inFrozenGadt { compareLower(bounds2, tyconIsTypeRef = false) } + } + case _ => false tycon2 match { case param2: TypeParamRef => - isMatchingApply(tp1) || - canConstrain(param2) && canInstantiate(param2) || - compareLower(bounds(param2), tyconIsTypeRef = false) + compareParamRef2(param2) case tycon2: TypeRef => isMatchingApply(tp1) || byGadtBounds || @@ -1354,7 +1334,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling if tv.isInstantiated then recur(tp1, tp2.superType) else - compareAppliedType2(tp2, tv.origin, args2) + compareParamRef2(tv.origin) case tycon2: AnnotatedType if !tycon2.isRefining => recur(tp1, tp2.superType) case tycon2: AppliedType => @@ -1383,12 +1363,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def byGadtBounds: Boolean = sym.onGadtBounds { bounds1 => inFrozenGadt { isSubType(bounds1.hi.applyIfParameterized(args1), tp2, approx.addLow) } - } && recordGadtUsageIf(true) - + } !sym.isClass && { defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) || - { recur(tp1.superTypeNormalized, tp2) && recordGadtUsageIf(MatchType.thatReducesUsingGadt(tp1)) } || + recur(tp1.superTypeNormalized, tp2) || tryLiftedToThis1 } || byGadtBounds case tycon1: TypeProxy => @@ -2887,9 +2866,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling object TypeComparer { - enum CompareResult: - case OK, Fail, OKwithGADTUsed - /** Class for unification variables used in `natValue`. */ private class AnyConstantType extends UncachedGroundType with ValueType { var tpe: Type = NoType @@ -2951,9 +2927,6 @@ object TypeComparer { def isSubTypeWhenFrozen(tp1: Type, tp2: Type)(using Context): Boolean = comparing(_.isSubTypeWhenFrozen(tp1, tp2)) - def testSubType(tp1: Type, tp2: Type)(using Context): CompareResult = - comparing(_.testSubType(tp1, tp2)) - def isSameTypeWhenFrozen(tp1: Type, tp2: Type)(using Context): Boolean = comparing(_.isSameTypeWhenFrozen(tp1, tp2)) diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index fe0fc8a6dc2d..b4bac4a5b069 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4923,13 +4923,6 @@ object Types { myReduced.nn } - /** True if the reduction uses GADT constraints. */ - def reducesUsingGadt(using Context): Boolean = - (reductionContext ne null) && reductionContext.keysIterator.exists { - case tp: TypeRef => reductionContext(tp).exists - case _ => false - } - override def computeHash(bs: Binders): Int = doHash(bs, scrutinee, bound :: cases) override def eql(that: Type): Boolean = that match { @@ -4945,11 +4938,6 @@ object Types { def apply(bound: Type, scrutinee: Type, cases: List[Type])(using Context): MatchType = unique(new CachedMatchType(bound, scrutinee, cases)) - def thatReducesUsingGadt(tp: Type)(using Context): Boolean = tp match - case MatchType.InDisguise(mt) => mt.reducesUsingGadt - case mt: MatchType => mt.reducesUsingGadt - case _ => false - /** Extractor for match types hidden behind an AppliedType/MatchAlias. */ object InDisguise: def unapply(tp: AppliedType)(using Context): Option[MatchType] = tp match diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TastyPrinter.scala b/compiler/src/dotty/tools/dotc/core/tasty/TastyPrinter.scala index 5876b69edfde..3dfeb550f58b 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TastyPrinter.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TastyPrinter.scala @@ -109,6 +109,7 @@ class TastyPrinter(bytes: Array[Byte]) { val length = treeStr("%5d".format(index(currentAddr) - index(startAddr))) sb.append(s"\n $length:" + " " * indent) } + def printInt() = sb.append(treeStr(" " + readInt())) def printNat() = sb.append(treeStr(" " + readNat())) def printName() = { val idx = readNat() @@ -139,6 +140,8 @@ class TastyPrinter(bytes: Array[Byte]) { printTrees() case PARAMtype => printNat(); printNat() + case ASSUMEINFO => + until(end) { printNat(); printInt(); printTree(); printTree() } case _ => printTrees() } diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala index 8a396921f32b..8c761c315efa 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala @@ -488,6 +488,13 @@ class TreePickler(pickler: TastyPickler) { writeByte(BLOCK) stats.foreach(preRegister) withLength { pickleTree(expr); stats.foreach(pickleTree) } + case AssumeInfo(sym, info, body) => + writeByte(ASSUMEINFO) + withLength { + pickleSymRef(sym) + pickleType(info) + pickleTree(body) + } case tree @ If(cond, thenp, elsep) => writeByte(IF) withLength { diff --git a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala index 9078a8959112..90f1494f5611 100644 --- a/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala +++ b/compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala @@ -1303,6 +1303,12 @@ class TreeUnpickler(reader: TastyReader, skipTree() readStats(ctx.owner, end, (stats, ctx) => Block(stats, exprReader.readTerm()(using ctx))) + case ASSUMEINFO => + val sym = readSymRef() + val info = readType() + inContext(ctx.withAssumeInfo(ctx.assumeInfo.add(sym, info))) { + AssumeInfo(sym, info, readTerm()) + } case INLINED => val exprReader = fork skipTree() diff --git a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala index 872dc7793ff4..d794af9d18ed 100644 --- a/compiler/src/dotty/tools/dotc/inlines/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/inlines/Inliner.scala @@ -856,9 +856,10 @@ class Inliner(val call: tpd.Tree)(using Context): case Some((caseBindings, rhs0)) => // drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match) // note that any actually necessary casts will be reinserted by the typing pass below - val rhs1 = rhs0 match { - case Block(stats, t) if t.span.isSynthetic => - t match { + def dropCasts(tree: Tree): Tree = tree match { + case tree @ AssumeInfo(_, _, body) => cpy.AssumeInfo(tree)(body = dropCasts(body)) + case Block(stats, tree) if tree.span.isSynthetic => + tree match { case Typed(expr, _) => Block(stats, expr) case TypeApply(sel@Select(expr, _), _) if sel.symbol.isTypeCast => @@ -866,8 +867,9 @@ class Inliner(val call: tpd.Tree)(using Context): case _ => rhs0 } - case _ => rhs0 + case _ => tree } + val rhs1 = dropCasts(rhs0) val (usedBindings, rhs2) = dropUnusedDefs(caseBindings, rhs1) val rhs = seq(usedBindings, rhs2) inlining.println(i"""--- reduce: @@ -943,6 +945,11 @@ class Inliner(val call: tpd.Tree)(using Context): case ident: Ident if ident.isType && typeBindingsSet.contains(ident.symbol) => val TypeAlias(r) = ident.symbol.info: @unchecked TypeTree(r).withSpan(ident.span) + case tree @ AssumeInfo(_, _, _) => + def rec(tree: Tree): Tree = tree match + case AssumeInfo(sym, _, body) if typeBindingsSet.contains(sym) => rec(body) + case _ => tree + rec(tree) case tree => tree } ) diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index ee0062f77dcd..f8867dd348f4 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -297,11 +297,15 @@ class PlainPrinter(_ctx: Context) extends Printer { protected def paramsText(lam: LambdaType): Text = { val erasedParams = lam.erasedParams def paramText(name: Name, tp: Type, erased: Boolean) = - keywordText("erased ").provided(erased) ~ toText(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true) + keywordText("erased ").provided(erased) ~ ParamRefNameString(name) ~ lambdaHash(lam) ~ toTextRHS(tp, isParameter = true) Text(lam.paramNames.lazyZip(lam.paramInfos).lazyZip(erasedParams).map(paramText), ", ") } - protected def ParamRefNameString(name: Name): String = nameString(name) + protected def ParamRefNameString(name: Name): String = + val name1 = name match + case name: TermName if homogenizedView && name.info.kind == NameKinds.DepParamName => name.underlying + case name => name + nameString(name1) protected def ParamRefNameString(param: ParamRef): String = ParamRefNameString(param.binder.paramNames(param.paramNum)) @@ -409,8 +413,8 @@ class PlainPrinter(_ctx: Context) extends Printer { val names = if lam.isDeclaredVarianceLambda then lam.paramNames.lazyZip(lam.declaredVariances).map((name, v) => - varianceSign(v) + name) - else lam.paramNames.map(_.toString) + varianceSign(v) ~ toText(name)) + else lam.paramNames.map(toText) val infos = lam.paramInfos.map(toText) val tparams = names.zip(infos).map(_ ~ _) ("[" ~ Text(tparams, ",") ~ "]", lam.resType) @@ -428,7 +432,7 @@ class PlainPrinter(_ctx: Context) extends Printer { /** String representation of a definition's type following its name */ protected def toTextRHS(tp: Type, isParameter: Boolean = false): Text = controlled { - homogenize(tp) match { + homogenizeArg(tp) match { case tp: TypeBounds => val (tparamStr, rhs) = decomposeLambdas(tp) val binder = rhs match @@ -662,7 +666,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case _ => "{...}" s"import $exprStr.$selectorStr" - def toText(c: OrderingConstraint): Text = + def toText(c: Constraint): Text = val savedConstraint = ctx.typerState.constraint try // The current TyperState constraint determines how type variables are printed diff --git a/compiler/src/dotty/tools/dotc/printing/Printer.scala b/compiler/src/dotty/tools/dotc/printing/Printer.scala index 326630844dde..076eb1af463e 100644 --- a/compiler/src/dotty/tools/dotc/printing/Printer.scala +++ b/compiler/src/dotty/tools/dotc/printing/Printer.scala @@ -160,7 +160,7 @@ abstract class Printer { def toText(result: ImportInfo): Text /** Textual representation of a constraint */ - def toText(c: OrderingConstraint): Text + def toText(c: Constraint): Text /** Textual representation of a GADT constraint */ def toText(c: GadtConstraint): Text diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 014e5ddf0d66..2fdf07737b39 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -478,6 +478,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { changePrec(GlobalPrec) { toTextLocal(lhs) ~ " = " ~ toText(rhs) } case block: Block => blockToText(block) + case AssumeInfo(sym, info, body) => + (typeText(toText(sym.typeRef)) ~ toText(info)).close ~ " ~ " ~ toText(body) case If(cond, thenp, elsep) => val isInline = tree.isInstanceOf[Trees.InlineIf[?]] changePrec(GlobalPrec) { diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala index f54baeb7256c..8a8a66ab159e 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala @@ -774,6 +774,11 @@ private class ExtractAPICollector(using Context) extends ThunkHolder { h = constantHash(c, h) case n: Name => h = nameHash(n, h) + case sym: Symbol => + h = MurmurHash3.mix(h, apiDefinition(sym, inlineOrigin).hashCode) + case tp: TypeBounds => + h = MurmurHash3.mix(h, apiType(tp.lo).hashCode) + h = MurmurHash3.mix(h, apiType(tp.hi).hashCode) case elem => cannotHash(what = i"`${elem.tryToShow}` of unknown class ${elem.getClass}", elem, tree) h diff --git a/compiler/src/dotty/tools/dotc/transform/Erasure.scala b/compiler/src/dotty/tools/dotc/transform/Erasure.scala index 981dd5f60aea..a3ee1ecb473d 100644 --- a/compiler/src/dotty/tools/dotc/transform/Erasure.scala +++ b/compiler/src/dotty/tools/dotc/transform/Erasure.scala @@ -1059,6 +1059,8 @@ object Erasure { */ override def typedImport(tree: untpd.Import)(using Context) = EmptyTree + override def typedAssumeInfo(tree: untpd.AssumeInfo, pt: Type)(using Context): Tree = super.typed(tree.body, pt) + override def adapt(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree = trace(i"adapting ${tree.showSummary()}: ${tree.tpe} to $pt", show = true) { if ctx.phase != erasurePhase && ctx.phase != erasurePhase.next then diff --git a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala index d4dd911241d3..2d9d05ed80f8 100644 --- a/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala +++ b/compiler/src/dotty/tools/dotc/transform/MegaPhase.scala @@ -321,6 +321,10 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase { val tpt = transformTree(tree.tpt, start) goTyped(cpy.Typed(tree)(expr, tpt), start) } + case tree: AssumeInfo => + tree.fold(transformTree(_, start)) { (assumeInfo, body) => + cpy.AssumeInfo(assumeInfo)(body = body) + } case tree: CaseDef => inContext(prepCaseDef(tree, start)(using outerCtx)) { val pat = withMode(Mode.Pattern)(transformTree(tree.pat, start)) diff --git a/compiler/src/dotty/tools/dotc/transform/Pickler.scala b/compiler/src/dotty/tools/dotc/transform/Pickler.scala index f5fe34bafc2f..01009615b7c3 100644 --- a/compiler/src/dotty/tools/dotc/transform/Pickler.scala +++ b/compiler/src/dotty/tools/dotc/transform/Pickler.scala @@ -83,7 +83,8 @@ class Pickler extends Phase { cls <- dropCompanionModuleClasses(topLevelClasses(unit.tpdTree)) tree <- sliceTopLevel(unit.tpdTree, cls) do - if ctx.settings.YtestPickler.value then beforePickling(cls) = tree.show + if ctx.settings.YtestPickler.value then + beforePickling(cls) = tree.show(using ctx.fresh.setSetting(ctx.settings.YnoDeepSubtypes, false)) val pickler = new TastyPickler(cls) val treePkl = new TreePickler(pickler) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 7f3e47c14732..ce1fe00d5521 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -271,14 +271,6 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase override def transform(tree: Tree)(using Context): Tree = try tree match { - // TODO move CaseDef case lower: keep most probable trees first for performance - case CaseDef(pat, _, _) => - val gadtCtx = - pat.removeAttachment(typer.Typer.InferredGadtConstraints) match - case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt)) - case None => - ctx - super.transform(tree)(using gadtCtx) case tree: Ident => if tree.isType then checkNotPackage(tree) @@ -433,6 +425,19 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase case tree: New if isCheckable(tree) => Checking.checkInstantiable(tree.tpe, tree.tpe, tree.srcPos) super.transform(tree) + case cdef @ CaseDef(pat, guard, body) => + // test case pos/i9833 + val assumeInfo = new TreeAccumulator[AssumeInfoMap] { + def apply(assumeInfo: AssumeInfoMap, tree: Tree)(using Context) = tree match + case AssumeInfo(sym, info, body) => apply(assumeInfo.add(sym, info), body) + case _ => assumeInfo + }.apply(ctx.assumeInfo, body) + inContext(ctx.withAssumeInfo(assumeInfo)) { + val pat1 = transform(pat) + val guard1 = transform(guard) + val body1 = transform(body) + cpy.CaseDef(cdef)(pat1, guard1, body1) + } case tree: Closure if !tree.tpt.isEmpty => Checking.checkRealizable(tree.tpt.tpe, tree.srcPos, "SAM type") super.transform(tree) diff --git a/compiler/src/dotty/tools/dotc/transform/Recheck.scala b/compiler/src/dotty/tools/dotc/transform/Recheck.scala index c524bbb7702f..fd01412863bb 100644 --- a/compiler/src/dotty/tools/dotc/transform/Recheck.scala +++ b/compiler/src/dotty/tools/dotc/transform/Recheck.scala @@ -338,6 +338,9 @@ abstract class Recheck extends Phase, SymTransformer: recheck(tree.guard, defn.BooleanType) recheck(tree.body, pt) + def recheckAssumeInfo(tree: AssumeInfo, pt: Type)(using Context): Type = + tree.fold(recheck(_, pt))((_, tp) => tp) + def recheckReturn(tree: Return)(using Context): Type = // Avoid local pattern defined symbols in returns from matchResult blocks // that are inserted by the PatternMatcher transform. @@ -454,6 +457,7 @@ abstract class Recheck extends Phase, SymTransformer: case tree: If => recheckIf(tree, pt) case tree: Closure => recheckClosure(tree, pt) case tree: Match => recheckMatch(tree, pt) + case tree: AssumeInfo => recheckAssumeInfo(tree, pt) case tree: Return => recheckReturn(tree) case tree: WhileDo => recheckWhileDo(tree) case tree: Try => recheckTry(tree, pt) diff --git a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala index 4548dccb598f..6a7ea2ecbf80 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala @@ -1307,6 +1307,9 @@ object Semantic: eval(stats, thisV, klass) eval(expr, thisV, klass) + case AssumeInfo(_, _, body) => + eval(body, thisV, klass) + case If(cond, thenp, elsep) => eval(cond :: thenp :: elsep :: Nil, thisV, klass).join diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 79d6501ccb2d..7fcea7ef5c7f 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1353,7 +1353,8 @@ trait Applications extends Compatibility { qual match case TypeApply(qual1, targs) => - tryWithTypeArgs(qual1, targs.mapconserve(typedType(_)))((t, ts) => + val gadtCtx = ctx.fresh.setFreshGADTBounds + tryWithTypeArgs(qual1, targs.mapconserve(typedType(_)(using gadtCtx)))((t, ts) => tryWithTypeArgs(qual, Nil)(fallBack)) case _ => tryWithTypeArgs(qual, Nil)(fallBack) diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index 98e9cb638c17..f117c0fa587a 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -389,6 +389,9 @@ trait TypeAssigner { def assignType(tree: untpd.Block, stats: List[Tree], expr: Tree)(using Context): Block = tree.withType(avoidingType(expr, stats)) + def assignType(tree: untpd.AssumeInfo, body: Tree)(using Context): AssumeInfo = + tree.withType(body.tpe) + def assignType(tree: untpd.Inlined, bindings: List[Tree], expansion: Tree)(using Context): Inlined = tree.withType(avoidingType(expansion, bindings)) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 16b256e69059..55e77cdf2eed 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -28,7 +28,6 @@ import Checking._ import Inferencing._ import Dynamic.isDynamicExpansion import EtaExpansion.etaExpand -import TypeComparer.CompareResult import inlines.{Inlines, PrepareInlineable} import util.Spans._ import util.common._ @@ -71,9 +70,6 @@ object Typer { if (!tree.isEmpty && !tree.isInstanceOf[untpd.TypedSplice] && ctx.typerState.isGlobalCommittable) assert(tree.span.exists, i"position not set for $tree # ${tree.uniqueId} of ${tree.getClass} in ${tree.source}") - /** An attachment for GADT constraints that were inferred for a pattern. */ - val InferredGadtConstraints = new Property.StickyKey[core.GadtConstraint] - /** An attachment on a Select node with an `apply` field indicating that the `apply` * was inserted by the Typer. */ @@ -1196,17 +1192,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case elsep: untpd.If => isIncomplete(elsep) case _ => false - // Insert a GADT cast if the type of the branch does not conform - // to the type assigned to the whole if tree. - // This happens when the computation of the type of the if tree - // uses GADT constraints. See #15646. - def gadtAdaptBranch(tree: Tree, branchPt: Type): Tree = - TypeComparer.testSubType(tree.tpe.widenExpr, branchPt) match { - case CompareResult.OKwithGADTUsed => - insertGadtCast(tree, tree.tpe.widen, branchPt) - case _ => tree - } - val branchPt = if isIncomplete(tree) then defn.UnitType else pt.dropIfProto val result = @@ -1220,16 +1205,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val elsep0 = typed(tree.elsep, branchPt)(using cond1.nullableContextIf(false)) thenp0 :: elsep0 :: Nil }: @unchecked - val resType = thenp1.tpe | elsep1.tpe - val thenp2 :: elsep2 :: Nil = - (thenp1 :: elsep1 :: Nil) map { t => - // Adapt each branch to ensure that their types conforms to the - // type assigned to the if tree by inserting GADT casts. - gadtAdaptBranch(t, resType) - }: @unchecked - - cpy.If(tree)(cond1, thenp2, elsep2).withType(resType) + cpy.If(tree)(cond1, thenp1, elsep1).withType(resType) def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo) def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo) @@ -1823,15 +1800,13 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val pat1 = indexPattern(tree).transform(pat) val guard1 = typedExpr(tree.guard, defn.BooleanType) var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1), pt1, ctx.scope.toList) - if ctx.gadt.isNarrowing then - // Store GADT constraint to later retrieve it (in PostTyper, for now). - // GADT constraints are necessary to correctly check bounds of type app, - // see tests/pos/i12226 and issue #12226. It might be possible that this - // will end up taking too much memory. If it does, we should just limit - // how much GADT constraints we infer - it's always sound to infer less. - pat1.putAttachment(InferredGadtConstraints, ctx.gadt) if (pt1.isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds body1 = body1.ensureConforms(pt1)(using originalCtx) + if !ctx.gadt.eql(originalCtx.gadt) then + for sym <- ctx.gadt.symbols.reverseIterator do + val bounds = ctx.gadt.fullBounds(sym).nn + if !bounds.lo.isExactlyNothing || !bounds.hi.isExactlyAny then + body1 = AssumeInfo(sym, bounds, body1) assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1) } @@ -1841,6 +1816,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer using gadtCtx)) } + def typedAssumeInfo(tree: untpd.AssumeInfo, pt: Type)(using Context): Tree = { + tree.fold(typed(_, pt)) { (assumeInfo, body) => + assignType(cpy.AssumeInfo(assumeInfo)(body = body), body) + } + } + def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = { val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind] val expr1 = typed(tree.expr, bind1.symbol.info) @@ -2404,14 +2385,14 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer } // Register GADT constraint for class type parameters from outer to inner class definition. (Useful when nested classes exist.) But do not cross a function definition. - if sym.flags.is(Method) then + if !ctx.isAfterTyper && sym.flags.is(Method) then rhsCtx.setFreshGADTBounds ctx.outer.outersIterator.takeWhile(!_.owner.is(Method)) .filter(ctx => ctx.owner.isClass && ctx.owner.typeParams.nonEmpty) .toList.reverse .foreach(ctx => rhsCtx.gadtState.addToConstraint(ctx.owner.typeParams)) - if tparamss.nonEmpty then + if !ctx.isAfterTyper && tparamss.nonEmpty then rhsCtx.setFreshGADTBounds val tparamSyms = tparamss.flatten.map(_.symbol) if !sym.isConstructor then @@ -3003,6 +2984,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case tree: untpd.Import => typedImport(tree) case tree: untpd.Export => typedExport(tree) case tree: untpd.Match => typedMatch(tree, pt) + case tree: untpd.AssumeInfo => typedAssumeInfo(tree, pt) case tree: untpd.Return => typedReturn(tree) case tree: untpd.WhileDo => typedWhileDo(tree) case tree: untpd.Try => typedTry(tree, pt) @@ -3068,9 +3050,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer || tree.isDef // ... unless tree is a definition then interpolateTypeVars(tree, pt, locked) - val simplified = tree.tpe.simplified - if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743 - tree.overwriteType(simplified) + tree.overwriteType(tree.tpe.simplified(using ctx.withGadt(GadtConstraint.empty))) tree protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { @@ -3893,27 +3873,13 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer tree.srcPos.startPos) tree } - else TypeComparer.testSubType(tree.tpe.widenExpr, pt) match - case CompareResult.Fail => - wtp match - case wtp: MethodType => missingArgs(wtp) - case _ => - typr.println(i"adapt to subtype ${tree.tpe} !<:< $pt") - //typr.println(TypeComparer.explained(tree.tpe <:< pt)) - adaptToSubType(wtp) - case CompareResult.OKwithGADTUsed - if pt.isValueType - && !inContext(ctx.fresh.setGadtState(GadtState(GadtConstraint.empty))) { - val res = (tree.tpe.widenExpr frozen_<:< pt) - if res then - // we overshot; a cast is not needed, after all. - gadts.println(i"unnecessary GADTused for $tree: ${tree.tpe.widenExpr} vs $pt in ${ctx.source}") - res - } => - insertGadtCast(tree, wtp, pt) - case _ => - //typr.println(i"OK ${tree.tpe}\n${TypeComparer.explained(_.isSubType(tree.tpe, pt))}") // uncomment for unexpected successes - tree + else if !(tree.tpe.widenExpr <:< pt) then + wtp match + case wtp: MethodType => missingArgs(wtp) + case _ => + typr.println(i"adapt to subtype ${tree.tpe} !<:< $pt") + adaptToSubType(wtp) + else tree } // Follow proxies and approximate type paramrefs by their upper bound @@ -4384,37 +4350,4 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer report.error(em"Invalid Scala 2 macro $call", call.srcPos) EmptyTree else typedExpr(call, defn.AnyType) - - /** Insert GADT cast to target type `pt` on the `tree` - * so that -Ycheck in later phases succeeds. - * The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts. - */ - private def insertGadtCast(tree: Tree, wtp: Type, pt: Type)(using Context): Tree = - val target = - if tree.tpe.isSingleton then - // In the target type, when the singleton type is intersected, we also intersect - // the GADT-approximated type of the singleton to avoid the loss of - // information. See #15646. - val gadtApprox = Inferencing.approximateGADT(wtp) - gadts.println(i"gadt approx $wtp ~~~ $gadtApprox") - val conj = - TypeComparer.testSubType(gadtApprox, pt) match { - case CompareResult.OK => - // GADT approximation of the tree type is a subtype of expected type under empty GADT - // constraints, so it is enough to only have the GADT approximation. - AndType(tree.tpe, gadtApprox) - case _ => - // In other cases, we intersect both the approximated type and the expected type. - AndType(AndType(tree.tpe, gadtApprox), pt) - } - if tree.tpe.isStable && !conj.isStable then - // this is needed for -Ycheck. Without the annotation Ycheck will - // skolemize the result type which will lead to different types before - // and after checking. See i11955.scala. - AnnotatedType(conj, Annotation(defn.UncheckedStableAnnot, tree.symbol.span)) - else conj - else pt - gadts.println(i"insert GADT cast from $tree to $target") - tree.cast(target) - end insertGadtCast } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index a575238f7cd4..522f2b2f3b7e 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -784,6 +784,16 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler end extension end BlockMethods + type AssumeInfo = tpd.AssumeInfo + object AssumeInfoTypeTest extends TypeTest[Tree, AssumeInfo]: + def unapply(x: Tree): Option[AssumeInfo & x.type] = x match + case x: (tpd.AssumeInfo & x.type) => Some(x) + case _ => None + object AssumeInfo extends AssumeInfoModule: + def copy(original: Tree)(sym: Symbol, info: TypeRepr, body: Term): AssumeInfo = + tpd.cpy.AssumeInfo(original)(sym, info, body) + def unapply(x: AssumeInfo): (Symbol, TypeRepr, Term) = (x.sym, x.info, x.body) + type Closure = tpd.Closure object ClosureTypeTest extends TypeTest[Tree, Closure]: diff --git a/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala b/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala index c229338ad228..ba45f6a76b7e 100644 --- a/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala +++ b/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala @@ -99,6 +99,8 @@ object Extractors { this += "Assign(" += lhs += ", " += rhs += ")" case Block(stats, expr) => this += "Block(" ++= stats += ", " += expr += ")" + case AssumeInfo(sym, info, body) => + this += "AssumeInfo(" += sym += ", " += info += ", " += body += ")" case If(cond, thenp, elsep) => this += "If(" += cond += ", " += thenp += ", " += elsep += ")" case Closure(meth, tpt) => @@ -259,6 +261,7 @@ object Extractors { else if x.isDefDef then this += "IsDefDefSymbol(<" += x.fullName += ">)" else if x.isValDef then this += "IsValDefSymbol(<" += x.fullName += ">)" else if x.isTypeDef then this += "IsTypeDefSymbol(<" += x.fullName += ">)" + else if x.isBind then this += "IsBindSymbol(<" += x.fullName += ">)" else { assert(x.isNoSymbol); this += "NoSymbol()" } def visitParamClause(x: ParamClause): this.type = diff --git a/compiler/test-resources/repl/i8548 b/compiler/test-resources/repl/i8548 index 8cb8104ab9d1..76dd2f9c43c3 100644 --- a/compiler/test-resources/repl/i8548 +++ b/compiler/test-resources/repl/i8548 @@ -1,2 +1,2 @@ scala> def foo[F[_],A](f:F[A]):F[A] = f -def foo[F[_$1], A](f: F[A]): F[A] \ No newline at end of file +def foo[F[_], A](f: F[A]): F[A] diff --git a/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala b/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala index a85cc9ad80f9..b1878be6a416 100644 --- a/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala +++ b/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala @@ -102,7 +102,7 @@ class SemanticdbTests: |inspect with: | diff $expect ${expect.resolveSibling("" + expect.getFileName + ".out")} |Or else update all expect files with - | sbt 'scala3-compiler-bootstrapped/test:runMain dotty.tools.dotc.semanticdb.updateExpect'""".stripMargin) + | sbt 'scala3-compiler-bootstrapped/Test/runMain dotty.tools.dotc.semanticdb.updateExpect'""".stripMargin) Files.walk(target).sorted(Comparator.reverseOrder).forEach(Files.delete) if errors.nonEmpty then fail(s"${errors.size} errors in expect test.") diff --git a/library/src/scala/quoted/Quotes.scala b/library/src/scala/quoted/Quotes.scala index edf8aa61b559..dbe432551c35 100644 --- a/library/src/scala/quoted/Quotes.scala +++ b/library/src/scala/quoted/Quotes.scala @@ -1340,6 +1340,14 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => end extension end BlockMethods + type AssumeInfo <: Term + given AssumeInfoTypeTest: TypeTest[Tree, AssumeInfo] + val AssumeInfo: AssumeInfoModule + trait AssumeInfoModule: + this: AssumeInfo.type => + def copy(original: Tree)(sym: Symbol, info: TypeRepr, body: Term): AssumeInfo + def unapply(x: AssumeInfo): (Symbol, TypeRepr, Term) + /** `TypeTest` that allows testing at runtime in a pattern match if a `Tree` is a `Closure` */ given ClosureTypeTest: TypeTest[Tree, Closure] @@ -4745,6 +4753,8 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => foldTree(foldTree(x, lhs)(owner), rhs)(owner) case Block(stats, expr) => foldTree(foldTrees(x, stats)(owner), expr)(owner) + case AssumeInfo(sym, info, body) => + foldTree(x, body)(owner) case If(cond, thenp, elsep) => foldTree(foldTree(foldTree(x, cond)(owner), thenp)(owner), elsep)(owner) case While(cond, body) => @@ -4948,6 +4958,8 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => Assign.copy(tree)(transformTerm(lhs)(owner), transformTerm(rhs)(owner)) case Block(stats, expr) => Block.copy(tree)(transformStats(stats)(owner), transformTerm(expr)(owner)) + case AssumeInfo(sym, info, body) => + AssumeInfo.copy(tree)(sym, info, transformTerm(body)(owner)) case If(cond, thenp, elsep) => If.copy(tree)(transformTerm(cond)(owner), transformTerm(thenp)(owner), transformTerm(elsep)(owner)) case Closure(meth, tpt) => diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index cb15d82affb8..d50a972bdcbe 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -3,6 +3,10 @@ import com.typesafe.tools.mima.core._ object MiMaFilters { val Library: Seq[ProblemFilter] = Seq( + ProblemFilters.exclude[MissingMethodProblem]("scala.quoted.Quotes#reflectModule.AssumeInfo"), + ProblemFilters.exclude[MissingMethodProblem]("scala.quoted.Quotes#reflectModule.AssumeInfoTypeTest"), + ProblemFilters.exclude[MissingClassProblem]("scala.quoted.Quotes$reflectModule$AssumeInfoModule"), + ProblemFilters.exclude[DirectMissingMethodProblem]("scala.caps.unsafeBox"), ProblemFilters.exclude[DirectMissingMethodProblem]("scala.caps.unsafeUnbox"), ProblemFilters.exclude[DirectMissingMethodProblem]("scala.CanEqual.canEqualMap"), @@ -31,9 +35,11 @@ object MiMaFilters { // Added java.io.Serializable as LazyValControlState supertype ProblemFilters.exclude[MissingTypesProblem]("scala.runtime.LazyVals$LazyValControlState"), ProblemFilters.exclude[MissingTypesProblem]("scala.runtime.LazyVals$Waiting"), - ) val TastyCore: Seq[ProblemFilter] = Seq( + // New TASTy tags + ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyFormat.ASSUMEINFO"), + ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyBuffer.reset"), ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyFormat.APPLYsigpoly"), ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyHash.pjwHash64"), diff --git a/tasty/src/dotty/tools/tasty/TastyFormat.scala b/tasty/src/dotty/tools/tasty/TastyFormat.scala index 226fc14acb39..be5c82435fff 100644 --- a/tasty/src/dotty/tools/tasty/TastyFormat.scala +++ b/tasty/src/dotty/tools/tasty/TastyFormat.scala @@ -97,6 +97,7 @@ Standard-Section: "ASTs" TopLevelStat* TYPED Length expr_Term ascriptionType_Term -- expr: ascription ASSIGN Length lhs_Term rhs_Term -- lhs = rhs BLOCK Length expr_Term Stat* -- { stats; expr } + ASSUMEINFO Length sym_ASTRef info_Type body_Term -- Contextual info of a symbol, such as GADT bounds INLINED Length expr_Term call_Term? ValOrDefDef* -- Inlined code from call, with given body `expr` and given bindings LAMBDA Length meth_Term target_Type? -- Closure over method `f` of type `target` (omitted id `target` is a function type) IF Length [INLINE] cond_Term then_Term else_Term -- inline? if cond then thenPart else elsePart @@ -581,6 +582,7 @@ object TastyFormat { // final val ??? = 179 final val METHODtype = 180 final val APPLYsigpoly = 181 + final val ASSUMEINFO = 182 final val MATCHtype = 190 final val MATCHtpt = 191 @@ -754,6 +756,7 @@ object TastyFormat { case NAMEDARG => "NAMEDARG" case ASSIGN => "ASSIGN" case BLOCK => "BLOCK" + case ASSUMEINFO => "ASSUMEINFO" case IF => "IF" case LAMBDA => "LAMBDA" case MATCH => "MATCH" @@ -811,7 +814,7 @@ object TastyFormat { */ def numRefs(tag: Int): Int = tag match { case VALDEF | DEFDEF | TYPEDEF | TYPEPARAM | PARAM | NAMEDARG | RETURN | BIND | - SELFDEF | REFINEDtype | TERMREFin | TYPEREFin | SELECTin | HOLE => 1 + SELFDEF | REFINEDtype | TERMREFin | TYPEREFin | SELECTin | ASSUMEINFO | HOLE => 1 case RENAMED | PARAMtype => 2 case POLYtype | TYPELAMBDAtype | METHODtype => -1 case _ => 0 diff --git a/tests/neg-custom-args/kind-projector-underscores.check b/tests/neg-custom-args/kind-projector-underscores.check index 2a832ae3d7a2..f15caecaad50 100644 --- a/tests/neg-custom-args/kind-projector-underscores.check +++ b/tests/neg-custom-args/kind-projector-underscores.check @@ -21,8 +21,8 @@ -- Error: tests/neg-custom-args/kind-projector-underscores.scala:5:23 -------------------------------------------------- 5 |class Bar1 extends Foo[Either[_, _]] // error | ^^^^^^^^^^^^ - | Type argument Either does not have the same kind as its bound [_$1] + | Type argument Either does not have the same kind as its bound [_] -- Error: tests/neg-custom-args/kind-projector-underscores.scala:6:22 -------------------------------------------------- 6 |class Bar2 extends Foo[_] // error | ^ - | Type argument _ does not have the same kind as its bound [_$1] + | Type argument _ does not have the same kind as its bound [_] diff --git a/tests/neg-custom-args/kind-projector.check b/tests/neg-custom-args/kind-projector.check index f6c258c5c58d..9a1568ec425a 100644 --- a/tests/neg-custom-args/kind-projector.check +++ b/tests/neg-custom-args/kind-projector.check @@ -5,8 +5,8 @@ -- Error: tests/neg-custom-args/kind-projector.scala:5:23 -------------------------------------------------------------- 5 |class Bar1 extends Foo[Either[*, *]] // error | ^^^^^^^^^^^^ - | Type argument Either does not have the same kind as its bound [_$1] + | Type argument Either does not have the same kind as its bound [_] -- Error: tests/neg-custom-args/kind-projector.scala:6:22 -------------------------------------------------------------- 6 |class Bar2 extends Foo[*] // error | ^ - | Type argument _ does not have the same kind as its bound [_$1] + | Type argument _ does not have the same kind as its bound [_] diff --git a/tests/neg/enum-values.check b/tests/neg/enum-values.check index 37990e8f312e..5413dfaca8d0 100644 --- a/tests/neg/enum-values.check +++ b/tests/neg/enum-values.check @@ -39,7 +39,7 @@ | failed with: | | Found: Array[example.Tag[?]] - | Required: Array[example.TypeCtorsK[?[_$1]]] + | Required: Array[example.TypeCtorsK[?[_]]] -- [E008] Not Found Error: tests/neg/enum-values.scala:36:6 ------------------------------------------------------------ 36 | Tag.valueOf("Int") // error | ^^^^^^^^^^^ diff --git a/tests/pos/gadt-expr.http4s.scala b/tests/pos/gadt-expr.http4s.scala new file mode 100644 index 000000000000..3730f2bea316 --- /dev/null +++ b/tests/pos/gadt-expr.http4s.scala @@ -0,0 +1,17 @@ +// A minimisation of a http4s test failure +// that came while implementing GadtExpr +// which was caused by a change in the selection of which type parameter +// to keep during constraint parameter replacement +// and was fixed by reverting the change +// and instead tweaking the order in which gadt bounds are added during unpickling +class Cat[F[_]] +object Cat: + given [G[_]]: Cat[G] = new Cat[G] + +class Dog + +trait Foo[F[_]]: + def bar[G[x] >: F[x]](using cat: Cat[G], dog: Dog = new Dog) = () + +class Test: + def meth[F[_]](foo: Foo[F]) = foo.bar diff --git a/tests/pos/gadt-expr.libretto.scala b/tests/pos/gadt-expr.libretto.scala new file mode 100644 index 000000000000..c3d5d9b1bcc7 --- /dev/null +++ b/tests/pos/gadt-expr.libretto.scala @@ -0,0 +1,20 @@ +// A minimisation of a libretto compilation failure +// that came while implementing GadtExpr +// which induced a type parameter reference pickling leak +// due to how, while constraining the pattern types, +// to assign any type variable constraint or GADT constraints +// TypeComparer unwraps (shallowly) a TypeVar to the origin type parameter ref +// which are the parameter and variable for a pattern bind symbol. +// When the type variable is then instantiated, that doesn't assign in the GADT bound +// so it leaks. +// This was fixed by guarding the TypeVar unwrapping +// to allow the GADT constraints to record +// from 3rd/4th try to 1st/2nd try. +class Rec[F[_]] + +sealed trait Foo[A] +case class Bar[F[_]]() extends Foo[F[Rec[F]]] + +class Test: + def meth[X](foo: Foo[X]) = foo match + case Bar() => diff --git a/tests/pos/gadt-expr.play-json2/Macro.scala b/tests/pos/gadt-expr.play-json2/Macro.scala new file mode 100644 index 000000000000..4e331a4a454d --- /dev/null +++ b/tests/pos/gadt-expr.play-json2/Macro.scala @@ -0,0 +1,23 @@ +final case class JsNumber() + +sealed trait Val[+A] +final case class Box[T](value: T) extends Val[T] + +trait Reads[A]: + def reads(json: JsNumber): Val[A] +object Reads: + given Reads[Int] = { case JsNumber() => Box(0) case null => ??? } + +object Macro: + import scala.compiletime.*, scala.deriving.* + + inline def reads[A](using m: Mirror.ProductOf[A]): Reads[A] = new Reads[A]: + def reads(js: JsNumber) = rec[A, m.MirroredElemLabels, m.MirroredElemTypes](js) + + inline def rec[A, L <: Tuple, T <: Tuple](js: JsNumber)(using m: Mirror.ProductOf[A]): Val[A] = + inline (erasedValue[L], erasedValue[T]) match + case _: (EmptyTuple, EmptyTuple) => ??? + case _: ( l *: ls, t *: ts) => summonInline[Reads[t]].reads(js) match + case Box(x) => rec[A, ls, ts](js) + +final case class Foo(bar: Int, baz: Int) diff --git a/tests/pos/gadt-expr.play-json2/Test.scala b/tests/pos/gadt-expr.play-json2/Test.scala new file mode 100644 index 000000000000..d4cf9e3b1478 --- /dev/null +++ b/tests/pos/gadt-expr.play-json2/Test.scala @@ -0,0 +1,2 @@ +class Test: + def reads = Macro.reads[Foo] diff --git a/tests/pos/gadt-expr.protoquill.scala b/tests/pos/gadt-expr.protoquill.scala new file mode 100644 index 000000000000..68a6ef32f68d --- /dev/null +++ b/tests/pos/gadt-expr.protoquill.scala @@ -0,0 +1,14 @@ +// A minimisation of a protoquill test failure +// that came while implementing GadtExpr +// which was causing a unresolved symbols pickler error. +// This is caused because during typedUnapply +// pattern bound symbols are added to the GADT constraint +// that aren't actually part of the unapply +// so it was fixed by giving that its own fresh GADT constraint +class Foo[A] +class Bar[B]: + def unapply(any: Any): true = true + +class Test: + def test(any: Any) = any match + case Bar[Foo[?]]() => 1 diff --git a/tests/pos/i11050.gadt-expr.scala b/tests/pos/i11050.gadt-expr.scala new file mode 100644 index 000000000000..2a13f8e92524 --- /dev/null +++ b/tests/pos/i11050.gadt-expr.scala @@ -0,0 +1,16 @@ +package pkg + +case class Box(value: Int) + +// Original: tests/run-custom-args/fatal-warnings/i11050.scala +// minimised to fix implementation of GadtExpr +class Test: + def test: String = foo[Box] + + transparent inline def foo[T](using m: deriving.Mirror.Of[T]) = + bar[m.MirroredElemLabels] + + transparent inline def bar[L] = + inline new Tuple2(compiletime.erasedValue[L], 1) match + case _: Tuple2[l *: _, _] => + compiletime.constValue[l].toString diff --git a/tests/pos/i15872.scala b/tests/pos/i15872.scala new file mode 100644 index 000000000000..4d85cbc1e83e --- /dev/null +++ b/tests/pos/i15872.scala @@ -0,0 +1,10 @@ +// From https://github.com/lampepfl/dotty/pull/15872#issuecomment-1218041165 +trait Tag[S] +case class TupTag[A, T <: Tuple]() extends Tag[A *: T] + +def tupleId[T <: Tuple](x: T): x.type = x + +def foo[S](x: S, ev: Tag[S]) = ev.match { + case _: TupTag[a, t] => + val t1: t = tupleId(x.tail) +} diff --git a/tests/pos/i4471-gadt.gadt-expr1.scala b/tests/pos/i4471-gadt.gadt-expr1.scala new file mode 100644 index 000000000000..1759d79ef52f --- /dev/null +++ b/tests/pos/i4471-gadt.gadt-expr1.scala @@ -0,0 +1,10 @@ +sealed trait Foo[A] + +case class Foo1[X]() extends Foo[(X, Int)] +case class Foo2[Y]() extends Foo[(Y, Int)] + +// A minimisation of pos/i4471-gadt to fix its breakage while implementing GadtExpr +class Test: + def test[T](f1: Foo[T], f2: Foo[T]) = (f1, f2) match + case (_: Foo1[x], _: Foo2[y]) => 1 + case _ => 2 diff --git a/tests/pos/i4471-gadt.gadt-expr2.scala b/tests/pos/i4471-gadt.gadt-expr2.scala new file mode 100644 index 000000000000..3079b27f34a1 --- /dev/null +++ b/tests/pos/i4471-gadt.gadt-expr2.scala @@ -0,0 +1,7 @@ +sealed trait Foo[A, B] +case class Id[T]() extends Foo[T, T] + +// A minimisation of pos/i4471-gadt to fix its breakage while implementing GadtExpr +class Test: + def mat[X, Y](scr: Foo[X, Y]) = scr match + case Id() => diff --git a/tests/pos/i4471-gadt.gadt-expr3.scala b/tests/pos/i4471-gadt.gadt-expr3.scala new file mode 100644 index 000000000000..18a77102f81b --- /dev/null +++ b/tests/pos/i4471-gadt.gadt-expr3.scala @@ -0,0 +1,7 @@ +sealed trait Foo[A, B] +case class Id[T]() extends Foo[T, T] + +// A minimisation of pos/i4471-gadt to fix its breakage while implementing GadtExpr +class Test: + def mat[X, Y](scr: Foo[X, Y]) = scr match + case _: Id[t] => diff --git a/tests/pos/i4471-gadt.gadt-expr4.scala b/tests/pos/i4471-gadt.gadt-expr4.scala new file mode 100644 index 000000000000..3811e39fc5fe --- /dev/null +++ b/tests/pos/i4471-gadt.gadt-expr4.scala @@ -0,0 +1,10 @@ +sealed trait Foo[S, T] + +case class Bar[A, B](_1: Foo[A, A]) extends Foo[(A, A), (B, B)] +case class Baz[F, G]() extends Foo[(F, G), (G, F)] + +// A minimisation of pos/i4471-gadt to fix its breakage while implementing GadtExpr +class Test: + def meth[X, Y, Z](x: Foo[X, Y], y: Foo[Y, Z]) = (x, y) match + case (z: Bar[a, b], _: Bar[c, d]) => (z._1 : Foo[a, a]) match + case _: Baz[f, g] => diff --git a/tests/pos/infinite-loop-potential.scala b/tests/pos/infinite-loop-potential.scala deleted file mode 100644 index 89a2bdd2a9a2..000000000000 --- a/tests/pos/infinite-loop-potential.scala +++ /dev/null @@ -1,10 +0,0 @@ -object InfiniteSubtypingLoopPossibility { - trait A[X] - trait B extends A[B] - trait Min[+S <: B with A[S]] - - def c: Any = ??? - c match { - case pc: Min[_] => - } -} diff --git a/tests/pos/lst.gadt-expr.scala b/tests/pos/lst.gadt-expr.scala new file mode 100644 index 000000000000..a30233262901 --- /dev/null +++ b/tests/pos/lst.gadt-expr.scala @@ -0,0 +1,6 @@ +class Lst[+T](val elems: Any) extends AnyVal: + override def equals(that: Any) = that match + case that: Lst[t] => eqLst(that) + case _ => false + + def eqLst[U](that: Lst[U]) = elems.asInstanceOf[AnyRef] eq that.elems.asInstanceOf[AnyRef] diff --git a/tests/run/gadt-expr.play-json/Macro.scala b/tests/run/gadt-expr.play-json/Macro.scala new file mode 100644 index 000000000000..f2356ab4fcc1 --- /dev/null +++ b/tests/run/gadt-expr.play-json/Macro.scala @@ -0,0 +1,40 @@ +sealed trait JsValue +final case class JsNumber(value: Int) extends JsValue +final case class JsObject(underlying: Map[String, JsValue]) extends JsValue + +sealed trait JsResult[+A] +final case class JsSuccess[T](value: T) extends JsResult[T] +case object JsError extends JsResult[Nothing] + +trait Reads[A]: + def reads(json: JsValue): JsResult[A] +object Reads: + given Reads[Int] = { case JsNumber(n) => JsSuccess(n) case _ => JsError } + +object Macro: + import scala.compiletime.*, scala.deriving.* + + inline def reads[A](using m: Mirror.ProductOf[A]): Reads[A] = new Reads[A]: + def reads(js: JsValue) = js match + case obj @ JsObject(_) => + val elems = new Array[Any](constValue[Tuple.Size[m.MirroredElemTypes]]) + rec[A, m.MirroredElemLabels, m.MirroredElemTypes](obj, elems)(0) + case _ => JsError + + inline def rec[A, L <: Tuple, T <: Tuple]( + obj: JsObject, elems: Array[Any])(n: Int)(using m: Mirror.ProductOf[A]): JsResult[A] = + inline (erasedValue[L], erasedValue[T]) match + case _: (EmptyTuple, EmptyTuple) => JsSuccess(m.fromProduct(ArrayProduct(elems))) + case _: (l *: ls, t *: ts) => + val key = inline erasedValue[l] match { case s: String => s } + val reads = summonInline[Reads[t]] + reads.reads(obj.underlying(key)) match + case JsSuccess(x) => + elems(n) = x + rec[A, ls, ts](obj, elems)(n + 1) + case _ => JsError + +final class ArrayProduct[A](elems: Array[A]) extends Product: + def canEqual(that: Any): Boolean = that.isInstanceOf[ArrayProduct[_]] + def productArity: Int = elems.size + def productElement(idx: Int): Any = elems(idx) diff --git a/tests/run/gadt-expr.play-json/Test.scala b/tests/run/gadt-expr.play-json/Test.scala new file mode 100644 index 000000000000..a63401fdd07b --- /dev/null +++ b/tests/run/gadt-expr.play-json/Test.scala @@ -0,0 +1,14 @@ +final case class Foo(bar: Int, baz: Int) + +// A minimisation of a play-json test failure +// that came while implementing GadtExpr +// which was causing an inline match to fall through to the default case +// because a bug in TreeTypeMap wasn't substituting symbols correctly +object Test: + def main(args: Array[String]): Unit = + val json = JsObject(Map("bar" -> JsNumber(1), "baz" -> JsNumber(2))) + val reads = Macro.reads[Foo] + val baz = reads.reads(json) match + case JsSuccess(foo) => foo.baz + case _ => ??? + assert(baz == 2, s"expected 2 but was $baz") diff --git a/tests/run/gadt-expr/Expr.scala b/tests/run/gadt-expr/Expr.scala new file mode 100644 index 000000000000..54067b3c0aca --- /dev/null +++ b/tests/run/gadt-expr/Expr.scala @@ -0,0 +1,3 @@ +sealed trait Expr[A] +final case class IntExpr() extends Expr[Int] +final case class BoolExpr() extends Expr[Boolean] diff --git a/tests/run/gadt-expr/Lib.scala b/tests/run/gadt-expr/Lib.scala new file mode 100644 index 000000000000..f381b1d9c6f0 --- /dev/null +++ b/tests/run/gadt-expr/Lib.scala @@ -0,0 +1,5 @@ +class Lib: + def extract[A](expr: Expr[A]): A = (expr: @unchecked) match + case IntExpr() => + val res: A = 1 + res diff --git a/tests/run/gadt-expr/Test.scala b/tests/run/gadt-expr/Test.scala new file mode 100644 index 000000000000..d4e649dc88ad --- /dev/null +++ b/tests/run/gadt-expr/Test.scala @@ -0,0 +1,8 @@ +// The starting example of the kind of thing that +// GadtExpr should be able to handle without the help +// of GADT cast insertions and re-inferring GADT constraints all the time. +object Test extends Lib: + def main(args: Array[String]): Unit = + val expr: Expr[Int] = IntExpr() + val int: Int = extract(expr) + assert(int + 1 == 2) diff --git a/tests/semanticdb/metac.expect b/tests/semanticdb/metac.expect index f5556e28bd1b..4f31231e8f28 100644 --- a/tests/semanticdb/metac.expect +++ b/tests/semanticdb/metac.expect @@ -922,7 +922,7 @@ Text => empty Language => Scala Symbols => 181 entries Occurrences => 148 entries -Synthetics => 6 entries +Synthetics => 5 entries Symbols: _empty_/Enums. => final object Enums extends Object { self: Enums.type => +30 decls } @@ -1259,7 +1259,6 @@ Occurrences: Synthetics: [52:9..52:13):Refl => *.unapply[Option[B]] -[52:31..52:50):identity[Option[B]] => *[Function1[A, Option[B]]] [54:14..54:18):Some => *.apply[Some[Int]] [54:14..54:34):Some(Some(1)).unwrap => *(given_<:<_T_T[Option[Int]]) [54:19..54:23):Some => *.apply[Int]