Skip to content

Commit 4d5d9e3

Browse files
authored
Merge pull request swiftlang#41792 from DougGregor/collapse-thunk-types
2 parents 8298c72 + 9c13426 commit 4d5d9e3

File tree

5 files changed

+330
-509
lines changed

5 files changed

+330
-509
lines changed

include/swift/SIL/TypeLowering.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,39 @@ CanSILFunctionType getNativeSILFunctionType(
12041204
Optional<SubstitutionMap> reqtSubs = None,
12051205
ProtocolConformanceRef witnessMethodConformance = ProtocolConformanceRef());
12061206

1207+
/// The thunk kinds used in the differentiation transform.
1208+
enum class DifferentiationThunkKind {
1209+
/// A reabstraction thunk.
1210+
///
1211+
/// Reabstraction thunks transform a function-typed value to another one with
1212+
/// different parameter/result abstraction patterns. This is identical to the
1213+
/// thunks generated by SILGen.
1214+
Reabstraction,
1215+
1216+
/// An index subset thunk.
1217+
///
1218+
/// An index subset thunk is used transform JVP/VJPs into a version that is
1219+
/// "wrt" fewer differentiation parameters.
1220+
/// - Differentials of thunked JVPs use zero for non-requested differentiation
1221+
/// parameters.
1222+
/// - Pullbacks of thunked VJPs discard results for non-requested
1223+
/// differentiation parameters.
1224+
IndexSubset
1225+
};
1226+
1227+
/// Build the type of a function transformation thunk.
1228+
CanSILFunctionType buildSILFunctionThunkType(
1229+
SILFunction *fn,
1230+
CanSILFunctionType &sourceType,
1231+
CanSILFunctionType &expectedType,
1232+
CanType &inputSubstType,
1233+
CanType &outputSubstType,
1234+
GenericEnvironment *&genericEnv,
1235+
SubstitutionMap &interfaceSubs,
1236+
CanType &dynamicSelfType,
1237+
bool withoutActuallyEscaping,
1238+
Optional<DifferentiationThunkKind> differentiationThunkKind = None);
1239+
12071240
} // namespace swift
12081241

12091242
namespace llvm {

include/swift/SILOptimizer/Differentiation/Thunk.h

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,41 +47,6 @@ class ADContext;
4747
// moved to a shared location.
4848
//===----------------------------------------------------------------------===//
4949

50-
/// The thunk kinds used in the differentiation transform.
51-
enum class DifferentiationThunkKind {
52-
/// A reabstraction thunk.
53-
///
54-
/// Reabstraction thunks transform a function-typed value to another one with
55-
/// different parameter/result abstraction patterns. This is identical to the
56-
/// thunks generated by SILGen.
57-
Reabstraction,
58-
59-
/// An index subset thunk.
60-
///
61-
/// An index subset thunk is used transform JVP/VJPs into a version that is
62-
/// "wrt" fewer differentiation parameters.
63-
/// - Differentials of thunked JVPs use zero for non-requested differentiation
64-
/// parameters.
65-
/// - Pullbacks of thunked VJPs discard results for non-requested
66-
/// differentiation parameters.
67-
IndexSubset
68-
};
69-
70-
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
71-
OpenedArchetypeType *openedExistential,
72-
GenericEnvironment *&genericEnv,
73-
SubstitutionMap &contextSubs,
74-
SubstitutionMap &interfaceSubs,
75-
ArchetypeType *&newArchetype);
76-
77-
/// Build the type of a function transformation thunk.
78-
CanSILFunctionType buildThunkType(SILFunction *fn,
79-
CanSILFunctionType &sourceType,
80-
CanSILFunctionType &expectedType,
81-
GenericEnvironment *&genericEnv,
82-
SubstitutionMap &interfaceSubs,
83-
bool withoutActuallyEscaping,
84-
DifferentiationThunkKind thunkKind);
8550

8651
/// Get or create a reabstraction thunk from `fromType` to `toType`, to be
8752
/// called in `caller`.

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,6 +2407,296 @@ CanSILFunctionType swift::getNativeSILFunctionType(
24072407
substConstant, reqtSubs, witnessMethodConformance, None);
24082408
}
24092409

2410+
/// Build a generic signature and environment for a re-abstraction thunk.
2411+
///
2412+
/// Most thunks share the generic environment with their original function.
2413+
/// The one exception is if the thunk type involves an open existential,
2414+
/// in which case we "promote" the opened existential to a new generic parameter.
2415+
///
2416+
/// \param SGF - the parent function
2417+
/// \param openedExistential - the opened existential to promote to a generic
2418+
// parameter, if any
2419+
/// \param inheritGenericSig - whether to inherit the generic signature from the
2420+
/// parent function.
2421+
/// \param genericEnv - the new generic environment
2422+
/// \param contextSubs - map old archetypes to new archetypes
2423+
/// \param interfaceSubs - map interface types to old archetypes
2424+
static CanGenericSignature
2425+
buildThunkSignature(SILFunction *fn,
2426+
bool inheritGenericSig,
2427+
OpenedArchetypeType *openedExistential,
2428+
GenericEnvironment *&genericEnv,
2429+
SubstitutionMap &contextSubs,
2430+
SubstitutionMap &interfaceSubs,
2431+
ArchetypeType *&newArchetype) {
2432+
auto *mod = fn->getModule().getSwiftModule();
2433+
auto &ctx = mod->getASTContext();
2434+
2435+
// If there's no opened existential, we just inherit the generic environment
2436+
// from the parent function.
2437+
if (openedExistential == nullptr) {
2438+
auto genericSig =
2439+
fn->getLoweredFunctionType()->getInvocationGenericSignature();
2440+
genericEnv = fn->getGenericEnvironment();
2441+
interfaceSubs = fn->getForwardingSubstitutionMap();
2442+
contextSubs = interfaceSubs;
2443+
return genericSig;
2444+
}
2445+
2446+
// Add the existing generic signature.
2447+
int depth = 0;
2448+
GenericSignature baseGenericSig;
2449+
if (inheritGenericSig) {
2450+
if (auto genericSig =
2451+
fn->getLoweredFunctionType()->getInvocationGenericSignature()) {
2452+
baseGenericSig = genericSig;
2453+
depth = genericSig.getGenericParams().back()->getDepth() + 1;
2454+
}
2455+
}
2456+
2457+
// Add a new generic parameter to replace the opened existential.
2458+
auto *newGenericParam =
2459+
GenericTypeParamType::get(/*type sequence*/ false, depth, 0, ctx);
2460+
2461+
assert(openedExistential->isRoot());
2462+
auto constraint = openedExistential->getExistentialType();
2463+
if (auto existential = constraint->getAs<ExistentialType>())
2464+
constraint = existential->getConstraintType();
2465+
2466+
Requirement newRequirement(RequirementKind::Conformance, newGenericParam,
2467+
constraint);
2468+
2469+
auto genericSig = buildGenericSignature(ctx, baseGenericSig,
2470+
{ newGenericParam },
2471+
{ newRequirement });
2472+
genericEnv = genericSig.getGenericEnvironment();
2473+
2474+
newArchetype = genericEnv->mapTypeIntoContext(newGenericParam)
2475+
->castTo<ArchetypeType>();
2476+
2477+
// Calculate substitutions to map the caller's archetypes to the thunk's
2478+
// archetypes.
2479+
if (auto calleeGenericSig = fn->getLoweredFunctionType()
2480+
->getInvocationGenericSignature()) {
2481+
contextSubs = SubstitutionMap::get(
2482+
calleeGenericSig,
2483+
[&](SubstitutableType *type) -> Type {
2484+
return genericEnv->mapTypeIntoContext(type);
2485+
},
2486+
MakeAbstractConformanceForGenericType());
2487+
}
2488+
2489+
// Calculate substitutions to map interface types to the caller's archetypes.
2490+
interfaceSubs = SubstitutionMap::get(
2491+
genericSig,
2492+
[&](SubstitutableType *type) -> Type {
2493+
if (type->isEqual(newGenericParam))
2494+
return openedExistential;
2495+
return fn->mapTypeIntoContext(type);
2496+
},
2497+
MakeAbstractConformanceForGenericType());
2498+
2499+
return genericSig.getCanonicalSignature();
2500+
}
2501+
2502+
/// Build the type of a function transformation thunk.
2503+
CanSILFunctionType swift::buildSILFunctionThunkType(
2504+
SILFunction *fn,
2505+
CanSILFunctionType &sourceType,
2506+
CanSILFunctionType &expectedType,
2507+
CanType &inputSubstType,
2508+
CanType &outputSubstType,
2509+
GenericEnvironment *&genericEnv,
2510+
SubstitutionMap &interfaceSubs,
2511+
CanType &dynamicSelfType,
2512+
bool withoutActuallyEscaping,
2513+
Optional<DifferentiationThunkKind> differentiationThunkKind) {
2514+
// We shouldn't be thunking generic types here, and substituted function types
2515+
// ought to have their substitutions applied before we get here.
2516+
assert(!expectedType->isPolymorphic() &&
2517+
!expectedType->getCombinedSubstitutions());
2518+
assert(!sourceType->isPolymorphic() &&
2519+
!sourceType->getCombinedSubstitutions());
2520+
2521+
// This may inherit @noescape from the expectedType. The @noescape attribute
2522+
// is only stripped when using this type to materialize a new decl.
2523+
auto extInfoBuilder = expectedType->getExtInfo().intoBuilder();
2524+
if (!differentiationThunkKind ||
2525+
*differentiationThunkKind == DifferentiationThunkKind::Reabstraction ||
2526+
extInfoBuilder.hasContext()) {
2527+
// Can't build a reabstraction thunk without context, so we require
2528+
// ownership semantics on the result type.
2529+
assert(expectedType->getExtInfo().hasContext());
2530+
2531+
extInfoBuilder = extInfoBuilder.withRepresentation(
2532+
SILFunctionType::Representation::Thin);
2533+
}
2534+
2535+
if (withoutActuallyEscaping)
2536+
extInfoBuilder = extInfoBuilder.withNoEscape(false);
2537+
2538+
// Does the thunk type involve archetypes other than opened existentials?
2539+
bool hasArchetypes = false;
2540+
// Does the thunk type involve an open existential type?
2541+
CanOpenedArchetypeType openedExistential;
2542+
auto archetypeVisitor = [&](CanType t) {
2543+
if (auto archetypeTy = dyn_cast<ArchetypeType>(t)) {
2544+
if (auto opened = dyn_cast<OpenedArchetypeType>(archetypeTy)) {
2545+
const auto root = cast<OpenedArchetypeType>(CanType(opened->getRoot()));
2546+
assert((openedExistential == CanArchetypeType() ||
2547+
openedExistential == root) &&
2548+
"one too many open existentials");
2549+
openedExistential = root;
2550+
} else {
2551+
hasArchetypes = true;
2552+
}
2553+
}
2554+
};
2555+
2556+
// Use the generic signature from the context if the thunk involves
2557+
// generic parameters.
2558+
CanGenericSignature genericSig;
2559+
SubstitutionMap contextSubs;
2560+
ArchetypeType *newArchetype = nullptr;
2561+
2562+
if (expectedType->hasArchetype() || sourceType->hasArchetype()) {
2563+
expectedType.visit(archetypeVisitor);
2564+
sourceType.visit(archetypeVisitor);
2565+
2566+
genericSig = buildThunkSignature(fn,
2567+
hasArchetypes,
2568+
openedExistential,
2569+
genericEnv,
2570+
contextSubs,
2571+
interfaceSubs,
2572+
newArchetype);
2573+
}
2574+
2575+
auto substTypeHelper = [&](SubstitutableType *type) -> Type {
2576+
if (CanType(type) == openedExistential)
2577+
return newArchetype;
2578+
2579+
// If a nested archetype is rooted on our opened existential, fail:
2580+
// Type::subst attempts to substitute the parent of a nested archetype
2581+
// only if it fails to find a replacement for the nested one.
2582+
if (auto *opened = dyn_cast<OpenedArchetypeType>(type)) {
2583+
if (openedExistential->isEqual(opened->getRoot())) {
2584+
return nullptr;
2585+
}
2586+
}
2587+
2588+
return Type(type).subst(contextSubs);
2589+
};
2590+
auto substConformanceHelper =
2591+
LookUpConformanceInSubstitutionMap(contextSubs);
2592+
2593+
// Utility function to apply contextSubs, and also replace the
2594+
// opened existential with the new archetype.
2595+
auto substFormalTypeIntoThunkContext =
2596+
[&](CanType t) -> CanType {
2597+
return t.subst(substTypeHelper, substConformanceHelper)
2598+
->getCanonicalType();
2599+
};
2600+
auto substLoweredTypeIntoThunkContext =
2601+
[&](CanSILFunctionType t) -> CanSILFunctionType {
2602+
return SILType::getPrimitiveObjectType(t)
2603+
.subst(fn->getModule(), substTypeHelper, substConformanceHelper)
2604+
.castTo<SILFunctionType>();
2605+
};
2606+
2607+
sourceType = substLoweredTypeIntoThunkContext(sourceType);
2608+
expectedType = substLoweredTypeIntoThunkContext(expectedType);
2609+
2610+
bool hasDynamicSelf = false;
2611+
2612+
if (inputSubstType) {
2613+
inputSubstType = substFormalTypeIntoThunkContext(inputSubstType);
2614+
hasDynamicSelf |= inputSubstType->hasDynamicSelfType();
2615+
}
2616+
2617+
if (outputSubstType) {
2618+
outputSubstType = substFormalTypeIntoThunkContext(outputSubstType);
2619+
hasDynamicSelf |= outputSubstType->hasDynamicSelfType();
2620+
}
2621+
2622+
hasDynamicSelf |= sourceType->hasDynamicSelfType();
2623+
hasDynamicSelf |= expectedType->hasDynamicSelfType();
2624+
2625+
// If our parent function was pseudogeneric, this thunk must also be
2626+
// pseudogeneric, since we have no way to pass generic parameters.
2627+
if (genericSig)
2628+
if (fn->getLoweredFunctionType()->isPseudogeneric())
2629+
extInfoBuilder = extInfoBuilder.withIsPseudogeneric();
2630+
2631+
// Add the function type as the parameter.
2632+
auto contextConvention =
2633+
fn->getTypeLowering(sourceType).isTrivial()
2634+
? ParameterConvention::Direct_Unowned
2635+
: ParameterConvention::Direct_Guaranteed;
2636+
SmallVector<SILParameterInfo, 4> params;
2637+
params.append(expectedType->getParameters().begin(),
2638+
expectedType->getParameters().end());
2639+
2640+
if (!differentiationThunkKind ||
2641+
*differentiationThunkKind == DifferentiationThunkKind::Reabstraction) {
2642+
params.push_back({sourceType,
2643+
sourceType->getExtInfo().hasContext()
2644+
? contextConvention
2645+
: ParameterConvention::Direct_Unowned});
2646+
}
2647+
2648+
// If this thunk involves DynamicSelfType in any way, add a capture for it
2649+
// in case we need to recover metadata.
2650+
if (hasDynamicSelf) {
2651+
dynamicSelfType = fn->getDynamicSelfMetadata()->getType().getASTType();
2652+
if (!isa<MetatypeType>(dynamicSelfType)) {
2653+
dynamicSelfType = CanMetatypeType::get(dynamicSelfType,
2654+
MetatypeRepresentation::Thick);
2655+
}
2656+
params.push_back({dynamicSelfType, ParameterConvention::Direct_Unowned});
2657+
}
2658+
2659+
auto mapTypeOutOfContext = [&](CanType type) -> CanType {
2660+
return type->mapTypeOutOfContext()->getCanonicalType(genericSig);
2661+
};
2662+
2663+
// Map the parameter and expected types out of context to get the interface
2664+
// type of the thunk.
2665+
SmallVector<SILParameterInfo, 4> interfaceParams;
2666+
interfaceParams.reserve(params.size());
2667+
for (auto &param : params) {
2668+
auto interfaceParam = param.map(mapTypeOutOfContext);
2669+
interfaceParams.push_back(interfaceParam);
2670+
}
2671+
2672+
SmallVector<SILYieldInfo, 4> interfaceYields;
2673+
for (auto &yield : expectedType->getYields()) {
2674+
auto interfaceYield = yield.map(mapTypeOutOfContext);
2675+
interfaceYields.push_back(interfaceYield);
2676+
}
2677+
2678+
SmallVector<SILResultInfo, 4> interfaceResults;
2679+
for (auto &result : expectedType->getResults()) {
2680+
auto interfaceResult = result.map(mapTypeOutOfContext);
2681+
interfaceResults.push_back(interfaceResult);
2682+
}
2683+
2684+
Optional<SILResultInfo> interfaceErrorResult;
2685+
if (expectedType->hasErrorResult()) {
2686+
auto errorResult = expectedType->getErrorResult();
2687+
interfaceErrorResult = errorResult.map(mapTypeOutOfContext);;
2688+
}
2689+
2690+
// The type of the thunk function.
2691+
return SILFunctionType::get(
2692+
genericSig, extInfoBuilder.build(), expectedType->getCoroutineKind(),
2693+
ParameterConvention::Direct_Unowned, interfaceParams, interfaceYields,
2694+
interfaceResults, interfaceErrorResult,
2695+
expectedType->getPatternSubstitutions(), SubstitutionMap(),
2696+
fn->getASTContext());
2697+
2698+
}
2699+
24102700
//===----------------------------------------------------------------------===//
24112701
// Foreign SILFunctionTypes
24122702
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)