@@ -16,12 +16,6 @@ import Inferencing._
16
16
import transform .TypeUtils ._
17
17
import transform .SymUtils ._
18
18
19
-
20
- // TODOs:
21
- // - handle case where there's no companion object
22
- // - check that derived instances are stable
23
- // - reference derived instances with correct prefix instead of just the symbol
24
-
25
19
/** A typer mixin that implements typeclass derivation functionality */
26
20
trait Deriving { this : Typer =>
27
21
@@ -131,14 +125,23 @@ trait Deriving { this: Typer =>
131
125
flags : FlagSet = EmptyFlags )(implicit ctx : Context ): TermSymbol =
132
126
newSymbol(name, info, pos, flags | Method ).asTerm
133
127
128
+ /** A version of Type#underlyingClassRef that works also for higher-kinded types */
129
+ private def underlyingClassRef (tp : Type ): Type = tp match {
130
+ case tp : TypeRef if tp.symbol.isClass => tp
131
+ case tp : TypeRef if tp.symbol.isAbstractType => NoType
132
+ case tp : TermRef => NoType
133
+ case tp : TypeProxy => underlyingClassRef(tp.underlying)
134
+ case _ => NoType
135
+ }
136
+
134
137
/** Enter type class instance with given name and info in current scope, provided
135
138
* an instance with the same name does not exist already.
136
139
* @param reportErrors Report an error if an instance with the same name exists already
137
140
*/
138
141
private def addDerivedInstance (clsName : Name , info : Type , pos : Position , reportErrors : Boolean ) = {
139
142
val instanceName = s " derived $$ $clsName" .toTermName
140
143
if (ctx.denotNamed(instanceName).exists) {
141
- if (reportErrors) ctx.error(i " duplicate typeclass derivation for $clsName" )
144
+ if (reportErrors) ctx.error(i " duplicate typeclass derivation for $clsName" , pos )
142
145
}
143
146
else add(newMethod(instanceName, info, pos, Implicit ))
144
147
}
@@ -162,8 +165,9 @@ trait Deriving { this: Typer =>
162
165
* that have the same name but different prefixes through selective aliasing.
163
166
*/
164
167
private def processDerivedInstance (derived : untpd.Tree ): Unit = {
165
- val uncheckedType = typedAheadType(derived, AnyTypeConstructorProto ).tpe.dealias
166
- val derivedType = checkClassType(uncheckedType, derived.pos, traitReq = false , stablePrefixReq = true )
168
+ val originalType = typedAheadType(derived, AnyTypeConstructorProto ).tpe
169
+ val underlyingType = underlyingClassRef(originalType)
170
+ val derivedType = checkClassType(underlyingType, derived.pos, traitReq = false , stablePrefixReq = true )
167
171
val nparams = derivedType.classSymbol.typeParams.length
168
172
if (nparams == 1 ) {
169
173
val typeClass = derivedType.classSymbol
@@ -174,7 +178,7 @@ trait Deriving { this: Typer =>
174
178
val instanceInfo =
175
179
if (cls.typeParams.isEmpty) ExprType (resultType)
176
180
else PolyType .fromParams(cls.typeParams, ImplicitMethodType (evidenceParamInfos, resultType))
177
- addDerivedInstance(derivedType .typeSymbol.name, instanceInfo, derived.pos, reportErrors = true )
181
+ addDerivedInstance(originalType .typeSymbol.name, instanceInfo, derived.pos, reportErrors = true )
178
182
}
179
183
else
180
184
ctx.error(
@@ -377,14 +381,20 @@ trait Deriving { this: Typer =>
377
381
def instantiated (info : Type ): Type = info match {
378
382
case info : PolyType => instantiated(info.instantiate(tparamRefs))
379
383
case info : MethodType => info.instantiate(params.map(_.termRef))
380
- case info => info
384
+ case info => info.widenExpr
385
+ }
386
+ def classAndCompanionRef (tp : Type ): (ClassSymbol , TermRef ) = tp match {
387
+ case tp @ TypeRef (prefix, _) if tp.symbol.isClass =>
388
+ (tp.symbol.asClass, prefix.select(tp.symbol.companionModule).asInstanceOf [TermRef ])
389
+ case tp : TypeProxy =>
390
+ classAndCompanionRef(tp.underlying)
381
391
}
382
392
val resultType = instantiated(sym.info)
383
- val typeCls = resultType.classSymbol
393
+ val ( typeCls, companionRef) = classAndCompanionRef( resultType)
384
394
if (typeCls == defn.ShapedClass )
385
395
shapedRHS(resultType)
386
396
else {
387
- val module = untpd.ref(typeCls.companionModule.termRef ).withPos(sym.pos)
397
+ val module = untpd.ref(companionRef ).withPos(sym.pos)
388
398
val rhs = untpd.Select (module, nme.derived)
389
399
typed(rhs, resultType)
390
400
}
0 commit comments