diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td index 7f02e723f3d91..1ca284a3e70dc 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2], def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>; class ArmSME_IntrOp overloadedOperands = [], - list traits = []> + list traits = [], int numResults = 0, + list overloadedResults = []> : LLVM_IntrOpBase< /*Dialect dialect=*/ArmSME_Dialect, /*string opName=*/"intr." # mnemonic, /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic), - /*list overloadedResults=*/[], + /*list overloadedResults=*/overloadedResults, /*list overloadedOperands=*/overloadedOperands, /*list traits=*/traits, - /*int numResults=*/0>; + /*int numResults=*/numResults>; // Zero def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">, @@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str Arguments<(ins Arg, Arg)>; -// Vector to tile +// Vector to tile slice class LLVM_aarch64_sme_write : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3], [AllShapesMatch<["pg", "vector"]>]>, @@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write Arg:$pg, Arg:$vector)>; +// Tile slice to vector +class LLVM_aarch64_sme_read + : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[], + [AllShapesMatch<["vector", "pg", "res"]>, + AllElementTypesMatch<["vector", "res"]>], + /*numResults=*/1, /*overloadedResults=*/[0]>, + Arguments<(ins Arg:$vector, + Arg:$pg, + Arg, + Arg)>; + def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">; def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">; +def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">; +def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">; + def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">; def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">; diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp index 750627421215d..7cbc382b0050a 100644 --- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/TypeUtilities.h" using namespace mlir; using namespace mlir::arm_sme; diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir index e119e1f1a4044..ae99ac5e02d62 100644 --- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir @@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32, (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> () llvm.return } + +// ----- + +llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes( + %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8> +) -> vector<[3]xf32> { + %tile = llvm.mlir.constant(0 : index) : i32 + // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}} + %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) : + (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32> + llvm.return %res : vector<[3]xf32> +} + +// ----- + +llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types( + %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32> +) -> vector<[3]xi32> { + %tile = llvm.mlir.constant(0 : index) : i32 + // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}} + %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) : + (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32> + llvm.return %res : vector<[4]xi32> +} diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir index 9bb6b0c6574fc..628d7ba4b649e 100644 --- a/mlir/test/Target/LLVMIR/arm-sme.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -334,3 +334,100 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32, (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> () llvm.return } + +// ----- + + +llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32, + %nxv16i1 : vector<[16]xi1>, + %nxv8i1 : vector<[8]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv2i1 : vector<[2]xi1>, + %nxv1i1 : vector<[1]xi1>, + %nxv16i8 : vector<[16]xi8>, + %nxv8i16 : vector<[8]xi16>, + %nxv4i32 : vector<[4]xi32>, + %nxv2i64 : vector<[2]xi64>, + %nxv1i128 : vector<[1]xi128>, + %nxv8f16 : vector<[8]xf16>, + %nxv8bf16 : vector<[8]xbf16>, + %nxv4f32 : vector<[4]xf32>, + %nxv2f64 : vector<[2]xf64>) { + %tile = llvm.mlir.constant(0 : index) : i32 + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv16i8 + %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice) + : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv8i16 + %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice) + : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv4i32 + %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice) + : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv2i64 + %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice) + : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv1i128 + %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice) + : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv8f16 + %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice) + : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv8bf16 + %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice) + : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv4f32 + %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) + : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32> + // CHECK: call @llvm.aarch64.sme.read.horiz.nxv2f64 + %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice) + : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64> + llvm.return +} + +// ----- + +llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32, + %nxv16i1 : vector<[16]xi1>, + %nxv8i1 : vector<[8]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv2i1 : vector<[2]xi1>, + %nxv1i1 : vector<[1]xi1>, + %nxv16i8 : vector<[16]xi8>, + %nxv8i16 : vector<[8]xi16>, + %nxv4i32 : vector<[4]xi32>, + %nxv2i64 : vector<[2]xi64>, + %nxv1i128 : vector<[1]xi128>, + %nxv8f16 : vector<[8]xf16>, + %nxv8bf16 : vector<[8]xbf16>, + %nxv4f32 : vector<[4]xf32>, + %nxv2f64 : vector<[2]xf64>) { + %tile = llvm.mlir.constant(0 : index) : i32 + // CHECK: call @llvm.aarch64.sme.read.vert.nxv16i8 + %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice) + : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv8i16 + %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice) + : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv4i32 + %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice) + : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv2i64 + %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice) + : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv1i128 + %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice) + : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv8f16 + %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice) + : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv8bf16 + %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice) + : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv4f32 + %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice) + : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32> + // CHECK: call @llvm.aarch64.sme.read.vert.nxv2f64 + %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice) + : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64> + llvm.return +}