diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 0b1646e23620..f10a1cc7372c 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -526,9 +526,11 @@ object Contexts { final def withOwner(owner: Symbol): Context = if (owner ne this.owner) fresh.setOwner(owner) else this + final def withTyperState(typerState: TyperState): Context = + if typerState ne this.typerState then fresh.setTyperState(typerState) else this + final def withUncommittedTyperState: Context = - val ts = typerState.uncommittedAncestor - if ts ne typerState then fresh.setTyperState(ts) else this + withTyperState(typerState.uncommittedAncestor) final def withProperty[T](key: Key[T], value: Option[T]): Context = if (property(key) == value) this @@ -599,8 +601,8 @@ object Contexts { this.scope = newScope this def setTyperState(typerState: TyperState): this.type = { this.typerState = typerState; this } - def setNewTyperState(): this.type = setTyperState(typerState.fresh().setCommittable(true)) - def setExploreTyperState(): this.type = setTyperState(typerState.fresh().setCommittable(false)) + def setNewTyperState(): this.type = setTyperState(typerState.fresh(committable = true)) + def setExploreTyperState(): this.type = setTyperState(typerState.fresh(committable = false)) def setReporter(reporter: Reporter): this.type = setTyperState(typerState.fresh().setReporter(reporter)) def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) } def setGadt(gadt: GadtConstraint): this.type = diff --git a/compiler/src/dotty/tools/dotc/core/TyperState.scala b/compiler/src/dotty/tools/dotc/core/TyperState.scala index cd5cb1a65105..9aa51a6714c3 100644 --- a/compiler/src/dotty/tools/dotc/core/TyperState.scala +++ b/compiler/src/dotty/tools/dotc/core/TyperState.scala @@ -103,11 +103,12 @@ class TyperState() { this /** A fresh typer state with the same constraint as this one. */ - def fresh(reporter: Reporter = StoreReporter(this.reporter)): TyperState = + def fresh(reporter: Reporter = StoreReporter(this.reporter), + committable: Boolean = this.isCommittable): TyperState = util.Stats.record("TyperState.fresh") TyperState().init(this, this.constraint) .setReporter(reporter) - .setCommittable(this.isCommittable) + .setCommittable(committable) /** The uninstantiated variables */ def uninstVars: collection.Seq[TypeVar] = constraint.uninstVars @@ -182,24 +183,25 @@ class TyperState() { /** Integrate the constraints from `that` into this TyperState. * - * @pre If `that` is committable, it must not contain any type variable which + * @pre If `this` and `that` are committable, `that` must not contain any type variable which * does not exist in `this` (in other words, all its type variables must * be owned by a common parent of `this` and `that`). */ - def mergeConstraintWith(that: TyperState)(using Context): Unit = + def mergeConstraintWith(that: TyperState)(using Context): this.type = + if this eq that then return this + that.ensureNotConflicting(constraint) - val comparingCtx = - if ctx.typerState == this then ctx - else ctx.fresh.setTyperState(this) + val comparingCtx = ctx.withTyperState(this) - comparing(typeComparer => + inContext(comparingCtx)(comparing(typeComparer => val other = that.constraint val res = other.domainLambdas.forall(tl => // Integrate the type lambdas from `other` constraint.contains(tl) || other.isRemovable(tl) || { val tvars = tl.paramRefs.map(other.typeVarOfParam(_)).collect { case tv: TypeVar => tv } - tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar)) + if this.isCommittable then + tvars.foreach(tvar => if !tvar.inst.exists && !isOwnedAnywhere(this, tvar) then includeVar(tvar)) typeComparer.addToConstraint(tl, tvars) }) && // Integrate the additional constraints on type variables from `other` @@ -220,10 +222,11 @@ class TyperState() { ) ) assert(res || ctx.reporter.errorsReported, i"cannot merge $constraint with $other.") - )(using comparingCtx) + )) for tl <- constraint.domainLambdas do if constraint.isRemovable(tl) then constraint = constraint.remove(tl) + this end mergeConstraintWith /** Take ownership of `tvar`. diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index bf46c38c9a39..2c8064ca23b4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -1177,7 +1177,29 @@ trait Implicits: // compare the extension methods instead of their wrappers. def stripExtension(alt: SearchSuccess) = methPart(stripApply(alt.tree)).tpe (stripExtension(alt1), stripExtension(alt2)) match - case (ref1: TermRef, ref2: TermRef) => diff = compare(ref1, ref2) + case (ref1: TermRef, ref2: TermRef) => + // ref1 and ref2 might refer to type variables owned by + // alt1.tstate and alt2.tstate respectively, to compare the + // alternatives correctly we need a TyperState that includes + // constraints from both sides, see + // tests/*/extension-specificity2.scala for test cases. + val constraintsIn1 = alt1.tstate.constraint ne ctx.typerState.constraint + val constraintsIn2 = alt2.tstate.constraint ne ctx.typerState.constraint + def exploreState(alt: SearchSuccess): TyperState = + alt.tstate.fresh(committable = false) + val comparisonState = + if constraintsIn1 && constraintsIn2 then + exploreState(alt1).mergeConstraintWith(alt2.tstate) + else if constraintsIn1 then + exploreState(alt1) + else if constraintsIn2 then + exploreState(alt2) + else + ctx.typerState + + diff = inContext(ctx.withTyperState(comparisonState)) { + compare(ref1, ref2) + } case _ => if diff < 0 then alt2 else if diff > 0 then alt1 diff --git a/tests/neg/extension-specificity2.scala b/tests/neg/extension-specificity2.scala new file mode 100644 index 000000000000..0087dbbe7165 --- /dev/null +++ b/tests/neg/extension-specificity2.scala @@ -0,0 +1,10 @@ +trait Bla1[A]: + extension (x: A) def foo(y: A): Int +trait Bla2[A]: + extension (x: A) def foo(y: A): Int + +def test = + given bla1[T <: Int]: Bla1[T] = ??? + given bla2[S <: Int]: Bla2[S] = ??? + + 1.foo(2) // error: never extension is more specific than the other diff --git a/tests/run/extension-specificity2.scala b/tests/run/extension-specificity2.scala new file mode 100644 index 000000000000..eeaad80a3687 --- /dev/null +++ b/tests/run/extension-specificity2.scala @@ -0,0 +1,37 @@ +trait Foo[F[_]]: + extension [A](fa: F[A]) + def foo[B](fb: F[B]): Int + +def test1 = + // Simplified from https://github.com/typelevel/spotted-leopards/issues/2 + given listFoo: Foo[List] with + extension [A](fa: List[A]) + def foo[B](fb: List[B]): Int = 1 + + given functionFoo[T]: Foo[[A] =>> T => A] with + extension [A](fa: T => A) + def foo[B](fb: T => B): Int = 2 + + val x = List(1, 2).foo(List(3, 4)) + assert(x == 1, x) + +def test2 = + // This test case would fail if we used `wildApprox` on the method types + // instead of using the correct typer state. + trait Bar1[A]: + extension (x: A => A) def bar(y: A): Int + trait Bar2: + extension (x: Int => 1) def bar(y: Int): Int + + given bla1[T]: Bar1[T] with + extension (x: T => T) def bar(y: T): Int = 1 + given bla2: Bar2 with + extension (x: Int => 1) def bar(y: Int): Int = 2 + + val f: Int => 1 = x => 1 + val x = f.bar(1) + assert(x == 2, x) + +@main def Test = + test1 + test2