Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArmSME] Name arguments of SME intrinsics (NFC) #69608

Merged
merged 3 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down
48 changes: 24 additions & 24 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],

// Zero
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
Arguments<(ins Arg<I32, "Tile mask">)>;
Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;

// MOP's
class ArmSME_IntrMopOverloadedOp<string mnemonic>
: ArmSME_IntrOp<mnemonic, [4]>,
Arguments<(ins Arg<I32, "Virtual tile ID">,
Arg<MOPPredicate, "LHS predicate">,
Arg<MOPPredicate, "RHS predicate">,
Arg<MOPVector, "LHS vector operand">,
Arg<MOPVector, "RHS vector operand">)>;
Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
Arg<MOPVector, "LHS vector operand">:$lhs_vector,
Arg<MOPVector, "RHS vector operand">:$rhs_vector)>;

def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
Expand All @@ -65,10 +65,10 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
// Loads
class ArmSME_IntrLoadOp<string mnemonic>
: ArmSME_IntrOp<mnemonic>,
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
Arg<LLVM_AnyPointer, "Load address">,
Arg<I32, "Virtual tile ID">,
Arg<I32, "Tile slice">)>;
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">:$predicate,
Arg<LLVM_AnyPointer, "Load address">:$load_address,
Arg<I32, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;

def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
Expand All @@ -84,10 +84,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
// Stores
class ArmSME_IntrStoreOp<string mnemonic>
: ArmSME_IntrOp<mnemonic>,
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
Arg<I32, "Virtual tile ID">,
Arg<I32, "Tile slice">)>;
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">:$predicate,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
Arg<I32, "Virtual tile ID">:$tild_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;

def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
Expand All @@ -102,28 +102,28 @@ def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;

def LLVM_aarch64_sme_str
: ArmSME_IntrOp<"str">,
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
Arguments<(ins Arg<I32, "Index">:$index,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address)>;

// Vector to tile slice
class LLVM_aarch64_sme_write<string direction>
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
[AllShapesMatch<["pg", "vector"]>]>,
Arguments<(ins Arg<I32, "Virtual tile ID">,
Arg<I32, "Tile slice">,
Arg<SVEPredicate, "Vector predicate">:$pg,
[AllShapesMatch<["predicate", "vector"]>]>,
Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index,
Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<SVEVector, "Vector operand">:$vector)>;

// Tile slice to vector
class LLVM_aarch64_sme_read<string direction>
: ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
[AllShapesMatch<["vector", "pg", "res"]>,
[AllShapesMatch<["vector", "predicate", "res"]>,
AllElementTypesMatch<["vector", "res"]>],
/*numResults=*/1, /*overloadedResults=*/[0]>,
Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
Arg<SVEPredicate, "Vector predicate">:$pg,
Arg<I32, "Virtual tile ID">,
Arg<I32, "Tile slice">)>;
Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<I32, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;

def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
%nxv4i1 : vector<[4]xi1>,
%nxv16i8 : vector<[16]xi8>) {
%tile = llvm.mlir.constant(0 : index) : i32
// expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
// expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}}
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
llvm.return
Expand All @@ -17,7 +17,7 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
%tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
) -> vector<[3]xf32> {
%tile = llvm.mlir.constant(0 : index) : i32
// expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
// expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
%res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
(vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
llvm.return %res : vector<[3]xf32>
Expand Down