Skip to content

Commit

Permalink
[mlir][sparse] fix logical error when generating sort_coo. (#66690)
Browse files Browse the repository at this point in the history
To fix issue: #66664
  • Loading branch information
PeimingLiu authored Sep 18, 2023
1 parent cacdb90 commit 4176ce6
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 380 deletions.
67 changes: 45 additions & 22 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,14 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// p = (lo+hi)/2 // pivot index
// i = lo
// j = hi-1
// while (i < j) do {
// while (true) do {
// while (xs[i] < xs[p]) i ++;
// i_eq = (xs[i] == xs[p]);
// while (xs[j] > xs[p]) j --;
// j_eq = (xs[j] == xs[p]);
//
// if (i >= j) return j + 1;
//
// if (i < j) {
// swap(xs[i], xs[j])
// if (i == p) {
Expand All @@ -581,8 +584,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// }
// }
// }
// return p
// }
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
Expand All @@ -605,22 +607,22 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
Value i = lo;
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
Value trueVal = constantI1(builder, loc, true); // The value for while (true)
SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
trueVal.getType()};
scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);

// The before-region of the WhileOp.
Block *before =
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
{loc, loc, loc, loc});
builder.setInsertionPointToEnd(before);
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
before->getArgument(0),
before->getArgument(1));
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
builder.create<scf::ConditionOp>(loc, before->getArgument(3),
before->getArguments());

// The after-region of the WhileOp.
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
builder.setInsertionPointToEnd(after);
i = after->getArgument(0);
j = after->getArgument(1);
Expand All @@ -637,7 +639,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
j = jresult;

// If i < j:
cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value> swapOperands{i, j};
Expand Down Expand Up @@ -675,11 +678,15 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointAfter(ifOp2);
builder.create<scf::YieldOp>(
loc,
ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
/*cont=*/constantI1(builder, loc, true)});

// False branch for if i < j:
// False branch for if i < j (i.e., i >= j):
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
p = builder.create<arith::AddIOp>(loc, j,
constantOne(builder, loc, j.getType()));
builder.create<scf::YieldOp>(
loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});

// Return for the whileOp.
builder.setInsertionPointAfter(ifOp);
Expand Down Expand Up @@ -927,6 +934,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
Location loc = func.getLoc();
Value lo = args[loIdx];
Value hi = args[hiIdx];
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.

FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
Expand All @@ -935,14 +944,25 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
TypeRange{IndexType::get(context)},
args.drop_back(nTrailingP))
.getResult(0);
Value pP1 =
builder.create<arith::AddIOp>(loc, p, constantIndex(builder, loc, 1));

Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
// Partition already sorts array with len <= 2
Value c2 = constantIndex(builder, loc, 2);
Value len = builder.create<arith::SubIOp>(loc, hi, lo);
Value lenGtTwo =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
scf::IfOp ifLenGtTwo =
builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
// Returns an empty range to mark the entire region is fully sorted.
builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});

// Else len > 2, need recursion.
builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
lenLow, lenHigh);

SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);

Value c0 = constantIndex(builder, loc, 0);
Expand All @@ -961,14 +981,17 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
// the bigger partition to be processed by the enclosed while-loop.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
mayRecursion(lo, p, lenLow);
builder.create<scf::YieldOp>(loc, ValueRange{pP1, hi});
builder.create<scf::YieldOp>(loc, ValueRange{p, hi});

builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
mayRecursion(pP1, hi, lenHigh);
mayRecursion(p, hi, lenHigh);
builder.create<scf::YieldOp>(loc, ValueRange{lo, p});

builder.setInsertionPointAfter(ifOp);
return std::make_pair(ifOp.getResult(0), ifOp.getResult(1));
builder.create<scf::YieldOp>(loc, ifOp.getResults());

builder.setInsertionPointAfter(ifLenGtTwo);
return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
}

/// Creates a function to perform insertion sort on the values in the range of
Expand Down
Loading

0 comments on commit 4176ce6

Please sign in to comment.