Skip to content

Add support for Java records in patterns. #19577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
55 changes: 52 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,22 @@ object desugar {
case Select(qual, tpnme.AnyVal) => isScala(qual)
case _ => false
}

def isScala(tree: Tree): Boolean = tree match {
case Ident(nme.scala) => true
case Select(Ident(nme.ROOTPKG), nme.scala) => true
case _ => false
}

def isRecord(tree: Tree): Boolean = tree match {
case Select(Select(Select(Ident(nme.ROOTPKG), nme.java), nme.lang), tpnme.Record) => true
case _ => false
}

def namePos = cdef.sourcePos.withSpan(cdef.nameSpan)

val isJavaRecord = mods.is(JavaDefined) && parents.exists(isRecord)

val isObject = mods.is(Module)
val isCaseClass = mods.is(Case) && !isObject
val isCaseObject = mods.is(Case) && isObject
Expand Down Expand Up @@ -769,6 +777,11 @@ object desugar {

val companionMembers = defaultGetters ::: enumCases

def tupleApply(params: List[untpd.Tree]): untpd.Apply = {
val fun = Select(Ident(nme.scala), s"${StdNames.str.Tuple}$arity".toTermName)
Apply(fun, params)
}

// The companion object definitions, if a companion is needed, Nil otherwise.
// companion definitions include:
// 1. If class is a case class case class C[Ts](p1: T1, ..., pN: TN)(moreParams):
Expand Down Expand Up @@ -801,9 +814,8 @@ object desugar {
case vparam :: Nil =>
Apply(scalaDot(nme.Option), Select(Ident(unapplyParamName), vparam.name))
case vparams =>
val tupleApply = Select(Ident(nme.scala), s"Tuple$arity".toTermName)
val members = vparams.map(vparam => Select(Ident(unapplyParamName), vparam.name))
Apply(scalaDot(nme.Option), Apply(tupleApply, members))
val members = vparams.map(param => Select(Ident(unapplyParamName), param.name))
Apply(scalaDot(nme.Option), tupleApply(members))

val hasRepeatedParam = constrVparamss.head.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
Expand Down Expand Up @@ -832,6 +844,43 @@ object desugar {
companionDefs(anyRef, companionMembers)
else if isValueClass && !isObject then
companionDefs(anyRef, Nil)
else if (isJavaRecord) {

/** Get the canonical constructor of the Java record.
*
* Java classes have a dummy constructor; see [[JavaParsers.makeTemplate]] for
* more details
*/
def canonicalConstructor(impl: Template): DefDef = {
impl.body.collectFirst {
case ddef: DefDef if ddef.name.isConstructorName && ddef.mods.is(Synthetic) =>
ddef
}.get
}

val constr1 = canonicalConstructor(impl)
val tParams = constr1.leadingTypeParams
val vParams = asTermOnly(constr1.trailingParamss).head
val arity = vParams.length

val classTypeRef = appliedRef(classTycon)

val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)

val unapplyRHS =
if (arity == 0) Literal(Constant(true))
else Ident(unapplyParam.name)

val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()

val unapplyMeth = DefDef(
nme.unapply,
joinParams(derivedTparams, (unapplyParam :: Nil) :: Nil),
unapplyResTp,
unapplyRHS
).withMods(synthetic | Inline)
companionDefs(anyRef, unapplyMeth :: Nil)
}
else Nil

enumCompanionRef match {
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,7 @@ class Definitions {
def isAbstractFunctionClass(cls: Symbol): Boolean = isVarArityClass(cls, str.AbstractFunction)
def isTupleClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Tuple)
def isProductClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Product)
def isJavaRecordClass(cls: Symbol): Boolean = cls.is(JavaDefined) && cls.derivesFrom(JavaRecordClass)

def isBoxedUnitClass(cls: Symbol): Boolean =
cls.isClass && (cls.owner eq ScalaRuntimePackageClass) && cls.name == tpnme.BoxedUnit
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ class SymUtils:
def caseAccessors(using Context): List[Symbol] =
self.info.decls.filter(_.is(CaseAccessor))

// TODO: Check if `Synthetic` is stamped properly
def javaRecordComponents(using Context): List[Symbol] =
self.info.decls.filter(sym => sym.is(Synthetic) && sym.is(Method) && !sym.isConstructor)

def getter(using Context): Symbol =
if (self.isGetter) self else accessorNamed(self.asTerm.name.getterName)

Expand Down
24 changes: 22 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,28 @@ object PatternMatcher {
.select(defn.RuntimeTuples_apply)
.appliedTo(receiver, Literal(Constant(i)))

if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
def resultTypeSym = unapp.symbol.info.resultType.typeSymbol

def isSyntheticJavaRecordUnapply(sym: Symbol) =
// Since the `unapply` symbol is marked as inline, the `Typer` wraps the body of the `unapply` in a separate
// anonymous class. The result type alone is not enough to distinguish that we're calling the synthesized unapply —
// we could have defined a separate `unapply` method returning a Java record somewhere, hence we resort to using
// the `coord`.
sym.is(Synthetic) && sym.isAnonymousClass && {
val resultSym = resultTypeSym
// TODO: Can a user define a separate unapply function in Java?
val unapplyFn = resultSym.linkedClass.info.decl(nme.unapply)
// TODO: This is nasty, can we add an attachment on the anonymous function for a prior link?
defn.isJavaRecordClass(resultSym) && unapplyFn.symbol.coord == sym.coord
}

def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
def recordSel(sym: Symbol) = tupleSel(sym).appliedToTermArgs(Nil)

if (isSyntheticJavaRecordUnapply(unapp.symbol.owner))
val components = resultTypeSym.javaRecordComponents.map(recordSel)
matchArgsPlan(components, args, onSuccess)
else if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
val isGenericTuple = defn.isTupleClass(caseClass) &&
!defn.isTupleNType(tree.tpe match { case tp: OrType => tp.join case tp => tp }) // widen even hard unions, to see if it's a union of tuples
val components = if isGenericTuple then caseAccessors.indices.toList.map(tupleApp(_, ref(scrutinee))) else caseAccessors.map(tupleSel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
def computeFromCaseClass: (Type, List[Type]) =
val (baseRef, baseInfo) =
val rawRef = caseClass.typeRef
// TODO: HERE!!
val rawInfo = caseClass.primaryConstructor.info
optInfo match
case Some(info) =>
Expand Down
20 changes: 20 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ object Applications {
ref.info.widenExpr.annotatedToRepeated
}

// TODO: Error here
def isJavaRecordMatch(tp: Type, numArgs: Int, errorPos: SrcPos = NoSourcePosition)(using Context): Boolean =
defn.isJavaRecordClass(tp.typeSymbol)

/** Does `tp` fit the "product match" conditions as an unapply result type
* for a pattern with `numArgs` subpatterns?
* This is the case if `tp` has members `_1` to `_N` where `N == numArgs`.
Expand Down Expand Up @@ -108,6 +112,20 @@ object Applications {
if (isValid) elemTp else NoType
}

def javaRecordComponentTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {

val params = tp.typeSymbol.javaRecordComponents.map(_.info.resultType)

tp match
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is surely too standard to need to do it by hand...

case tp: AppliedType =>
val argsIter = tp.args.iterator
for (param <- params) yield param match
case param if param.typeSymbol.isTypeParam => argsIter.next()
case param => param
case _ => params

}

def productSelectorTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {
val sels = for (n <- Iterator.from(0)) yield extractorMemberType(tp, nme.selectorName(n), errorPos)
sels.takeWhile(_.exists).toList
Expand Down Expand Up @@ -192,6 +210,8 @@ object Applications {
// which is better than the message in `fail`.
else if unapplyResult.derivesFrom(defn.NonEmptyTupleClass) then
foldApplyTupleType(unapplyResult)
else if (isJavaRecordMatch(unapplyResult, args.length, pos)) then
javaRecordComponentTypes(unapplyResult, pos)
else fail
}
}
Expand Down
5 changes: 1 addition & 4 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,6 @@ class Namer { typer: Typer =>
*/
private def invalidateIfClashingSynthetic(denot: SymDenotation): Unit =

def isJavaRecord(owner: Symbol) =
owner.is(JavaDefined) && owner.derivesFrom(defn.JavaRecordClass)

def isCaseClassOrCompanion(owner: Symbol) =
owner.isClass && {
if (owner.is(Module)) owner.linkedClass.is(CaseClass)
Expand All @@ -907,7 +904,7 @@ class Namer { typer: Typer =>
)
||
// remove synthetic constructor or method of a java Record if it clashes with a non-synthetic constructor
(isJavaRecord(denot.owner)
(defn.isJavaRecordClass(denot.owner)
&& denot.is(Method)
&& denot.owner.unforcedDecls.lookupAll(denot.name).exists(c => c != denot.symbol && c.info.matches(denot.info))
)
Expand Down
13 changes: 10 additions & 3 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CompilationTests {
import CompilationTests._
import CompilationTest.aggregateTests

def `isJava16+`: Boolean = scala.util.Properties.isJavaAtLeast("16")

// Positive tests ------------------------------------------------------------

@Test def pos: Unit = {
Expand All @@ -50,7 +52,7 @@ class CompilationTests {
else Nil
)

if scala.util.Properties.isJavaAtLeast("16") then
if `isJava16+` then
tests ::= compileFilesInDir("tests/pos-java16+", defaultOptions.and("-Ysafe-init"))

aggregateTests(tests*).checkCompile()
Expand Down Expand Up @@ -155,13 +157,18 @@ class CompilationTests {

@Test def runAll: Unit = {
implicit val testGroup: TestGroup = TestGroup("runAll")
aggregateTests(
var tests = List(
compileFilesInDir("tests/run", defaultOptions.and("-Ysafe-init")),
compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes),
compileFilesInDir("tests/run-custom-args/captures", allowDeepSubtypes.and("-language:experimental.captureChecking")),
// Run tests for legacy lazy vals.
compileFilesInDir("tests/run", defaultOptions.and("-Ysafe-init", "-Ylegacy-lazy-vals", "-Ycheck-constraint-deps"), FileFilter.include(TestSources.runLazyValsAllowlist)),
).checkRuns()
)

if `isJava16+` then
tests ::= compileFilesInDir("tests/run-java16+", defaultOptions.and("-Ysafe-init"))

aggregateTests(tests*).checkRuns()
}

// Generic java signatures tests ---------------------------------------------
Expand Down
31 changes: 31 additions & 0 deletions tests/pos-java16+/java-records-patmatch/FromScala.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
object C:

def useR0: Unit =
val r = R0()

// unapply in valdef
val R0() = r

// unapply in patmatch
r match {
case R0() =>
}


def useR1: Int =
val r = R1(1, "foo")

// unapply in valdef
val R1(i, _) = r
val a: Int = i

// unapply in patmatch
r match {
case R1(i, _) => i
}

def useR2: String =
val r = R2("asd")
r match {
case R2(s) => s
}
1 change: 1 addition & 0 deletions tests/pos-java16+/java-records-patmatch/R0.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R0() {}
1 change: 1 addition & 0 deletions tests/pos-java16+/java-records-patmatch/R1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R1(int i, String s) {}
1 change: 1 addition & 0 deletions tests/pos-java16+/java-records-patmatch/R2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R2<T>(T t) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records/R0.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R0() {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records/R1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R1(int i) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records/R2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R2<T>(T t, int i) {}
Empty file.
21 changes: 21 additions & 0 deletions tests/run-java16+/java-records/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// scalajs: --skip

object Test:
def main(args: Array[String]): Unit =
val r0 = R0()
r0 match
case R0() =>

val r1 = R1(42)
r1 match
case R1(i) => assert(i == 42)

val R1(i) = r1
assert(i == 42)

val r2 = R2("foo", 9)
val R2(s, _) = r2
assert(s == "foo")

r2 match
case R2(_, i) => assert(i == 9)