Skip to content

Commit bb1052c

Browse files
authored
[AutoDiff upstream] Upstream @derivative attribute type-checking. (#28738)
The `@derivative` attribute registers a function as a derivative of another function-like declaration: a `func`, `init`, `subscript`, or `var` computed property declaration. The `@derivative` attribute also has an optional `wrt:` clause specifying the parameters that are differentiated "with respect to", i.e. the differentiation parameters. The differentiation parameters must conform to the `Differentiable` protocol. If the `wrt:` clause is unspecified, the differentiation parameters are inferred to be all parameters that conform to `Differentiable`. `@derivative` attribute type-checking verifies that the type of the derivative function declaration is consistent with the type of the referenced original declaration and the differentiation parameters. The `@derivative` attribute is gated by the `-enable-experimental-differentiable-programming` flag. Resolves TF-829.
1 parent d91a675 commit bb1052c

11 files changed

+1518
-8
lines changed

include/swift/AST/ASTContext.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ namespace swift {
6262
class Decl;
6363
class DeclContext;
6464
class DefaultArgumentInitializer;
65+
class DerivativeAttr;
6566
class ExtensionDecl;
6667
class ForeignRepresentationInfo;
6768
class FuncDecl;
@@ -282,6 +283,16 @@ class ASTContext final {
282283
/// across invocations of both the parser and the type-checker.
283284
unsigned NextAutoClosureDiscriminator = 0;
284285

286+
/// Cache of `@derivative` attributes keyed by parameter indices and
287+
/// derivative function kind. Used to diagnose duplicate `@derivative`
288+
/// attributes for the same key.
289+
// TODO(TF-1042): remove `DerivativeAttrs` from `ASTContext`. Serialize
290+
// derivative function configurations per original `AbstractFunctionDecl`.
291+
llvm::DenseMap<
292+
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
293+
DerivativeAttr *>
294+
DerivativeAttrs;
295+
285296
private:
286297
/// The current generation number, which reflects the number of
287298
/// times that external modules have been loaded.

include/swift/AST/Attr.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,21 @@ class DifferentiableAttr final
17651765
}
17661766
};
17671767

1768-
/// Attribute that registers a function as a derivative of another function.
1768+
/// The `@derivative` attribute registers a function as a derivative of another
1769+
/// function-like declaration: a 'func', 'init', 'subscript', or 'var' computed
1770+
/// property declaration.
1771+
///
1772+
/// The `@derivative` attribute also has an optional `wrt:` clause specifying
1773+
/// the parameters that are differentiated "with respect to", i.e. the
1774+
/// differentiation parameters. The differentiation parameters must conform to
1775+
/// the `Differentiable` protocol.
1776+
///
1777+
/// If the `wrt:` clause is unspecified, the differentiation parameters are
1778+
/// inferred to be all parameters that conform to `Differentiable`.
1779+
///
1780+
/// `@derivative` attribute type-checking verifies that the type of the
1781+
/// derivative function declaration is consistent with the type of the
1782+
/// referenced original declaration and the differentiation parameters.
17691783
///
17701784
/// Examples:
17711785
/// @derivative(of: sin(_:))

include/swift/AST/AutoDiff.h

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===--- AutoDiff.h - Swift Automatic Differentiation ---------------------===//
1+
//===--- AutoDiff.h - Swift automatic differentiation utilities -----------===//
22
//
33
// This source file is part of the Swift.org open source project
44
//
@@ -10,7 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212
//
13-
// This file defines AST support for automatic differentiation.
13+
// This file defines utilities for automatic differentiation.
1414
//
1515
//===----------------------------------------------------------------------===//
1616

@@ -21,11 +21,14 @@
2121

2222
#include "swift/AST/Identifier.h"
2323
#include "swift/AST/IndexSubset.h"
24-
#include "swift/Basic/SourceLoc.h"
24+
#include "swift/AST/Type.h"
2525
#include "swift/Basic/Range.h"
26+
#include "swift/Basic/SourceLoc.h"
2627

2728
namespace swift {
2829

30+
class AnyFunctionType;
31+
2932
/// A function type differentiability kind.
3033
enum class DifferentiabilityKind : uint8_t {
3134
NonDifferentiable = 0,
@@ -130,6 +133,46 @@ class ParsedAutoDiffParameter {
130133
}
131134
};
132135

136+
/// Automatic differentiation utility namespace.
137+
namespace autodiff {
138+
139+
/// Appends the subset's parameter's types to `results`, in the order in
140+
/// which they appear in the function type.
141+
void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
142+
SmallVectorImpl<Type> &results,
143+
bool reverseCurryLevels = false);
144+
145+
} // end namespace autodiff
146+
133147
} // end namespace swift
134148

149+
namespace llvm {
150+
151+
using swift::AutoDiffDerivativeFunctionKind;
152+
153+
template <typename T> struct DenseMapInfo;
154+
155+
template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
156+
static AutoDiffDerivativeFunctionKind getEmptyKey() {
157+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
158+
DenseMapInfo<unsigned>::getEmptyKey());
159+
}
160+
161+
static AutoDiffDerivativeFunctionKind getTombstoneKey() {
162+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
163+
DenseMapInfo<unsigned>::getTombstoneKey());
164+
}
165+
166+
static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
167+
return DenseMapInfo<unsigned>::getHashValue(Val);
168+
}
169+
170+
static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
171+
const AutoDiffDerivativeFunctionKind &RHS) {
172+
return LHS == RHS;
173+
}
174+
};
175+
176+
} // end namespace llvm
177+
135178
#endif // SWIFT_AST_AUTODIFF_H

include/swift/AST/DiagnosticsSema.def

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2756,6 +2756,9 @@ ERROR(attr_not_on_variadic_parameters,none,
27562756
ERROR(attr_not_on_subscript_parameters,none,
27572757
"'%0' must not be used on subscript parameters", (StringRef))
27582758

2759+
ERROR(attr_ambiguous_reference_to_decl,none,
2760+
"ambiguous reference to %0 in '@%1' attribute", (DeclNameRef, StringRef))
2761+
27592762
ERROR(override_final,none,
27602763
"%0 overrides a 'final' %1", (DescriptiveDeclKind, DescriptiveDeclKind))
27612764

@@ -2891,6 +2894,68 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
28912894
"containing type %0 does not conform to protocol %1",
28922895
(DeclName, DeclName))
28932896

2897+
// @derivative
2898+
ERROR(derivative_attr_expected_result_tuple,none,
2899+
"'@derivative(of:)' attribute requires function to return a two-element "
2900+
"tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' "
2901+
"or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ())
2902+
ERROR(derivative_attr_invalid_result_tuple_value_label,none,
2903+
"'@derivative(of:)' attribute requires function to return a two-element "
2904+
"tuple (first element must have label 'value:')", ())
2905+
ERROR(derivative_attr_invalid_result_tuple_func_label,none,
2906+
"'@derivative(of:)' attribute requires function to return a two-element "
2907+
"tuple (second element must have label 'pullback:' or 'differential:')",
2908+
())
2909+
ERROR(derivative_attr_result_value_not_differentiable,none,
2910+
"'@derivative(of:)' attribute requires function to return a two-element "
2911+
"tuple (first element type %0 must conform to 'Differentiable')", (Type))
2912+
ERROR(derivative_attr_result_func_type_mismatch,none,
2913+
"function result's %0 type does not match %1", (Identifier, DeclName))
2914+
NOTE(derivative_attr_result_func_type_mismatch_note,none,
2915+
"%0 does not have expected type %1", (Identifier, Type))
2916+
NOTE(derivative_attr_result_func_original_note,none,
2917+
"%0 defined here", (DeclName))
2918+
ERROR(derivative_attr_not_in_same_file_as_original,none,
2919+
"derivative not in the same file as the original function", ())
2920+
ERROR(derivative_attr_original_stored_property_unsupported,none,
2921+
"cannot register derivative for stored property %0", (DeclNameRef))
2922+
ERROR(derivative_attr_original_already_has_derivative,none,
2923+
"a derivative already exists for %0", (DeclName))
2924+
NOTE(derivative_attr_duplicate_note,none,
2925+
"other attribute declared here", ())
2926+
2927+
// Automatic differentiation attributes
2928+
ERROR(autodiff_attr_original_decl_invalid_kind,none,
2929+
"%0 is not a 'func', 'init', 'subscript', or 'var' computed property "
2930+
"declaration", (DeclNameRef))
2931+
ERROR(autodiff_attr_original_decl_none_valid_found,none,
2932+
"could not find function %0 with expected type %1", (DeclNameRef, Type))
2933+
ERROR(autodiff_attr_original_decl_not_same_type_context,none,
2934+
"%0 is not defined in the current type context", (DeclNameRef))
2935+
2936+
// differentiation `wrt` parameters clause
2937+
ERROR(diff_function_no_parameters,none,
2938+
"%0 has no parameters to differentiate with respect to", (DeclName))
2939+
ERROR(diff_params_clause_param_name_unknown,none,
2940+
"unknown parameter name %0", (Identifier))
2941+
ERROR(diff_params_clause_self_instance_method_only,none,
2942+
"'self' parameter is only applicable to instance methods", ())
2943+
ERROR(diff_params_clause_self_must_be_first,none,
2944+
"'self' parameter must come first in the parameter list", ())
2945+
ERROR(diff_params_clause_params_not_original_order,none,
2946+
"parameters must be specified in original order", ())
2947+
ERROR(diff_params_clause_param_index_out_of_range,none,
2948+
"parameter index is larger than total number of parameters", ())
2949+
ERROR(diff_params_clause_no_inferred_parameters,PointsToFirstBadToken,
2950+
"no differentiation parameters could be inferred; must differentiate "
2951+
"with respect to at least one parameter conforming to 'Differentiable'",
2952+
())
2953+
ERROR(diff_params_clause_cannot_diff_wrt_inout_parameter,none,
2954+
"cannot differentiate with respect to 'inout' parameter (%0)", (Type))
2955+
ERROR(diff_params_clause_param_not_differentiable,none,
2956+
"can only differentiate with respect to parameters that conform to "
2957+
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
2958+
28942959
//------------------------------------------------------------------------------
28952960
// MARK: Type Check Expressions
28962961
//------------------------------------------------------------------------------

include/swift/AST/KnownIdentifiers.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ IDENTIFIER_(nsError)
197197
// Custom string interpolation type used by os log APIs.
198198
IDENTIFIER(OSLogMessage)
199199

200+
// Differentiable programming
201+
IDENTIFIER(TangentVector)
202+
200203
#undef IDENTIFIER
201204
#undef IDENTIFIER_
202205
#undef IDENTIFIER_WITH_NAME

lib/AST/AutoDiff.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
//===--- AutoDiff.cpp - Swift automatic differentiation utilities ---------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "swift/AST/AutoDiff.h"
14+
#include "swift/AST/Types.h"
15+
16+
using namespace swift;
17+
18+
// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
19+
// most once (for curried method types) is sufficient.
20+
static void unwrapCurryLevels(AnyFunctionType *fnTy,
21+
SmallVectorImpl<AnyFunctionType *> &results) {
22+
while (fnTy != nullptr) {
23+
results.push_back(fnTy);
24+
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
25+
}
26+
}
27+
28+
static unsigned countNumFlattenedElementTypes(Type type) {
29+
if (auto *tupleTy = type->getCanonicalType()->getAs<TupleType>())
30+
return accumulate(tupleTy->getElementTypes(), 0,
31+
[&](unsigned num, Type type) {
32+
return num + countNumFlattenedElementTypes(type);
33+
});
34+
return 1;
35+
}
36+
37+
// TODO(TF-874): Simplify this helper and remove the `reverseCurryLevels` flag.
38+
// See TF-874 for WIP.
39+
void autodiff::getSubsetParameterTypes(IndexSubset *subset,
40+
AnyFunctionType *type,
41+
SmallVectorImpl<Type> &results,
42+
bool reverseCurryLevels) {
43+
SmallVector<AnyFunctionType *, 2> curryLevels;
44+
unwrapCurryLevels(type, curryLevels);
45+
46+
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
47+
unsigned currentOffset = 0;
48+
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
49+
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
50+
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
51+
}
52+
53+
// If `reverseCurryLevels` is true, reverse the curry levels and offsets.
54+
if (reverseCurryLevels) {
55+
std::reverse(curryLevels.begin(), curryLevels.end());
56+
std::reverse(curryLevelParameterIndexOffsets.begin(),
57+
curryLevelParameterIndexOffsets.end());
58+
}
59+
60+
for (unsigned curryLevelIndex : indices(curryLevels)) {
61+
auto *curryLevel = curryLevels[curryLevelIndex];
62+
unsigned parameterIndexOffset =
63+
curryLevelParameterIndexOffsets[curryLevelIndex];
64+
for (unsigned paramIndex : range(curryLevel->getNumParams()))
65+
if (subset->contains(parameterIndexOffset + paramIndex))
66+
results.push_back(curryLevel->getParams()[paramIndex].getOldType());
67+
}
68+
}

lib/AST/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ add_swift_host_library(swiftAST STATIC
2929
ASTVerifier.cpp
3030
ASTWalker.cpp
3131
Attr.cpp
32-
IndexSubset.cpp
32+
AutoDiff.cpp
3333
Availability.cpp
3434
AvailabilitySpec.cpp
3535
Builtins.cpp
@@ -55,6 +55,7 @@ add_swift_host_library(swiftAST STATIC
5555
Identifier.cpp
5656
ImportCache.cpp
5757
IncrementalRanges.cpp
58+
IndexSubset.cpp
5859
InlinableText.cpp
5960
LayoutConstraint.cpp
6061
Module.cpp

0 commit comments

Comments
 (0)