Skip to content

Self Attention block

lorenzo chelini edited this page Nov 21, 2023 · 28 revisions

Self attention layer

In self-attention from "Attention is all you need":

  • [B] Batch dimension = 64

  • [T] Sequence length = 32 (to be confirmed)

  • [De] Embedding lenght (or dmodel) = 512

  • [heads] = 8

  • key_dim = value_dim = De/heads = 64

Starting point, Python code:

import tensorflow as tf
import keras_nlp
from tensorflow.python.pywrap_mlir import import_graphdef
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.compiler.mlir import mlir
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

class MultiHeadAttention(tf.Module):
    def __init__(self):
        super().__init__()
        self.model = tf.keras.layers.MultiHeadAttention(num_heads=8, key_dim=64, use_bias=False)

    @tf.function(jit_compile=True)
    def forward(self, query, key, value):
        return self.model(query, key, value, training=False)

T = 32
de = 512
dv = 64

query = tf.keras.Input(shape=[T, de], batch_size=64, dtype=tf.float32)
value = query
key = query

model = MultiHeadAttention()
concrete_func = model.forward.get_concrete_function(tf.TensorSpec(shape=query.shape, dtype=tf.float32),
                                                    tf.TensorSpec(shape=value.shape, dtype=tf.float32),
                                                    tf.TensorSpec(shape=key.shape, dtype=tf.float32))
concrete_func = convert_variables_to_constants_v2(concrete_func)
graph = concrete_func.graph.as_graph_def()

mlir_tf = import_graphdef(
        graph,
        "tf-standard-pipeline",
        False,
        input_names=["query", "value", "key"],
        input_data_types=["DT_FLOAT", "DT_FLOAT", "DT_FLOAT"],
        input_data_shapes=["64,32,512", "64,32,512", "64,32,512"],
        output_names=["Identity:0"]
    )
print(mlir_tf)

Python code in tf dialect:

%0 = "tf.Identity"(%cst_3) {_has_manual_control_dependencies = true, device = ""} : (tensor<8x64x512xf32>) -> tensor<8x64x512xf32>
%1 = "tf.Identity"(%cst_1) {_has_manual_control_dependencies = true, device = ""} : (tensor<512x8x64xf32>) -> tensor<512x8x64xf32>
%2 = "tf.Identity"(%cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor<512x8x64xf32>) -> tensor<512x8x64xf32>
%3 = "tf.Identity"(%cst) {_has_manual_control_dependencies = true, device = ""} : (tensor<512x8x64xf32>) -> tensor<512x8x64xf32>

// projection Wv
V matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x value_dim]
%4 = "tf.Einsum"(%arg2, %3) {device = "", equation = "abc,cde->abde"} : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>

// projection Wq
Q matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim]
%5 = "tf.Einsum"(%arg0, %2) {device = "", equation = "abc,cde->abde"} : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>

// Scaling factor on Q: https://github.com/keras-team/keras/blob/v2.12.0/keras/layers/attention/multi_head_attention.py#L523
%6 = "tf.Mul"(%5, %cst_2) {device = ""} : (tensor<64x32x8x64xf32>, tensor<f32>) -> tensor<64x32x8x64xf32>

// projection Wk
K matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim]
%7 = "tf.Einsum"(%arg1, %1) {device = "", equation = "abc,cde->abde"} : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>

// Dot product between query and key
[B x T x heads x key_dim][B x T x heads x key_dim] => [B x heads x T x T]
%8 = "tf.Einsum"(%7, %6) {device = "", equation = "aecd,abcd->acbe"} : (tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) -> tensor<64x8x32x32xf32>
%9 = "tf.Softmax"(%8) {device = ""} : (tensor<64x8x32x32xf32>) -> tensor<64x8x32x32xf32>
%10 = "tf.Identity"(%9) {device = ""} : (tensor<64x8x32x32xf32>) -> tensor<64x8x32x32xf32>

// Dot product between softmax result and value
[B x heads x T x T] [B x T x heads x value_dim] => [B x T x heads x value_dim]
%11 = "tf.Einsum"(%10, %4) {device = "", equation = "acbe,aecd->abcd"} : (tensor<64x8x32x32xf32>, tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32>

// projection Wo
[B x T x heads x value_dim] [heads x value_dim x De] => [B x T x De]
%12 = "tf.Einsum"(%11, %0) {device = "", equation = "abcd,cde->abe"} : (tensor<64x32x8x64xf32>, tensor<8x64x512xf32>) -> tensor<64x32x512xf32>
%13 = "tf.Identity"(%12) {device = ""} : (tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
return %13 : tensor<64x32x512xf32>

full code with weights: https://gist.github.com/chelini/6e5aabc3b4dd5bac4a0c66bb96b4be2e

Previous code converted to StableHlo using tf-opt -tf-lower-to-mlprogram-and-hlo

%7 = stablehlo.einsum %arg2, %3, config = "abc,cde->abde" : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
%8 = stablehlo.einsum %arg0, %4, config = "abc,cde->abde" : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
%9 = stablehlo.multiply %8, %0 : tensor<64x32x8x64xf32>
%10 = stablehlo.einsum %arg1, %5, config = "abc,cde->abde" : (tensor<64x32x512xf32>, tensor<512x8x64xf32>) -> tensor<64x32x8x64xf32>
%11 = stablehlo.einsum %10, %9, config = "aecd,abcd->acbe" : (tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) -> tensor<64x8x32x32xf32>
%12 = stablehlo.reduce(%11 init: %2) across dimensions = [3] : (tensor<64x8x32x32xf32>, tensor<f32>) -> tensor<64x8x32xf32>
    reducer(%arg3: tensor<f32>, %arg4: tensor<f32>)  {
    %23 = stablehlo.maximum %arg3, %arg4 : tensor<f32>
    stablehlo.return %23 : tensor<f32>
}
%13 = stablehlo.reshape %12 : (tensor<64x8x32xf32>) -> tensor<64x8x32x1xf32>
%14 = stablehlo.broadcast_in_dim %13, dims = [0, 1, 2, 3] : (tensor<64x8x32x1xf32>) -> tensor<64x8x32x32xf32>
%15 = stablehlo.subtract %11, %14 : tensor<64x8x32x32xf32>
%16 = stablehlo.exponential %15 : tensor<64x8x32x32xf32>
%17 = stablehlo.reduce(%16 init: %1) across dimensions = [3] : (tensor<64x8x32x32xf32>, tensor<f32>) -> tensor<64x8x32xf32>
     reducer(%arg3: tensor<f32>, %arg4: tensor<f32>)  {
      %23 = stablehlo.add %arg3, %arg4 : tensor<f32>
      stablehlo.return %23 : tensor<f32>
}
%18 = stablehlo.reshape %17 : (tensor<64x8x32xf32>) -> tensor<64x8x32x1xf32>
%19 = stablehlo.broadcast_in_dim %18, dims = [0, 1, 2, 3] : (tensor<64x8x32x1xf32>) -> tensor<64x8x32x32xf32>
%20 = stablehlo.divide %16, %19 : tensor<64x8x32x32xf32>
%21 = stablehlo.einsum %20, %7, config = "acbe,aecd->abcd" : (tensor<64x8x32x32xf32>, tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32>
%22 = stablehlo.einsum %21, %6, config = "abcd,cde->abe" : (tensor<64x32x8x64xf32>, tensor<8x64x512xf32>) -> tensor<64x32x512xf32>
return %22 : tensor<64x32x512xf32>

IR emitted by TF converting to linalg: https://gist.github.com/chelini/393ee9334a1a21a1e5781ddafcd003fa

see full IR here: https://gist.github.com/chelini/358f4fabf2c7683184cce9a329513ca2 Linalg IR printed using IREE:

// -----// IR Dump After ConvertStableHloToIreeInputDialects (iree-stablehlo-to-iree-input) //----- //
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
module {
  func.func @main(%arg0: tensor<64x32x512xf32>, %arg1: tensor<64x32x512xf32>, %arg2: tensor<64x32x512xf32>) -> tensor<64x32x512xf32> {
    %cst = arith.constant -0.000000e+00 : f32
    %cst_0 = arith.constant 0xFF800000 : f32
    %cst_1 = arith.constant 0.000000e+00 : f32
    %cst_2 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
    %cst_3 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
    %cst_4 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
    %cst_5 = arith.constant dense_resource<__elided__> : tensor<512x512xf32>
    %cst_6 = arith.constant dense<1.250000e-01> : tensor<64x32x8x64xf32>
    
    // heads * key_dim = De, key_dim = value_dim
    // V matrix: [B x T x De] [De x heads x value_dim] => [B x T x heads x value_dim]
    // V matrix: [[B x T] x De] [De x [heads x value_dim]] = [[B x T] x [heads x value_dim]]
    // V matrix: [B x T x heads x value_dim]
    %collapsed = tensor.collapse_shape %arg2 [[0, 1], [2]] : tensor<64x32x512xf32> into tensor<2048x512xf32>
    %0 = tensor.empty() : tensor<2048x512xf32>
    %1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    %2 = linalg.matmul ins(%collapsed, %cst_2 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%1 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    %expanded = tensor.expand_shape %2 [[0, 1], [2, 3]] : tensor<2048x512xf32> into tensor<64x32x8x64xf32>
    
    // Q matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim] heads * key_dim = De
    // Q matrix: [[B x T] x De] [De x [heads x key_dim]] = [[B x T] x [heads x key_dim]]
    // Q matrix: [B x T x heads x key_dim]
    %collapsed_7 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<64x32x512xf32> into tensor<2048x512xf32>
    %3 = tensor.empty() : tensor<2048x512xf32>
    %4 = linalg.fill ins(%cst_1 : f32) outs(%3 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    %5 = linalg.matmul ins(%collapsed_7, %cst_3 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%4 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    %expanded_8 = tensor.expand_shape %5 [[0, 1], [2, 3]] : tensor<2048x512xf32> into tensor<64x32x8x64xf32>
    
    // Scaling factor on Q
    %6 = tensor.empty() : tensor<64x32x8x64xf32>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_8, %cst_6 : tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) outs(%6 : tensor<64x32x8x64xf32>) {
    ^bb0(%in: f32, %in_21: f32, %out: f32):
      %46 = arith.mulf %in, %in_21 : f32
      linalg.yield %46 : f32
    } -> tensor<64x32x8x64xf32>
    
    // K matrix: [B x T x De] [De x heads x key_dim] => [B x T x heads x key_dim] heads * key_dim = De
    // K matrix: [[B x T] x De] [De x [heads x key_dim]] = [[B x T]x[heads x key_dim]]
    // K matrix: [B x T x heads x key_dim]
    %collapsed_9 = tensor.collapse_shape %arg1 [[0, 1], [2]] : tensor<64x32x512xf32> into tensor<2048x512xf32>
    %8 = tensor.empty() : tensor<2048x512xf32>
    %9 = linalg.fill ins(%cst_1 : f32) outs(%8 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    %10 = linalg.matmul ins(%collapsed_9, %cst_4 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%9 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    %expanded_10 = tensor.expand_shape %10 [[0, 1], [2, 3]] : tensor<2048x512xf32> into tensor<64x32x8x64xf32>
    
    // K matrix: [B x T x heads x key_dim] =>  [B x heads x T x key_dim]
    %11 = tensor.empty() : tensor<64x8x32x64xf32>
    %12 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_10 : tensor<64x32x8x64xf32>) outs(%11 : tensor<64x8x32x64xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x32x64xf32>
    
    // Q matrix: [B x T x heads x key_dim] => [B x heads x key_dim x T]
    %13 = tensor.empty() : tensor<64x8x64x32xf32>
    %14 = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<64x32x8x64xf32>) outs(%13 : tensor<64x8x64x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x64x32xf32>

    // K matrix: [B x heads x T x key_dim] => [[B x heads] x T x key_dim]
    // Q matrix: [B x heads x key_dim x T] => [[B x heads] x key_dim x T]
    %collapsed_11 = tensor.collapse_shape %12 [[0, 1], [2], [3]] : tensor<64x8x32x64xf32> into tensor<512x32x64xf32>
    %collapsed_12 = tensor.collapse_shape %14 [[0, 1], [2], [3]] : tensor<64x8x64x32xf32> into tensor<512x64x32xf32>
    
    // [[B x heads] x T x key_dim] * [[B x heads] x key_dim x T] => [[B x heads] x T x T]
    %15 = tensor.empty() : tensor<512x32x32xf32>
    %16 = linalg.fill ins(%cst_1 : f32) outs(%15 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>
    %17 = linalg.batch_matmul ins(%collapsed_11, %collapsed_12 : tensor<512x32x64xf32>, tensor<512x64x32xf32>) outs(%16 : tensor<512x32x32xf32>) -> tensor<512x32x32xf32>
    
    // [[B x heads] x T x T] => [[B x heads] x T x T]
    %expanded_13 = tensor.expand_shape %17 [[0, 1], [2], [3]] : tensor<512x32x32xf32> into tensor<64x8x32x32xf32>
    %18 = tensor.empty() : tensor<64x8x32x32xf32>
    %19 = linalg.generic {indexing_maps = [#map3, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_13 : tensor<64x8x32x32xf32>) outs(%18 : tensor<64x8x32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x32x32xf32>
    
    // SOFTMAX
    %20 = tensor.empty() : tensor<64x8x32xf32>
    %21 = linalg.fill ins(%cst_0 : f32) outs(%20 : tensor<64x8x32xf32>) -> tensor<64x8x32xf32>
    %22 = linalg.generic {indexing_maps = [#map, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%19 : tensor<64x8x32x32xf32>) outs(%21 : tensor<64x8x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %46 = arith.maxf %out, %in : f32
      linalg.yield %46 : f32
    } -> tensor<64x8x32xf32>
    %expanded_14 = tensor.expand_shape %22 [[0], [1], [2, 3]] : tensor<64x8x32xf32> into tensor<64x8x32x1xf32>
    %23 = tensor.empty() : tensor<64x8x32x32xf32>
    %24 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_14 : tensor<64x8x32x1xf32>) outs(%23 : tensor<64x8x32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x32x32xf32>
    %25 = tensor.empty() : tensor<64x8x32x32xf32>
    %26 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%19, %24 : tensor<64x8x32x32xf32>, tensor<64x8x32x32xf32>) outs(%25 : tensor<64x8x32x32xf32>) {
    ^bb0(%in: f32, %in_21: f32, %out: f32):
      %46 = arith.subf %in, %in_21 : f32
      linalg.yield %46 : f32
    } -> tensor<64x8x32x32xf32>
    %27 = tensor.empty() : tensor<64x8x32x32xf32>
    %28 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%26 : tensor<64x8x32x32xf32>) outs(%27 : tensor<64x8x32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %46 = math.exp %in : f32
      linalg.yield %46 : f32
    } -> tensor<64x8x32x32xf32>
    %29 = tensor.empty() : tensor<64x8x32xf32>
    %30 = linalg.fill ins(%cst : f32) outs(%29 : tensor<64x8x32xf32>) -> tensor<64x8x32xf32>
    %31 = linalg.generic {indexing_maps = [#map, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%28 : tensor<64x8x32x32xf32>) outs(%30 : tensor<64x8x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      %46 = arith.addf %out, %in : f32
      linalg.yield %46 : f32
    } -> tensor<64x8x32xf32>
    %expanded_15 = tensor.expand_shape %31 [[0], [1], [2, 3]] : tensor<64x8x32xf32> into tensor<64x8x32x1xf32>
    %32 = tensor.empty() : tensor<64x8x32x32xf32>
    %33 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_15 : tensor<64x8x32x1xf32>) outs(%32 : tensor<64x8x32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x32x32xf32>
    %34 = tensor.empty() : tensor<64x8x32x32xf32>
    %35 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%28, %33 : tensor<64x8x32x32xf32>, tensor<64x8x32x32xf32>) outs(%34 : tensor<64x8x32x32xf32>) {
    ^bb0(%in: f32, %in_21: f32, %out: f32):
      %46 = arith.divf %in, %in_21 : f32
      linalg.yield %46 : f32
    } -> tensor<64x8x32x32xf32>

    // V matrix: [B x T x heads x value_dim] => [B x heads x T x value_dim]
    %36 = tensor.empty() : tensor<64x8x32x64xf32>
    %37 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<64x32x8x64xf32>) outs(%36 : tensor<64x8x32x64xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x32x64xf32>
    
    // [B x heads x T x value_dim] => [[B x heads] x T x value_dim]
    %collapsed_16 = tensor.collapse_shape %35 [[0, 1], [2], [3]] : tensor<64x8x32x32xf32> into tensor<512x32x32xf32>
    // [B x heads x T x T] => [[B x heads] x T x T]
    %collapsed_17 = tensor.collapse_shape %37 [[0, 1], [2], [3]] : tensor<64x8x32x64xf32> into tensor<512x32x64xf32>
    %38 = tensor.empty() : tensor<512x32x64xf32>
    %39 = linalg.fill ins(%cst_1 : f32) outs(%38 : tensor<512x32x64xf32>) -> tensor<512x32x64xf32>
    
    // [[B x heads] x T x T] * [[B x heads] x T x value_dim] => [[B x heads] x T x value_dim]
    %40 = linalg.batch_matmul ins(%collapsed_16, %collapsed_17 : tensor<512x32x32xf32>, tensor<512x32x64xf32>) outs(%39 : tensor<512x32x64xf32>) -> tensor<512x32x64xf32>
    
    // [[B x heads] x T x value_dim] => [B x heads x T x value_dim]
    %expanded_18 = tensor.expand_shape %40 [[0, 1], [2], [3]] : tensor<512x32x64xf32> into tensor<64x8x32x64xf32>
   
    // [B x heads x T x value_dim] => [B x T x heads x value_dim]
    %41 = tensor.empty() : tensor<64x32x8x64xf32>
    %42 = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_18 : tensor<64x8x32x64xf32>) outs(%41 : tensor<64x32x8x64xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x32x8x64xf32>

    // [B x T x heads x value_dim] => [[B x T] x [heads x value_dim]]
    %collapsed_19 = tensor.collapse_shape %42 [[0, 1], [2, 3]] : tensor<64x32x8x64xf32> into tensor<2048x512xf32>
    
    // [[B x T] x [heads x value_dim]] * Vo[[heads x value_dim] x De] => [[B x T] x De]
    %43 = tensor.empty() : tensor<2048x512xf32>
    %44 = linalg.fill ins(%cst_1 : f32) outs(%43 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    %45 = linalg.matmul ins(%collapsed_19, %cst_5 : tensor<2048x512xf32>, tensor<512x512xf32>) outs(%44 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
    
    // [[B x T] x De] => [B x T x De]
    %expanded_20 = tensor.expand_shape %45 [[0, 1], [2]] : tensor<2048x512xf32> into tensor<64x32x512xf32>
    return %expanded_20 : tensor<64x32x512xf32>
  }
}

Note:

linalg.matmul and linalg.batch_matmul are emitted by IREE, see these patterns: https://github.com/openxla/iree/blob/11388b8f54620c1f960aa9eb7a8d573e3b6d1335/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgDotProd.cpp#L274 The transformations are sane as we want two-level of parallelism: heads and batch. But rewriting to named ops introduces collapse and expand ops which are difficult to propagate through. We decided to use the input from TensorFlow.

Optimization, high-level idea

  • Tile and fuse along the batch and the head dimension.
  • Map contractions to GEMMs by mapping the most minor m, n and k dimensions.

Packing

Assuming we pack each linalg.matmul using a blocking factors of 32. The code after packing is shown below:

// unpack operation right after the first packed matmul
%unpack = tensor.unpack %5 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<8x16x32x32xf32> -> tensor<256x512xf32>
// expand shape in the original layout
%expanded = tensor.expand_shape %unpack [[0, 1], [2, 3]] : tensor<256x512xf32> into tensor<64x4x8x64xf32>

It is unclear how to propagate through the tensor.unpack. A possible solution is to try to sink the expand_shape. Is this the right abstraction for running the packing propagation? The expand_shape are a consequence on how IREE rewrites the IR, perhaps we should avoid applying this patterns in the first place?

Packing overhead

----------------------------------------------------------------
Benchmark                      Time             CPU   Iterations
----------------------------------------------------------------
BM_transpose_tpp          173069 ns       173066 ns         4046
BM_pack_tpp               173082 ns       173079 ns         4044
BM_unpack_tpp             591061 ns       591031 ns         1110
BM_mha_projection_tpp   12095330 ns     12095010 ns           58
BM_brgemm_tpp            9632356 ns      9631671 ns           72

BM_transpose_tpp is a linalg.transpose lowered using tpp.identity, BM_pack_tpp is a pack operation lowered using tpp.identity. BM_unpack_tpp is an unpack operation lowered using a tpp.identity and a linalg.copy, the latter is needed to preserve DPS. BM_brgemm_tpp is a BRGEMM operation lowered using tpp.brgemm while BM_mha_projection_tpp is an entire projection (BRGEMM + tpp.zero as init + pack on A operand and unpack on output).

Additional sources:

Layer norm from torch:

https://gist.github.com/silvasean/b9c5f60dfbe3f51cf079bad3c76d095a

Gelu

import tensorflow as tf
import keras_nlp
from tensorflow.python.pywrap_mlir import import_graphdef
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.compiler.mlir import mlir
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

class SimpleGelu(tf.Module):
    def __init__(self):
        super().__init__()

    @tf.function(jit_compile=True)
    def forward(self, x):
        return tf.keras.activations.gelu(x, approximate=False)

x = tf.keras.Input(shape=[32, 512], batch_size=64, dtype=tf.float32)

model = SimpleGelu()
concrete_func = model.forward.get_concrete_function(tf.TensorSpec(shape=x.shape, dtype=tf.float32))
concrete_func = convert_variables_to_constants_v2(concrete_func)
graph = concrete_func.graph.as_graph_def()

mlir_tf = import_graphdef(
        graph,
        "tf-standard-pipeline",
        False,
        input_names=["x"],
        input_data_types=["DT_FLOAT"],
        input_data_shapes=["64,32,512"],
        output_names=["Identity:0"]
    )
print(mlir_tf)

Translates to:

module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1395 : i32}} {
  func.func @main(%arg0: tensor<64x32x512xf32>) -> tensor<64x32x512xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "x", outputs = "Identity:0"}} {
    %cst = "tf.Const"() {value = dense<0.707106769> : tensor<f32>} : () -> tensor<f32>
    %cst_0 = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
    %cst_1 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
    %0 = "tf.Mul"(%arg0, %cst_1) {device = ""} : (tensor<64x32x512xf32>, tensor<f32>) -> tensor<64x32x512xf32>
    %1 = "tf.Mul"(%arg0, %cst) : (tensor<64x32x512xf32>, tensor<f32>) -> tensor<64x32x512xf32>
    %2 = "tf.Erf"(%1) {device = ""} : (tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
    %3 = "tf.AddV2"(%2, %cst_0) {device = ""} : (tensor<64x32x512xf32>, tensor<f32>) -> tensor<64x32x512xf32>
    %4 = "tf.Mul"(%0, %3) {device = ""} : (tensor<64x32x512xf32>, tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
    %5 = "tf.Identity"(%4) {device = ""} : (tensor<64x32x512xf32>) -> tensor<64x32x512xf32>
    return %5 : tensor<64x32x512xf32>
  }
}

Lowering to Linalg using: ./tf-opt gelu.mlir -tf-lower-to-mlprogram-and-hlo -stablehlo-legalize-to-hlo -hlo-legalize-to-linalg -canonicalize -cse

We get:

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  func.func @main(%arg0: tensor<64x32x512xf32>) -> tensor<64x32x512xf32> {
    %cst = arith.constant dense<1.000000e+00> : tensor<64x32x512xf32>
    %cst_0 = arith.constant dense<0.707106769> : tensor<64x32x512xf32>
    %cst_1 = arith.constant dense<5.000000e-01> : tensor<64x32x512xf32>
    %cst_2 = arith.constant dense<-1.000000e+00> : tensor<64x32x512xf32>
    %cst_3 = arith.constant dense<-0.0142647391> : tensor<64x32x512xf32>
    %cst_4 = arith.constant dense<-0.00737332925> : tensor<64x32x512xf32>
    %cst_5 = arith.constant dense<-0.00168282702> : tensor<64x32x512xf32>
    %cst_6 = arith.constant dense<-2.13374049E-4> : tensor<64x32x512xf32>
    %cst_7 = arith.constant dense<-1.45660715E-5> : tensor<64x32x512xf32>
    %cst_8 = arith.constant dense<-0.0160960332> : tensor<64x32x512xf32>
    %cst_9 = arith.constant dense<-2.954600e-03> : tensor<64x32x512xf32>
    %cst_10 = arith.constant dense<-7.34990637E-4> : tensor<64x32x512xf32>
    %cst_11 = arith.constant dense<-5.69250624E-5> : tensor<64x32x512xf32>
    %cst_12 = arith.constant dense<-2.10102394E-6> : tensor<64x32x512xf32>
    %cst_13 = arith.constant dense<2.77068146E-8> : tensor<64x32x512xf32>
    %cst_14 = arith.constant dense<-2.72614237E-10> : tensor<64x32x512xf32>
    %cst_15 = arith.constant dense<4.000000e+00> : tensor<64x32x512xf32>
    %cst_16 = arith.constant dense<-4.000000e+00> : tensor<64x32x512xf32>
    %0 = tensor.empty() : tensor<64x32x512xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst_1 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst_0 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %3 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_16, %2, %cst_15 : tensor<64x32x512xf32>, tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %in_18: f32, %out: f32):
      %30 = arith.maxf %in, %in_17 : f32
      %31 = arith.minf %30, %in_18 : f32
      linalg.yield %31 : f32
    } -> tensor<64x32x512xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %3 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %cst_14 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5, %cst_13 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %8 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %cst_12 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%9, %cst_11 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %11 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %12 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%11, %cst_10 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %13 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %14 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13, %cst_9 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%14, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %16 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %cst_8 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %17 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %cst_7 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %18 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%17, %cst_6 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %20 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%19, %cst_5 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %21 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%20, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %22 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%21, %cst_4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %23 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%22, %4 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %24 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%23, %cst_3 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %25 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %16 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %26 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%25, %24 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.divf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %27 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_2, %26, %cst : tensor<64x32x512xf32>, tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %in_18: f32, %out: f32):
      %30 = arith.maxf %in, %in_17 : f32
      %31 = arith.minf %30, %in_18 : f32
      linalg.yield %31 : f32
    } -> tensor<64x32x512xf32>
    %28 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%27, %cst : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.addf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    %29 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %28 : tensor<64x32x512xf32>, tensor<64x32x512xf32>) outs(%0 : tensor<64x32x512xf32>) {
    ^bb0(%in: f32, %in_17: f32, %out: f32):
      %30 = arith.mulf %in, %in_17 : f32
      linalg.yield %30 : f32
    } -> tensor<64x32x512xf32>
    return %29 : tensor<64x32x512xf32>
  }
}