Skip to content

Commit 8cbe371

Browse files
committed
[llvm][STLExtras] Add various type_trait utilities currently present in MLIR
This revision moves several type_trait utilities from MLIR into LLVM. Namely, this revision adds: is_detected - This matches the experimental std::is_detected is_invocable - This matches the c++17 std::is_invocable function_traits - A utility traits class for getting the argument and result types of a callable type Differential Revision: https://reviews.llvm.org/D78059
1 parent f52ec5d commit 8cbe371

File tree

11 files changed

+196
-108
lines changed

11 files changed

+196
-108
lines changed

llvm/include/llvm/ADT/STLExtras.h

+73
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,79 @@ template <typename T> struct make_const_ref {
7575
typename std::add_const<T>::type>::type;
7676
};
7777

78+
/// Utilities for detecting if a given trait holds for some set of arguments
79+
/// 'Args'. For example, the given trait could be used to detect if a given type
80+
/// has a copy assignment operator:
81+
/// template<class T>
82+
/// using has_copy_assign_t = decltype(std::declval<T&>()
83+
/// = std::declval<const T&>());
84+
/// bool fooHasCopyAssign = is_detected<has_copy_assign_t, FooClass>::value;
85+
namespace detail {
86+
template <typename...> using void_t = void;
87+
template <class, template <class...> class Op, class... Args> struct detector {
88+
using value_t = std::false_type;
89+
};
90+
template <template <class...> class Op, class... Args>
91+
struct detector<void_t<Op<Args...>>, Op, Args...> {
92+
using value_t = std::true_type;
93+
};
94+
} // end namespace detail
95+
96+
template <template <class...> class Op, class... Args>
97+
using is_detected = typename detail::detector<void, Op, Args...>::value_t;
98+
99+
/// Check if a Callable type can be invoked with the given set of arg types.
100+
namespace detail {
101+
template <typename Callable, typename... Args>
102+
using is_invocable =
103+
decltype(std::declval<Callable &>()(std::declval<Args>()...));
104+
} // namespace detail
105+
106+
template <typename Callable, typename... Args>
107+
using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
108+
109+
/// This class provides various trait information about a callable object.
110+
/// * To access the number of arguments: Traits::num_args
111+
/// * To access the type of an argument: Traits::arg_t<i>
112+
/// * To access the type of the result: Traits::result_t
113+
template <typename T, bool isClass = std::is_class<T>::value>
114+
struct function_traits : public function_traits<decltype(&T::operator())> {};
115+
116+
/// Overload for class function types.
117+
template <typename ClassType, typename ReturnType, typename... Args>
118+
struct function_traits<ReturnType (ClassType::*)(Args...) const, false> {
119+
/// The number of arguments to this function.
120+
enum { num_args = sizeof...(Args) };
121+
122+
/// The result type of this function.
123+
using result_t = ReturnType;
124+
125+
/// The type of an argument to this function.
126+
template <size_t i>
127+
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
128+
};
129+
/// Overload for class function types.
130+
template <typename ClassType, typename ReturnType, typename... Args>
131+
struct function_traits<ReturnType (ClassType::*)(Args...), false>
132+
: function_traits<ReturnType (ClassType::*)(Args...) const> {};
133+
/// Overload for non-class function types.
134+
template <typename ReturnType, typename... Args>
135+
struct function_traits<ReturnType (*)(Args...), false> {
136+
/// The number of arguments to this function.
137+
enum { num_args = sizeof...(Args) };
138+
139+
/// The result type of this function.
140+
using result_t = ReturnType;
141+
142+
/// The type of an argument to this function.
143+
template <size_t i>
144+
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
145+
};
146+
/// Overload for non-class function type references.
147+
template <typename ReturnType, typename... Args>
148+
struct function_traits<ReturnType (&)(Args...), false>
149+
: public function_traits<ReturnType (*)(Args...)> {};
150+
78151
//===----------------------------------------------------------------------===//
79152
// Extra additions to <functional>
80153
//===----------------------------------------------------------------------===//

llvm/unittests/ADT/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ add_llvm_unittest(ADTTests
7373
TinyPtrVectorTest.cpp
7474
TripleTest.cpp
7575
TwineTest.cpp
76+
TypeTraitsTest.cpp
7677
WaymarkingTest.cpp
7778
)
7879

llvm/unittests/ADT/TypeTraitsTest.cpp

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//===- TypeTraitsTest.cpp - type_traits unit tests ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/ADT/STLExtras.h"
10+
#include "gtest/gtest.h"
11+
12+
using namespace llvm;
13+
14+
//===----------------------------------------------------------------------===//
15+
// function_traits
16+
//===----------------------------------------------------------------------===//
17+
18+
namespace {
19+
/// Check a callable type of the form `bool(const int &)`.
20+
template <typename CallableT> struct CheckFunctionTraits {
21+
static_assert(
22+
std::is_same<typename function_traits<CallableT>::result_t, bool>::value,
23+
"expected result_t to be `bool`");
24+
static_assert(
25+
std::is_same<typename function_traits<CallableT>::template arg_t<0>,
26+
const int &>::value,
27+
"expected arg_t<0> to be `const int &`");
28+
static_assert(function_traits<CallableT>::num_args == 1,
29+
"expected num_args to be 1");
30+
};
31+
32+
/// Test function pointers.
33+
using FuncType = bool (*)(const int &);
34+
struct CheckFunctionPointer : CheckFunctionTraits<FuncType> {};
35+
36+
static bool func(const int &v);
37+
struct CheckFunctionPointer2 : CheckFunctionTraits<decltype(&func)> {};
38+
39+
/// Test method pointers.
40+
struct Foo {
41+
bool func(const int &v);
42+
};
43+
struct CheckMethodPointer : CheckFunctionTraits<decltype(&Foo::func)> {};
44+
45+
/// Test lambda references.
46+
auto lambdaFunc = [](const int &v) -> bool { return true; };
47+
struct CheckLambda : CheckFunctionTraits<decltype(lambdaFunc)> {};
48+
49+
} // end anonymous namespace
50+
51+
//===----------------------------------------------------------------------===//
52+
// is_detected
53+
//===----------------------------------------------------------------------===//
54+
55+
namespace {
56+
struct HasFooMethod {
57+
void foo() {}
58+
};
59+
struct NoFooMethod {};
60+
61+
template <class T> using has_foo_method_t = decltype(std::declval<T &>().foo());
62+
63+
static_assert(is_detected<has_foo_method_t, HasFooMethod>::value,
64+
"expected foo method to be detected");
65+
static_assert(!is_detected<has_foo_method_t, NoFooMethod>::value,
66+
"expected no foo method to be detected");
67+
} // end anonymous namespace
68+
69+
//===----------------------------------------------------------------------===//
70+
// is_invocable
71+
//===----------------------------------------------------------------------===//
72+
73+
static void invocable_fn(int) {}
74+
75+
static_assert(is_invocable<decltype(invocable_fn), int>::value,
76+
"expected function to be invocable");
77+
static_assert(!is_invocable<decltype(invocable_fn), void *>::value,
78+
"expected function not to be invocable");
79+
static_assert(!is_invocable<decltype(invocable_fn), int, int>::value,
80+
"expected function not to be invocable");

mlir/include/mlir/ADT/TypeSwitch.h

+11-9
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ template <typename DerivedT, typename T> class TypeSwitchBase {
4646
/// Note: This inference rules for this overload are very simple: strip
4747
/// pointers and references.
4848
template <typename CallableT> DerivedT &Case(CallableT &&caseFn) {
49-
using Traits = FunctionTraits<std::decay_t<CallableT>>;
49+
using Traits = llvm::function_traits<std::decay_t<CallableT>>;
5050
using CaseT = std::remove_cv_t<std::remove_pointer_t<
5151
std::remove_reference_t<typename Traits::template arg_t<0>>>>;
5252

@@ -64,20 +64,22 @@ template <typename DerivedT, typename T> class TypeSwitchBase {
6464
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
6565
/// selected if `value` already has a suitable dyn_cast method.
6666
template <typename CastT, typename ValueT>
67-
static auto castValue(
68-
ValueT value,
69-
typename std::enable_if_t<
70-
is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
67+
static auto
68+
castValue(ValueT value,
69+
typename std::enable_if_t<
70+
llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
71+
nullptr) {
7172
return value.template dyn_cast<CastT>();
7273
}
7374

7475
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
7576
/// selected if llvm::dyn_cast should be used.
7677
template <typename CastT, typename ValueT>
77-
static auto castValue(
78-
ValueT value,
79-
typename std::enable_if_t<
80-
!is_detected<has_dyn_cast_t, ValueT, CastT>::value> * = nullptr) {
78+
static auto
79+
castValue(ValueT value,
80+
typename std::enable_if_t<
81+
!llvm::is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
82+
nullptr) {
8183
return dyn_cast<CastT>(value);
8284
}
8385

mlir/include/mlir/IR/Matchers.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,20 @@ using has_operation_or_value_matcher_t =
140140

141141
/// Statically switch to a Value matcher.
142142
template <typename MatcherClass>
143-
typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
144-
MatcherClass, Value>::value,
145-
bool>
143+
typename std::enable_if_t<
144+
llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
145+
Value>::value,
146+
bool>
146147
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
147148
return matcher.match(op->getOperand(idx));
148149
}
149150

150151
/// Statically switch to an Operation matcher.
151152
template <typename MatcherClass>
152-
typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
153-
MatcherClass, Operation *>::value,
154-
bool>
153+
typename std::enable_if_t<
154+
llvm::is_detected<detail::has_operation_or_value_matcher_t, MatcherClass,
155+
Operation *>::value,
156+
bool>
155157
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
156158
if (auto defOp = op->getOperand(idx).getDefiningOp())
157159
return matcher.match(defOp);

mlir/include/mlir/IR/OpDefinition.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -1298,16 +1298,16 @@ class Op : public OpState,
12981298
/// If 'T' is the same interface as 'interfaceID' return the concept
12991299
/// instance.
13001300
template <typename T>
1301-
static typename std::enable_if<is_detected<has_get_interface_id, T>::value,
1302-
void *>::type
1301+
static typename std::enable_if<
1302+
llvm::is_detected<has_get_interface_id, T>::value, void *>::type
13031303
lookup(TypeID interfaceID) {
13041304
return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
13051305
}
13061306

13071307
/// 'T' is known to not be an interface, return nullptr.
13081308
template <typename T>
1309-
static typename std::enable_if<!is_detected<has_get_interface_id, T>::value,
1310-
void *>::type
1309+
static typename std::enable_if<
1310+
!llvm::is_detected<has_get_interface_id, T>::value, void *>::type
13111311
lookup(TypeID) {
13121312
return nullptr;
13131313
}

mlir/include/mlir/Pass/AnalysisManager.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ using has_is_invalidated = decltype(std::declval<T &>().isInvalidated(
7171

7272
/// Implementation of 'isInvalidated' if the analysis provides a definition.
7373
template <typename AnalysisT>
74-
std::enable_if_t<is_detected<has_is_invalidated, AnalysisT>::value, bool>
74+
std::enable_if_t<llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
7575
isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
7676
return analysis.isInvalidated(pa);
7777
}
7878
/// Default implementation of 'isInvalidated'.
7979
template <typename AnalysisT>
80-
std::enable_if_t<!is_detected<has_is_invalidated, AnalysisT>::value, bool>
80+
std::enable_if_t<!llvm::is_detected<has_is_invalidated, AnalysisT>::value, bool>
8181
isInvalidated(AnalysisT &analysis, const PreservedAnalyses &pa) {
8282
return !pa.isPreserved<AnalysisT>();
8383
}

mlir/include/mlir/Support/STLExtras.h

-72
Original file line numberDiff line numberDiff line change
@@ -88,37 +88,6 @@ inline void interleaveComma(const Container &c, raw_ostream &os) {
8888
interleaveComma(c, os, [&](const T &a) { os << a; });
8989
}
9090

91-
/// Utilities for detecting if a given trait holds for some set of arguments
92-
/// 'Args'. For example, the given trait could be used to detect if a given type
93-
/// has a copy assignment operator:
94-
/// template<class T>
95-
/// using has_copy_assign_t = decltype(std::declval<T&>()
96-
/// = std::declval<const T&>());
97-
/// bool fooHasCopyAssign = is_detected<has_copy_assign_t, FooClass>::value;
98-
namespace detail {
99-
template <typename...> using void_t = void;
100-
template <class, template <class...> class Op, class... Args> struct detector {
101-
using value_t = std::false_type;
102-
};
103-
template <template <class...> class Op, class... Args>
104-
struct detector<void_t<Op<Args...>>, Op, Args...> {
105-
using value_t = std::true_type;
106-
};
107-
} // end namespace detail
108-
109-
template <template <class...> class Op, class... Args>
110-
using is_detected = typename detail::detector<void, Op, Args...>::value_t;
111-
112-
/// Check if a Callable type can be invoked with the given set of arg types.
113-
namespace detail {
114-
template <typename Callable, typename... Args>
115-
using is_invocable =
116-
decltype(std::declval<Callable &>()(std::declval<Args>()...));
117-
} // namespace detail
118-
119-
template <typename Callable, typename... Args>
120-
using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
121-
12291
//===----------------------------------------------------------------------===//
12392
// Extra additions to <iterator>
12493
//===----------------------------------------------------------------------===//
@@ -356,47 +325,6 @@ template <typename ContainerTy> bool has_single_element(ContainerTy &&c) {
356325
return it != e && std::next(it) == e;
357326
}
358327

359-
//===----------------------------------------------------------------------===//
360-
// Extra additions to <type_traits>
361-
//===----------------------------------------------------------------------===//
362-
363-
/// This class provides various trait information about a callable object.
364-
/// * To access the number of arguments: Traits::num_args
365-
/// * To access the type of an argument: Traits::arg_t<i>
366-
/// * To access the type of the result: Traits::result_t<i>
367-
template <typename T, bool isClass = std::is_class<T>::value>
368-
struct FunctionTraits : public FunctionTraits<decltype(&T::operator())> {};
369-
370-
/// Overload for class function types.
371-
template <typename ClassType, typename ReturnType, typename... Args>
372-
struct FunctionTraits<ReturnType (ClassType::*)(Args...) const, false> {
373-
/// The number of arguments to this function.
374-
enum { num_args = sizeof...(Args) };
375-
376-
/// The result type of this function.
377-
using result_t = ReturnType;
378-
379-
/// The type of an argument to this function.
380-
template <size_t i>
381-
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
382-
};
383-
/// Overload for non-class function types.
384-
template <typename ReturnType, typename... Args>
385-
struct FunctionTraits<ReturnType (*)(Args...), false> {
386-
/// The number of arguments to this function.
387-
enum { num_args = sizeof...(Args) };
388-
389-
/// The result type of this function.
390-
using result_t = ReturnType;
391-
392-
/// The type of an argument to this function.
393-
template <size_t i>
394-
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
395-
};
396-
/// Overload for non-class function type references.
397-
template <typename ReturnType, typename... Args>
398-
struct FunctionTraits<ReturnType (&)(Args...), false>
399-
: public FunctionTraits<ReturnType (*)(Args...)> {};
400328
} // end namespace mlir
401329

402330
#endif // MLIR_SUPPORT_STLEXTRAS_H

mlir/include/mlir/Support/StorageUniquer.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class StorageUniquer {
215215
/// 'ImplTy::getKey' function for the provided arguments.
216216
template <typename ImplTy, typename... Args>
217217
static typename std::enable_if<
218-
is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
218+
llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
219219
typename ImplTy::KeyTy>::type
220220
getKey(Args &&... args) {
221221
return ImplTy::getKey(args...);
@@ -224,7 +224,7 @@ class StorageUniquer {
224224
/// the 'ImplTy::KeyTy' with the provided arguments.
225225
template <typename ImplTy, typename... Args>
226226
static typename std::enable_if<
227-
!is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
227+
!llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
228228
typename ImplTy::KeyTy>::type
229229
getKey(Args &&... args) {
230230
return typename ImplTy::KeyTy(args...);
@@ -238,17 +238,17 @@ class StorageUniquer {
238238
/// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
239239
template <typename ImplTy, typename DerivedKey>
240240
static typename std::enable_if<
241-
is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
241+
llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
242242
::llvm::hash_code>::type
243243
getHash(unsigned kind, const DerivedKey &derivedKey) {
244244
return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
245245
}
246246
/// If there is no 'ImplTy::hashKey' default to using the
247247
/// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
248248
template <typename ImplTy, typename DerivedKey>
249-
static typename std::enable_if<
250-
!is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
251-
::llvm::hash_code>::type
249+
static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
250+
ImplTy, DerivedKey>::value,
251+
::llvm::hash_code>::type
252252
getHash(unsigned kind, const DerivedKey &derivedKey) {
253253
return llvm::hash_combine(
254254
kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));

0 commit comments

Comments
 (0)