From 643c8b95f0dca5e5c906de0822465f67bc3c5bfa Mon Sep 17 00:00:00 2001 From: Sondre Wold Date: Mon, 17 Jun 2024 16:55:17 +0200 Subject: [PATCH] Updates --- reference_transformer/pyproject.toml | 1 + .../reference_transformer/dataset.py | 19 +++++++- .../reference_transformer/tokenizer.py | 18 ++++++-- .../reference_transformer/train.py | 46 +++++++++++++++++++ reference_transformer/tests/test_dataset.py | 20 ++++++++ 5 files changed, 99 insertions(+), 5 deletions(-) create mode 100644 reference_transformer/reference_transformer/train.py create mode 100644 reference_transformer/tests/test_dataset.py diff --git a/reference_transformer/pyproject.toml b/reference_transformer/pyproject.toml index 7acc369..f493bf1 100644 --- a/reference_transformer/pyproject.toml +++ b/reference_transformer/pyproject.toml @@ -11,6 +11,7 @@ torch = "^2.3.1" numpy = "^2.0.0" pytest = "^8.2.2" pre-commit = "^3.7.1" +tqdm = "^4.66.4" [build-system] diff --git a/reference_transformer/reference_transformer/dataset.py b/reference_transformer/reference_transformer/dataset.py index 2fee5cd..06efbb6 100644 --- a/reference_transformer/reference_transformer/dataset.py +++ b/reference_transformer/reference_transformer/dataset.py @@ -1,7 +1,22 @@ -import torch from torch.utils.data import Dataset +import torch class TextDataset(Dataset): - def __init__(self, path: str, tokenizer): + def __init__(self, data: str, tokenizer, sequence_length: int, device: str = "cpu"): + super().__init__() + self.token_ids = torch.tensor(tokenizer.encode(data)).to(device) + self.sequence_length = sequence_length + + def __len__(self): + return len(self.token_ids) + def get_batch(self, batch_size): + idx = torch.randint(len(self) - self.sequence_length, (batch_size,)) + source_ids = torch.stack( + [self.token_ids[i : i + self.sequence_length] for i in idx] + ) + target_ids = torch.stack( + [self.token_ids[i + 1 : i + self.sequence_length + 1] for i in idx] + ) + return source_ids, target_ids diff --git a/reference_transformer/reference_transformer/tokenizer.py b/reference_transformer/reference_transformer/tokenizer.py index 60e505c..8661033 100644 --- a/reference_transformer/reference_transformer/tokenizer.py +++ b/reference_transformer/reference_transformer/tokenizer.py @@ -2,7 +2,16 @@ import logging from abc import ABC, abstractmethod +UNKNOWN_TOKEN = "" +UNKNOWN_TOKEN_ID = -1 + + class Tokenizer(ABC): + + @abstractmethod + def train(self, text: str): + pass + @abstractmethod def encode(self, text: str): pass @@ -15,11 +24,14 @@ class Tokenizer(ABC): class StupidTokenizer(Tokenizer): def __init__(self, data: str): super().__init__() + self.train(data) + + def train(self, data: str): self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))} self.i2c = {i: c for c, i in self.c2i.items()} - + def encode(self, text: str): - return [self.c2i[c] for c in text] + return [self.c2i.get(c, UNKNOWN_TOKEN_ID) for c in text] def decode(self, ids: list[int]): - return ''.join(self.i2c[i] for i in ids) + return "".join(self.i2c.get(i, UNKNOWN_TOKEN) for i in ids) diff --git a/reference_transformer/reference_transformer/train.py b/reference_transformer/reference_transformer/train.py new file mode 100644 index 0000000..577176f --- /dev/null +++ b/reference_transformer/reference_transformer/train.py @@ -0,0 +1,46 @@ +import argparse +import torch +import numpy as np +from pathlib import Path +from reference_transformer.tokenizer import StupidTokenizer +from reference_transformer.dataset import TextDataset +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="TraininLoop", + description="Trains a specified Transformer model using autoregressive language modeling", + epilog="Intended for educational purpose", + ) + parser.add_argument("--train_data_path", type=Path, default=None) + parser.add_argument("--validation_data_path", type=Path, default=None) + parser.add_argument("--sequence_length", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--max_iter", type=int, default=10000) + args = parser.parse_args() + return args + + +def train( + model, train_data: TextDataset, max_iter: int, batch_size: int, optimizer, criterion +): + model.train() + for i in tqdm(range(0, max_iter)): + optimizer.zero_grad() + source, target = train_data.get_batch(batch_size) + break + + +def main(args): + with open(args.train_data_path, "r", encoding="utf8") as f: + raw_train_data = f.read() + tokenizer = StupidTokenizer(raw_train_data) + train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length) + criterion = torch.nn.CrossEntropy() + train(None, train_data, args.max_iter, args.batch_size) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/reference_transformer/tests/test_dataset.py b/reference_transformer/tests/test_dataset.py new file mode 100644 index 0000000..142b749 --- /dev/null +++ b/reference_transformer/tests/test_dataset.py @@ -0,0 +1,20 @@ +import torch +import pytest +from reference_transformer.tokenizer import StupidTokenizer +from reference_transformer.dataset import TextDataset + + +@pytest.fixture +def test_data(): + with open("./data/shakespeare.txt", "r", encoding="utf8") as f: + data = f.read() + return data + + +def test_source_target_align(test_data): + tokenizer = StupidTokenizer(test_data) + dataset = TextDataset(test_data, tokenizer, 128) + source, target = dataset.get_batch(32) + for s, t in zip(source, target): + for i in range(0, len(target) - 1): + assert t[i] == s[i + 1], "Failed to align source and target ids" -- 2.39.5