Skip to content

Commit 9c13426

Browse files
committed
Collapse the differentiation-specific thunk type generation code into the general version
We had two copies of this code that had drifted apart. Bring them back together so there is just one place where we compute the type of a reabstraction thunk.
1 parent 3cb7af0 commit 9c13426

File tree

4 files changed

+48
-254
lines changed

4 files changed

+48
-254
lines changed

include/swift/SIL/TypeLowering.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,26 @@ CanSILFunctionType getNativeSILFunctionType(
12041204
Optional<SubstitutionMap> reqtSubs = None,
12051205
ProtocolConformanceRef witnessMethodConformance = ProtocolConformanceRef());
12061206

1207+
/// The thunk kinds used in the differentiation transform.
1208+
enum class DifferentiationThunkKind {
1209+
/// A reabstraction thunk.
1210+
///
1211+
/// Reabstraction thunks transform a function-typed value to another one with
1212+
/// different parameter/result abstraction patterns. This is identical to the
1213+
/// thunks generated by SILGen.
1214+
Reabstraction,
1215+
1216+
/// An index subset thunk.
1217+
///
1218+
/// An index subset thunk is used transform JVP/VJPs into a version that is
1219+
/// "wrt" fewer differentiation parameters.
1220+
/// - Differentials of thunked JVPs use zero for non-requested differentiation
1221+
/// parameters.
1222+
/// - Pullbacks of thunked VJPs discard results for non-requested
1223+
/// differentiation parameters.
1224+
IndexSubset
1225+
};
1226+
12071227
/// Build the type of a function transformation thunk.
12081228
CanSILFunctionType buildSILFunctionThunkType(
12091229
SILFunction *fn,
@@ -1214,7 +1234,8 @@ CanSILFunctionType buildSILFunctionThunkType(
12141234
GenericEnvironment *&genericEnv,
12151235
SubstitutionMap &interfaceSubs,
12161236
CanType &dynamicSelfType,
1217-
bool withoutActuallyEscaping);
1237+
bool withoutActuallyEscaping,
1238+
Optional<DifferentiationThunkKind> differentiationThunkKind = None);
12181239

12191240
} // namespace swift
12201241

include/swift/SILOptimizer/Differentiation/Thunk.h

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,41 +47,6 @@ class ADContext;
4747
// moved to a shared location.
4848
//===----------------------------------------------------------------------===//
4949

50-
/// The thunk kinds used in the differentiation transform.
51-
enum class DifferentiationThunkKind {
52-
/// A reabstraction thunk.
53-
///
54-
/// Reabstraction thunks transform a function-typed value to another one with
55-
/// different parameter/result abstraction patterns. This is identical to the
56-
/// thunks generated by SILGen.
57-
Reabstraction,
58-
59-
/// An index subset thunk.
60-
///
61-
/// An index subset thunk is used transform JVP/VJPs into a version that is
62-
/// "wrt" fewer differentiation parameters.
63-
/// - Differentials of thunked JVPs use zero for non-requested differentiation
64-
/// parameters.
65-
/// - Pullbacks of thunked VJPs discard results for non-requested
66-
/// differentiation parameters.
67-
IndexSubset
68-
};
69-
70-
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
71-
OpenedArchetypeType *openedExistential,
72-
GenericEnvironment *&genericEnv,
73-
SubstitutionMap &contextSubs,
74-
SubstitutionMap &interfaceSubs,
75-
ArchetypeType *&newArchetype);
76-
77-
/// Build the type of a function transformation thunk.
78-
CanSILFunctionType buildThunkType(SILFunction *fn,
79-
CanSILFunctionType &sourceType,
80-
CanSILFunctionType &expectedType,
81-
GenericEnvironment *&genericEnv,
82-
SubstitutionMap &interfaceSubs,
83-
bool withoutActuallyEscaping,
84-
DifferentiationThunkKind thunkKind);
8550

8651
/// Get or create a reabstraction thunk from `fromType` to `toType`, to be
8752
/// called in `caller`.

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,23 +2509,28 @@ CanSILFunctionType swift::buildSILFunctionThunkType(
25092509
GenericEnvironment *&genericEnv,
25102510
SubstitutionMap &interfaceSubs,
25112511
CanType &dynamicSelfType,
2512-
bool withoutActuallyEscaping) {
2512+
bool withoutActuallyEscaping,
2513+
Optional<DifferentiationThunkKind> differentiationThunkKind) {
25132514
// We shouldn't be thunking generic types here, and substituted function types
25142515
// ought to have their substitutions applied before we get here.
25152516
assert(!expectedType->isPolymorphic() &&
25162517
!expectedType->getCombinedSubstitutions());
25172518
assert(!sourceType->isPolymorphic() &&
25182519
!sourceType->getCombinedSubstitutions());
25192520

2520-
// Can't build a thunk without context, so we require ownership semantics
2521-
// on the result type.
2522-
assert(expectedType->getExtInfo().hasContext());
2523-
25242521
// This may inherit @noescape from the expectedType. The @noescape attribute
25252522
// is only stripped when using this type to materialize a new decl.
2526-
auto extInfoBuilder =
2527-
expectedType->getExtInfo().intoBuilder().withRepresentation(
2523+
auto extInfoBuilder = expectedType->getExtInfo().intoBuilder();
2524+
if (!differentiationThunkKind ||
2525+
*differentiationThunkKind == DifferentiationThunkKind::Reabstraction ||
2526+
extInfoBuilder.hasContext()) {
2527+
// Can't build a reabstraction thunk without context, so we require
2528+
// ownership semantics on the result type.
2529+
assert(expectedType->getExtInfo().hasContext());
2530+
2531+
extInfoBuilder = extInfoBuilder.withRepresentation(
25282532
SILFunctionType::Representation::Thin);
2533+
}
25292534

25302535
if (withoutActuallyEscaping)
25312536
extInfoBuilder = extInfoBuilder.withNoEscape(false);
@@ -2631,10 +2636,14 @@ CanSILFunctionType swift::buildSILFunctionThunkType(
26312636
SmallVector<SILParameterInfo, 4> params;
26322637
params.append(expectedType->getParameters().begin(),
26332638
expectedType->getParameters().end());
2634-
params.push_back({sourceType,
2635-
sourceType->getExtInfo().hasContext()
2636-
? contextConvention
2637-
: ParameterConvention::Direct_Unowned});
2639+
2640+
if (!differentiationThunkKind ||
2641+
*differentiationThunkKind == DifferentiationThunkKind::Reabstraction) {
2642+
params.push_back({sourceType,
2643+
sourceType->getExtInfo().hasContext()
2644+
? contextConvention
2645+
: ParameterConvention::Direct_Unowned});
2646+
}
26382647

26392648
// If this thunk involves DynamicSelfType in any way, add a capture for it
26402649
// in case we need to recover metadata.

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 6 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -36,220 +36,19 @@ namespace autodiff {
3636
// moved to a shared location.
3737
//===----------------------------------------------------------------------===//
3838

39-
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
40-
OpenedArchetypeType *openedExistential,
41-
GenericEnvironment *&genericEnv,
42-
SubstitutionMap &contextSubs,
43-
SubstitutionMap &interfaceSubs,
44-
ArchetypeType *&newArchetype) {
45-
// If there's no opened existential, we just inherit the generic environment
46-
// from the parent function.
47-
if (openedExistential == nullptr) {
48-
auto genericSig = fn->getLoweredFunctionType()->getInvocationGenericSignature();
49-
genericEnv = fn->getGenericEnvironment();
50-
interfaceSubs = fn->getForwardingSubstitutionMap();
51-
contextSubs = interfaceSubs;
52-
return genericSig;
53-
}
54-
55-
auto &ctx = fn->getASTContext();
56-
57-
// Add the existing generic signature.
58-
GenericSignature baseGenericSig;
59-
int depth = 0;
60-
if (inheritGenericSig) {
61-
baseGenericSig = fn->getLoweredFunctionType()->getInvocationGenericSignature();
62-
if (baseGenericSig)
63-
depth = baseGenericSig.getGenericParams().back()->getDepth() + 1;
64-
}
65-
66-
// Add a new generic parameter to replace the opened existential.
67-
auto *newGenericParam =
68-
GenericTypeParamType::get(/*type sequence*/ false, depth, 0, ctx);
69-
70-
assert(openedExistential->isRoot());
71-
auto constraint = openedExistential->getExistentialType();
72-
if (auto existential = constraint->getAs<ExistentialType>())
73-
constraint = existential->getConstraintType();
74-
75-
Requirement newRequirement(RequirementKind::Conformance, newGenericParam,
76-
constraint);
77-
78-
auto genericSig = buildGenericSignature(ctx, baseGenericSig,
79-
{ newGenericParam },
80-
{ newRequirement });
81-
genericEnv = genericSig.getGenericEnvironment();
82-
83-
newArchetype =
84-
genericEnv->mapTypeIntoContext(newGenericParam)->castTo<ArchetypeType>();
85-
86-
// Calculate substitutions to map the caller's archetypes to the thunk's
87-
// archetypes.
88-
if (auto calleeGenericSig =
89-
fn->getLoweredFunctionType()->getSubstGenericSignature()) {
90-
contextSubs = SubstitutionMap::get(
91-
calleeGenericSig,
92-
[&](SubstitutableType *type) -> Type {
93-
return genericEnv->mapTypeIntoContext(type);
94-
},
95-
MakeAbstractConformanceForGenericType());
96-
}
97-
98-
// Calculate substitutions to map interface types to the caller's archetypes.
99-
interfaceSubs = SubstitutionMap::get(
100-
genericSig,
101-
[&](SubstitutableType *type) -> Type {
102-
if (type->isEqual(newGenericParam))
103-
return openedExistential;
104-
return fn->mapTypeIntoContext(type);
105-
},
106-
MakeAbstractConformanceForGenericType());
107-
108-
return genericSig.getCanonicalSignature();
109-
}
110-
11139
CanSILFunctionType buildThunkType(SILFunction *fn,
11240
CanSILFunctionType &sourceType,
11341
CanSILFunctionType &expectedType,
11442
GenericEnvironment *&genericEnv,
11543
SubstitutionMap &interfaceSubs,
11644
bool withoutActuallyEscaping,
11745
DifferentiationThunkKind thunkKind) {
118-
assert(!expectedType->isPolymorphic() &&
119-
!expectedType->getCombinedSubstitutions());
120-
assert(!sourceType->isPolymorphic() &&
121-
!sourceType->getCombinedSubstitutions());
122-
123-
// Cannot build a reabstraction thunk without context. Ownership semantics
124-
// on the result type are required.
125-
if (thunkKind == DifferentiationThunkKind::Reabstraction)
126-
assert(expectedType->getExtInfo().hasContext());
127-
128-
// This may inherit @noescape from the expected type. The `@noescape`
129-
// attribute is only stripped when using this type to materialize a new decl.
130-
// Use `@convention(thin)` if:
131-
// - Building a reabstraction thunk type.
132-
// - Building an index subset thunk type, where the expected type has context
133-
// (i.e. is `@convention(thick)`).
134-
auto extInfoBuilder = expectedType->getExtInfo().intoBuilder();
135-
if (thunkKind == DifferentiationThunkKind::Reabstraction ||
136-
extInfoBuilder.hasContext()) {
137-
extInfoBuilder = extInfoBuilder.withRepresentation(
138-
SILFunctionType::Representation::Thin);
139-
}
140-
if (withoutActuallyEscaping)
141-
extInfoBuilder = extInfoBuilder.withNoEscape(false);
142-
143-
// Does the thunk type involve archetypes other than opened existentials?
144-
bool hasArchetypes = false;
145-
// Does the thunk type involve an open existential type?
146-
CanOpenedArchetypeType openedExistential;
147-
auto archetypeVisitor = [&](CanType t) {
148-
if (auto archetypeTy = dyn_cast<ArchetypeType>(t)) {
149-
if (auto opened = dyn_cast<OpenedArchetypeType>(archetypeTy)) {
150-
const auto root = cast<OpenedArchetypeType>(CanType(opened->getRoot()));
151-
assert((openedExistential == CanArchetypeType() ||
152-
openedExistential == root) &&
153-
"one too many open existentials");
154-
openedExistential = root;
155-
} else {
156-
hasArchetypes = true;
157-
}
158-
}
159-
};
160-
161-
// Use the generic signature from the context if the thunk involves
162-
// generic parameters.
163-
CanGenericSignature genericSig;
164-
SubstitutionMap contextSubs;
165-
ArchetypeType *newArchetype = nullptr;
166-
167-
if (expectedType->hasArchetype() || sourceType->hasArchetype()) {
168-
expectedType.visit(archetypeVisitor);
169-
sourceType.visit(archetypeVisitor);
170-
genericSig =
171-
buildThunkSignature(fn, hasArchetypes, openedExistential, genericEnv,
172-
contextSubs, interfaceSubs, newArchetype);
173-
}
174-
175-
auto substTypeHelper = [&](SubstitutableType *type) -> Type {
176-
if (CanType(type) == openedExistential)
177-
return newArchetype;
178-
return Type(type).subst(contextSubs);
179-
};
180-
auto substConformanceHelper = LookUpConformanceInSubstitutionMap(contextSubs);
181-
182-
// Utility function to apply contextSubs, and also replace the
183-
// opened existential with the new archetype.
184-
auto substLoweredTypeIntoThunkContext =
185-
[&](CanSILFunctionType t) -> CanSILFunctionType {
186-
return SILType::getPrimitiveObjectType(t)
187-
.subst(fn->getModule(), substTypeHelper, substConformanceHelper)
188-
.castTo<SILFunctionType>();
189-
};
190-
191-
sourceType = substLoweredTypeIntoThunkContext(sourceType);
192-
expectedType = substLoweredTypeIntoThunkContext(expectedType);
193-
194-
// If our parent function was pseudogeneric, this thunk must also be
195-
// pseudogeneric, since we have no way to pass generic parameters.
196-
if (genericSig)
197-
if (fn->getLoweredFunctionType()->isPseudogeneric())
198-
extInfoBuilder = extInfoBuilder.withIsPseudogeneric();
199-
200-
// Add the function type as the parameter.
201-
auto contextConvention =
202-
SILType::getPrimitiveObjectType(sourceType).isTrivial(*fn)
203-
? ParameterConvention::Direct_Unowned
204-
: ParameterConvention::Direct_Guaranteed;
205-
SmallVector<SILParameterInfo, 4> params;
206-
params.append(expectedType->getParameters().begin(),
207-
expectedType->getParameters().end());
208-
// Add reabstraction function parameter only if building a reabstraction thunk
209-
// type.
210-
if (thunkKind == DifferentiationThunkKind::Reabstraction)
211-
params.push_back({sourceType, sourceType->getExtInfo().hasContext()
212-
? contextConvention
213-
: ParameterConvention::Direct_Unowned});
214-
215-
auto mapTypeOutOfContext = [&](CanType type) -> CanType {
216-
return type->mapTypeOutOfContext()->getCanonicalType(genericSig);
217-
};
218-
219-
// Map the parameter and expected types out of context to get the interface
220-
// type of the thunk.
221-
SmallVector<SILParameterInfo, 4> interfaceParams;
222-
interfaceParams.reserve(params.size());
223-
for (auto &param : params) {
224-
auto interfaceParam = param.map(mapTypeOutOfContext);
225-
interfaceParams.push_back(interfaceParam);
226-
}
227-
228-
SmallVector<SILYieldInfo, 4> interfaceYields;
229-
for (auto &yield : expectedType->getYields()) {
230-
auto interfaceYield = yield.map(mapTypeOutOfContext);
231-
interfaceYields.push_back(interfaceYield);
232-
}
233-
234-
SmallVector<SILResultInfo, 4> interfaceResults;
235-
for (auto &result : expectedType->getResults()) {
236-
auto interfaceResult = result.map(mapTypeOutOfContext);
237-
interfaceResults.push_back(interfaceResult);
238-
}
239-
240-
Optional<SILResultInfo> interfaceErrorResult;
241-
if (expectedType->hasErrorResult()) {
242-
auto errorResult = expectedType->getErrorResult();
243-
interfaceErrorResult = errorResult.map(mapTypeOutOfContext);
244-
}
245-
246-
// The type of the thunk function.
247-
return SILFunctionType::get(
248-
genericSig, extInfoBuilder.build(), expectedType->getCoroutineKind(),
249-
ParameterConvention::Direct_Unowned, interfaceParams, interfaceYields,
250-
interfaceResults, interfaceErrorResult,
251-
expectedType->getPatternSubstitutions(), SubstitutionMap(),
252-
fn->getASTContext());
46+
CanType inputSubstType;
47+
CanType outputSubstType;
48+
CanType dynamicSelfType;
49+
return buildSILFunctionThunkType(
50+
fn, sourceType, expectedType, inputSubstType, outputSubstType, genericEnv,
51+
interfaceSubs, dynamicSelfType, withoutActuallyEscaping, thunkKind);
25352
}
25453

25554
/// Forward function arguments, handling ownership convention mismatches.

0 commit comments

Comments
 (0)