diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 98b412c8ec9eb..b1b1d67ac2d42 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc, // This code deals with permutations as well as non-permutations that // arise from rank changing blocking. const auto dimToLvl = stt.getDimToLvl(); + const auto lvlToDim = stt.getLvlToDim(); SmallVector dim2lvlValues(lvlRank); // for each lvl, expr in dim vars SmallVector lvl2dimValues(dimRank); // for each dim, expr in lvl vars SmallVector lvlSizesValues(lvlRank); @@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc, Dimension d = 0; uint64_t cf = 0, cm = 0; switch (exp.getKind()) { - case AffineExprKind::DimId: + case AffineExprKind::DimId: { d = exp.cast().getPosition(); break; - case AffineExprKind::FloorDiv: - d = exp.cast() - .getLHS() - .cast() - .getPosition(); - cf = exp.cast() - .getRHS() - .cast() - .getValue(); + } + case AffineExprKind::FloorDiv: { + auto floor = exp.cast(); + d = floor.getLHS().cast().getPosition(); + cf = floor.getRHS().cast().getValue(); break; - case AffineExprKind::Mod: - d = exp.cast() - .getLHS() - .cast() - .getPosition(); - cm = exp.cast() - .getRHS() - .cast() - .getValue(); + } + case AffineExprKind::Mod: { + auto mod = exp.cast(); + d = mod.getLHS().cast().getPosition(); + cm = mod.getRHS().cast().getValue(); break; + } default: llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type"); } dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm)); - lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim // Compute the level sizes. // (1) l = d : size(d) // (2) l = d / c : size(d) / c @@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc, } lvlSizesValues[l] = lvlSz; } + // Generate lvl2dim. + assert(dimRank == lvlToDim.getNumResults()); + for (Dimension d = 0; d < dimRank; d++) { + AffineExpr exp = lvlToDim.getResult(d); + // We expect: + // (1) d = l + // (2) d = l' * c + l + Level l = 0, ll = 0; + uint64_t c = 0; + switch (exp.getKind()) { + case AffineExprKind::DimId: { + l = exp.cast().getPosition(); + break; + } + case AffineExprKind::Add: { + // Always mul on lhs, symbol/constant on rhs. + auto add = exp.cast(); + assert(add.getLHS().getKind() == AffineExprKind::Mul); + auto mul = add.getLHS().cast(); + ll = mul.getLHS().cast().getPosition(); + c = mul.getRHS().cast().getValue(); + l = add.getRHS().cast().getPosition(); + break; + } + default: + llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type"); + } + lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll)); + } // Return buffers. dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues); lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);