Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Add tile slice to vector intrinsics #66910

Merged
merged 3 commits into from
Sep 22, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Sep 20, 2023

Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect:

llvm.aarch64.sme.read.vert
llvm.aarch64.sme.read.horiz

This also slightly updates ArmSME_IntrOp to support return values.

Add support for following vector to tile (MOVA) intrinsics to ArmSME
dialect:

  llvm.aarch64.sme.read.vert
  llvm.aarch64.sme.read.horiz

This also slightly updates ArmSME_IntrOp to support return values.
@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2023

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Changes

Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect:

llvm.aarch64.sme.read.vert
llvm.aarch64.sme.read.horiz

This also slightly updates ArmSME_IntrOp to support return values.


Full diff: https://github.com/llvm/llvm-project/pull/66910.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+19-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+1)
  • (modified) mlir/test/Target/LLVMIR/arm-sme-invalid.mlir (+24)
  • (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (+134)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..00e1fefc0521a78 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
 def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
 
 class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
-                    list<Trait> traits = []>
+                    list<Trait> traits = [], int numResults = 0,
+                    list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
           /*Dialect dialect=*/ArmSME_Dialect,
           /*string opName=*/"intr." # mnemonic,
           /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
-          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/0>;
+          /*int numResults=*/numResults>;
 
 // Zero
 def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
       Arguments<(ins Arg<I32, "Index">,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
 
-// Vector to tile
+// Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
     : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
                     [AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
                      Arg<SVEPredicate, "Vector predicate">:$pg,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+                    [AllShapesMatch<["vector", "pg", "res"]>,
+                     AllElementTypesMatch<["vector", "res"]>],
+                    /*numResults*/1, /*overloadedResults*/[0]>,
+      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">)>;
+
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
 def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
 def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
 def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
 
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..7cbc382b0050a6e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/TypeUtilities.h"
 
 using namespace mlir;
 using namespace mlir::arm_sme;
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index e119e1f1a404416..ae99ac5e02d62f0 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
       (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
+) -> vector<[3]xf32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
+      (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+  llvm.return %res : vector<[3]xf32>
+}
+
+// -----
+
+llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
+  %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
+) -> vector<[3]xi32> {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
+      (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.return %res : vector<[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 9bb6b0c6574fcdb..c318e6d2d37f7fd 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -334,3 +334,137 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
       (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
+  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
+  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
+  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
+  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
+  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
+  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
+  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
+  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
+  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @prevent_dce.nxv16i8(vector<[16]xi8>)
+llvm.func @prevent_dce.nxv8i16(vector<[8]xi16>)
+llvm.func @prevent_dce.nxv4i32(vector<[4]xi32>)
+llvm.func @prevent_dce.nxv2i64(vector<[2]xi64>)
+llvm.func @prevent_dce.nxv1i128(vector<[1]xi128>)
+llvm.func @prevent_dce.nxv8f16(vector<[8]xf16>)
+llvm.func @prevent_dce.nxv8bf16(vector<[8]xbf16>)
+llvm.func @prevent_dce.nxv4f32(vector<[4]xf32>)
+llvm.func @prevent_dce.nxv2f64(vector<[2]xf64>)
+
+llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
+                                              %nxv16i1 : vector<[16]xi1>,
+                                              %nxv8i1 : vector<[8]xi1>,
+                                              %nxv4i1 : vector<[4]xi1>,
+                                              %nxv2i1 : vector<[2]xi1>,
+                                              %nxv1i1 : vector<[1]xi1>,
+                                              %nxv16i8 : vector<[16]xi8>,
+                                              %nxv8i16 : vector<[8]xi16>,
+                                              %nxv4i32 : vector<[4]xi32>,
+                                              %nxv2i64 : vector<[2]xi64>,
+                                              %nxv1i128 : vector<[1]xi128>,
+                                              %nxv8f16 : vector<[8]xf16>,
+                                              %nxv8bf16 : vector<[8]xbf16>,
+                                              %nxv4f32 : vector<[4]xf32>,
+                                              %nxv2f64 : vector<[2]xf64>) {
+  %tile = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
+  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
+    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  llvm.call @prevent_dce.nxv16i8(%res0) : (vector<[16]xi8>) -> ()
+  // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
+  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  llvm.call @prevent_dce.nxv8i16(%res1) : (vector<[8]xi16>) -> ()
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
+  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  llvm.call @prevent_dce.nxv4i32(%res2) : (vector<[4]xi32>) -> ()
+  // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
+  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  llvm.call @prevent_dce.nxv2i64(%res3) : (vector<[2]xi64>) -> ()
+  // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
+  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
+    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  llvm.call @prevent_dce.nxv1i128(%res4) : (vector<[1]xi128>) -> ()
+  // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
+  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  llvm.call @prevent_dce.nxv8f16(%res5) : (vector<[8]xf16>) -> ()
+  // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
+  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  llvm.call @prevent_dce.nxv8bf16(%res6) : (vector<[8]xbf16>) -> ()
+  // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
+  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
+    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  llvm.call @prevent_dce.nxv4f32(%res7) : (vector<[4]xf32>) -> ()
+  // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
+  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
+    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  llvm.call @prevent_dce.nxv2f64(%res8) : (vector<[2]xf64>) -> ()
+  llvm.return
+}

@@ -12,6 +12,7 @@

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/TypeUtilities.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required for the AllElementTypesMatch constraint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Sep 21, 2023
@MacDue MacDue force-pushed the mova_tile_slice_to_vector branch from 4816091 to 62c8311 Compare September 21, 2023 19:43
@MacDue MacDue merged commit cb3a394 into llvm:main Sep 22, 2023
MacDue added a commit that referenced this pull request Sep 26, 2023
This adds a custom lowering for SME that loops over each row of the
tile, extracting it via an SME MOVA, then printing with a normal 1D
vector.print.

This makes writing SME integration tests easier and less verbose.

Depends on: #66910, #66911
legrosbuffle pushed a commit to legrosbuffle/llvm-project that referenced this pull request Sep 29, 2023
…6691)

This adds a custom lowering for SME that loops over each row of the
tile, extracting it via an SME MOVA, then printing with a normal 1D
vector.print.

This makes writing SME integration tests easier and less verbose.

Depends on: llvm#66910, llvm#66911
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants