Shuan Chen

PhD Student in KAIST CBE

0%

Attention Mechanism

Attention is all you need, and it's still valid!

Attention is all you need

The attention mechanism is proposed in a legendary paper called Attention Is All You Need by Google Brain in NIPS 2017. Since then, people doing machine learning are using it in almost every feild, cited over 30K times until now. What does this Attention really is?

Intuition of attention

Let’s start from an example:

Shuan is writing a paper in a KAIST office.

In this sentence we may ask several questions

  1. Where is Shuan?
  2. What is Shuan doing?
  3. Who is Shuan?

To answer these three questions, we might focus on different part of the sentence. Let’s try to highlight the important terms for each question.

  1. Shuan is writing a paper in a KAIST office. —> Shuan is in a KAIST office.
  2. Shuan is writing a paper in a KAIST office. —> Shuan is writing paper.
  3. Shuan is writing a paper in a KAIST office. —> Maybe a graduate student in KAIST?

We’d try to focus on specific part of the text for a specific question instead of reading the whole sentence. This action of focus or importance is called attention here.

Attention in machine learning

While it may sound impossible for machine to understand such ATTENTION concept. But the team in Google Brain actually did it and that’s why it is so legendary.
So how does the machine understand ATTENTION?

In machine learning (or Pytorch), we always represent words as tensor by tokenization and word embedding

Now, we want the tensor of “Shuan” (Tensor1) to have the important information like “writing” (Tensor3) or “KAIST” (Tensor7) so we add the values of “writing” and “KAIST” to Shuan

So the tensor of Shuan now has the information from writing and KAIST, which makes it possible to answer the above three questions.

But how does the machine know which value of tensor should goes to Shuan (such as writing or KAIST) and which not (such as is ,a, in)? The answer is using key and query.

Key and Query

To know the importance of two words (tensors), the attention mechanism obtains the importance (e) by calculating the dot product of two tensors, query tensor and key tensor.

</sub>e{i, j} =\ Q{i}\odot K_{j}

Where the querys and keys are usually obtained by linear transformation of original tensor:

</sub>x^{y^z}=(1+{\rm e}^x)^{-2xy^w}

</sub>Q{i} = Linear{Q}(x_i)

</sub>K{i} = Linear{K}(x_i)

The model should learn the importance by fitting the model with two linear layers, Query layer and Key layer.
Next, the attention score is computed through a softmax function

</sub>a{ij} = \frac{exp(e{ij})} { \sum{j=1} exp(e{ij}) }

The attention score are then obtained abd shown as below

Finally the tensor is updated by adding up the values times attention score.

</sub>x’ {i} = x{i} + \sum{j=1} a{ij}V{j}, where V{j} = Linear{V}(x{j})

Attention in pytorch

It’s actually quite easy to implement in pytorch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch import nn

class Attention(nn.Module):
def __init__(self, d_model, dropout = 0.1):
super(Attention, self).__init__()
self.d_model = d_model
self.q_linear = nn.Linear(d_model, d_model, bias=False)
self.v_linear = nn.Linear(d_model, d_model, bias=False)
self.k_linear = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
bs = x.size(0)
k = self.k_linear(x)
q = self.q_linear(x)
v = self.v_linear(x)
scores = torch.matmul(q, k.transpose(-1, -2))
scores = torch.softmax(scores, dim=-1)
output = torch.matmul(self.dropout(scores) , v)
return output + x

Most people add some more tricks like multi-head attention and layer normalization to make it performs better

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout = 0.1):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.q_linear = nn.Linear(d_model, d_model, bias=False)
self.v_linear = nn.Linear(d_model, d_model, bias=False)
self.k_linear = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

def reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def attention(self, q, k, v):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = torch.softmax(scores, dim=-1)
output = torch.matmul(self.dropout(scores), v)
return scores, output

def forward(self, x):
bs = x.size(0)
k = self.k_linear(x).view(bs, -1, self.h, self.d_k)
q = self.q_linear(x).view(bs, -1, self.h, self.d_k)
v = self.v_linear(x).view(bs, -1, self.h, self.d_k)
k = k.transpose(1,2)
q = q.transpose(1,2)
v = v.transpose(1,2)
scores, output = self.attention(q, k, v)
output = output.transpose(1,2).contiguous().view(bs, self.d_model)
output = output + x
output = self.layer_norm(output)
return output

External Link

I strongly recommend this video and this blog to further understand the attention mechanism and trasnformer architecture with nicely made animation!