diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 04d9ddf2183f8..6e63d52d22a1f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator return v; Type promotedType = dstElementType; if (vecType) - promotedType = VectorType::get(vecType.getShape(), promotedType); + promotedType = vecType.clone(promotedType); if (isa(dstElementType)) return rewriter.create(loc, promotedType, v); return rewriter.create(loc, promotedType, v); diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir index deea7747f3679..3746897bcd864 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -34,16 +34,16 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, // CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>, // CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>, -// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> // CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> // CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1> -// CHECK: vector.mask %[[MASK0]] { vector.outerproduct +// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> // CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1> -// CHECK: vector.mask %[[MASK1]] { vector.outerproduct +// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> // CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1> -// CHECK: vector.mask %[[MASK2]] { vector.outerproduct +// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3xf32>, @@ -54,22 +54,46 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, return %0 : vector<2xf32> } + +// CHECK-LABEL: func.func @masked_extract_contract2_scalable( +// CHECK-SAME: %{{.*}}: vector<[2]x[3]xf32>, +// CHECK-SAME: %{{.*}}: vector<[3]xf32>, +// CHECK-SAME: %{{.*}}: vector<[2]xf32>, +// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x[3]xi1>) -> vector<[2]xf32> +// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x[3]xi1> to vector<[3]x[2]xi1> +// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<[3]x[2]xi1> +// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> + +// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<[3]x[2]xi1> +// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> + +// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<[3]x[2]xi1> +// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> +func.func @masked_extract_contract2_scalable(%arg0: vector<[2]x[3]xf32>, + %arg1: vector<[3]xf32>, + %arg2: vector<[2]xf32>, + %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32> + return %0 : vector<[2]xf32> +} + // CHECK-LABEL: func.func @masked_extract_contract4( -// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>, -// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { -// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> -// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1> -// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1> -// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1> -// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1> -// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1> -// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK-SAME: %{{.*}}: vector<3x5xf32>, +// CHECK-SAME: %{{.*}}: vector<5x7xf32>, +// CHECK-SAME: %{{.*}}: vector<3x7xf32>, +// CHECK-SAME: %[[M:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { +// CHECK: %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> +// CHECK: %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<3x7xi1> from vector<5x3x7xi1> +// CHECK: %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<3x7xi1> from vector<5x3x7xi1> +// CHECK: %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<3x7xi1> from vector<5x3x7xi1> +// CHECK: %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<3x7xi1> from vector<5x3x7xi1> +// CHECK: %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<3x7xi1> from vector<5x3x7xi1> +// CHECK: %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, %arg1: vector<5x7xf32>, @@ -80,10 +104,36 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, return %0 : vector<3x7xf32> } +// CHECK-LABEL: func.func @masked_extract_contract4_scalable( +// CHECK-SAME: %{{.*}}: vector<[3]x[5]xf32>, +// CHECK-SAME: %{{.*}}: vector<[5]x[7]xf32>, +// CHECK-SAME: %{{.*}}: vector<[3]x[7]xf32>, +// CHECK-SAME: %[[M:.*]]: vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> { +// CHECK: %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<[3]x[7]x[5]xi1> to vector<[5]x[3]x[7]xi1> +// CHECK: %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1> +// CHECK: %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32> +// CHECK: %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1> +// CHECK: %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32> +// CHECK: %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1> +// CHECK: %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32> +// CHECK: %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1> +// CHECK: %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32> +// CHECK: %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1> +// CHECK: %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32> + +func.func @masked_extract_contract4_scalable(%arg0: vector<[3]x[5]xf32>, + %arg1: vector<[5]x[7]xf32>, + %arg2: vector<[3]x[7]xf32>, + %m : vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> { + %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<[3]x[5]xf32>, vector<[5]x[7]xf32> into vector<[3]x[7]xf32> } : vector<[3]x[7]x[5]xi1> -> vector<[3]x[7]xf32> + return %0 : vector<[3]x[7]xf32> +} + // CHECK-LABEL: func @matmul -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>, +// CHECK-SAME: %[[B:.*]]: vector<4x3xf32>, +// CHECK-SAME: %[[C:.*]]: vector<2x3xf32> // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] // CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32> // @@ -116,6 +166,42 @@ func.func @matmul(%arg0: vector<2x4xf32>, return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_scalable +// CHECK-SAME: %[[A:.*]]: vector<[2]x[4]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[4]x[3]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-SAME: : vector<[2]x[4]xf32> to vector<[4]x[2]xf32> +// +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[4]x[2]xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[4]x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32> +// +// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<[2]xf32> from vector<[4]x[2]xf32> +// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<[4]x[3]xf32> +// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] +// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32> +// +// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<[2]xf32> from vector<[4]x[2]xf32> +// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<[4]x[3]xf32> +// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] +// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32> +// +// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<[2]xf32> from vector<[4]x[2]xf32> +// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<[4]x[3]xf32> +// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] +// CHECK-SAME: : vector<[2]xf32>, vector<[3]xf32> +// +// CHECK: return %[[c3]] : vector<[2]x[3]xf32> +func.func @matmul_scalable(%arg0: vector<[2]x[4]xf32>, + %arg1: vector<[4]x[3]xf32>, + %arg2: vector<[2]x[3]xf32>) -> vector<[2]x[3]xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<[2]x[4]xf32>, vector<[4]x[3]xf32> into vector<[2]x[3]xf32> + return %0 : vector<[2]x[3]xf32> +} + // CHECK-LABEL: func @matmul_0 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, @@ -133,6 +219,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_0_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + // CHECK-LABEL: func @matmul_0_mixed // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, @@ -152,6 +255,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_0_mixed_scalable +// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf16>, +// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf16>, +// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf16> from vector<[1]x[2]xf16> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<[1]x[3]xf16> +// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<[2]xf16> to vector<[2]xf32> +// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] +// CHECK: return %[[c0]] : vector<[2]x[3]xf32> +func.func @matmul_0_mixed_scalable(%arg0: vector<[2]x[1]xf16>, %arg1: vector<[1]x[3]xf16>, %arg2: vector<[2]x[3]xf32>) +-> vector<[2]x[3]xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<[2]x[1]xf16>, vector<[1]x[3]xf16> into vector<[2]x[3]xf32> + return %0 : vector<[2]x[3]xf32> +} + #matmat_accesses_1 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, @@ -163,9 +285,9 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: } // CHECK-LABEL: func @matmul_1 -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>, +// CHECK-SAME: %[[C:.*]]: vector<2x3xf32> // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> @@ -180,6 +302,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_1_scalable +// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[3]x[1]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<[1]x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<[2]x[3]xf32> +func.func @matmul_1_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[3]x[1]xf32>, %arg2: vector<[2]x[3]xf32>) +-> vector<[2]x[3]xf32> +{ + %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 + : vector<[2]x[1]xf32>, vector<[3]x[1]xf32> into vector<[2]x[3]xf32> + return %0 : vector<[2]x[3]xf32> +} + #matmat_accesses_2 = [ affine_map<(m, n, k) -> (k, m)>, affine_map<(m, n, k) -> (k, n)>, @@ -191,9 +331,9 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto } // CHECK-LABEL: func @matmul_2 -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:.*]]: vector<2x3xf32> // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] @@ -206,6 +346,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_2_scalable +// CHECK-SAME: %[[A:.*]]: vector<[1]x[2]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<[2]xf32> from vector<[1]x[2]xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<[2]x[3]xf32> +func.func @matmul_2_scalable(%arg0: vector<[1]x[2]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[2]x[3]xf32>) +-> vector<[2]x[3]xf32> +{ + %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 + : vector<[1]x[2]xf32>, vector<[1]x[3]xf32> into vector<[2]x[3]xf32> + return %0 : vector<[2]x[3]xf32> +} + #matmat_accesses_3 = [ affine_map<(m, n, k) -> (k, m)>, affine_map<(m, n, k) -> (n, k)>, @@ -217,9 +373,9 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto } // CHECK-LABEL: func @matmul_3 -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>, +// CHECK-SAME: %[[C:.*]]: vector<2x3xf32> // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32> @@ -233,6 +389,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_3_scalable +// CHECK-SAME: %[[A:.*]]: vector<[1]x[2]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[3]x[1]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32> +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<[2]xf32> from vector<[1]x[2]xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<[1]x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<[2]x[3]xf32> +func.func @matmul_3_scalable(%arg0: vector<[1]x[2]xf32>, %arg1: vector<[3]x[1]xf32>, %arg2: vector<[2]x[3]xf32>) +-> vector<[2]x[3]xf32> +{ + %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 + : vector<[1]x[2]xf32>, vector<[3]x[1]xf32> into vector<[2]x[3]xf32> + return %0 : vector<[2]x[3]xf32> +} + #matmat_accesses_4 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, @@ -244,9 +417,9 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto } // CHECK-LABEL: func @matmul_4 -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:.*]]: vector<3x2xf32> // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> @@ -260,6 +433,23 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @matmul_4_scalable +// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<[3]x[2]xf32> +func.func @matmul_4_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>) +-> vector<[3]x[2]xf32> +{ + %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 + : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32> + return %0 : vector<[3]x[2]xf32> +} + #matmat_accesses_5 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, @@ -271,9 +461,9 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto } // CHECK-LABEL: func @matmul_5 -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:.*]]: vector<3x2xf32> // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] // CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> // CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> @@ -287,6 +477,23 @@ func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @matmul_5_scalable +// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<[3]x[2]xf32> +func.func @matmul_5_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>) +-> vector<[3]x[2]xf32> +{ + %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 + : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32> + return %0 : vector<[3]x[2]xf32> +} + #matmat_accesses_6 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, @@ -298,9 +505,9 @@ func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto } // CHECK-LABEL: func @matmul_6 -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:.*]]: vector<3x2xf32> // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] // CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> // CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> @@ -314,6 +521,23 @@ func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @matmul_6_scalable +// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<[3]x[2]xf32> +func.func @matmul_6_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>) +-> vector<[3]x[2]xf32> +{ + %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 + : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32> + return %0 : vector<[3]x[2]xf32> +} + #matmat_accesses_7 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, @@ -325,9 +549,9 @@ func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto } // CHECK-LABEL: func @matmul_7 -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:.*]]: vector<3x2xf32> // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] // CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> // CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> @@ -341,6 +565,23 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @matmul_7_scalable +// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>, +// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<[3]x[2]xf32> +func.func @matmul_7_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>) +-> vector<[3]x[2]xf32> +{ + %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 + : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32> + return %0 : vector<[3]x[2]xf32> +} + // CHECK-LABEL: @masked_matvec_mk_k_m // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -362,6 +603,27 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a return %res : vector<4xf32> } +// CHECK-LABEL: @masked_matvec_mk_k_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[4]x[2]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[4]x[2]xi1> +func.func @masked_matvec_mk_k_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[MASK]] + // CHECK: vector.transpose %[[MAT]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(m, k) -> (m, k)>, + affine_map<(m, k) -> (k)>, + affine_map<(m, k) -> (m)>], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<[4]x[2]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[4]x[2]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + // CHECK-LABEL: @masked_matvec_km_k_m // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -383,6 +645,27 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a return %res : vector<4xf32> } +// CHECK-LABEL: @masked_matvec_km_k_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[2]x[4]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[4]x[2]xi1> +func.func @masked_matvec_km_k_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[MASK]] + // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(m, k) -> (k, m)>, + affine_map<(m, k) -> (k)>, + affine_map<(m, k) -> (m)>], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<[2]x[4]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[4]x[2]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + // CHECK-LABEL: @masked_matvec_k_mk_m // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -404,6 +687,27 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a return %res : vector<4xf32> } +// CHECK-LABEL: @masked_matvec_k_mk_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[4]x[2]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[4]x[2]xi1> +func.func @masked_matvec_k_mk_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[MASK]] + // CHECK: vector.transpose %[[MAT]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(m, k) -> (k)>, + affine_map<(m, k) -> (m, k)>, + affine_map<(m, k) -> (m)>], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind + } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[4]x[2]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[4]x[2]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + // CHECK-LABEL: @masked_matvec_k_km_m // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -425,6 +729,27 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a return %res : vector<4xf32> } +// CHECK-LABEL: @masked_matvec_k_km_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[2]x[4]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[4]x[2]xi1> +func.func @masked_matvec_k_km_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[MASK]] + // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(m, k) -> (k)>, + affine_map<(m, k) -> (k, m)>, + affine_map<(m, k) -> (m)>], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind + } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[2]x[4]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[4]x[2]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + // CHECK-LABEL: @masked_tmatvec_mk_k_m // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -446,6 +771,27 @@ func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, % return %res : vector<4xf32> } +// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[4]x[2]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[2]x[4]xi1> +func.func @masked_tmatvec_mk_k_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[MASK]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(k, m) -> (m, k)>, + affine_map<(k, m) -> (k)>, + affine_map<(k, m) -> (m)>], + iterator_types = ["reduction", "parallel"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<[4]x[2]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[2]x[4]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + // CHECK-LABEL: @masked_tmatvec_km_k_m // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -467,6 +813,27 @@ func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, % return %res : vector<4xf32> } +// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[2]x[4]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[2]x[4]xi1> +func.func @masked_tmatvec_km_k_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> { + // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[MASK]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(k, m) -> (k, m)>, + affine_map<(k, m) -> (k)>, + affine_map<(k, m) -> (m)>], + iterator_types = ["reduction", "parallel"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<[2]x[4]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[2]x[4]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + // CHECK-LABEL: @masked_tmatvec_k_mk_m // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -488,6 +855,27 @@ func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, % return %res : vector<4xf32> } +// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[4]x[2]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[2]x[4]xi1> +func.func @masked_tmatvec_k_mk_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[MASK]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(k, m) -> (k)>, + affine_map<(k, m) -> (m, k)>, + affine_map<(k, m) -> (m)>], + iterator_types = ["reduction", "parallel"], + kind = #vector.kind + } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[4]x[2]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[2]x[4]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + // CHECK-LABEL: @masked_tmatvec_k_km_m // CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32> @@ -509,6 +897,27 @@ func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, % return %res : vector<4xf32> } +// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable +// CHECK-SAME: %[[MAT:.+]]: vector<[2]x[4]xf32> +// CHECK-SAME: %[[VEC:.+]]: vector<[2]xf32> +// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[MASK:.+]]: vector<[2]x[4]xi1> +func.func @masked_tmatvec_k_km_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> { + // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[MASK]] + // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind} : vector<[4]xf32>, f32 } + %res = vector.mask %mask { + vector.contract { + indexing_maps = [affine_map<(k, m) -> (k)>, + affine_map<(k, m) -> (k, m)>, + affine_map<(k, m) -> (m)>], + iterator_types = ["reduction", "parallel"], + kind = #vector.kind + } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[2]x[4]xf32>, vector<[4]xf32> into vector<[4]xf32> + } : vector<[2]x[4]xi1> -> vector<[4]xf32> + return %res : vector<[4]xf32> +} + transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op):