diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 59f8c37cbf55..44152c072468 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1652,6 +1652,11 @@ object Types { */ def deepenProto(using Context): Type = this + /** If this is a prototype with some ignored component, reveal it, and + * deepen the result transitively. Otherwise the type itself. + */ + def deepenProtoTrans(using Context): Type = this + /** If this is an ignored proto type, its underlying type, otherwise the type itself */ def revealIgnored: Type = this @@ -3436,7 +3441,7 @@ object Types { case tp: TermRef => applyPrefix(tp) case tp: AppliedType => tp.fold(status, compute(_, _, theAcc)) case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional) - case TermParamRef(`thisLambdaType`, _) => TrueDeps + case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps case _: ThisType | _: BoundType | NoPrefix => status case _ => (if theAcc != null then theAcc else DepAcc()).foldOver(status, tp) diff --git a/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala b/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala index 3d43cc5976c8..f19741668d32 100644 --- a/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala +++ b/compiler/src/dotty/tools/dotc/typer/ImportInfo.scala @@ -12,7 +12,6 @@ import Implicits.RenamedImplicitRef import config.SourceVersion import StdNames.nme import printing.Texts.Text -import ProtoTypes.NoViewsAllowed.normalizedCompatible import NameKinds.QualifiedName import Decorators._ diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index 1da24e4f8146..0d2c60c677b0 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -47,27 +47,40 @@ object ProtoTypes { necessarySubType(tpn, pt) || tpn.isValueSubType(pt) || viewExists(tpn, pt) /** Test compatibility after normalization. - * Do this in a fresh typerstate unless `keepConstraint` is true. + * If `keepConstraint` is false, the current constraint set will not be modified by this call. */ - def normalizedCompatible(tp: Type, pt: Type, keepConstraint: Boolean)(using Context): Boolean = { - def testCompat(using Context): Boolean = { + def normalizedCompatible(tp: Type, pt: Type, keepConstraint: Boolean)(using Context): Boolean = + + def testCompat(using Context): Boolean = val normTp = normalize(tp, pt) isCompatible(normTp, pt) || pt.isRef(defn.UnitClass) && normTp.isParameterless - } - if (keepConstraint) - tp.widenSingleton match { + + if keepConstraint then + tp.widenSingleton match case poly: PolyType => - // We can't keep the constraint in this case, since we have to add type parameters - // to it, but there's no place to associate them with type variables. - // So we'd get a "inconsistent: no typevars were added to committable constraint" - // assertion failure in `constrained`. To do better, we'd have to change the - // constraint handling architecture so that some type parameters are committable - // and others are not. But that's a whole different ballgame. - normalizedCompatible(tp, pt, keepConstraint = false) + val newctx = ctx.fresh.setNewTyperState() + val result = testCompat(using newctx) + typr.println( + i"""normalizedCompatible for $poly, $pt = $result + |constraint was: ${ctx.typerState.constraint} + |constraint now: ${newctx.typerState.constraint}""") + if result + && (ctx.typerState.constraint ne newctx.typerState.constraint) + && { + val existingVars = ctx.typerState.uninstVars.toSet + newctx.typerState.uninstVars.forall(existingVars.contains) + } + then newctx.typerState.commit() + // If the new constrait contains fresh type variables we cannot keep it, + // since those type variables are not instantiated anywhere in the source. + // See pos/i6682a.scala for a test case. See pos/11243.scala and pos/i5773b.scala + // for tests where it matters that we keep the constraint otherwise. + // TODO: A better solution would clean the new constraint, so that it "avoids" + // the problematic type variables. But we have not implemented such an algorithm yet. + result case _ => testCompat - } else explore(testCompat) - } + end normalizedCompatible private def disregardProto(pt: Type)(using Context): Boolean = pt.dealias.isRef(defn.UnitClass) @@ -79,10 +92,18 @@ object ProtoTypes { val savedConstraint = ctx.typerState.constraint val res = pt.widenExpr match { case pt: FunProto => - mt match { - case mt: MethodType => constrainResult(resultTypeApprox(mt), pt.resultType) + mt match + case mt: MethodType => + constrainResult(resultTypeApprox(mt), pt.resultType) + && { + if pt.constrainResultDeep + && mt.isImplicitMethod == (pt.applyKind == ApplyKind.Using) + then + pt.args.lazyZip(mt.paramInfos).forall((arg, paramInfo) => + pt.typedArg(arg, paramInfo).tpe <:< paramInfo) + else true + } case _ => true - } case _: ValueTypeOrProto if !disregardProto(pt) => necessarilyCompatible(mt, pt) case pt: WildcardType if pt.optBounds.exists => @@ -123,6 +144,7 @@ object ProtoTypes { abstract case class IgnoredProto(ignored: Type) extends CachedGroundType with MatchAlways: override def revealIgnored = ignored override def deepenProto(using Context): Type = ignored + override def deepenProtoTrans(using Context): Type = ignored.deepenProtoTrans override def computeHash(bs: Hashable.Binders): Int = doHash(bs, ignored) @@ -202,7 +224,12 @@ object ProtoTypes { def map(tm: TypeMap)(using Context): SelectionProto = derivedSelectionProto(name, tm(memberProto), compat) def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T = ta(x, memberProto) - override def deepenProto(using Context): SelectionProto = derivedSelectionProto(name, memberProto.deepenProto, compat) + override def deepenProto(using Context): SelectionProto = + derivedSelectionProto(name, memberProto.deepenProto, compat) + + override def deepenProtoTrans(using Context): SelectionProto = + derivedSelectionProto(name, memberProto.deepenProtoTrans, compat) + override def computeHash(bs: Hashable.Binders): Int = { val delta = (if (compat eq NoViewsAllowed) 1 else 0) | (if (privateOK) 2 else 0) addDelta(doHash(bs, name, memberProto), delta) @@ -276,9 +303,21 @@ object ProtoTypes { /** A prototype for expressions that appear in function position * * [](args): resultType + * + * @param args The untyped arguments to which the function is applied + * @param resType The expeected result type + * @param typer The typer to use for typing the arguments + * @param applyKind The kind of application (regular/using/tupled infix operand) + * @param state The state object to use for tracking the changes to this prototype + * @param constrainResultDeep + * A flag to indicate that constrainResult on this prototype + * should typecheck and compare the arguments. */ - case class FunProto(args: List[untpd.Tree], resType: Type)(typer: Typer, - override val applyKind: ApplyKind, state: FunProtoState = new FunProtoState)(using protoCtx: Context) + case class FunProto(args: List[untpd.Tree], resType: Type)( + typer: Typer, + override val applyKind: ApplyKind, + state: FunProtoState = new FunProtoState, + val constrainResultDeep: Boolean = false)(using protoCtx: Context) extends UncachedGroundType with ApplyingProto with FunOrPolyProto { override def resultType(using Context): Type = resType @@ -290,9 +329,17 @@ object ProtoTypes { typer.isApplicableType(tp, args, resultType, keepConstraint && !args.exists(isPoly)) } - def derivedFunProto(args: List[untpd.Tree] = this.args, resultType: Type, typer: Typer = this.typer): FunProto = - if ((args eq this.args) && (resultType eq this.resultType) && (typer eq this.typer)) this - else new FunProto(args, resultType)(typer, applyKind) + def derivedFunProto( + args: List[untpd.Tree] = this.args, + resultType: Type = this.resultType, + typer: Typer = this.typer, + constrainResultDeep: Boolean = this.constrainResultDeep): FunProto = + if (args eq this.args) + && (resultType eq this.resultType) + && (typer eq this.typer) + && constrainResultDeep == this.constrainResultDeep + then this + else new FunProto(args, resultType)(typer, applyKind, constrainResultDeep = constrainResultDeep) /** @return True if all arguments have types. */ @@ -419,7 +466,11 @@ object ProtoTypes { def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T = ta(ta.foldOver(x, typedArgs().tpes), resultType) - override def deepenProto(using Context): FunProto = derivedFunProto(args, resultType.deepenProto, typer) + override def deepenProto(using Context): FunProto = + derivedFunProto(args, resultType.deepenProto) + + override def deepenProtoTrans(using Context): FunProto = + derivedFunProto(args, resultType.deepenProtoTrans, constrainResultDeep = true) override def withContext(newCtx: Context): ProtoType = if newCtx `eq` protoCtx then this @@ -472,7 +523,11 @@ object ProtoTypes { def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T = ta(ta(x, argType), resultType) - override def deepenProto(using Context): ViewProto = derivedViewProto(argType, resultType.deepenProto) + override def deepenProto(using Context): ViewProto = + derivedViewProto(argType, resultType.deepenProto) + + override def deepenProtoTrans(using Context): ViewProto = + derivedViewProto(argType, resultType.deepenProtoTrans) } class CachedViewProto(argType: Type, resultType: Type) extends ViewProto(argType, resultType) { @@ -522,7 +577,11 @@ object ProtoTypes { def fold[T](x: T, ta: TypeAccumulator[T])(using Context): T = ta(ta.foldOver(x, targs.tpes), resultType) - override def deepenProto(using Context): PolyProto = derivedPolyProto(targs, resultType.deepenProto) + override def deepenProto(using Context): PolyProto = + derivedPolyProto(targs, resultType.deepenProto) + + override def deepenProtoTrans(using Context): PolyProto = + derivedPolyProto(targs, resultType.deepenProtoTrans) } /** A prototype for expressions [] that are known to be functions: diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index e37c2e482af7..02a85c6fd20c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3182,9 +3182,8 @@ class Typer extends Namer val arg = inferImplicitArg(formal, tree.span.endPos) arg.tpe match case failed: AmbiguousImplicits => - val pt1 = pt.deepenProto - if (pt1 `ne` pt) && (pt1 ne sharpenedPt) - && constrainResult(tree.symbol, wtp, pt1) + val pt1 = pt.deepenProtoTrans + if (pt1 `ne` pt) && (pt1 ne sharpenedPt) && constrainResult(tree.symbol, wtp, pt1) then implicitArgs(formals, argIndex, pt1) else arg :: implicitArgs(formals1, argIndex + 1, pt1) case failed: SearchFailureType => diff --git a/tests/neg/i6391.scala b/tests/neg/i6391.scala index fbe3e770d8ef..6df99467e012 100644 --- a/tests/neg/i6391.scala +++ b/tests/neg/i6391.scala @@ -1,4 +1,4 @@ object Test { def foo(x: String, y: x.type): Any = ??? - val f = foo // error // error: cannot convert to closure -} \ No newline at end of file + val f = foo // error +} diff --git a/tests/pos/i11243.scala b/tests/pos/i11243.scala new file mode 100644 index 000000000000..7966df0c8243 --- /dev/null +++ b/tests/pos/i11243.scala @@ -0,0 +1,127 @@ +object WriterTest extends App { + + object Functor: + def apply[F[_]](using f: Functor[F]) = f + + trait Functor[F[_]]: + extension [A, B](x: F[A]) + def map(f: A => B): F[B] + + object Applicative: + def apply[F[_]](using a: Applicative[F]) = a + + trait Applicative[F[_]] extends Functor[F]: + def pure[A](x:A):F[A] + + extension [A,B](x: F[A]) + def ap(f: F[A => B]): F[B] + + def map(f: A => B): F[B] = { + x.ap(pure(f)) + } + + extension [A,B,C](fa: F[A]) def map2(fb: F[B])(f: (A,B) => C): F[C] = { + val fab: F[B => C] = fa.map((a: A) => (b: B) => f(a,b)) + fb.ap(fab) + } + + end Applicative + + + object Monad: + def apply[F[_]](using m: Monad[F]) = m + + trait Monad[F[_]] extends Applicative[F]: + + // The unit value for a monad + def pure[A](x:A):F[A] + + extension[A,B](fa :F[A]) + // The fundamental composition operation + def flatMap(f :A=>F[B]):F[B] + + // Monad can also implement `ap` in terms of `map` and `flatMap` + def ap(fab: F[A => B]): F[B] = { + fab.flatMap { + f => + fa.flatMap { + a => + pure(f(a)) + } + } + + } + + end Monad + + given eitherMonad[Err]: Monad[[X] =>> Either[Err,X]] with + def pure[A](a: A): Either[Err, A] = Right(a) + extension [A,B](x: Either[Err,A]) def flatMap(f: A => Either[Err, B]) = { + x match { + case Right(a) => f(a) + case Left(err) => Left(err) + } + } + + given optionMonad: Monad[Option] with + def pure[A](a: A) = Some(a) + extension[A,B](fa: Option[A]) + def flatMap(f: A => Option[B]) = { + fa match { + case Some(a) => + f(a) + case None => + None + } + } + + given listMonad: Monad[List] with + def pure[A](a: A): List[A] = List(a) + + extension[A,B](x: List[A]) + def flatMap(f: A => List[B]): List[B] = { + x match { + case hd :: tl => f(hd) ++ tl.flatMap(f) + case Nil => Nil + } + } + + case class Transformer[F[_]: Monad,A](val wrapped: F[A]) + + given transformerMonad[F[_]: Monad]: Monad[[X] =>> Transformer[F,X]] with { + + def pure[A](a: A): Transformer[F,A] = Transformer(summon[Monad[F]].pure(a)) + + extension [A,B](fa: Transformer[F,A]) + def flatMap(f: A => Transformer[F,B]) = { + val ffa: F[B] = Monad[F].flatMap(fa.wrapped) { + case a => { + f(a).wrapped.map { + case b => + b + } + } + } + Transformer(ffa) + } + } + + type EString[A] = Either[String,A] + + def incrementEven(a: Int): Transformer[EString,Int] = { + if(a % 2 == 1) Transformer(Left("Odd number provided")) + else Transformer(Right(a + 1)) + } + + def doubleOdd(a: Int): Transformer[EString, Int] = { + if(a % 2 == 0) Transformer(Left("Even number provided")) + else Transformer(Right(a * 2)) + } + + val writerExample = incrementEven(8) + val example = + WriterTest.transformerMonad.flatMap(writerExample)(doubleOdd) + //writerExample.flatMap(doubleOdd) // Error ambiguous F + + +} \ No newline at end of file diff --git a/tests/neg/i5773.scala b/tests/pos/i5773b.scala similarity index 83% rename from tests/neg/i5773.scala rename to tests/pos/i5773b.scala index 617720e085cf..015a4c2828a4 100644 --- a/tests/neg/i5773.scala +++ b/tests/pos/i5773b.scala @@ -10,7 +10,7 @@ object Semigroup { implicit def sumSemigroup[N](implicit N: Numeric[N]): Semigroup[N] = new { extension (lhs: N) override def append(rhs: N): N = N.plus(lhs, rhs) - extension (lhs: Int) def appendS(rhs: N): N = ??? // N.plus(lhs, rhs) + extension (lhs: Int) override def appendS(rhs: N): N = ??? // N.plus(lhs, rhs) } } @@ -18,7 +18,7 @@ object Semigroup { object Main { import Semigroup.sumSemigroup // this is not sufficient def f1 = { - println(1 appendS 2) // error This should give the following error message: + println(1 appendS 2) // This used to give the following error message: /* 21 | println(1 appendS 2) | ^^^^^^^^^