-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmulti_attention.py
76 lines (56 loc) · 2.25 KB
/
multi_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"Papper link https://arxiv.org/abs/1706.03762"
import tensorflow as tf
from tensorflow.keras.layers import *
batch_size = 32
seq_len = 128
d_k = 256
d_v = 256
n_heads = 12
ff_dim = 256
class SingleAttention(Layer):
def __init__(self, d_k, d_v):
super(SingleAttention, self).__init__()
self.d_k = d_k
self.d_v = d_v
def build(self, input_shape):
self.query = Dense(self.d_k,
input_shape=input_shape,
kernel_initializer='glorot_uniform',
bias_initializer='glorot_uniform')
self.key = Dense(self.d_k,
input_shape=input_shape,
kernel_initializer='glorot_uniform',
bias_initializer='glorot_uniform')
self.value = Dense(self.d_v,
input_shape=input_shape,
kernel_initializer='glorot_uniform',
bias_initializer='glorot_uniform')
def call(self, inputs): # inputs = (in_seq, in_seq, in_seq)
q = self.query(inputs[0])
k = self.key(inputs[1])
attn_weights = tf.matmul(q, k, transpose_b=True)
attn_weights = tf.map_fn(lambda x: x/np.sqrt(self.d_k), attn_weights)
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
v = self.value(inputs[2])
attn_out = tf.matmul(attn_weights, v)
return attn_out
class MultiAttention(Layer):
def __init__(self, d_k, d_v, n_heads):
super(MultiAttention, self).__init__()
self.d_k = d_k
self.d_v = d_v
self.n_heads = n_heads
self.attn_heads = list()
def build(self, input_shape):
for n in range(self.n_heads):
self.attn_heads.append(SingleAttention(self.d_k, self.d_v))
# input_shape[0]=(batch, seq_len, 7), input_shape[0][-1]=7
self.linear = Dense(input_shape[0][-1],
input_shape=input_shape,
kernel_initializer='glorot_uniform',
bias_initializer='glorot_uniform')
def call(self, inputs):
attn = [self.attn_heads[i](inputs) for i in range(self.n_heads)]
concat_attn = tf.concat(attn, axis=-1)
multi_linear = self.linear(concat_attn)
return multi_linear