diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 4d41ee55f3a9..a2dd3e450a50 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -187,7 +187,7 @@ class CheckCaptures extends Recheck, SymTransformer: capt.println(i"solving $t") refs.solve() traverse(parent) - case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) => + case t @ defn.RefinedFunctionOf(rinfo) => traverse(rinfo) case tp: TypeVar => case tp: TypeRef => @@ -769,7 +769,7 @@ class CheckCaptures extends Recheck, SymTransformer: case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) => adaptFun(actual, args.init, args.last, expected, covariant, insertBox, (aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1)) - case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) => + case actual @ defn.RefinedFunctionOf(rinfo: MethodType) => // TODO Find a way to combine handling of generic and dependent function types (here and elsewhere) adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox, (aargs1, ares1) => @@ -779,11 +779,11 @@ class CheckCaptures extends Recheck, SymTransformer: adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox, (aargs1, ares1) => actual.derivedLambdaType(paramInfos = aargs1, resType = ares1)) - case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) => + case actual @ defn.RefinedFunctionOf(rinfo: PolyType) => adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox, ares1 => val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1) - val actual1 = actual.derivedRefinedType(p, nme, rinfo1) + val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1) actual1 ) case _ => @@ -996,7 +996,7 @@ class CheckCaptures extends Recheck, SymTransformer: case CapturingType(parent, refs) => healCaptureSet(refs) traverse(parent) - case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) => + case defn.RefinedFunctionOf(rinfo: MethodType) => traverse(rinfo) case tp: TermLambda => val saved = allowed diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 463919e85893..758486532512 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -54,7 +54,7 @@ extends tpd.TreeTraverser: val boxedRes = recur(res) if boxedRes eq res then tp else tp1.derivedAppliedType(tycon, args.init :+ boxedRes) - case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) => + case tp1 @ defn.RefinedFunctionOf(rinfo: MethodType) => val boxedRinfo = recur(rinfo) if boxedRinfo eq rinfo then tp else boxedRinfo.toFunctionType(alwaysDependent = true) @@ -231,7 +231,7 @@ extends tpd.TreeTraverser: tp.derivedAppliedType(tycon1, args1 :+ res1) else tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg))) - case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) => + case defn.RefinedFunctionOf(rinfo: MethodType) => val rinfo1 = apply(rinfo) if rinfo1 ne rinfo then rinfo1.toFunctionType(alwaysDependent = true) else tp diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index cebc2cb67c45..2411899b1740 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1132,6 +1132,20 @@ class Definitions { case _ => None } + object RefinedFunctionOf { + /** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`. + * Extracts the method type type and apply info. + */ + def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = { + tpe.refinedInfo match + case mt: MethodOrPoly + if tpe.refinedName == nme.apply + && (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) => + Some(mt) + case _ => None + } + } + object PolyFunctionOf { /** Matches a refined `PolyFunction` type and extracts the apply info. * diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 755214954e61..e09778786290 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -4055,8 +4055,8 @@ object Types { tp.derivedAppliedType(tycon, addInto(args.head) :: Nil) case tp @ AppliedType(tycon, args) if defn.isFunctionNType(tp) => wrapConvertible(tp.derivedAppliedType(tycon, args.init :+ addInto(args.last))) - case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionType(tp) => - wrapConvertible(tp.derivedRefinedType(parent, rname, addInto(rinfo))) + case tp @ defn.RefinedFunctionOf(rinfo) => + wrapConvertible(tp.derivedRefinedType(tp.parent, tp.refinedName, addInto(rinfo))) case tp: MethodOrPoly => tp.derivedLambdaType(resType = addInto(tp.resType)) case ExprType(resType) =>