Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation #69540

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
// This code deals with permutations as well as non-permutations that
// arise from rank changing blocking.
const auto dimToLvl = stt.getDimToLvl();
const auto lvlToDim = stt.getLvlToDim();
SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
SmallVector<Value> lvlSizesValues(lvlRank);
Expand All @@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
Dimension d = 0;
uint64_t cf = 0, cm = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId:
case AffineExprKind::DimId: {
d = exp.cast<AffineDimExpr>().getPosition();
break;
case AffineExprKind::FloorDiv:
d = exp.cast<AffineBinaryOpExpr>()
.getLHS()
.cast<AffineDimExpr>()
.getPosition();
cf = exp.cast<AffineBinaryOpExpr>()
.getRHS()
.cast<AffineConstantExpr>()
.getValue();
}
case AffineExprKind::FloorDiv: {
auto floor = exp.cast<AffineBinaryOpExpr>();
d = floor.getLHS().cast<AffineDimExpr>().getPosition();
cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
break;
case AffineExprKind::Mod:
d = exp.cast<AffineBinaryOpExpr>()
.getLHS()
.cast<AffineDimExpr>()
.getPosition();
cm = exp.cast<AffineBinaryOpExpr>()
.getRHS()
.cast<AffineConstantExpr>()
.getValue();
}
case AffineExprKind::Mod: {
auto mod = exp.cast<AffineBinaryOpExpr>();
d = mod.getLHS().cast<AffineDimExpr>().getPosition();
cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
break;
}
default:
llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
}
dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
// Compute the level sizes.
// (1) l = d : size(d)
// (2) l = d / c : size(d) / c
Expand All @@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
}
lvlSizesValues[l] = lvlSz;
}
// Generate lvl2dim.
assert(dimRank == lvlToDim.getNumResults());
for (Dimension d = 0; d < dimRank; d++) {
AffineExpr exp = lvlToDim.getResult(d);
// We expect:
// (1) d = l
// (2) d = l' * c + l
Level l = 0, ll = 0;
uint64_t c = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId: {
l = exp.cast<AffineDimExpr>().getPosition();
break;
}
case AffineExprKind::Add: {
// Always mul on lhs, symbol/constant on rhs.
auto add = exp.cast<AffineBinaryOpExpr>();
assert(add.getLHS().getKind() == AffineExprKind::Mul);
auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
c = mul.getRHS().cast<AffineConstantExpr>().getValue();
l = add.getRHS().cast<AffineDimExpr>().getPosition();
break;
}
default:
llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
}
lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
}
// Return buffers.
dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
Expand Down