From 124d56203cd253ada3be54d2c6482bc28eacd239 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 5 Jan 2024 12:10:39 -0800 Subject: [PATCH] Bump StableHLO to f8dcebfa1ec166806974f6ae0dfb902d36b47238 (#16049) Updated the serialized models too, but just manually, so used naming convention to capture that it wasn't newly generated/manually done. --- .../e2e_test_framework/models/jax_models.py | 2 +- .../e2e_test_framework/models/tf_models.py | 10 +- .../Conversion/ConvertCollectives.cpp | 9 +- .../Conversion/LegalizeCHLO.cpp | 12 +- .../Preprocessing/Canonicalization.cpp | 49 ++++---- .../Preprocessing/DotGeneralToDot.cpp | 7 +- .../Preprocessing/EinsumToDotGeneral.cpp | 2 +- .../Preprocessing/StableHLOToStableHLO.cpp | 63 +++++----- .../test/stablehlo_to_stablehlo.mlir | 14 +-- .../StableHLOToIREEInputDialects.cpp | 3 +- .../Conversion/StableHLOToLinalg.cpp | 111 ++++++++---------- .../StableHLOToLinalgConvolution.cpp | 14 +-- .../Conversion/StableHLOToLinalgExt.cpp | 6 +- .../Conversion/StableHLOToLinalgRandom.cpp | 30 +++-- .../Conversion/StableHLOToLinalgReduce.cpp | 11 +- .../Conversion/test/stablehlo_to_linalg.mlir | 82 ++++++------- .../test/stablehlo_to_linalg_ext.mlir | 10 +- .../mnist_train_test/mnist_train_test.py | 2 +- .../stablehlo_models/unidirectional_lstm.mlir | 8 +- tests/e2e/stablehlo_ops/broadcast.mlir | 4 +- tests/e2e/stablehlo_ops/dynamic_slice.mlir | 6 +- tests/e2e/stablehlo_ops/pad.mlir | 8 +- tests/e2e/stablehlo_ops/reverse.mlir | 6 +- tests/e2e/stablehlo_ops/slice.mlir | 24 ++-- tests/e2e/stablehlo_ops/transpose.mlir | 4 +- .../generated_e2e_test_fetch_models.cmake | 34 +++--- .../generated_e2e_test_iree_artifacts.cmake | 68 +++++------ tests/microbenchmarks/stablehlo_fft_abs.mlir | 2 +- third_party/stablehlo | 2 +- 29 files changed, 290 insertions(+), 313 deletions(-) diff --git a/build_tools/python/e2e_test_framework/models/jax_models.py b/build_tools/python/e2e_test_framework/models/jax_models.py index 874d8943950e..6613683c5179 100644 --- a/build_tools/python/e2e_test_framework/models/jax_models.py +++ b/build_tools/python/e2e_test_framework/models/jax_models.py @@ -11,7 +11,7 @@ from e2e_test_framework.definitions import common_definitions import e2e_test_framework.models.utils as model_utils -GCS_ARTIFACT_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.14_1691969180" +GCS_ARTIFACT_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.14_1691969180j" ID_FORMAT = string.Template("${model_id}-batch${batch_size}") NAME_FORMAT = string.Template("${name}_BATCH${batch_size}") diff --git a/build_tools/python/e2e_test_framework/models/tf_models.py b/build_tools/python/e2e_test_framework/models/tf_models.py index 071aec90e389..3c81790c8ddc 100644 --- a/build_tools/python/e2e_test_framework/models/tf_models.py +++ b/build_tools/python/e2e_test_framework/models/tf_models.py @@ -21,7 +21,7 @@ tags=["int32", "seqlen128"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Converted from https://huggingface.co/microsoft/MiniLM-L12-H384-uncased/commit/44acabbec0ef496f6dbc93adadea57f376b7c0ec - source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734j.mlirbc", entry_function="predict", input_types=["1x128xi32", "1x128xi32", "1x128xi32"], ) @@ -32,7 +32,7 @@ tags=["fp32", "seqlen512", "tensorflow"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Converted from https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#tfbertformaskedlm - source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734j.mlirbc", entry_function="forward", input_types=["1x512xi32", "1x512xi32"], ) @@ -43,7 +43,7 @@ tags=["fp32", "cnn", "tensorflow"], source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Converted from https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/efficientnet_v2.py - source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734j.mlirbc", entry_function="forward", input_types=["1x384x384x3xf32"], ) @@ -56,7 +56,7 @@ source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR, # Derived from https://github.com/mlcommons/inference/tree/master/language/bert # Instructions on how to regenerate the model: https://gist.github.com/mariecwhite/e61ccebd979d98d097946ac7725bcc29 - source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc", + source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734j.mlirbc", entry_function="serving_default", input_types=["1x384xi32", "1x384xi32", "1x384xi32"], ) @@ -81,7 +81,7 @@ input_types=["1x1xi32", "12x2x1x12x4x64xf32"], ) -TF_MODELS_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975" +TF_MODELS_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j" ID_FORMAT = string.Template("${model_id}-batch-${batch_size}") NAME_FORMAT = string.Template("${name}Batch${batch_size}") diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp index a4ec6a8e63bb..2d5726df1990 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/ConvertCollectives.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Utils/IndexSet.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo-iree/Conversion/Rewriters.h" @@ -448,7 +449,7 @@ static Value emitTranspose(ConversionPatternRewriter &rewriter, Location loc, llvm::to_vector(llvm::seq(0, inputShape.size())); std::swap(permutation[srcDim], permutation[dstDim]); std::swap(inputShape[srcDim], inputShape[dstDim]); - DenseIntElementsAttr permutationAttr = rewriter.getI64VectorAttr(permutation); + auto permutationAttr = rewriter.getDenseI64ArrayAttr(permutation); return rewriter.create( loc, RankedTensorType::get(inputShape, inputType.getElementType()), input, permutationAttr); @@ -705,7 +706,7 @@ Value splitAndConcatForAllToAll(ConversionPatternRewriter &rewriter, result = rewriter.create( loc, RankedTensorType::get(transposeResultShape, inputType.getElementType()), - result, rewriter.getI64VectorAttr(permutation)); + result, rewriter.getDenseI64ArrayAttr(permutation)); // Reshape llvm::SmallVector finalShape(inputShape); @@ -852,7 +853,7 @@ struct ReduceScatterOpConversion final auto inputType = cast(op.getOperand().getType()); SmallVector reduceInputShape(inputType.getShape()); Value reduceInput = adaptor.getOperand(); - DenseIntElementsAttr permutationAttr; + DenseI64ArrayAttr permutationAttr; SmallVector scatterResultShape(resultType.getShape()); auto elemType = getElementTypeOrSelf(reduceInput.getType()); @@ -861,7 +862,7 @@ struct ReduceScatterOpConversion final auto permutation = llvm::to_vector(llvm::seq(0, scatterResultShape.size())); std::swap(permutation[0], permutation[scatterDim]); - permutationAttr = rewriter.getI64VectorAttr(permutation); + permutationAttr = rewriter.getDenseI64ArrayAttr(permutation); std::swap(reduceInputShape[0], reduceInputShape[scatterDim]); std::swap(scatterResultShape[0], scatterResultShape[scatterDim]); // Transpose the input. diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp index b22f318b21f5..b26589ab958f 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/LegalizeCHLO.cpp @@ -2144,14 +2144,14 @@ struct ConvertTopKOp final : OpConversionPattern { } else { values = rewriter.create( op.getLoc(), tupleFirstElement, - DenseIntElementsAttr::get(indicesTy, beginIndices), - DenseIntElementsAttr::get(indicesTy, endIndices), - DenseIntElementsAttr::get(indicesTy, strides)); + rewriter.getDenseI64ArrayAttr(beginIndices), + rewriter.getDenseI64ArrayAttr(endIndices), + rewriter.getDenseI64ArrayAttr(strides)); indices = rewriter.create( op.getLoc(), tupleSecondElement, - DenseIntElementsAttr::get(indicesTy, beginIndices), - DenseIntElementsAttr::get(indicesTy, endIndices), - DenseIntElementsAttr::get(indicesTy, strides)); + rewriter.getDenseI64ArrayAttr(beginIndices), + rewriter.getDenseI64ArrayAttr(endIndices), + rewriter.getDenseI64ArrayAttr(strides)); } rewriter.replaceOp(op, {values, indices}); diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp index 63ce62e6180d..160c0481f7fb 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/Canonicalization.cpp @@ -41,6 +41,16 @@ namespace { // allowed to materialize as new constants. constexpr int64_t kFoldOpEltLimit = 65536; +static bool isIotaRange(ArrayRef dims) { + for (auto [idx, value] : llvm::enumerate(dims)) { + if (idx != value) { + return false; + } + } + + return true; +} + static bool isIotaRange(ElementsAttr attr) { auto elems = attr.tryGetValues(); if (!elems) @@ -469,7 +479,7 @@ struct BroadcastInDimOpCanon final return failure(); // Fold when broadcast is a noop. - DenseIntElementsAttr dims = op.getBroadcastDimensions(); + auto dims = op.getBroadcastDimensions(); bool isDimsIota = isIotaRange(dims); if (type == operandTy && isDimsIota) { rewriter.replaceOp(op, operand); @@ -485,7 +495,7 @@ struct BroadcastInDimOpCanon final return success(); } - auto bsDimIndices = dims.getValues(); + auto bsDimIndices = dims; if (operandTy.hasStaticShape() && type.hasStaticShape() && type.getNumElements() == operandTy.getNumElements()) { // BroadcastInDim equivalent to reshape. @@ -505,12 +515,10 @@ struct BroadcastInDimOpCanon final // Eliminate redundant nested BroadcastInDim. if (auto broadcastInDimOp = operand.getDefiningOp()) { - auto newIndices = cast( - broadcastInDimOp.getBroadcastDimensions().mapValues( - dims.getElementType(), [&bsDimIndices](const APInt &dim) { - return APInt(dim.getBitWidth(), - bsDimIndices[dim.getSExtValue()], true); - })); + auto newIndices = + rewriter.getDenseI64ArrayAttr(llvm::to_vector(llvm::map_range( + broadcastInDimOp.getBroadcastDimensions(), + [&bsDimIndices](int64_t dim) { return bsDimIndices[dim]; }))); rewriter.replaceOpWithNewOp( op, type, broadcastInDimOp.getOperand(), newIndices); return success(); @@ -631,7 +639,7 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final // output has static shape, replace with broadcast_in_dim if (type.hasStaticShape()) { rewriter.replaceOpWithNewOp( - op, type, op.getOperand(), op.getBroadcastDimensions()); + op, type, op.getOperand(), op.getBroadcastDimensionsAttr()); return success(); } @@ -648,7 +656,7 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final refineOpWithNewOp( rewriter, op, RankedTensorType::get(outputShape, type.getElementType()), - op.getOperand(), op.getBroadcastDimensions()); + op.getOperand(), op.getBroadcastDimensionsAttr()); return success(); } } @@ -670,16 +678,11 @@ struct ChainedDynamicBroadcastInDimCanonicalization final return failure(); // Compose broadcast dimensions. - DenseIntElementsAttr precedingBcastDims = - precedingBcast.getBroadcastDimensions(); - DenseIntElementsAttr bcastDims = bcast.getBroadcastDimensions(); - SmallVector composition; - for (APInt precedingDim : precedingBcastDims) { - composition.push_back( - *(bcastDims.value_begin() + precedingDim.getZExtValue())); + SmallVector composition; + for (int64_t precedingDim : precedingBcast.getBroadcastDimensions()) { + composition.push_back(bcast.getBroadcastDimensions()[precedingDim]); } - auto composedBcastDims = - DenseIntElementsAttr::get(precedingBcastDims.getType(), composition); + auto composedBcastDims = rewriter.getDenseI64ArrayAttr(composition); rewriter.replaceOpWithNewOp( bcast, bcast.getType(), precedingBcast.getOperand(), @@ -928,9 +931,9 @@ struct GatherOpCanon final : OpRewritePattern { auto sliceType = RankedTensorType::get(sliceShape, elementType); Value result = rewriter.create( gather.getLoc(), sliceType, gather.getOperand(), - rewriter.getI64TensorAttr(sliceStart), - rewriter.getI64TensorAttr(sliceEnd), - rewriter.getI64TensorAttr(sliceStride)); + rewriter.getDenseI64ArrayAttr(sliceStart), + rewriter.getDenseI64ArrayAttr(sliceEnd), + rewriter.getDenseI64ArrayAttr(sliceStride)); ArrayRef collapsedSliceDims = dnums.getCollapsedSliceDims(); if (!collapsedSliceDims.empty()) { @@ -1030,7 +1033,7 @@ struct TransposeIsReshape final "tensor type"); } - SmallVector permValues(permutation.getValues()); + SmallVector permValues(permutation); SmallVector nonZeroPerms; nonZeroPerms.reserve(permValues.size()); diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp index ed2832198043..50e8e8fbecb7 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/DotGeneralToDot.cpp @@ -47,13 +47,8 @@ Value transposeReshape(Value arg, Location loc, auto transposePermutation = llvm::to_vector<5>(llvm::concat(leftDims, rightDims)); - TensorType transposePermutationType = - RankedTensorType::get({static_cast(transposePermutation.size())}, - rewriter.getIntegerType(64)); - auto transposePermutationAttr = - llvm::cast(DenseIntElementsAttr::get( - transposePermutationType, llvm::ArrayRef(transposePermutation))); + rewriter.getDenseI64ArrayAttr(transposePermutation); // Compute the resulting shape. llvm::SmallVector transposedShape; diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp index e671fed59d08..e33b75b0f6ba 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/EinsumToDotGeneral.cpp @@ -149,7 +149,7 @@ struct EinsumToDotGeneralPattern final } else { // Generate a transpose. rewriter.replaceOpWithNewOp( - einsum, dotGeneralOp, rewriter.getI64TensorAttr(resultPerms)); + einsum, dotGeneralOp, rewriter.getDenseI64ArrayAttr(resultPerms)); } return success(); } diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp index 53321b69edeb..fff4726950e3 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/StableHLOToStableHLO.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" @@ -39,11 +40,8 @@ bool isIota(ArrayRef array) { return true; } -DenseIntElementsAttr make1DElementsAttr(OpBuilder &b, - ArrayRef integers) { - auto type = RankedTensorType::get({static_cast(integers.size())}, - b.getIntegerType(64)); - return DenseIntElementsAttr::get(type, integers); +DenseI64ArrayAttr make1DElementsAttr(OpBuilder &b, ArrayRef integers) { + return b.getDenseI64ArrayAttr(integers); } Value getF32Const(ImplicitLocOpBuilder b, ArrayRef shapes, @@ -90,7 +88,7 @@ struct ReorderConvOpInputDimensions final auto transposed = rewriter.create( op.getLoc(), RankedTensorType::get(transposeShape, lhsType.getElementType()), - op.getLhs(), rewriter.getI64TensorAttr(permutations)); + op.getLhs(), rewriter.getDenseI64ArrayAttr(permutations)); llvm::SmallVector newSpatialDimensions(spatialDims.size()); std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1); @@ -158,7 +156,7 @@ struct ReorderConvOpKernelDimensions final auto transposeKernel = rewriter.create( op.getLoc(), RankedTensorType::get(transposeShape, kernelType.getElementType()), - kernel, rewriter.getI64TensorAttr(permutation)); + kernel, rewriter.getDenseI64ArrayAttr(permutation)); auto newDimensionNumbers = mlir::stablehlo::ConvDimensionNumbersAttr::get( op.getContext(), dimensionNumbers.getInputBatchDimension(), @@ -246,7 +244,7 @@ struct ReorderConvOpOutputDimensions final auto transposed = rewriter.create( op.getLoc(), resultType, newConv, - rewriter.getI64TensorAttr(invertPermutation)); + rewriter.getDenseI64ArrayAttr(invertPermutation)); rewriter.replaceOp(op, transposed.getResult()); return success(); @@ -286,7 +284,7 @@ struct TransposeReshapeGenericDotGeneral final } return b.create( loc, RankedTensorType::get(transposeShape, type.getElementType()), src, - b.getI64TensorAttr(targetOrder)); + b.getDenseI64ArrayAttr(targetOrder)); } Value ReshapeIfNonStandard(OpBuilder &b, Location loc, Value src, @@ -748,7 +746,8 @@ struct ScatterBatchFirst final : OpRewritePattern { } indices = builder.create( - indicesTy.clone(newShape), indices, builder.getI64TensorAttr(perm)); + indicesTy.clone(newShape), indices, + builder.getDenseI64ArrayAttr(perm)); indicesTy = llvm::cast(indices.getType()); indexVectorDim = indicesTy.getRank() - 1; } @@ -792,7 +791,7 @@ struct ScatterBatchFirst final : OpRewritePattern { newShape.push_back(updateTy.getDimSize(updatePerm[i])); update = builder.create( updateTy.clone(newShape), update, - builder.getI64TensorAttr(updatePerm)); + builder.getDenseI64ArrayAttr(updatePerm)); } } @@ -1025,7 +1024,7 @@ struct MulCastOfBool final : OpRewritePattern { llvm::seq(resultRank - valueTy.getRank(), resultRank)); return rewriter.create( op.getLoc(), newTy, value, lhsShape, - rewriter.getI64TensorAttr(dimensions)); + rewriter.getDenseI64ArrayAttr(dimensions)); }; zero = broadcast(zero); @@ -1184,7 +1183,7 @@ struct ReorderBroadcastInDimOpAndElementwiseOp final Value result = rewriter.create(op.getLoc(), resultType, bcastOperands); rewriter.replaceOpWithNewOp( - op, op.getType(), result, bcastOps[0].getBroadcastDimensions()); + op, op.getType(), result, bcastOps[0].getBroadcastDimensionsAttr()); for (auto bcastOp : bcastOps) { if (bcastOp.getOperation()->use_empty()) { @@ -1283,11 +1282,11 @@ struct DotToMul final : OpRewritePattern { lhs = rewriter.create( op.getLoc(), resultTy.clone(lhsTy.getElementType()), lhs, outSize, - rewriter.getI64TensorAttr({0, 1})); + rewriter.getDenseI64ArrayAttr({0, 1})); rhs = rewriter.create( op.getLoc(), resultTy.clone(rhsTy.getElementType()), rhs, outSize, - rewriter.getI64TensorAttr({0, 1})); + rewriter.getDenseI64ArrayAttr({0, 1})); auto computeETy = lhsTy.getElementType(); if (computeETy.getIntOrFloatBitWidth() < rhsTy.getElementTypeBitWidth()) @@ -1451,12 +1450,12 @@ struct DotGeneralIsMul final : OpRewritePattern { // Transpose the left hand side and the right hand side. lhs = builder.create( RankedTensorType::get(lhsTransposeShape, lhsTy.getElementType()), lhs, - builder.getI64TensorAttr(permLhs)); + builder.getDenseI64ArrayAttr(permLhs)); lhsTy = llvm::cast(lhs.getType()); rhs = builder.create( RankedTensorType::get(rhsTransposeShape, rhsTy.getElementType()), rhs, - builder.getI64TensorAttr(permRhs)); + builder.getDenseI64ArrayAttr(permRhs)); rhsTy = llvm::cast(rhs.getType()); auto dimI32Ty = RankedTensorType::get({1}, builder.getI32Type()); @@ -1512,7 +1511,7 @@ struct DotGeneralIsMul final : OpRewritePattern { RankedTensorType::get(resultTy.getShape(), lhsTy.getElementType()); lhs = builder.createOrFold( lhsBroadcastTy, lhs, outputShape, - rewriter.getI64TensorAttr(lhsDimMapping)); + rewriter.getDenseI64ArrayAttr(lhsDimMapping)); // Broadcast the right hand side to match the expected output shape. llvm::SmallVector rhsDimMapping(rhsTy.getRank()); @@ -1524,7 +1523,7 @@ struct DotGeneralIsMul final : OpRewritePattern { RankedTensorType::get(resultTy.getShape(), rhsTy.getElementType()); rhs = builder.createOrFold( rhsBroadcastTy, rhs, outputShape, - rewriter.getI64TensorAttr(rhsDimMapping)); + rewriter.getDenseI64ArrayAttr(rhsDimMapping)); lhs = builder.createOrFold(resultTy, lhs); rhs = builder.createOrFold(resultTy, rhs); @@ -1651,7 +1650,7 @@ bool isIotaOrIotaBroadcast(PatternRewriter &rewriter, Value input) { return true; } - (void)rewriter.notifyMatchFailure(iotaOp, "Iota must be on last dimension"); + (void)rewriter.notifyMatchFailure(iotaOp, "iota must be on last dimension"); return false; } @@ -1659,11 +1658,9 @@ bool isIotaOrIotaBroadcast(PatternRewriter &rewriter, Value input) { input.getDefiningOp())) { auto broadcastLastDim = cast(broadcastOp.getType()).getRank() - 1; - SmallVector broadcastDimensions = llvm::to_vector( - broadcastOp.getBroadcastDimensions().getValues()); - if (broadcastDimensions.back() != broadcastLastDim) { + if (broadcastOp.getBroadcastDimensions().back() != broadcastLastDim) { (void)rewriter.notifyMatchFailure( - broadcastOp, "Last dimension must be maintained in broadcast"); + broadcastOp, "last dimension must be maintained in broadcast"); return false; } return isIotaOrIotaBroadcast(rewriter, broadcastOp.getOperand()); @@ -1682,7 +1679,7 @@ struct IotaSortSliceIsTopK final : OpRewritePattern { Value topKInput; if (opOperands.size() != 2 || opResults.size() != 2) { return rewriter.notifyMatchFailure( - op, "Slice that maps to TopK must have exactly two inputs/outputs"); + op, "slice that maps to TopK must have exactly two inputs/outputs"); } Value inputIota; @@ -1697,7 +1694,7 @@ struct IotaSortSliceIsTopK final : OpRewritePattern { } if (!inputIota) { - return rewriter.notifyMatchFailure(op, "Sort isn't called from Iota."); + return rewriter.notifyMatchFailure(op, "sort isn't called from Iota"); } Block &block = op.getRegion().front(); @@ -1713,7 +1710,7 @@ struct IotaSortSliceIsTopK final : OpRewritePattern { if (!getTop) { return rewriter.notifyMatchFailure(op, - "Unsupported comparison direction"); + "unsupported comparison direction"); } Value topV, topI; @@ -1722,27 +1719,25 @@ struct IotaSortSliceIsTopK final : OpRewritePattern { for (auto [idx, result] : llvm::enumerate(opResults)) { if (result.getUsers().empty()) return rewriter.notifyMatchFailure( - op, "Sort isn't calling into a slice op."); + op, "sort isn't calling into a slice op"); auto sliceOp = dyn_cast(*result.getUsers().begin()); if (!sliceOp) { return rewriter.notifyMatchFailure( - op, "Sort isn't calling into a slice op."); + op, "sort isn't calling into a slice op"); } - for (auto stride : sliceOp.getStrides().getValues()) { + for (auto stride : sliceOp.getStrides()) { if (stride != 1) { return rewriter.notifyMatchFailure( - op, "All slice strides must be 1 in order to match to TopK."); + op, "all slice strides must be 1 in order to match to TopK"); } } // Treat the first slice as inputs, the second as indices. if (idx == 0) { topV = sliceOp.getResult(); - SmallVector limitIndices = - llvm::to_vector(sliceOp.getLimitIndices().getValues()); - k = limitIndices.back(); + k = sliceOp.getLimitIndices().back(); } else { topI = sliceOp.getResult(); } diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir index 8e36b844d9f9..2ebad087ba81 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/Preprocessing/test/stablehlo_to_stablehlo.mlir @@ -355,7 +355,7 @@ func.func @convolution(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf padding = dense<0> : tensor<1x2xi64>, precision_config = [#stablehlo, #stablehlo], rhs_dilation = dense<1> : tensor<1xi64>, - window_strides = dense<1> : tensor<1xi64> + window_strides = dense<[1]> : tensor<1xi64> } : (tensor<16x32x256xf32>, tensor<1x256x256xbf16>) -> tensor<16x32x256xf32> // CHECK: return %[[CONV]] func.return %0 : tensor<16x32x256xf32> @@ -413,8 +413,8 @@ func.func @iota_sort_slice_is_topk(%in : tensor<16x16xf32>) -> (tensor<16x8xf32> %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor "stablehlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) - %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xf32>) -> tensor<16x8xf32> - %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xi32>) -> tensor<16x8xi32> + %1 = "stablehlo.slice"(%0#0) { start_indices = array, limit_indices = array, strides = array } : (tensor<16x16xf32>) -> tensor<16x8xf32> + %2 = "stablehlo.slice"(%0#1) { start_indices = array, limit_indices = array, strides = array } : (tensor<16x16xi32>) -> tensor<16x8xi32> return %1, %2 : tensor<16x8xf32>, tensor<16x8xi32> } @@ -434,8 +434,8 @@ func.func @broadcast_iota_sort_slice_is_topk(%in : tensor<16x16x16xf32>) -> (ten %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor "stablehlo.return"(%7) : (tensor) -> () }) {dimension = 2 : i64, is_stable = true} : (tensor<16x16x16xf32>, tensor<16x16x16xi32>) -> (tensor<16x16x16xf32>, tensor<16x16x16xi32>) - %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32> - %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32> + %1 = "stablehlo.slice"(%0#0) { start_indices = array, limit_indices = array, strides = array } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32> + %2 = "stablehlo.slice"(%0#1) { start_indices = array, limit_indices = array, strides = array } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32> return %1, %2 : tensor<16x16x8xf32>, tensor<16x16x8xi32> } @@ -455,8 +455,8 @@ func.func @broadcast_iota_sort_slice_incorrect_dims(%in : tensor<16x16x16xf32>) %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor "stablehlo.return"(%7) : (tensor) -> () }) {dimension = 2 : i64, is_stable = true} : (tensor<16x16x16xf32>, tensor<16x16x16xi32>) -> (tensor<16x16x16xf32>, tensor<16x16x16xi32>) - %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32> - %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0, 0]> : tensor<3xi64>, limit_indices = dense<[16, 16, 8]> : tensor<3xi64>, strides = dense<[1, 1, 1]> : tensor<3xi64> } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32> + %1 = "stablehlo.slice"(%0#0) { start_indices = array, limit_indices = array, strides = array } : (tensor<16x16x16xf32>) -> tensor<16x16x8xf32> + %2 = "stablehlo.slice"(%0#1) { start_indices = array, limit_indices = array, strides = array } : (tensor<16x16x16xi32>) -> tensor<16x16x8xi32> return %1, %2 : tensor<16x16x8xf32>, tensor<16x16x8xi32> } diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp index a6d8cfa1bbd8..f82fb9a1417a 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToIREEInputDialects.cpp @@ -178,8 +178,7 @@ struct FftOpConversion final : OpConversionPattern { int64_t rank = inputType.getRank(); int64_t n = inputType.getDimSize(rank - 1); - int64_t fftLength = - op.getFftLength().getSplatValue().getInt() / 2 + 1; + int64_t fftLength = op.getFftLength().front() / 2 + 1; Location loc = op.getLoc(); auto matrixType = diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp index d26afec70a08..ff36eb135306 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalg.cpp @@ -492,15 +492,12 @@ struct HloBroadcastInDimConverter final SmallVector dimExprs; dimExprs.reserve(nloops); - if (broadcastOp.getBroadcastDimensions()) { - for (auto [idx, broadcastDim] : llvm::enumerate( - broadcastOp.getBroadcastDimensions().getValues())) { - int size = broadcastDim.getSExtValue(); - bool expansionNeeded = - operandShape[idx] == 1 && resultType.getShape()[size] != 1; - dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0) - : b->getAffineDimExpr(size)); - } + for (auto [idx, size] : + llvm::enumerate(broadcastOp.getBroadcastDimensions())) { + bool expansionNeeded = + operandShape[idx] == 1 && resultType.getShape()[size] != 1; + dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0) + : b->getAffineDimExpr(size)); } return { AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), @@ -577,7 +574,7 @@ Value transposeBroadcastOperand(PatternRewriter &rewriter, Location loc, return rewriter.create( loc, RankedTensorType::get(transposedOperandShape, operandTy.getElementType()), - operand, rewriter.getI64VectorAttr(permutation)); + operand, rewriter.getDenseI64ArrayAttr(permutation)); } struct BroadcastInDimOpToBroadcastConverter final @@ -589,8 +586,7 @@ struct BroadcastInDimOpToBroadcastConverter final ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - SmallVector broadcastDimensions = - llvm::to_vector(op.getBroadcastDimensions().getValues()); + SmallVector broadcastDimensions = op.getBroadcastDimensions(); Value operand = adaptor.getOperand(); auto operandTy = llvm::cast(operand.getType()); @@ -658,9 +654,8 @@ struct HloDynamicBroadcastInDimConverter final // Use static type info. auto bcastDims = - llvm::map_to_vector(op.getBroadcastDimensions(), [](const APInt &d) { - return static_cast(d.getLimitedValue()); - }); + llvm::map_to_vector(op.getBroadcastDimensions(), + [](int64_t d) { return static_cast(d); }); for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) { if (ShapedType::isDynamic(dim)) continue; @@ -671,17 +666,13 @@ struct HloDynamicBroadcastInDimConverter final } // Use annotated expansion behavior, if available. - if (op.getKnownExpandingDimensions()) { - for (const auto &it : - op.getKnownExpandingDimensions()->getValues()) { - auto i = it.getLimitedValue(); + if (auto dims = op.getKnownExpandingDimensions()) { + for (int i : *dims) { dimExprs[i] = rewriter.getAffineConstantExpr(0); } } - if (op.getKnownNonexpandingDimensions()) { - for (const auto &it : - op.getKnownNonexpandingDimensions()->getValues()) { - auto i = it.getLimitedValue(); + if (auto dims = op.getKnownNonexpandingDimensions()) { + for (int i : *dims) { dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]); } } @@ -730,8 +721,7 @@ struct DynamicBroadcastInDimOpToBroadcastConverter final if (!resultTy) return failure(); - SmallVector broadcastDimensions = - llvm::to_vector(op.getBroadcastDimensions().getValues()); + SmallVector broadcastDimensions = op.getBroadcastDimensions(); SmallVector> expansionBehavior( broadcastDimensions.size()); @@ -745,14 +735,14 @@ struct DynamicBroadcastInDimOpToBroadcastConverter final // Use annotated expansion behavior, if available. if (op.getKnownExpandingDimensions()) { - for (const auto &it : - op.getKnownExpandingDimensions()->getValues()) { + auto dims = op.getKnownExpandingDimensions().value(); + for (int it : dims) { expansionBehavior[it] = true; } } if (op.getKnownNonexpandingDimensions()) { - for (const auto &it : - op.getKnownNonexpandingDimensions()->getValues()) { + auto dims = op.getKnownNonexpandingDimensions().value(); + for (int it : dims) { expansionBehavior[it] = false; } } @@ -853,7 +843,7 @@ struct TransposeConverter final SmallVector inputExprs; inputExprs.resize(resultType.getRank()); for (auto [idx, value] : llvm::enumerate(op.getPermutation())) { - inputExprs[value.getZExtValue()] = b->getAffineDimExpr(idx); + inputExprs[value] = b->getAffineDimExpr(idx); } return { AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), @@ -876,8 +866,7 @@ struct TransposeOpToTransposeConverter final Value emptyTensor = getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); - auto permutation = rewriter.getDenseI64ArrayAttr( - llvm::to_vector(op.getPermutation().getValues())); + auto permutation = op.getPermutationAttr(); rewriter.replaceOpWithNewOp( op, adaptor.getOperand(), emptyTensor, permutation, @@ -1453,8 +1442,7 @@ struct ReverseConverter final inputExprs.reserve(nloops); for (int64_t i = 0; i < nloops; ++i) inputExprs.push_back(b->getAffineDimExpr(i)); - for (const APInt &dim : op.getDimensions()) { - int i = dim.getZExtValue(); + for (int i : op.getDimensions()) { if (resultType.isDynamicDim(i)) return {}; int n = resultType.getShape()[i]; @@ -1479,9 +1467,9 @@ struct SliceConverter final : OpConversionPattern { } SmallVector offsets, sizes, strides; - auto startIndices = sliceOp.getStartIndices().getValues(); - auto limitIndices = sliceOp.getLimitIndices().getValues(); - auto sliceStrides = sliceOp.getStrides().getValues(); + auto startIndices = sliceOp.getStartIndices(); + auto limitIndices = sliceOp.getLimitIndices(); + auto sliceStrides = sliceOp.getStrides(); for (int64_t i = 0, e = argType.getRank(); i < e; ++i) { int64_t start = startIndices[i]; @@ -1526,9 +1514,8 @@ struct DynamicSliceConverter final SmallVector startIndices, sizes; auto originalStartIndexType = llvm::cast( dynamicSliceOp.getStartIndices().front().getType()); - for (auto [idx, start, size] : - llvm::enumerate(adaptor.getStartIndices(), - dynamicSliceOp.getSliceSizes().getValues())) { + for (auto [idx, start, size] : llvm::enumerate( + adaptor.getStartIndices(), dynamicSliceOp.getSliceSizes())) { sizes.push_back(rewriter.getI64IntegerAttr(size)); // By stablehlo.DynamicSlice definition: @@ -2305,7 +2292,7 @@ struct PadOpNegativePaddingConversion final SmallVector sliceStarts; bool hasNegativePadding = false; - for (int64_t low : op.getEdgePaddingLow().getValues()) { + for (int64_t low : op.getEdgePaddingLow()) { if (low >= 0) { padLow.push_back(low); sliceStarts.push_back(rewriter.getIndexAttr(0)); @@ -2316,7 +2303,7 @@ struct PadOpNegativePaddingConversion final } } - for (int64_t high : op.getEdgePaddingHigh().getValues()) { + for (int64_t high : op.getEdgePaddingHigh()) { if (high >= 0) { padHigh.push_back(high); } else { @@ -2332,8 +2319,8 @@ struct PadOpNegativePaddingConversion final // Create a new pad op with the positive values. Value pad = rewriter.create( op.getLoc(), adaptor.getOperand(), adaptor.getPaddingValue(), - rewriter.getI64TensorAttr(padLow), rewriter.getI64TensorAttr(padHigh), - op.getInteriorPadding()); + rewriter.getDenseI64ArrayAttr(padLow), + rewriter.getDenseI64ArrayAttr(padHigh), op.getInteriorPadding()); // Then slice according to the negative edge padding. Static shapes only for // now. @@ -2365,24 +2352,26 @@ struct PadOpConversion final : OpConversionPattern { return rewriter.notifyMatchFailure(op, "type conversion failed"); // Negative edge padding is decomposed separately. - auto isNegative = [](const APInt &intVal) { return intVal.isNegative(); }; - if (llvm::any_of(op.getEdgePaddingLow().getValues(), isNegative) || - llvm::any_of(op.getEdgePaddingHigh().getValues(), isNegative)) + auto isNegative = [](int64_t intVal) { return intVal < 0; }; + if (llvm::any_of(op.getEdgePaddingLow(), isNegative) || + llvm::any_of(op.getEdgePaddingHigh(), isNegative)) return failure(); Value paddingVal = rewriter.createOrFold( loc, adaptor.getPaddingValue()); - SmallVector low( - op.getEdgePaddingLow().getValues()); + auto i64ToFoldResult = [&](const int64_t &i) -> OpFoldResult { + return rewriter.getIntegerAttr(rewriter.getI64Type(), i); + }; // If there is no interior padding lower to tensor.pad directly. - if (llvm::all_of(op.getInteriorPadding().getValues(), - [](const APInt &intVal) { return intVal.isZero(); })) { - SmallVector high( - op.getEdgePaddingHigh().getValues()); + if (llvm::all_of(op.getInteriorPadding(), + [](const int64_t &i) { return i == 0; })) { auto padTensorOp = rewriter.create( - loc, resultType, adaptor.getOperand(), low, high, paddingVal); + loc, resultType, adaptor.getOperand(), + llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult), + llvm::map_to_vector(op.getEdgePaddingHigh(), i64ToFoldResult), + paddingVal); rewriter.replaceOp(op, padTensorOp.getResult()); return success(); } @@ -2405,15 +2394,15 @@ struct PadOpConversion final : OpConversionPattern { .getResult(); }); // Map interior padding to strides. - auto strides = - llvm::map_to_vector(op.getInteriorPadding().getValues(), - [&](IntegerAttr stride) -> OpFoldResult { - return rewriter.getIntegerAttr( - stride.getType(), stride.getValue() + 1); - }); + auto strides = llvm::map_to_vector( + op.getInteriorPadding(), [&](const int64_t &stride) -> OpFoldResult { + return rewriter.getIntegerAttr(rewriter.getI64Type(), stride + 1); + }); rewriter.replaceOpWithNewOp( - op, adaptor.getOperand(), fill, low, sizes, strides); + op, adaptor.getOperand(), fill, + llvm::map_to_vector(op.getEdgePaddingLow(), i64ToFoldResult), sizes, + strides); return success(); } }; diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp index 971115b3c05c..353852dcd5e8 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgConvolution.cpp @@ -55,9 +55,6 @@ Value applyConvolutionPadding(Location loc, Value input, } } - IntegerType indexType = rewriter.getIntegerType(64); - auto attrType = RankedTensorType::get({rank}, indexType); - Value zero; if (auto complexType = dyn_cast(inputType.getElementType())) { auto zeroElement = rewriter.getZeroAttr(complexType.getElementType()); @@ -72,9 +69,9 @@ Value applyConvolutionPadding(Location loc, Value input, } return rewriter.create( - loc, input, zero, DenseIntElementsAttr::get(attrType, padLow), - DenseIntElementsAttr::get(attrType, padHigh), - DenseIntElementsAttr::get(attrType, padInterior)); + loc, input, zero, rewriter.getDenseI64ArrayAttr(padLow), + rewriter.getDenseI64ArrayAttr(padHigh), + rewriter.getDenseI64ArrayAttr(padInterior)); } /// If the ConvolutionOp has a window reversal, applies it to the filter. @@ -95,10 +92,7 @@ Value applyConvolutionReversal(Location loc, OpBuilder &b, } return b.create( - loc, filter, - mlir::DenseIntElementsAttr::get( - RankedTensorType::get(reversedDims.size(), b.getI64Type()), - reversedDims)); + loc, filter, b.getDenseI64ArrayAttr(reversedDims)); } /// Returns true if the given `dimensionNumbers` from a stablehlo.convolution op diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp index e924809c5e95..cc27b85dc021 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgExt.cpp @@ -374,10 +374,10 @@ struct FftOpConversion final : OpConversionPattern { if (!operandType || !operandType.hasStaticShape()) { return failure(); } - if (!op.getFftLength().isSplat()) { + if (!llvm::all_equal(op.getFftLength())) { return rewriter.notifyMatchFailure(op, "non-splat length"); } - int fftLength = op.getFftLength().getSplatValue().getInt(); + int fftLength = op.getFftLength().front(); if (fftLength & (fftLength - 1)) { return rewriter.notifyMatchFailure( op, "expected FFT length to be a power of two"); @@ -442,7 +442,7 @@ struct ReverseOpConversion final rewriter.create(loc, mixedSizes, ty.getElementType()); rewriter.replaceOpWithNewOp( op, typeConverter->convertType(op.getType()), adaptor.getOperands(), - emptyTensor, op.getDimensions()); + emptyTensor, rewriter.getI64TensorAttr(op.getDimensions())); return success(); } }; diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp index 592d6d5c19eb..033dd69b3258 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgRandom.cpp @@ -426,9 +426,9 @@ LogicalResult generateLinalgThreeFry32(OpBuilder &builder, Location loc, llvm::SmallVector offset(resultTy.getRank(), 0); llvm::SmallVector stride(resultTy.getRank(), 1); Value slice = builder.create( - loc, resultTy, reshape, builder.getI64TensorAttr(offset), - builder.getI64TensorAttr(resultTy.getShape()), - builder.getI64TensorAttr(stride)); + loc, resultTy, reshape, builder.getDenseI64ArrayAttr(offset), + builder.getDenseI64ArrayAttr(resultTy.getShape()), + builder.getDenseI64ArrayAttr(stride)); // Set the new tensor values. store = setState64(builder, loc, store, newState); @@ -636,12 +636,14 @@ LogicalResult generateLinalgPhilox32(OpBuilder &builder, Location loc, // Slice to only the required results. collapseShape[0] = resultTy.getNumElements(); - llvm::SmallVector offset(resultTy.getRank(), 0); - llvm::SmallVector stride(resultTy.getRank(), 1); + auto sliceResultTy = intermediateType.clone(collapseShape); + llvm::SmallVector offset(sliceResultTy.getRank(), 0); + llvm::SmallVector stride(sliceResultTy.getRank(), 1); Value slice = builder.create( - loc, intermediateType.clone(collapseShape), reshapeIntermediate, - builder.getI64TensorAttr(offset), builder.getI64TensorAttr(collapseShape), - builder.getI64TensorAttr(stride)); + loc, sliceResultTy, reshapeIntermediate, + builder.getDenseI64ArrayAttr(offset), + builder.getDenseI64ArrayAttr(collapseShape), + builder.getDenseI64ArrayAttr(stride)); Value reshapeResult = builder.create(loc, resultTy, slice); @@ -727,12 +729,14 @@ LogicalResult generateLinalgPhilox64(OpBuilder &builder, Location loc, // Slice to only the required results. collapseShape[0] = resultTy.getNumElements(); - llvm::SmallVector offset(resultTy.getRank(), 0); - llvm::SmallVector stride(resultTy.getRank(), 1); + auto sliceResultTy = intermediateType.clone(collapseShape); + llvm::SmallVector offset(sliceResultTy.getRank(), 0); + llvm::SmallVector stride(sliceResultTy.getRank(), 1); Value slice = builder.create( - loc, intermediateType.clone(collapseShape), reshapeIntermediate, - builder.getI64TensorAttr(offset), builder.getI64TensorAttr(collapseShape), - builder.getI64TensorAttr(stride)); + loc, sliceResultTy, reshapeIntermediate, + builder.getDenseI64ArrayAttr(offset), + builder.getDenseI64ArrayAttr(collapseShape), + builder.getDenseI64ArrayAttr(stride)); Value reshapeResult = builder.create(loc, resultTy, slice); diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp index a579fe7f59ed..c261d039b30e 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/StableHLOToLinalgReduce.cpp @@ -404,7 +404,7 @@ struct ReduceWindowOpOnTensorsGenericConversion final if (!resultTy.hasStaticShape()) return failure(); - auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape()); + auto broadcastSizes = rewriter.getDenseI64ArrayAttr(resultTy.getShape()); broadcastValues.push_back(rewriter.create( loc, resultTy, initValue, broadcastSizes)); } @@ -426,12 +426,9 @@ struct ReduceWindowOpOnTensorsGenericConversion final staticInteriors[idx] = dilation - 1; } - auto padAttrType = - RankedTensorType::get({rank}, rewriter.getIntegerType(64)); - auto padLows = DenseIntElementsAttr::get(padAttrType, staticLows); - auto padHighs = DenseIntElementsAttr::get(padAttrType, staticHighs); - auto padInteriors = - DenseIntElementsAttr::get(padAttrType, staticInteriors); + auto padLows = rewriter.getDenseI64ArrayAttr(staticLows); + auto padHighs = rewriter.getDenseI64ArrayAttr(staticHighs); + auto padInteriors = rewriter.getDenseI64ArrayAttr(staticInteriors); for (auto [input, initValue] : llvm::zip(inputs, initValues)) { input = rewriter.create( diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir index f8b5c103e2b4..bcf0c43cfd06 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg.mlir @@ -417,7 +417,7 @@ func.func @broadcast_in_dim_ui32(%operand: tensor<5x7x1xui32>) -> tensor<7x10x6x func.func @broadcast_in_dim_with_one_to_one( %operand: tensor<1xf32>) -> tensor<1x5xf32> { %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + {broadcast_dimensions = array} : (tensor<1xf32>) -> tensor<1x5xf32> func.return %0 : tensor<1x5xf32> } @@ -464,7 +464,7 @@ func.func @broadcast_in_dim_with_transpose( // CHECK: func @broadcast_in_dim_scalar func.func @broadcast_in_dim_scalar(%operand: tensor) -> tensor<7x10x6xf32> { %0 = "stablehlo.broadcast_in_dim"(%operand) - {broadcast_dimensions = dense<[]> : tensor<0xi64>} + {broadcast_dimensions = array} : (tensor) -> tensor<7x10x6xf32> func.return %0 : tensor<7x10x6xf32> } @@ -484,7 +484,7 @@ func.func @broadcast_in_dim_scalar(%operand: tensor) -> tensor<7x10x6xf32> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @broadcast_scalar func.func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { - %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> + %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array} : (tensor) -> tensor<4x2x1xf32> func.return %0: tensor<4x2x1xf32> } // CHECK: tensor.empty() : tensor<4x2x1xf32> @@ -505,7 +505,7 @@ func.func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> // CHECK: func @broadcast func.func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { - %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> + %0 = "stablehlo.broadcast"(%arg) {broadcast_sizes = array} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> func.return %0: tensor<4x2x1x4x?x16xf32> } // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -655,7 +655,7 @@ func.func @map_mixed(%arg0: tensor, ^bb0(%arg2: tensor, %arg3: tensor): %1 = stablehlo.add %arg2, %arg3 : tensor "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor, tensor<4xf32>) -> tensor func.return %0 : tensor } @@ -675,7 +675,7 @@ func.func @map_one_arg(%arg0: tensor) -> tensor { ^bb0(%arg2: tensor): %1 = stablehlo.add %arg2, %arg2 : tensor "stablehlo.return"(%1) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor) -> tensor func.return %0 : tensor } @@ -706,7 +706,7 @@ func.func @map_compare(%arg0: tensor>, {comparison_direction = #stablehlo} : (tensor, tensor) -> tensor "stablehlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor>, tensor>) -> tensor func.return %0 : tensor } @@ -741,9 +741,9 @@ func.func @map_compare(%arg0: tensor>, func.func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { %0 = arith.constant dense<0.0> : tensor %1 = "stablehlo.pad"(%arg0, %0) { - edge_padding_high = dense<[2, 3]> : tensor<2xi64>, - edge_padding_low = dense<[4, 5]> : tensor<2xi64>, - interior_padding = dense<0> : tensor<2xi64> + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array } : (tensor<12x4xf32>, tensor) -> tensor<18x12xf32> func.return %1 : tensor<18x12xf32> } @@ -758,9 +758,9 @@ func.func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> { func.func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x12xf32> { %0 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = dense<[2, 3]> : tensor<2xi64>, - edge_padding_low = dense<[4, 5]> : tensor<2xi64>, - interior_padding = dense<0> : tensor<2xi64> + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array } : (tensor<12x4xf32>, tensor) -> tensor<18x12xf32> func.return %0 : tensor<18x12xf32> } @@ -777,9 +777,9 @@ func.func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor) -> tensor<18x func.func @pad_interior(%arg0: tensor<12x4xui32>, %arg1: tensor) -> tensor<29x15xui32> { %0 = arith.constant dense<0> : tensor %1 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = dense<[2, 3]> : tensor<2xi64>, - edge_padding_low = dense<[4, 5]> : tensor<2xi64>, - interior_padding = dense<[1, 1]> : tensor<2xi64> + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array } : (tensor<12x4xui32>, tensor) -> tensor<29x15xui32> func.return %1 : tensor<29x15xui32> } @@ -798,9 +798,9 @@ func.func @pad_interior(%arg0: tensor<12x4xui32>, %arg1: tensor) -> tensor func.func @pad_interior_negative(%arg0: tensor<12x4xui32>, %arg1: tensor) -> tensor<25x9xui32> { %0 = arith.constant dense<0> : tensor %1 = "stablehlo.pad"(%arg0, %arg1) { - edge_padding_high = dense<[-2, 3]> : tensor<2xi64>, - edge_padding_low = dense<[4, -1]> : tensor<2xi64>, - interior_padding = dense<[1, 1]> : tensor<2xi64> + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array } : (tensor<12x4xui32>, tensor) -> tensor<25x9xui32> func.return %1 : tensor<25x9xui32> } @@ -1066,7 +1066,7 @@ func.func @reshape_empty(%arg0: tensor<7x0xf64>) -> tensor<0x42x101xf64> { // CHECK: func @reverse func.func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { %result = "stablehlo.reverse"(%input) { - dimensions = dense<1> : tensor<1xi64>, someattr + dimensions = array, someattr } : (tensor<2x3xf32>) -> tensor<2x3xf32> func.return %result : tensor<2x3xf32> } @@ -1332,9 +1332,9 @@ func.func @torch_index_select_dynamic(%input: tensor, // CHECK: tensor.extract_slice %{{.*}}[1, 0] [1, 4] [1, 1] : tensor<3x4xi32> to tensor<1x4xi32> func.func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { %0 = "stablehlo.slice"(%arg0) { - start_indices = dense<[1, 0]> : tensor<2xi64>, - limit_indices = dense<[2, 4]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + start_indices = array, + limit_indices = array, + strides = array } : (tensor<3x4xi32>) -> tensor<1x4xi32> func.return %0 : tensor<1x4xi32> } @@ -1345,9 +1345,9 @@ func.func @slice_whole_stride(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { // CHECK: tensor.extract_slice %{{.*}}[1, 1] [1, 2] [1, 1] : tensor<3x4xi32> to tensor<1x2xi32> func.func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { %0 = "stablehlo.slice"(%arg0) { - start_indices = dense<[1, 1]> : tensor<2xi64>, - limit_indices = dense<[2, 3]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + start_indices = array, + limit_indices = array, + strides = array } : (tensor<3x4xi32>) -> tensor<1x2xi32> func.return %0 : tensor<1x2xi32> } @@ -1358,9 +1358,9 @@ func.func @slice_stride_part(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { // CHECK: tensor.extract_slice %{{.*}}[0] [6] [2] : tensor<13xi32> to tensor<6xi32> func.func @slice_with_strides(%arg0: tensor<13xi32>) -> tensor<6xi32> { %0 = "stablehlo.slice"(%arg0) { - limit_indices = dense<12> : tensor<1xi64>, - start_indices = dense<0> : tensor<1xi64>, - strides = dense<2> : tensor<1xi64> + limit_indices = array, + start_indices = array, + strides = array } : (tensor<13xi32>) -> tensor<6xi32> func.return %0 : tensor<6xi32> } @@ -1371,9 +1371,9 @@ func.func @slice_with_strides(%arg0: tensor<13xi32>) -> tensor<6xi32> { // CHECK: tensor.extract_slice %{{.*}}[0] [3] [2] : tensor<6xi32> to tensor<3xi32> func.func @slice_with_strides2(%arg0: tensor<6xi32>) -> tensor<3xi32> { %0 = "stablehlo.slice"(%arg0) { - limit_indices = dense<5> : tensor<1xi64>, - start_indices = dense<0> : tensor<1xi64>, - strides = dense<2> : tensor<1xi64> + limit_indices = array, + start_indices = array, + strides = array } : (tensor<6xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } @@ -1384,9 +1384,9 @@ func.func @slice_with_strides2(%arg0: tensor<6xi32>) -> tensor<3xi32> { // CHECK: tensor.extract_slice %{{.*}}[0, 2, 0] [3, 0, 5] [1, 2, 1] : tensor<3x3x5xf64> to tensor<3x0x5xf64> func.func @slice_with_empty_result(%arg0: tensor<3x3x5xf64>) -> tensor<3x0x5xf64> { %0 = "stablehlo.slice"(%arg0) { - limit_indices = dense<[3, 2, 5]> : tensor<3xi64>, - start_indices = dense<[0, 2, 0]> : tensor<3xi64>, - strides = dense<[1, 2, 1]> : tensor<3xi64> + limit_indices = array, + start_indices = array, + strides = array } : (tensor<3x3x5xf64>) -> tensor<3x0x5xf64> func.return %0 : tensor<3x0x5xf64> } @@ -1399,7 +1399,7 @@ func.func @slice_with_empty_result(%arg0: tensor<3x3x5xf64>) -> tensor<3x0x5xf64 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] func.func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor) -> tensor<1x4xf32> { %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = dense<[1, 4]> : tensor<2xi64> + slice_sizes = array } : (tensor<3x4xf32>, tensor, tensor) -> tensor<1x4xf32> func.return %0 : tensor<1x4xf32> } @@ -1422,7 +1422,7 @@ func.func @dynamic_slice_unsigned_index( %arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) -> tensor<1x4xui32> { %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = dense<[1, 4]> : tensor<2xi64> + slice_sizes = array } : (tensor<3x4xui32>, tensor, tensor) -> tensor<1x4xui32> func.return %0 : tensor<1x4xui32> } @@ -1438,7 +1438,7 @@ func.func @dynamic_slice_unsigned_index( // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] func.func @dynamic_slice_unsigned(%arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) -> tensor<1x4xui32> { %0 = "stablehlo.dynamic_slice"(%arg, %start1, %start2) { - slice_sizes = dense<[1, 4]> : tensor<2xi64> + slice_sizes = array } : (tensor<3x4xui32>, tensor, tensor) -> tensor<1x4xui32> func.return %0 : tensor<1x4xui32> } @@ -1559,7 +1559,7 @@ func.func @dynamic_update_slice_float(%target: tensor<3x3xf32>, // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @transpose func.func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { - %0 = "stablehlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} + %0 = "stablehlo.transpose"(%arg0) {permutation = array} : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> func.return %0 : tensor<3x2x5x9xi32> } @@ -1574,7 +1574,7 @@ func.func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @transpose_dynamic func.func @transpose_dynamic(%arg0: tensor) -> tensor { - %0 = "stablehlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>, someattr} + %0 = "stablehlo.transpose"(%arg0) {permutation = array, someattr} : (tensor) -> tensor func.return %0 : tensor } @@ -1607,7 +1607,7 @@ func.func @transpose_dynamic(%arg0: tensor) -> tensor func.func @transpose_unsigned(%arg0: tensor<2x2xui32>) -> tensor<2x2xui32> { %0 = "stablehlo.transpose"(%arg0) { - permutation = dense<[1, 0]> : tensor<2xi64>, + permutation = array, result_layout = dense<[0, 1]> : tensor<2xindex> } : (tensor<2x2xui32>) -> tensor<2x2xui32> return %0 : tensor<2x2xui32> diff --git a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir index 17270e2861b9..5f4b338ea61a 100644 --- a/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir +++ b/compiler/plugins/input/StableHLO/stablehlo-iree/Conversion/test/stablehlo_to_linalg_ext.mlir @@ -394,7 +394,7 @@ func.func @scatter_ui32(%arg0: tensor<1xui32>, %arg1: tensor<1x1xi32>, %arg2: te // CHECK: func.func @rfft_1d func.func @rfft_1d(%input: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) { %0 = "stablehlo.fft"(%input) { - fft_length = dense<8> : tensor<1xi64>, fft_type = #stablehlo + fft_length = array, fft_type = #stablehlo } : (tensor<8xf32>) -> tensor<5xcomplex> %1 = "stablehlo.real"(%0) : (tensor<5xcomplex>) -> tensor<5xf32> %2 = "stablehlo.imag"(%0) : (tensor<5xcomplex>) -> tensor<5xf32> @@ -442,7 +442,7 @@ func.func @rfft_1d(%input: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) { // CHECK: func.func @rfft_2d func.func @rfft_2d(%input: tensor<4x8xf32>) -> (tensor<4x5xf32>, tensor<4x5xf32>) { %0 = "stablehlo.fft"(%input) { - fft_length = dense<8> : tensor<1xi64>, fft_type = #stablehlo + fft_length = array, fft_type = #stablehlo } : (tensor<4x8xf32>) -> tensor<4x5xcomplex> %1 = "stablehlo.real"(%0) : (tensor<4x5xcomplex>) -> tensor<4x5xf32> %2 = "stablehlo.imag"(%0) : (tensor<4x5xcomplex>) -> tensor<4x5xf32> @@ -490,7 +490,7 @@ func.func @rfft_2d(%input: tensor<4x8xf32>) -> (tensor<4x5xf32>, tensor<4x5xf32> // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] func.func @reverse_dim1(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { %0 = "stablehlo.reverse"(%arg0) { - dimensions = dense<1> : tensor<1xi64> + dimensions = array } : (tensor<3x5xi32>) -> tensor<3x5xi32> return %0 : tensor<3x5xi32> } @@ -505,7 +505,7 @@ func.func @reverse_dim1(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> { %0 = "stablehlo.reverse"(%arg0) { - dimensions = dense<1> : tensor<1xi64> + dimensions = array } : (tensor<3x5xui32>) -> tensor<3x5xui32> return %0 : tensor<3x5xui32> } @@ -526,7 +526,7 @@ func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> { // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] func.func @reverse_multi_dim(%arg0: tensor) -> tensor { %0 = "stablehlo.reverse"(%arg0) { - dimensions = dense<[0, 1]> : tensor<2xi64> + dimensions = array } : (tensor) -> tensor return %0 : tensor } diff --git a/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py b/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py index 47334ac6fe52..2e5bcf959bc4 100644 --- a/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py +++ b/tests/e2e/stablehlo_models/mnist_train_test/mnist_train_test.py @@ -19,7 +19,7 @@ from iree.compiler.tools import InputType, compile_file from iree.runtime import load_vm_flatbuffer_file -MODEL_ARTIFACTS_URL = "https://storage.googleapis.com/iree-model-artifacts/mnist_train.a49ba1535a45ac0f3e6be22a7ed5dddf4a53cd1f41126af938f0667b998f8e11.tar" +MODEL_ARTIFACTS_URL = "https://storage.googleapis.com/iree-model-artifacts/mnist_train.45208053dcd69ebd7428fe5b785249a7bdff2d62d55fb81b815889c4e1b993bb.tar" Tensor = TypeVar("Tensor") diff --git a/tests/e2e/stablehlo_models/unidirectional_lstm.mlir b/tests/e2e/stablehlo_models/unidirectional_lstm.mlir index e16251c4c7ad..a2a378cf98c2 100644 --- a/tests/e2e/stablehlo_models/unidirectional_lstm.mlir +++ b/tests/e2e/stablehlo_models/unidirectional_lstm.mlir @@ -79,18 +79,18 @@ func.func private @Forward_o16DF3vQKaI__disable_call_shape_inference_true_.189(% %62 = stablehlo.dot %61, %45, precision = [DEFAULT] : (tensor<1x74xf32>, tensor<74x40xf32>) -> tensor<1x40xf32> %63 = stablehlo.reshape %43 : (tensor<40xf32>) -> tensor<1x40xf32> %64 = stablehlo.add %62, %63 : tensor<1x40xf32> - %65 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 30]> : tensor<2xi64>, start_indices = dense<[0, 20]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32> + %65 = "stablehlo.slice"(%64) {limit_indices = array, start_indices = array, strides = array} : (tensor<1x40xf32>) -> tensor<1x10xf32> %66 = stablehlo.multiply %65, %8 : tensor<1x10xf32> %67 = stablehlo.tanh %66 : tensor<1x10xf32> %68 = stablehlo.multiply %67, %8 : tensor<1x10xf32> %69 = stablehlo.add %68, %8 : tensor<1x10xf32> %70 = stablehlo.multiply %69, %47 : tensor<1x10xf32> - %71 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 20]> : tensor<2xi64>, start_indices = dense<[0, 10]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32> + %71 = "stablehlo.slice"(%64) {limit_indices = array, start_indices = array, strides = array} : (tensor<1x40xf32>) -> tensor<1x10xf32> %72 = stablehlo.multiply %71, %8 : tensor<1x10xf32> %73 = stablehlo.tanh %72 : tensor<1x10xf32> %74 = stablehlo.multiply %73, %8 : tensor<1x10xf32> %75 = stablehlo.add %74, %8 : tensor<1x10xf32> - %76 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 10]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32> + %76 = "stablehlo.slice"(%64) {limit_indices = array, start_indices = array, strides = array} : (tensor<1x40xf32>) -> tensor<1x10xf32> %77 = stablehlo.tanh %76 : tensor<1x10xf32> %78 = stablehlo.multiply %75, %77 : tensor<1x10xf32> %79 = stablehlo.add %70, %78 : tensor<1x10xf32> @@ -100,7 +100,7 @@ func.func private @Forward_o16DF3vQKaI__disable_call_shape_inference_true_.189(% %83 = stablehlo.reshape %56 : (tensor<1x1xf32>) -> tensor<1xf32> %84 = stablehlo.broadcast_in_dim %83, dims = [0] : (tensor<1xf32>) -> tensor<1x10xf32> %85 = stablehlo.compare GT, %84, %7 : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xi1> - %86 = "stablehlo.slice"(%64) {limit_indices = dense<[1, 40]> : tensor<2xi64>, start_indices = dense<[0, 30]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x40xf32>) -> tensor<1x10xf32> + %86 = "stablehlo.slice"(%64) {limit_indices = array, start_indices = array, strides = array} : (tensor<1x40xf32>) -> tensor<1x10xf32> %87 = stablehlo.multiply %86, %8 : tensor<1x10xf32> %88 = stablehlo.tanh %87 : tensor<1x10xf32> %89 = stablehlo.multiply %88, %8 : tensor<1x10xf32> diff --git a/tests/e2e/stablehlo_ops/broadcast.mlir b/tests/e2e/stablehlo_ops/broadcast.mlir index b72346636c4f..ad4e72d927b5 100644 --- a/tests/e2e/stablehlo_ops/broadcast.mlir +++ b/tests/e2e/stablehlo_ops/broadcast.mlir @@ -1,7 +1,7 @@ func.func @broadcast_2D_3D() { %input = util.unfoldable_constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32> - %result = "stablehlo.broadcast"(%input) {broadcast_sizes = dense<3> : tensor<1xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32> + %result = "stablehlo.broadcast"(%input) {broadcast_sizes = array} : (tensor<2x4xi32>) -> tensor<3x2x4xi32> check.expect_eq_const(%result, dense<[ [[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]], @@ -11,7 +11,7 @@ func.func @broadcast_2D_3D() { func.func @broadcast_3D_scalar() { %input = util.unfoldable_constant dense<42> : tensor - %result = "stablehlo.broadcast"(%input) {broadcast_sizes = dense<[3, 2, 4]> : tensor<3xi64>} : (tensor) -> tensor<3x2x4xi32> + %result = "stablehlo.broadcast"(%input) {broadcast_sizes = array} : (tensor) -> tensor<3x2x4xi32> check.expect_eq_const(%result, dense<[ [[42, 42, 42, 42], [42, 42, 42, 42]], [[42, 42, 42, 42], [42, 42, 42, 42]], diff --git a/tests/e2e/stablehlo_ops/dynamic_slice.mlir b/tests/e2e/stablehlo_ops/dynamic_slice.mlir index dc0a20110ef7..a423b2395bbe 100644 --- a/tests/e2e/stablehlo_ops/dynamic_slice.mlir +++ b/tests/e2e/stablehlo_ops/dynamic_slice.mlir @@ -6,7 +6,7 @@ func.func @dynamic_slice() { %start1 = util.unfoldable_constant dense<1> : tensor %start2 = util.unfoldable_constant dense<2> : tensor %result = "stablehlo.dynamic_slice"(%input, %start1, %start2) { - slice_sizes = dense<[2, 2]> : tensor<2xi64> + slice_sizes = array } : (tensor<3x4xi32>, tensor, tensor) -> tensor<2x2xi32> check.expect_eq_const(%result, dense<[ [7, 8], @@ -22,7 +22,7 @@ func.func @dynamic_unit_slice() { %start1 = util.unfoldable_constant dense<1> : tensor %start2 = util.unfoldable_constant dense<2> : tensor %result = "stablehlo.dynamic_slice"(%input, %start1, %start2) { - slice_sizes = dense<[1, 2]> : tensor<2xi64> + slice_sizes = array } : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x2xi32> check.expect_eq_const(%result, dense<[ [7, 8]]> : tensor<1x2xi32>) : tensor<1x2xi32> @@ -33,7 +33,7 @@ func.func @dynamic_1d_slice() { %input = util.unfoldable_constant dense<[1, 2, 3, 4]> : tensor<4xi32> %start1 = util.unfoldable_constant dense<1> : tensor %result = "stablehlo.dynamic_slice"(%input, %start1) { - slice_sizes = dense<[2]> : tensor<1xi64> + slice_sizes = array } : (tensor<4xi32>, tensor) -> tensor<2xi32> check.expect_eq_const(%result, dense<[2, 3]> : tensor<2xi32>) : tensor<2xi32> return diff --git a/tests/e2e/stablehlo_ops/pad.mlir b/tests/e2e/stablehlo_ops/pad.mlir index 9774bbc7e5a1..18ccf7873b09 100644 --- a/tests/e2e/stablehlo_ops/pad.mlir +++ b/tests/e2e/stablehlo_ops/pad.mlir @@ -2,9 +2,9 @@ func.func @pad_test() { %input = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> %c0 = arith.constant dense<0> : tensor %res = "stablehlo.pad"(%input, %c0) { - edge_padding_low = dense<[0, 1]> : tensor<2xi64>, - edge_padding_high = dense<[1, 5]> : tensor<2xi64>, - interior_padding = dense<0> : tensor<2xi64> + edge_padding_low = array, + edge_padding_high = array, + interior_padding = array } : (tensor<2x3xi32>, tensor) -> tensor<3x9xi32> check.expect_eq_const(%res, dense<[ [0, 1, 2, 3, 0, 0, 0, 0, 0], @@ -16,7 +16,7 @@ func.func @pad_test() { func.func @pad_no_op() { %input = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> %c0 = arith.constant dense<0> : tensor - %res = "stablehlo.pad"(%input, %c0) {edge_padding_high = dense<[0, 0]> : tensor<2xi64>, edge_padding_low = dense<[0, 0]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<2x3xi32>, tensor) -> tensor<2x3xi32> + %res = "stablehlo.pad"(%input, %c0) {edge_padding_high = array, edge_padding_low = array, interior_padding = array} : (tensor<2x3xi32>, tensor) -> tensor<2x3xi32> check.expect_eq(%res, %input) : tensor<2x3xi32> return } diff --git a/tests/e2e/stablehlo_ops/reverse.mlir b/tests/e2e/stablehlo_ops/reverse.mlir index 11d53e6e1e08..336065c4fa11 100644 --- a/tests/e2e/stablehlo_ops/reverse.mlir +++ b/tests/e2e/stablehlo_ops/reverse.mlir @@ -1,19 +1,19 @@ func.func @xla_reverse() { %t1 = util.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - %dim0 = "stablehlo.reverse"(%t1) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %dim0 = "stablehlo.reverse"(%t1) {dimensions = array} : (tensor<2x3xf32>) -> tensor<2x3xf32> check.expect_almost_eq_const( %dim0, dense<[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32> ) : tensor<2x3xf32> - %dim1 = "stablehlo.reverse"(%t1) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %dim1 = "stablehlo.reverse"(%t1) {dimensions = array} : (tensor<2x3xf32>) -> tensor<2x3xf32> check.expect_almost_eq_const( %dim1, dense<[[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]> : tensor<2x3xf32> ) : tensor<2x3xf32> - %both_dims = "stablehlo.reverse"(%t1) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %both_dims = "stablehlo.reverse"(%t1) {dimensions = array} : (tensor<2x3xf32>) -> tensor<2x3xf32> check.expect_almost_eq_const( %both_dims, dense<[[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]> : tensor<2x3xf32> diff --git a/tests/e2e/stablehlo_ops/slice.mlir b/tests/e2e/stablehlo_ops/slice.mlir index 2f0120dd3e75..f52c5a64d5e9 100644 --- a/tests/e2e/stablehlo_ops/slice.mlir +++ b/tests/e2e/stablehlo_ops/slice.mlir @@ -4,9 +4,9 @@ func.func @slice_whole_buffer() { [05, 06, 07, 08], [09, 10, 11, 12]]> : tensor<3x4xi32> %result = "stablehlo.slice"(%input) { - start_indices = dense<[0, 0]> : tensor<2xi64>, - limit_indices = dense<[3, 4]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + start_indices = array, + limit_indices = array, + strides = array } : (tensor<3x4xi32>) -> tensor<3x4xi32> check.expect_eq_const(%result, dense<[ [1, 2, 3, 4], @@ -21,9 +21,9 @@ func.func @slice_whole_stride() { [05, 06, 07, 08], [09, 10, 11, 12]]> : tensor<3x4xi32> %result = "stablehlo.slice"(%input) { - start_indices = dense<[1, 0]> : tensor<2xi64>, - limit_indices = dense<[2, 4]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + start_indices = array, + limit_indices = array, + strides = array } : (tensor<3x4xi32>) -> tensor<1x4xi32> check.expect_eq_const(%result, dense<[[5, 6, 7, 8]]> : tensor<1x4xi32>) : tensor<1x4xi32> return @@ -35,9 +35,9 @@ func.func @slice_stride_part() { [05, 06, 07, 08], [09, 10, 11, 12]]> : tensor<3x4xi32> %result = "stablehlo.slice"(%input) { - start_indices = dense<[1, 1]> : tensor<2xi64>, - limit_indices = dense<[2, 3]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + start_indices = array, + limit_indices = array, + strides = array } : (tensor<3x4xi32>) -> tensor<1x2xi32> check.expect_eq_const(%result, dense<[[6, 7]]> : tensor<1x2xi32>) : tensor<1x2xi32> return @@ -49,9 +49,9 @@ func.func @slice_multi_stride() { [05, 06, 07, 08], [09, 10, 11, 12]]> : tensor<3x4xi32> %result = "stablehlo.slice"(%input) { - start_indices = dense<[1, 0]> : tensor<2xi64>, - limit_indices = dense<[3, 4]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> + start_indices = array, + limit_indices = array, + strides = array } : (tensor<3x4xi32>) -> tensor<2x4xi32> check.expect_eq_const(%result, dense<[ [5, 6, 7, 8], diff --git a/tests/e2e/stablehlo_ops/transpose.mlir b/tests/e2e/stablehlo_ops/transpose.mlir index d709f2e5b43c..8b9b2389cae1 100644 --- a/tests/e2e/stablehlo_ops/transpose.mlir +++ b/tests/e2e/stablehlo_ops/transpose.mlir @@ -2,7 +2,7 @@ func.func @transpose_2d() { %input = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> %0 = "stablehlo.transpose"(%input) { - permutation = dense<[1, 0]> : tensor<2xi64> + permutation = array } : (tensor<2x3xi32>) -> tensor<3x2xi32> check.expect_eq_const(%0, dense<[[1, 4], [2, 5], @@ -16,7 +16,7 @@ func.func @transpose_3d() { [[ 7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> %0 = "stablehlo.transpose"(%input) { - permutation = dense<[0, 2, 1]> : tensor<3xi64> + permutation = array } : (tensor<2x2x3xi32>) -> tensor<2x3x2xi32> check.expect_eq_const(%0, dense<[ [[ 1, 4], diff --git a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake index 334b996898f3..64a78f0d44b4 100644 --- a/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake +++ b/tests/e2e/test_artifacts/generated_e2e_test_fetch_models.cmake @@ -84,29 +84,29 @@ iree_fetch_artifact( iree_fetch_artifact( NAME "model-EfficientNetV2STF" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/EfficientNetV2STF_2023-05-07.timestamp_1683504734.mlirbc" - OUTPUT "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/EfficientNetV2STF_2023-05-07.timestamp_1683504734j.mlirbc" + OUTPUT "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-MiniLML12H384Uncased" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734.mlirbc" - OUTPUT "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734j.mlirbc" + OUTPUT "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-BertForMaskedLMTF" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertForMaskedLMTF_2023-05-07.timestamp_1683504734.mlirbc" - OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertForMaskedLMTF_2023-05-07.timestamp_1683504734j.mlirbc" + OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-BertLargeTF" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc" - OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/manual/BertLargeTF_2023-05-07.timestamp_1683504734j.mlirbc" + OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" UNPACK ) @@ -140,63 +140,63 @@ iree_fetch_artifact( iree_fetch_artifact( NAME "model-BertLargeTFBatch1" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/BERT_LARGE_FP32_TF_384XI32_BATCH1/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/BERT_LARGE_FP32_TF_384XI32_BATCH1/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTFBatch1.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-BertLargeTFBatch32" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/BERT_LARGE_FP32_TF_384XI32_BATCH32/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/BERT_LARGE_FP32_TF_384XI32_BATCH32/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTFBatch32.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-BertLargeTFBatch64" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/BERT_LARGE_FP32_TF_384XI32_BATCH64/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/BERT_LARGE_FP32_TF_384XI32_BATCH64/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_BertLargeTFBatch64.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-Resnet50TFBatch1" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/RESNET50_FP32_TF_224X224X3XF32_BATCH1/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/RESNET50_FP32_TF_224X224X3XF32_BATCH1/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Resnet50TFBatch1.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-Resnet50TFBatch64" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/RESNET50_FP32_TF_224X224X3XF32_BATCH64/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/RESNET50_FP32_TF_224X224X3XF32_BATCH64/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Resnet50TFBatch64.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-Resnet50TFBatch128" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/RESNET50_FP32_TF_224X224X3XF32_BATCH128/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/RESNET50_FP32_TF_224X224X3XF32_BATCH128/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_Resnet50TFBatch128.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-T5LargeTFBatch1" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH1/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/T5_LARGE_FP32_TF_512XI32_BATCH1/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_T5LargeTFBatch1.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-T5LargeTFBatch16" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH16/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/T5_LARGE_FP32_TF_512XI32_BATCH16/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_T5LargeTFBatch16.mlirbc" UNPACK ) iree_fetch_artifact( NAME "model-T5LargeTFBatch32" - SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975/T5_LARGE_FP32_TF_512XI32_BATCH32/stablehlo.mlirbc" + SOURCE_URL "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j/T5_LARGE_FP32_TF_512XI32_BATCH32/stablehlo.mlirbc" OUTPUT "${ROOT_ARTIFACTS_DIR}/model_T5LargeTFBatch32.mlirbc" UNPACK ) diff --git a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake index 1cbbde3e43f8..dafca14c0fe0 100644 --- a/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake +++ b/tests/e2e/test_artifacts/generated_e2e_test_iree_artifacts.cmake @@ -246,7 +246,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -260,7 +260,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -274,7 +274,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -288,7 +288,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -538,7 +538,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -553,7 +553,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -568,7 +568,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -583,7 +583,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -838,7 +838,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -853,7 +853,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -868,7 +868,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -883,7 +883,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -1093,7 +1093,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -1105,7 +1105,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -1117,7 +1117,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -1129,7 +1129,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -1275,7 +1275,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2120,7 +2120,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2138,7 +2138,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2156,7 +2156,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2174,7 +2174,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_no-dt_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2492,7 +2492,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2511,7 +2511,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2530,7 +2530,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2549,7 +2549,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__experimental-flags_dt-only_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2872,7 +2872,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2891,7 +2891,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2910,7 +2910,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -2929,7 +2929,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_dt-uk_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" @@ -3195,7 +3195,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_EfficientNetV2STF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_EfficientNetV2STF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -3211,7 +3211,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -3227,7 +3227,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertForMaskedLMTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertForMaskedLMTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -3243,7 +3243,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_BertLargeTF.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_BertLargeTF_stablehlo___cuda-sm_80-linux_gnu-cuda__default-flags_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=cuda" @@ -3433,7 +3433,7 @@ iree_bytecode_module( iree_bytecode_module( NAME "iree-module-MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_compile-stats_" - SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734.mlirbc" + SRC "${ROOT_ARTIFACTS_DIR}/model_MiniLML12H384Uncased.timestamp_1683504734j.mlirbc" MODULE_FILE_NAME "${ROOT_ARTIFACTS_DIR}/iree_module_MiniLML12H384Uncased_stablehlo___riscv_64-generic-linux_gnu-llvm_cpu__default-flags_compile-stats_/module.vmfb" FLAGS "--iree-hal-target-backends=llvm-cpu" diff --git a/tests/microbenchmarks/stablehlo_fft_abs.mlir b/tests/microbenchmarks/stablehlo_fft_abs.mlir index 3ba9274439f5..cdb5f7c1d933 100644 --- a/tests/microbenchmarks/stablehlo_fft_abs.mlir +++ b/tests/microbenchmarks/stablehlo_fft_abs.mlir @@ -5,7 +5,7 @@ func.func @rfft_abs_6x1024() -> tensor<6x513xf32> { %input = util.unfoldable_constant dense<1.0> : tensor<6x1024xf32> %0 = "stablehlo.fft"(%input) { - fft_length = dense<1024> : tensor<1xi64>, + fft_length = array, fft_type = #stablehlo } : (tensor<6x1024xf32>) -> tensor<6x513xcomplex> %1 = "stablehlo.abs"(%0) : (tensor<6x513xcomplex>) -> tensor<6x513xf32> diff --git a/third_party/stablehlo b/third_party/stablehlo index 6b1ebdbfa70e..f8dcebfa1ec1 160000 --- a/third_party/stablehlo +++ b/third_party/stablehlo @@ -1 +1 @@ -Subproject commit 6b1ebdbfa70ef9ce794f41e4fd1b0839191164c9 +Subproject commit f8dcebfa1ec166806974f6ae0dfb902d36b47238