13
13
#include " mlir/Analysis/AffineStructures.h"
14
14
#include " mlir/Analysis/LinearTransform.h"
15
15
#include " mlir/Analysis/Presburger/Simplex.h"
16
+ #include " mlir/Analysis/Presburger/Utils.h"
16
17
#include " mlir/Dialect/Affine/IR/AffineOps.h"
17
18
#include " mlir/Dialect/Affine/IR/AffineValueMap.h"
18
19
#include " mlir/Dialect/Arithmetic/IR/Arithmetic.h"
@@ -700,14 +701,13 @@ void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
700
701
// Searches for a constraint with a non-zero coefficient at `colIdx` in
701
702
// equality (isEq=true) or inequality (isEq=false) constraints.
702
703
// Returns true and sets row found in search in `rowIdx`, false otherwise.
703
- static bool findConstraintWithNonZeroAt (const FlatAffineConstraints &cst,
704
- unsigned colIdx, bool isEq,
705
- unsigned *rowIdx) {
706
- assert (colIdx < cst.getNumCols () && " position out of bounds" );
704
+ bool FlatAffineConstraints::findConstraintWithNonZeroAt (
705
+ unsigned colIdx, bool isEq, unsigned *rowIdx) const {
706
+ assert (colIdx < getNumCols () && " position out of bounds" );
707
707
auto at = [&](unsigned rowIdx) -> int64_t {
708
- return isEq ? cst. atEq (rowIdx, colIdx) : cst. atIneq (rowIdx, colIdx);
708
+ return isEq ? atEq (rowIdx, colIdx) : atIneq (rowIdx, colIdx);
709
709
};
710
- unsigned e = isEq ? cst. getNumEqualities () : cst. getNumInequalities ();
710
+ unsigned e = isEq ? getNumEqualities () : getNumInequalities ();
711
711
for (*rowIdx = 0 ; *rowIdx < e; ++(*rowIdx)) {
712
712
if (at (*rowIdx) != 0 ) {
713
713
return true ;
@@ -1203,145 +1203,6 @@ bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
1203
1203
return true ;
1204
1204
}
1205
1205
1206
- // / Check if the pos^th identifier can be represented as a division using upper
1207
- // / bound inequality at position `ubIneq` and lower bound inequality at position
1208
- // / `lbIneq`.
1209
- // /
1210
- // / Let `id` be the pos^th identifier, then `id` is equivalent to
1211
- // / `expr floordiv divisor` if there are constraints of the form:
1212
- // / 0 <= expr - divisor * id <= divisor - 1
1213
- // / Rearranging, we have:
1214
- // / divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id'
1215
- // / -divisor * id + expr >= 0 <-- Upper bound for 'id'
1216
- // /
1217
- // / For example:
1218
- // / 32*k >= 16*i + j - 31 <-- Lower bound for 'k'
1219
- // / 32*k <= 16*i + j <-- Upper bound for 'k'
1220
- // / expr = 16*i + j, divisor = 32
1221
- // / k = ( 16*i + j ) floordiv 32
1222
- // /
1223
- // / 4q >= i + j - 2 <-- Lower bound for 'q'
1224
- // / 4q <= i + j + 1 <-- Upper bound for 'q'
1225
- // / expr = i + j + 1, divisor = 4
1226
- // / q = (i + j + 1) floordiv 4
1227
- //
1228
- // / This function also supports detecting divisions from bounds that are
1229
- // / strictly tighter than the division bounds described above, since tighter
1230
- // / bounds imply the division bounds. For example:
1231
- // / 4q - i - j + 2 >= 0 <-- Lower bound for 'q'
1232
- // / -4q + i + j >= 0 <-- Tight upper bound for 'q'
1233
- // /
1234
- // / To extract floor divisions with tighter bounds, we assume that that the
1235
- // / constraints are of the form:
1236
- // / c <= expr - divisior * id <= divisor - 1, where 0 <= c <= divisor - 1
1237
- // / Rearranging, we have:
1238
- // / divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id'
1239
- // / -divisor * id + expr - c >= 0 <-- Upper bound for 'id'
1240
- // /
1241
- // / If successful, `expr` is set to dividend of the division and `divisor` is
1242
- // / set to the denominator of the division.
1243
- static LogicalResult getDivRepr (const FlatAffineConstraints &cst, unsigned pos,
1244
- unsigned ubIneq, unsigned lbIneq,
1245
- SmallVector<int64_t , 8 > &expr,
1246
- unsigned &divisor) {
1247
-
1248
- assert (pos <= cst.getNumIds () && " Invalid identifier position" );
1249
- assert (ubIneq <= cst.getNumInequalities () &&
1250
- " Invalid upper bound inequality position" );
1251
- assert (lbIneq <= cst.getNumInequalities () &&
1252
- " Invalid upper bound inequality position" );
1253
-
1254
- // Extract divisor from the lower bound.
1255
- divisor = cst.atIneq (lbIneq, pos);
1256
-
1257
- // First, check if the constraints are opposite of each other except the
1258
- // constant term.
1259
- unsigned i = 0 , e = 0 ;
1260
- for (i = 0 , e = cst.getNumIds (); i < e; ++i)
1261
- if (cst.atIneq (ubIneq, i) != -cst.atIneq (lbIneq, i))
1262
- break ;
1263
-
1264
- if (i < e)
1265
- return failure ();
1266
-
1267
- // Then, check if the constant term is of the proper form.
1268
- // Due to the form of the upper/lower bound inequalities, the sum of their
1269
- // constants is `divisor - 1 - c`. From this, we can extract c:
1270
- int64_t constantSum = cst.atIneq (lbIneq, cst.getNumCols () - 1 ) +
1271
- cst.atIneq (ubIneq, cst.getNumCols () - 1 );
1272
- int64_t c = divisor - 1 - constantSum;
1273
-
1274
- // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also
1275
- // implictly checks that `divisor` is positive.
1276
- if (!(c >= 0 && c <= divisor - 1 ))
1277
- return failure ();
1278
-
1279
- // The inequality pair can be used to extract the division.
1280
- // Set `expr` to the dividend of the division except the constant term, which
1281
- // is set below.
1282
- expr.resize (cst.getNumCols (), 0 );
1283
- for (i = 0 , e = cst.getNumIds (); i < e; ++i)
1284
- if (i != pos)
1285
- expr[i] = cst.atIneq (ubIneq, i);
1286
-
1287
- // From the upper bound inequality's form, its constant term is equal to the
1288
- // constant term of `expr`, minus `c`. From this,
1289
- // constant term of `expr` = constant term of upper bound + `c`.
1290
- expr.back () = cst.atIneq (ubIneq, cst.getNumCols () - 1 ) + c;
1291
-
1292
- return success ();
1293
- }
1294
-
1295
- // / Check if the pos^th identifier can be expressed as a floordiv of an affine
1296
- // / function of other identifiers (where the divisor is a positive constant).
1297
- // / `foundRepr` contains a boolean for each identifier indicating if the
1298
- // / explicit representation for that identifier has already been computed.
1299
- // / Returns the upper and lower bound inequalities using which the floordiv can
1300
- // / be computed. If the representation could be computed, `dividend` and
1301
- // / `denominator` are set. If the representation could not be computed,
1302
- // / `llvm::None` is returned.
1303
- static Optional<std::pair<unsigned , unsigned >>
1304
- computeSingleVarRepr (const FlatAffineConstraints &cst,
1305
- const SmallVector<bool , 8 > &foundRepr, unsigned pos,
1306
- SmallVector<int64_t , 8 > ÷nd, unsigned &divisor) {
1307
- assert (pos < cst.getNumIds () && " invalid position" );
1308
- assert (foundRepr.size () == cst.getNumIds () &&
1309
- " Size of foundRepr does not match total number of variables" );
1310
-
1311
- SmallVector<unsigned , 4 > lbIndices, ubIndices;
1312
- cst.getLowerAndUpperBoundIndices (pos, &lbIndices, &ubIndices);
1313
-
1314
- for (unsigned ubPos : ubIndices) {
1315
- for (unsigned lbPos : lbIndices) {
1316
- // Attempt to get divison representation from ubPos, lbPos.
1317
- if (failed (getDivRepr (cst, pos, ubPos, lbPos, dividend, divisor)))
1318
- continue ;
1319
-
1320
- // Check if the inequalities depend on a variable for which
1321
- // an explicit representation has not been found yet.
1322
- // Exit to avoid circular dependencies between divisions.
1323
- unsigned c, f;
1324
- for (c = 0 , f = cst.getNumIds (); c < f; ++c) {
1325
- if (c == pos)
1326
- continue ;
1327
- if (!foundRepr[c] && dividend[c] != 0 )
1328
- break ;
1329
- }
1330
-
1331
- // Expression can't be constructed as it depends on a yet unknown
1332
- // identifier.
1333
- // TODO: Visit/compute the identifiers in an order so that this doesn't
1334
- // happen. More complex but much more efficient.
1335
- if (c < f)
1336
- continue ;
1337
-
1338
- return std::make_pair (ubPos, lbPos);
1339
- }
1340
- }
1341
-
1342
- return llvm::None;
1343
- }
1344
-
1345
1206
void FlatAffineConstraints::getLocalReprs (
1346
1207
std::vector<llvm::Optional<std::pair<unsigned , unsigned >>> &repr) const {
1347
1208
std::vector<SmallVector<int64_t , 8 >> dividends (getNumLocalIds ());
@@ -1378,8 +1239,9 @@ void FlatAffineConstraints::getLocalReprs(
1378
1239
changed = false ;
1379
1240
for (unsigned i = 0 , e = getNumLocalIds (); i < e; ++i) {
1380
1241
if (!foundRepr[i + divOffset]) {
1381
- if (auto res = computeSingleVarRepr (*this , foundRepr, divOffset + i,
1382
- dividends[i], denominators[i])) {
1242
+ if (auto res = presburger_utils::computeSingleVarRepr (
1243
+ *this , foundRepr, divOffset + i, dividends[i],
1244
+ denominators[i])) {
1383
1245
foundRepr[i + divOffset] = true ;
1384
1246
repr[i] = res;
1385
1247
changed = true ;
@@ -1437,11 +1299,9 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
1437
1299
for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
1438
1300
// Find a row which has a non-zero coefficient in column 'j'.
1439
1301
unsigned pivotRow;
1440
- if (!findConstraintWithNonZeroAt (*this , pivotCol, /* isEq=*/ true ,
1441
- &pivotRow)) {
1302
+ if (!findConstraintWithNonZeroAt (pivotCol, /* isEq=*/ true , &pivotRow)) {
1442
1303
// No pivot row in equalities with non-zero at 'pivotCol'.
1443
- if (!findConstraintWithNonZeroAt (*this , pivotCol, /* isEq=*/ false ,
1444
- &pivotRow)) {
1304
+ if (!findConstraintWithNonZeroAt (pivotCol, /* isEq=*/ false , &pivotRow)) {
1445
1305
// If inequalities are also non-zero in 'pivotCol', it can be
1446
1306
// eliminated.
1447
1307
continue ;
@@ -1596,60 +1456,6 @@ static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
1596
1456
return false ;
1597
1457
}
1598
1458
1599
- // / Gather all lower and upper bounds of the identifier at `pos`, and
1600
- // / optionally any equalities on it. In addition, the bounds are to be
1601
- // / independent of identifiers in position range [`offset`, `offset` + `num`).
1602
- void FlatAffineConstraints::getLowerAndUpperBoundIndices (
1603
- unsigned pos, SmallVectorImpl<unsigned > *lbIndices,
1604
- SmallVectorImpl<unsigned > *ubIndices, SmallVectorImpl<unsigned > *eqIndices,
1605
- unsigned offset, unsigned num) const {
1606
- assert (pos < getNumIds () && " invalid position" );
1607
- assert (offset + num < getNumCols () && " invalid range" );
1608
-
1609
- // Checks for a constraint that has a non-zero coeff for the identifiers in
1610
- // the position range [offset, offset + num) while ignoring `pos`.
1611
- auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
1612
- unsigned c, f;
1613
- auto cst = isEq ? getEquality (r) : getInequality (r);
1614
- for (c = offset, f = offset + num; c < f; ++c) {
1615
- if (c == pos)
1616
- continue ;
1617
- if (cst[c] != 0 )
1618
- break ;
1619
- }
1620
- return c < f;
1621
- };
1622
-
1623
- // Gather all lower bounds and upper bounds of the variable. Since the
1624
- // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1625
- // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1626
- for (unsigned r = 0 , e = getNumInequalities (); r < e; r++) {
1627
- // The bounds are to be independent of [offset, offset + num) columns.
1628
- if (containsConstraintDependentOnRange (r, /* isEq=*/ false ))
1629
- continue ;
1630
- if (atIneq (r, pos) >= 1 ) {
1631
- // Lower bound.
1632
- lbIndices->push_back (r);
1633
- } else if (atIneq (r, pos) <= -1 ) {
1634
- // Upper bound.
1635
- ubIndices->push_back (r);
1636
- }
1637
- }
1638
-
1639
- // An equality is both a lower and upper bound. Record any equalities
1640
- // involving the pos^th identifier.
1641
- if (!eqIndices)
1642
- return ;
1643
-
1644
- for (unsigned r = 0 , e = getNumEqualities (); r < e; r++) {
1645
- if (atEq (r, pos) == 0 )
1646
- continue ;
1647
- if (containsConstraintDependentOnRange (r, /* isEq=*/ true ))
1648
- continue ;
1649
- eqIndices->push_back (r);
1650
- }
1651
- }
1652
-
1653
1459
// / Check if the pos^th identifier can be expressed as a floordiv of an affine
1654
1460
// / function of other identifiers (where the divisor is a positive constant)
1655
1461
// / given the initial set of expressions in `exprs`. If it can be, the
@@ -1670,7 +1476,8 @@ static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
1670
1476
1671
1477
SmallVector<int64_t , 8 > dividend;
1672
1478
unsigned divisor;
1673
- auto ulPair = computeSingleVarRepr (cst, foundRepr, pos, dividend, divisor);
1479
+ auto ulPair = presburger_utils::computeSingleVarRepr (cst, foundRepr, pos,
1480
+ dividend, divisor);
1674
1481
1675
1482
// No upper-lower bound pair found for this var.
1676
1483
if (!ulPair)
@@ -2109,7 +1916,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
2109
1916
2110
1917
// Detect an identifier as an expression of other identifiers.
2111
1918
unsigned idx;
2112
- if (!findConstraintWithNonZeroAt (* this , pos, /* isEq=*/ true , &idx)) {
1919
+ if (!findConstraintWithNonZeroAt (pos, /* isEq=*/ true , &idx)) {
2113
1920
continue ;
2114
1921
}
2115
1922
@@ -3447,12 +3254,10 @@ void FlatAffineValueConstraints::getIneqAsAffineValueMap(
3447
3254
vmap.reset (AffineMap::get (numDims - 1 , numSyms, boundExpr), operands);
3448
3255
}
3449
3256
3450
- // / Returns true if the pos^th column is all zero for both inequalities and
3451
- // / equalities..
3452
- static bool isColZero (const FlatAffineConstraints &cst, unsigned pos) {
3257
+ bool FlatAffineConstraints::isColZero (unsigned pos) const {
3453
3258
unsigned rowPos;
3454
- return !findConstraintWithNonZeroAt (cst, pos, /* isEq=*/ false , &rowPos) &&
3455
- !findConstraintWithNonZeroAt (cst, pos, /* isEq=*/ true , &rowPos);
3259
+ return !findConstraintWithNonZeroAt (pos, /* isEq=*/ false , &rowPos) &&
3260
+ !findConstraintWithNonZeroAt (pos, /* isEq=*/ true , &rowPos);
3456
3261
}
3457
3262
3458
3263
IntegerSet FlatAffineConstraints::getAsIntegerSet (MLIRContext *context) const {
@@ -3471,7 +3276,7 @@ IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
3471
3276
SmallVector<unsigned > noLocalRepVars;
3472
3277
unsigned numDimsSymbols = getNumDimAndSymbolIds ();
3473
3278
for (unsigned i = numDimsSymbols, e = getNumIds (); i < e; ++i) {
3474
- if (!memo[i] && !isColZero (* this , /* pos=*/ i))
3279
+ if (!memo[i] && !isColZero (/* pos=*/ i))
3475
3280
noLocalRepVars.push_back (i - numDimsSymbols);
3476
3281
}
3477
3282
if (!noLocalRepVars.empty ()) {
0 commit comments