Skip to content

Commit 204c3b5

Browse files
committed
[llvm][STLExtras] Move various iterator/range utilities from MLIR to LLVM
This revision moves the various range utilities present in MLIR to LLVM to enable greater reuse. This revision moves the following utilities: * indexed_accessor_* This is set of utility iterator/range base classes that allow for building a range class where the iterators are represented by an object+index pair. * make_second_range Given a range of pairs, returns a range iterating over the `second` elements. * hasSingleElement Returns if the given range has 1 element. size() == 1 checks end up being very common, but size() is not always O(1) (e.g., ilist). This method provides O(1) checks for those cases. Differential Revision: https://reviews.llvm.org/D78064
1 parent 8cbe371 commit 204c3b5

File tree

24 files changed

+275
-303
lines changed

24 files changed

+275
-303
lines changed

llvm/include/llvm/ADT/STLExtras.h

+213
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,12 @@ constexpr bool empty(const T &RangeOrContainer) {
263263
return adl_begin(RangeOrContainer) == adl_end(RangeOrContainer);
264264
}
265265

266+
/// Returns true of the given range only contains a single element.
267+
template <typename ContainerTy> bool hasSingleElement(ContainerTy &&c) {
268+
auto it = std::begin(c), e = std::end(c);
269+
return it != e && std::next(it) == e;
270+
}
271+
266272
/// Return a range covering \p RangeOrContainer with the first N elements
267273
/// excluded.
268274
template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N) {
@@ -1017,6 +1023,213 @@ detail::concat_range<ValueT, RangeTs...> concat(RangeTs &&... Ranges) {
10171023
std::forward<RangeTs>(Ranges)...);
10181024
}
10191025

1026+
/// A utility class used to implement an iterator that contains some base object
1027+
/// and an index. The iterator moves the index but keeps the base constant.
1028+
template <typename DerivedT, typename BaseT, typename T,
1029+
typename PointerT = T *, typename ReferenceT = T &>
1030+
class indexed_accessor_iterator
1031+
: public llvm::iterator_facade_base<DerivedT,
1032+
std::random_access_iterator_tag, T,
1033+
std::ptrdiff_t, PointerT, ReferenceT> {
1034+
public:
1035+
ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const {
1036+
assert(base == rhs.base && "incompatible iterators");
1037+
return index - rhs.index;
1038+
}
1039+
bool operator==(const indexed_accessor_iterator &rhs) const {
1040+
return base == rhs.base && index == rhs.index;
1041+
}
1042+
bool operator<(const indexed_accessor_iterator &rhs) const {
1043+
assert(base == rhs.base && "incompatible iterators");
1044+
return index < rhs.index;
1045+
}
1046+
1047+
DerivedT &operator+=(ptrdiff_t offset) {
1048+
this->index += offset;
1049+
return static_cast<DerivedT &>(*this);
1050+
}
1051+
DerivedT &operator-=(ptrdiff_t offset) {
1052+
this->index -= offset;
1053+
return static_cast<DerivedT &>(*this);
1054+
}
1055+
1056+
/// Returns the current index of the iterator.
1057+
ptrdiff_t getIndex() const { return index; }
1058+
1059+
/// Returns the current base of the iterator.
1060+
const BaseT &getBase() const { return base; }
1061+
1062+
protected:
1063+
indexed_accessor_iterator(BaseT base, ptrdiff_t index)
1064+
: base(base), index(index) {}
1065+
BaseT base;
1066+
ptrdiff_t index;
1067+
};
1068+
1069+
namespace detail {
1070+
/// The class represents the base of a range of indexed_accessor_iterators. It
1071+
/// provides support for many different range functionalities, e.g.
1072+
/// drop_front/slice/etc.. Derived range classes must implement the following
1073+
/// static methods:
1074+
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
1075+
/// - Dereference an iterator pointing to the base object at the given
1076+
/// index.
1077+
/// * BaseT offset_base(const BaseT &base, ptrdiff_t index)
1078+
/// - Return a new base that is offset from the provide base by 'index'
1079+
/// elements.
1080+
template <typename DerivedT, typename BaseT, typename T,
1081+
typename PointerT = T *, typename ReferenceT = T &>
1082+
class indexed_accessor_range_base {
1083+
public:
1084+
using RangeBaseT =
1085+
indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>;
1086+
1087+
/// An iterator element of this range.
1088+
class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
1089+
PointerT, ReferenceT> {
1090+
public:
1091+
// Index into this iterator, invoking a static method on the derived type.
1092+
ReferenceT operator*() const {
1093+
return DerivedT::dereference_iterator(this->getBase(), this->getIndex());
1094+
}
1095+
1096+
private:
1097+
iterator(BaseT owner, ptrdiff_t curIndex)
1098+
: indexed_accessor_iterator<iterator, BaseT, T, PointerT, ReferenceT>(
1099+
owner, curIndex) {}
1100+
1101+
/// Allow access to the constructor.
1102+
friend indexed_accessor_range_base<DerivedT, BaseT, T, PointerT,
1103+
ReferenceT>;
1104+
};
1105+
1106+
indexed_accessor_range_base(iterator begin, iterator end)
1107+
: base(DerivedT::offset_base(begin.getBase(), begin.getIndex())),
1108+
count(end.getIndex() - begin.getIndex()) {}
1109+
indexed_accessor_range_base(const iterator_range<iterator> &range)
1110+
: indexed_accessor_range_base(range.begin(), range.end()) {}
1111+
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
1112+
: base(base), count(count) {}
1113+
1114+
iterator begin() const { return iterator(base, 0); }
1115+
iterator end() const { return iterator(base, count); }
1116+
ReferenceT operator[](unsigned index) const {
1117+
assert(index < size() && "invalid index for value range");
1118+
return DerivedT::dereference_iterator(base, index);
1119+
}
1120+
1121+
/// Compare this range with another.
1122+
template <typename OtherT> bool operator==(const OtherT &other) {
1123+
return size() == std::distance(other.begin(), other.end()) &&
1124+
std::equal(begin(), end(), other.begin());
1125+
}
1126+
1127+
/// Return the size of this range.
1128+
size_t size() const { return count; }
1129+
1130+
/// Return if the range is empty.
1131+
bool empty() const { return size() == 0; }
1132+
1133+
/// Drop the first N elements, and keep M elements.
1134+
DerivedT slice(size_t n, size_t m) const {
1135+
assert(n + m <= size() && "invalid size specifiers");
1136+
return DerivedT(DerivedT::offset_base(base, n), m);
1137+
}
1138+
1139+
/// Drop the first n elements.
1140+
DerivedT drop_front(size_t n = 1) const {
1141+
assert(size() >= n && "Dropping more elements than exist");
1142+
return slice(n, size() - n);
1143+
}
1144+
/// Drop the last n elements.
1145+
DerivedT drop_back(size_t n = 1) const {
1146+
assert(size() >= n && "Dropping more elements than exist");
1147+
return DerivedT(base, size() - n);
1148+
}
1149+
1150+
/// Take the first n elements.
1151+
DerivedT take_front(size_t n = 1) const {
1152+
return n < size() ? drop_back(size() - n)
1153+
: static_cast<const DerivedT &>(*this);
1154+
}
1155+
1156+
/// Take the last n elements.
1157+
DerivedT take_back(size_t n = 1) const {
1158+
return n < size() ? drop_front(size() - n)
1159+
: static_cast<const DerivedT &>(*this);
1160+
}
1161+
1162+
/// Allow conversion to any type accepting an iterator_range.
1163+
template <typename RangeT, typename = std::enable_if_t<std::is_constructible<
1164+
RangeT, iterator_range<iterator>>::value>>
1165+
operator RangeT() const {
1166+
return RangeT(iterator_range<iterator>(*this));
1167+
}
1168+
1169+
protected:
1170+
indexed_accessor_range_base(const indexed_accessor_range_base &) = default;
1171+
indexed_accessor_range_base(indexed_accessor_range_base &&) = default;
1172+
indexed_accessor_range_base &
1173+
operator=(const indexed_accessor_range_base &) = default;
1174+
1175+
/// The base that owns the provided range of values.
1176+
BaseT base;
1177+
/// The size from the owning range.
1178+
ptrdiff_t count;
1179+
};
1180+
} // end namespace detail
1181+
1182+
/// This class provides an implementation of a range of
1183+
/// indexed_accessor_iterators where the base is not indexable. Ranges with
1184+
/// bases that are offsetable should derive from indexed_accessor_range_base
1185+
/// instead. Derived range classes are expected to implement the following
1186+
/// static method:
1187+
/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index)
1188+
/// - Dereference an iterator pointing to a parent base at the given index.
1189+
template <typename DerivedT, typename BaseT, typename T,
1190+
typename PointerT = T *, typename ReferenceT = T &>
1191+
class indexed_accessor_range
1192+
: public detail::indexed_accessor_range_base<
1193+
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
1194+
public:
1195+
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
1196+
: detail::indexed_accessor_range_base<
1197+
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
1198+
std::make_pair(base, startIndex), count) {}
1199+
using detail::indexed_accessor_range_base<
1200+
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT,
1201+
ReferenceT>::indexed_accessor_range_base;
1202+
1203+
/// Returns the current base of the range.
1204+
const BaseT &getBase() const { return this->base.first; }
1205+
1206+
/// Returns the current start index of the range.
1207+
ptrdiff_t getStartIndex() const { return this->base.second; }
1208+
1209+
/// See `detail::indexed_accessor_range_base` for details.
1210+
static std::pair<BaseT, ptrdiff_t>
1211+
offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
1212+
// We encode the internal base as a pair of the derived base and a start
1213+
// index into the derived base.
1214+
return std::make_pair(base.first, base.second + index);
1215+
}
1216+
/// See `detail::indexed_accessor_range_base` for details.
1217+
static ReferenceT
1218+
dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
1219+
ptrdiff_t index) {
1220+
return DerivedT::dereference(base.first, base.second + index);
1221+
}
1222+
};
1223+
1224+
/// Given a container of pairs, return a range over the second elements.
1225+
template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
1226+
return llvm::map_range(
1227+
std::forward<ContainerTy>(c),
1228+
[](decltype((*std::begin(c))) elt) -> decltype((elt.second)) {
1229+
return elt.second;
1230+
});
1231+
}
1232+
10201233
//===----------------------------------------------------------------------===//
10211234
// Extra additions to <utility>
10221235
//===----------------------------------------------------------------------===//

llvm/unittests/Support/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_llvm_unittest(SupportTests
4040
FormatVariadicTest.cpp
4141
GlobPatternTest.cpp
4242
Host.cpp
43+
IndexedAccessorTest.cpp
4344
ItaniumManglingCanonicalizerTest.cpp
4445
JSONTest.cpp
4546
KnownBitsTest.cpp

mlir/unittests/Support/IndexedAccessorTest.cpp renamed to llvm/unittests/Support/IndexedAccessorTest.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Support/STLExtras.h"
109
#include "llvm/ADT/ArrayRef.h"
10+
#include "llvm/ADT/STLExtras.h"
1111
#include "gmock/gmock.h"
1212

13-
using namespace mlir;
14-
using namespace mlir::detail;
13+
using namespace llvm;
14+
using namespace llvm::detail;
1515

1616
namespace {
1717
/// Simple indexed accessor range that wraps an array.
@@ -24,7 +24,7 @@ struct ArrayIndexedAccessorRange
2424
using indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *,
2525
T>::indexed_accessor_range;
2626

27-
/// See `indexed_accessor_range` for details.
27+
/// See `llvm::indexed_accessor_range` for details.
2828
static T &dereference(T *data, ptrdiff_t index) { return data[index]; }
2929
};
3030
} // end anonymous namespace

mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,16 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
290290

291291
/// Range class for element types.
292292
class ElementTypeRange
293-
: public ::mlir::detail::indexed_accessor_range_base<
293+
: public ::llvm::detail::indexed_accessor_range_base<
294294
ElementTypeRange, const Type *, Type, Type, Type> {
295295
private:
296296
using RangeBaseT::RangeBaseT;
297297

298-
/// See `mlir::detail::indexed_accessor_range_base` for details.
298+
/// See `llvm::detail::indexed_accessor_range_base` for details.
299299
static const Type *offset_base(const Type *object, ptrdiff_t index) {
300300
return object + index;
301301
}
302-
/// See `mlir::detail::indexed_accessor_range_base` for details.
302+
/// See `llvm::detail::indexed_accessor_range_base` for details.
303303
static Type dereference_iterator(const Type *object, ptrdiff_t index) {
304304
return object[index];
305305
}

mlir/include/mlir/IR/Attributes.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -648,13 +648,14 @@ using DenseIterPtrAndSplat =
648648
template <typename ConcreteT, typename T, typename PointerT = T *,
649649
typename ReferenceT = T &>
650650
class DenseElementIndexedIteratorImpl
651-
: public indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
652-
PointerT, ReferenceT> {
651+
: public llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
652+
PointerT, ReferenceT> {
653653
protected:
654654
DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
655655
size_t dataIndex)
656-
: indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, PointerT,
657-
ReferenceT>({data, isSplat}, dataIndex) {}
656+
: llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
657+
PointerT, ReferenceT>({data, isSplat},
658+
dataIndex) {}
658659

659660
/// Return the current index for this iterator, adjusted for the case of a
660661
/// splat.
@@ -746,8 +747,9 @@ class DenseElementsAttr
746747
/// A utility iterator that allows walking over the internal Attribute values
747748
/// of a DenseElementsAttr.
748749
class AttributeElementIterator
749-
: public indexed_accessor_iterator<AttributeElementIterator, const void *,
750-
Attribute, Attribute, Attribute> {
750+
: public llvm::indexed_accessor_iterator<AttributeElementIterator,
751+
const void *, Attribute,
752+
Attribute, Attribute> {
751753
public:
752754
/// Accesses the Attribute value at this iterator position.
753755
Attribute operator*() const;

mlir/include/mlir/IR/BlockSupport.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,19 @@ class PredecessorIterator final
5454

5555
/// This class implements the successor iterators for Block.
5656
class SuccessorRange final
57-
: public detail::indexed_accessor_range_base<SuccessorRange, BlockOperand *,
58-
Block *, Block *, Block *> {
57+
: public llvm::detail::indexed_accessor_range_base<
58+
SuccessorRange, BlockOperand *, Block *, Block *, Block *> {
5959
public:
6060
using RangeBaseT::RangeBaseT;
6161
SuccessorRange(Block *block);
6262
SuccessorRange(Operation *term);
6363

6464
private:
65-
/// See `detail::indexed_accessor_range_base` for details.
65+
/// See `llvm::detail::indexed_accessor_range_base` for details.
6666
static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) {
6767
return object + index;
6868
}
69-
/// See `detail::indexed_accessor_range_base` for details.
69+
/// See `llvm::detail::indexed_accessor_range_base` for details.
7070
static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) {
7171
return object[index].get();
7272
}

mlir/include/mlir/IR/OpImplementation.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class OpAsmPrinter {
113113
void printArrowTypeList(TypeRange &&types) {
114114
auto &os = getStream() << " -> ";
115115

116-
bool wrapped = !has_single_element(types) ||
116+
bool wrapped = !llvm::hasSingleElement(types) ||
117117
(*types.begin()).template isa<FunctionType>();
118118
if (wrapped)
119119
os << '(';

0 commit comments

Comments
 (0)