diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h index b27ceca215dad..fe1f9062a37ef 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_ARMSME_IR_ARMSME_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index feeac3b8a0355..df837ebcf23b3 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -38,16 +38,16 @@ class ArmSME_IntrOp overloadedOperands = [], // Zero def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">, - Arguments<(ins Arg)>; + Arguments<(ins Arg:$tile_mask)>; // MOP's class ArmSME_IntrMopOverloadedOp : ArmSME_IntrOp, - Arguments<(ins Arg, - Arg, - Arg, - Arg, - Arg)>; + Arguments<(ins Arg:$tile_id, + Arg:$lhs_predicate, + Arg:$rhs_predicate, + Arg:$lhs_vector, + Arg:$rhs_vector)>; def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">; def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">; @@ -65,10 +65,10 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">; // Loads class ArmSME_IntrLoadOp : ArmSME_IntrOp, - Arguments<(ins Arg, - Arg, - Arg, - Arg)>; + Arguments<(ins Arg:$predicate, + Arg:$load_address, + Arg:$tile_id, + Arg:$tile_slice_index)>; def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">; def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">; @@ -84,10 +84,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">; // Stores class ArmSME_IntrStoreOp : ArmSME_IntrOp, - Arguments<(ins Arg, - Arg, - Arg, - Arg)>; + Arguments<(ins Arg:$predicate, + Arg:$store_address, + Arg:$tild_id, + Arg:$tile_slice_index)>; def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">; def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">; @@ -102,28 +102,28 @@ def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">; def LLVM_aarch64_sme_str : ArmSME_IntrOp<"str">, - Arguments<(ins Arg, - Arg)>; + Arguments<(ins Arg:$index, + Arg:$store_address)>; // Vector to tile slice class LLVM_aarch64_sme_write : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3], - [AllShapesMatch<["pg", "vector"]>]>, - Arguments<(ins Arg, - Arg, - Arg:$pg, + [AllShapesMatch<["predicate", "vector"]>]>, + Arguments<(ins Arg:$tile_id, + Arg:$tile_slice_index, + Arg:$predicate, Arg:$vector)>; // Tile slice to vector class LLVM_aarch64_sme_read : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[], - [AllShapesMatch<["vector", "pg", "res"]>, + [AllShapesMatch<["vector", "predicate", "res"]>, AllElementTypesMatch<["vector", "res"]>], /*numResults=*/1, /*overloadedResults=*/[0]>, Arguments<(ins Arg:$vector, - Arg:$pg, - Arg, - Arg)>; + Arg:$predicate, + Arg:$tile_id, + Arg:$tile_slice_index)>; def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">; def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">; diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir index ae99ac5e02d62..b3202b26f8e1e 100644 --- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir @@ -5,7 +5,7 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>) { %tile = llvm.mlir.constant(0 : index) : i32 - // expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}} + // expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}} "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) : (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> () llvm.return @@ -17,7 +17,7 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes( %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8> ) -> vector<[3]xf32> { %tile = llvm.mlir.constant(0 : index) : i32 - // expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}} + // expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}} %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) : (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32> llvm.return %res : vector<[3]xf32>