From 9fb152583e8728872d37bbf992204ac8a6097d69 Mon Sep 17 00:00:00 2001 From: Sondre Wold Date: Wed, 19 Jun 2024 11:09:51 +0200 Subject: [PATCH] Add validation loop --- .../reference_transformer/train.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/reference_transformer/reference_transformer/train.py b/reference_transformer/reference_transformer/train.py index cf76257..afa92af 100644 --- a/reference_transformer/reference_transformer/train.py +++ b/reference_transformer/reference_transformer/train.py @@ -15,7 +15,7 @@ def parse_args(): epilog="Intended for educational purpose", ) parser.add_argument("--train_data_path", type=Path, default=None) - parser.add_argument("--validation_data_path", type=Path, default=None) + parser.add_argument("--val_data_path", type=Path, default=None) parser.add_argument("--sequence_length", type=int, default=64) parser.add_argument("--hidden_size", type=int, default=256) parser.add_argument("--decoder_layers", type=int, default=2) @@ -28,9 +28,10 @@ def parse_args(): return args -def train( +def run( model, train_data: TextDataset, + val_data: TextDataset, max_iter: int, batch_size: int, optimizer, @@ -38,27 +39,45 @@ def train( scheduler, ): model.train() - for i in tqdm(range(0, max_iter)): + for i in tqdm(range(0, max_iter), leave=False): optimizer.zero_grad() source, target = train_data.get_batch(batch_size) B, T = source.shape logits = model(source) logits = logits.transpose(-2, -1) loss = criterion(logits, target) - if i % 100 == 0: - print(f"Loss: {loss}") - + if i % 100 == 0 and i != 0: + val_loss = val(model, val_data, criterion, batch_size) + print(f"-Training loss: {loss}, Validation loss: {val_loss}") loss.backward() optimizer.step() scheduler.step() +@torch.no_grad() +def val(model, val_data: TextDataset, criterion, batch_size: int): + model.eval() + val_loss = 0.0 + val_iters = 100 + for i in tqdm(range(0, val_iters), leave=False): + source, target = val_data.get_batch(batch_size) + logits = model(source) + logits = logits.transpose(-2, -1) + loss = criterion(logits, target) + val_loss += loss.item() + model.train() + return val_loss / val_iters + + def main(args): with open(args.train_data_path, "r", encoding="utf8") as f: raw_train_data = f.read() + with open(args.val_data_path, "r", encoding="utf8") as f: + raw_val_data = f.read() tokenizer = StupidTokenizer(raw_train_data) max_steps = args.max_iter // args.batch_size train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length) + val_data = TextDataset(raw_val_data, tokenizer, args.sequence_length) model = GPT( args.hidden_size, tokenizer.vocab_size, @@ -70,9 +89,10 @@ def main(args): criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_steps) - train( + run( model, train_data, + val_data, args.max_iter, args.batch_size, optimizer, -- 2.39.5