Skip to content

Commit 996ad36

Browse files
committed
Merge pull request #598 from dotty-staging/add/simplify-primitives
Add/simplify primitives
2 parents 6ec4b0a + d07d669 commit 996ad36

File tree

5 files changed

+90
-5
lines changed

5 files changed

+90
-5
lines changed

src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ class Definitions {
539539

540540
// ----- primitive value class machinery ------------------------------------------
541541

542-
lazy val ScalaNumericValueClasses: collection.Set[Symbol] = Set(
542+
lazy val ScalaNumericValueClassList = List(
543543
ByteClass,
544544
ShortClass,
545545
CharClass,
@@ -548,6 +548,7 @@ class Definitions {
548548
FloatClass,
549549
DoubleClass)
550550

551+
lazy val ScalaNumericValueClasses: collection.Set[Symbol] = ScalaNumericValueClassList.toSet
551552
lazy val ScalaValueClasses: collection.Set[Symbol] = ScalaNumericValueClasses + UnitClass + BooleanClass
552553

553554
lazy val ScalaBoxedClasses = ScalaValueClasses map boxedClass

src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,9 @@ object SymDenotations {
386386
/** Is symbol a primitive value class? */
387387
def isPrimitiveValueClass(implicit ctx: Context) = defn.ScalaValueClasses contains symbol
388388

389+
/** Is symbol a primitive numeric value class? */
390+
def isNumericValueClass(implicit ctx: Context) = defn.ScalaNumericValueClasses contains symbol
391+
389392
/** Is symbol a phantom class for which no runtime representation exists? */
390393
def isPhantomClass(implicit ctx: Context) = defn.PhantomClasses contains symbol
391394

src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ trait Applications extends Compatibility { self: Typer =>
127127
*/
128128
protected def makeVarArg(n: Int, elemFormal: Type): Unit
129129

130+
/** If all `args` have primitive numeric types, make sure it's the same one */
131+
protected def harmonizeArgs(args: List[TypedArg]): List[TypedArg]
132+
130133
/** Signal failure with given message at position of given argument */
131134
protected def fail(msg: => String, arg: Arg): Unit
132135

@@ -334,7 +337,14 @@ trait Applications extends Compatibility { self: Typer =>
334337
addTyped(arg, formal)
335338
case _ =>
336339
val elemFormal = formal.widenExpr.argTypesLo.head
337-
args foreach (addTyped(_, elemFormal))
340+
val origConstraint = ctx.typerState.constraint
341+
var typedArgs = args.map(typedArg(_, elemFormal))
342+
val harmonizedArgs = harmonizeArgs(typedArgs)
343+
if (harmonizedArgs ne typedArgs) {
344+
ctx.typerState.constraint = origConstraint
345+
typedArgs = harmonizedArgs
346+
}
347+
typedArgs.foreach(addArg(_, elemFormal))
338348
makeVarArg(args.length, elemFormal)
339349
}
340350
else args match {
@@ -389,6 +399,7 @@ trait Applications extends Compatibility { self: Typer =>
389399
def argType(arg: Tree, formal: Type): Type = normalize(arg.tpe, formal)
390400
def treeToArg(arg: Tree): Tree = arg
391401
def isVarArg(arg: Tree): Boolean = tpd.isWildcardStarArg(arg)
402+
def harmonizeArgs(args: List[Tree]) = harmonize(args)
392403
}
393404

394405
/** Subclass of Application for applicability tests with type arguments and value
@@ -405,6 +416,7 @@ trait Applications extends Compatibility { self: Typer =>
405416
def argType(arg: Type, formal: Type): Type = arg
406417
def treeToArg(arg: Tree): Type = arg.tpe
407418
def isVarArg(arg: Type): Boolean = arg.isRepeatedParam
419+
def harmonizeArgs(args: List[Type]) = harmonizeTypes(args)
408420
}
409421

410422
/** Subclass of Application for type checking an Apply node, where
@@ -430,6 +442,8 @@ trait Applications extends Compatibility { self: Typer =>
430442
typedArgBuf += seqToRepeated(seqLit)
431443
}
432444

445+
def harmonizeArgs(args: List[TypedArg]) = harmonize(args)
446+
433447
override def appPos = app.pos
434448

435449
def fail(msg: => String, arg: Trees.Tree[T]) = {
@@ -1024,6 +1038,41 @@ trait Applications extends Compatibility { self: Typer =>
10241038
result
10251039
}
10261040
}
1041+
1042+
private def harmonizeWith[T <: AnyRef](ts: List[T])(tpe: T => Type, adapt: (T, Type) => T)(implicit ctx: Context): List[T] = {
1043+
def numericClasses(ts: List[T], acc: Set[Symbol]): Set[Symbol] = ts match {
1044+
case t :: ts1 =>
1045+
val sym = tpe(t).widen.classSymbol
1046+
if (sym.isNumericValueClass) numericClasses(ts1, acc + sym)
1047+
else Set()
1048+
case Nil =>
1049+
acc
1050+
}
1051+
val clss = numericClasses(ts, Set())
1052+
if (clss.size > 1) {
1053+
val lub = defn.ScalaNumericValueClassList.find(lubCls =>
1054+
clss.forall(defn.isValueSubClass(_, lubCls))).get.typeRef
1055+
ts.mapConserve(adapt(_, lub))
1056+
}
1057+
else ts
1058+
}
1059+
1060+
/** If `trees` all have numeric value types, and they do not have all the same type,
1061+
* pick a common numeric supertype and convert all trees to this type.
1062+
*/
1063+
def harmonize(trees: List[Tree])(implicit ctx: Context): List[Tree] = {
1064+
def adapt(tree: Tree, pt: Type): Tree = tree match {
1065+
case cdef: CaseDef => tpd.cpy.CaseDef(cdef)(body = adapt(cdef.body, pt))
1066+
case _ => adaptInterpolated(tree, pt, tree)
1067+
}
1068+
harmonizeWith(trees)(_.tpe, adapt)
1069+
}
1070+
1071+
/** If all `types` are numeric value types, and they are not all the same type,
1072+
* pick a common numeric supertype and return it instead of every original type.
1073+
*/
1074+
def harmonizeTypes(tpes: List[Type])(implicit ctx: Context): List[Type] =
1075+
harmonizeWith(tpes)(identity, (tp, pt) => pt)
10271076
}
10281077

10291078
/*

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
493493
val cond1 = typed(tree.cond, defn.BooleanType)
494494
val thenp1 = typed(tree.thenp, pt)
495495
val elsep1 = typed(tree.elsep orElse untpd.unitLiteral withPos tree.pos, pt)
496-
assignType(cpy.If(tree)(cond1, thenp1, elsep1), thenp1, elsep1)
496+
val thenp2 :: elsep2 :: Nil = harmonize(thenp1 :: elsep1 :: Nil)
497+
assignType(cpy.If(tree)(cond1, thenp2, elsep2), thenp2, elsep2)
497498
}
498499

499500
def typedFunction(tree: untpd.Function, pt: Type)(implicit ctx: Context) = track("typedFunction") {
@@ -629,7 +630,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
629630
fullyDefinedType(sel1.tpe, "pattern selector", tree.pos))
630631

631632
val cases1 = typedCases(tree.cases, selType, pt)
632-
assignType(cpy.Match(tree)(sel1, cases1), cases1)
633+
val cases2 = harmonize(cases1).asInstanceOf[List[CaseDef]]
634+
assignType(cpy.Match(tree)(sel1, cases2), cases2)
633635
}
634636
}
635637

@@ -737,7 +739,9 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
737739
val expr1 = typed(tree.expr, pt)
738740
val cases1 = typedCases(tree.cases, defn.ThrowableType, pt)
739741
val finalizer1 = typed(tree.finalizer, defn.UnitType)
740-
assignType(cpy.Try(tree)(expr1, cases1, finalizer1), expr1, cases1)
742+
val expr2 :: cases2x = harmonize(expr1 :: cases1)
743+
val cases2 = cases2x.asInstanceOf[List[CaseDef]]
744+
assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2)
741745
}
742746

743747
def typedThrow(tree: untpd.Throw)(implicit ctx: Context): Tree = track("typedThrow") {

tests/pos/harmonize.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
object Test {
2+
3+
def main(args: Array[String]) = {
4+
val x = true
5+
val n = 1
6+
val y = if (x) 'A' else n
7+
val z: Int = y
8+
9+
val yy = n match {
10+
case 1 => 'A'
11+
case 2 => n
12+
case 3 => 1.0
13+
}
14+
val zz: Double = yy
15+
16+
val a = try {
17+
'A'
18+
} catch {
19+
case ex: Exception => n
20+
case ex: Error => 3L
21+
}
22+
val b: Long = a
23+
24+
val xs = List(1.0, n, 'c')
25+
val ys: List[Double] = xs
26+
}
27+
28+
}

0 commit comments

Comments
 (0)