numpy = "^2.0.0"
pytest = "^8.2.2"
pre-commit = "^3.7.1"
+tqdm = "^4.66.4"
[build-system]
-import torch
from torch.utils.data import Dataset
+import torch
class TextDataset(Dataset):
- def __init__(self, path: str, tokenizer):
+ def __init__(self, data: str, tokenizer, sequence_length: int, device: str = "cpu"):
+ super().__init__()
+ self.token_ids = torch.tensor(tokenizer.encode(data)).to(device)
+ self.sequence_length = sequence_length
+
+ def __len__(self):
+ return len(self.token_ids)
+ def get_batch(self, batch_size):
+ idx = torch.randint(len(self) - self.sequence_length, (batch_size,))
+ source_ids = torch.stack(
+ [self.token_ids[i : i + self.sequence_length] for i in idx]
+ )
+ target_ids = torch.stack(
+ [self.token_ids[i + 1 : i + self.sequence_length + 1] for i in idx]
+ )
+ return source_ids, target_ids
from abc import ABC, abstractmethod
+UNKNOWN_TOKEN = "<UNK>"
+UNKNOWN_TOKEN_ID = -1
+
+
class Tokenizer(ABC):
+
+ @abstractmethod
+ def train(self, text: str):
+ pass
+
@abstractmethod
def encode(self, text: str):
pass
class StupidTokenizer(Tokenizer):
def __init__(self, data: str):
super().__init__()
+ self.train(data)
+
+ def train(self, data: str):
self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))}
self.i2c = {i: c for c, i in self.c2i.items()}
-
+
def encode(self, text: str):
- return [self.c2i[c] for c in text]
+ return [self.c2i.get(c, UNKNOWN_TOKEN_ID) for c in text]
def decode(self, ids: list[int]):
- return ''.join(self.i2c[i] for i in ids)
+ return "".join(self.i2c.get(i, UNKNOWN_TOKEN) for i in ids)
--- /dev/null
+import argparse
+import torch
+import numpy as np
+from pathlib import Path
+from reference_transformer.tokenizer import StupidTokenizer
+from reference_transformer.dataset import TextDataset
+from tqdm import tqdm
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ prog="TraininLoop",
+ description="Trains a specified Transformer model using autoregressive language modeling",
+ epilog="Intended for educational purpose",
+ )
+ parser.add_argument("--train_data_path", type=Path, default=None)
+ parser.add_argument("--validation_data_path", type=Path, default=None)
+ parser.add_argument("--sequence_length", type=int, default=128)
+ parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--max_iter", type=int, default=10000)
+ args = parser.parse_args()
+ return args
+
+
+def train(
+ model, train_data: TextDataset, max_iter: int, batch_size: int, optimizer, criterion
+):
+ model.train()
+ for i in tqdm(range(0, max_iter)):
+ optimizer.zero_grad()
+ source, target = train_data.get_batch(batch_size)
+ break
+
+
+def main(args):
+ with open(args.train_data_path, "r", encoding="utf8") as f:
+ raw_train_data = f.read()
+ tokenizer = StupidTokenizer(raw_train_data)
+ train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length)
+ criterion = torch.nn.CrossEntropy()
+ train(None, train_data, args.max_iter, args.batch_size)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
--- /dev/null
+import torch
+import pytest
+from reference_transformer.tokenizer import StupidTokenizer
+from reference_transformer.dataset import TextDataset
+
+
+def test_data():
+ with open("./data/shakespeare.txt", "r", encoding="utf8") as f:
+ data = f.read()
+ return data
+
+
+def test_source_target_align(test_data):
+ tokenizer = StupidTokenizer(test_data)
+ dataset = TextDataset(test_data, tokenizer, 128)
+ source, target = dataset.get_batch(32)
+ for s, t in zip(source, target):
+ for i in range(0, len(target) - 1):
+ assert t[i] == s[i + 1], "Failed to align source and target ids"