From: Sondre Wold Date: Tue, 18 Jun 2024 07:50:38 +0000 (+0200) Subject: Add GPT-style implementation X-Git-Url: https://letsjmore.com/?a=commitdiff_plain;h=3d3f61b17f9cb6187bc67e5ca8c54128cd999c69;p=reference_transformer.git Add GPT-style implementation --- diff --git a/reference_transformer/pyproject.toml b/reference_transformer/pyproject.toml index f493bf1..fe48bac 100644 --- a/reference_transformer/pyproject.toml +++ b/reference_transformer/pyproject.toml @@ -12,6 +12,7 @@ numpy = "^2.0.0" pytest = "^8.2.2" pre-commit = "^3.7.1" tqdm = "^4.66.4" +einops = "^0.8.0" [build-system] diff --git a/reference_transformer/reference_transformer/gpt.py b/reference_transformer/reference_transformer/gpt.py new file mode 100644 index 0000000..86919de --- /dev/null +++ b/reference_transformer/reference_transformer/gpt.py @@ -0,0 +1,117 @@ +import einops +import math +import torch +import torch.nn as nn + + +class GPT(nn.Module): + def __init__( + self, + hidden_size: int, + vocab_size: int, + num_layers: int, + heads: int, + dropout: float, + sequence_length: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.num_layers = num_layers + self.dropout = dropout + self.heads = heads + self.sequence_length = sequence_length + self.embedding = nn.Embedding(vocab_size, self.hidden_size) + self.pos_embedding = nn.Embedding(sequence_length, self.hidden_size) + self.layers = nn.Sequential( + *[ + DecoderBlock(self.hidden_size, self.heads, self.dropout) + for _ in range(self.num_layers) + ] + ) + self.dropout = nn.Dropout(self.dropout) + self.head = nn.Linear(self.hidden_size, self.vocab_size) + + def forward(self, x): + B, T = x.shape + x = self.embedding(x) + x = x + self.pos_embedding(torch.arange(T)) + x = self.dropout(x) + x = self.layers(x) + x = self.head(x) + return x + + +class DecoderBlock(nn.Module): + def __init__(self, hidden_size: int, heads: int, dropout: float): + super().__init__() + self.hidden_size = hidden_size + self.heads = heads + self.layer_norm1 = nn.LayerNorm(self.hidden_size) + self.layer_norm2 = nn.LayerNorm(self.hidden_size) + self.dropout = nn.Dropout(dropout) + self.ff = FeedForward(self.hidden_size, dropout) + self.mha = MultiHeadedAttention(self.hidden_size, self.hidden_size) + + def forward(self, x): + x = x + self.dropout(self.mha(self.layer_norm1(x))) + x = self.layer_norm2(x) + x = x + self.dropout(self.ff(x)) + return x + + +class MultiHeadedAttention(nn.Module): + def __init__(self, hidden_size: int, heads: int): + super().__init__() + assert hidden_size % heads == 0, "Wrong number of heads given the hidden size" + self.hidden_size = hidden_size + self.heads = heads + self.query = nn.Linear(hidden_size, hidden_size) + self.key = nn.Linear(hidden_size, hidden_size) + self.value = nn.Linear(hidden_size, hidden_size) + self.output = nn.Linear(hidden_size, hidden_size) + self.scale = 1 / math.sqrt(hidden_size // heads) + + def forward(self, x): + B, T, C = x.shape + queries = self.query(x) + keys = self.key(x) + values = self.value(x) # B, T, C + + # Reshape to multi-headed + queries = queries.view(B, T, self.heads, self.hidden_size // self.heads) + keys = keys.view(B, T, self.heads, self.hidden_size // self.heads) + values = values.view(B, T, self.heads, self.hidden_size // self.heads) + + # Calculate query-key product: each token in query against each token in key + attention_weights = torch.einsum( + "bqhc,bkhc->bhqk", queries, keys + ) # B, heads, T, T + autoregressive_mask = torch.tril(torch.ones((T, T))) + attention_weights = attention_weights.masked_fill( + autoregressive_mask == 0.0, float("-inf") + ) + attention_weights = attention_weights * self.scale + attention_weights = torch.softmax(attention_weights, dim=-1) + + # Multiply with values + scaled_values = torch.einsum("bhqk,bkhc->bkhc", attention_weights, values) + + # Concat heads + flattened = einops.rearrange(scaled_values, "b k h c -> b k (h c)") + output = self.output(flattened) + return output + + +class FeedForward(nn.Module): + def __init__(self, hidden_size: int, dropout: float): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.ReLU(), + nn.Linear(hidden_size * 4, hidden_size), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.layers(x) diff --git a/reference_transformer/reference_transformer/tokenizer.py b/reference_transformer/reference_transformer/tokenizer.py index 8661033..091bbae 100644 --- a/reference_transformer/reference_transformer/tokenizer.py +++ b/reference_transformer/reference_transformer/tokenizer.py @@ -24,10 +24,12 @@ class Tokenizer(ABC): class StupidTokenizer(Tokenizer): def __init__(self, data: str): super().__init__() + self.symbols = sorted(list(set(data))) + self.vocab_size = len(self.symbols) self.train(data) def train(self, data: str): - self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))} + self.c2i = {c: i for i, c in enumerate(self.symbols)} self.i2c = {i: c for c, i in self.c2i.items()} def encode(self, text: str): diff --git a/reference_transformer/reference_transformer/train.py b/reference_transformer/reference_transformer/train.py index 577176f..cf76257 100644 --- a/reference_transformer/reference_transformer/train.py +++ b/reference_transformer/reference_transformer/train.py @@ -4,6 +4,7 @@ import numpy as np from pathlib import Path from reference_transformer.tokenizer import StupidTokenizer from reference_transformer.dataset import TextDataset +from reference_transformer.gpt import GPT from tqdm import tqdm @@ -15,30 +16,69 @@ def parse_args(): ) parser.add_argument("--train_data_path", type=Path, default=None) parser.add_argument("--validation_data_path", type=Path, default=None) - parser.add_argument("--sequence_length", type=int, default=128) - parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--sequence_length", type=int, default=64) + parser.add_argument("--hidden_size", type=int, default=256) + parser.add_argument("--decoder_layers", type=int, default=2) + parser.add_argument("--heads", type=int, default=4) + parser.add_argument("--dropout", type=float, default=0.1) + parser.add_argument("--lr", type=float, default=0.01) + parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--max_iter", type=int, default=10000) args = parser.parse_args() return args def train( - model, train_data: TextDataset, max_iter: int, batch_size: int, optimizer, criterion + model, + train_data: TextDataset, + max_iter: int, + batch_size: int, + optimizer, + criterion, + scheduler, ): model.train() for i in tqdm(range(0, max_iter)): optimizer.zero_grad() source, target = train_data.get_batch(batch_size) - break + B, T = source.shape + logits = model(source) + logits = logits.transpose(-2, -1) + loss = criterion(logits, target) + if i % 100 == 0: + print(f"Loss: {loss}") + + loss.backward() + optimizer.step() + scheduler.step() def main(args): with open(args.train_data_path, "r", encoding="utf8") as f: raw_train_data = f.read() tokenizer = StupidTokenizer(raw_train_data) + max_steps = args.max_iter // args.batch_size train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length) - criterion = torch.nn.CrossEntropy() - train(None, train_data, args.max_iter, args.batch_size) + model = GPT( + args.hidden_size, + tokenizer.vocab_size, + args.decoder_layers, + args.heads, + args.dropout, + args.sequence_length, + ) + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_steps) + train( + model, + train_data, + args.max_iter, + args.batch_size, + optimizer, + criterion, + scheduler, + ) if __name__ == "__main__":