-
Notifications
You must be signed in to change notification settings - Fork 671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of vecmat/matvec in SetEncoding and MaterializeEncoding #15257
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR is quite large. Can we break it down? E.g., I think at least we could have one PR adding materialization patterns, and the other for set_encoding.
case EncodingUser::VECMAT: | ||
return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{1, 4, 8}); | ||
case EncodingUser::MATVEC: | ||
case EncodingUser::BATCH_MATVEC: | ||
return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 1}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is mostly for testing. To enable it on CPU backend, you need to add configurations to CPUMaterializeEncodingPass
(I'm sorry about the structure and I will centralize the code)
switch (user) { | ||
case IREE::LinalgExt::EncodingUser::MATMUL: | ||
opTiled = rewriter | ||
.create<linalg::MatmulOp>( | ||
loc, encodedOut.getType(), | ||
ValueRange{encodedLhs, encodedRhs}, encodedOut) | ||
.getResult(0); | ||
break; | ||
|
||
case IREE::LinalgExt::EncodingUser::BATCH_MATMUL: | ||
opTiled = rewriter | ||
.create<linalg::BatchMatmulOp>( | ||
loc, encodedOut.getType(), | ||
ValueRange{encodedLhs, encodedRhs}, encodedOut) | ||
.getResult(0); | ||
break; | ||
case IREE::LinalgExt::EncodingUser::VECMAT: | ||
opTiled = rewriter | ||
.create<linalg::VecmatOp>( | ||
loc, encodedOut.getType(), | ||
ValueRange{encodedLhs, encodedRhs}, encodedOut) | ||
.getResult(0); | ||
break; | ||
case IREE::LinalgExt::EncodingUser::MATVEC: | ||
opTiled = rewriter | ||
.create<linalg::MatvecOp>( | ||
loc, encodedOut.getType(), | ||
ValueRange{encodedLhs, encodedRhs}, encodedOut) | ||
.getResult(0); | ||
break; | ||
case IREE::LinalgExt::EncodingUser::BATCH_MATVEC: | ||
opTiled = rewriter | ||
.create<linalg::BatchMatvecOp>( | ||
loc, encodedOut.getType(), | ||
ValueRange{encodedLhs, encodedRhs}, encodedOut) | ||
.getResult(0); | ||
break; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EncodingUser resultUser = resultEncoding.getUser().getValue(); | ||
if (!isVecmatEncodingUser(resultUser) && !isMatvecEncodingUser(resultUser) && | ||
!isMatmulEncodingUser(resultUser) && | ||
!isBatchMatvecEncodingUser(resultUser) && | ||
!isBatchMatmulEncodingUser(resultUser)) { | ||
return rewriter.notifyMatchFailure(op, "unsupported encoding type"); | ||
} | ||
|
||
auto outType = operands[2].getType().cast<RankedTensorType>(); | ||
|
||
auto loc = op.getLoc(); | ||
Operation *resultOp; | ||
if (isBatchMatvecEncodingUser(resultUser) || | ||
isBatchMatmulEncodingUser(resultUser)) { | ||
resultOp = rewriter.create<linalg::BatchMmt4DOp>( | ||
op.getLoc(), outType, ValueRange{operands[0], operands[1]}, | ||
ValueRange{operands[2]}); | ||
} else { | ||
resultOp = rewriter.create<linalg::Mmt4DOp>( | ||
op.getLoc(), outType, ValueRange{operands[0], operands[1]}, | ||
ValueRange{operands[2]}); | ||
} | ||
return resultOp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use switch-case here? It's cleaner and easier for maintenance.
return false; | ||
} | ||
|
||
int64_t getExpandedDimIndex(EncodingAttr encoding) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a doc. What happen if it returns -1
?
Sorry that I failed to mention this possibility earlier, but: One simple approach would be to rewrite |
@bjacob I think this is on me for not advocating for it a bit harder once I saw the amount of work it took, since I considered it. Where do you think this would fit best, while setting the encoding or elsewhere? |
I think it definitely is nice to have it be its own separate pattern+pass -- this is quite generally useful to have, and reviewers tend to like reusable composable patterns that do one thing each. You could initially create it as its own new pattern/pass in GlobalOptimization/ side by side with SetEncoding. Then if a reviewer thinks it belongs elsewhere, that should be easy to move. |
…ces to enable tiling (iree-org#15273) As discussed in [issue#15053](iree-org#15053) and [PR#15257](iree-org#15257), expanding the vectors in vecmat/matvec operations and avoiding them in the encoding step entirely seems like the best approach.
This PR is currently broken as it relies on 68945.
This implementation is by no means elegant. The vectors being passed into
MaterializeEncoding
without being expanded first require several plugs in multiple operations, since most of them try to pack the tensors directly (or at least infer the shape after packing). Thus, the shapes post-expansion need to be inferred outside thevecmat
/matvec
op lowering itself. This also requires a surprising amount of utility ops to calculate.In addition, this approach creates the issue that since we are first padding, then expanding, and then packing -- the pad+pack ops don't get folded into pack.
Expanding the tensors in
SetEncoding
would elliminate most of the lack of elegance (but add some of it back in there, so this is not a perfect soultion by any means), but it does introduce unneccesary expansions earlier in the stack as discussed in 15053. There might be a different approach to this that I haven't considered, so feedback is appreciated.