--- /dev/null
+import einops
+import math
+import torch
+import torch.nn as nn
+
+
+class GPT(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ vocab_size: int,
+ num_layers: int,
+ heads: int,
+ dropout: float,
+ sequence_length: int,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.heads = heads
+ self.sequence_length = sequence_length
+ self.embedding = nn.Embedding(vocab_size, self.hidden_size)
+ self.pos_embedding = nn.Embedding(sequence_length, self.hidden_size)
+ self.layers = nn.Sequential(
+ *[
+ DecoderBlock(self.hidden_size, self.heads, self.dropout)
+ for _ in range(self.num_layers)
+ ]
+ )
+ self.dropout = nn.Dropout(self.dropout)
+ self.head = nn.Linear(self.hidden_size, self.vocab_size)
+
+ def forward(self, x):
+ B, T = x.shape
+ x = self.embedding(x)
+ x = x + self.pos_embedding(torch.arange(T))
+ x = self.dropout(x)
+ x = self.layers(x)
+ x = self.head(x)
+ return x
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, hidden_size: int, heads: int, dropout: float):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.heads = heads
+ self.layer_norm1 = nn.LayerNorm(self.hidden_size)
+ self.layer_norm2 = nn.LayerNorm(self.hidden_size)
+ self.dropout = nn.Dropout(dropout)
+ self.ff = FeedForward(self.hidden_size, dropout)
+ self.mha = MultiHeadedAttention(self.hidden_size, self.hidden_size)
+
+ def forward(self, x):
+ x = x + self.dropout(self.mha(self.layer_norm1(x)))
+ x = self.layer_norm2(x)
+ x = x + self.dropout(self.ff(x))
+ return x
+
+
+class MultiHeadedAttention(nn.Module):
+ def __init__(self, hidden_size: int, heads: int):
+ super().__init__()
+ assert hidden_size % heads == 0, "Wrong number of heads given the hidden size"
+ self.hidden_size = hidden_size
+ self.heads = heads
+ self.query = nn.Linear(hidden_size, hidden_size)
+ self.key = nn.Linear(hidden_size, hidden_size)
+ self.value = nn.Linear(hidden_size, hidden_size)
+ self.output = nn.Linear(hidden_size, hidden_size)
+ self.scale = 1 / math.sqrt(hidden_size // heads)
+
+ def forward(self, x):
+ B, T, C = x.shape
+ queries = self.query(x)
+ keys = self.key(x)
+ values = self.value(x) # B, T, C
+
+ # Reshape to multi-headed
+ queries = queries.view(B, T, self.heads, self.hidden_size // self.heads)
+ keys = keys.view(B, T, self.heads, self.hidden_size // self.heads)
+ values = values.view(B, T, self.heads, self.hidden_size // self.heads)
+
+ # Calculate query-key product: each token in query against each token in key
+ attention_weights = torch.einsum(
+ "bqhc,bkhc->bhqk", queries, keys
+ ) # B, heads, T, T
+ autoregressive_mask = torch.tril(torch.ones((T, T)))
+ attention_weights = attention_weights.masked_fill(
+ autoregressive_mask == 0.0, float("-inf")
+ )
+ attention_weights = attention_weights * self.scale
+ attention_weights = torch.softmax(attention_weights, dim=-1)
+
+ # Multiply with values
+ scaled_values = torch.einsum("bhqk,bkhc->bkhc", attention_weights, values)
+
+ # Concat heads
+ flattened = einops.rearrange(scaled_values, "b k h c -> b k (h c)")
+ output = self.output(flattened)
+ return output
+
+
+class FeedForward(nn.Module):
+ def __init__(self, hidden_size: int, dropout: float):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.ReLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
from pathlib import Path
from reference_transformer.tokenizer import StupidTokenizer
from reference_transformer.dataset import TextDataset
+from reference_transformer.gpt import GPT
from tqdm import tqdm
)
parser.add_argument("--train_data_path", type=Path, default=None)
parser.add_argument("--validation_data_path", type=Path, default=None)
- parser.add_argument("--sequence_length", type=int, default=128)
- parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--sequence_length", type=int, default=64)
+ parser.add_argument("--hidden_size", type=int, default=256)
+ parser.add_argument("--decoder_layers", type=int, default=2)
+ parser.add_argument("--heads", type=int, default=4)
+ parser.add_argument("--dropout", type=float, default=0.1)
+ parser.add_argument("--lr", type=float, default=0.01)
+ parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--max_iter", type=int, default=10000)
args = parser.parse_args()
return args
def train(
- model, train_data: TextDataset, max_iter: int, batch_size: int, optimizer, criterion
+ model,
+ train_data: TextDataset,
+ max_iter: int,
+ batch_size: int,
+ optimizer,
+ criterion,
+ scheduler,
):
model.train()
for i in tqdm(range(0, max_iter)):
optimizer.zero_grad()
source, target = train_data.get_batch(batch_size)
- break
+ B, T = source.shape
+ logits = model(source)
+ logits = logits.transpose(-2, -1)
+ loss = criterion(logits, target)
+ if i % 100 == 0:
+ print(f"Loss: {loss}")
+
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
def main(args):
with open(args.train_data_path, "r", encoding="utf8") as f:
raw_train_data = f.read()
tokenizer = StupidTokenizer(raw_train_data)
+ max_steps = args.max_iter // args.batch_size
train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length)
- criterion = torch.nn.CrossEntropy()
- train(None, train_data, args.max_iter, args.batch_size)
+ model = GPT(
+ args.hidden_size,
+ tokenizer.vocab_size,
+ args.decoder_layers,
+ args.heads,
+ args.dropout,
+ args.sequence_length,
+ )
+ criterion = torch.nn.CrossEntropyLoss()
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_steps)
+ train(
+ model,
+ train_data,
+ args.max_iter,
+ args.batch_size,
+ optimizer,
+ criterion,
+ scheduler,
+ )
if __name__ == "__main__":