Skip to content

Commit fe8d39a

Browse files
committed
Instantiate more type variables to hard unions
Fixes #14770
1 parent 1724d84 commit fe8d39a

File tree

8 files changed

+92
-32
lines changed

8 files changed

+92
-32
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import config.Printers.typr
1212
import typer.ProtoTypes.{newTypeVar, representedParamRef}
1313
import UnificationDirection.*
1414
import NameKinds.AvoidNameKind
15+
import NullOpsDecorator.stripNull
1516

1617
/** Methods for adding constraints and solving them.
1718
*
@@ -525,10 +526,12 @@ trait ConstraintHandling {
525526
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
526527
* as those could leak the annotation to users (see run/inferred-repeated-result).
527528
*/
528-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
529+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
529530
def widenOr(tp: Type) =
530-
val tpw = tp.widenUnion
531-
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
531+
if widenUnions then
532+
val tpw = tp.widenUnion
533+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
534+
else tp.hardenUnions
532535

533536
def widenSingle(tp: Type) =
534537
val tpw = tp.widenSingletons
@@ -548,24 +551,40 @@ trait ConstraintHandling {
548551
wideInst.dropRepeatedAnnot
549552
end widenInferred
550553

554+
extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match
555+
case tp: AndType =>
556+
tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions)
557+
case tp: RefinedType =>
558+
tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo)
559+
case tp: RecType =>
560+
tp.rebind(tp.parent.hardenUnions)
561+
case tp: HKTypeLambda =>
562+
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
563+
case tp: OrType =>
564+
val tp1 = tp.stripNull
565+
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
566+
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
567+
case _ =>
568+
tp
569+
551570
/** The instance type of `param` in the current constraint (which contains `param`).
552571
* If `fromBelow` is true, the instance type is the lub of the parameter's
553572
* lower bounds; otherwise it is the glb of its upper bounds. However,
554573
* a lower bound instantiation can be a singleton type only if the upper bound
555574
* is also a singleton type.
556575
*/
557-
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
576+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type = {
558577
val approx = approximation(param, fromBelow).simplified
559578
if fromBelow then
560-
val widened = widenInferred(approx, param)
579+
val widened = widenInferred(approx, param, widenUnions)
561580
// Widening can add extra constraints, in particular the widened type might
562581
// be a type variable which is now instantiated to `param`, and therefore
563582
// cannot be used as an instantiation of `param` without creating a loop.
564583
// If that happens, we run `instanceType` again to find a new instantation.
565584
// (we do not check for non-toplevel occurences: those should never occur
566585
// since `addOneBound` disallows recursive lower bounds).
567586
if constraint.occursAtToplevel(param, widened) then
568-
instanceType(param, fromBelow)
587+
instanceType(param, fromBelow, widenUnions)
569588
else
570589
widened
571590
else

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -492,23 +492,35 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
492492
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
493493
case _ => true
494494

495-
widenOK
496-
|| joinOK
497-
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
498-
|| containsAnd(tp1)
499-
&& !joined
500-
&& {
501-
joined = true
502-
try inFrozenGadt(recur(tp1.join, tp2))
503-
finally joined = false
504-
}
505-
// An & on the left side loses information. We compensate by also trying the join.
506-
// This is less ad-hoc than it looks since we produce joins in type inference,
507-
// and then need to check that they are indeed supertypes of the original types
508-
// under -Ycheck. Test case is i7965.scala.
509-
// On the other hand, we could get a combinatorial explosion by applying such joins
510-
// recursively, so we do it only once. See i14870.scala as a test case, which would
511-
// loop for a very long time without the recursion brake.
495+
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
496+
case tvar: TypeVar if constraint.contains(tvar.origin) =>
497+
tvar.widenUnions = false
498+
case tp2: TypeParamRef if constraint.contains(tp2) =>
499+
hardenTypeVars(constraint.typeVarOfParam(tp2))
500+
case tp2: AndOrType =>
501+
hardenTypeVars(tp2.tp1)
502+
hardenTypeVars(tp2.tp2)
503+
case _ =>
504+
505+
val res = widenOK
506+
|| joinOK
507+
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
508+
|| containsAnd(tp1)
509+
&& !joined
510+
&& {
511+
joined = true
512+
try inFrozenGadt(recur(tp1.join, tp2))
513+
finally joined = false
514+
}
515+
// An & on the left side loses information. We compensate by also trying the join.
516+
// This is less ad-hoc than it looks since we produce joins in type inference,
517+
// and then need to check that they are indeed supertypes of the original types
518+
// under -Ycheck. Test case is i7965.scala.
519+
// On the other hand, we could get a combinatorial explosion by applying such joins
520+
// recursively, so we do it only once. See i14870.scala as a test case, which would
521+
// loop for a very long time without the recursion brake.
522+
if res then hardenTypeVars(tp2)
523+
res
512524

513525
case tp1: MatchType =>
514526
val reduced = tp1.reduced
@@ -2851,8 +2863,8 @@ object TypeComparer {
28512863
def subtypeCheckInProgress(using Context): Boolean =
28522864
comparing(_.subtypeCheckInProgress)
28532865

2854-
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
2855-
comparing(_.instanceType(param, fromBelow))
2866+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type =
2867+
comparing(_.instanceType(param, fromBelow, widenUnions))
28562868

28572869
def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
28582870
comparing(_.approximation(param, fromBelow))
@@ -2872,8 +2884,8 @@ object TypeComparer {
28722884
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
28732885
comparing(_.addToConstraint(tl, tvars))
28742886

2875-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
2876-
comparing(_.widenInferred(inst, bound))
2887+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
2888+
comparing(_.widenInferred(inst, bound, widenUnions))
28772889

28782890
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
28792891
comparing(_.dropTransparentTraits(tp, bound))

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,9 @@ object TypeOps:
517517
override def apply(tp: Type): Type = tp match
518518
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
519519
val lo = TypeComparer.instanceType(
520-
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
520+
tp.origin,
521+
fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound,
522+
widenUnions = tp.widenUnions)(using mapCtx)
521523
val lo1 = apply(lo)
522524
if (lo1 ne lo) lo1 else tp
523525
case _ =>

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4507,6 +4507,8 @@ object Types {
45074507
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState | Null, val nestingLevel: Int) extends CachedProxyType with ValueType {
45084508
private var currentOrigin = initOrigin
45094509

4510+
var widenUnions = true
4511+
45104512
def origin: TypeParamRef = currentOrigin
45114513

45124514
/** Set origin to new parameter. Called if we merge two conflicting constraints.
@@ -4569,7 +4571,7 @@ object Types {
45694571
* is also a singleton type.
45704572
*/
45714573
def instantiate(fromBelow: Boolean)(using Context): Type =
4572-
val tp = TypeComparer.instanceType(origin, fromBelow)
4574+
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions)
45734575
if myInst.exists then // The line above might have triggered instantiation of the current type variable
45744576
myInst
45754577
else

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ class Namer { typer: Typer =>
18841884
TypeOps.simplify(tp.widenTermRefExpr,
18851885
if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match
18861886
case ctp: ConstantType if sym.isInlineVal => ctp
1887-
case tp => TypeComparer.widenInferred(tp, pt)
1887+
case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true)
18881888

18891889
// Replace aliases to Unit by Unit itself. If we leave the alias in
18901890
// it would be erased to BoxedUnit.

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
489489
val tparams = poly.paramRefs
490490
val variances = childClass.typeParams.map(_.paramVarianceSign)
491491
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
492-
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
492+
TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true))
493493
resType.substParams(poly, instanceTypes)
494494
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
495495
case _ =>

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2808,7 +2808,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28082808
if (ctx.mode.is(Mode.Pattern)) app1
28092809
else {
28102810
val elemTpes = elems.lazyZip(pts).map((elem, pt) =>
2811-
TypeComparer.widenInferred(elem.tpe, pt))
2811+
TypeComparer.widenInferred(elem.tpe, pt, widenUnions = true))
28122812
val resTpe = TypeOps.nestedPairs(elemTpes)
28132813
app1.cast(resTpe)
28142814
}

tests/pos/i14770.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
type UndefOr[A] = A | Unit
2+
3+
extension [A](maybe: UndefOr[A])
4+
def foreach(f: A => Unit): Unit =
5+
maybe match
6+
case () => ()
7+
case a: A => f(a)
8+
9+
trait Foo
10+
trait Bar
11+
12+
object Baz:
13+
var booBap: Foo | Bar = _
14+
15+
def z: UndefOr[Foo | Bar] = ???
16+
17+
@main
18+
def main =
19+
z.foreach(x => Baz.booBap = x)
20+
21+
def test[A](v: A | Unit): A | Unit = v
22+
val x1 = test(5: Int | Unit)
23+
val x2 = test(5: String | Int | Unit)
24+
val _: Int | Unit = x1
25+
val _: String | Int | Unit = x2

0 commit comments

Comments
 (0)