]> git.sondrewold.no Git - reference_transformer.git/commitdiff
Add GPT-style implementation
authorSondre Wold <[email protected]>
Tue, 18 Jun 2024 07:50:38 +0000 (09:50 +0200)
committerSondre Wold <[email protected]>
Tue, 18 Jun 2024 07:50:38 +0000 (09:50 +0200)
reference_transformer/pyproject.toml
reference_transformer/reference_transformer/gpt.py [new file with mode: 0644]
reference_transformer/reference_transformer/tokenizer.py
reference_transformer/reference_transformer/train.py

index f493bf1f5d10def08093337bf65d32d3a792ff1d..fe48bac199f10ff20acda2b5c7c56f85982e6048 100644 (file)
@@ -12,6 +12,7 @@ numpy = "^2.0.0"
 pytest = "^8.2.2"
 pre-commit = "^3.7.1"
 tqdm = "^4.66.4"
+einops = "^0.8.0"
 
 
 [build-system]
diff --git a/reference_transformer/reference_transformer/gpt.py b/reference_transformer/reference_transformer/gpt.py
new file mode 100644 (file)
index 0000000..86919de
--- /dev/null
@@ -0,0 +1,117 @@
+import einops
+import math
+import torch
+import torch.nn as nn
+
+
+class GPT(nn.Module):
+    def __init__(
+        self,
+        hidden_size: int,
+        vocab_size: int,
+        num_layers: int,
+        heads: int,
+        dropout: float,
+        sequence_length: int,
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.vocab_size = vocab_size
+        self.num_layers = num_layers
+        self.dropout = dropout
+        self.heads = heads
+        self.sequence_length = sequence_length
+        self.embedding = nn.Embedding(vocab_size, self.hidden_size)
+        self.pos_embedding = nn.Embedding(sequence_length, self.hidden_size)
+        self.layers = nn.Sequential(
+            *[
+                DecoderBlock(self.hidden_size, self.heads, self.dropout)
+                for _ in range(self.num_layers)
+            ]
+        )
+        self.dropout = nn.Dropout(self.dropout)
+        self.head = nn.Linear(self.hidden_size, self.vocab_size)
+
+    def forward(self, x):
+        B, T = x.shape
+        x = self.embedding(x)
+        x = x + self.pos_embedding(torch.arange(T))
+        x = self.dropout(x)
+        x = self.layers(x)
+        x = self.head(x)
+        return x
+
+
+class DecoderBlock(nn.Module):
+    def __init__(self, hidden_size: int, heads: int, dropout: float):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.heads = heads
+        self.layer_norm1 = nn.LayerNorm(self.hidden_size)
+        self.layer_norm2 = nn.LayerNorm(self.hidden_size)
+        self.dropout = nn.Dropout(dropout)
+        self.ff = FeedForward(self.hidden_size, dropout)
+        self.mha = MultiHeadedAttention(self.hidden_size, self.hidden_size)
+
+    def forward(self, x):
+        x = x + self.dropout(self.mha(self.layer_norm1(x)))
+        x = self.layer_norm2(x)
+        x = x + self.dropout(self.ff(x))
+        return x
+
+
+class MultiHeadedAttention(nn.Module):
+    def __init__(self, hidden_size: int, heads: int):
+        super().__init__()
+        assert hidden_size % heads == 0, "Wrong number of heads given the hidden size"
+        self.hidden_size = hidden_size
+        self.heads = heads
+        self.query = nn.Linear(hidden_size, hidden_size)
+        self.key = nn.Linear(hidden_size, hidden_size)
+        self.value = nn.Linear(hidden_size, hidden_size)
+        self.output = nn.Linear(hidden_size, hidden_size)
+        self.scale = 1 / math.sqrt(hidden_size // heads)
+
+    def forward(self, x):
+        B, T, C = x.shape
+        queries = self.query(x)
+        keys = self.key(x)
+        values = self.value(x)  # B, T, C
+
+        # Reshape to multi-headed
+        queries = queries.view(B, T, self.heads, self.hidden_size // self.heads)
+        keys = keys.view(B, T, self.heads, self.hidden_size // self.heads)
+        values = values.view(B, T, self.heads, self.hidden_size // self.heads)
+
+        # Calculate query-key product: each token in query against each token in key
+        attention_weights = torch.einsum(
+            "bqhc,bkhc->bhqk", queries, keys
+        )  # B, heads, T, T
+        autoregressive_mask = torch.tril(torch.ones((T, T)))
+        attention_weights = attention_weights.masked_fill(
+            autoregressive_mask == 0.0, float("-inf")
+        )
+        attention_weights = attention_weights * self.scale
+        attention_weights = torch.softmax(attention_weights, dim=-1)
+
+        # Multiply with values
+        scaled_values = torch.einsum("bhqk,bkhc->bkhc", attention_weights, values)
+
+        # Concat heads
+        flattened = einops.rearrange(scaled_values, "b k h c -> b k (h c)")
+        output = self.output(flattened)
+        return output
+
+
+class FeedForward(nn.Module):
+    def __init__(self, hidden_size: int, dropout: float):
+        super().__init__()
+        self.layers = nn.Sequential(
+            nn.Linear(hidden_size, hidden_size * 4),
+            nn.ReLU(),
+            nn.Linear(hidden_size * 4, hidden_size),
+            nn.Dropout(dropout),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
index 86610333748c977fb1823dae050544123086283f..091bbae90fff6d74e4ab0d06ce9fd87ba81d9eca 100644 (file)
@@ -24,10 +24,12 @@ class Tokenizer(ABC):
 class StupidTokenizer(Tokenizer):
     def __init__(self, data: str):
         super().__init__()
+        self.symbols = sorted(list(set(data)))
+        self.vocab_size = len(self.symbols)
         self.train(data)
 
     def train(self, data: str):
-        self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))}
+        self.c2i = {c: i for i, c in enumerate(self.symbols)}
         self.i2c = {i: c for c, i in self.c2i.items()}
 
     def encode(self, text: str):
index 577176f7eb1df5b0076bb8391db95a46261d9295..cf7625754721c150ae8d95487b7d8f1d984abc2b 100644 (file)
@@ -4,6 +4,7 @@ import numpy as np
 from pathlib import Path
 from reference_transformer.tokenizer import StupidTokenizer
 from reference_transformer.dataset import TextDataset
+from reference_transformer.gpt import GPT
 from tqdm import tqdm
 
 
@@ -15,30 +16,69 @@ def parse_args():
     )
     parser.add_argument("--train_data_path", type=Path, default=None)
     parser.add_argument("--validation_data_path", type=Path, default=None)
-    parser.add_argument("--sequence_length", type=int, default=128)
-    parser.add_argument("--batch_size", type=int, default=128)
+    parser.add_argument("--sequence_length", type=int, default=64)
+    parser.add_argument("--hidden_size", type=int, default=256)
+    parser.add_argument("--decoder_layers", type=int, default=2)
+    parser.add_argument("--heads", type=int, default=4)
+    parser.add_argument("--dropout", type=float, default=0.1)
+    parser.add_argument("--lr", type=float, default=0.01)
+    parser.add_argument("--batch_size", type=int, default=16)
     parser.add_argument("--max_iter", type=int, default=10000)
     args = parser.parse_args()
     return args
 
 
 def train(
-    model, train_data: TextDataset, max_iter: int, batch_size: int, optimizer, criterion
+    model,
+    train_data: TextDataset,
+    max_iter: int,
+    batch_size: int,
+    optimizer,
+    criterion,
+    scheduler,
 ):
     model.train()
     for i in tqdm(range(0, max_iter)):
         optimizer.zero_grad()
         source, target = train_data.get_batch(batch_size)
-        break
+        B, T = source.shape
+        logits = model(source)
+        logits = logits.transpose(-2, -1)
+        loss = criterion(logits, target)
+        if i % 100 == 0:
+            print(f"Loss: {loss}")
+
+        loss.backward()
+        optimizer.step()
+        scheduler.step()
 
 
 def main(args):
     with open(args.train_data_path, "r", encoding="utf8") as f:
         raw_train_data = f.read()
     tokenizer = StupidTokenizer(raw_train_data)
+    max_steps = args.max_iter // args.batch_size
     train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length)
-    criterion = torch.nn.CrossEntropy()
-    train(None, train_data, args.max_iter, args.batch_size)
+    model = GPT(
+        args.hidden_size,
+        tokenizer.vocab_size,
+        args.decoder_layers,
+        args.heads,
+        args.dropout,
+        args.sequence_length,
+    )
+    criterion = torch.nn.CrossEntropyLoss()
+    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
+    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_steps)
+    train(
+        model,
+        train_data,
+        args.max_iter,
+        args.batch_size,
+        optimizer,
+        criterion,
+        scheduler,
+    )
 
 
 if __name__ == "__main__":