--- /dev/null
+import torch
+from torch.utils.data import Dataset
+
+
+class TextDataset(Dataset):
+ def __init__(self, path: str, tokenizer):
+
--- /dev/null
+import logging
+from abc import ABC, abstractmethod
+
+
+class Tokenizer(ABC):
+ @abstractmethod
+ def encode(self, text: str):
+ pass
+
+ @abstractmethod
+ def decode(self, ids: list):
+ pass
+
+
+class StupidTokenizer(Tokenizer):
+ def __init__(self, data: str):
+ super().__init__()
+ self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))}
+ self.i2c = {i: c for c, i in self.c2i.items()}
+
+ def encode(self, text: str):
+ return [self.c2i[c] for c in text]
+
+ def decode(self, ids: list[int]):
+ return ''.join(self.i2c[i] for i in ids)
--- /dev/null
+import pytest
+from reference_transformer.tokenizer import StupidTokenizer
+
+
+
+def test_data():
+ with open("./data/shakespeare.txt", 'r', encoding="utf8") as f:
+ data = f.read()
+ return data
+
+
+def test_stupid_encode_decode(test_data):
+ tokenizer = StupidTokenizer(test_data)
+ encoded = tokenizer.encode(test_data)
+ decoded = tokenizer.decode(encoded)
+ assert test_data == decoded, "EncodeDecode does not reproduce original data..."
+
+