Generative Pretrained Transformer (GPT)¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
Generative Pretrained Transformer, also known as GPT, has demonstrated significant success in different language problems. GPT is a transformer-based generative model for language modeling task. Despite being trained only on language modeling task, GPT's capability can extend beyond the the language modeling scope to perform well on almost any language task. Furthermore, researchers found that as the model and data size increase, the GPT model is observed to gain foundational understanding of human language, such that any downstream language task can be achieved with little to no extra training. This finding has started a new "gold rush" in the field of AI - the pursuit of foundation model.
In this apphub, we will implement GPT model in FastEstiamtor using PyTorch backend. We will showcase the building blocks of GPT, along with its training details. The model we'll be training in this example won't be a billion or trillion parameter model, but with a slight adjustment of model parameters, our code base can be used to produce billion-parameter large language models with state-of-the-art performance.
Getting Things Ready¶
First let's get the import out of the way.
import random
import tempfile
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import fastestimator as fe
from fastestimator.dataset.data import wikitext_103
from fastestimator.op.numpyop import LambdaOp, NumpyOp
from fastestimator.op.tensorop import LambdaOp as TLambdaOp
from fastestimator.op.tensorop import TensorOp
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import BestModelSaver
Next, we define some parameters that we will use in this notebook:
data_dir=None
epochs=50
batch_size=32
context_len=512
num_blocks=6
em_dim=1024
ffwd_dim=4096
num_heads=16
save_dir=tempfile.mkdtemp()
train_steps_per_epoch=3000
eval_steps_per_epoch=500
Loading Data¶
For training we will use wikitext-103 dataset, which contains 28475 wiki articles and 103 million tokens. Although our training data contains hundreds of millions of tokens, our data is nothing comparing with what people train nowadays - the entire internet. But still, our training data serves a good demonstration purpose in this apphub.
class TextDataset(Dataset):
def __init__(self, file_path, num_chars=5000):
super().__init__()
self.texts = self._read_file(file_path)
self.num_chars = num_chars
@staticmethod
def _read_file(path):
text = ''.join(pd.read_parquet(path, engine='fastparquet')['text'].to_list())
return text
def __len__(self):
# this is just a placeholder, we use 'train_steps_per_epoch' to control training length
return 10000
def __getitem__(self, idx):
start_idx = random.randint(0, len(self.texts) - self.num_chars - 1)
random_text = self.texts[start_idx:start_idx + self.num_chars]
return {"x": random_text[random_text.index(" ") + 1:]} # always start from a new word
train_data, eval_data, test_data = wikitext_103.load_data(data_dir)
train_data, eval_data, test_data = TextDataset(train_data), TextDataset(eval_data), TextDataset(test_data)
Here we are reading the whole data into memory since the overall text file size is only ~200MB. In large-scale training, you will need to make sure the dataset reads data on the fly with NumpyOp
. For each training sample, our dataset class will randomly extract 5000 continuous characters from the entire pool of articles. One sample's text looks like following:
sample_text = train_data[0]['x'][:1000]
print(sample_text)
Two on November 12 , 1997 . The episode 's initial broadcast was viewed by approximately 16 @.@ 59 million people , which represented 16 % of the viewing audience during that time . Both Gillian Anderson and David Duchovny consider this among the best episodes of the fourth season . Composer Mark Snow was nominated for an Emmy Award for the music he produced for this episode . He said of the episode 's music , " It was a different kind of texture for the show . Light , magic , nothing terribly threatening " . Snow received many requests for a recording of the music used at the end of the episode . Website IGN named " Paper Hearts " their sixth favorite " standalone " episode of the show , calling it " creepy and unsettling " , and claiming Noonan 's character was " one of the most disturbing villains to make an appearance in the series " . Noonan 's acting has also been praised by Vince Gilligan , who says the " understated " manner in which Roche is portrayed " sends chills down [
Tokenization¶
AI model deals with numbers, not text. Therefore, we need a special process to convert text into number, and this is called tokenization
. Tokenization can take place in many levels: character level, word level, sentence level, and so on. The most popular approach used by current large language model is the word level.
The naive way of doing word-level tokenization is to simply store vocabulary of all possible words, and perform a mapping between possible word to index. However, the limitation of word mapping is that there are many out-of-vocabulary words people customize to express specific meanings, which makes it almost impossible for pre-defined vocabulary to capture all the words. For example the word FastEstimator
technically doesn't exist in English, but we all know what it means. :)
To overcome the above challenge, people created sub-word tokenization. This is more consistent with how english grammar works and it can significantly reduce the chance of encountering unknown words. More importantly, this also increases the reusability of tokens and reduce the problem space quite well.
There are many sweet toolkits that provide available tokenization functionality, this time we will use HuggingFace's GPT2 tokenization scheme.
tokenizer = AutoTokenizer.from_pretrained("gpt2")
Now let's play with tokenization on some sample text:
original_sample_text = sample_text[:50]
tokens = tokenizer.encode(original_sample_text)
decoded_text = tokenizer.decode(tokens)
print("Original Text: {}".format(original_sample_text))
print("Encoded Tokens: {}".format(tokens))
print("Decoded Text: {}".format(decoded_text))
Original Text: Two on November 12 , 1997 . The episode 's initial Encoded Tokens: [7571, 319, 3389, 1105, 837, 8309, 764, 383, 4471, 705, 82, 4238] Decoded Text: Two on November 12, 1997. The episode's initial
Data Pipeline¶
Now we are ready to construct Pipeline
to feed the data iteratively during training. We will make the encoding process happen on-the-fly during training. In addition, after encoding we will truncate the tokens for a maximum length.
class Encode(NumpyOp):
def __init__(self, tokenizer, inputs, outputs, mode=None):
super().__init__(inputs=inputs, outputs=outputs, mode=mode)
self.tokenizer = tokenizer
def forward(self, data, state):
return np.array(self.tokenizer(data, truncation=True)['input_ids'])
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[
Encode(inputs="x", outputs="x", tokenizer=tokenizer),
LambdaOp(fn=lambda x: x[:context_len + 1], inputs="x", outputs="x") # get 1 more token for next word prediction's target
])
Now let's take a look at what's the pipeline's output. We defined the dataset and pipeline in a way that it only contains the key x
.
batch_result = pipeline.get_results()
print(batch_result['x'].shape)
torch.Size([32, 513])
As we can see above, the value of key x
contains the batched tokens in shape of (Batch, max_length). The max_length is defined to be the maximum length of input tokens allowed by the model plus one. We will explain the plus one part later in network.
Define GPT Model¶
Now we are ready to define the GPT model architecture. The model architecture is a transformer-based model built with attention blocks. We will define each one of the component from lower to upper level.
Multi-head Attention¶
Attention block is the building block of any transformer-based network. We described its working mechanism in this example. But to summarize it briefly, there are 3 vectors: Query
,Key
,Value
. Given a Query
, it would perform an interpolated version of table look-up, where the table is defined by Key
and Value
.
Multi-head attention is basically multiple attention blocks working in parallel with results concatenated. One advantage of doing multi-head attention is that we can split the embedding dimension among multiple heads such that the computation complexity is reduced. This is similar to how group convolution reducing computation when comparing with full convolution.
For language modeling task, a look-ahead mask is applied such that tokens can only attend to tokens before them, not after. This is so that we don't give away the answer too easily for next word prediction.
class MultiHeadAttention(nn.Module):
# Multi-head attention is like group convolution, but for attention.
def __init__(self, context_len, em_dim, num_heads=4, p_drop=0.2, use_mask=True):
super().__init__()
self.num_heads = num_heads
self.use_mask = use_mask
self.key = nn.Linear(em_dim, em_dim, bias=False)
self.query = nn.Linear(em_dim, em_dim, bias=False)
self.value = nn.Linear(em_dim, em_dim, bias=False)
self.projection = nn.Linear(em_dim, em_dim)
self.register_buffer('tril', torch.tril(torch.ones(context_len, context_len))) # lookahead mask
self.dropout_attn = nn.Dropout(p_drop)
self.dropout_proj = nn.Dropout(p_drop)
def forward(self, x):
B, T, _ = x.shape # input shape: B, seq, em_dim
k, q, v = self.key(x), self.query(x), self.value(x) # B, seq, em_dim
# split the head and move the head dimension next to batch so heads are indepenent
k = k.reshape(B, T, self.num_heads, -1).permute(0, 2, 1, 3) # B, head, seq, em_dim//head
q = q.reshape(B, T, self.num_heads, -1).permute(0, 2, 1, 3) # B, head, seq, em_dim//head
v = v.reshape(B, T, self.num_heads, -1).permute(0, 2, 1, 3) # B, head, seq, em_dim//head
# attention
attention = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # B, head, seq, seq
if self.use_mask:
attention = attention.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # apply lookahead mask
attention = attention.softmax(dim=-1)
attention = self.dropout_attn(attention)
x = (attention @ v).permute(0, 2, 1, 3) # B, seq, head, em_dim//head
x = x.reshape(B, T, -1) # B, seq, em_dim
# projection
x = self.projection(x)
x = self.dropout_proj(x)
return x
Combining Multi-head Attention with Feed Forward Blocks¶
We are now ready to define a higher-level basic unit - AttentionBlock
, which consists of one multi-head attention and one feed forward dense layers connected through residual manner. Note that LayerNormalization
is used here, such that each individual samples are normalized within itself.
class AttentionBlock(nn.Module):
"""multi-attention + feedforward skip connection"""
def __init__(self, context_len, em_dim, num_heads, ffwd_dim, p_drop=0.2, use_mask=True):
super().__init__()
self.self_attention = MultiHeadAttention(context_len,
em_dim,
num_heads=num_heads,
p_drop=p_drop,
use_mask=use_mask)
self.ffwd = nn.Sequential(nn.Linear(em_dim, ffwd_dim),
nn.ReLU(),
nn.Linear(ffwd_dim, em_dim),
nn.Dropout(p_drop))
self.norm1 = nn.LayerNorm(em_dim)
self.norm2 = nn.LayerNorm(em_dim)
def forward(self, x):
x = x + self.self_attention(self.norm1(x))
x = x + self.ffwd(self.norm2(x))
return x
GPT Model¶
The final GPT model contains 3 basic components:
- Embeddings: Both token embedding and position embedding. For position embedding, the original idea of using cosine/sine function is not necessary when we simply make them trainable.
- Attention Blocks: Several
AttentionBlock
in a row. - Language Prediction Head: Simply a fully connected dense layer.
class GPT(nn.Module):
def __init__(self, num_blocks, vocab_size, context_len, em_dim, num_heads, ffwd_dim, p_drop=0.2, use_mask=True):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, em_dim)
self.position_embedding = nn.Embedding(context_len, em_dim)
self.blocks = nn.Sequential(
*[AttentionBlock(context_len, em_dim, num_heads, ffwd_dim, p_drop, use_mask) for _ in range(num_blocks)])
self.final_norm = nn.LayerNorm(em_dim)
self.lm_head = nn.Linear(em_dim, vocab_size)
self.register_buffer('pos_idx', torch.arange(context_len)) # position index
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x):
token_em = self.token_embedding(x)
position_em = self.position_embedding(self.pos_idx[:x.shape[-1]])
x = token_em + position_em
x = self.blocks(x)
x = self.final_norm(x)
logits = self.lm_head(x)
return logits
Tensor Operations¶
During training, here are the 4 events that will happen in sequential:
- We split the
x
coming from pipeline into input tokens and next word tokens. There is a one-to-one mapping between each position. - We send the input token to the GPT model to get the prediction.
- We calculate the cross entropy between the predicted tokens and the next word tokens as loss.
- Update the model based on the gradients of loss.
model = fe.build(
model_fn=lambda: GPT(num_blocks=num_blocks,
vocab_size=tokenizer.vocab_size,
context_len=context_len,
em_dim=em_dim,
num_heads=num_heads,
ffwd_dim=ffwd_dim,
p_drop=0.3),
optimizer_fn=lambda x: torch.optim.AdamW(x, lr=3e-4))
class CrossEntropy(TensorOp):
def forward(self, data, state):
logits, targets = data
B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.reshape(B * T)
loss = F.cross_entropy(logits, targets)
return loss
network = fe.Network(ops=[
TLambdaOp(fn=lambda x: (x[..., :-1], x[..., 1:]), inputs="x", outputs=("input", "target")),
ModelOp(model=model, inputs="input", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "target"), outputs="ce"),
UpdateOp(model=model, loss_name="ce")
])
Putting Everything Together¶
Now that we have defined everything related to the training, we can use Estimator
class to put them together. We save the model every time we observe a best validation loss.
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=BestModelSaver(model=model, save_dir=save_dir),
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch)
Start Training¶
With the default parameter, our GPT model contains 179M trainable parameters with 512 tokens as context window. We need about 40Gb GPU memory to train the model with the default batch size. The total training takes around 40 hours on single A100 GPU.
estimator.fit(warmup=False)
Inferencing: Auto-Regressive Generation with GPT¶
Once a model is trained, the model can be used to generate text based on what's given to them as context. Then we can include what's previously predicted as new input, generate the next prediction and it can just keeps going. This behavior is referred to as being Auto Regressive
.
Let's define a function that can generate responses like that:
def generate_response(prompt, model, tokenizer, max_response_token=128, context_len=512):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()
tokens = torch.Tensor(tokenizer.encode(prompt)).long().to(device)
num_input_tokens = tokens.shape[0]
assert num_input_tokens <= context_len, "prompt exceeding maximum input tokens"
tokens = tokens[None, ...] # add batch dimension
responses = None
for _ in range(max_response_token):
input_tokens = tokens[:, -context_len:]
# get prediction
logits = model(input_tokens)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
if responses is None:
responses = idx_next
else:
responses = torch.cat((responses, idx_next), dim=1) # (B, T+1)
tokens = torch.cat((tokens, idx_next), dim=1) # (B, T+1)
if idx_next[0, 0] == 102:
break
responses = responses.to('cpu').numpy()
responses = tokenizer.decode(responses[0])
return responses
Let's take a look at two sample responses and we can see how they are able to complete our paragraph in style similar to our training.
Sampe 1¶
prompt = "Computer Science is not really science"
response = generate_response(prompt, model, tokenizer, context_len=context_len)
print("Prompt: {}".format(prompt))
print("Response: {}".format(response))
Prompt: Computer Science is not really science Response: fiction. " Chris Schilling of The Guardian described it as " essentially an online horror game, offering something interesting and exciting. It doesn 't quite make any sense. " A reviewer from Computer and Video Games criticised the interface, interface and camera controls, but pointed out that the game's artificial intelligence and perspective were good. The reviewer found that at the start of the fourth game, " the sequel will strangle up the RPG elements and made sure that IGN's reviewer, Lucas Touch, would proudly announce that the good " as the story progresses ". In contrast to its predecessor, the reviewer found that Lucasfilm Games was not
Sample 2¶
prompt = "Artificial Intelligence is dangerous to human"
response = generate_response(prompt, model, tokenizer, context_len=context_len)
print("Prompt: {}".format(prompt))
print("Response: {}".format(response))
Prompt: Artificial Intelligence is dangerous to human Response: beings, while in other ways humans can withstand painful mental and physical abuse. = = Controlling = = The brain can control impulses in an action that also affects normal brain activity, including sleep, inanimate objects, and inanimate objects. However, individuals possess various brain characteristics that may serve to combat and carry information about it, such as improvements in physical and mental strength and behavior. Rather than killing innocents, the brain can essentially do so as it does for anyone or objects, but it becomes more common and more efficient when actively defending itself. The brain is usually led by a brain
Going Forward¶
In large language model, the pre-training is only the first step towards making something useful. People usually take the pre-trained model and perform fine-tuning on different tasks to make more specialized model. For example, if we want to build a chat bot, then we can finetune the model on instructions (sometimes also called instruction tuning) such that the model can chat with human. Instruction tuning is an emergent field that evolves very quickly.
To conclude, with pre-training, a model understands language and the world. With instruction tuning, a model becomes usable.