@@ -43,10 +43,23 @@ class ClassTags extends MiniPhaseTransform with IdentityDenotTransformer { thisT
43
43
override def transformTypeApply (tree : tpd.TypeApply )(implicit ctx : Context , info : TransformerInfo ): tpd.Tree =
44
44
if (tree.fun.symbol eq classTagCache) {
45
45
val tp = tree.args.head.tpe
46
+ val defn = ctx.definitions
47
+ val (elemType, ndims) = tp match {
48
+ case defn.MultiArrayType (elem, ndims) => (elem, ndims)
49
+ case _ => (tp, 0 )
50
+ }
51
+
46
52
val claz = tp.classSymbol
53
+ val elemClaz = elemType.classSymbol
47
54
assert(! claz.isPrimitiveValueClass) // should be inserted by typer
48
- if (ValueClasses .isDerivedValueClass(claz)) ref(claz.companionModule)
49
- else if (claz eq defn.AnyClass ) ref(scala2ClassTagModule).select(nme.Any ).ensureConforms(tree.tpe)
50
- else ref(scala2ClassTagModule).select(nme.apply).appliedToType(tp).appliedTo(Literal (Constant (claz.typeRef)))
55
+ val elemTag = if (defn.ScalaValueClasses .contains(elemClaz) || elemClaz == defn.NothingClass || elemClaz == defn.NullClass )
56
+ ref(defn.DottyPredefModule ).select(s " ${elemClaz.name}ClassTag " .toTermName)
57
+ else if (ValueClasses .isDerivedValueClass(elemClaz)) ref(claz.companionModule)
58
+ else if (elemClaz eq defn.AnyClass ) ref(scala2ClassTagModule).select(nme.Any )
59
+ else {
60
+ val erazedTp = TypeErasure .erasure(elemType).classSymbol.typeRef
61
+ ref(scala2ClassTagModule).select(nme.apply).appliedToType(erazedTp).appliedTo(Literal (Constant (erazedTp)))
62
+ }
63
+ (1 to ndims).foldLeft(elemTag)((arr, level) => Select (arr, nme.wrap).ensureApplied).ensureConforms(tree.tpe)
51
64
} else tree
52
65
}
0 commit comments