diff --git a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp index 617789754900..1ad3ab1fad4f 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp @@ -146,25 +146,29 @@ static Value unsetEncodingAndExtractSlice(OpBuilder &builder, Location loc, namespace { -/// Rewrites the matmul op to work on tensors with encoding. Optionally +/// Rewrites contraction ops to work on tensors with encoding. Optionally /// also pads the operands. -struct SetMatmulEncoding : public OpRewritePattern { - SetMatmulEncoding(MLIRContext *context, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, +/// Currently works on matmul, batch_matmul, vecmat, matvec and batch_matvec. +struct SetContractionOpEncoding + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern< + linalg::ContractionOpInterface>::OpInterfaceRewritePattern; + LogicalResult matchAndRewrite(linalg::ContractionOpInterface op, PatternRewriter &rewriter) const override { - if (!matmulOp.hasTensorSemantics()) + auto linalgOp = dyn_cast(op.getOperation()); + if (!linalgOp || !linalgOp.hasTensorSemantics()) return failure(); - auto inputs = matmulOp.getDpsInputs(); - auto outputs = matmulOp.getDpsInits(); + + auto inputs = linalgOp.getDpsInputs(); + auto outputs = linalgOp.getDpsInits(); auto hasEncoding = [](Value operand) -> bool { auto type = llvm::dyn_cast(operand.getType()); return type && type.getEncoding(); }; if (llvm::any_of(inputs, hasEncoding) || llvm::any_of(outputs, hasEncoding)) { - return failure(); + return rewriter.notifyMatchFailure(op, + "input/output already has encoding"); } Value origLhs = inputs[0]; @@ -182,85 +186,26 @@ struct SetMatmulEncoding : public OpRewritePattern { Type outElemType = getElemType(origOut); if (!lhsElemType || !rhsElemType || !outElemType) { - return failure(); + return rewriter.notifyMatchFailure(op, "invalid input/output"); } - IREE::LinalgExt::EncodingUser user = IREE::LinalgExt::EncodingUser::MATMUL; - Location loc = matmulOp.getLoc(); - TypeRange operandTypes = matmulOp->getOperandTypes(); - Value encodedLhs = - padAndSetEncoding(rewriter, loc, origLhs, user, - IREE::LinalgExt::EncodingRole::LHS, operandTypes); - Value encodedRhs = - padAndSetEncoding(rewriter, loc, origRhs, user, - IREE::LinalgExt::EncodingRole::RHS, operandTypes); - Value encodedOut = - padAndSetEncoding(rewriter, loc, origOut, user, - IREE::LinalgExt::EncodingRole::RESULT, operandTypes); - - Value matmulTiled = rewriter - .create( - loc, encodedOut.getType(), - ValueRange{encodedLhs, encodedRhs}, encodedOut) - .getResult(0); - - // Sizes are computed by original output size. - FailureOr> origOutSizes = - IREE::LinalgExt::getDims(rewriter, loc, origOut); - if (failed(origOutSizes)) { - return rewriter.notifyMatchFailure(matmulOp, - "failed to get shape of result"); + IREE::LinalgExt::EncodingUser user; + if (op.isRowMajorMatmul() || op.isColumnMajorMatmul()) { + user = IREE::LinalgExt::EncodingUser::MATMUL; + } else if (op.isRowMajorBatchMatmul()) { + user = IREE::LinalgExt::EncodingUser::BATCH_MATMUL; + } else if (op.isVecmat()) { + user = IREE::LinalgExt::EncodingUser::VECMAT; + } else if (op.isMatvec()) { + user = IREE::LinalgExt::EncodingUser::MATVEC; + } else if (op.isBatchMatvec()) { + user = IREE::LinalgExt::EncodingUser::BATCH_MATVEC; + } else { + return rewriter.notifyMatchFailure(op, "unsupported contraction op"); } - Value result = unsetEncodingAndExtractSlice(rewriter, loc, matmulTiled, - origOutSizes.value()); - - rewriter.replaceOp(matmulOp, result); - return success(); - } -}; - -struct SetBatchMatmulEncoding : public OpRewritePattern { - SetBatchMatmulEncoding(MLIRContext *context, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(linalg::BatchMatmulOp matmulOp, - PatternRewriter &rewriter) const override { - if (!matmulOp.hasTensorSemantics()) - return failure(); - auto inputs = matmulOp.getDpsInputs(); - auto outputs = matmulOp.getDpsInits(); - auto hasEncoding = [](Value operand) -> bool { - auto type = llvm::dyn_cast(operand.getType()); - return type && type.getEncoding(); - }; - if (llvm::any_of(inputs, hasEncoding) || - llvm::any_of(outputs, hasEncoding)) { - return failure(); - } - - Value origLhs = inputs[0]; - Value origRhs = inputs[1]; - Value origOut = outputs[0]; - - auto getElemType = [](Value v) -> Type { - if (auto tensorType = llvm::dyn_cast(v.getType())) { - return tensorType.getElementType(); - } - return {}; - }; - Type lhsElemType = getElemType(origLhs); - Type rhsElemType = getElemType(origRhs); - Type outElemType = getElemType(origOut); - - if (!lhsElemType || !rhsElemType || !outElemType) { - return failure(); - } - - IREE::LinalgExt::EncodingUser user = - IREE::LinalgExt::EncodingUser::BATCH_MATMUL; - Location loc = matmulOp.getLoc(); - TypeRange operandTypes = matmulOp->getOperandTypes(); + Location loc = op.getLoc(); + TypeRange operandTypes = op->getOperandTypes(); Value encodedLhs = padAndSetEncoding(rewriter, loc, origLhs, user, IREE::LinalgExt::EncodingRole::LHS, operandTypes); @@ -271,24 +216,56 @@ struct SetBatchMatmulEncoding : public OpRewritePattern { padAndSetEncoding(rewriter, loc, origOut, user, IREE::LinalgExt::EncodingRole::RESULT, operandTypes); - Value matmulTiled = rewriter - .create( - loc, encodedOut.getType(), - ValueRange{encodedLhs, encodedRhs}, encodedOut) - .getResult(0); + Value opTiled; + switch (user) { + case IREE::LinalgExt::EncodingUser::MATMUL: + opTiled = rewriter + .create( + loc, encodedOut.getType(), + ValueRange{encodedLhs, encodedRhs}, encodedOut) + .getResult(0); + break; + + case IREE::LinalgExt::EncodingUser::BATCH_MATMUL: + opTiled = rewriter + .create( + loc, encodedOut.getType(), + ValueRange{encodedLhs, encodedRhs}, encodedOut) + .getResult(0); + break; + case IREE::LinalgExt::EncodingUser::VECMAT: + opTiled = rewriter + .create( + loc, encodedOut.getType(), + ValueRange{encodedLhs, encodedRhs}, encodedOut) + .getResult(0); + break; + case IREE::LinalgExt::EncodingUser::MATVEC: + opTiled = rewriter + .create( + loc, encodedOut.getType(), + ValueRange{encodedLhs, encodedRhs}, encodedOut) + .getResult(0); + break; + case IREE::LinalgExt::EncodingUser::BATCH_MATVEC: + opTiled = rewriter + .create( + loc, encodedOut.getType(), + ValueRange{encodedLhs, encodedRhs}, encodedOut) + .getResult(0); + break; + } // Sizes are computed by original output size. FailureOr> origOutSizes = IREE::LinalgExt::getDims(rewriter, loc, origOut); if (failed(origOutSizes)) { - return rewriter.notifyMatchFailure(matmulOp, - "failed to get shape of result"); + return rewriter.notifyMatchFailure(op, "failed to get shape of result"); } - Value result = unsetEncodingAndExtractSlice(rewriter, loc, matmulTiled, + Value result = unsetEncodingAndExtractSlice(rewriter, loc, opTiled, origOutSizes.value()); - - rewriter.replaceOp(matmulOp, result); + rewriter.replaceOp(linalgOp, result); return success(); } }; @@ -332,7 +309,7 @@ void SetEncodingPass::runOnOperation() { MLIRContext *context = &getContext(); { RewritePatternSet patterns(context); - patterns.insert(context); + patterns.insert(context); linalg::FillOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir index e51ec26772ea..ed8a1d803ad8 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir @@ -603,6 +603,284 @@ func.func @batch_matmul_i8i8i32(%arg0 : tensor<64x100x250xi8>, %arg1 : tensor<64 // ----- +func.func @vecmat_f32f32f32(%arg0 : tensor<250xf32>, %arg1 : tensor<250x100xf32>, + %arg2 : tensor<100xf32>) -> tensor<100xf32> { + %0 = linalg.vecmat ins(%arg0, %arg1 : tensor<250xf32>, tensor<250x100xf32>) + outs(%arg2 : tensor<100xf32>) -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> +// CHECK: func @vecmat_f32f32f32( +// CHECK-SAME: %[[ARG0:.+]]: tensor<250xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<250x100xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor<100xf32> +// CHECK-DAG: %[[C250:.+]] = arith.constant 250 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C100:.+]] = arith.constant 100 : index +// CHECK: %[[LHS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor<250xf32, #iree_linalg_ext.encoding> -> index +// CHECK: %[[LHS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]], %[[C250]]] +// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0] high[%[[LHS_PADDING_SIZE]]] +// CHECK: tensor<250xf32> to tensor +// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]] +// CHECK-SAME: tensor>> +// CHECK: %[[RHS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor<250x100xf32, #iree_linalg_ext.encoding> -> index, index +// CHECK: %[[RHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#0, %[[C250]]] +// CHECK: %[[RHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#1, %[[C100]]] +// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0] high[%[[RHS_PADDING_SIZE0]], %[[RHS_PADDING_SIZE1]]] +// CHECK: tensor<250x100xf32> to tensor +// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]] +// CHECK-SAME: tensor>> +// CHECK: %[[OUTS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor<100xf32, #iree_linalg_ext.encoding> -> index +// CHECK: %[[OUTS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]], %[[C100]]] +// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0] high[%[[OUTS_PADDING_SIZE]]] +// CHECK: tensor<100xf32> to tensor +// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]] +// CHECK-SAME: tensor>> +// CHECK: %[[VECMAT:.+]] = linalg.vecmat +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[VECMAT]] +// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0] [100] [1] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @vecmat_f32f32f32_dynamic(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.vecmat ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> +// CHECK: func @vecmat_f32f32f32_dynamic( +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor, %[[ARG2:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[LHS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor> -> index +// CHECK: %[[LHS_DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[LHS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]], %[[LHS_DIM0]]] +// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0] high[%[[LHS_PADDING_SIZE]]] +// CHECK: tensor to tensor +// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[RHS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor> -> index, index +// CHECK: %[[RHS_DIM0:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[RHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#0, %[[RHS_DIM0]]] +// CHECK: %[[RHS_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[RHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#1, %[[RHS_DIM1]]] +// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0] high[%[[RHS_PADDING_SIZE0]], %[[RHS_PADDING_SIZE1]]] +// CHECK: tensor to tensor +// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[OUTS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor> -> index +// CHECK: %[[OUTS_DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK: %[[OUTS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]], %[[OUTS_DIM0]]] +// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0] high[%[[OUTS_PADDING_SIZE]]] +// CHECK: tensor to tensor +// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[VECMAT:.+]] = linalg.vecmat +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[VECMAT]] +// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0] [%[[OUTS_DIM0]]] [1] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @matvec_bf16bf16f32(%arg0 : tensor<100x250xbf16>, %arg1 : tensor<250xbf16>, + %arg2 : tensor<100xf32>) -> tensor<100xf32> { + %0 = linalg.matvec ins(%arg0, %arg1 : tensor<100x250xbf16>, tensor<250xbf16>) + outs(%arg2 : tensor<100xf32>) -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> +// CHECK: func @matvec_bf16bf16f32( +// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xbf16> +// CHECK-SAME: %[[ARG1:.+]]: tensor<250xbf16> +// CHECK-SAME: %[[ARG2:.+]]: tensor<100xf32> +// CHECK-DAG: %[[C100:.+]] = arith.constant 100 : index +// CHECK-DAG: %[[C250:.+]] = arith.constant 250 : index +// CHECK-DAG: %[[BF16_0:.+]] = arith.constant 0.000000e+00 : bf16 +// CHECK-DAG: %[[F32_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[LHS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor<100x250xbf16, #iree_linalg_ext.encoding> -> index, index +// CHECK: %[[LHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#0, %[[C100]]] +// CHECK: %[[LHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#1, %[[C250]]] +// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[LHS_PADDING_SIZE0]], %[[LHS_PADDING_SIZE1]]] +// CHECK: tensor<100x250xbf16> to tensor +// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[RHS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor<250xbf16, #iree_linalg_ext.encoding> -> index +// CHECK: %[[RHS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]], %[[C250]]] +// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0] high[%[[RHS_PADDING_SIZE]]] +// CHECK: tensor<250xbf16> to tensor +// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[OUTS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor<100xf32, #iree_linalg_ext.encoding> -> index +// CHECK: %[[OUTS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]], %[[C100]]] +// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0] high[%[[OUTS_PADDING_SIZE]]] +// CHECK: tensor<100xf32> to tensor +// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[MATVEC:.+]] = linalg.matvec +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[MATVEC]] +// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0] [100] [1] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @matvec_bf16bf16f32_dynamic(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matvec ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> +// CHECK: func @matvec_bf16bf16f32_dynamic( +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor, %[[ARG2:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[BF16_0:.+]] = arith.constant 0.000000e+00 : bf16 +// CHECK-DAG: %[[F32_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[LHS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor> -> index, index +// CHECK: %[[LHS_DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[LHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#0, %[[LHS_DIM0]]] +// CHECK: %[[LHS_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[LHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#1, %[[LHS_DIM1]]] +// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[LHS_PADDING_SIZE0]], %[[LHS_PADDING_SIZE1]]] +// CHECK: tensor to tensor +// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[RHS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor> -> index +// CHECK: %[[RHS_DIM0:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[RHS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]], %[[RHS_DIM0]]] +// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0] high[%[[RHS_PADDING_SIZE]]] +// CHECK: tensor to tensor +// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[OUTS_TILE_SIZE:.+]] = iree_linalg_ext.upper_bound_tile_size tensor> -> index +// CHECK: %[[OUTS_DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK: %[[OUTS_PADDING_SIZE:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]], %[[OUTS_DIM0]]] +// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0] high[%[[OUTS_PADDING_SIZE]]] +// CHECK: tensor to tensor +// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[MATVEC:.+]] = linalg.matvec +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[MATVEC]] +// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0] [%[[OUTS_DIM0]]] [1] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @batch_matvec_i8i8i32(%arg0 : tensor<3x100x250xi8>, %arg1 : tensor<3x250xi8>, + %arg2 : tensor<3x100xi32>) -> tensor<3x100xi32> { + %0 = linalg.batch_matvec ins(%arg0, %arg1 : tensor<3x100x250xi8>, tensor<3x250xi8>) + outs(%arg2 : tensor<3x100xi32>) -> tensor<3x100xi32> + return %0 : tensor<3x100xi32> +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> +// CHECK: func @batch_matvec_i8i8i32( +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x100x250xi8> +// CHECK-SAME: %[[ARG1:.+]]: tensor<3x250xi8> +// CHECK-SAME: %[[ARG2:.+]]: tensor<3x100xi32> +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[C100:.+]] = arith.constant 100 : index +// CHECK-DAG: %[[C250:.+]] = arith.constant 250 : index +// CHECK-DAG: %[[I8_0:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[I32_0:.+]] = arith.constant 0 : i32 +// CHECK: %[[LHS_TILE_SIZE:.+]]:3 = iree_linalg_ext.upper_bound_tile_size tensor<3x100x250xi8, #iree_linalg_ext.encoding> -> index, index, index +// CHECK: %[[LHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#0, %[[C3]]] +// CHECK: %[[LHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#1, %[[C100]]] +// CHECK: %[[LHS_PADDING_SIZE2:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#2, %[[C250]]] +// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0, 0] high[%[[LHS_PADDING_SIZE0]], %[[LHS_PADDING_SIZE1]], %[[LHS_PADDING_SIZE2]]] +// CHECK: tensor<3x100x250xi8> to tensor +// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]] +// CHECK-SAME: tensor>> +// CHECK: %[[RHS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor<3x250xi8, #iree_linalg_ext.encoding> -> index, index +// CHECK: %[[RHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#0, %[[C3]]] +// CHECK: %[[RHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#1, %[[C250]]] +// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0] high[%[[RHS_PADDING_SIZE0]], %[[RHS_PADDING_SIZE1]]] +// CHECK: tensor<3x250xi8> to tensor +// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]] +// CHECK-SAME: tensor>> +// CHECK: %[[OUTS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor<3x100xi32, #iree_linalg_ext.encoding> -> index, index +// CHECK: %[[OUTS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]]#0, %[[C3]]] +// CHECK: %[[OUTS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]]#1, %[[C100]]] +// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0, 0] high[%[[OUTS_PADDING_SIZE0]], %[[OUTS_PADDING_SIZE1]]] +// CHECK: tensor<3x100xi32> to tensor +// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]] +// CHECK-SAME: tensor>> +// CHECK: %[[BATCH_MATVEC:.+]] = linalg.batch_matvec +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[BATCH_MATVEC]] +// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0, 0] [3, 100] [1, 1] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @batch_matvec_f16f16f16_dynamic(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.batch_matvec ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (-s1 + (s1 ceildiv s0) * s0)> +// CHECK: func @batch_matvec_f16f16f16_dynamic( +// CHECK-SAME: %[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor, %[[ARG2:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16 +// CHECK: %[[LHS_TILE_SIZE:.+]]:3 = iree_linalg_ext.upper_bound_tile_size tensor> -> index, index, index +// CHECK: %[[LHS_DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[LHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#0, %[[LHS_DIM0]]] +// CHECK: %[[LHS_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[LHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#1, %[[LHS_DIM1]]] +// CHECK: %[[LHS_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[LHS_PADDING_SIZE2:.+]] = affine.apply #[[MAP]]()[%[[LHS_TILE_SIZE]]#2, %[[LHS_DIM2]]] +// CHECK: %[[LHS_PAD:.+]] = tensor.pad %[[ARG0]] low[0, 0, 0] high[%[[LHS_PADDING_SIZE0]], %[[LHS_PADDING_SIZE1]], %[[LHS_PADDING_SIZE2]]] +// CHECK: tensor to tensor +// CHECK: %[[LHS:.+]] = iree_linalg_ext.set_encoding %[[LHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[RHS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor> -> index, index +// CHECK: %[[RHS_DIM0:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[RHS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#0, %[[RHS_DIM0]]] +// CHECK: %[[RHS_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[RHS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[RHS_TILE_SIZE]]#1, %[[RHS_DIM1]]] +// CHECK: %[[RHS_PAD:.+]] = tensor.pad %[[ARG1]] low[0, 0] high[%[[RHS_PADDING_SIZE0]], %[[RHS_PADDING_SIZE1]]] +// CHECK: tensor to tensor +// CHECK: %[[RHS:.+]] = iree_linalg_ext.set_encoding %[[RHS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[OUTS_TILE_SIZE:.+]]:2 = iree_linalg_ext.upper_bound_tile_size tensor> -> index, index +// CHECK: %[[OUTS_DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK: %[[OUTS_PADDING_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]]#0, %[[OUTS_DIM0]]] +// CHECK: %[[OUTS_DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor +// CHECK: %[[OUTS_PADDING_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[OUTS_TILE_SIZE]]#1, %[[OUTS_DIM1]]] +// CHECK: %[[OUTS_PAD:.+]] = tensor.pad %[[ARG2]] low[0, 0] high[%[[OUTS_PADDING_SIZE0]], %[[OUTS_PADDING_SIZE1]]] +// CHECK: tensor to tensor +// CHECK: %[[OUTS:.+]] = iree_linalg_ext.set_encoding %[[OUTS_PAD]] +// CHECK-SAME: tensor> +// CHECK: %[[BATCH_MATVEC:.+]] = linalg.batch_matvec +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[OUTS]] : +// CHECK: %[[RESULT_PADDED:.+]] = iree_linalg_ext.unset_encoding %[[BATCH_MATVEC]] +// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[RESULT_PADDED]][0, 0] [%[[OUTS_DIM0]], %[[OUTS_DIM1]]] [1, 1] +// CHECK: return %[[RESULT]] + +// ----- + func.func @fold_fill_with_set_encoding(%arg0 : index, %arg1 : index) -> tensor> { %cst = arith.constant 0.0 : f32 diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td index ee4f8fc3133f..0d71a76f4631 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td @@ -57,11 +57,17 @@ class IREELinalgExt_EnumAttr def MATMUL : I32EnumAttrCase<"MATMUL", 0>; def BATCH_MATMUL : I32EnumAttrCase<"BATCH_MATMUL", 1>; +def VECMAT : I32EnumAttrCase<"VECMAT", 2>; +def MATVEC : I32EnumAttrCase<"MATVEC", 3>; +def BATCH_MATVEC : I32EnumAttrCase<"BATCH_MATVEC", 4>; def EncodingUser : IREELinalgExt_I32EnumAttr<"EncodingUser", "Describes the operation that a tensor is an operand or a result of.", [ MATMUL, BATCH_MATMUL, + VECMAT, + MATVEC, + BATCH_MATVEC, ]>; def EncodingUserAttr : diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h index 889e0a7ce5fc..93285894d16b 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Utils/EncodingUtils.h @@ -9,6 +9,7 @@ #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" namespace mlir { namespace iree_compiler { @@ -21,6 +22,41 @@ bool isMatmulEncodingUser(EncodingUser user); // Check if encoding user is one of batch matmul encodings. bool isBatchMatmulEncodingUser(EncodingUser user); +// Check if encoding user is one of vecmat encodings. +bool isVecmatEncodingUser(EncodingUser user); + +// Check if encoding user is one of matvec encodings. +bool isMatvecEncodingUser(EncodingUser user); + +// Check if encoding user is one of batch matvec encodings. +bool isBatchMatvecEncodingUser(EncodingUser user); + +// Check if encoding belongs to a vector in a matrix/vector operation. +bool isVectorEncoding(EncodingAttr encoding); + +// Check if encoding user is a vector in a vecmat operation. +bool isVecmatVector(EncodingAttr encoding); + +// Check if encoding user is a vector in a matvec operation. +bool isMatvecVector(EncodingAttr encoding); + +// Check if encoding user is a vector in a batch_matvec operation. +bool isBatchMatvecVector(EncodingAttr encoding); + +// Get the dimension that is being expanded when provided a vector/matrix +// operation encoding. +int64_t getExpandedDimIndex(EncodingAttr encoding); + +// Get the reassociation maps for expanding/collapsing vectors in vector/matrix +// operations based on their encoding. +SmallVector +getReassociationMapsForVectors(EncodingAttr encoding); + +// Based on the encoding, deduce the new type of a vector after +// expanding/collapsing it in a vector/matrix operation. +RankedTensorType createNewTypeForVectors(RankedTensorType inputType, + EncodingAttr encoding, bool expanding); + struct MatmulTileParams { int64_t M = 1; int64_t K = 1; diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp index 520716396573..1186ffcb631a 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp @@ -62,6 +62,13 @@ getMaterializedType(RankedTensorType tensorType, if (failed(materializeEncodingInfo)) { return tensorType; } + + auto encoding = getEncodingAttr(tensorType); + if (isVectorEncoding(encoding)) { + tensorType = + createNewTypeForVectors(tensorType, encoding, /*expanding=*/true); + } + return tensor::PackOp::inferPackedType( getOriginalTypeWithEncoding(tensorType), materializeEncodingInfo->innerTileSizes, @@ -97,6 +104,11 @@ chooseEncodingInfo(RankedTensorType tensorType) { case EncodingUser::MATMUL: case EncodingUser::BATCH_MATMUL: return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 8}); + case EncodingUser::VECMAT: + return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{1, 4, 8}); + case EncodingUser::MATVEC: + case EncodingUser::BATCH_MATVEC: + return chooseEncodingInfoForMatmul(user, role, /*tileParams=*/{8, 4, 1}); } llvm_unreachable("unhandled EncodingUser case"); } @@ -135,8 +147,22 @@ static FailureOr lowerSetEncodingOpToPackOp( if (failed(materializeEncodingInfo)) { return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding"); } - // Create `tensor.empty` operation for the result of the pack operation. + Location loc = encodingOp.getLoc(); + + // Handle tensor expansion for Vecmat/Matvec + auto encoding = getEncodingAttr(resultType); + if (isVectorEncoding(encoding)) { + SmallVector ri = + getReassociationMapsForVectors(encoding); + if (!ri.empty()) { + resultType = createNewTypeForVectors(resultType, encoding, + /*expanding=*/true); + source = + rewriter.create(loc, resultType, source, ri); + } + } + // Create `tensor.empty` operation for the result of the pack operation. FailureOr> innerTileSizesOfr = getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo, materializeEncodingValueFn); @@ -144,9 +170,6 @@ static FailureOr lowerSetEncodingOpToPackOp( return rewriter.notifyMatchFailure( encodingOp, "failed to generate runtime tile size query"); } - auto encoding = getEncodingAttr(resultType); - if (!encoding) - return failure(); std::optional paddingValue = getPaddingValue(source); SmallVector sourceDims = getDims(rewriter, loc, source); SmallVector resultDims = tensor::PackOp::getResultShape( @@ -160,7 +183,7 @@ static FailureOr lowerSetEncodingOpToPackOp( *innerTileSizesOfr, paddingValue, materializeEncodingInfo->outerDimsPerm); } -/// Utility method to convert from `set_encoding` op to `pack` operation. +/// Utility method to convert from `unset_encoding` op to `unpack` operation. /// The source is taken as input so that these could be used with /// `OpConversionPatterns`. static FailureOr lowerUnsetEncodingToUnpackOp( @@ -177,6 +200,16 @@ static FailureOr lowerUnsetEncodingToUnpackOp( Location loc = encodingOp.getLoc(); SmallVector resultDims = getDims(rewriter, loc, encodingOp.getSource()); + + auto source = encodingOp.getSource(); + // Handle tensor expansion for Vecmat/Matvec + auto encoding = getEncodingAttr(sourceType); + if (isVectorEncoding(encoding)) { + sourceType = + createNewTypeForVectors(sourceType, encoding, /*expanding=*/true); + resultDims.insert(resultDims.begin() + getExpandedDimIndex(encoding), + rewriter.getI64IntegerAttr(1)); + } auto emptyOp = rewriter.create(loc, resultDims, sourceType.getElementType()); FailureOr> innerTileSizesOfr = @@ -221,83 +254,67 @@ static FailureOr> lowerUpperBoundTileSizeOpToConstants( } return results; } - -/// Utility method to convert from `linalg.matmul` with +/// Utility method to convert from linalg contraction ops with /// - lhs encoding with role=LHS /// - rhs encoding with role=RHS /// - result encoding with role=RESULT -/// to linalg.mmt4d op. +/// to linalg.mmt4d or linalg.batch_mmt4d ops. static FailureOr -lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp, - ValueRange convertedInputOperands, - ValueRange convertedOutputOperands, MaterializeEncodingFn, - MaterializeEncodingValueFn) { - if (!matmulOp.hasTensorSemantics()) +lowerOpWithEncoding(RewriterBase &rewriter, + mlir::linalg::ContractionOpInterface op, + ArrayRef operands) { + auto linalgOp = dyn_cast(op.getOperation()); + if (!linalgOp || !linalgOp.hasTensorSemantics()) return failure(); - auto inputs = matmulOp.getDpsInputOperands(); - auto outputs = matmulOp.getDpsInits(); - auto lhsEncoding = - getEncodingAttr(inputs[0]->get().getType().cast()); - auto rhsEncoding = - getEncodingAttr(inputs[1]->get().getType().cast()); - auto resultEncoding = - getEncodingAttr(outputs[0].getType().cast()); - if (!lhsEncoding || !rhsEncoding || !resultEncoding) { - return failure(); - } - if (!isMatmulEncodingUser(lhsEncoding.getUser().getValue()) || - !isMatmulEncodingUser(rhsEncoding.getUser().getValue()) || - !isMatmulEncodingUser(resultEncoding.getUser().getValue()) || - lhsEncoding.getRole().getValue() != - mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS || - rhsEncoding.getRole().getValue() != - mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS || - resultEncoding.getRole().getValue() != - mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RESULT) { - return failure(); - } - Operation *mmt4DOp = rewriter.create( - matmulOp.getLoc(), convertedOutputOperands[0].getType(), - convertedInputOperands, convertedOutputOperands); - return mmt4DOp; -} + auto inputs = linalgOp.getDpsInputs(); + auto outputs = linalgOp.getDpsInits(); -/// Utility method to convert from `linalg.batch_matmul` with -/// - lhs encoding with user=BATCH_MATMUL_*, role=LHS -/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS -/// - result encoding with user=BATCH_MATMUL_*, role=RESULT -/// to linalg.batch_mmt4d op. -static FailureOr -lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp, - ValueRange convertedInputOperands, - ValueRange convertedOutputOperands, MaterializeEncodingFn, - MaterializeEncodingValueFn) { - if (!batchMatmulOp.hasTensorSemantics()) - return failure(); - auto inputs = batchMatmulOp.getDpsInputOperands(); - auto outputs = batchMatmulOp.getDpsInits(); auto lhsEncoding = - getEncodingAttr(inputs[0]->get().getType().cast()); + getEncodingAttr(inputs[0].getType().cast()); auto rhsEncoding = - getEncodingAttr(inputs[1]->get().getType().cast()); + getEncodingAttr(inputs[1].getType().cast()); auto resultEncoding = getEncodingAttr(outputs[0].getType().cast()); if (!lhsEncoding || !rhsEncoding || !resultEncoding) { return failure(); } - if (!isBatchMatmulEncodingUser(lhsEncoding.getUser().getValue()) || - !isBatchMatmulEncodingUser(rhsEncoding.getUser().getValue()) || - !isBatchMatmulEncodingUser(resultEncoding.getUser().getValue()) || - lhsEncoding.getRole().getValue() != EncodingRole::LHS || + if (lhsEncoding.getRole().getValue() != EncodingRole::LHS || rhsEncoding.getRole().getValue() != EncodingRole::RHS || resultEncoding.getRole().getValue() != EncodingRole::RESULT) { - return failure(); + return rewriter.notifyMatchFailure(op, "incorrect encoding role"); } - Operation *batchMmt4DOp = rewriter.create( - batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(), - convertedInputOperands, convertedOutputOperands); - return batchMmt4DOp; + + if (lhsEncoding.getUser().getValue() != rhsEncoding.getUser().getValue() || + lhsEncoding.getUser().getValue() != resultEncoding.getUser().getValue() || + rhsEncoding.getUser().getValue() != resultEncoding.getUser().getValue()) { + return rewriter.notifyMatchFailure( + op, "encoding for all elements of operation must match"); + } + + EncodingUser resultUser = resultEncoding.getUser().getValue(); + if (!isVecmatEncodingUser(resultUser) && !isMatvecEncodingUser(resultUser) && + !isMatmulEncodingUser(resultUser) && + !isBatchMatvecEncodingUser(resultUser) && + !isBatchMatmulEncodingUser(resultUser)) { + return rewriter.notifyMatchFailure(op, "unsupported encoding type"); + } + + auto outType = operands[2].getType().cast(); + + auto loc = op.getLoc(); + Operation *resultOp; + if (isBatchMatvecEncodingUser(resultUser) || + isBatchMatmulEncodingUser(resultUser)) { + resultOp = rewriter.create( + op.getLoc(), outType, ValueRange{operands[0], operands[1]}, + ValueRange{operands[2]}); + } else { + resultOp = rewriter.create( + op.getLoc(), outType, ValueRange{operands[0], operands[1]}, + ValueRange{operands[2]}); + } + return resultOp; } /// Utility method to convert from `linalg.fill` on `tensor` type with @@ -330,6 +347,12 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp, return rewriter.notifyMatchFailure(emptyOp, "unhandled result encoding"); } Location loc = emptyOp.getLoc(); + // Handle tensor expansion for Vecmat/Matvec + auto encoding = getEncodingAttr(resultType); + if (isVectorEncoding(encoding)) { + resultType = createNewTypeForVectors(resultType, encoding, + /*expanding=*/true); + } FailureOr> innerTileSizesOfr = getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo, materializeEncodingValueFn); @@ -399,7 +422,24 @@ struct UnsetEncodingOpToPackOpConversion if (failed(unpackOp)) return rewriter.notifyMatchFailure(encodingOp, "failed to convert to unpack op"); - rewriter.replaceOp(encodingOp, unpackOp->getResult()); + Value unpacked = unpackOp->getResult(); + // Handle tensor collapsing for Vecmat/Matvec + auto encoding = getEncodingAttr(encodingOp.getSourceType()); + if (isVectorEncoding(encoding)) { + SmallVector ri = + getReassociationMapsForVectors(encoding); + if (!ri.empty()) { + auto unpackedType = unpacked.getType().cast(); + RankedTensorType sourceType = + createNewTypeForVectors(unpackedType, encoding, + /*expanding=*/false); + unpacked = rewriter + .create( + encodingOp.getLoc(), sourceType, unpacked, ri) + .getResult(); + } + } + rewriter.replaceOp(encodingOp, unpacked); return success(); } }; @@ -450,6 +490,26 @@ struct MaterializeDPSOperation : public OpMaterializeEncodingPattern { } }; +/// Generic pattern to convert operaiton that is in Destination Passing Style. +struct MaterializeContractionOp : public OpInterfaceConversionPattern< + mlir::linalg::ContractionOpInterface> { + using OpInterfaceConversionPattern< + mlir::linalg::ContractionOpInterface>::OpInterfaceConversionPattern; + + LogicalResult + matchAndRewrite(mlir::linalg::ContractionOpInterface op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FailureOr convertedOp = + lowerOpWithEncoding(rewriter, op, operands); + if (failed(convertedOp)) { + return failure(); + } + rewriter.replaceOp(op.getOperation(), convertedOp.value()->getResults()); + return success(); + } +}; + /// Generic pattern to convert an operation. template struct MaterializeOperation : public OpMaterializeEncodingPattern { @@ -556,12 +616,12 @@ void populateMaterializeEncodingPatterns( // Add all patterns for converting from encoded type to the materialized // type patterns.insert, - MaterializeDPSOperation, - MaterializeDPSOperation, MaterializeOperation, SetEncodingOpToPackOpConversion, UnsetEncodingOpToPackOpConversion>( patterns.getContext(), typeConverter, materializeEncodingValueFn); + patterns.insert(typeConverter, + patterns.getContext()); ::mlir::memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); } diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp index 4b3e2f813f7f..1cce69e182e8 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Utils/EncodingUtils.cpp @@ -19,11 +19,112 @@ bool isBatchMatmulEncodingUser(EncodingUser user) { return user == EncodingUser::BATCH_MATMUL; } +bool isVecmatEncodingUser(EncodingUser user) { + return user == EncodingUser::VECMAT; +} + +bool isMatvecEncodingUser(EncodingUser user) { + return user == EncodingUser::MATVEC; +} + +bool isBatchMatvecEncodingUser(EncodingUser user) { + return user == EncodingUser::BATCH_MATVEC; +} + +bool isVectorEncoding(EncodingAttr encoding) { + return (isVecmatVector(encoding) || isMatvecVector(encoding) || + isBatchMatvecVector(encoding)); +} + +bool isVecmatVector(EncodingAttr encoding) { + if (!encoding) + return false; + auto user = encoding.getUser().getValue(); + auto role = encoding.getRole().getValue(); + if (user == EncodingUser::VECMAT && + (role == EncodingRole::LHS || role == EncodingRole::RESULT)) { + return true; + } + return false; +} + +bool isMatvecVector(EncodingAttr encoding) { + if (!encoding) + return false; + auto user = encoding.getUser().getValue(); + auto role = encoding.getRole().getValue(); + if (user == EncodingUser::MATVEC && + (role == EncodingRole::RHS || role == EncodingRole::RESULT)) { + return true; + } + return false; +} + +bool isBatchMatvecVector(EncodingAttr encoding) { + if (!encoding) + return false; + auto user = encoding.getUser().getValue(); + auto role = encoding.getRole().getValue(); + if (user == EncodingUser::BATCH_MATVEC && + (role == EncodingRole::RHS || role == EncodingRole::RESULT)) { + return true; + } + return false; +} + +int64_t getExpandedDimIndex(EncodingAttr encoding) { + if (isVecmatVector(encoding)) + return 0; + if (isMatvecVector(encoding)) + return 1; + if (isBatchMatvecVector(encoding)) + return 2; + return -1; +} + +SmallVector +getReassociationMapsForVectors(EncodingAttr encoding) { + SmallVector ri = {}; + if (isVecmatVector(encoding) || isMatvecVector(encoding)) + ri = {{0, 1}}; + else if (isBatchMatvecVector(encoding)) + ri = {{0}, {1, 2}}; + return ri; +} + +RankedTensorType createNewTypeForVectors(RankedTensorType inputType, + EncodingAttr encoding, + bool expanding) { + RankedTensorType newType = inputType; + Type eType = inputType.getElementType(); + if (isVecmatVector(encoding)) { + if (expanding) + newType = RankedTensorType::get({1, inputType.getDimSize(0)}, eType); + else + newType = RankedTensorType::get({inputType.getDimSize(1)}, eType); + } else if (isMatvecVector(encoding)) { + if (expanding) + newType = RankedTensorType::get({inputType.getDimSize(0), 1}, eType); + else + newType = RankedTensorType::get({inputType.getDimSize(0)}, eType); + } else if (isBatchMatvecVector(encoding)) { + if (expanding) + newType = RankedTensorType::get( + {inputType.getDimSize(0), inputType.getDimSize(1), 1}, eType); + else + newType = RankedTensorType::get( + {inputType.getDimSize(0), inputType.getDimSize(1)}, eType); + } + return newType; +} + MaterializeEncodingInfo chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role, MatmulTileParams tileParams) { // Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix. - int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0; + int64_t matmulDimBase = + (isBatchMatmulEncodingUser(user) || isBatchMatvecEncodingUser(user)) ? 1 + : 0; MaterializeEncodingInfo encodingInfo; encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1}; diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir index 9c87e957dfd8..8cb8ee11fc60 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir @@ -338,3 +338,397 @@ func.func @pack_batch_matmul_fill_dynamic(%arg0 : tensor, %arg1 : ten // CHECK-SAME: outs(%[[FILL]] : // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]] // CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_unpack_vecmat_lhs(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_vecmat_lhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor into tensor<1x?xf32> +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<1x?x1x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED]] inner_dims_pos = [0, 1] inner_tiles = [1, 4] into %[[PACK_DEST]] +// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0:.+]], %[[C0]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]]) : tensor<1x?xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [1, 4] into %[[UNPACK_DEST]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] {{\[}}[0, 1]] : tensor<1x?xf32> into tensor +// CHECK: return %[[COLLAPSED]] + +// ----- + +func.func @pack_unpack_vecmat_rhs(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK: func @pack_unpack_vecmat_rhs( +// CHECK: tensor.pack +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] +// CHECK: tensor.unpack %{{.+}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] + +// ----- + +func.func @pack_unpack_vecmat_result(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK: func @pack_unpack_vecmat_result( +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %{{.+}} {{\[}}[0, 1]] : tensor into tensor<1x?xf32> +// CHECK: tensor.pack +// CHECK-SAME: %[[EXPANDED]] inner_dims_pos = [0, 1] inner_tiles = [1, 8] +// CHECK: %[[UNPACK:.+]] = tensor.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [1, 8] +// CHECK: tensor.collapse_shape %[[UNPACK]] {{\[}}[0, 1]] : tensor<1x?xf32> into tensor + +// ----- + +func.func @pack_vecmat(%arg0: tensor<250xf32>, %arg1: tensor<250x100xf32>, %arg2: tensor<100xf32>) -> tensor<100xf32> { + %cst = arith.constant 0.0 : f32 + %pad_lhs = tensor.pad %arg0 low[0] high[2] { + ^bb0(%arg3: index): + tensor.yield %cst : f32 + } : tensor<250xf32> to tensor<252xf32> + %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<252xf32> -> tensor<252xf32, #iree_linalg_ext.encoding> + %pad_rhs = tensor.pad %arg1 low[0, 0] high[2, 4] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<250x100xf32> to tensor<252x104xf32> + %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252x104xf32> -> tensor<252x104xf32, #iree_linalg_ext.encoding> + %pad_output = tensor.pad %arg2 low[0] high[4] { + ^bb0(%arg3: index): + tensor.yield %cst : f32 + } : tensor<100xf32> to tensor<104xf32> + %output = iree_linalg_ext.set_encoding %pad_output : tensor<104xf32> -> tensor<104xf32, #iree_linalg_ext.encoding> + %vecmat_packed = linalg.vecmat ins(%lhs, %rhs : tensor<252xf32, #iree_linalg_ext.encoding>, tensor<252x104xf32, #iree_linalg_ext.encoding>) + outs(%output : tensor<104xf32, #iree_linalg_ext.encoding>) -> tensor<104xf32, #iree_linalg_ext.encoding> + %vecmat = iree_linalg_ext.unset_encoding %vecmat_packed : tensor<104xf32, #iree_linalg_ext.encoding> -> tensor<104xf32> + %extracted_slice = tensor.extract_slice %vecmat[0] [100] [1] : tensor<104xf32> to tensor<100xf32> + return %extracted_slice : tensor<100xf32> +} +// CHECK: func @pack_vecmat( +// CHECK-SAME: %[[ARG0:.+]]: tensor<250xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<250x100xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor<100xf32> +// CHECK: %[[CST:.+]] = arith.constant 0.0 +// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[ARG0]] +// CHECK: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[PADDED_LHS]] +// CHECK: %[[INIT_LHS:.+]] = tensor.empty() : tensor<1x63x1x4xf32> +// CHECK: %[[PACK_LHS:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_LHS]] inner_dims_pos = [0, 1] inner_tiles = [1, 4] into %[[INIT_LHS]] +// CHECK: %[[INIT_RHS:.+]] = tensor.empty() : tensor<13x63x8x4xf32> +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG1]] padding_value(%[[CST]] : f32) +// CHECK-SAME: into %[[INIT_RHS]] +// CHECK: %[[PADDED_RESULT:.+]] = tensor.pad %[[ARG2]] +// CHECK: %[[EXPANDED_RESULT:.+]] = tensor.expand_shape %[[PADDED_RESULT]] +// CHECK: %[[INIT_RESULT:.+]] = tensor.empty() : tensor<1x13x1x8xf32> +// CHECK: %[[PACK_RESULT:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RESULT]] inner_dims_pos = [0, 1] inner_tiles = [1, 8] into %[[INIT_RESULT]] +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[PACK_RESULT]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSED]][0] [100] [1] +// CHECK: return %[[SLICE]] + +// ----- + +func.func @pack_vecmat_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> + %2 = iree_linalg_ext.set_encoding %arg2 : tensor -> tensor> + %3 = linalg.vecmat ins(%0, %1 : tensor>, tensor>) + outs(%2 : tensor>) -> tensor> + %4 = iree_linalg_ext.unset_encoding %3 : tensor> -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_vecmat_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[EXPANDED_LHS:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[PACK_LHS:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_LHS]] +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG1]] +// CHECK: %[[EXPANDED_RESULT:.+]] = tensor.expand_shape %[[ARG2]] +// CHECK: %[[PACK_RESULT:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RESULT]] +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[PACK_RESULT]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] +// CHECK: return %[[COLLAPSED]] + +// ----- + +func.func @pack_vecmat_fill_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d1 = tensor.dim %arg1, %c1 : tensor + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> + %2 = tensor.empty(%d1) : tensor> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor>) + -> tensor> + %4 = linalg.vecmat ins(%0, %1 : tensor>, tensor>) + outs(%3 : tensor>) -> tensor> + %5 = iree_linalg_ext.unset_encoding %4 : tensor> -> tensor + return %5 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_vecmat_fill_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]] +// CHECK-DAG: %[[OUT_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]] +// CHECK-DAG: %[[PACK_LHS:.+]] = tensor.pack {{.*}}%[[EXPANDED]] +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG1]] +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]]) : tensor<1x?x1x8xf32> +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[FILL]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] +// CHECK: return %[[COLLAPSED]] + + +func.func @pack_unpack_matvec_lhs(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} + +// CHECK: func @pack_unpack_matvec_lhs( +// CHECK: tensor.pack +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 4] +// CHECK: tensor.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 4] + +// ----- + +func.func @pack_unpack_matvec_rhs(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_matvec_rhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor into tensor +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] +// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<1x?x1x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED]] outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [1, 4] into %[[PACK_DEST]] +// CHECK: %[[D1:.+]] = tensor.dim %[[ARG0:.+]], %[[C0]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D1]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [1, 4] into %[[UNPACK_DEST]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] {{\[}}[0, 1]] : tensor into tensor +// CHECK: return %[[COLLAPSED]] + +// ----- + +func.func @pack_matvec(%arg0: tensor<100x250xf32>, %arg1: tensor<250xf32>, %arg2: tensor<100xf32>) -> tensor<100xf32> { + %cst = arith.constant 0.0 : f32 + %pad_lhs = tensor.pad %arg0 low[0, 0] high[4, 2] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<100x250xf32> to tensor<104x252xf32> + %lhs = iree_linalg_ext.set_encoding %pad_lhs : tensor<104x252xf32> -> tensor<104x252xf32, #iree_linalg_ext.encoding> + %pad_rhs = tensor.pad %arg1 low[0] high[2] { + ^bb0(%arg3: index): + tensor.yield %cst : f32 + } : tensor<250xf32> to tensor<252xf32> + %rhs = iree_linalg_ext.set_encoding %pad_rhs : tensor<252xf32> -> tensor<252xf32, #iree_linalg_ext.encoding> + %pad_output = tensor.pad %arg2 low[0] high[4] { + ^bb0(%arg3: index): + tensor.yield %cst : f32 + } : tensor<100xf32> to tensor<104xf32> + %output = iree_linalg_ext.set_encoding %pad_output : tensor<104xf32> -> tensor<104xf32, #iree_linalg_ext.encoding> + %matvec_packed = linalg.matvec ins(%lhs, %rhs : tensor<104x252xf32, #iree_linalg_ext.encoding>, tensor<252xf32, #iree_linalg_ext.encoding>) + outs(%output : tensor<104xf32, #iree_linalg_ext.encoding>) -> tensor<104xf32, #iree_linalg_ext.encoding> + %matvec = iree_linalg_ext.unset_encoding %matvec_packed : tensor<104xf32, #iree_linalg_ext.encoding> -> tensor<104xf32> + %extracted_slice = tensor.extract_slice %matvec[0] [100] [1] : tensor<104xf32> to tensor<100xf32> + return %extracted_slice : tensor<100xf32> +} +// CHECK: func @pack_matvec( +// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<250xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor<100xf32> +// CHECK: %[[CST:.+]] = arith.constant 0.0 +// CHECK: %[[INIT_LHS:.+]] = tensor.empty() : tensor<13x63x8x4xf32> +// CHECK: %[[PACK_LHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] padding_value(%[[CST]] : f32) +// CHECK-SAME: into %[[INIT_LHS]] +// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[ARG1]] +// CHECK: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[PADDED_RHS]] +// CHECK: %[[INIT_RHS:.+]] = tensor.empty() : tensor<1x63x1x4xf32> +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RHS]] outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [1, 4] into %[[INIT_RHS]] +// CHECK: %[[PADDED_RESULT:.+]] = tensor.pad %[[ARG2]] +// CHECK: %[[EXPANDED_RESULT:.+]] = tensor.expand_shape %[[PADDED_RESULT]] +// CHECK: %[[INIT_RESULT:.+]] = tensor.empty() : tensor<13x1x8x1xf32> +// CHECK: %[[PACK_RESULT:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RESULT]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[INIT_RESULT]] +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[PACK_RESULT]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSED]][0] [100] [1] +// CHECK: return %[[SLICE]] + +// ----- + +func.func @pack_matvec_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> + %2 = iree_linalg_ext.set_encoding %arg2 : tensor -> tensor> + %3 = linalg.matvec ins(%0, %1 : tensor>, tensor>) + outs(%2 : tensor>) -> tensor> + %4 = iree_linalg_ext.unset_encoding %3 : tensor> -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_matvec_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[PACK_LHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] +// CHECK: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[ARG1]] +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RHS]] +// CHECK: %[[EXPANDED_RESULT:.+]] = tensor.expand_shape %[[ARG2]] +// CHECK: %[[PACK_RESULT:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RESULT]] +// CHECK: %[[MMT4D:.+]] = linalg.mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[PACK_RESULT]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] +// CHECK: return %[[COLLAPSED]] + +// ----- + +func.func @pack_unpack_batch_matvec_lhs(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_batch_matvec_lhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]] +// CHECK-DAG: %[[OUTER_D2:.+]] = affine.apply #[[MAP1]]()[%[[D2]]] +// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[D0]], %[[OUTER_D1]], %[[OUTER_D2]]) : tensor +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[PACK_DEST]] +// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 4] into %[[UNPACK_DEST]] +// CHECK: return %[[UNPACK]] + +// ----- + +func.func @pack_unpack_batch_matvec_rhs(%arg0 : tensor) -> tensor { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.unset_encoding %0 : tensor> -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_unpack_batch_matvec_rhs( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor into tensor +// CHECK: tensor.pack %[[EXPANDED]] +// CHECK: tensor.unpack +// CHECK: tensor.collapse_shape + +// ----- + +func.func @pack_batch_matvec(%arg0 : tensor<128x80x32xf32>, %arg1 : tensor<128x32xf32>, %arg2 : tensor<128x80xf32>) -> tensor<128x80xf32> { + %0 = iree_linalg_ext.set_encoding %arg0 : tensor<128x80x32xf32> -> tensor<128x80x32xf32, #iree_linalg_ext.encoding> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor<128x32xf32> -> tensor<128x32xf32, #iree_linalg_ext.encoding> + %2 = iree_linalg_ext.set_encoding %arg2 : tensor<128x80xf32> -> tensor<128x80xf32, #iree_linalg_ext.encoding> + %3 = linalg.batch_matvec ins(%0, %1 : tensor<128x80x32xf32, #iree_linalg_ext.encoding>, tensor<128x32xf32, #iree_linalg_ext.encoding>) + outs(%2 : tensor<128x80xf32, #iree_linalg_ext.encoding>) -> tensor<128x80xf32, #iree_linalg_ext.encoding> + %4 = iree_linalg_ext.unset_encoding %3 : tensor<128x80xf32, #iree_linalg_ext.encoding> -> tensor<128x80xf32> + return %4 : tensor<128x80xf32> +} +// CHECK: func @pack_batch_matvec( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x80x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<128x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<128x80xf32> +// CHECK: %[[PACK_LHS:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] +// CHECK: %[[EXPANDED_RHS:.+]] = tensor.expand_shape %[[ARG1]] +// CHECK: %[[PACK_RHS:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RHS]] +// CHECK: %[[EXPANDED_RESULT:.+]] = tensor.expand_shape %[[ARG2]] +// CHECK: %[[PACK_RESULT:.+]] = tensor.pack +// CHECK-SAME: %[[EXPANDED_RESULT]] +// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d +// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] : +// CHECK-SAME: outs(%[[PACK_RESULT]] : +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[UNPACK]] +// CHECK: return %[[COLLAPSED]] + +// ----- + +func.func @pack_batch_matvec_fill_dynamic(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %0 = iree_linalg_ext.set_encoding %arg0 : tensor -> tensor> + %1 = iree_linalg_ext.set_encoding %arg1 : tensor -> tensor> + %2 = tensor.empty(%d0, %d1) : tensor> + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor>) + -> tensor> + %4 = linalg.batch_matvec ins(%0, %1 : tensor>, tensor>) + outs(%3 : tensor>) -> tensor> + %5 = iree_linalg_ext.unset_encoding %4 : tensor> -> tensor + return %5 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK: func @pack_batch_matvec_fill_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: tensor.pack %[[ARG0]] +// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0], [1, 2]] : tensor into tensor +// CHECK: tensor.pack %[[EXPANDED]] +// CHECK: linalg.batch_mmt4d +// CHECK: tensor.unpack +// CHECK: tensor.collapse_shape \ No newline at end of file