]> git.sondrewold.no Git - complexity-regularizer.git/commitdiff
[train] add train loop
authorSondre Wold <[email protected]>
Sat, 18 Oct 2025 07:27:37 +0000 (09:27 +0200)
committerSondre Wold <[email protected]>
Sat, 18 Oct 2025 07:27:37 +0000 (09:27 +0200)
complexity_regularizer/models.py
complexity_regularizer/train.py [new file with mode: 0644]
pyproject.toml
uv.lock

index 15bd52f0e53e01ce4160b76d4789f88395cb6280..8b12ac69d34162726af1381ce223f0534dc4d0cd 100644 (file)
@@ -8,7 +8,7 @@ class LinearRegressionModel(nn.Module):
         self.n_features = n_features
         self.weights = nn.Parameter(torch.empty((1, n_features), requires_grad=True))
         if bias:
-            self.bias = nn.Parameter(torch.empty((1), requires_grad=True))
+            self.bias = nn.Parameter(torch.empty(1, requires_grad=True))
         else:
             self.register_parameter("bias", None)
         self.initialize_parameters()
@@ -16,7 +16,7 @@ class LinearRegressionModel(nn.Module):
     def initialize_parameters(self) -> None:
         nn.init.kaiming_uniform_(self.weights)
         if self.bias is not None:
-            nn.init.uniform_(self.bias)
+            nn.init.zeros_(self.bias)
 
     def forward(self, x) -> torch.Tensor:
         return x @ self.weights.T + self.bias
diff --git a/complexity_regularizer/train.py b/complexity_regularizer/train.py
new file mode 100644 (file)
index 0000000..54b6fe9
--- /dev/null
@@ -0,0 +1,86 @@
+from complexity_regularizer.models import LinearRegressionModel
+from torch.utils.data.dataset import TensorDataset
+import torch
+import argparse
+from dataclasses import dataclass
+from complexity_regularizer.dataset_builder import create_regression_data
+from torch.utils.data import random_split, DataLoader
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser("Simple regression problem with an exponential regularization of the parameters")
+    parser.add_argument("--n_samples", type=int, default=100)
+    parser.add_argument("--n_latents", type=int, default=5)
+    parser.add_argument("--model_degree", type=int, default=2)
+    parser.add_argument("--noise_factor", type=float, default=0.05)
+    parser.add_argument("--lr", type=float, default=1e-3)
+    parser.add_argument("--train_batch_size", type=int, default=8)
+    parser.add_argument("--val_batch_size", type=int, default=4)
+    parser.add_argument("--epochs", type=int, default=1)
+    return parser.parse_args()
+
+@dataclass
+class TrainConfig:
+    n_samples: int
+    n_latents: int
+    noise_factor: float
+    model_degree: int
+    lr: float
+    train_batch_size: int
+    val_batch_size: int
+    epochs: int
+
+
+class Trainer():
+    """A simple trainer class to keep track of the training"""
+    def __init__(self, config: TrainConfig):
+        data: TensorDataset = create_regression_data(
+                n_samples=config.n_samples,
+                n_latents=config.n_latents,
+                noise_factor=config.noise_factor,
+                model_degree=config.model_degree
+        )
+        self.config = config
+        train_dataset, val_dataset = random_split(data, lengths=[0.8, 0.2])
+        self.train_loader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True)
+        self.val_loader = DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=False)
+        self.model = LinearRegressionModel(n_features=config.model_degree + 1)
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=config.lr)
+        self.criterion = torch.nn.MSELoss()
+
+    def train(self) -> None:
+        self.model.train()
+        running_loss: float = 0.0
+        for (X, y) in self.train_loader:
+            self.optimizer.zero_grad()
+            y_hat: torch.Tensor = self.model(X)
+            loss: torch.Tensor = self.criterion(y_hat.view(-1), y)
+            running_loss += loss.item()
+            loss.backward()
+            self.optimizer.step()
+
+        print(f"Train loss: {(running_loss / len(self.train_loader)):.3f}")
+
+    @torch.no_grad()
+    def val(self) -> None:
+        self.model.eval()
+        running_loss: float = 0.0
+        for (X, y) in self.val_loader:
+            y_hat: torch.Tensor = self.model(X)
+            loss: torch.Tensor = self.criterion(y_hat.view(-1), y)
+            running_loss += loss.item()
+
+        print(f"Validation loss: {(running_loss / len(self.val_loader)):.3f}")
+
+def main() -> None:
+    args: argparse.Namespace = parse_args()
+    config: TrainConfig = TrainConfig(**vars(args))
+    trainer: Trainer = Trainer(config)
+    for epoch in range(config.epochs):
+        print(f"Starting training for epoch: {epoch + 1}")
+        trainer.train()
+        trainer.val()
+    
+if __name__ == "__main__":
+    main()
+
+
index b0c638e8893bd34b08c6de265c89d754a02900d2..33e4bd7e48c9d768b90f5545839c019e28cbfdfd 100644 (file)
@@ -7,6 +7,7 @@ requires-python = ">=3.13"
 dependencies = [
     "numpy>=2.3.4",
     "torch>=2.9.0",
+    "tqdm>=4.67.1",
 ]
 
 [dependency-groups]
diff --git a/uv.lock b/uv.lock
index 655b89d376a464ba516441aff4f61d3551f4c17a..9d4d2950081d2a10f1d841f9a3b3333f480547e4 100644 (file)
--- a/uv.lock
+++ b/uv.lock
@@ -18,6 +18,7 @@ source = { virtual = "." }
 dependencies = [
     { name = "numpy" },
     { name = "torch" },
+    { name = "tqdm" },
 ]
 
 [package.dev-dependencies]
@@ -31,6 +32,7 @@ dev = [
 requires-dist = [
     { name = "numpy", specifier = ">=2.3.4" },
     { name = "torch", specifier = ">=2.9.0" },
+    { name = "tqdm", specifier = ">=4.67.1" },
 ]
 
 [package.metadata.requires-dev]
@@ -489,6 +491,18 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" },
 ]
 
+[[package]]
+name = "tqdm"
+version = "4.67.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "colorama", marker = "sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" },
+]
+
 [[package]]
 name = "triton"
 version = "3.5.0"