Skip to content

Commit

Permalink
Bump StableHLO to f8dcebfa1ec166806974f6ae0dfb902d36b47238 (#16049)
Browse files Browse the repository at this point in the history
Updated the serialized models too, but just manually, so used naming
convention to capture that it wasn't newly generated/manually done.
  • Loading branch information
jpienaar authored Jan 5, 2024
1 parent d6dad12 commit 124d562
Show file tree
Hide file tree
Showing 29 changed files with 290 additions and 313 deletions.
2 changes: 1 addition & 1 deletion build_tools/python/e2e_test_framework/models/jax_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from e2e_test_framework.definitions import common_definitions
import e2e_test_framework.models.utils as model_utils

GCS_ARTIFACT_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.14_1691969180"
GCS_ARTIFACT_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.14_1691969180j"

ID_FORMAT = string.Template("${model_id}-batch${batch_size}")
NAME_FORMAT = string.Template("${name}_BATCH${batch_size}")
Expand Down
10 changes: 5 additions & 5 deletions build_tools/python/e2e_test_framework/models/tf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
tags=["int32", "seqlen128"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://huggingface.co/microsoft/MiniLM-L12-H384-uncased/commit/44acabbec0ef496f6dbc93adadea57f376b7c0ec
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734.mlirbc",
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/MiniLML12H384Uncased_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="predict",
input_types=["1x128xi32", "1x128xi32", "1x128xi32"],
)
Expand All @@ -32,7 +32,7 @@
tags=["fp32", "seqlen512", "tensorflow"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#tfbertformaskedlm
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734.mlirbc",
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertForMaskedLMTF_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="forward",
input_types=["1x512xi32", "1x512xi32"],
)
Expand All @@ -43,7 +43,7 @@
tags=["fp32", "cnn", "tensorflow"],
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Converted from https://github.com/keras-team/keras/blob/v2.10.0/keras/applications/efficientnet_v2.py
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734.mlirbc",
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/EfficientNetV2STF_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="forward",
input_types=["1x384x384x3xf32"],
)
Expand All @@ -56,7 +56,7 @@
source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
# Derived from https://github.com/mlcommons/inference/tree/master/language/bert
# Instructions on how to regenerate the model: https://gist.github.com/mariecwhite/e61ccebd979d98d097946ac7725bcc29
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734.mlirbc",
source_url=f"{TF_MODELS_MANUAL_ROOT_DIR}/BertLargeTF_2023-05-07.timestamp_1683504734j.mlirbc",
entry_function="serving_default",
input_types=["1x384xi32", "1x384xi32", "1x384xi32"],
)
Expand All @@ -81,7 +81,7 @@
input_types=["1x1xi32", "12x2x1x12x4x64xf32"],
)

TF_MODELS_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975"
TF_MODELS_ROOT_DIR = "https://storage.googleapis.com/iree-model-artifacts/tensorflow/tf_models_2.15.0.dev20230817_1692333975j"

ID_FORMAT = string.Template("${model_id}-batch-${batch_size}")
NAME_FORMAT = string.Template("${name}Batch${batch_size}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Utils/IndexSet.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo-iree/Conversion/Rewriters.h"
Expand Down Expand Up @@ -448,7 +449,7 @@ static Value emitTranspose(ConversionPatternRewriter &rewriter, Location loc,
llvm::to_vector(llvm::seq<int64_t>(0, inputShape.size()));
std::swap(permutation[srcDim], permutation[dstDim]);
std::swap(inputShape[srcDim], inputShape[dstDim]);
DenseIntElementsAttr permutationAttr = rewriter.getI64VectorAttr(permutation);
auto permutationAttr = rewriter.getDenseI64ArrayAttr(permutation);
return rewriter.create<mlir::stablehlo::TransposeOp>(
loc, RankedTensorType::get(inputShape, inputType.getElementType()), input,
permutationAttr);
Expand Down Expand Up @@ -705,7 +706,7 @@ Value splitAndConcatForAllToAll(ConversionPatternRewriter &rewriter,
result = rewriter.create<mlir::stablehlo::TransposeOp>(
loc,
RankedTensorType::get(transposeResultShape, inputType.getElementType()),
result, rewriter.getI64VectorAttr(permutation));
result, rewriter.getDenseI64ArrayAttr(permutation));

// Reshape
llvm::SmallVector<int64_t> finalShape(inputShape);
Expand Down Expand Up @@ -852,7 +853,7 @@ struct ReduceScatterOpConversion final
auto inputType = cast<RankedTensorType>(op.getOperand().getType());
SmallVector<int64_t> reduceInputShape(inputType.getShape());
Value reduceInput = adaptor.getOperand();
DenseIntElementsAttr permutationAttr;
DenseI64ArrayAttr permutationAttr;

SmallVector<int64_t> scatterResultShape(resultType.getShape());
auto elemType = getElementTypeOrSelf(reduceInput.getType());
Expand All @@ -861,7 +862,7 @@ struct ReduceScatterOpConversion final
auto permutation =
llvm::to_vector(llvm::seq<int64_t>(0, scatterResultShape.size()));
std::swap(permutation[0], permutation[scatterDim]);
permutationAttr = rewriter.getI64VectorAttr(permutation);
permutationAttr = rewriter.getDenseI64ArrayAttr(permutation);
std::swap(reduceInputShape[0], reduceInputShape[scatterDim]);
std::swap(scatterResultShape[0], scatterResultShape[scatterDim]);
// Transpose the input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2144,14 +2144,14 @@ struct ConvertTopKOp final : OpConversionPattern<mlir::chlo::TopKOp> {
} else {
values = rewriter.create<mlir::stablehlo::SliceOp>(
op.getLoc(), tupleFirstElement,
DenseIntElementsAttr::get(indicesTy, beginIndices),
DenseIntElementsAttr::get(indicesTy, endIndices),
DenseIntElementsAttr::get(indicesTy, strides));
rewriter.getDenseI64ArrayAttr(beginIndices),
rewriter.getDenseI64ArrayAttr(endIndices),
rewriter.getDenseI64ArrayAttr(strides));
indices = rewriter.create<mlir::stablehlo::SliceOp>(
op.getLoc(), tupleSecondElement,
DenseIntElementsAttr::get(indicesTy, beginIndices),
DenseIntElementsAttr::get(indicesTy, endIndices),
DenseIntElementsAttr::get(indicesTy, strides));
rewriter.getDenseI64ArrayAttr(beginIndices),
rewriter.getDenseI64ArrayAttr(endIndices),
rewriter.getDenseI64ArrayAttr(strides));
}

rewriter.replaceOp(op, {values, indices});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ namespace {
// allowed to materialize as new constants.
constexpr int64_t kFoldOpEltLimit = 65536;

static bool isIotaRange(ArrayRef<int64_t> dims) {
for (auto [idx, value] : llvm::enumerate(dims)) {
if (idx != value) {
return false;
}
}

return true;
}

static bool isIotaRange(ElementsAttr attr) {
auto elems = attr.tryGetValues<APInt>();
if (!elems)
Expand Down Expand Up @@ -469,7 +479,7 @@ struct BroadcastInDimOpCanon final
return failure();

// Fold when broadcast is a noop.
DenseIntElementsAttr dims = op.getBroadcastDimensions();
auto dims = op.getBroadcastDimensions();
bool isDimsIota = isIotaRange(dims);
if (type == operandTy && isDimsIota) {
rewriter.replaceOp(op, operand);
Expand All @@ -485,7 +495,7 @@ struct BroadcastInDimOpCanon final
return success();
}

auto bsDimIndices = dims.getValues<int64_t>();
auto bsDimIndices = dims;
if (operandTy.hasStaticShape() && type.hasStaticShape() &&
type.getNumElements() == operandTy.getNumElements()) {
// BroadcastInDim equivalent to reshape.
Expand All @@ -505,12 +515,10 @@ struct BroadcastInDimOpCanon final
// Eliminate redundant nested BroadcastInDim.
if (auto broadcastInDimOp =
operand.getDefiningOp<mlir::stablehlo::BroadcastInDimOp>()) {
auto newIndices = cast<DenseIntElementsAttr>(
broadcastInDimOp.getBroadcastDimensions().mapValues(
dims.getElementType(), [&bsDimIndices](const APInt &dim) {
return APInt(dim.getBitWidth(),
bsDimIndices[dim.getSExtValue()], true);
}));
auto newIndices =
rewriter.getDenseI64ArrayAttr(llvm::to_vector(llvm::map_range(
broadcastInDimOp.getBroadcastDimensions(),
[&bsDimIndices](int64_t dim) { return bsDimIndices[dim]; })));
rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
op, type, broadcastInDimOp.getOperand(), newIndices);
return success();
Expand Down Expand Up @@ -631,7 +639,7 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final
// output has static shape, replace with broadcast_in_dim
if (type.hasStaticShape()) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
op, type, op.getOperand(), op.getBroadcastDimensions());
op, type, op.getOperand(), op.getBroadcastDimensionsAttr());
return success();
}

Expand All @@ -648,7 +656,7 @@ struct DynamicBroadcastInDimOpNotActuallyDynamic final
refineOpWithNewOp<mlir::stablehlo::BroadcastInDimOp>(
rewriter, op,
RankedTensorType::get(outputShape, type.getElementType()),
op.getOperand(), op.getBroadcastDimensions());
op.getOperand(), op.getBroadcastDimensionsAttr());
return success();
}
}
Expand All @@ -670,16 +678,11 @@ struct ChainedDynamicBroadcastInDimCanonicalization final
return failure();

// Compose broadcast dimensions.
DenseIntElementsAttr precedingBcastDims =
precedingBcast.getBroadcastDimensions();
DenseIntElementsAttr bcastDims = bcast.getBroadcastDimensions();
SmallVector<APInt> composition;
for (APInt precedingDim : precedingBcastDims) {
composition.push_back(
*(bcastDims.value_begin<APInt>() + precedingDim.getZExtValue()));
SmallVector<int64_t> composition;
for (int64_t precedingDim : precedingBcast.getBroadcastDimensions()) {
composition.push_back(bcast.getBroadcastDimensions()[precedingDim]);
}
auto composedBcastDims =
DenseIntElementsAttr::get(precedingBcastDims.getType(), composition);
auto composedBcastDims = rewriter.getDenseI64ArrayAttr(composition);

rewriter.replaceOpWithNewOp<mlir::stablehlo::DynamicBroadcastInDimOp>(
bcast, bcast.getType(), precedingBcast.getOperand(),
Expand Down Expand Up @@ -928,9 +931,9 @@ struct GatherOpCanon final : OpRewritePattern<mlir::stablehlo::GatherOp> {
auto sliceType = RankedTensorType::get(sliceShape, elementType);
Value result = rewriter.create<mlir::stablehlo::SliceOp>(
gather.getLoc(), sliceType, gather.getOperand(),
rewriter.getI64TensorAttr(sliceStart),
rewriter.getI64TensorAttr(sliceEnd),
rewriter.getI64TensorAttr(sliceStride));
rewriter.getDenseI64ArrayAttr(sliceStart),
rewriter.getDenseI64ArrayAttr(sliceEnd),
rewriter.getDenseI64ArrayAttr(sliceStride));

ArrayRef<int64_t> collapsedSliceDims = dnums.getCollapsedSliceDims();
if (!collapsedSliceDims.empty()) {
Expand Down Expand Up @@ -1030,7 +1033,7 @@ struct TransposeIsReshape final
"tensor type");
}

SmallVector<int64_t> permValues(permutation.getValues<int64_t>());
SmallVector<int64_t> permValues(permutation);

SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,8 @@ Value transposeReshape(Value arg, Location loc,
auto transposePermutation =
llvm::to_vector<5>(llvm::concat<const int64_t>(leftDims, rightDims));

TensorType transposePermutationType =
RankedTensorType::get({static_cast<int64_t>(transposePermutation.size())},
rewriter.getIntegerType(64));

auto transposePermutationAttr =
llvm::cast<DenseIntElementsAttr>(DenseIntElementsAttr::get(
transposePermutationType, llvm::ArrayRef(transposePermutation)));
rewriter.getDenseI64ArrayAttr(transposePermutation);

// Compute the resulting shape.
llvm::SmallVector<int64_t, 5> transposedShape;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ struct EinsumToDotGeneralPattern final
} else {
// Generate a transpose.
rewriter.replaceOpWithNewOp<mlir::stablehlo::TransposeOp>(
einsum, dotGeneralOp, rewriter.getI64TensorAttr(resultPerms));
einsum, dotGeneralOp, rewriter.getDenseI64ArrayAttr(resultPerms));
}
return success();
}
Expand Down
Loading

0 comments on commit 124d562

Please sign in to comment.