Skip to content

Commit 5d28c9f

Browse files
oderskymichelou
authored andcommitted
Use all available context info for healing ambiguous implicits
When retrying after an ambiguous implicit, we now make use of all the information in the prototype, including ignored parts. We also try to match formal parameters with actually given arguments. Fixes scala#11243 Fixes scala#5773, which previously was closed with a more detailed error message.
1 parent 1bcca9e commit 5d28c9f

File tree

6 files changed

+196
-26
lines changed

6 files changed

+196
-26
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1652,6 +1652,11 @@ object Types {
16521652
*/
16531653
def deepenProto(using Context): Type = this
16541654

1655+
/** If this is a prototype with some ignored component, reveal it, and
1656+
* deepen the result transitively. Otherwise the type itself.
1657+
*/
1658+
def deepenProtoTrans(using Context): Type = this
1659+
16551660
/** If this is an ignored proto type, its underlying type, otherwise the type itself */
16561661
def revealIgnored: Type = this
16571662

@@ -3436,7 +3441,7 @@ object Types {
34363441
case tp: TermRef => applyPrefix(tp)
34373442
case tp: AppliedType => tp.fold(status, compute(_, _, theAcc))
34383443
case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional)
3439-
case TermParamRef(`thisLambdaType`, _) => TrueDeps
3444+
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
34403445
case _: ThisType | _: BoundType | NoPrefix => status
34413446
case _ =>
34423447
(if theAcc != null then theAcc else DepAcc()).foldOver(status, tp)

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,34 @@ object ProtoTypes {
4949
/** Test compatibility after normalization.
5050
* Do this in a fresh typerstate unless `keepConstraint` is true.
5151
*/
52-
def normalizedCompatible(tp: Type, pt: Type, keepConstraint: Boolean)(using Context): Boolean = {
53-
def testCompat(using Context): Boolean = {
52+
def normalizedCompatible(tp: Type, pt: Type, keepConstraint: Boolean)(using Context): Boolean =
53+
54+
def testCompat(using Context): Boolean =
5455
val normTp = normalize(tp, pt)
5556
isCompatible(normTp, pt) || pt.isRef(defn.UnitClass) && normTp.isParameterless
56-
}
57-
if (keepConstraint)
58-
tp.widenSingleton match {
57+
58+
if keepConstraint || ctx.mode.is(Mode.ConstrainResultDeep) then
59+
tp.widenSingleton match
5960
case poly: PolyType =>
60-
// We can't keep the constraint in this case, since we have to add type parameters
61-
// to it, but there's no place to associate them with type variables.
62-
// So we'd get a "inconsistent: no typevars were added to committable constraint"
63-
// assertion failure in `constrained`. To do better, we'd have to change the
64-
// constraint handling architecture so that some type parameters are committable
65-
// and others are not. But that's a whole different ballgame.
66-
normalizedCompatible(tp, pt, keepConstraint = false)
61+
val newctx = ctx.fresh.setNewTyperState()
62+
val result = testCompat(using newctx)
63+
typr.println(
64+
i"""normalizedCompatible for $poly, $pt = $result
65+
|constraint was: ${ctx.typerState.constraint}
66+
|constraint now: ${newctx.typerState.constraint}""")
67+
val existingVars = ctx.typerState.uninstVars.toSet
68+
if result
69+
&& (ctx.typerState.constraint ne newctx.typerState.constraint)
70+
&& newctx.typerState.uninstVars.forall(existingVars.contains)
71+
then newctx.typerState.commit()
72+
// If the new constrait contains fresh type variables we cannot keep it,
73+
// since those type variables are not instantiated anywhere in the source.
74+
// See pos/i6682a.scala for a test case. See pos/11243.scala and pos/i5773b.scala
75+
// for tests where it matters that we keep the constraint otherwise.
76+
result
6777
case _ => testCompat
68-
}
6978
else explore(testCompat)
70-
}
79+
end normalizedCompatible
7180

7281
private def disregardProto(pt: Type)(using Context): Boolean =
7382
pt.dealias.isRef(defn.UnitClass)
@@ -80,7 +89,16 @@ object ProtoTypes {
8089
val res = pt.widenExpr match {
8190
case pt: FunProto =>
8291
mt match {
83-
case mt: MethodType => constrainResult(resultTypeApprox(mt), pt.resultType)
92+
case mt: MethodType =>
93+
constrainResult(resultTypeApprox(mt), pt.resultType)
94+
&& {
95+
if ctx.mode.is(Mode.ConstrainResultDeep) then
96+
if mt.isImplicitMethod == (pt.applyKind == ApplyKind.Using) then
97+
val tpargs = pt.args.lazyZip(mt.paramInfos).map(pt.typedArg)
98+
tpargs.tpes.corresponds(mt.paramInfos)(_ <:< _)
99+
else true
100+
else true
101+
}
84102
case _ => true
85103
}
86104
case _: ValueTypeOrProto if !disregardProto(pt) =>
@@ -123,6 +141,7 @@ object ProtoTypes {
123141
abstract case class IgnoredProto(ignored: Type) extends CachedGroundType with MatchAlways:
124142
override def revealIgnored = ignored
125143
override def deepenProto(using Context): Type = ignored
144+
override def deepenProtoTrans(using Context): Type = ignored.deepenProtoTrans
126145

127146
override def computeHash(bs: Hashable.Binders): Int = doHash(bs, ignored)
128147

@@ -202,7 +221,12 @@ object ProtoTypes {
202221
def map(tm: TypeMap)(using Context): SelectionProto = derivedSelectionProto(name, tm(memberProto), compat)
203222
def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T = ta(x, memberProto)
204223

205-
override def deepenProto(using Context): SelectionProto = derivedSelectionProto(name, memberProto.deepenProto, compat)
224+
override def deepenProto(using Context): SelectionProto =
225+
derivedSelectionProto(name, memberProto.deepenProto, compat)
226+
227+
override def deepenProtoTrans(using Context): SelectionProto =
228+
derivedSelectionProto(name, memberProto.deepenProtoTrans, compat)
229+
206230
override def computeHash(bs: Hashable.Binders): Int = {
207231
val delta = (if (compat eq NoViewsAllowed) 1 else 0) | (if (privateOK) 2 else 0)
208232
addDelta(doHash(bs, name, memberProto), delta)
@@ -419,7 +443,11 @@ object ProtoTypes {
419443
def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T =
420444
ta(ta.foldOver(x, typedArgs().tpes), resultType)
421445

422-
override def deepenProto(using Context): FunProto = derivedFunProto(args, resultType.deepenProto, typer)
446+
override def deepenProto(using Context): FunProto =
447+
derivedFunProto(args, resultType.deepenProto, typer)
448+
449+
override def deepenProtoTrans(using Context): FunProto =
450+
derivedFunProto(args, resultType.deepenProtoTrans, typer)
423451

424452
override def withContext(newCtx: Context): ProtoType =
425453
if newCtx `eq` protoCtx then this
@@ -472,7 +500,11 @@ object ProtoTypes {
472500
def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T =
473501
ta(ta(x, argType), resultType)
474502

475-
override def deepenProto(using Context): ViewProto = derivedViewProto(argType, resultType.deepenProto)
503+
override def deepenProto(using Context): ViewProto =
504+
derivedViewProto(argType, resultType.deepenProto)
505+
506+
override def deepenProtoTrans(using Context): ViewProto =
507+
derivedViewProto(argType, resultType.deepenProtoTrans)
476508
}
477509

478510
class CachedViewProto(argType: Type, resultType: Type) extends ViewProto(argType, resultType) {
@@ -522,7 +554,11 @@ object ProtoTypes {
522554
def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T =
523555
ta(ta.foldOver(x, targs.tpes), resultType)
524556

525-
override def deepenProto(using Context): PolyProto = derivedPolyProto(targs, resultType.deepenProto)
557+
override def deepenProto(using Context): PolyProto =
558+
derivedPolyProto(targs, resultType.deepenProto)
559+
560+
override def deepenProtoTrans(using Context): PolyProto =
561+
derivedPolyProto(targs, resultType.deepenProtoTrans)
526562
}
527563

528564
/** A prototype for expressions [] that are known to be functions:

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3182,9 +3182,11 @@ class Typer extends Namer
31823182
val arg = inferImplicitArg(formal, tree.span.endPos)
31833183
arg.tpe match
31843184
case failed: AmbiguousImplicits =>
3185-
val pt1 = pt.deepenProto
3185+
val pt1 = pt.deepenProtoTrans
31863186
if (pt1 `ne` pt) && (pt1 ne sharpenedPt)
3187-
&& constrainResult(tree.symbol, wtp, pt1)
3187+
&& withMode(Mode.ConstrainResultDeep)(
3188+
constrainResult(tree.symbol, wtp, pt1)
3189+
)
31883190
then implicitArgs(formals, argIndex, pt1)
31893191
else arg :: implicitArgs(formals1, argIndex + 1, pt1)
31903192
case failed: SearchFailureType =>

tests/neg/i6391.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
object Test {
22
def foo(x: String, y: x.type): Any = ???
3-
val f = foo // error // error: cannot convert to closure
4-
}
3+
val f = foo // error
4+
}

tests/pos/i11243.scala

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
object WriterTest extends App {
2+
3+
object Functor:
4+
def apply[F[_]](using f: Functor[F]) = f
5+
6+
trait Functor[F[_]]:
7+
extension [A, B](x: F[A])
8+
def map(f: A => B): F[B]
9+
10+
object Applicative:
11+
def apply[F[_]](using a: Applicative[F]) = a
12+
13+
trait Applicative[F[_]] extends Functor[F]:
14+
def pure[A](x:A):F[A]
15+
16+
extension [A,B](x: F[A])
17+
def ap(f: F[A => B]): F[B]
18+
19+
def map(f: A => B): F[B] = {
20+
x.ap(pure(f))
21+
}
22+
23+
extension [A,B,C](fa: F[A]) def map2(fb: F[B])(f: (A,B) => C): F[C] = {
24+
val fab: F[B => C] = fa.map((a: A) => (b: B) => f(a,b))
25+
fb.ap(fab)
26+
}
27+
28+
end Applicative
29+
30+
31+
object Monad:
32+
def apply[F[_]](using m: Monad[F]) = m
33+
34+
trait Monad[F[_]] extends Applicative[F]:
35+
36+
// The unit value for a monad
37+
def pure[A](x:A):F[A]
38+
39+
extension[A,B](fa :F[A])
40+
// The fundamental composition operation
41+
def flatMap(f :A=>F[B]):F[B]
42+
43+
// Monad can also implement `ap` in terms of `map` and `flatMap`
44+
def ap(fab: F[A => B]): F[B] = {
45+
fab.flatMap {
46+
f =>
47+
fa.flatMap {
48+
a =>
49+
pure(f(a))
50+
}
51+
}
52+
53+
}
54+
55+
end Monad
56+
57+
given eitherMonad[Err]: Monad[[X] =>> Either[Err,X]] with
58+
def pure[A](a: A): Either[Err, A] = Right(a)
59+
extension [A,B](x: Either[Err,A]) def flatMap(f: A => Either[Err, B]) = {
60+
x match {
61+
case Right(a) => f(a)
62+
case Left(err) => Left(err)
63+
}
64+
}
65+
66+
given optionMonad: Monad[Option] with
67+
def pure[A](a: A) = Some(a)
68+
extension[A,B](fa: Option[A])
69+
def flatMap(f: A => Option[B]) = {
70+
fa match {
71+
case Some(a) =>
72+
f(a)
73+
case None =>
74+
None
75+
}
76+
}
77+
78+
given listMonad: Monad[List] with
79+
def pure[A](a: A): List[A] = List(a)
80+
81+
extension[A,B](x: List[A])
82+
def flatMap(f: A => List[B]): List[B] = {
83+
x match {
84+
case hd :: tl => f(hd) ++ tl.flatMap(f)
85+
case Nil => Nil
86+
}
87+
}
88+
89+
case class Transformer[F[_]: Monad,A](val wrapped: F[A])
90+
91+
given transformerMonad[F[_]: Monad]: Monad[[X] =>> Transformer[F,X]] with {
92+
93+
def pure[A](a: A): Transformer[F,A] = Transformer(summon[Monad[F]].pure(a))
94+
95+
extension [A,B](fa: Transformer[F,A])
96+
def flatMap(f: A => Transformer[F,B]) = {
97+
val ffa: F[B] = Monad[F].flatMap(fa.wrapped) {
98+
case a => {
99+
f(a).wrapped.map {
100+
case b =>
101+
b
102+
}
103+
}
104+
}
105+
Transformer(ffa)
106+
}
107+
}
108+
109+
type EString[A] = Either[String,A]
110+
111+
def incrementEven(a: Int): Transformer[EString,Int] = {
112+
if(a % 2 == 1) Transformer(Left("Odd number provided"))
113+
else Transformer(Right(a + 1))
114+
}
115+
116+
def doubleOdd(a: Int): Transformer[EString, Int] = {
117+
if(a % 2 == 0) Transformer(Left("Even number provided"))
118+
else Transformer(Right(a * 2))
119+
}
120+
121+
val writerExample = incrementEven(8)
122+
val example =
123+
WriterTest.transformerMonad.flatMap(writerExample)(doubleOdd)
124+
//writerExample.flatMap(doubleOdd) // Error ambiguous F
125+
126+
127+
}

tests/neg/i5773.scala renamed to tests/pos/i5773b.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ object Semigroup {
1010

1111
implicit def sumSemigroup[N](implicit N: Numeric[N]): Semigroup[N] = new {
1212
extension (lhs: N) override def append(rhs: N): N = N.plus(lhs, rhs)
13-
extension (lhs: Int) def appendS(rhs: N): N = ??? // N.plus(lhs, rhs)
13+
extension (lhs: Int) override def appendS(rhs: N): N = ??? // N.plus(lhs, rhs)
1414
}
1515
}
1616

1717

1818
object Main {
1919
import Semigroup.sumSemigroup // this is not sufficient
2020
def f1 = {
21-
println(1 appendS 2) // error This should give the following error message:
21+
println(1 appendS 2) // This used to give the following error message:
2222
/*
2323
21 | println(1 appendS 2)
2424
| ^^^^^^^^^

0 commit comments

Comments
 (0)