Skip to content

Commit 0d277a1

Browse files
committed
First-class Java record pattern matching.
This adds Java records as a first-class citizen within Scala's pattern matching. Broadly, we synthesize an `inline unapply` on the companion module which we later use to correlate with the bound variables within the pattern matcher.
1 parent 7b44500 commit 0d277a1

File tree

8 files changed

+51
-11
lines changed

8 files changed

+51
-11
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -869,13 +869,10 @@ object desugar {
869869

870870
val unapplyRHS =
871871
if (arity == 0) Literal(Constant(true))
872-
else tupleApply(
873-
vParams.map(
874-
param => Select(Ident(unapplyParam.name), param.name)
875-
)
876-
)
872+
else Ident(unapplyParam.name)
877873

878874
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
875+
879876
val unapplyMeth = DefDef(
880877
nme.unapply,
881878
joinParams(derivedTparams, (unapplyParam :: Nil) :: Nil),

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,7 @@ class Definitions {
16361636
def isAbstractFunctionClass(cls: Symbol): Boolean = isVarArityClass(cls, str.AbstractFunction)
16371637
def isTupleClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Tuple)
16381638
def isProductClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Product)
1639+
def isJavaRecordClass(cls: Symbol): Boolean = cls.is(JavaDefined) && cls.derivesFrom(JavaRecordClass)
16391640

16401641
def isBoxedUnitClass(cls: Symbol): Boolean =
16411642
cls.isClass && (cls.owner eq ScalaRuntimePackageClass) && cls.name == tpnme.BoxedUnit

compiler/src/dotty/tools/dotc/core/SymUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ class SymUtils:
249249
def caseAccessors(using Context): List[Symbol] =
250250
self.info.decls.filter(_.is(CaseAccessor))
251251

252+
// TODO: Check if `Synthetic` is stamped properly
253+
def javaRecordComponents(using Context): List[Symbol] =
254+
self.info.decls.filter(sym => sym.is(Synthetic) && sym.is(Method) && !sym.isConstructor)
255+
252256
def getter(using Context): Symbol =
253257
if (self.isGetter) self else accessorNamed(self.asTerm.name.getterName)
254258

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,28 @@ object PatternMatcher {
343343
.select(defn.RuntimeTuples_apply)
344344
.appliedTo(receiver, Literal(Constant(i)))
345345

346-
if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
347-
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
346+
def resultTypeSym = unapp.symbol.info.resultType.typeSymbol
347+
348+
def isSyntheticJavaRecordUnapply(sym: Symbol) =
349+
// Since the `unapply` symbol is marked as inline, the `Typer` wraps the body of the `unapply` in a separate
350+
// anonymous class. The result type alone is not enough to distinguish that we're calling the synthesized unapply —
351+
// we could have defined a separate `unapply` method returning a Java record somewhere, hence we resort to using
352+
// the `coord`.
353+
sym.is(Synthetic) && sym.isAnonymousClass && {
354+
val resultSym = resultTypeSym
355+
// TODO: Can a user define a separate unapply function in Java?
356+
val unapplyFn = resultSym.linkedClass.info.decl(nme.unapply)
357+
// TODO: This is nasty, can we add an attachment on the anonymous function for a prior link?
358+
defn.isJavaRecordClass(resultSym) && unapplyFn.symbol.coord == sym.coord
359+
}
360+
361+
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
362+
def recordSel(sym: Symbol) = tupleSel(sym).appliedToTermArgs(Nil)
363+
364+
if (isSyntheticJavaRecordUnapply(unapp.symbol.owner))
365+
val components = resultTypeSym.javaRecordComponents.map(recordSel)
366+
matchArgsPlan(components, args, onSuccess)
367+
else if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
348368
val isGenericTuple = defn.isTupleClass(caseClass) &&
349369
!defn.isTupleNType(tree.tpe match { case tp: OrType => tp.join case tp => tp }) // widen even hard unions, to see if it's a union of tuples
350370
val components = if isGenericTuple then caseAccessors.indices.toList.map(tupleApp(_, ref(scrutinee))) else caseAccessors.map(tupleSel)

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
521521
def computeFromCaseClass: (Type, List[Type]) =
522522
val (baseRef, baseInfo) =
523523
val rawRef = caseClass.typeRef
524+
// TODO: HERE!!
524525
val rawInfo = caseClass.primaryConstructor.info
525526
optInfo match
526527
case Some(info) =>

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ object Applications {
4949
ref.info.widenExpr.annotatedToRepeated
5050
}
5151

52+
// TODO: Error here
53+
def isJavaRecordMatch(tp: Type, numArgs: Int, errorPos: SrcPos = NoSourcePosition)(using Context): Boolean =
54+
defn.isJavaRecordClass(tp.typeSymbol)
55+
5256
/** Does `tp` fit the "product match" conditions as an unapply result type
5357
* for a pattern with `numArgs` subpatterns?
5458
* This is the case if `tp` has members `_1` to `_N` where `N == numArgs`.
@@ -108,6 +112,20 @@ object Applications {
108112
if (isValid) elemTp else NoType
109113
}
110114

115+
def javaRecordComponentTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {
116+
117+
val params = tp.typeSymbol.javaRecordComponents.map(_.info.resultType)
118+
119+
tp match
120+
case tp: AppliedType =>
121+
val argsIter = tp.args.iterator
122+
for (param <- params) yield param match
123+
case param if param.typeSymbol.isTypeParam => argsIter.next()
124+
case param => param
125+
case _ => params
126+
127+
}
128+
111129
def productSelectorTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {
112130
val sels = for (n <- Iterator.from(0)) yield extractorMemberType(tp, nme.selectorName(n), errorPos)
113131
sels.takeWhile(_.exists).toList
@@ -192,6 +210,8 @@ object Applications {
192210
// which is better than the message in `fail`.
193211
else if unapplyResult.derivesFrom(defn.NonEmptyTupleClass) then
194212
foldApplyTupleType(unapplyResult)
213+
else if (isJavaRecordMatch(unapplyResult, args.length, pos)) then
214+
javaRecordComponentTypes(unapplyResult, pos)
195215
else fail
196216
}
197217
}

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,6 @@ class Namer { typer: Typer =>
880880
*/
881881
private def invalidateIfClashingSynthetic(denot: SymDenotation): Unit =
882882

883-
def isJavaRecord(owner: Symbol) =
884-
owner.is(JavaDefined) && owner.derivesFrom(defn.JavaRecordClass)
885-
886883
def isCaseClassOrCompanion(owner: Symbol) =
887884
owner.isClass && {
888885
if (owner.is(Module)) owner.linkedClass.is(CaseClass)
@@ -907,7 +904,7 @@ class Namer { typer: Typer =>
907904
)
908905
||
909906
// remove synthetic constructor or method of a java Record if it clashes with a non-synthetic constructor
910-
(isJavaRecord(denot.owner)
907+
(defn.isJavaRecordClass(denot.owner)
911908
&& denot.is(Method)
912909
&& denot.owner.unforcedDecls.lookupAll(denot.name).exists(c => c != denot.symbol && c.info.matches(denot.info))
913910
)

tests/run-java16+/java-records/R3.java

Whitespace-only changes.

0 commit comments

Comments
 (0)