From: Sondre Wold Date: Mon, 17 Jun 2024 11:45:29 +0000 (+0200) Subject: Add StupidTokenizer and test X-Git-Url: https://letsjmore.com/?a=commitdiff_plain;h=7a9053858bc7b118309e67b689a32e5ac65c7193;p=reference_transformer.git Add StupidTokenizer and test --- diff --git a/reference_transformer/reference_transformer/__init__.py b/reference_transformer/reference_transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reference_transformer/reference_transformer/dataset.py b/reference_transformer/reference_transformer/dataset.py new file mode 100644 index 0000000..2fee5cd --- /dev/null +++ b/reference_transformer/reference_transformer/dataset.py @@ -0,0 +1,7 @@ +import torch +from torch.utils.data import Dataset + + +class TextDataset(Dataset): + def __init__(self, path: str, tokenizer): + diff --git a/reference_transformer/reference_transformer/tokenizer.py b/reference_transformer/reference_transformer/tokenizer.py new file mode 100644 index 0000000..60e505c --- /dev/null +++ b/reference_transformer/reference_transformer/tokenizer.py @@ -0,0 +1,25 @@ +import logging +from abc import ABC, abstractmethod + + +class Tokenizer(ABC): + @abstractmethod + def encode(self, text: str): + pass + + @abstractmethod + def decode(self, ids: list): + pass + + +class StupidTokenizer(Tokenizer): + def __init__(self, data: str): + super().__init__() + self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))} + self.i2c = {i: c for c, i in self.c2i.items()} + + def encode(self, text: str): + return [self.c2i[c] for c in text] + + def decode(self, ids: list[int]): + return ''.join(self.i2c[i] for i in ids) diff --git a/reference_transformer/tests/.test_tokenizers.py.swp b/reference_transformer/tests/.test_tokenizers.py.swp new file mode 100644 index 0000000..bf60e2e Binary files /dev/null and b/reference_transformer/tests/.test_tokenizers.py.swp differ diff --git a/reference_transformer/tests/__init__.py b/reference_transformer/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reference_transformer/tests/test_tokenizers.py b/reference_transformer/tests/test_tokenizers.py new file mode 100644 index 0000000..f66c929 --- /dev/null +++ b/reference_transformer/tests/test_tokenizers.py @@ -0,0 +1,19 @@ +import pytest +from reference_transformer.tokenizer import StupidTokenizer + + + +@pytest.fixture +def test_data(): + with open("./data/shakespeare.txt", 'r', encoding="utf8") as f: + data = f.read() + return data + + +def test_stupid_encode_decode(test_data): + tokenizer = StupidTokenizer(test_data) + encoded = tokenizer.encode(test_data) + decoded = tokenizer.decode(encoded) + assert test_data == decoded, "EncodeDecode does not reproduce original data..." + +