]> git.sondrewold.no Git - reference_transformer.git/commitdiff
Add StupidTokenizer and test
authorSondre Wold <[email protected]>
Mon, 17 Jun 2024 11:45:29 +0000 (13:45 +0200)
committerSondre Wold <[email protected]>
Mon, 17 Jun 2024 11:45:29 +0000 (13:45 +0200)
reference_transformer/reference_transformer/__init__.py [new file with mode: 0644]
reference_transformer/reference_transformer/dataset.py [new file with mode: 0644]
reference_transformer/reference_transformer/tokenizer.py [new file with mode: 0644]
reference_transformer/tests/.test_tokenizers.py.swp [new file with mode: 0644]
reference_transformer/tests/__init__.py [new file with mode: 0644]
reference_transformer/tests/test_tokenizers.py [new file with mode: 0644]

diff --git a/reference_transformer/reference_transformer/__init__.py b/reference_transformer/reference_transformer/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/reference_transformer/reference_transformer/dataset.py b/reference_transformer/reference_transformer/dataset.py
new file mode 100644 (file)
index 0000000..2fee5cd
--- /dev/null
@@ -0,0 +1,7 @@
+import torch
+from torch.utils.data import Dataset
+
+
+class TextDataset(Dataset):
+    def __init__(self, path: str, tokenizer):
+
diff --git a/reference_transformer/reference_transformer/tokenizer.py b/reference_transformer/reference_transformer/tokenizer.py
new file mode 100644 (file)
index 0000000..60e505c
--- /dev/null
@@ -0,0 +1,25 @@
+import logging
+from abc import ABC, abstractmethod
+
+
+class Tokenizer(ABC):
+    @abstractmethod
+    def encode(self, text: str):
+        pass
+
+    @abstractmethod
+    def decode(self, ids: list):
+        pass
+
+
+class StupidTokenizer(Tokenizer):
+    def __init__(self, data: str):
+        super().__init__()
+        self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))}
+        self.i2c = {i: c for c, i in self.c2i.items()}
+        
+    def encode(self, text: str):
+        return [self.c2i[c] for c in text]
+
+    def decode(self, ids: list[int]):
+        return ''.join(self.i2c[i] for i in ids)
diff --git a/reference_transformer/tests/.test_tokenizers.py.swp b/reference_transformer/tests/.test_tokenizers.py.swp
new file mode 100644 (file)
index 0000000..bf60e2e
Binary files /dev/null and b/reference_transformer/tests/.test_tokenizers.py.swp differ
diff --git a/reference_transformer/tests/__init__.py b/reference_transformer/tests/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/reference_transformer/tests/test_tokenizers.py b/reference_transformer/tests/test_tokenizers.py
new file mode 100644 (file)
index 0000000..f66c929
--- /dev/null
@@ -0,0 +1,19 @@
+import pytest
+from reference_transformer.tokenizer import StupidTokenizer
+
+
+
+def test_data():
+    with open("./data/shakespeare.txt", 'r', encoding="utf8") as f:
+        data = f.read()
+    return data
+
+
+def test_stupid_encode_decode(test_data):
+    tokenizer = StupidTokenizer(test_data)
+    encoded = tokenizer.encode(test_data)
+    decoded = tokenizer.decode(encoded)
+    assert test_data == decoded, "EncodeDecode does not reproduce original data..."
+
+