-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Add tile slice to vector intrinsics #66910
Conversation
Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect: llvm.aarch64.sme.read.vert llvm.aarch64.sme.read.horiz This also slightly updates ArmSME_IntrOp to support return values.
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-sme ChangesAdd support for following vector to tile (MOVA) intrinsics to ArmSME dialect:
This also slightly updates ArmSME_IntrOp to support return values. Full diff: https://github.com/llvm/llvm-project/pull/66910.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..00e1fefc0521a78 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
- list<Trait> traits = []>
+ list<Trait> traits = [], int numResults = 0,
+ list<int> overloadedResults = []>
: LLVM_IntrOpBase<
/*Dialect dialect=*/ArmSME_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
- /*list<int> overloadedResults=*/[],
+ /*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/0>;
+ /*int numResults=*/numResults>;
// Zero
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
-// Vector to tile
+// Vector to tile slice
class LLVM_aarch64_sme_write<string direction>
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
[AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
Arg<SVEPredicate, "Vector predicate">:$pg,
Arg<SVEVector, "Vector operand">:$vector)>;
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+ : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+ [AllShapesMatch<["vector", "pg", "res"]>,
+ AllElementTypesMatch<["vector", "res"]>],
+ /*numResults*/1, /*overloadedResults*/[0]>,
+ Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+ Arg<SVEPredicate, "Vector predicate">:$pg,
+ Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">)>;
+
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..7cbc382b0050a6e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
using namespace mlir::arm_sme;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index e119e1f1a404416..ae99ac5e02d62f0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+ %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+ %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
+ (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+ llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+ %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
+ %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
+ (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.return %res : vector<[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 9bb6b0c6574fcdb..c318e6d2d37f7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -334,3 +334,137 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
llvm.return
}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
+ %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+ : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
+ %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
+ %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
+ %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+ // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
+ %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+ : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
+ %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
+ %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
+ %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
+ %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+ llvm.return
+}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
+ %nxv16i1 : vector<[16]xi1>,
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>,
+ %nxv16i8 : vector<[16]xi8>,
+ %nxv8i16 : vector<[8]xi16>,
+ %nxv4i32 : vector<[4]xi32>,
+ %nxv2i64 : vector<[2]xi64>,
+ %nxv1i128 : vector<[1]xi128>,
+ %nxv8f16 : vector<[8]xf16>,
+ %nxv8bf16 : vector<[8]xbf16>,
+ %nxv4f32 : vector<[4]xf32>,
+ %nxv2f64 : vector<[2]xf64>) {
+ %tile = llvm.mlir.constant(0 : index) : i32
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
+ %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+ : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
+ %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
+ %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
+ %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+ // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
+ %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+ : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
+ %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
+ %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
+ %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+ : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
+ %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+ : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+ llvm.return
+}
|
@@ -12,6 +12,7 @@ | |||
|
|||
#include "mlir/Dialect/ArmSME/IR/ArmSME.h" | |||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" | |||
#include "mlir/IR/TypeUtilities.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this required for the AllElementTypesMatch
constraint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep
Fixes: llvm#66910 (comment) Fixes: llvm#66910 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
4816091
to
62c8311
Compare
…6691) This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print. This makes writing SME integration tests easier and less verbose. Depends on: llvm#66910, llvm#66911
Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect:
This also slightly updates ArmSME_IntrOp to support return values.