]> git.sondrewold.no Git - reference_transformer.git/commitdiff
Updates
authorSondre Wold <[email protected]>
Mon, 17 Jun 2024 14:55:17 +0000 (16:55 +0200)
committerSondre Wold <[email protected]>
Mon, 17 Jun 2024 14:55:17 +0000 (16:55 +0200)
reference_transformer/pyproject.toml
reference_transformer/reference_transformer/dataset.py
reference_transformer/reference_transformer/tokenizer.py
reference_transformer/reference_transformer/train.py [new file with mode: 0644]
reference_transformer/tests/test_dataset.py [new file with mode: 0644]

index 7acc369f23f9024b2ee2679bfb5e7d2152fc12ef..f493bf1f5d10def08093337bf65d32d3a792ff1d 100644 (file)
@@ -11,6 +11,7 @@ torch = "^2.3.1"
 numpy = "^2.0.0"
 pytest = "^8.2.2"
 pre-commit = "^3.7.1"
+tqdm = "^4.66.4"
 
 
 [build-system]
index 2fee5cdf3d4b87c132a2b1318d048d43b2ff5482..06efbb607ba23745b2feb4f1c3ce1451a4db041b 100644 (file)
@@ -1,7 +1,22 @@
-import torch
 from torch.utils.data import Dataset
+import torch
 
 
 class TextDataset(Dataset):
-    def __init__(self, path: str, tokenizer):
+    def __init__(self, data: str, tokenizer, sequence_length: int, device: str = "cpu"):
+        super().__init__()
+        self.token_ids = torch.tensor(tokenizer.encode(data)).to(device)
+        self.sequence_length = sequence_length
+
+    def __len__(self):
+        return len(self.token_ids)
 
+    def get_batch(self, batch_size):
+        idx = torch.randint(len(self) - self.sequence_length, (batch_size,))
+        source_ids = torch.stack(
+            [self.token_ids[i : i + self.sequence_length] for i in idx]
+        )
+        target_ids = torch.stack(
+            [self.token_ids[i + 1 : i + self.sequence_length + 1] for i in idx]
+        )
+        return source_ids, target_ids
index 60e505c5c73787799e357362b0001203c7247f11..86610333748c977fb1823dae050544123086283f 100644 (file)
@@ -2,7 +2,16 @@ import logging
 from abc import ABC, abstractmethod
 
 
+UNKNOWN_TOKEN = "<UNK>"
+UNKNOWN_TOKEN_ID = -1
+
+
 class Tokenizer(ABC):
+
+    @abstractmethod
+    def train(self, text: str):
+        pass
+
     @abstractmethod
     def encode(self, text: str):
         pass
@@ -15,11 +24,14 @@ class Tokenizer(ABC):
 class StupidTokenizer(Tokenizer):
     def __init__(self, data: str):
         super().__init__()
+        self.train(data)
+
+    def train(self, data: str):
         self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))}
         self.i2c = {i: c for c, i in self.c2i.items()}
-        
+
     def encode(self, text: str):
-        return [self.c2i[c] for c in text]
+        return [self.c2i.get(c, UNKNOWN_TOKEN_ID) for c in text]
 
     def decode(self, ids: list[int]):
-        return ''.join(self.i2c[i] for i in ids)
+        return "".join(self.i2c.get(i, UNKNOWN_TOKEN) for i in ids)
diff --git a/reference_transformer/reference_transformer/train.py b/reference_transformer/reference_transformer/train.py
new file mode 100644 (file)
index 0000000..577176f
--- /dev/null
@@ -0,0 +1,46 @@
+import argparse
+import torch
+import numpy as np
+from pathlib import Path
+from reference_transformer.tokenizer import StupidTokenizer
+from reference_transformer.dataset import TextDataset
+from tqdm import tqdm
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        prog="TraininLoop",
+        description="Trains a specified Transformer model using autoregressive language modeling",
+        epilog="Intended for educational purpose",
+    )
+    parser.add_argument("--train_data_path", type=Path, default=None)
+    parser.add_argument("--validation_data_path", type=Path, default=None)
+    parser.add_argument("--sequence_length", type=int, default=128)
+    parser.add_argument("--batch_size", type=int, default=128)
+    parser.add_argument("--max_iter", type=int, default=10000)
+    args = parser.parse_args()
+    return args
+
+
+def train(
+    model, train_data: TextDataset, max_iter: int, batch_size: int, optimizer, criterion
+):
+    model.train()
+    for i in tqdm(range(0, max_iter)):
+        optimizer.zero_grad()
+        source, target = train_data.get_batch(batch_size)
+        break
+
+
+def main(args):
+    with open(args.train_data_path, "r", encoding="utf8") as f:
+        raw_train_data = f.read()
+    tokenizer = StupidTokenizer(raw_train_data)
+    train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length)
+    criterion = torch.nn.CrossEntropy()
+    train(None, train_data, args.max_iter, args.batch_size)
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    main(args)
diff --git a/reference_transformer/tests/test_dataset.py b/reference_transformer/tests/test_dataset.py
new file mode 100644 (file)
index 0000000..142b749
--- /dev/null
@@ -0,0 +1,20 @@
+import torch
+import pytest
+from reference_transformer.tokenizer import StupidTokenizer
+from reference_transformer.dataset import TextDataset
+
+
+def test_data():
+    with open("./data/shakespeare.txt", "r", encoding="utf8") as f:
+        data = f.read()
+    return data
+
+
+def test_source_target_align(test_data):
+    tokenizer = StupidTokenizer(test_data)
+    dataset = TextDataset(test_data, tokenizer, 128)
+    source, target = dataset.get_batch(32)
+    for s, t in zip(source, target):
+        for i in range(0, len(target) - 1):
+            assert t[i] == s[i + 1], "Failed to align source and target ids"