]> git.sondrewold.no Git - reference_transformer.git/commitdiff
Add validation loop main
authorSondre Wold <[email protected]>
Wed, 19 Jun 2024 09:09:51 +0000 (11:09 +0200)
committerSondre Wold <[email protected]>
Wed, 19 Jun 2024 09:09:51 +0000 (11:09 +0200)
reference_transformer/reference_transformer/train.py

index cf7625754721c150ae8d95487b7d8f1d984abc2b..afa92afe6ca205da46c009ff28cd6f39100c3656 100644 (file)
@@ -15,7 +15,7 @@ def parse_args():
         epilog="Intended for educational purpose",
     )
     parser.add_argument("--train_data_path", type=Path, default=None)
-    parser.add_argument("--validation_data_path", type=Path, default=None)
+    parser.add_argument("--val_data_path", type=Path, default=None)
     parser.add_argument("--sequence_length", type=int, default=64)
     parser.add_argument("--hidden_size", type=int, default=256)
     parser.add_argument("--decoder_layers", type=int, default=2)
@@ -28,9 +28,10 @@ def parse_args():
     return args
 
 
-def train(
+def run(
     model,
     train_data: TextDataset,
+    val_data: TextDataset,
     max_iter: int,
     batch_size: int,
     optimizer,
@@ -38,27 +39,45 @@ def train(
     scheduler,
 ):
     model.train()
-    for i in tqdm(range(0, max_iter)):
+    for i in tqdm(range(0, max_iter), leave=False):
         optimizer.zero_grad()
         source, target = train_data.get_batch(batch_size)
         B, T = source.shape
         logits = model(source)
         logits = logits.transpose(-2, -1)
         loss = criterion(logits, target)
-        if i % 100 == 0:
-            print(f"Loss: {loss}")
-
+        if i % 100 == 0 and i != 0:
+            val_loss = val(model, val_data, criterion, batch_size)
+            print(f"-Training loss: {loss}, Validation loss: {val_loss}")
         loss.backward()
         optimizer.step()
         scheduler.step()
 
 
+def val(model, val_data: TextDataset, criterion, batch_size: int):
+    model.eval()
+    val_loss = 0.0
+    val_iters = 100
+    for i in tqdm(range(0, val_iters), leave=False):
+        source, target = val_data.get_batch(batch_size)
+        logits = model(source)
+        logits = logits.transpose(-2, -1)
+        loss = criterion(logits, target)
+        val_loss += loss.item()
+    model.train()
+    return val_loss / val_iters
+
+
 def main(args):
     with open(args.train_data_path, "r", encoding="utf8") as f:
         raw_train_data = f.read()
+    with open(args.val_data_path, "r", encoding="utf8") as f:
+        raw_val_data = f.read()
     tokenizer = StupidTokenizer(raw_train_data)
     max_steps = args.max_iter // args.batch_size
     train_data = TextDataset(raw_train_data, tokenizer, args.sequence_length)
+    val_data = TextDataset(raw_val_data, tokenizer, args.sequence_length)
     model = GPT(
         args.hidden_size,
         tokenizer.vocab_size,
@@ -70,9 +89,10 @@ def main(args):
     criterion = torch.nn.CrossEntropyLoss()
     optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_steps)
-    train(
+    run(
         model,
         train_data,
+        val_data,
         args.max_iter,
         args.batch_size,
         optimizer,