Skip to content

Fix #8111: Use better algorithms to infer parameter types #8232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import typer.Applications._
import typer.ProtoTypes._
import typer.ForceDegree
import typer.Inferencing.isFullyDefined
import typer.IfBottom

import scala.annotation.internal.sharable

Expand Down Expand Up @@ -644,7 +645,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object.
tvar =>
!(ctx.typerState.constraint.entry(tvar.origin) `eq` tvar.origin.underlying) ||
(tvar `eq` removeThisType.prefixTVar),
allowBottom = false
IfBottom.flip
)

// If parent contains a reference to an abstract type, then we should
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1865,7 +1865,7 @@ trait Applications extends Compatibility {
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
else defn.FunctionOf(commonParamTypes, WildcardType)
overload.println(i"pretype arg $arg with expected type $commonFormal")
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.noBottom)))
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom)))
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
}
case None =>
Expand Down
31 changes: 21 additions & 10 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object Inferencing {
def instantiateSelected(tp: Type, tvars: List[Type])(implicit ctx: Context): Unit =
if (tvars.nonEmpty)
IsFullyDefinedAccumulator(
ForceDegree.Value(tvars.contains, allowBottom = false), minimizeSelected = true
ForceDegree.Value(tvars.contains, IfBottom.flip), minimizeSelected = true
).process(tp)

/** Instantiate any type variables in `tp` whose bounds contain a reference to
Expand Down Expand Up @@ -98,7 +98,7 @@ object Inferencing {

* If (1) and (2) do not apply, and minimizeSelected is not set:
* 6: T is maximized if it appears only contravariantly in the given type,
* or if forceDegree is `noBottom` and T has no lower bound different from Nothing.
* or if forceDegree is `flipBottom` and T has no lower bound different from Nothing.
* 7. Otherwise, T is minimized.
*
* The instantiation for (6) and (7) is done in two phases:
Expand Down Expand Up @@ -132,8 +132,10 @@ object Inferencing {
if tvar.hasLowerBound then instantiate(tvar, fromBelow = true)
else if tvar.hasUpperBound then instantiate(tvar, fromBelow = false)
else () // hold off instantiating unbounded unconstrained variables
else if variance >= 0 && (force.allowBottom || tvar.hasLowerBound) then
else if variance >= 0 && (force.ifBottom == IfBottom.ok || tvar.hasLowerBound) then
instantiate(tvar, fromBelow = true)
else if variance >= 0 && force.ifBottom == IfBottom.fail then
return false
else
toMaximize = tvar :: toMaximize
foldOver(x, tvar)
Expand All @@ -150,9 +152,14 @@ object Inferencing {
if !tvar.isInstantiated then
instantiate(tvar, fromBelow = false)
case nil =>
val res = apply(true, tp)
if res then maximize(toMaximize)
res
apply(true, tp)
&& (
toMaximize.isEmpty
|| { maximize(toMaximize)
toMaximize = Nil // Do another round since the maximixing instances
process(tp) // might have type uninstantiated variables themselves.
}
)
}

/** For all type parameters occurring in `tp`:
Expand Down Expand Up @@ -509,9 +516,13 @@ trait Inferencing { this: Typer =>

/** An enumeration controlling the degree of forcing in "is-dully-defined" checks. */
@sharable object ForceDegree {
class Value(val appliesTo: TypeVar => Boolean, val allowBottom: Boolean)
val none: Value = new Value(_ => false, allowBottom = true)
val all: Value = new Value(_ => true, allowBottom = true)
val noBottom: Value = new Value(_ => true, allowBottom = false)
class Value(val appliesTo: TypeVar => Boolean, val ifBottom: IfBottom)
val none: Value = new Value(_ => false, IfBottom.ok)
val all: Value = new Value(_ => true, IfBottom.ok)
val failBottom: Value = new Value(_ => true, IfBottom.fail)
val flipBottom: Value = new Value(_ => true, IfBottom.flip)
}

enum IfBottom:
case ok, fail, flip

2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
// Make sure all type arguments to the call are fully determined,
// but continue if that's not achievable (or else i7459.scala would crash).
for arg <- callTypeArgs do
isFullyDefined(arg.tpe, ForceDegree.noBottom)
isFullyDefined(arg.tpe, ForceDegree.flipBottom)

/** A map from parameter names of the inlineable method to references of the actual arguments.
* For a type argument this is the full argument type.
Expand Down
71 changes: 43 additions & 28 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ class Typer extends Namer
var templ1 = templ
def isEligible(tp: Type) = tp.exists && !tp.typeSymbol.is(Final) && !tp.isRef(defn.AnyClass)
if (templ1.parents.isEmpty &&
isFullyDefined(pt, ForceDegree.noBottom) &&
isFullyDefined(pt, ForceDegree.flipBottom) &&
isSkolemFree(pt) &&
isEligible(pt.underlyingClassRef(refinementOK = false)))
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
Expand Down Expand Up @@ -1009,16 +1009,20 @@ class Typer extends Namer
yield param.name -> idx
}.toMap
if (paramIndex.size == params.length)
expr match {
expr match
case untpd.TypedSplice(expr1) =>
expr1.tpe
case _ =>
given nestedCtx as Context = ctx.fresh.setNewTyperState()
val protoArgs = args map (_ withType WildcardType)
val callProto = FunProto(protoArgs, WildcardType)(this, app.isGivenApply)
val expr1 = typedExpr(expr, callProto)
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
expr1.tpe
}
if nestedCtx.reporter.hasErrors then NoType
else
given Context = ctx
nestedCtx.typerState.commit()
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
expr1.tpe
else NoType
case _ =>
NoType
Expand All @@ -1030,42 +1034,53 @@ class Typer extends Namer
// try to instantiate `pt` if this is possible. If it does not
// work the error will be reported later in `inferredParam`,
// when we try to infer the parameter type.
isFullyDefined(pt, ForceDegree.noBottom)
isFullyDefined(pt, ForceDegree.flipBottom)
case _ =>
}

val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)

/** Two attempts: First, if expected type is fully defined pick this one.
* Second, if function is of the form
* (x1, ..., xN) => f(... x1, ..., XN, ...)
* where each `xi` occurs exactly once in the argument list of `f` (in
* any order), and f has a method type MT, pick the corresponding parameter
* type in MT, if this one is fully defined.
* If both attempts fail, issue a "missing parameter type" error.
*/
def inferredParamType(param: untpd.ValDef, formal: Type): Type = {
if (isFullyDefined(formal, ForceDegree.noBottom)) return formal
calleeType.widen match {
/** The inferred parameter type for a parameter in a lambda that does
* not have an explicit type given.
* An inferred parameter type I has two possible sources:
* - the type S known from the context
* - the "target type" T known from the callee `f` if the lambda is of a form like `x => f(x)`
* If `T` exists, we know that `S <: I <: T`.
*
* The inference makes three attempts:
*
* 1. If the expected type `S` is already fully defined under ForceDegree.failBottom
* pick this one.
* 2. Compute the target type `T` and make it known that `S <: T`.
* If the expected type `S` can be fully defined under ForceDegree.flipBottom,
* pick this one (this might use the fact that S <: T for an upper approximation).
* 3. Otherwise, if the target type `T` can be fully defined under ForceDegree.flipBottom,
* pick this one.
*
* If all attempts fail, issue a "missing parameter type" error.
*/
def inferredParamType(param: untpd.ValDef, formal: Type): Type =
if isFullyDefined(formal, ForceDegree.failBottom) then return formal
val target = calleeType.widen match
case mtpe: MethodType =>
val pos = paramIndex(param.name)
if (pos < mtpe.paramInfos.length) {
if pos < mtpe.paramInfos.length then
val ptype = mtpe.paramInfos(pos)
if (isFullyDefined(ptype, ForceDegree.noBottom) && !ptype.isRepeatedParam)
return ptype
}
case _ =>
}
errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)
}
if ptype.isRepeatedParam then NoType else ptype
else NoType
case _ => NoType
if target.exists then formal <:< target
if isFullyDefined(formal, ForceDegree.flipBottom) then formal
else if target.exists && isFullyDefined(target, ForceDegree.flipBottom) then target
else errorType(AnonymousFunctionMissingParamType(param, params, tree, formal), param.sourcePos)

def protoFormal(i: Int): Type =
if (protoFormals.length == params.length) protoFormals(i)
else errorType(WrongNumberOfParameters(protoFormals.length), tree.sourcePos)

/** Is `formal` a product type which is elementwise compatible with `params`? */
def ptIsCorrectProduct(formal: Type) =
isFullyDefined(formal, ForceDegree.noBottom) &&
isFullyDefined(formal, ForceDegree.flipBottom) &&
(defn.isProductSubType(formal) || formal.derivesFrom(defn.PairClass)) &&
productSelectorTypes(formal, tree.sourcePos).corresponds(params) {
(argType, param) =>
Expand Down Expand Up @@ -1379,7 +1394,7 @@ class Typer extends Namer
}
case _ =>
tree.withType(
if (isFullyDefined(pt, ForceDegree.noBottom)) pt
if (isFullyDefined(pt, ForceDegree.flipBottom)) pt
else if (ctx.reporter.errorsReported) UnspecifiedErrorType
else errorType(i"cannot infer type; expected type $pt is not fully defined", tree.sourcePos))
}
Expand Down Expand Up @@ -3054,7 +3069,7 @@ class Typer extends Namer
pt match {
case SAMType(sam)
if wtp <:< sam.toFunctionType() =>
// was ... && isFullyDefined(pt, ForceDegree.noBottom)
// was ... && isFullyDefined(pt, ForceDegree.flipBottom)
// but this prevents case blocks from implementing polymorphic partial functions,
// since we do not know the result parameter a priori. Have to wait until the
// body is typechecked.
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/i8111.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object Example extends App {

def assertLazy[A, B](f: (A) => B): Boolean = ???

def fromEither[E, F](eea: Either[E, F]): Unit = ???

lazy val result = assertLazy(fromEither)

println("It compiles!")
}