-
Notifications
You must be signed in to change notification settings - Fork 31
Self Attention block
In self-attention from "Attention is all you need":
-
[B] Batch dimension = 64
-
[T] Sequence length = 32 (to be confirmed)
-
[De] Embedding lenght (or dmodel) = 512
-
[heads] = 8
-
key_dim = value_dim = De/heads = 64
Starting point, Python code:
import tensorflow as tf
import keras_nlp
from tensorflow.python.pywrap_mlir import import_graphdef
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.compiler.mlir import mlir
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
class MultiHeadAttention(tf.Module):
def __init__(self):
super().__init__()
self.model = tf.keras.layers.MultiHeadAttention(num_heads=8, key_dim=64, use_bias=False)
@tf.function(jit_compile=True)
def forward(self, query, key, value):
return self.model(query, key, value, training=False)
T = 32
de = 512
dv = 64
query = tf.keras.Input(shape=[T, de], batch_size=64, dtype=tf.float32)
value = query
key = query
model = MultiHeadAttention()
concrete_func = model.forward.get_concrete_function(tf.TensorSpec(shape=query.shape, dtype=tf.float32),
tf.TensorSpec(shape=value.shape, dtype=tf.float32),
tf.TensorSpec(shape=key.shape, dtype=tf.float32))
concrete_func = convert_variables_to_constants_v2(concrete_func)
graph = concrete_func.graph.as_graph_def()
mlir_tf = import_graphdef(
graph,
"tf-standard-pipeline",
False,
input_names=["query", "value", "key"],
input_data_types=["DT_FLOAT", "DT_FLOAT", "DT_FLOAT"],
input_data_shapes=["64,32,512", "64,32,512", "64,32,512"],
output_names=["Identity:0"]
)
print(mlir_tf)
Python code in tf dialect:
%0 = "tf.Identity"(%cst_3) {_has_manual_control_dependencies = true, device = ""} : (tensor<8x64x512xf32>) -> tensor<8x64x512xf32>
%1 = "tf.Identity"(%cst_1) {_has_manual_control_dependencies = true, device = ""} : (tensor<512x8x64xf32>) -> tensor<512x8x64xf32>
%2 = "tf.Identity"(%cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor<512x8x64xf32>) -> tensor<512x8x64xf32>
%3 = "tf.Identity"(%cst) {_has_manual_control_dependencies = true, device = ""} : (tensor<512x8x64xf32>) -> tensor<512x8x64xf32>
// projection Wv
V matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x value_dim]
%4 = "tf.Einsum"(%arg2, %3) {device = "", equation = "abc,cde->abde"} : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
// projection Wq
Q matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim]
%5 = "tf.Einsum"(%arg0, %2) {device = "", equation = "abc,cde->abde"} : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
// Scaling factor on Q: https://github.com/keras-team/keras/blob/v2.12.0/keras/layers/attention/multi_head_attention.py#L523
%6 = "tf.Mul"(%5, %cst_2) {device = ""} : (tensor<64x32x8x64xf32>, tensor<f32>) -> tensor<64x32x8x64xf32>
// projection Wk
K matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim]
%7 = "tf.Einsum"(%arg1, %1) {device = "", equation = "abc,cde->abde"} : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
// Dot product between query and key
[B x T x heads x key_dim][B x T x heads x key_dim] => [B x heads x T x T]
%8 = "tf.Einsum"(%7, %6) {device = "", equation = "aecd,abcd->acbe"} : (tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) -> tensor<64x8x32x32xf32>
%9 = "tf.Softmax"(%8) {device = ""} : (tensor<64x8x32x32xf32>) -> tensor<64x8x32x32xf32>
%10 = "tf.Identity"(%9) {device = ""} : (tensor<64x8x32x32xf32>) -> tensor<64x8x32x32xf32>
// Dot product between softmax result and value
[B x heads x T x T] [B x T x heads x value_dim] => [B x T x heads x value_dim]
%11 = "tf.Einsum"(%10, %4) {device = "", equation = "acbe,aecd->abcd"} : (tensor<64x8x32x32xf32>, tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32>
// projection Wo
[B x T x heads x value_dim] [heads x value_dim x De] => [B x T x De]
%12 = "tf.Einsum"(%11, %0) {device = "", equation = "abcd,cde->abe"} : (tensor<64x32x8x64xf32>, tensor<8x64x512xf32>) -> tensor<64x32x512xf32>
%13 = "tf.Identity"(%12) {device = ""} : (tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
return %13 : tensor<64x32x512xf32>
full code with weights: https://gist.github.com/chelini/6e5aabc3b4dd5bac4a0c66bb96b4be2e
Previous code converted to StableHlo using tf-opt -tf-lower-to-mlprogram-and-hlo
%7 = stablehlo.einsum %arg2, %3, config = "abc,cde->abde" : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
%8 = stablehlo.einsum %arg0, %4, config = "abc,cde->abde" : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
%9 = stablehlo.multiply %8, %0 : tensor<64x32x8x64xf32>
%10 = stablehlo.einsum %arg1, %5, config = "abc,cde->abde" : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
%11 = stablehlo.einsum %10, %9, config = "aecd,abcd->acbe" : (tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) -> tensor<64x8x32x32xf32>
%12 = stablehlo.reduce(%11 init: %2) across dimensions = [3] : (tensor<64x8x32x32xf32>, tensor<f32>) -> tensor<64x8x32xf32>
reducer(%arg3: tensor<f32>, %arg4: tensor<f32>) {
%23 = stablehlo.maximum %arg3, %arg4 : tensor<f32>
stablehlo.return %23 : tensor<f32>
}
%13 = stablehlo.reshape %12 : (tensor<64x8x32xf32>) -> tensor<64x8x32x1xf32>
%14 = stablehlo.broadcast_in_dim %13, dims = [0, 1, 2, 3] : (tensor<64x8x32x1xf32>) -> tensor<64x8x32x32xf32>
%15 = stablehlo.subtract %11, %14 : tensor<64x8x32x32xf32>
%16 = stablehlo.exponential %15 : tensor<64x8x32x32xf32>
%17 = stablehlo.reduce(%16 init: %1) across dimensions = [3] : (tensor<64x8x32x32xf32>, tensor<f32>) -> tensor<64x8x32xf32>
reducer(%arg3: tensor<f32>, %arg4: tensor<f32>) {
%23 = stablehlo.add %arg3, %arg4 : tensor<f32>
stablehlo.return %23 : tensor<f32>
}
%18 = stablehlo.reshape %17 : (tensor<64x8x32xf32>) -> tensor<64x8x32x1xf32>
%19 = stablehlo.broadcast_in_dim %18, dims = [0, 1, 2, 3] : (tensor<64x8x32x1xf32>) -> tensor<64x8x32x32xf32>
%20 = stablehlo.divide %16, %19 : tensor<64x8x32x32xf32>
%21 = stablehlo.einsum %20, %7, config = "acbe,aecd->abcd" : (tensor<64x8x32x32xf32>, tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32>
%22 = stablehlo.einsum %21, %6, config = "abcd,cde->abe" : (tensor<64x32x8x64xf32>, tensor<8x64x512xf32>) -> tensor<64x32x512xf32>
return %22 : tensor<64x32x512xf32>
IR emitted by TF converting to linalg: https://gist.github.com/chelini/393ee9334a1a21a1e5781ddafcd003fa
see full IR here: https://gist.github.com/chelini/358f4fabf2c7683184cce9a329513ca2 Linalg IR printed using IREE:
// -----// IR Dump After ConvertStableHloToIreeInputDialects (iree-stablehlo-to-iree-input) //----- //
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
module {
func.func @main(%arg0: tensor<64x32x512xf32>, %arg1: tensor<64x32x512xf32>, %arg2: tensor<64x32x512xf32>) -> tensor<64x32x512xf32> {
%cst = arith.constant -0.000000e+00 : f32
%cst_0 = arith.constant 0xFF800000 : f32
%cst_1 = arith.constant 0.000000e+00 : f32
%cst_2 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
%cst_3 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
%cst_4 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
%cst_5 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
%cst_6 = arith.constant dense<1.250000e-01> : tensor<64x32x8x64xf32>
// heads * key_dim = De, key_dim = value_dim
// V matrix: [B x T x De] [De x heads x value_dim] => [B x T x heads x value_dim]
// V matrix: [[B x T] x De] [De x [heads x value_dim]] = [[B x T] x [heads x value_dim]]
// V matrix: [B x T x heads x value_dim]
%collapsed = tensor.collapse_shape %arg2 [[0, 1], [2]] : tensor<64x32x512xf32> into tensor<2048x512xf32>
%0 = tensor.empty() : tensor<2048x512xf32>
%1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%2 = linalg.matmul ins(%collapsed, %cst_2 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%1 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%expanded = tensor.expand_shape %2 [[0, 1], [2, 3]] : tensor<2048x512xf32> into tensor<64x32x8x64xf32>
// Q matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim] heads * key_dim = De
// Q matrix: [[B x T] x De] [De x [heads x key_dim]] = [[B x T] x [heads x key_dim]]
// Q matrix: [B x T x heads x key_dim]
%collapsed_7 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<64x32x512xf32> into tensor<2048x512xf32>
%3 = tensor.empty() : tensor<2048x512xf32>
%4 = linalg.fill ins(%cst_1 : f32) outs(%3 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%5 = linalg.matmul ins(%collapsed_7, %cst_3 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%4 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%expanded_8 = tensor.expand_shape %5 [[0, 1], [2, 3]] : tensor<2048x512xf32> into tensor<64x32x8x64xf32>
// Scaling factor on Q
%6 = tensor.empty() : tensor<64x32x8x64xf32>
%7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_8, %cst_6 : tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) outs(%6 : tensor<64x32x8x64xf32>) {
^bb0(%in: f32, %in_21: f32, %out: f32):
%46 = arith.mulf %in, %in_21 : f32
linalg.yield %46 : f32
} -> tensor<64x32x8x64xf32>
// K matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim] heads * key_dim = De
// K matrix: [[B x T] x De] [De x [heads x key_dim]] = [[B x T]x[heads x key_dim]]
// K matrix: [B x T x heads x key_dim]
%collapsed_9 = tensor.collapse_shape %arg1 [[0, 1], [2]] : tensor<64x32x512xf32> into tensor<2048x512xf32>
%8 = tensor.empty() : tensor<2048x512xf32>
%9 = linalg.fill ins(%cst_1 : f32) outs(%8 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%10 = linalg.matmul ins(%collapsed_9, %cst_4 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%9 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%expanded_10 = tensor.expand_shape %10 [[0, 1], [2, 3]] : tensor<2048x512xf32> into tensor<64x32x8x64xf32>
// K matrix: [B x T x heads x key_dim] => [B x heads x T x key_dim]
%11 = tensor.empty() : tensor<64x8x32x64xf32>
%12 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_10 : tensor<64x32x8x64xf32>) outs(%11 : tensor<64x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x8x32x64xf32>
// Q matrix: [B x T x heads x key_dim] => [B x heads x key_dim x T]
%13 = tensor.empty() : tensor<64x8x64x32xf32>
%14 = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<64x32x8x64xf32>) outs(%13 : tensor<64x8x64x32xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x8x64x32xf32>
// K matrix: [B x heads x T x key_dim] => [[B x heads] x T x key_dim]
// Q matrix: [B x heads x key_dim x T] => [[B x heads] x key_dim x T]
%collapsed_11 = tensor.collapse_shape %12 [[0, 1], [2], [3]] : tensor<64x8x32x64xf32> into tensor<512x32x64xf32>
%collapsed_12 = tensor.collapse_shape %14 [[0, 1], [2], [3]] : tensor<64x8x64x32xf32> into tensor<512x64x32xf32>
// [[B x heads] x T x key_dim] * [[B x heads] x key_dim x T] => [[B x heads] x T x T]
%15 = tensor.empty() : tensor<512x32x32xf32>
%16 = linalg.fill ins(%cst_1 : f32) outs(%15 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>
%17 = linalg.batch_matmul ins(%collapsed_11, %collapsed_12 : tensor<512x32x64xf32>, tensor<512x64x32xf32>) outs(%16 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>
// [[B x heads] x T x T] => [[B x heads] x T x T]
%expanded_13 = tensor.expand_shape %17 [[0, 1], [2], [3]] : tensor<512x32x32xf32> into tensor<64x8x32x32xf32>
%18 = tensor.empty() : tensor<64x8x32x32xf32>
%19 = linalg.generic {indexing_maps = [#map3, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_13 : tensor<64x8x32x32xf32>) outs(%18 : tensor<64x8x32x32xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x8x32x32xf32>
// SOFTMAX
%20 = tensor.empty() : tensor<64x8x32xf32>
%21 = linalg.fill ins(%cst_0 : f32) outs(%20 : tensor<64x8x32xf32>) -> tensor<64x8x32xf32>
%22 = linalg.generic {indexing_maps = [#map, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%19 : tensor<64x8x32x32xf32>) outs(%21 : tensor<64x8x32xf32>) {
^bb0(%in: f32, %out: f32):
%46 = arith.maxf %out, %in : f32
linalg.yield %46 : f32
} -> tensor<64x8x32xf32>
%expanded_14 = tensor.expand_shape %22 [[0], [1], [2, 3]] : tensor<64x8x32xf32> into tensor<64x8x32x1xf32>
%23 = tensor.empty() : tensor<64x8x32x32xf32>
%24 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_14 : tensor<64x8x32x1xf32>) outs(%23 : tensor<64x8x32x32xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x8x32x32xf32>
%25 = tensor.empty() : tensor<64x8x32x32xf32>
%26 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%19, %24 : tensor<64x8x32x32xf32>, tensor<64x8x32x32xf32>) outs(%25 : tensor<64x8x32x32xf32>) {
^bb0(%in: f32, %in_21: f32, %out: f32):
%46 = arith.subf %in, %in_21 : f32
linalg.yield %46 : f32
} -> tensor<64x8x32x32xf32>
%27 = tensor.empty() : tensor<64x8x32x32xf32>
%28 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%26 : tensor<64x8x32x32xf32>) outs(%27 : tensor<64x8x32x32xf32>) {
^bb0(%in: f32, %out: f32):
%46 = math.exp %in : f32
linalg.yield %46 : f32
} -> tensor<64x8x32x32xf32>
%29 = tensor.empty() : tensor<64x8x32xf32>
%30 = linalg.fill ins(%cst : f32) outs(%29 : tensor<64x8x32xf32>) -> tensor<64x8x32xf32>
%31 = linalg.generic {indexing_maps = [#map, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%28 : tensor<64x8x32x32xf32>) outs(%30 : tensor<64x8x32xf32>) {
^bb0(%in: f32, %out: f32):
%46 = arith.addf %out, %in : f32
linalg.yield %46 : f32
} -> tensor<64x8x32xf32>
%expanded_15 = tensor.expand_shape %31 [[0], [1], [2, 3]] : tensor<64x8x32xf32> into tensor<64x8x32x1xf32>
%32 = tensor.empty() : tensor<64x8x32x32xf32>
%33 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_15 : tensor<64x8x32x1xf32>) outs(%32 : tensor<64x8x32x32xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x8x32x32xf32>
%34 = tensor.empty() : tensor<64x8x32x32xf32>
%35 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%28, %33 : tensor<64x8x32x32xf32>, tensor<64x8x32x32xf32>) outs(%34 : tensor<64x8x32x32xf32>) {
^bb0(%in: f32, %in_21: f32, %out: f32):
%46 = arith.divf %in, %in_21 : f32
linalg.yield %46 : f32
} -> tensor<64x8x32x32xf32>
// V matrix: [B x T x heads x value_dim] => [B x heads x T x value_dim]
%36 = tensor.empty() : tensor<64x8x32x64xf32>
%37 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<64x32x8x64xf32>) outs(%36 : tensor<64x8x32x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x8x32x64xf32>
// [B x heads x T x value_dim] => [[B x heads] x T x value_dim]
%collapsed_16 = tensor.collapse_shape %35 [[0, 1], [2], [3]] : tensor<64x8x32x32xf32> into tensor<512x32x32xf32>
// [B x heads x T x T] => [[B x heads] x T x T]
%collapsed_17 = tensor.collapse_shape %37 [[0, 1], [2], [3]] : tensor<64x8x32x64xf32> into tensor<512x32x64xf32>
%38 = tensor.empty() : tensor<512x32x64xf32>
%39 = linalg.fill ins(%cst_1 : f32) outs(%38 : tensor<512x32x64xf32>) -> tensor<512x32x64xf32>
// [[B x heads] x T x T] * [[B x heads] x T x value_dim] => [[B x heads] x T x value_dim]
%40 = linalg.batch_matmul ins(%collapsed_16, %collapsed_17 : tensor<512x32x32xf32>, tensor<512x32x64xf32>) outs(%39 : tensor<512x32x64xf32>) -> tensor<512x32x64xf32>
// [[B x heads] x T x value_dim] => [B x heads x T x value_dim]
%expanded_18 = tensor.expand_shape %40 [[0, 1], [2], [3]] : tensor<512x32x64xf32> into tensor<64x8x32x64xf32>
// [B x heads x T x value_dim] => [B x T x heads x value_dim]
%41 = tensor.empty() : tensor<64x32x8x64xf32>
%42 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_18 : tensor<64x8x32x64xf32>) outs(%41 : tensor<64x32x8x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<64x32x8x64xf32>
// [B x T x heads x value_dim] => [[B x T] x [heads x value_dim]]
%collapsed_19 = tensor.collapse_shape %42 [[0, 1], [2, 3]] : tensor<64x32x8x64xf32> into tensor<2048x512xf32>
// [[B x T] x [heads x value_dim]] * Vo[[heads x value_dim] x De] => [[B x T] x De]
%43 = tensor.empty() : tensor<2048x512xf32>
%44 = linalg.fill ins(%cst_1 : f32) outs(%43 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%45 = linalg.matmul ins(%collapsed_19, %cst_5 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%44 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
// [[B x T] x De] => [B x T x De]
%expanded_20 = tensor.expand_shape %45 [[0, 1], [2]] : tensor<2048x512xf32> into tensor<64x32x512xf32>
return %expanded_20 : tensor<64x32x512xf32>
}
}
Note:
linalg.matmul
and linalg.batch_matmul
are emitted by IREE, see these patterns: https://github.com/openxla/iree/blob/11388b8f54620c1f960aa9eb7a8d573e3b6d1335/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgDotProd.cpp#L274
The transformations are sane as we want two-level of parallelism: heads and batch. But rewriting to named ops introduces
collapse and expand ops which are difficult to propagate through. We decided to use the input from TensorFlow.
- Tile and fuse along the batch and the head dimension.
- Map contractions to GEMMs by mapping the most minor m, n and k dimensions.
Assuming we pack each linalg.matmul
using a blocking factors of 32. The code after packing is shown below:
// unpack operation right after the first packed matmul
%unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<8x16x32x32xf32> -> tensor<256x512xf32>
// expand shape in the original layout
%expanded = tensor.expand_shape %unpack [[0, 1], [2, 3]] : tensor<256x512xf32> into tensor<64x4x8x64xf32>
It is unclear how to propagate through the tensor.unpack
. A possible solution is to try to sink the expand_shape
. Is
this the right abstraction for running the packing propagation? The expand_shape
are a consequence on how IREE rewrites
the IR, perhaps we should avoid applying this patterns in the first place?
----------------------------------------------------------------
Benchmark Time CPU Iterations
----------------------------------------------------------------
BM_transpose_tpp 173069 ns 173066 ns 4046
BM_pack_tpp 173082 ns 173079 ns 4044
BM_unpack_tpp 591061 ns 591031 ns 1110
BM_mha_projection_tpp 12095330 ns 12095010 ns 58
BM_brgemm_tpp 9632356 ns 9631671 ns 72
BM_transpose_tpp
is a linalg.transpose
lowered using tpp.identity
, BM_pack_tpp
is a pack operation lowered using tpp.identity
.
BM_unpack_tpp
is an unpack operation lowered using a tpp.identity
and a linalg.copy
, the latter is
needed to preserve DPS. BM_brgemm_tpp
is a BRGEMM operation lowered using tpp.brgemm
while BM_mha_projection_tpp
is an entire projection (BRGEMM + tpp.zero as init + pack on A operand and unpack on output).
-
ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs
https://gist.github.com/silvasean/b9c5f60dfbe3f51cf079bad3c76d095a
import tensorflow as tf
import keras_nlp
from tensorflow.python.pywrap_mlir import import_graphdef
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.compiler.mlir import mlir
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
class SimpleGelu(tf.Module):
def __init__(self):
super().__init__()
@tf.function(jit_compile=True)
def forward(self, x):
return tf.keras.activations.gelu(x, approximate=False)
x = tf.keras.Input(shape=[32, 512], batch_size=64, dtype=tf.float32)
model = SimpleGelu()
concrete_func = model.forward.get_concrete_function(tf.TensorSpec(shape=x.shape, dtype=tf.float32))
concrete_func = convert_variables_to_constants_v2(concrete_func)
graph = concrete_func.graph.as_graph_def()
mlir_tf = import_graphdef(
graph,
"tf-standard-pipeline",
False,
input_names=["x"],
input_data_types=["DT_FLOAT"],
input_data_shapes=["64,32,512"],
output_names=["Identity:0"]
)
print(mlir_tf)
Translates to:
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1395 : i32}} {
func.func @main(%arg0: tensor<64x32x512xf32>) -> tensor<64x32x512xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "x", outputs = "Identity:0"}} {
%cst = "tf.Const"() {value = dense<0.707106769> : tensor<f32>} : () -> tensor<f32>
%cst_0 = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%cst_1 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
%0 = "tf.Mul"(%arg0, %cst_1) {device = ""} : (tensor<64x32x512xf32>, tensor<f32>) -> tensor<64x32x512xf32>
%1 = "tf.Mul"(%arg0, %cst) : (tensor<64x32x512xf32>, tensor<f32>) -> tensor<64x32x512xf32>
%2 = "tf.Erf"(%1) {device = ""} : (tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
%3 = "tf.AddV2"(%2, %cst_0) {device = ""} : (tensor<64x32x512xf32>, tensor<f32>) -> tensor<64x32x512xf32>
%4 = "tf.Mul"(%0, %3) {device = ""} : (tensor<64x32x512xf32>, tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
%5 = "tf.Identity"(%4) {device = ""} : (tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
return %5 : tensor<64x32x512xf32>
}
}
Lowering to Linalg using: ./tf-opt gelu.mlir -tf-lower-to-mlprogram-and-hlo -stablehlo-legalize-to-hlo -hlo-legalize-to-linalg -canonicalize -cse
We get:
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
func.func @main(%arg0: tensor<64x32x512xf32>) -> tensor<64x32x512xf32> {
%cst = arith.constant dense<1.000000e+00> : tensor<64x32x512xf32>
%cst_0 = arith.constant dense<0.707106769> : tensor<64x32x512xf32>
%cst_1 = arith.constant dense<5.000000e-01> : tensor<64x32x512xf32>
%cst_2 = arith.constant dense<-1.000000e+00> : tensor<64x32x512xf32>
%cst_3 = arith.constant dense<-0.0142647391> : tensor<64x32x512xf32>
%cst_4 = arith.constant dense<-0.00737332925> : tensor<64x32x512xf32>
%cst_5 = arith.constant dense<-0.00168282702> : tensor<64x32x512xf32>
%cst_6 = arith.constant dense<-2.13374049E-4> : tensor<64x32x512xf32>
%cst_7 = arith.constant dense<-1.45660715E-5> : tensor<64x32x512xf32>
%cst_8 = arith.constant dense<-0.0160960332> : tensor<64x32x512xf32>
%cst_9 = arith.constant dense<-2.954600e-03> : tensor<64x32x512xf32>
%cst_10 = arith.constant dense<-7.34990637E-4> : tensor<64x32x512xf32>
%cst_11 = arith.constant dense<-5.69250624E-5> : tensor<64x32x512xf32>
%cst_12 = arith.constant dense<-2.10102394E-6> : tensor<64x32x512xf32>
%cst_13 = arith.constant dense<2.77068146E-8> : tensor<64x32x512xf32>
%cst_14 = arith.constant dense<-2.72614237E-10> : tensor<64x32x512xf32>
%cst_15 = arith.constant dense<4.000000e+00> : tensor<64x32x512xf32>
%cst_16 = arith.constant dense<-4.000000e+00> : tensor<64x32x512xf32>
%0 = tensor.empty() : tensor<64x32x512xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst_1 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst_0 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16, %2, %cst_15 : tensor<64x32x512xf32>, tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %in_18: f32, %out: f32):
%30 = arith.maxf %in, %in_17 : f32
%31 = arith.minf %30, %in_18 : f32
linalg.yield %31 : f32
} -> tensor<64x32x512xf32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %3 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %cst_14 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5, %cst_13 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%8 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %cst_12 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9, %cst_11 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%12 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%11, %cst_10 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%13 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%14 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13, %cst_9 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%14, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%16 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %cst_8 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%17 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %cst_7 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%18 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%17, %cst_6 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%20 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%19, %cst_5 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%21 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%20, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%22 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%21, %cst_4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%23 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%22, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%24 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%23, %cst_3 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%25 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %16 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%26 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%25, %24 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.divf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%27 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_2, %26, %cst : tensor<64x32x512xf32>, tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %in_18: f32, %out: f32):
%30 = arith.maxf %in, %in_17 : f32
%31 = arith.minf %30, %in_18 : f32
linalg.yield %31 : f32
} -> tensor<64x32x512xf32>
%28 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%27, %cst : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.addf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
%29 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %28 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
^bb0(%in: f32, %in_17: f32, %out: f32):
%30 = arith.mulf %in, %in_17 : f32
linalg.yield %30 : f32
} -> tensor<64x32x512xf32>
return %29 : tensor<64x32x512xf32>
}
}