Skip to content

Add support for Java records in patterns. #19577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
55 changes: 52 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -495,14 +495,22 @@ object desugar {
case Select(qual, tpnme.AnyVal) => isScala(qual)
case _ => false
}

def isScala(tree: Tree): Boolean = tree match {
case Ident(nme.scala) => true
case Select(Ident(nme.ROOTPKG), nme.scala) => true
case _ => false
}

def isRecord(tree: Tree): Boolean = tree match {
case Select(Select(Select(Ident(nme.ROOTPKG), nme.java), nme.lang), tpnme.Record) => true
case _ => false
}

def namePos = cdef.sourcePos.withSpan(cdef.nameSpan)

val isJavaRecord = mods.is(JavaDefined) && parents.exists(isRecord)

val isObject = mods.is(Module)
val isCaseClass = mods.is(Case) && !isObject
val isCaseObject = mods.is(Case) && isObject
Expand Down Expand Up @@ -769,6 +777,11 @@ object desugar {

val companionMembers = defaultGetters ::: enumCases

def tupleApply(params: List[untpd.Tree]): untpd.Apply = {
val fun = Select(Ident(nme.scala), s"${StdNames.str.Tuple}$arity".toTermName)
Apply(fun, params)
}

// The companion object definitions, if a companion is needed, Nil otherwise.
// companion definitions include:
// 1. If class is a case class case class C[Ts](p1: T1, ..., pN: TN)(moreParams):
Expand Down Expand Up @@ -801,9 +814,8 @@ object desugar {
case vparam :: Nil =>
Apply(scalaDot(nme.Option), Select(Ident(unapplyParamName), vparam.name))
case vparams =>
val tupleApply = Select(Ident(nme.scala), s"Tuple$arity".toTermName)
val members = vparams.map(vparam => Select(Ident(unapplyParamName), vparam.name))
Apply(scalaDot(nme.Option), Apply(tupleApply, members))
val members = vparams.map(param => Select(Ident(unapplyParamName), param.name))
Apply(scalaDot(nme.Option), tupleApply(members))

val hasRepeatedParam = constrVparamss.head.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
Expand Down Expand Up @@ -832,6 +844,43 @@ object desugar {
companionDefs(anyRef, companionMembers)
else if isValueClass && !isObject then
companionDefs(anyRef, Nil)
else if (isJavaRecord) {

/** Get the canonical constructor of the Java record.
*
* Java classes have a dummy constructor; see [[JavaParsers.makeTemplate]] for
* more details
*/
def canonicalConstructor(impl: Template): DefDef = {
impl.body.collectFirst {
case ddef: DefDef if ddef.name.isConstructorName && ddef.mods.is(Synthetic) =>
ddef
}.get
}

val constr1 = canonicalConstructor(impl)
val tParams = constr1.leadingTypeParams
val vParams = asTermOnly(constr1.trailingParamss).head
val arity = vParams.length

val classTypeRef = appliedRef(classTycon)

val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)

val unapplyRHS =
if (arity == 0) Literal(Constant(true))
else Ident(unapplyParam.name)

val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()

val unapplyMeth = DefDef(
nme.unapply,
joinParams(derivedTparams, (unapplyParam :: Nil) :: Nil),
unapplyResTp,
unapplyRHS
).withMods(synthetic | Inline)
companionDefs(anyRef, unapplyMeth :: Nil)
}
else Nil

enumCompanionRef match {
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,9 @@ class Definitions {
@tu lazy val RuntimeTuples_isInstanceOfEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfEmptyTuple")
@tu lazy val RuntimeTuples_isInstanceOfNonEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfNonEmptyTuple")

@tu lazy val JavaRecordReflectMirrorTypeRef: TypeRef = requiredClassRef("scala.runtime.JavaRecordMirror")
@tu lazy val JavaRecordReflectMirrorModule: Symbol = requiredModule("scala.runtime.JavaRecordMirror")

@tu lazy val TupledFunctionTypeRef: TypeRef = requiredClassRef("scala.util.TupledFunction")
def TupledFunctionClass(using Context): ClassSymbol = TupledFunctionTypeRef.symbol.asClass
def RuntimeTupleFunctionsModule(using Context): Symbol = requiredModule("scala.runtime.TupledFunctions")
Expand Down Expand Up @@ -1636,6 +1639,7 @@ class Definitions {
def isAbstractFunctionClass(cls: Symbol): Boolean = isVarArityClass(cls, str.AbstractFunction)
def isTupleClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Tuple)
def isProductClass(cls: Symbol): Boolean = isVarArityClass(cls, str.Product)
def isJavaRecordClass(cls: Symbol): Boolean = cls.is(JavaDefined) && cls.derivesFrom(JavaRecordClass)

def isBoxedUnitClass(cls: Symbol): Boolean =
cls.isClass && (cls.owner eq ScalaRuntimePackageClass) && cls.name == tpnme.BoxedUnit
Expand Down
17 changes: 14 additions & 3 deletions compiler/src/dotty/tools/dotc/core/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,14 @@ class SymUtils:
def canAccessCtor: Boolean =
def isAccessible(sym: Symbol): Boolean = ctx.owner.isContainedIn(sym)
def isSub(sym: Symbol): Boolean = ctx.owner.ownersIterator.exists(_.derivesFrom(sym))
val ctor = self.primaryConstructor
val ctor = if defn.isJavaRecordClass(self) then self.javaCanonicalConstructor else self.primaryConstructor
(!ctor.isOneOf(Private | Protected) || isSub(self)) // we cant access the ctor because we do not extend cls
&& (!ctor.privateWithin.exists || isAccessible(ctor.privateWithin)) // check scope is compatible


def companionMirror = self.useCompanionAsProductMirror
if (!self.is(CaseClass)) "it is not a case class"

if (!(self.is(CaseClass) || defn.isJavaRecordClass(self))) "it is not a case class or record class"
else if (self.is(Abstract)) "it is an abstract class"
else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list"
else if self.isDerivedValueClass then "it is a value class"
Expand Down Expand Up @@ -146,7 +147,7 @@ class SymUtils:
&& (!self.is(Method) || self.is(Accessor))

def useCompanionAsProductMirror(using Context): Boolean =
self.linkedClass.exists && !self.is(Scala2x) && !self.linkedClass.is(Case)
self.linkedClass.exists && !self.is(Scala2x) && !self.linkedClass.is(Case) && !defn.isJavaRecordClass(self)

def useCompanionAsSumMirror(using Context): Boolean =
def companionExtendsSum(using Context): Boolean =
Expand Down Expand Up @@ -249,6 +250,16 @@ class SymUtils:
def caseAccessors(using Context): List[Symbol] =
self.info.decls.filter(_.is(CaseAccessor))

// TODO: I'm convinced that we need to introduce a flag to get the canonical constructor.
// we should also check whether the names are erased in the ctor. If not, we should
// be able to infer the components directly from the constructor.
def javaCanonicalConstructor(using Context): Symbol =
self.info.decls.filter(_.isConstructor).tail.head

// TODO: Check if `Synthetic` is stamped properly
def javaRecordComponents(using Context): List[Symbol] =
self.info.decls.filter(sym => sym.is(Synthetic) && sym.is(Method) && !sym.isConstructor)

def getter(using Context): Symbol =
if (self.isGetter) self else accessorNamed(self.asTerm.name.getterName)

Expand Down
14 changes: 12 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,18 @@ object PatternMatcher {
.select(defn.RuntimeTuples_apply)
.appliedTo(receiver, Literal(Constant(i)))

if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
def resultTypeSym = unapp.symbol.info.resultType.typeSymbol

// TODO: Check Scala -> Java, erased?
def isJavaRecordUnapply(sym: Symbol) = defn.isJavaRecordClass(resultTypeSym)
def tupleSel(sym: Symbol) = ref(scrutinee).select(sym)
def recordSel(sym: Symbol) = tupleSel(sym).appliedToTermArgs(Nil)

// TODO: Move this to the correct location
if (isJavaRecordUnapply(unapp.symbol.owner))
val components = resultTypeSym.javaRecordComponents.map(recordSel)
matchArgsPlan(components, args, onSuccess)
else if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length)
val isGenericTuple = defn.isTupleClass(caseClass) &&
!defn.isTupleNType(tree.tpe match { case tp: OrType => tp.join case tp => tp }) // widen even hard unions, to see if it's a union of tuples
val components = if isGenericTuple then caseAccessors.indices.toList.map(tupleApp(_, ref(scrutinee))) else caseAccessors.map(tupleSel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
def computeFromCaseClass: (Type, List[Type]) =
val (baseRef, baseInfo) =
val rawRef = caseClass.typeRef
// TODO: HERE!!
val rawInfo = caseClass.primaryConstructor.info
optInfo match
case Some(info) =>
Expand Down
20 changes: 20 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ object Applications {
ref.info.widenExpr.annotatedToRepeated
}

// TODO: Error here
def isJavaRecordMatch(tp: Type, numArgs: Int, errorPos: SrcPos = NoSourcePosition)(using Context): Boolean =
defn.isJavaRecordClass(tp.typeSymbol)

/** Does `tp` fit the "product match" conditions as an unapply result type
* for a pattern with `numArgs` subpatterns?
* This is the case if `tp` has members `_1` to `_N` where `N == numArgs`.
Expand Down Expand Up @@ -108,6 +112,20 @@ object Applications {
if (isValid) elemTp else NoType
}

def javaRecordComponentTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {

val params = tp.typeSymbol.javaRecordComponents.map(_.info.resultType)

tp match
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is surely too standard to need to do it by hand...

case tp: AppliedType =>
val argsIter = tp.args.iterator
for (param <- params) yield param match
case param if param.typeSymbol.isTypeParam => argsIter.next()
case param => param
case _ => params

}

def productSelectorTypes(tp: Type, errorPos: SrcPos)(using Context): List[Type] = {
val sels = for (n <- Iterator.from(0)) yield extractorMemberType(tp, nme.selectorName(n), errorPos)
sels.takeWhile(_.exists).toList
Expand Down Expand Up @@ -192,6 +210,8 @@ object Applications {
// which is better than the message in `fail`.
else if unapplyResult.derivesFrom(defn.NonEmptyTupleClass) then
foldApplyTupleType(unapplyResult)
else if (isJavaRecordMatch(unapplyResult, args.length, pos)) then
javaRecordComponentTypes(unapplyResult, pos)
else fail
}
}
Expand Down
5 changes: 1 addition & 4 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,6 @@ class Namer { typer: Typer =>
*/
private def invalidateIfClashingSynthetic(denot: SymDenotation): Unit =

def isJavaRecord(owner: Symbol) =
owner.is(JavaDefined) && owner.derivesFrom(defn.JavaRecordClass)

def isCaseClassOrCompanion(owner: Symbol) =
owner.isClass && {
if (owner.is(Module)) owner.linkedClass.is(CaseClass)
Expand All @@ -907,7 +904,7 @@ class Namer { typer: Typer =>
)
||
// remove synthetic constructor or method of a java Record if it clashes with a non-synthetic constructor
(isJavaRecord(denot.owner)
(defn.isJavaRecordClass(denot.owner)
&& denot.is(Method)
&& denot.owner.unforcedDecls.lookupAll(denot.name).exists(c => c != denot.symbol && c.info.matches(denot.info))
)
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
def newTupleMirror(arity: Int): Tree =
New(defn.RuntimeTupleMirrorTypeRef, Literal(Constant(arity)) :: Nil)

def newJavaRecordReflectMirror(tpe: Type) =
Copy link
Contributor Author

@yilinwei yilinwei Feb 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mirror is of limited use and requires the user to check the type at the derivation site for the mirrored mono type. This isn't the end of the world - we could add an extension method which uses the RecordComponent to invoke the accessor, but can we add a method to the mirror like productElement(a: A)(i: Int) which allows users which derive with the mirror to hide the implementation details?

Maybe this doesn't matter because actual practitioners would use the dotty reflect to select the field labels.

ref(defn.JavaRecordReflectMirrorModule)
.select(nme.apply)
.appliedToType(tpe)
.appliedTo(clsOf(tpe))

def makeProductMirror(pre: Type, cls: Symbol, tps: Option[List[Type]]): TreeWithErrors =
val accessors = cls.caseAccessors
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
Expand All @@ -427,6 +433,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
}
val mirrorRef =
if cls.useCompanionAsProductMirror then companionPath(mirroredType, span)
else if defn.isJavaRecordClass(cls) then newJavaRecordReflectMirror(cls.typeRef)
else if defn.isTupleClass(cls) then newTupleMirror(typeElems.size) // TODO: cls == defn.PairClass when > 22
else anonymousMirror(monoType, MirrorImpl.OfProduct(pre), span)
withNoErrors(mirrorRef.cast(mirrorType).withSpan(span))
Expand Down Expand Up @@ -458,6 +465,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val reason = s"it reduces to a tuple with arity $arity, expected arity <= $maxArity"
withErrors(i"${defn.PairClass} is not a generic product because $reason")
case MirrorSource.ClassSymbol(pre, cls) =>

if cls.isGenericProduct then
if ctx.runZincPhases then
// The mirror should be resynthesized if the constructor of the
Expand Down
13 changes: 10 additions & 3 deletions compiler/test/dotty/tools/dotc/CompilationTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class CompilationTests {
import CompilationTests._
import CompilationTest.aggregateTests

def `isJava16+`: Boolean = scala.util.Properties.isJavaAtLeast("16")

// Positive tests ------------------------------------------------------------

@Test def pos: Unit = {
Expand All @@ -50,7 +52,7 @@ class CompilationTests {
else Nil
)

if scala.util.Properties.isJavaAtLeast("16") then
if `isJava16+` then
tests ::= compileFilesInDir("tests/pos-java16+", defaultOptions.and("-Ysafe-init"))

aggregateTests(tests*).checkCompile()
Expand Down Expand Up @@ -155,13 +157,18 @@ class CompilationTests {

@Test def runAll: Unit = {
implicit val testGroup: TestGroup = TestGroup("runAll")
aggregateTests(
var tests = List(
compileFilesInDir("tests/run", defaultOptions.and("-Ysafe-init")),
compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes),
compileFilesInDir("tests/run-custom-args/captures", allowDeepSubtypes.and("-language:experimental.captureChecking")),
// Run tests for legacy lazy vals.
compileFilesInDir("tests/run", defaultOptions.and("-Ysafe-init", "-Ylegacy-lazy-vals", "-Ycheck-constraint-deps"), FileFilter.include(TestSources.runLazyValsAllowlist)),
).checkRuns()
)

if `isJava16+` then
tests ::= compileFilesInDir("tests/run-java16+", defaultOptions.and("-Ysafe-init"))

aggregateTests(tests*).checkRuns()
}

// Generic java signatures tests ---------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions library/src/scala/runtime/JavaRecordMirror.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package scala.runtime

import java.lang.Record
import java.lang.reflect.Constructor
import scala.reflect.ClassTag

// TODO: Rename to JavaRecordReflectMirror
object JavaRecordMirror:
Copy link
Contributor Author

@yilinwei yilinwei Feb 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to be a big deal? It means that compiling the scala library would need JDK 16+ but could target older versions, but having this on the classpath and not loaded should not be an issue? Maybe I could exclude the source, if records are not supported but would that cause issues with the current publishing process?


def apply[T <: Record](clazz: Class[T]): JavaRecordMirror[T] =
val components = clazz.getRecordComponents.nn
val constructorTypes = components.map(_.nn.getType.nn)
val constr = clazz.getDeclaredConstructor(constructorTypes*).nn
new JavaRecordMirror(components.length, constr)

def of[T <: Record : ClassTag]: JavaRecordMirror[T] =
JavaRecordMirror(summon[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]])

// TODO: Is a constructor serializable?
final class JavaRecordMirror[T] private(arity: Int, constr: Constructor[T]) extends scala.deriving.Mirror.Product with Serializable:

override type MirroredMonoType <: Record

final def fromProduct(product: Product): MirroredMonoType =
if product.productArity != arity then
throw IllegalArgumentException(s"expected Product with $arity elements, got ${product.productArity}")
else
// TODO: Check this byte code, we want to unroll to give a happy medium between JIT'ing and having tons of extra classes
val t = arity match
case 0 => constr.newInstance()
case 1 => constr.newInstance(product.productElement(0))
case 2 => constr.newInstance(product.productElement(0), product.productElement(1))

t.nn.asInstanceOf[MirroredMonoType]
5 changes: 5 additions & 0 deletions tests/pos-java16+/java-records-mirror/FromScala.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import scala.deriving.Mirror

object C:
def useR2: Unit =
summon[Mirror.Of[R2]]
1 change: 1 addition & 0 deletions tests/pos-java16+/java-records-mirror/R2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R2(int i, String s) {}
31 changes: 31 additions & 0 deletions tests/pos-java16+/java-records-patmatch/FromScala.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
object C:

def useR0: Unit =
val r = R0()

// unapply in valdef
val R0() = r

// unapply in patmatch
r match {
case R0() =>
}


def useR1: Int =
val r = R1(1, "foo")

// unapply in valdef
val R1(i, _) = r
val a: Int = i

// unapply in patmatch
r match {
case R1(i, _) => i
}

def useR2: String =
val r = R2("asd")
r match {
case R2(s) => s
}
1 change: 1 addition & 0 deletions tests/pos-java16+/java-records-patmatch/R0.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R0() {}
1 change: 1 addition & 0 deletions tests/pos-java16+/java-records-patmatch/R1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R1(int i, String s) {}
1 change: 1 addition & 0 deletions tests/pos-java16+/java-records-patmatch/R2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R2<T>(T t) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records/R0.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R0() {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records/R1.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R1(int i) {}
1 change: 1 addition & 0 deletions tests/run-java16+/java-records/R2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public record R2<T>(T t, int i) {}
Empty file.
Loading