Skip to content

Base multiversal equality on typeclass derivation #5843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 2 additions & 41 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -538,45 +538,6 @@ object desugar {
if (isEnum)
parents1 = parents1 :+ ref(defn.EnumType)

// The Eq instance for an Enum class. For an enum class
//
// enum class C[T1, ..., Tn]
//
// we generate:
//
// implicit def eqInstance[T1$1, ..., Tn$1, T1$2, ..., Tn$2](implicit
// ev1: Eq[T1$1, T1$2], ..., evn: Eq[Tn$1, Tn$2]])
// : Eq[C[T1$, ..., Tn$1], C[T1$2, ..., Tn$2]] = Eq
//
// Higher-kinded type arguments `Ti` are omitted as evidence parameters.
//
// FIXME: This is too simplistic. Instead of just generating evidence arguments
// for every first-kinded type parameter, we should look instead at the
// actual types occurring in cases and derive parameters from these. E.g. in
//
// enum HK[F[_]] {
// case C1(x: F[Int]) extends HK[F[Int]]
// case C2(y: F[String]) extends HL[F[Int]]
//
// we would need evidence parameters for `F[Int]` and `F[String]`
// We should generate Eq instances with the techniques
// of typeclass derivation once that is available.
def eqInstance = {
val leftParams = constrTparams.map(derivedTypeParam(_, "$1"))
val rightParams = constrTparams.map(derivedTypeParam(_, "$2"))
val subInstances =
for ((param1, param2) <- leftParams `zip` rightParams if !isHK(param1))
yield appliedRef(ref(defn.EqType), List(param1, param2), widenHK = true)
DefDef(
name = nme.eqInstance,
tparams = leftParams ++ rightParams,
vparamss = if (subInstances.isEmpty) Nil else List(makeImplicitParameters(subInstances)),
tpt = appliedTypeTree(ref(defn.EqType),
appliedRef(classTycon, leftParams) :: appliedRef(classTycon, rightParams) :: Nil),
rhs = ref(defn.EqModule.termRef)).withFlags(Synthetic | Implicit)
}
def eqInstances = if (isEnum) eqInstance :: Nil else Nil

// derived type classes of non-module classes go to their companions
val (clsDerived, companionDerived) =
if (mods.is(Module)) (impl.derived, Nil) else (Nil, impl.derived)
Expand All @@ -595,7 +556,7 @@ object desugar {
mdefs
}

val companionMembers = defaultGetters ::: eqInstances ::: enumCases
val companionMembers = defaultGetters ::: enumCases

// The companion object definitions, if a companion is needed, Nil otherwise.
// companion definitions include:
Expand Down Expand Up @@ -645,7 +606,7 @@ object desugar {
}
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)
}
else if (companionMembers.nonEmpty || companionDerived.nonEmpty)
else if (companionMembers.nonEmpty || companionDerived.nonEmpty || isEnum)
companionDefs(anyRef, companionMembers)
else if (isValueClass) {
impl.constr.vparamss match {
Expand Down
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -734,11 +734,11 @@ class Definitions {
lazy val TastyReflectionModule: TermSymbol = ctx.requiredModule("scala.tasty.Reflection")
lazy val TastyReflection_macroContext: TermSymbol = TastyReflectionModule.requiredMethod("macroContext")

lazy val EqType: TypeRef = ctx.requiredClassRef("scala.Eq")
def EqClass(implicit ctx: Context): ClassSymbol = EqType.symbol.asClass
def EqModule(implicit ctx: Context): Symbol = EqClass.companionModule
lazy val EqlType: TypeRef = ctx.requiredClassRef("scala.Eql")
def EqlClass(implicit ctx: Context): ClassSymbol = EqlType.symbol.asClass
def EqlModule(implicit ctx: Context): Symbol = EqlClass.companionModule

def Eq_eqAny(implicit ctx: Context): TermSymbol = EqModule.requiredMethod(nme.eqAny)
def Eql_eqlAny(implicit ctx: Context): TermSymbol = EqlModule.requiredMethod(nme.eqlAny)

lazy val NotType: TypeRef = ctx.requiredClassRef("scala.implicits.Not")
def NotClass(implicit ctx: Context): ClassSymbol = NotType.symbol.asClass
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Denotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ import collection.mutable.ListBuffer
*/
object Denotations {

implicit def eqDenotation: Eq[Denotation, Denotation] = Eq
implicit def eqDenotation: Eql[Denotation, Denotation] = Eql.derived

/** A PreDenotation represents a group of single denotations or a single multi-denotation
* It is used as an optimization to avoid forming MultiDenotations too eagerly.
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ object Mode {
/** Allow GADTFlexType labelled types to have their bounds adjusted */
val GADTflexible: Mode = newMode(8, "GADTflexible")

/** Assume -language:strictEquality */
val StrictEquality: Mode = newMode(9, "StrictEquality")

/** We are currently printing something: avoid to produce more logs about
* the printing
*/
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object Names {
def toTermName: TermName
}

implicit def eqName: Eq[Name, Name] = Eq
implicit def eqName: Eql[Name, Name] = Eql.derived

/** A common superclass of Name and Symbol. After bootstrap, this should be
* just the type alias Name | Symbol
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ object StdNames {
val equals_ : N = "equals"
val error: N = "error"
val eval: N = "eval"
val eqAny: N = "eqAny"
val eqlAny: N = "eqlAny"
val ex: N = "ex"
val experimental: N = "experimental"
val f: N = "f"
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ trait Symbols { this: Context =>

object Symbols {

implicit def eqSymbol: Eq[Symbol, Symbol] = Eq
implicit def eqSymbol: Eql[Symbol, Symbol] = Eql.derived

/** Tree attachment containing the identifiers in a tree as a sorted array */
val Ids: Property.Key[Array[String]] = new Property.Key
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object Types {

@sharable private[this] var nextId = 0

implicit def eqType: Eq[Type, Type] = Eq
implicit def eqType: Eql[Type, Type] = Eql.derived

/** Main class representing types.
*
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Deriving.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ trait Deriving { this: Typer =>
if (nparams == 0) Nil
else if (nparams == 1) tparam :: Nil
else typeClass.typeParams.map(tcparam =>
tparam.copy(name = s"${tparam.name}_${tcparam.name}".toTypeName)
tparam.copy(name = s"${tparam.name}_$$_${tcparam.name}".toTypeName)
.asInstanceOf[TypeSymbol])
}
val firstKindedParamss = clsParamss.filter {
Expand Down
105 changes: 76 additions & 29 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -699,16 +699,65 @@ trait Implicits { self: Typer =>
if (ctx.inInlineMethod || enclosingInlineds.nonEmpty) ref(defn.TastyReflection_macroContext)
else EmptyTree

/** If `formal` is of the form Eq[T, U], where no `Eq` instance exists for
* either `T` or `U`, synthesize `Eq.eqAny[T, U]` as solution.
/** If `formal` is of the form Eql[T, U], try to synthesize an
* `Eql.eqlAny[T, U]` as solution.
*/
def synthesizedEq(formal: Type)(implicit ctx: Context): Tree = {
//println(i"synth eq $formal / ${formal.argTypes}%, %")

/** Is there an `Eql[T, T]` instance, assuming -strictEquality? */
def hasEq(tp: Type)(implicit ctx: Context): Boolean = {
val inst = inferImplicitArg(defn.EqlType.appliedTo(tp, tp), span)
!inst.isEmpty && !inst.tpe.isError
}

/** Can we assume the eqlAny instance for `tp1`, `tp2`?
* This is the case if assumedCanEqual(tp1, tp2), or
* one of `tp1`, `tp2` has a reflexive `Eql` instance.
*/
def validEqAnyArgs(tp1: Type, tp2: Type)(implicit ctx: Context) =
assumedCanEqual(tp1, tp2) || {
val nestedCtx = ctx.fresh.addMode(Mode.StrictEquality)
!hasEq(tp1)(nestedCtx) && !hasEq(tp2)(nestedCtx)
}

/** Is an `Eql[cls1, cls2]` instance assumed for predefined classes `cls1`, cls2`? */
def canComparePredefinedClasses(cls1: ClassSymbol, cls2: ClassSymbol): Boolean = {
def cmpWithBoxed(cls1: ClassSymbol, cls2: ClassSymbol) =
cls2 == defn.boxedType(cls1.typeRef).symbol ||
cls1.isNumericValueClass && cls2.derivesFrom(defn.BoxedNumberClass)

if (cls1.isPrimitiveValueClass)
if (cls2.isPrimitiveValueClass)
cls1 == cls2 || cls1.isNumericValueClass && cls2.isNumericValueClass
else
cmpWithBoxed(cls1, cls2)
else if (cls2.isPrimitiveValueClass)
cmpWithBoxed(cls2, cls1)
else if (cls1 == defn.NullClass)
cls1 == cls2 || cls2.derivesFrom(defn.ObjectClass)
else if (cls2 == defn.NullClass)
cls1.derivesFrom(defn.ObjectClass)
else
false
}

/** Some simulated `Eql` instances for predefined types. It's more efficient
* to do this directly instead of setting up a lot of `Eql` instances to
* interpret.
*/
def canComparePredefined(tp1: Type, tp2: Type) =
tp1.classSymbols.exists(cls1 =>
tp2.classSymbols.exists(cls2 => canComparePredefinedClasses(cls1, cls2)))

formal.argTypes match {
case args @ (arg1 :: arg2 :: Nil)
if !ctx.featureEnabled(defn.LanguageModuleClass, nme.strictEquality) &&
ctx.test(implicit ctx => validEqAnyArgs(arg1, arg2)) =>
ref(defn.Eq_eqAny).appliedToTypes(args).withSpan(span)
case args @ (arg1 :: arg2 :: Nil) =>
List(arg1, arg2).foreach(fullyDefinedType(_, "eq argument", span))
if (canComparePredefined(arg1, arg2)
||
!strictEquality &&
ctx.test(implicit ctx => validEqAnyArgs(arg1, arg2)))
ref(defn.Eql_eqlAny).appliedToTypes(args).withSpan(span)
else EmptyTree
case _ =>
EmptyTree
}
Expand Down Expand Up @@ -737,14 +786,6 @@ trait Implicits { self: Typer =>
}
}

def hasEq(tp: Type): Boolean =
inferImplicit(defn.EqType.appliedTo(tp, tp), EmptyTree, span).isSuccess

def validEqAnyArgs(tp1: Type, tp2: Type)(implicit ctx: Context) = {
List(tp1, tp2).foreach(fullyDefinedType(_, "eqAny argument", span))
assumedCanEqual(tp1, tp2) || !hasEq(tp1) && !hasEq(tp2)
}

/** If `formal` is of the form `scala.reflect.Generic[T]` for some class type `T`,
* synthesize an instance for it.
*/
Expand Down Expand Up @@ -776,7 +817,7 @@ trait Implicits { self: Typer =>
trySpecialCase(defn.QuotedTypeClass, synthesizedTypeTag,
trySpecialCase(defn.GenericClass, synthesizedGeneric,
trySpecialCase(defn.TastyReflectionClass, synthesizedTastyContext,
trySpecialCase(defn.EqClass, synthesizedEq,
trySpecialCase(defn.EqlClass, synthesizedEq,
trySpecialCase(defn.ValueOfClass, synthesizedValueOf, failed))))))
}
}
Expand Down Expand Up @@ -885,16 +926,16 @@ trait Implicits { self: Typer =>
em"parameter ${paramName} of $methodStr"
}

private def assumedCanEqual(ltp: Type, rtp: Type)(implicit ctx: Context) = {
def eqNullable: Boolean = {
val other =
if (ltp.isRef(defn.NullClass)) rtp
else if (rtp.isRef(defn.NullClass)) ltp
else NoType

(other ne NoType) && !other.derivesFrom(defn.AnyValClass)
}
private def strictEquality(implicit ctx: Context): Boolean =
ctx.mode.is(Mode.StrictEquality) ||
ctx.featureEnabled(defn.LanguageModuleClass, nme.strictEquality)

/** An Eql[T, U] instance is assumed
* - if one of T, U is an error type, or
* - if one of T, U is a subtype of the lifted version of the other,
* unless strict equality is set.
*/
private def assumedCanEqual(ltp: Type, rtp: Type)(implicit ctx: Context) = {
// Map all non-opaque abstract types to their upper bound.
// This is done to check whether such types might plausibly be comparable to each other.
val lift = new TypeMap {
Expand All @@ -910,14 +951,20 @@ trait Implicits { self: Typer =>
if (variance > 0) mapOver(t) else t
}
}
ltp.isError || rtp.isError || ltp <:< lift(rtp) || rtp <:< lift(ltp) || eqNullable

ltp.isError ||
rtp.isError ||
!strictEquality && {
ltp <:< lift(rtp) ||
rtp <:< lift(ltp)
}
}

/** Check that equality tests between types `ltp` and `rtp` make sense */
def checkCanEqual(ltp: Type, rtp: Type, span: Span)(implicit ctx: Context): Unit =
if (!ctx.isAfterTyper && !assumedCanEqual(ltp, rtp)) {
val res = implicitArgTree(defn.EqType.appliedTo(ltp, rtp), span)
implicits.println(i"Eq witness found for $ltp / $rtp: $res: ${res.tpe}")
val res = implicitArgTree(defn.EqlType.appliedTo(ltp, rtp), span)
implicits.println(i"Eql witness found for $ltp / $rtp: $res: ${res.tpe}")
}

/** Find an implicit parameter or conversion.
Expand Down Expand Up @@ -985,7 +1032,7 @@ trait Implicits { self: Typer =>
if (argument.isEmpty) f(resultType) else ViewProto(f(argument.tpe.widen), f(resultType))
// Not clear whether we need to drop the `.widen` here. All tests pass with it in place, though.

private def isCoherent = pt.isRef(defn.EqClass)
private def isCoherent = pt.isRef(defn.EqlClass)

private val cmpContext = nestedContext()
private val cmpCandidates = (c1: Candidate, c2: Candidate) => compare(c1.ref, c2.ref, c1.level, c2.level)(cmpContext)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ class Namer { typer: Typer =>

if (impl.derived.nonEmpty) {
val (derivingClass, derivePos) = original.removeAttachment(desugar.DerivingCompanion) match {
case Some(pos) => (cls.companionClass.asClass, pos)
case Some(pos) => (cls.companionClass.orElse(cls).asClass, pos)
case None => (cls, impl.sourcePos.startPos)
}
val deriver = new Deriver(derivingClass, derivePos)(localCtx)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/util/SourceFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class SourceFile(val file: AbstractFile, computeContent: => Array[Char]) extends
}
}
object SourceFile {
implicit def eqSource: Eq[SourceFile, SourceFile] = Eq
implicit def eqSource: Eql[SourceFile, SourceFile] = Eql.derived

implicit def fromContext(implicit ctx: Context): SourceFile = ctx.source

Expand Down
4 changes: 2 additions & 2 deletions compiler/test-resources/repl/i4184
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ scala> object foo { class Foo }
// defined object foo
scala> object bar { class Foo }
// defined object bar
scala> implicit def eqFoo: Eq[foo.Foo, foo.Foo] = Eq
def eqFoo: Eq[foo.Foo, foo.Foo]
scala> implicit def eqFoo: Eql[foo.Foo, foo.Foo] = Eql.derived
def eqFoo: Eql[foo.Foo, foo.Foo]
scala> object Bar { new foo.Foo == new bar.Foo }
1 | object Bar { new foo.Foo == new bar.Foo }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
Loading