Skip to content

New footprint calculation scheme #19639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 13 additions & 35 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3054,7 +3054,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
end provablyDisjointTypeArgs

protected def explainingTypeComparer(short: Boolean) = ExplainingTypeComparer(comparerContext, short)
protected def trackingTypeComparer = TrackingTypeComparer(comparerContext)
protected def matchReducer = MatchReducer(comparerContext)

private def inSubComparer[T, Cmp <: TypeComparer](comparer: Cmp)(op: Cmp => T): T =
val saved = myInstance
Expand All @@ -3068,8 +3068,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
inSubComparer(cmp)(op)
cmp.lastTrace(header)

def tracked[T](op: TrackingTypeComparer => T)(using Context): T =
inSubComparer(trackingTypeComparer)(op)
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
inSubComparer(matchReducer)(op)
}

object TypeComparer {
Expand Down Expand Up @@ -3236,14 +3236,14 @@ object TypeComparer {
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
comparing(_.explained(op, header, short))

def tracked[T](op: TrackingTypeComparer => T)(using Context): T =
comparing(_.tracked(op))
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
comparing(_.reduceMatchWith(op))

def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult =
comparing(_.subCaptures(refs1, refs2, frozen))
}

object TrackingTypeComparer:
object MatchReducer:
import printing.*, Texts.*
enum MatchResult extends Showable:
case Reduced(tp: Type)
Expand All @@ -3259,38 +3259,16 @@ object TrackingTypeComparer:
case Stuck => "Stuck"
case NoInstance(fails) => "NoInstance(" ~ Text(fails.map(p.toText(_) ~ p.toText(_)), ", ") ~ ")"

class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
import TrackingTypeComparer.*
/** A type comparer for reducing match types.
* TODO: Not sure this needs to be a type comparer. Can we make it a
* separate class?
*/
class MatchReducer(initctx: Context) extends TypeComparer(initctx) {
import MatchReducer.*

init(initctx)

override def trackingTypeComparer = this

val footprint: mutable.Set[Type] = mutable.Set[Type]()

override def bounds(param: TypeParamRef)(using Context): TypeBounds = {
if (param.binder `ne` caseLambda) footprint += param
super.bounds(param)
}

override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean = {
if (param.binder `ne` caseLambda) footprint += param
super.addOneBound(param, bound, isUpper)
}

override def gadtBounds(sym: Symbol)(using Context): TypeBounds | Null = {
if (sym.exists) footprint += sym.typeRef
super.gadtBounds(sym)
}

override def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean =
if (sym.exists) footprint += sym.typeRef
super.gadtAddBound(sym, b, isUpper)

override def typeVarInstance(tvar: TypeVar)(using Context): Type = {
footprint += tvar
super.typeVarInstance(tvar)
}
override def matchReducer = this

def matchCases(scrut: Type, cases: List[MatchTypeCaseSpec])(using Context): Type = {
// a reference for the type parameters poisoned during matching
Expand Down
75 changes: 53 additions & 22 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5009,6 +5009,8 @@ object Types extends TypeUtils {
case ex: Throwable =>
handleRecursive("normalizing", s"${scrutinee.show} match ..." , ex)

private def thisMatchType = this

def reduced(using Context): Type = {

def contextInfo(tp: Type): Type = tp match {
Expand All @@ -5023,16 +5025,43 @@ object Types extends TypeUtils {
tp.underlying
}

def updateReductionContext(footprint: collection.Set[Type]): Unit =
reductionContext = util.HashMap()
for (tp <- footprint)
reductionContext(tp) = contextInfo(tp)
typr.println(i"footprint for $this $hashCode: ${footprint.toList.map(x => (x, contextInfo(x)))}%, %")

def isUpToDate: Boolean =
reductionContext.keysIterator.forall { tp =>
reductionContext.keysIterator.forall: tp =>
reductionContext(tp) `eq` contextInfo(tp)
}

def setReductionContext(): Unit =
new TypeTraverser:
var footprint: Set[Type] = Set()
var deep: Boolean = true
val seen = util.HashSet[Type]()
def traverse(tp: Type) =
if !seen.contains(tp) then
seen += tp
tp match
case tp: NamedType =>
if tp.symbol.is(TypeParam) then footprint += tp
traverseChildren(tp)
case _: AppliedType | _: RefinedType =>
if deep then traverseChildren(tp)
case TypeBounds(lo, hi) =>
traverse(hi)
case tp: TypeVar =>
footprint += tp
traverse(tp.underlying)
case tp: TypeParamRef =>
footprint += tp
case _ =>
traverseChildren(tp)
end traverse

traverse(scrutinee)
deep = false
cases.foreach(traverse)
reductionContext = util.HashMap()
for tp <- footprint do
reductionContext(tp) = contextInfo(tp)
matchTypes.println(i"footprint for $thisMatchType $hashCode: ${footprint.toList.map(x => (x, contextInfo(x)))}%, %")
end setReductionContext

record("MatchType.reduce called")
if !Config.cacheMatchReduced
Expand All @@ -5043,20 +5072,22 @@ object Types extends TypeUtils {
record("MatchType.reduce computed")
if (myReduced != null) record("MatchType.reduce cache miss")
myReduced =
trace(i"reduce match type $this $hashCode", matchTypes, show = true)(withMode(Mode.Type) {
def matchCases(cmp: TrackingTypeComparer): Type =
val saved = ctx.typerState.snapshot()
try cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
catch case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
finally
updateReductionContext(cmp.footprint)
ctx.typerState.resetTo(saved)
// this drops caseLambdas in constraint and undoes any typevar
// instantiations during matchtype reduction

TypeComparer.tracked(matchCases)
})
trace(i"reduce match type $this $hashCode", matchTypes, show = true):
withMode(Mode.Type):
setReductionContext()
def matchCases(cmp: MatchReducer): Type =
val saved = ctx.typerState.snapshot()
try
cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
catch case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
finally
ctx.typerState.resetTo(saved)
// this drops caseLambdas in constraint and undoes any typevar
// instantiations during matchtype reduction
TypeComparer.reduceMatchWith(matchCases)

//else println(i"no change for $this $hashCode / $myReduced")
myReduced.nn
}

Expand Down
99 changes: 99 additions & 0 deletions tests/pos/bad-footprint.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

object NamedTuple:

opaque type AnyNamedTuple = Any
opaque type NamedTuple[N <: Tuple, +V <: Tuple] >: V <: AnyNamedTuple = V

export NamedTupleDecomposition.{Names, DropNames}

/** The type of the named tuple `X` mapped with the type-level function `F`.
* If `X = (n1 : T1, ..., ni : Ti)` then `Map[X, F] = `(n1 : F[T1], ..., ni : F[Ti])`.
*/
type Map[X <: AnyNamedTuple, F[_ <: Tuple.Union[DropNames[X]]]] =
NamedTuple[Names[X], Tuple.Map[DropNames[X], F]]

end NamedTuple

object NamedTupleDecomposition:
import NamedTuple.*

/** The names of a named tuple, represented as a tuple of literal string values. */
type Names[X <: AnyNamedTuple] <: Tuple = X match
case NamedTuple[n, _] => n

/** The value types of a named tuple represented as a regular tuple. */
type DropNames[NT <: AnyNamedTuple] <: Tuple = NT match
case NamedTuple[_, x] => x
end NamedTupleDecomposition

class Expr[Result]

object Expr:
import NamedTuple.{NamedTuple, AnyNamedTuple}

type Of[A] = Expr[A]

type StripExpr[E] = E match
case Expr.Of[b] => b

case class Ref[A]($name: String = "") extends Expr.Of[A]

case class Join[A <: AnyNamedTuple](a: A)
extends Expr.Of[NamedTuple.Map[A, StripExpr]]
end Expr

trait Query[A]

object Query:
// Extension methods to support for-expression syntax for queries
extension [R](x: Query[R])
def map[B](f: Expr.Ref[R] => Expr.Of[B]): Query[B] = ???

case class City(zipCode: Int, name: String, population: Int)

object Test:
import Expr.StripExpr
import NamedTuple.{NamedTuple, AnyNamedTuple}

val cities: Query[City] = ???
val q6 =
cities.map: city =>
val x: NamedTuple[
("name", "zipCode"),
(Expr.Of[String], Expr.Of[Int])] = ???
Expr.Join(x)

/* Was error:

-- [E007] Type Mismatch Error: bad-footprint.scala:60:16 -----------------------
60 | cities.map: city =>
| ^
|Found: Expr.Ref[City] =>
| Expr[
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (String,
| Int)]
| ]
|Required: Expr.Ref[City] =>
| Expr[
| NamedTuple.NamedTuple[
| NamedTupleDecomposition.Names[
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (
| Expr[String], Expr[Int])]
| ],
| Tuple.Map[
| NamedTupleDecomposition.DropNames[
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (
| Expr[String], Expr[Int])]
| ],
| Expr.StripExpr]
| ]
| ]
61 | val x: NamedTuple[
62 | ("name", "zipCode"),
63 | (Expr.Of[String], Expr.Of[Int])] = ???
64 | Expr.Join(x)
|
| longer explanation available when compiling with `-explain`
1 error found

*/