Overview
We’ve seen that RNN and CNN has a longer maximum path length. CNN could have better computational complexity for long sequences, but overall, self attention is the best for deep architectures. The transformer depends solely on self attention, and does not have convolutional or recurrent layers, unlike its predecessors, like seq2seq
[1].
Transformer was proposed for sequence-to-sequence learning on text data, but it’s gained popularity in speech, vision, and reinforcement learning tasks as well.
The Transformer has an encoder-decoder architecture.
- Different from Bahdanau Attention, input is added with positional encoding before being fed into the encoder and the decoder
A good custom implementation is here
Encoder
The encoder has one multi-head self-attention pooling and one positionwise feed-forward network (FFN) modules. Some highlights are:
- In the multi-head self-attention pooling, the queries, keys, and values are the previous encoder output.
- Inspired by ResNet, a residual connection (or skip connection) is added to boost the input signal. Because of this, no layer in the encoder changes the shape of the input.
- The positionwise feedforward networks transforms embeddings at all timesteps using the same multi-layer perceptrons (MLP). So, they do not perform any time-wise operations.
- Positionwise FFN COULD have a different hidden layer dimension within itself, as shown below. It just needs to output the same dimension.
Now, let’s enjoy some code.
Positionwise FFN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
class PositionwiseFFN(torch.nn.Module):
def __init__(self, hidden_dim, output_dim) -> None:
super().__init__()
self.dense1 = torch.nn.LazyLinear(hidden_dim)
self.relu = torch.nn.ReLU()
self.dense2 = torch.nn.LazyLinear(output_dim)
def forward(self, X):
# (batch size, number of time steps, output_dim).
return self.dense2(self.relu(self.dense1(X)))
ffn = PositionwiseFFN(4, 8)
ffn.eval()
# see (2, 3, 8)
print(ffn(torch.ones((2, 3, 4))).shape)
Encoder Layer
The encoder layer has 1 multi-head attention (self attention). There are two Add&Norm
layers. Either takes in a skip connection.
In this layer, we stick to the embedding_dim
.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class EncoderLayer(torch.nn.Module):
def __init__(
self,
embedding_dim,
num_heads,
dropout_rate=0.1,
) -> None:
super().__init__()
# need dropout. The torch implementation already has it
self.mha = MultiHeadAttention(
embed_dim=embedding_dim,
num_heads=num_heads,
)
self.dropout1 = torch.nn.Dropout(p=dropout_rate)
self.ffn = PositionwiseFFN(hidden_dim=embedding_dim, output_dim=embedding_dim)
self.layernorm1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.layernorm2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.dropout2 = torch.nn.Dropout(p=dropout_rate)
def forward(self, X, attn_mask, key_padding_mask):
# Self attention (input_seq_len, batch_size, embedding_dim)
self_attn_output, self_attn_weight = self.mha(
X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
# apply dropout layer to the self-attention output (~1 line)
self_attn_output = self.dropout1(
self_attn_output,
)
# Applying Skip Connection
mult_attn_out = self.layernorm1(
X + self_attn_output
) # (input_seq_len, batch_size, embedding_dim)
ffn_output = self.ffn(
mult_attn_out
) # (input_seq_len, batch_size, embedding_dim)
ffn_output = self.dropout2(ffn_output)
# Applying Skip Connection
encoder_layer_out = self.layernorm2(
ffn_output + mult_attn_out
) # (input_seq_len, batch_size, embedding_dim)
return encoder_layer_out
The full encoder output
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Encoder(torch.nn.Module):
def __init__(
self,
embedding_dim,
input_vocab_dim,
encoder_layer_num,
num_heads,
max_sentence_length,
dropout_rate=0.1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.positional_encoder = OGPositionalEncoder(
max_sentence_length=max_sentence_length, embedding_size=self.embedding_dim
)
self.embedding_converter = torch.nn.Embedding(
num_embeddings=input_vocab_dim, embedding_dim=self.embedding_dim
)
self.dropout_pre_encoder = torch.nn.Dropout(p=dropout_rate)
self.encoder_layers = torch.nn.ModuleList(
[
EncoderLayer(
embedding_dim=self.embedding_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
)
for _ in range(encoder_layer_num)
]
)
def forward(self, X, enc_padding_mask):
# X: [Batch_Size, Sentence_length]
X = self.embedding_converter(
X
) # X: [Batch_Size, Sentence_length, embedding_size]
X *= math.sqrt(float(self.embedding_dim))
# [Batch_Size, Sentence_length, embedding_dim]
X = self.positional_encoder(X) # applies positional encoding in addition
X = self.dropout_pre_encoder(X)
X = X.permute(1, 0, 2) # [input_seq_len, batch_size, qk_dim]
for encoder_layer in self.encoder_layers:
X = encoder_layer(X, attn_mask=None, key_padding_mask=enc_padding_mask)
X = X.permute(1, 0, 2) # [batch_size, input_seq_len, qk_dim]
return X
- Scaling: the embeddings are scaled by $\sqrt{\text{embedding_dimension}}$” before adding positional encodings so their magnitudes match. There’s a StackExchange thread on why exactly this is needed. However, some were also wondering about its necessity
Decoder Layer
A decoder layer has 1 multi-head self attention, and 1 encoder-decoder attention. In this layer, we do not change the embedding dimension, either.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class DecoderLayer(torch.nn.Module):
def __init__(
self,
embedding_dim,
num_heads,
dropout_rate=0.1,
) -> None:
super().__init__()
# need dropout. The torch implementation already has it
self.mha1 = MultiHeadAttention(
embed_dim=embedding_dim,
num_heads=num_heads,
)
self.mha2 = MultiHeadAttention(
embed_dim=embedding_dim,
num_heads=num_heads,
)
self.dropout1 = torch.nn.Dropout(p=dropout_rate)
self.dropout2 = torch.nn.Dropout(p=dropout_rate)
self.dropout3 = torch.nn.Dropout(p=dropout_rate)
self.ffn = PositionwiseFFN(hidden_dim=embedding_dim, output_dim=embedding_dim)
self.layernorm1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.layernorm2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.layernorm3 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, X, enc_output, attn_mask, key_padding_mask):
"""
Args:
X : embedding from output sequence [output_seq_len, batch_size, qk_dim]
enc_output : embedding from encoder
attn_mask : Boolean mask for the target_input to ensure autoregression
key_padding_mask : Boolean mask for the second multihead attention layer
Returns:
decoder output:
"""
self_attn_output, decoder_self_attn_weight = self.mha1(
X, X, X, attn_mask=attn_mask, key_padding_mask=None
)
# apply dropout layer to the self-attention output (~1 line)
self_attn_output = self.dropout1(
self_attn_output,
)
# Applying Skip Connection
out1 = self.layernorm1(
X + self_attn_output
) # (output_seq_len, batch_size, embedding_dim)
self_attn_output, decoder_encoder_attn_weight = self.mha2(
out1,
enc_output,
enc_output,
attn_mask=None,
key_padding_mask=key_padding_mask,
)
# apply dropout layer to the self-attention output (~1 line)
self_attn_output = self.dropout2(
self_attn_output,
)
# Applying Skip Connection
out2 = self.layernorm2(
out1 + self_attn_output
) # (output_seq_len, batch_size, embedding_dim)
ffn_output = self.ffn(out2) # (output_seq_len, batch_size, embedding_dim)
ffn_output = self.dropout3(ffn_output)
# Applying Skip Connection
out3 = self.layernorm2(
ffn_output + out2
) # (output_seq_len, batch_size, embedding_dim)
return out3, decoder_self_attn_weight, decoder_encoder_attn_weight
Decoder
The decoder also has residual connections, normalizations, two attention pooling modules, and one positionwise FFN module.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Decoder(torch.nn.Module):
def __init__(
self,
embedding_dim,
num_heads,
target_vocab_dim,
decoder_layer_num,
max_sentence_length,
dropout_rate=0.1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.positional_encoder = OGPositionalEncoder(
max_sentence_length=max_sentence_length, embedding_size=self.embedding_dim
)
self.embedding_converter = torch.nn.Embedding(
num_embeddings=target_vocab_dim, embedding_dim=self.embedding_dim
)
self.dropout_pre_decoder = torch.nn.Dropout(p=dropout_rate)
self.dec_layers = torch.nn.ModuleList(
[
DecoderLayer(
embedding_dim=self.embedding_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
)
for _ in range(decoder_layer_num)
]
)
def forward(self, X, enc_output, lookahead_mask, key_padding_mask):
"""
Args:
X : [batch_size, output_sentences_length]
enc_output : [batch_size, input_seq_len, qk_dim].
TODO: This might be a small discrepancy from the torch implementation, which is [input_seq_len, batch_size, qk_dim]
lookahead_mask : [num_queries, num_keys]
key_padding_mask : [batch_size, num_keys]
"""
# [batch_size, output_sentences_length]
X = self.embedding_converter(X)
X *= math.sqrt(float(self.embedding_dim))
X = self.positional_encoder(X) # applies positional encoding in addition
X = self.dropout_pre_decoder(X)
X = X.permute(1, 0, 2) # [output_seq_len, batch_size, qk_dim]
enc_output = enc_output.permute(1, 0, 2)
# [num_keys, batch_size, qk_dim]
decoder_self_attns, decoder_encoder_attns = [], []
for decoder_layer in self.dec_layers:
X, decoder_self_attn, decoder_encoder_attn = decoder_layer(
X,
enc_output,
attn_mask=lookahead_mask,
key_padding_mask=key_padding_mask,
)
decoder_self_attns.append(decoder_self_attn)
decoder_encoder_attns.append(decoder_encoder_attn)
X = X.permute(1, 0, 2) # [batch_size, output_seq_len, qk_dim]
return X, decoder_self_attns, decoder_encoder_attns
- The first attention module is a self-attention module.
- Its queries, keys and values are all from the decoder.
- It uses an lookahead mask, or attention mask, which preserves the autoregressive property, ensuring that the prediction only depends on those output tokens that have been generated.
- The attention module between the first self-attention module and the positionwise FFN module is called “encoder-decoder attention”.
- This layer uses a padding mask.
- Queries are from the decoder’s self-attention layer
- Keys and values are from the encoder.
All Together
Phew, what a journey! Good job in making it this far. Let’s now put all these pieces together. All code snippets in this post have been tested against their PyTorch counterparts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class Transformer(torch.nn.Module):
def __init__(
self,
embedding_dim,
input_vocab_dim,
target_vocab_dim,
layer_num,
num_heads,
max_sentence_length,
dropout_rate=0.1,
) -> None:
super().__init__()
self.encoder = Encoder(
embedding_dim=embedding_dim,
input_vocab_dim=input_vocab_dim,
encoder_layer_num=layer_num,
num_heads=num_heads,
max_sentence_length=max_sentence_length,
dropout_rate=dropout_rate,
)
self.decoder = Decoder(
embedding_dim=embedding_dim,
num_heads=num_heads,
target_vocab_dim=target_vocab_dim,
decoder_layer_num=layer_num,
max_sentence_length=max_sentence_length,
dropout_rate=dropout_rate,
)
self.final_dense_layer = torch.nn.Linear(
in_features=embedding_dim,
out_features=target_vocab_dim,
bias=False,
)
self.final_relu = torch.nn.ReLU()
self.final_softmax = torch.nn.Softmax(dim=-1)
def forward(
self,
input_sentences,
output_sentences,
enc_padding_mask,
attn_mask,
dec_padding_mask,
):
# input_sentences: [Batch_Size, input_sentences_length]
# [batch_size, input_seq_len, qk_dim]
enc_output = self.encoder(X=input_sentences, enc_padding_mask=enc_padding_mask)
# [batch_size, output_seq_len, qk_dim]
dec_output, decoder_self_attns, decoder_encoder_attns = self.decoder(
X=output_sentences,
enc_output=enc_output,
lookahead_mask=attn_mask,
key_padding_mask=dec_padding_mask,
)
# This is basically the raw logits.
# THIS IS ASSUMING THAT WE ARE USING CROSS_ENTROPY LOSS
# [batch_size, output_seq_len,target_vocab_dim]
logits = self.final_dense_layer(dec_output)
return logits, decoder_self_attns, decoder_encoder_attns
- At the end, we want the probabilities across target language words, so softmax is needed for training. However, ReLu is not advised here, because it could distort the relative differences between logits by setting the negative ones to 0. The standard practice is: No ReLu between Linear and Softmax.
- Also, THIS IS ASSUMING THAT WE ARE USING CROSS_ENTROPY LOSS. So here we are not adding a softmax layer here.
Advantages and Disadvantages of Transformer
Advantages:
- Parallel computing. Transformer abandoned the CNN and RNN architectures that were used for decades.
- The input is
[batch_size, input_seq_len, input_vocab_dim]
, the output is[batch_size, output_seq_len,target_vocab_dim]
. So unlike RNN architecutres which parse a sequence step by step, attention pooling with multiple heads (or partitions of attention) in parallel.
- The input is
Disadvantages:
- Local feature extraction (like in CNN) is lacking.
Tasks and Data
It’s common practice to pad input sequences to MAX_SENTENCE_LENGTH
. Therefore,
- the input is always [batch_size, max_sentence_length]
NUM_KEYS = NUM_QUERIES = max_sentence_length
since neither the encoder nor the decoder changes themax_sentence_length
dimension
In practice, one can apply below methods to reduce padding:
- Bucketing - bucketing is to group sentences of similar lengths to reduce sentence lengths.
- Packed Sequences: PyTorch’s
pack_padded_sequence
andpad_packed_sequence
utilities (more common in RNNs) to handle variable-length sequences.
Applications
- Machine Translation (using World-Machine-Translation datasets)
- Named Entity Recognition (like extracting “phone number” from resumes)