From 7a9053858bc7b118309e67b689a32e5ac65c7193 Mon Sep 17 00:00:00 2001 From: Sondre Wold Date: Mon, 17 Jun 2024 13:45:29 +0200 Subject: [PATCH] Add StupidTokenizer and test --- .../reference_transformer/__init__.py | 0 .../reference_transformer/dataset.py | 7 +++++ .../reference_transformer/tokenizer.py | 25 ++++++++++++++++++ .../tests/.test_tokenizers.py.swp | Bin 0 -> 12288 bytes reference_transformer/tests/__init__.py | 0 .../tests/test_tokenizers.py | 19 +++++++++++++ 6 files changed, 51 insertions(+) create mode 100644 reference_transformer/reference_transformer/__init__.py create mode 100644 reference_transformer/reference_transformer/dataset.py create mode 100644 reference_transformer/reference_transformer/tokenizer.py create mode 100644 reference_transformer/tests/.test_tokenizers.py.swp create mode 100644 reference_transformer/tests/__init__.py create mode 100644 reference_transformer/tests/test_tokenizers.py diff --git a/reference_transformer/reference_transformer/__init__.py b/reference_transformer/reference_transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reference_transformer/reference_transformer/dataset.py b/reference_transformer/reference_transformer/dataset.py new file mode 100644 index 0000000..2fee5cd --- /dev/null +++ b/reference_transformer/reference_transformer/dataset.py @@ -0,0 +1,7 @@ +import torch +from torch.utils.data import Dataset + + +class TextDataset(Dataset): + def __init__(self, path: str, tokenizer): + diff --git a/reference_transformer/reference_transformer/tokenizer.py b/reference_transformer/reference_transformer/tokenizer.py new file mode 100644 index 0000000..60e505c --- /dev/null +++ b/reference_transformer/reference_transformer/tokenizer.py @@ -0,0 +1,25 @@ +import logging +from abc import ABC, abstractmethod + + +class Tokenizer(ABC): + @abstractmethod + def encode(self, text: str): + pass + + @abstractmethod + def decode(self, ids: list): + pass + + +class StupidTokenizer(Tokenizer): + def __init__(self, data: str): + super().__init__() + self.c2i = {c: i for i, c in enumerate(sorted(list(set(data))))} + self.i2c = {i: c for c, i in self.c2i.items()} + + def encode(self, text: str): + return [self.c2i[c] for c in text] + + def decode(self, ids: list[int]): + return ''.join(self.i2c[i] for i in ids) diff --git a/reference_transformer/tests/.test_tokenizers.py.swp b/reference_transformer/tests/.test_tokenizers.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..bf60e2e9000681f145bea857f7531209dbad96fb GIT binary patch literal 12288 zcmeI2&2G~`5XU#%3y8kJ)`iwV$)*$ul}M3LRSq1AkP2KQ;~lpqiPzfQs4Wuk4!i>5 z3HZ432ndNM;fOdeYdbEffHR=~Du3^XXJ==Ab`DYI$=2h0`|`Hh65KBHXRXY>T>swr zyxS4Z6x!0U3Ag^#4k|m>-%rwAVpH1;+rTxRTy>~B^*Io=VjV{ax@ z8tW?S8}I+bdc%6o>at>1@EftvKmY_l00ck)1V8`;KmY_l00jPe0+O2$;hH!{*2}8! zS|>h{qmk5<8cn(JM zRjxXR*{JhNd!$pz!&!(EPIulpfjB97(=Ik?{Lw-P< zJ164jwy1)y=3%XdATwnt(Q$qJaqf?#DXEAfH3$X+ca%)Xm6TXg{=`Q;xoJ0hvN|v? z4o4B!=I%sY#w)X-XyXcdu@~$$ literal 0 HcmV?d00001 diff --git a/reference_transformer/tests/__init__.py b/reference_transformer/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reference_transformer/tests/test_tokenizers.py b/reference_transformer/tests/test_tokenizers.py new file mode 100644 index 0000000..f66c929 --- /dev/null +++ b/reference_transformer/tests/test_tokenizers.py @@ -0,0 +1,19 @@ +import pytest +from reference_transformer.tokenizer import StupidTokenizer + + + +@pytest.fixture +def test_data(): + with open("./data/shakespeare.txt", 'r', encoding="utf8") as f: + data = f.read() + return data + + +def test_stupid_encode_decode(test_data): + tokenizer = StupidTokenizer(test_data) + encoded = tokenizer.encode(test_data) + decoded = tokenizer.decode(encoded) + assert test_data == decoded, "EncodeDecode does not reproduce original data..." + + -- 2.39.5