Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Update v.contract -> v.outerproduct tests (1/N) #70379

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 26, 2023

Tests for conversions from vector.contract to vector.outerproduct
for matvec operations are updated with cases for scalable vectors.

This patch updates one specific test file (there might be similar
tests elsewhere):

  • vector-contract-to-outerproduct-transforms.mlir.

Only the parallel dimension is made scalable. Making the reduction
dimension scalable would lead to different patterns without
vector.outerproduct (that would need to be added to some other file).

One duplicate test for matvec is removed.

@llvmbot
Copy link
Member

llvmbot commented Oct 26, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Tests for conversions from vector.contract to vector.outerproduct
are updated with cases for scalable vectors. This patch updates one
specific test files:

  • vector-contract-to-outerproduct-transforms.mlir,

and only updates tests for matvec operations (the remaining matmul
operations have been updated in previous patches). For consistency with
the existing tests, only the parallel dimension is made scalable. Making
the reduction dimension scalable would lead to different patterns
without vector.outerproduct.


Full diff: https://github.com/llvm/llvm-project/pull/70379.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+224-56)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 44fb23088cea933..ec88759cd4927cb 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -313,6 +313,16 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+#matvec_accesses_1 = [
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_1 = {
+  indexing_maps = #matvec_accesses_1,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -323,17 +333,38 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (m, k)>,
-                       affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_2 = [
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_2 = {
+  indexing_maps = #matvec_accesses_2,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -344,17 +375,38 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k, m)>,
-                       affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_3 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_3 = {
+  indexing_maps = #matvec_accesses_3,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -365,17 +417,54 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m, k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_4 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_4 = {
+  indexing_maps = #matvec_accesses_4,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -386,17 +475,22 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (k, m)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+#matvec_accesses_5 = [
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+  indexing_maps = #matvec_accesses_5,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -407,17 +501,38 @@ func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (m, k)>,
-                       affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_6 = [
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+  indexing_maps = #matvec_accesses_6,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -428,17 +543,38 @@ func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k, m)>,
-                       affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_7 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+  indexing_maps = #matvec_accesses_7,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -449,17 +585,38 @@ func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m, k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_8 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_8 = {
+  indexing_maps = #matvec_accesses_8,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -470,17 +627,28 @@ func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (k, m)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {

@banach-space banach-space changed the title [mlir][vector] Update v.contract -> v.outerproduct tests [mlir][vector] Update v.contract -> v.outerproduct tests (1/N) Oct 27, 2023
@banach-space banach-space force-pushed the andrzej/add_more_tests_for_scaalble branch from e3aa2ca to c262dec Compare October 27, 2023 08:02
banach-space added a commit to banach-space/llvm-project that referenced this pull request Oct 27, 2023
The remaining tests for conversions from vector.contract to
vector.outerproduct for _matmul_ operations in:

  * "vector-contract-to-outerproduct-transforms.mlir"

are updated with cases for scalable vectors. One duplicated test is
removed.

In addition:

  * tests are re-organised so that _matvec_ tests and _matmul_ tests are
    "clustered" together,
  * one duplicate case for _matvec_ is removed,
  * function formatting is unified,
  * added comments to document and to seperate different cases,
  * unified the naming for matrix/vector dimensions: (i, j, k) -> (m, n,
    k),

While this does add a bit of noise to this patch, I wanted to avoid
sending seperate patches to refactor this file.

Depends on llvm#70379
@banach-space banach-space force-pushed the andrzej/add_more_tests_for_scaalble branch from c262dec to 4d3f6ba Compare October 27, 2023 08:48
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These mostly make sense to me, I think you've got all the duplicates (but I find them hard to spot 😅)

@banach-space
Copy link
Contributor Author

These mostly make sense to me, I think you've got all the duplicates (but I find them hard to spot 😅)

Yeah, that's why I have been updating:

#matvec_accesses_1 = [
  affine_map<(i, j) -> (i, j)>,
  affine_map<(i, j) -> (j)>,
  affine_map<(i, j) -> (i)>
]

to

#matvec_accesses_1 = [
  affine_map<(m, k) -> (m, k)>,
  affine_map<(m, k) -> (k)>,
  affine_map<(m, k) -> (m)>
]

so that all examples are consistent. Then it's easy to spot duplicates.

I've updated the summary to capture the changes from the fixup and will merge this shortly. Thanks for taking a look!

Tests for conversions from `vector.contract` to `vector.outerproduct`
for _matvec_ operations are updated with cases for scalable vectors.

This patch updates one specific test file (there might be similar
tests elsewhere):

   * vector-contract-to-outerproduct-transforms.mlir.

Only the parallel dimension is made scalable. Making the reduction
dimension scalable would lead to different patterns without
`vector.outerproduct` (that would need to be added to some other file).

One duplicate test for _matvec_ is removed.
@banach-space banach-space force-pushed the andrzej/add_more_tests_for_scaalble branch from c45f574 to be6d364 Compare October 27, 2023 11:59
@banach-space banach-space merged commit 8e0b3a8 into llvm:main Oct 27, 2023
@banach-space banach-space deleted the andrzej/add_more_tests_for_scaalble branch October 27, 2023 12:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants