Neural Machine Translation Using Transformer¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
In this tutorial we will look at a sequence to sequence task: translating one language into another. The architecture used for the task is the famous Transformer
.
The general idea behind the transformer architecture is the attention
mechanism that can perform a re-weighting of the features throughout the network. Another advantage brought by the transformer architecture is that it breaks the temporal dependency of the data, allowing more efficient parallelization of training. We will implement every detail of the transformer in this tutorial. Let's get started!
First let's define some hyper-parameters that we will use later.
data_dir = None
epochs=20
em_dim=128
batch_size=64
train_steps_per_epoch=None
eval_steps_per_epoch=None
Dataset¶
In this machine translation task, we will use the TED translation dataset. The dataset consists of 14 different translation tasks, such as Portuguese to English (pt_to_en
), Russian to English (ru_to_en
), and many others. In this tutorial, we will translate Portuguese to English. You can access this dataset through our dataset API - tednmt
. Feel free to check the docstring of the API for other translation options.
from fastestimator.dataset.data import tednmt
train_ds, eval_ds, test_ds = tednmt.load_data(data_dir, translate_option="pt_to_en")
Now that the dataset is downloaded, let's check what the dataset looks like:
print("example source language:")
print(train_ds[0]["source"])
print("")
print("example target language:")
print(train_ds[0]["target"])
example source language: entre todas as grandes privações com que nos debatemos hoje — pensamos em financeiras e económicas primeiro — aquela que mais me preocupa é a falta de diálogo político — a nossa capacidade de abordar conflitos modernos como eles são , de ir à raiz do que eles são e perceber os agentes-chave e lidar com eles . example target language: amongst all the troubling deficits we struggle with today — we think of financial and economic primarily — the ones that concern me most is the deficit of political dialogue — our ability to address modern conflicts as they are , to go to the source of what they 're all about and to understand the key players and to deal with them .
Preprocessing the languages¶
Since the text by itself cannot be recognized by computers, we need to perform a series of transformations to the text. Here are the steps:
- Split the sentence into words or sub-words. For example, "I love apple" can be split into ["I", "love", "apple"]. Sometimes in order to represent more words, a word is further reduced into sub-words. For example,
tokenization
can be split intotoken
and_ization
. As a result, a word like "civilization" doesn't require extra space when bothcivil
and_ization
are already in the dictionary. - Map the tokens into a discrete index according to the dictionary. In this task, we are loading a pretrained tokenizer with a built-in dictionary already.
- Add a [start] and [end] token around every index. This is mainly to help the network identify the beginning and end of each sentence.
- When creating a batch of multiple sentences, pad the shorter sentences with 0 so that each sentence in the batch has the same length.
import fastestimator as fe
from transformers import BertTokenizer
from fastestimator.op.numpyop import Batch, NumpyOp
import numpy as np
class Encode(NumpyOp):
def __init__(self, tokenizer, inputs, outputs, mode=None):
super().__init__(inputs=inputs, outputs=outputs, mode=mode)
self.tokenizer = tokenizer
def forward(self, data, state):
return np.array(self.tokenizer.encode(data))
pt_tokenizer = BertTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
pipeline = fe.Pipeline(
train_data=train_ds,
eval_data=eval_ds,
test_data=test_ds,
ops=[
Encode(inputs="source", outputs="source", tokenizer=pt_tokenizer),
Encode(inputs="target", outputs="target", tokenizer=en_tokenizer),
Batch(batch_size=batch_size, pad_value=0)
])
In the above code, tokenizer.encode
will take the sentence and execute the step 1 - 3. The padding step is done by providing pad_value=0
in the Batch
Op.
Preprocessing Results¶
data = pipeline.get_results()
print("source after processing:")
print(data["source"])
print("source batch shape:")
print(data["source"].shape)
print("---------------------------------------------------")
print("target after processing:")
print(data["target"])
print("target batch shape:")
print(data["target"].shape)
source after processing: tensor([[ 101, 420, 1485, ..., 1061, 119, 102], [ 101, 538, 179, ..., 0, 0, 0], [ 101, 122, 21174, ..., 0, 0, 0], ..., [ 101, 607, 230, ..., 0, 0, 0], [ 101, 123, 10186, ..., 0, 0, 0], [ 101, 11865, 3072, ..., 0, 0, 0]]) source batch shape: torch.Size([64, 72]) --------------------------------------------------- target after processing: tensor([[ 101, 5921, 2035, ..., 2068, 1012, 102], [ 101, 2057, 2040, ..., 0, 0, 0], [ 101, 1998, 1045, ..., 0, 0, 0], ..., [ 101, 2045, 1005, ..., 0, 0, 0], [ 101, 1996, 5424, ..., 0, 0, 0], [ 101, 2009, 2097, ..., 0, 0, 0]]) target batch shape: torch.Size([64, 70])
Transformer Architecture¶
Attention Unit¶
The basic form of the attention unit is defined in scaled_dot_product_attention
. Given a set of queries(Q), keys(K), and values(V), it first performs the matrix multiplication of Q and K. The output of this multiplication gives the matching score between various elements of Q and K. Then all the weights are normalized across the Keys dimension. Finally, the normalized score will be multiplied by the V to get the final result. The intuition behind the attention unit is essentially a dictionary look-up with interpolation.
import tensorflow as tf
def scaled_dot_product_attention(q, k, v, mask):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
num_heads, inp_length = tf.shape(scaled_attention_logits)[1], tf.shape(scaled_attention_logits)[2]
num_heads_mask, inp_length_mask = tf.shape(mask)[1], tf.shape(mask)[2]
# This manual tiling is to fix a auto-broadcasting issue with tensorflow
scaled_attention_logits += tf.tile(mask * -1e9, [1, num_heads // num_heads_mask, inp_length // inp_length_mask, 1])
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output
def point_wise_feed_forward_network(em_dim, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(em_dim) # (batch_size, seq_len, em_dim)
])
Multi-head Attention¶
There are two drawbacks of the attention unit above:
- The complexity of matrix multiplication is O(N^3), when batch size or embedding dimension increases, the computation will not scale well.
- A single attention head is limited in expressing local correlation between two words, because it calculates correlation by normalizing all embeddings dimensions. Sometimes this overall normalization will remove interesting local patterns. A good analogy is to consider a single attention unit as globally averaging a signal whereas a moving average is preferred to preserve certain information.
Multi-head attention is used to overcome the issues above. It breaks the embedding dimension into multiple heads. As a result, each head's embedding dimension is divided by the number of heads, reducing the computation complexity. Moreover, each head only takes a fraction of the embedding and can be viewed as a specialized expert for a specific context. The final results can be combined using another dense layer.
from tensorflow.keras import layers
class MultiHeadAttention(layers.Layer):
def __init__(self, em_dim, num_heads):
super().__init__()
assert em_dim % num_heads == 0, "model dimension must be multiply of number of heads"
self.num_heads = num_heads
self.em_dim = em_dim
self.depth = em_dim // self.num_heads
self.wq = layers.Dense(em_dim)
self.wk = layers.Dense(em_dim)
self.wv = layers.Dense(em_dim)
self.dense = layers.Dense(em_dim)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3]) # B, num_heads, seq_len, depth
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # B, seq_len, em_dim
k = self.wk(k) # B, seq_len, em_dim
v = self.wv(v) # B, seq_len, em_dim
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention = scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) #B, seq_len, num_heads, depth
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.em_dim)) # B, seq_len, em_dim
output = self.dense(concat_attention)
return output
Encoder and Decoder layer¶
Both the encoder and decoder layers will go through multi-head attention. The decoder layer will use another multi-attention module to connect the bridge between encoder outputs and targets. Specifically, in the decoders second multi-head attention module, encoded output is used as both values and keys whereas the target embedding is used as a query to "look up" encoder information. In the end, there is a feed-forward neural network to transform the looked-up value into something useful.
class EncoderLayer(layers.Layer):
def __init__(self, em_dim, num_heads, dff, rate=0.1):
super().__init__()
self.mha = MultiHeadAttention(em_dim, num_heads)
self.ffn = point_wise_feed_forward_network(em_dim, dff)
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = layers.Dropout(rate)
self.dropout2 = layers.Dropout(rate)
def call(self, x, training, mask):
attn_output = self.mha(x, x, x, mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output)
return out2
class DecoderLayer(layers.Layer):
def __init__(self, em_dim, num_heads, diff, rate=0.1):
super().__init__()
self.mha1 = MultiHeadAttention(em_dim, num_heads)
self.mha2 = MultiHeadAttention(em_dim, num_heads)
self.ffn = point_wise_feed_forward_network(em_dim, diff)
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = layers.Dropout(rate)
self.dropout2 = layers.Dropout(rate)
self.dropout3 = layers.Dropout(rate)
def call(self, x, enc_out, training, decode_mask, padding_mask):
attn1 = self.mha1(x, x, x, decode_mask)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
attn2 = self.mha2(enc_out, enc_out, out1, padding_mask)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1)
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2)
return out3
Putting Everything Together¶
A transformer consists of an Encoder and Decoder, which in turn consist of multiple stacked encoder/decoder layers. One interesting property of transformers is that they do not have an intrinsic awareness of the position dimension. Therefore, a position encoding is usually done to the embedding matrix to add position context to the embedding. A nice tutorial about positional encoding can be found here.
def get_angles(pos, i, em_dim):
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(em_dim))
return pos * angle_rates
def positional_encoding(position, em_dim):
angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(em_dim)[np.newaxis, :], em_dim)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
class Encoder(layers.Layer):
def __init__(self, num_layers, em_dim, num_heads, dff, input_vocab, max_pos_enc, rate=0.1):
super().__init__()
self.em_dim = em_dim
self.num_layers = num_layers
self.embedding = layers.Embedding(input_vocab, em_dim)
self.pos_encoding = positional_encoding(max_pos_enc, self.em_dim)
self.enc_layers = [EncoderLayer(em_dim, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = layers.Dropout(rate)
def call(self, x, mask, training=None):
seq_len = tf.shape(x)[1]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.em_dim, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
return x
class Decoder(layers.Layer):
def __init__(self, num_layers, em_dim, num_heads, dff, target_vocab, max_pos_enc, rate=0.1):
super().__init__()
self.em_dim = em_dim
self.num_layers = num_layers
self.embedding = layers.Embedding(target_vocab, em_dim)
self.pos_encoding = positional_encoding(max_pos_enc, em_dim)
self.dec_layers = [DecoderLayer(em_dim, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = layers.Dropout(rate)
def call(self, x, enc_output, decode_mask, padding_mask, training=None):
seq_len = tf.shape(x)[1]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.em_dim, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.dec_layers[i](x, enc_output, training, decode_mask, padding_mask)
return x
def transformer(num_layers, em_dim, num_heads, dff, input_vocab, target_vocab, max_pos_enc, max_pos_dec, rate=0.1):
inputs = layers.Input(shape=(None, ))
targets = layers.Input(shape=(None, ))
encode_mask = layers.Input(shape=(None, None, None))
decode_mask = layers.Input(shape=(None, None, None))
x = Encoder(num_layers, em_dim, num_heads, dff, input_vocab, max_pos_enc, rate=rate)(inputs, encode_mask)
x = Decoder(num_layers, em_dim, num_heads, dff, target_vocab, max_pos_dec, rate=rate)(targets,
x,
decode_mask,
encode_mask)
x = layers.Dense(target_vocab)(x)
model = tf.keras.Model(inputs=[inputs, targets, encode_mask, decode_mask], outputs=x)
return model
model = fe.build(
model_fn=lambda: transformer(num_layers=4,
em_dim=em_dim,
num_heads=8,
dff=512,
input_vocab=pt_tokenizer.vocab_size,
target_vocab=en_tokenizer.vocab_size,
max_pos_enc=1000,
max_pos_dec=1000),
optimizer_fn="adam")
Network Operations¶
Now that we have defined the transformer architecture, another thing that is worth mentioning is the mask. A mask is a boolean array that we created to tell the network to ignore certain words within the sentence. For example, to tell the network to ignore the words we padded, a padding mask is used. In order to not give away the answer when processing the word before it, a mask is also needed.
The loss function of transformer is simply a masked cross entropy loss, as it will only consider predictions that are not masked.
from fastestimator.op.tensorop import TensorOp
from fastestimator.op.tensorop.loss import LossOp
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
class CreateMasks(TensorOp):
def forward(self, data, state):
inp, tar = data
encode_mask = self.create_padding_mask(inp)
dec_look_ahead_mask = self.create_look_ahead_mask(tf.shape(tar)[1])
dec_target_padding_mask = self.create_padding_mask(tar)
decode_mask = tf.maximum(dec_target_padding_mask, dec_look_ahead_mask)
return encode_mask, decode_mask
@staticmethod
def create_padding_mask(seq):
seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
@staticmethod
def create_look_ahead_mask(size):
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
class ShiftData(TensorOp):
def forward(self, data, state):
target = data
return target[:, :-1], target[:, 1:]
class MaskedCrossEntropy(LossOp):
def __init__(self, inputs, outputs, mode=None):
super().__init__(inputs=inputs, outputs=outputs, mode=mode)
self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def forward(self, data, state):
y_pred, y_true = data
mask = tf.cast(tf.math.logical_not(tf.math.equal(y_true, 0)), tf.float32)
loss = self.loss_fn(y_true, y_pred) * mask
loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
return loss
network = fe.Network(ops=[
ShiftData(inputs="target", outputs=("target_inp", "target_real")),
CreateMasks(inputs=("source", "target_inp"), outputs=("encode_mask", "decode_mask")),
ModelOp(model=model, inputs=("source", "target_inp", "encode_mask", "decode_mask"), outputs="pred"),
MaskedCrossEntropy(inputs=("pred", "target_real"), outputs="ce"),
UpdateOp(model=model, loss_name="ce")
])
Metrics and Learning Rate Scheduling¶
The metric used to evaluate the model is a masked accuracy, which is simply accuracy with unmasked predictions and ground truths. The learning rate scheduler uses warm-up followed by exponential decay.
import tempfile
from fastestimator.trace.adapt import LRScheduler
from fastestimator.trace.io import BestModelSaver
from fastestimator.trace.metric.bleu_score import BleuScore
from fastestimator.trace.trace import Trace
model_dir=tempfile.mkdtemp()
def lr_fn(step, em_dim, warmupstep=4000):
lr = em_dim**-0.5 * min(step**-0.5, step * warmupstep**-1.5)
return lr
class MaskedAccuracy(Trace):
def on_epoch_begin(self, data):
self.correct = 0
self.total = 0
def on_batch_end(self, data):
y_pred, y_true = data["pred"].numpy(), data["target_real"].numpy()
mask = np.logical_not(y_true == 0)
matches = np.logical_and(y_true == np.argmax(y_pred, axis=2), mask)
self.correct += np.sum(matches)
self.total += np.sum(mask)
def on_epoch_end(self, data):
data.write_with_log(self.outputs[0], self.correct / self.total)
traces = [
MaskedAccuracy(inputs=("pred", "target_real"), outputs="masked_acc", mode="!train"),
BleuScore(true_key="target_real", pred_key ="pred", output_name="bleu_score", n_gram=2, mode="!train"),
BestModelSaver(model=model, save_dir=model_dir, metric="masked_acc", save_best_mode="max"),
LRScheduler(model=model, lr_fn=lambda step: lr_fn(step, em_dim))]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
traces=traces,
epochs=epochs,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch)
Start the training¶
The training will take around 30 minutes on a single V100 GPU
estimator.fit()
Let's translate something!¶
def token_to_words(sample, tokenizer):
words = tokenizer.decode(sample)
if '[CLS]' in words:
words = words[words.index('[CLS]')+5:]
if '[SEP]' in words:
words = words[:words.index('[SEP]')]
return words
sample_test_data = pipeline.get_results(mode="test")
sample_test_data = network.transform(data=sample_test_data, mode="test")
source = sample_test_data["source"].numpy()
predicted = sample_test_data["pred"].numpy()
predicted = np.argmax(predicted, axis=-1)
grouth_truth = sample_test_data["target_real"].numpy()
index = np.random.randint(0, source.shape[0])
sample_source, sample_predicted, sample_groud_truth = source[index], predicted[index], grouth_truth[index]
print("Source Language: ")
print(token_to_words(sample_source, pt_tokenizer))
print("")
print("Translation Ground Truth: ")
print(token_to_words(sample_groud_truth, en_tokenizer))
print("")
print("Machine Translation: ")
print(token_to_words(sample_predicted, en_tokenizer))
Source Language: muito obrigado. Translation Ground Truth: thank you very much. Machine Translation: thank you very much.
You are welcome.