Skip to content

Commit

Permalink
Fix tests and add type clarity in util function
Browse files Browse the repository at this point in the history
  • Loading branch information
NatashaKnk committed Oct 16, 2023
1 parent f2c6334 commit 58cda1d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
36 changes: 18 additions & 18 deletions mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;

auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();

if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
Expand All @@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;

auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();

if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
Expand All @@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;

auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();

if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
Expand All @@ -99,9 +99,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
bool mlir::isVecmat(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();

if (map0.getNumResults() != 1 || map1.getNumResults() != 2 ||
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
Expand All @@ -123,9 +123,9 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) {
bool mlir::isMatvec(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();

if (map0.getNumResults() != 2 || map1.getNumResults() != 1 ||
map2.getNumResults() != 1 || map0.getNumInputs() != 2 ||
Expand All @@ -147,9 +147,9 @@ bool mlir::isMatvec(ArrayAttr indexingMaps) {
bool mlir::isBatchMatvec(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();

if (map0.getNumResults() != 3 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
Expand Down
6 changes: 3 additions & 3 deletions mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ TEST(isVecmat, BindingSwapped) {
MLIRContext context;

AffineExpr k, n;
bindDims(&context, k, n); // bind in different order
bindDims(&context, n, k); // bind in different order
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k, n}, &context));
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
Expand Down Expand Up @@ -296,7 +296,7 @@ TEST(isMatvec, BindingSwapped) {
MLIRContext context;

AffineExpr k, n;
bindDims(&context, k, n); // bind in different order
bindDims(&context, n, k); // bind in different order
auto mapA = AffineMapAttr::get(AffineMap::get(2, 0, {n, k}, &context));
auto mapB = AffineMapAttr::get(AffineMap::get(2, 0, {k}, &context));
auto mapC = AffineMapAttr::get(AffineMap::get(2, 0, {n}, &context));
Expand Down Expand Up @@ -335,7 +335,7 @@ TEST(isBatchMatvec, BindingSwapped) {
MLIRContext context;

AffineExpr batch, k, n;
bindDims(&context, batch, k, n); // bind in different order
bindDims(&context, batch, n, k); // bind in different order
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context));
Expand Down

0 comments on commit 58cda1d

Please sign in to comment.