diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 9ca029b489ad1..44e82f452b3ce 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -86,6 +86,39 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> { /*methodBody=*/[{ return mlir::isRowMajorBatchMatmul($_op.getIndexingMaps()); }]>, + InterfaceMethod< + /*desc=*/[{ + Returns whether the given op has indexing maps that correspond to a + vector-matrix multiplication. + }], + /*retTy=*/"bool", + /*methodName=*/"isVecmat", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::isVecmat($_op.getIndexingMaps()); + }]>, + InterfaceMethod< + /*desc=*/[{ + Returns whether the given op has indexing maps that correspond to a + matrix-vector multiplication. + }], + /*retTy=*/"bool", + /*methodName=*/"isMatvec", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::isMatvec($_op.getIndexingMaps()); + }]>, + InterfaceMethod< + /*desc=*/[{ + Returns whether the given op has indexing maps that correspond to a + batched matrix-vector multiplication. + }], + /*retTy=*/"bool", + /*methodName=*/"isBatchMatvec", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::isBatchMatvec($_op.getIndexingMaps()); + }]>, ]; } diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index dab24bd930326..225b9f287d340 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -49,6 +49,24 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps); /// the reduction. bool isRowMajorBatchMatmul(ArrayAttr indexingMaps); +/// Tests whether the given maps describe a vector matrix multiplication. The +/// test is permutation-invariant. Note that this only checks the affine maps +/// from an operation, so does not perform any checks on the math being +/// performed within the reduction. +bool isVecmat(ArrayAttr indexingMaps); + +/// Tests whether the given maps describe a matrix vector multiplication. The +/// test is permutation-invariant. Note that this only checks the affine maps +/// from an operation, so does not perform any checks on the math being +/// performed within the reduction. +bool isMatvec(ArrayAttr indexingMaps); + +/// Tests whether the given maps describe a batch matrix vector multiplication. +/// The test is permutation-invariant. Note that this only checks the affine +/// maps from an operation, so does not perform any checks on the math being +/// performed within the reduction. +bool isBatchMatvec(ArrayAttr indexingMaps); + /// Return positions in `iteratorTypes` that match `iteratorTypeName`. inline void findPositionsOfType(ArrayRef iteratorTypes, utils::IteratorType iteratorTypeName, diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp index a2977901f4751..641ddf3f91cb2 100644 --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) { if (indexingMaps.size() != 3) return false; - auto map0 = cast(indexingMaps[0]).getValue(); - auto map1 = cast(indexingMaps[1]).getValue(); - auto map2 = cast(indexingMaps[2]).getValue(); + AffineMap map0 = cast(indexingMaps[0]).getValue(); + AffineMap map1 = cast(indexingMaps[1]).getValue(); + AffineMap map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || map2.getNumResults() != 2 || map0.getNumInputs() != 3 || @@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) { if (indexingMaps.size() != 3) return false; - auto map0 = cast(indexingMaps[0]).getValue(); - auto map1 = cast(indexingMaps[1]).getValue(); - auto map2 = cast(indexingMaps[2]).getValue(); + AffineMap map0 = cast(indexingMaps[0]).getValue(); + AffineMap map1 = cast(indexingMaps[1]).getValue(); + AffineMap map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || map2.getNumResults() != 2 || map0.getNumInputs() != 3 || @@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) { if (indexingMaps.size() != 3) return false; - auto map0 = cast(indexingMaps[0]).getValue(); - auto map1 = cast(indexingMaps[1]).getValue(); - auto map2 = cast(indexingMaps[2]).getValue(); + AffineMap map0 = cast(indexingMaps[0]).getValue(); + AffineMap map1 = cast(indexingMaps[1]).getValue(); + AffineMap map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 3 || map1.getNumResults() != 3 || map2.getNumResults() != 3 || map0.getNumInputs() != 4 || @@ -96,6 +96,79 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) { return indexingMaps == maps; } +bool mlir::isVecmat(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + AffineMap map0 = cast(indexingMaps[0]).getValue(); + AffineMap map1 = cast(indexingMaps[1]).getValue(); + AffineMap map2 = cast(indexingMaps[2]).getValue(); + + if (map0.getNumResults() != 1 || map1.getNumResults() != 2 || + map2.getNumResults() != 1 || map0.getNumInputs() != 2 || + map1.getNumInputs() != 2 || map2.getNumInputs() != 2) { + return false; + } + + // Extract dimensions for K * KxN -> N + AffineExpr k = map0.getResult(0); + AffineExpr n = map2.getResult(0); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + +bool mlir::isMatvec(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + AffineMap map0 = cast(indexingMaps[0]).getValue(); + AffineMap map1 = cast(indexingMaps[1]).getValue(); + AffineMap map2 = cast(indexingMaps[2]).getValue(); + + if (map0.getNumResults() != 2 || map1.getNumResults() != 1 || + map2.getNumResults() != 1 || map0.getNumInputs() != 2 || + map1.getNumInputs() != 2 || map2.getNumInputs() != 2) { + return false; + } + + // Extract dimensions for N*K * K -> N + AffineExpr k = map1.getResult(0); + AffineExpr n = map2.getResult(0); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + +bool mlir::isBatchMatvec(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + AffineMap map0 = cast(indexingMaps[0]).getValue(); + AffineMap map1 = cast(indexingMaps[1]).getValue(); + AffineMap map2 = cast(indexingMaps[2]).getValue(); + + if (map0.getNumResults() != 3 || map1.getNumResults() != 2 || + map2.getNumResults() != 2 || map0.getNumInputs() != 3 || + map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { + return false; + } + + // Extract dimensions for B*N*K * B*K -> B*N + AffineExpr b = map0.getResult(0); + AffineExpr k = map1.getResult(1); + AffineExpr n = map2.getResult(1); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, n, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + Operation *mlir::clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands) { IRMapping bvm; diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp index 583dbd463b911..3f576bacebf6a 100644 --- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp @@ -240,4 +240,134 @@ TEST(isRowMajorBatchMatmul, FirstInputSwapped) { EXPECT_THAT(maps, Not(Truly(isRowMajorBatchMatmul))); } +TEST(isVecmat, Simple) { + MLIRContext context; + + AffineExpr k, n; + bindDims(&context, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isVecmat)); +} + +TEST(isVecmat, BindingSwapped) { + MLIRContext context; + + AffineExpr k, n; + bindDims(&context, n, k); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isVecmat)); +} + +TEST(isVecmat, WrongDimOrderMatrix) { + MLIRContext context; + + AffineExpr k, n; + bindDims(&context, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isVecmat))); +} + +TEST(isMatvec, Simple) { + MLIRContext context; + + AffineExpr k, n; + bindDims(&context, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isMatvec)); +} + +TEST(isMatvec, BindingSwapped) { + MLIRContext context; + + AffineExpr k, n; + bindDims(&context, n, k); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isMatvec)); +} + +TEST(isMatvec, WrongDimOrderMatrix) { + MLIRContext context; + + AffineExpr k, n; + bindDims(&context, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isMatvec))); +} + +TEST(isBatchMatvec, Simple) { + MLIRContext context; + + AffineExpr batch, k, n; + bindDims(&context, batch, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isBatchMatvec)); +} + +TEST(isBatchMatvec, BindingSwapped) { + MLIRContext context; + + AffineExpr batch, k, n; + bindDims(&context, batch, n, k); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isBatchMatvec)); +} + +TEST(isBatchMatvec, Matmul) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isBatchMatvec))); +} + +TEST(isBatchMatvec, WrongDimOrderMatrix) { + MLIRContext context; + + AffineExpr batch, k, n; + bindDims(&context, batch, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isBatchMatvec))); +} + } // namespace