Skip to content

Commit

Permalink
[mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computa…
Browse files Browse the repository at this point in the history
…tion (#69540)

This makes sure

- GEN MAP dim=2 lvl=4
  (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2)
--
  (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3)

is indeed encoded as

MAP-REF (dim=2, lvl=4) isperm=0
  d2l = [ d0/2 d1/2 d0%2 d1%2 ]
  ld2 = [ l2+2*l0 l3+2*l1 ]
  • Loading branch information
aartbik authored Oct 19, 2023
1 parent e103515 commit f16cb0e
Showing 1 changed file with 42 additions and 20 deletions.
62 changes: 42 additions & 20 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
// This code deals with permutations as well as non-permutations that
// arise from rank changing blocking.
const auto dimToLvl = stt.getDimToLvl();
const auto lvlToDim = stt.getLvlToDim();
SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
SmallVector<Value> lvlSizesValues(lvlRank);
Expand All @@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
Dimension d = 0;
uint64_t cf = 0, cm = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId:
case AffineExprKind::DimId: {
d = exp.cast<AffineDimExpr>().getPosition();
break;
case AffineExprKind::FloorDiv:
d = exp.cast<AffineBinaryOpExpr>()
.getLHS()
.cast<AffineDimExpr>()
.getPosition();
cf = exp.cast<AffineBinaryOpExpr>()
.getRHS()
.cast<AffineConstantExpr>()
.getValue();
}
case AffineExprKind::FloorDiv: {
auto floor = exp.cast<AffineBinaryOpExpr>();
d = floor.getLHS().cast<AffineDimExpr>().getPosition();
cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
break;
case AffineExprKind::Mod:
d = exp.cast<AffineBinaryOpExpr>()
.getLHS()
.cast<AffineDimExpr>()
.getPosition();
cm = exp.cast<AffineBinaryOpExpr>()
.getRHS()
.cast<AffineConstantExpr>()
.getValue();
}
case AffineExprKind::Mod: {
auto mod = exp.cast<AffineBinaryOpExpr>();
d = mod.getLHS().cast<AffineDimExpr>().getPosition();
cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
break;
}
default:
llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
}
dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
// Compute the level sizes.
// (1) l = d : size(d)
// (2) l = d / c : size(d) / c
Expand All @@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
}
lvlSizesValues[l] = lvlSz;
}
// Generate lvl2dim.
assert(dimRank == lvlToDim.getNumResults());
for (Dimension d = 0; d < dimRank; d++) {
AffineExpr exp = lvlToDim.getResult(d);
// We expect:
// (1) d = l
// (2) d = l' * c + l
Level l = 0, ll = 0;
uint64_t c = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId: {
l = exp.cast<AffineDimExpr>().getPosition();
break;
}
case AffineExprKind::Add: {
// Always mul on lhs, symbol/constant on rhs.
auto add = exp.cast<AffineBinaryOpExpr>();
assert(add.getLHS().getKind() == AffineExprKind::Mul);
auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
c = mul.getRHS().cast<AffineConstantExpr>().getValue();
l = add.getRHS().cast<AffineDimExpr>().getPosition();
break;
}
default:
llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
}
lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
}
// Return buffers.
dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
Expand Down

0 comments on commit f16cb0e

Please sign in to comment.