Skip to content

Commit 9cb0649

Browse files
committed
recursively check for product ctor accessibility
1 parent 965f164 commit 9cb0649

File tree

6 files changed

+93
-56
lines changed

6 files changed

+93
-56
lines changed

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,26 @@ object SymUtils:
8282
* parameter section.
8383
*/
8484
def whyNotGenericProduct(using Context): String =
85+
/** for a case class, if it will have an anonymous mirror,
86+
* check that its constructor can be accessed
87+
* from the calling scope.
88+
*/
89+
def canAccessCtor: Boolean =
90+
def isAccessible(sym: Symbol): Boolean = ctx.owner.isContainedIn(sym)
91+
def isSub(sym: Symbol): Boolean = ctx.owner.ownersIterator.exists(_.derivesFrom(sym))
92+
val ctor = self.primaryConstructor
93+
(!ctor.isOneOf(Private | Protected) || isSub(self)) // we cant access the ctor because we do not extend cls
94+
&& (!ctor.privateWithin.exists || isAccessible(ctor.privateWithin)) // check scope is compatible
95+
96+
97+
val companionMirror = self.useCompanionAsProductMirror
8598
if (!self.is(CaseClass)) "it is not a case class"
8699
else if (self.is(Abstract)) "it is an abstract class"
87100
else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list"
88101
else if (isDerivedValueClass(self)) "it is a value class"
102+
else if (!(companionMirror || canAccessCtor)) s"the constructor of $self is innaccessible from the calling scope."
89103
else ""
104+
end whyNotGenericProduct
90105

91106
def isGenericProduct(using Context): Boolean = whyNotGenericProduct.isEmpty
92107

@@ -120,6 +135,9 @@ object SymUtils:
120135
self.isOneOf(FinalOrInline, butNot = Mutable)
121136
&& (!self.is(Method) || self.is(Accessor))
122137

138+
def useCompanionAsProductMirror(using Context): Boolean =
139+
self.linkedClass.exists && !self.is(Scala2x) && !self.linkedClass.is(Case)
140+
123141
def useCompanionAsSumMirror(using Context): Boolean =
124142
def companionExtendsSum(using Context): Boolean =
125143
self.linkedClass.isSubClass(defn.Mirror_SumClass)
@@ -145,39 +163,39 @@ object SymUtils:
145163
* and also the location of the generated mirror.
146164
* - all of its children are generic products, singletons, or generic sums themselves.
147165
*/
148-
def whyNotGenericSum(declScope: Symbol)(using Context): String =
166+
def whyNotGenericSum(using Context): String =
149167
if (!self.is(Sealed))
150168
s"it is not a sealed ${self.kindString}"
151169
else if (!self.isOneOf(AbstractOrTrait))
152170
"it is not an abstract class"
153171
else {
154172
val children = self.children
155173
val companionMirror = self.useCompanionAsSumMirror
156-
assert(!(companionMirror && (declScope ne self.linkedClass)))
157174
def problem(child: Symbol) = {
158175

159176
def isAccessible(sym: Symbol): Boolean =
160-
(self.isContainedIn(sym) && (companionMirror || declScope.isContainedIn(sym)))
177+
(self.isContainedIn(sym) && (companionMirror || ctx.owner.isContainedIn(sym)))
161178
|| sym.is(Module) && isAccessible(sym.owner)
162179

163180
if (child == self) "it has anonymous or inaccessible subclasses"
164181
else if (!isAccessible(child.owner)) i"its child $child is not accessible"
165-
else if (!child.isClass) ""
182+
else if (!child.isClass) "" // its a singleton enum value
166183
else {
167184
val s = child.whyNotGenericProduct
168-
if (s.isEmpty) s
169-
else if (child.is(Sealed)) {
170-
val s = child.whyNotGenericSum(if child.useCompanionAsSumMirror then child.linkedClass else ctx.owner)
171-
if (s.isEmpty) s
185+
if s.isEmpty then s
186+
else if child.is(Sealed) then
187+
val s = child.whyNotGenericSum
188+
if s.isEmpty then s
172189
else i"its child $child is not a generic sum because $s"
173-
} else i"its child $child is not a generic product because $s"
190+
else
191+
i"its child $child is not a generic product because $s"
174192
}
175193
}
176194
if (children.isEmpty) "it does not have subclasses"
177195
else children.map(problem).find(!_.isEmpty).getOrElse("")
178196
}
179197

180-
def isGenericSum(declScope: Symbol)(using Context): Boolean = whyNotGenericSum(declScope).isEmpty
198+
def isGenericSum(using Context): Boolean = whyNotGenericSum.isEmpty
181199

182200
/** If this is a constructor, its owner: otherwise this. */
183201
final def skipConstructor(using Context): Symbol =

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,9 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
594594
if (clazz.is(Module)) {
595595
if (clazz.is(Case)) makeSingletonMirror()
596596
else if (linked.isGenericProduct) makeProductMirror(linked)
597-
else if (linked.isGenericSum(clazz)) makeSumMirror(linked)
597+
else if (linked.isGenericSum) makeSumMirror(linked)
598598
else if (linked.is(Sealed))
599-
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum(clazz)}")
599+
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum}")
600600
}
601601
else if (impl.removeAttachment(ExtendsSingletonMirror).isDefined)
602602
makeSingletonMirror()

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
2626
/** Handlers to synthesize implicits for special types */
2727
type SpecialHandler = (Type, Span) => Context ?=> TreeWithErrors
2828
private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)]
29-
29+
3030
val synthesizedClassTag: SpecialHandler = (formal, span) =>
3131
formal.argInfos match
3232
case arg :: Nil =>
@@ -285,22 +285,6 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
285285
case OrType(tp1, tp2) => acceptable(tp1, cls) && acceptable(tp2, cls)
286286
case _ => tp.classSymbol eq cls
287287

288-
/** for a case class, if it will have an anonymous mirror,
289-
* check that its constructor can be accessed
290-
* from the calling scope.
291-
*/
292-
def canAccessCtor(cls: Symbol): Boolean =
293-
!genAnonyousMirror(cls) || {
294-
def isAccessible(sym: Symbol): Boolean = ctx.owner.isContainedIn(sym)
295-
def isSub(sym: Symbol): Boolean = ctx.owner.ownersIterator.exists(_.derivesFrom(sym))
296-
val ctor = cls.primaryConstructor
297-
(!ctor.isOneOf(Private | Protected) || isSub(cls)) // we cant access the ctor because we do not extend cls
298-
&& (!ctor.privateWithin.exists || isAccessible(ctor.privateWithin)) // check scope is compatible
299-
}
300-
301-
def genAnonyousMirror(cls: Symbol): Boolean =
302-
cls.is(Scala2x) || cls.linkedClass.is(Case)
303-
304288
def makeProductMirror(cls: Symbol): TreeWithErrors =
305289
val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
306290
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
@@ -318,21 +302,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
318302
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
319303
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
320304
val mirrorRef =
321-
if (genAnonyousMirror(cls)) anonymousMirror(monoType, ExtendsProductMirror, span)
322-
else companionPath(mirroredType, span)
305+
if cls.useCompanionAsProductMirror then companionPath(mirroredType, span)
306+
else anonymousMirror(monoType, ExtendsProductMirror, span)
323307
withNoErrors(mirrorRef.cast(mirrorType))
324308
end makeProductMirror
325309

326-
def getError(cls: Symbol): String =
327-
val reason = if !cls.isGenericProduct then
328-
i"because ${cls.whyNotGenericProduct}"
329-
else if !canAccessCtor(cls) then
330-
i"because the constructor of $cls is innaccessible from the calling scope."
331-
else
332-
""
333-
i"$cls is not a generic product $reason"
334-
end getError
335-
336310
mirroredType match
337311
case AndType(tp1, tp2) =>
338312
orElse(productMirror(tp1, formal, span), productMirror(tp2, formal, span))
@@ -349,21 +323,19 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
349323
withNoErrors(modulePath.cast(mirrorType))
350324
else
351325
val cls = mirroredType.classSymbol
352-
if acceptable(mirroredType, cls)
353-
&& cls.isGenericProduct
354-
&& canAccessCtor(cls)
355-
then
326+
val clsIsGenericProduct = cls.isGenericProduct
327+
if acceptable(mirroredType, cls) && clsIsGenericProduct then
356328
makeProductMirror(cls)
329+
else if !clsIsGenericProduct then
330+
(EmptyTree, List(i"$cls is not a generic product because ${cls.whyNotGenericProduct}"))
357331
else
358-
(EmptyTree, List(getError(cls)))
332+
EmptyTreeNoError
359333
end productMirror
360334

361335
private def sumMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =
362336

363337
val cls = mirroredType.classSymbol
364-
val useCompanion = cls.useCompanionAsSumMirror
365-
val declScope = if useCompanion then cls.linkedClass else ctx.owner
366-
val clsIsGenericSum = cls.isGenericSum(declScope)
338+
val clsIsGenericSum = cls.isGenericSum
367339

368340
def acceptable(tp: Type): Boolean = tp match
369341
case tp: TermRef => false
@@ -423,12 +395,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
423395
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
424396
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
425397
val mirrorRef =
426-
if useCompanion then companionPath(mirroredType, span)
398+
if cls.useCompanionAsSumMirror then companionPath(mirroredType, span)
427399
else anonymousMirror(monoType, ExtendsSumMirror, span)
428400
withNoErrors(mirrorRef.cast(mirrorType))
429401
else if !clsIsGenericSum then
430-
(EmptyTree, List(i"$cls is not a generic sum because ${cls.whyNotGenericSum(declScope)}"))
431-
else
402+
(EmptyTree, List(i"$cls is not a generic sum because ${cls.whyNotGenericSum}"))
403+
else
432404
EmptyTreeNoError
433405
end sumMirror
434406

@@ -595,7 +567,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
595567
tp.baseType(cls)
596568
val base = baseWithRefinements(formal)
597569
val result =
598-
if (base <:< formal.widenExpr)
570+
if (base <:< formal.widenExpr)
599571
// With the subtype test we enforce that the searched type `formal` is of the right form
600572
handler(base, span)
601573
else EmptyTreeNoError
@@ -609,19 +581,19 @@ end Synthesizer
609581

610582
object Synthesizer:
611583

612-
/** Tuple used to store the synthesis result with a list of errors. */
584+
/** Tuple used to store the synthesis result with a list of errors. */
613585
type TreeWithErrors = (Tree, List[String])
614586
private def withNoErrors(tree: Tree): TreeWithErrors = (tree, List.empty)
615587

616588
private val EmptyTreeNoError: TreeWithErrors = withNoErrors(EmptyTree)
617589

618590
private def orElse(treeWithErrors1: TreeWithErrors, treeWithErrors2: => TreeWithErrors): TreeWithErrors = treeWithErrors1 match
619-
case (tree, errors) if tree eq genericEmptyTree =>
591+
case (tree, errors) if tree eq genericEmptyTree =>
620592
val (tree2, errors2) = treeWithErrors2
621593
(tree2, errors ::: errors2)
622594
case _ => treeWithErrors1
623595

624-
private def clearErrorsIfNotEmpty(treeWithErrors: TreeWithErrors) = treeWithErrors match
596+
private def clearErrorsIfNotEmpty(treeWithErrors: TreeWithErrors) = treeWithErrors match
625597
case (tree, _) if tree eq genericEmptyTree => treeWithErrors
626598
case (tree, _) => withNoErrors(tree)
627599

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import scala.deriving.Mirror
2+
3+
package lib {
4+
sealed trait Foo
5+
object Foo // normally, would cache a mirror if one exists.
6+
case class Bar private[lib] () extends Foo
7+
case object Bar // force mirror for Bar to be anonymous.
8+
9+
10+
object CallSiteSucceed {
11+
val mFoo = summon[Mirror.SumOf[lib.Foo]] // ok
12+
val mBar = summon[Mirror.ProductOf[lib.Bar]] // ok
13+
}
14+
15+
}
16+
17+
package app {
18+
19+
object MustFail {
20+
// we are outsite of accessible scope for Bar's ctor, so this should fail.
21+
22+
val mFoo = summon[Mirror.SumOf[lib.Foo]] // error
23+
val mBar = summon[Mirror.ProductOf[lib.Bar]] // error
24+
}
25+
26+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package lib
2+
3+
import scala.deriving.Mirror
4+
5+
sealed trait Foo
6+
object Foo // normally, would cache a mirror if one exists.
7+
case class Bar private[lib] () extends Foo
8+
case object Bar // force mirror for Bar to be anonymous.
9+
10+
11+
object CallSiteSucceed {
12+
val mFoo = summon[Mirror.SumOf[Foo]] // ok
13+
val mBar = summon[Mirror.ProductOf[Bar]] // ok
14+
val sampleBar = Bar()
15+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@main def Test =
2+
assert(lib.CallSiteSucceed.mFoo eq lib.Foo) // binary compatibility with 3.1
3+
assert(lib.CallSiteSucceed.mBar ne lib.Bar) // anonymous mirror
4+
5+
assert(lib.CallSiteSucceed.mFoo.ordinal(lib.CallSiteSucceed.sampleBar) == 0)
6+
assert(lib.CallSiteSucceed.mBar.fromProduct(EmptyTuple) == lib.CallSiteSucceed.sampleBar)

0 commit comments

Comments
 (0)