Skip to content

Commit

Permalink
[mlir][ArmSME] Add custom vector.print lowering for SME tiles (#66691)
Browse files Browse the repository at this point in the history
This adds a custom lowering for SME that loops over each row of the
tile, extracting it via an SME MOVA, then printing with a normal 1D
vector.print.

This makes writing SME integration tests easier and less verbose.

Depends on: #66910, #66911
  • Loading branch information
MacDue authored Sep 26, 2023
1 parent 9555736 commit 174cd61
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 134 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);

/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
/// intrinsics. Or returns `tile` if already i32.
Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);

} // namespace arm_sme
} // namespace mlir

Expand Down
93 changes: 91 additions & 2 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,94 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
}
};

/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
/// extracting them via a MOVA, then printing with a 1D `vector.print`.
///
/// BEFORE:
/// ```mlir
/// vector.print %tile : vector<[4]x[4]xf32>
/// ```
/// AFTER:
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %c4 = arith.constant 4 : index
/// %ptrue = arith.constant dense<true> : vector<[4]xi1>
/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
/// %vscale = vector.vscale
/// %svl_s = arith.muli %c4, %vscale : index
/// %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
/// scf.for %i = %c0 to %svl_s step %c1 {
/// %slice_idx = arith.index_cast %i : index to i32
/// %tile_slice = "arm_sme.intr.read.horiz"
/// (%cst, %ptrue, %tile_id, %slice_idx)
/// : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
/// vector.print %tile_slice : vector<[4]xf32>
/// }
/// ```
struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
using OpRewritePattern<vector::PrintOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::PrintOp printOp,
PatternRewriter &rewriter) const override {
if (!printOp.getSource())
return failure();

VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
return failure();

auto loc = printOp.getLoc();

// Create an 'all true' predicate for each tile row.
auto predicateType =
VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));

// Cast tile to i32 tile ID.
auto tileId =
rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);

// Zero destination/fallback for tile slice extraction.
auto rowType = VectorType::get(vectorType.getDimSize(1),
vectorType.getElementType(), true);
auto zeroVector = rewriter.create<arith::ConstantOp>(
loc, rowType, rewriter.getZeroAttr(rowType));

// Create a loop over the rows of the tile.
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto minTileRows =
rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
{
// Loop body.
rewriter.setInsertionPointToStart(forOp.getBody());
// Extract the current row from the tile.
Value rowIndex = forOp.getInductionVar();
auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), rowIndex);
auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
// Print the row with a 1D vector.print.
rewriter.create<vector::PrintOp>(loc, tileSlice,
printOp.getPunctuation());
}

rewriter.eraseOp(printOp);
return success();
}
};

} // namespace

void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
patterns.getContext());
patterns.add<TileLoadOpConversion, TileStoreOpConversion,
TileVectorPrintOpConversion>(patterns.getContext());
}

namespace {
Expand All @@ -208,6 +291,12 @@ struct ConvertArmSMEToSCFPass
target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
arith::ArithDialect, scf::SCFDialect>();
target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
target.addDynamicallyLegalOp<vector::PrintOp>([](vector::PrintOp op) {
if (!op.getSource())
return true;
VectorType vectorType = dyn_cast<VectorType>(op.getPrintType());
return !vectorType || !arm_sme::isValidSMETileVectorType(vectorType);
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
Expand Down
17 changes: 0 additions & 17 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
}
};

/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
/// integer, to an i32 that can be passed as the `tile` parameter to the SME
/// intrinsics. Or returns `tile` if already i32.
Value castTileIDToI32(Value tile, Location loc,
ConversionPatternRewriter &rewriter) {
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
tile.getDefiningOp())) &&
"expected ArmSME GetTileID or CastVectorToTile op!");
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
if (tileElementWidth < 32)
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
if (tileElementWidth > 32)
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
return tile;
}

/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/ArmSME/Utils/Utils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"

using namespace mlir;
Expand Down Expand Up @@ -42,3 +43,16 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {

return true;
}

Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
RewriterBase &rewriter) {
assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
tile.getDefiningOp())) &&
"expected ArmSME GetTileID or CastVectorToTile op!");
unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
if (tileElementWidth < 32)
return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
if (tileElementWidth > 32)
return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
return tile;
}
22 changes: 22 additions & 0 deletions mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,25 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

// -----

func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
{
vector.print %tile : vector<[4]x[4]xf32>
return
}
// CHECK-LABEL: func.func @arm_sme_tile_print(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-DAG: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
// CHECK-DAG: %[[ZERO_VECTOR:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_cast %[[TILE_SLICE_INDEX]] : index to i32
// CHECK-NEXT: %[[TILE_SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VECTOR]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}

func.func @printTileEnd() {
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -44,7 +44,6 @@ func.func @entry() {

// Allocate memory.
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
%mem2 = memref.alloca(%za_s_size) : memref<?xi32>

// Fill each "row" of "mem1" with row number.
//
Expand All @@ -66,11 +65,6 @@ func.func @entry() {
// Load tile from "mem1" vertically.
%0 = arm_sme.tile_load %mem1[%c0, %c0], <vertical> : memref<?xi32>, vector<[4]x[4]xi32>

// Store tile back to "mem2" to print.
// TODO: Support vector.print for 2-D scalable vectors so don't have to spill
// to memory and reload to print.
vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

// 1. ORIGINAL HORIZONTAL LAYOUT
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
// 4x4xi32.
Expand Down Expand Up @@ -99,10 +93,7 @@ func.func @entry() {
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
vector.print %tileslice : vector<[4]xi32>
}
vector.print %0 : vector<[4]x[4]xi32>
func.call @printTileEnd() : () -> ()

return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -25,7 +25,7 @@ func.func @printTileBegin() {
return
}

func.func @printTileEnd() {
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -41,20 +41,8 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>

// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
%vscale = vector.vscale
%min_elts_s = arith.constant 4 : index
%svl_s = arith.muli %min_elts_s, %vscale : index
%za_s_size = arith.muli %svl_s, %svl_s : index

// Allocate memory.
%mem = memref.alloca(%za_s_size) : memref<?xf32>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 4x4xf32.
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 4x4xf32.
//
// WITHOUT-ACC: TILE BEGIN
// WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
Expand All @@ -63,10 +51,7 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
// WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
// WITHOUT-ACC: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
vector.print %tileslice : vector<[4]xf32>
}
vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()

return
Expand All @@ -81,20 +66,8 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>

// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
%vscale = vector.vscale
%min_elts_s = arith.constant 4 : index
%svl_s = arith.muli %min_elts_s, %vscale : index
%za_s_size = arith.muli %svl_s, %svl_s : index

// Allocate memory.
%mem = memref.alloca(%za_s_size) : memref<?xf32>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 4x4xf32.
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 4x4xf32.
//
// WITH-ACC: TILE BEGIN
// WITH-ACC-NEXT: ( 10, 10, 10, 10
Expand All @@ -103,10 +76,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
// WITH-ACC-NEXT: ( 10, 13, 16, 19
// WITH-ACC: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_s_size step %svl_s {
%tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
vector.print %tileslice : vector<[4]xf32>
}
vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()

return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

llvm.func @printCString(!llvm.ptr<i8>)

func.func @printTileBegin() {
func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}

func.func @printTileEnd() {
func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
Expand All @@ -32,7 +32,6 @@ func.func @printTileEnd() {
}

func.func @test_outerproduct_with_accumulator_2x2xf64() {
%c0 = arith.constant 0 : index
%f1 = arith.constant 1.0 : f64
%f2 = arith.constant 2.0 : f64
%f10 = arith.constant 10.0 : f64
Expand All @@ -44,30 +43,15 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {

%tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>

// Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
%vscale = vector.vscale
%min_elts_d = arith.constant 2 : index
%svl_d = arith.muli %min_elts_d, %vscale : index
%za_d_size = arith.muli %svl_d, %svl_d : index

// Allocate memory.
%mem = memref.alloca(%za_d_size) : memref<?xf64>

// Store the tile to memory.
vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>

// Reload and print. The smallest SVL is 128-bits so the tile will be at
// least 2x2xf64.
// Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 2x2xf64.
//
// CHECK: TILE BEGIN
// CHECK-NEXT: ( 12, 12
// CHECK-NEXT: ( 12, 12
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
scf.for %i = %c0 to %za_d_size step %svl_d {
%tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
vector.print %tileslice : vector<[2]xf64>
}
vector.print %tile : vector<[2]x[2]xf64>
func.call @printTileEnd() : () -> ()

return
Expand Down
Loading

0 comments on commit 174cd61

Please sign in to comment.