Skip to content

Map opaque types in arguments of inlined calls to proxies #12922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 68 additions & 37 deletions compiler/src/dotty/tools/dotc/typer/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {

private val methPart = funPart(call)
private val callTypeArgs = typeArgss(call).flatten
private val callValueArgss = termArgss(call)
private val rawCallValueArgss = termArgss(call)
private val inlinedMethod = methPart.symbol
private val inlineCallPrefix =
qualifier(methPart).orElse(This(inlinedMethod.enclosingClass.asClass))
Expand Down Expand Up @@ -581,31 +581,17 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
case (from, to) if from.symbol == ref.symbol && from =:= ref => to
}

/** If `binding` contains TermRefs that refer to objects with opaque
* type aliases, add proxy definitions that expose these aliases
* and substitute such TermRefs with theproxies. Example from pos/opaque-inline1.scala:
*
* object refined:
* opaque type Positive = Int
* inline def Positive(value: Int): Positive = f(value)
* def f(x: Positive): Positive = x
* def run: Unit = { val x = 9; val nine = refined.Positive(x) }
*
* This generates the following proxies:
*
* val $proxy1: refined.type{type Positive = Int} =
* refined.$asInstanceOf$[refined.type{type Positive = Int}]
* val refined$_this: ($proxy1 : refined.type{Positive = Int}) =
* $proxy1
*
* and every reference to `refined` in the inlined expression is replaced by
* `refined_$this`.
/** If `tp` contains TermRefs that refer to objects with opaque
* type aliases, add proxy definitions to `opaqueProxies` that expose these aliases.
*/
def accountForOpaques(binding: ValDef)(using Context): ValDef =
binding.symbol.info.foreachPart {
def addOpaqueProxies(tp: Type, span: Span, forThisProxy: Boolean)(using Context): Unit =
tp.foreachPart {
case ref: TermRef =>
for cls <- ref.widen.classSymbols do
if cls.containsOpaques && mapRef(ref).isEmpty then
if cls.containsOpaques
&& (forThisProxy || inlinedMethod.isContainedIn(cls))
&& mapRef(ref).isEmpty
then
def openOpaqueAliases(selfType: Type): List[(Name, Type)] = selfType match
case RefinedType(parent, rname, TypeAlias(alias)) =>
val opaq = cls.info.member(rname).symbol
Expand All @@ -620,27 +606,67 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
RefinedType(parent, refinement._1, TypeAlias(refinement._2))
)
val refiningSym = newSym(InlineBinderName.fresh(), Synthetic, refinedType).asTerm
val refiningDef = ValDef(refiningSym, tpd.ref(ref).cast(refinedType)).withSpan(binding.span)
inlining.println(i"add opaque alias proxy $refiningDef")
val refiningDef = ValDef(refiningSym, tpd.ref(ref).cast(refinedType)).withSpan(span)
inlining.println(i"add opaque alias proxy $refiningDef for $ref in $tp")
bindingsBuf += refiningDef
opaqueProxies += ((ref, refiningSym.termRef))
case _ =>
}

/** Map all TermRefs that match left element in `opaqueProxies` to the
* corresponding right element.
*/
val mapOpaques = TreeTypeMap(
typeMap = new TypeMap:
override def stopAt = StopAt.Package
def apply(t: Type) = mapOver {
t match
case ref: TermRef => mapRef(ref).getOrElse(ref)
case _ => t
}
)

/** If `binding` contains TermRefs that refer to objects with opaque
* type aliases, add proxy definitions that expose these aliases
* and substitute such TermRefs with theproxies. Example from pos/opaque-inline1.scala:
*
* object refined:
* opaque type Positive = Int
* inline def Positive(value: Int): Positive = f(value)
* def f(x: Positive): Positive = x
* def run: Unit = { val x = 9; val nine = refined.Positive(x) }
*
* This generates the following proxies:
*
* val $proxy1: refined.type{type Positive = Int} =
* refined.$asInstanceOf$[refined.type{type Positive = Int}]
* val refined$_this: ($proxy1 : refined.type{Positive = Int}) =
* $proxy1
*
* and every reference to `refined` in the inlined expression is replaced by
* `refined_$this`.
*/
def accountForOpaques(binding: ValDef)(using Context): ValDef =
addOpaqueProxies(binding.symbol.info, binding.span, forThisProxy = true)
if opaqueProxies.isEmpty then binding
else
val mapType = new TypeMap:
override def stopAt = StopAt.Package
def apply(t: Type) = mapOver {
t match
case ref: TermRef => mapRef(ref).getOrElse(ref)
case _ => t
}
binding.symbol.info = mapType(binding.symbol.info)
val mapTree = TreeTypeMap(typeMap = mapType)
mapTree.transform(binding).asInstanceOf[ValDef]
binding.symbol.info = mapOpaques.typeMap(binding.symbol.info)
mapOpaques.transform(binding).asInstanceOf[ValDef]
.showing(i"transformed this binding exposing opaque aliases: $result", inlining)
end accountForOpaques

/** If value argument contains references to objects that contain opaque types,
* map them to their opaque proxies.
*/
def mapOpaquesInValueArg(arg: Tree)(using Context): Tree =
val argType = arg.tpe.widen
addOpaqueProxies(argType, arg.span, forThisProxy = false)
if opaqueProxies.nonEmpty then
val mappedType = mapOpaques.typeMap(argType)
if mappedType ne argType then arg.cast(AndType(arg.tpe, mappedType))
else arg
else arg

private def canElideThis(tpe: ThisType): Boolean =
inlineCallPrefix.tpe == tpe && ctx.owner.isContainedIn(tpe.cls)
|| tpe.cls.isContainedIn(inlinedMethod)
Expand Down Expand Up @@ -773,7 +799,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
def inlined(sourcePos: SrcPos): Tree = {

// Special handling of `requireConst` and `codeOf`
callValueArgss match
rawCallValueArgss match
case (arg :: Nil) :: Nil =>
if inlinedMethod == defn.Compiletime_requireConst then
arg match
Expand Down Expand Up @@ -823,6 +849,11 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
case TypeApply(fn, _) => paramTypess(fn, acc)
case _ => acc

val callValueArgss = rawCallValueArgss.nestedMapConserve(mapOpaquesInValueArg)

if callValueArgss ne rawCallValueArgss then
inlining.println(i"mapped value args = ${callValueArgss.flatten}%, %")

// Compute bindings for all parameters, appending them to bindingsBuf
if !computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss, paramTypess(call, Nil)) then
return call
Expand Down Expand Up @@ -1254,7 +1285,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
case fail: Implicits.SearchFailureType =>
false
case _ =>
//inliner.println(i"inferred implicit $sym: ${sym.info} with $evidence: ${evidence.tpe.widen}, ${evCtx.gadt.constraint}, ${evCtx.typerState.constraint}")
//inlining.println(i"inferred implicit $sym: ${sym.info} with $evidence: ${evidence.tpe.widen}, ${evCtx.gadt.constraint}, ${evCtx.typerState.constraint}")
newTermBinding(sym, evidence)
true
}
Expand Down
8 changes: 8 additions & 0 deletions tests/run/i12914.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
ASD
asd
ASD
asd
ASD
asd
aSdaSdaSd
aSdaSdaSd
27 changes: 27 additions & 0 deletions tests/run/i12914.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

class opq:
opaque type Str = java.lang.String
object Str:
def apply(s: String): Str = s
inline def lower(s: Str): String = s.toLowerCase
extension (s: Str)
transparent inline def upper: String = s.toUpperCase
inline def concat(xs: List[Str]): Str = String(xs.flatten.toArray)
transparent inline def concat2(xs: List[Str]): Str = String(xs.flatten.toArray)


@main def Test =
val opq = new opq()
import opq.*
val a: Str = Str("aSd")
println(a.upper)
println(opq.lower(a))
def b: Str = Str("aSd")
println(b.upper)
println(opq.lower(b))
def c(): Str = Str("aSd")
println(c().upper)
println(opq.lower(c()))
println(opq.concat(List(a, b, c())))
println(opq.concat2(List(a, b, c())))