From 53fdc1c7ac1d854ad8b8e9953f442b90794e2673 Mon Sep 17 00:00:00 2001 From: Sondre Wold Date: Sat, 18 Oct 2025 09:27:37 +0200 Subject: [PATCH] [train] add train loop --- complexity_regularizer/models.py | 4 +- complexity_regularizer/train.py | 86 ++++++++++++++++++++++++++++++++ pyproject.toml | 1 + uv.lock | 14 ++++++ 4 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 complexity_regularizer/train.py diff --git a/complexity_regularizer/models.py b/complexity_regularizer/models.py index 15bd52f..8b12ac6 100644 --- a/complexity_regularizer/models.py +++ b/complexity_regularizer/models.py @@ -8,7 +8,7 @@ class LinearRegressionModel(nn.Module): self.n_features = n_features self.weights = nn.Parameter(torch.empty((1, n_features), requires_grad=True)) if bias: - self.bias = nn.Parameter(torch.empty((1), requires_grad=True)) + self.bias = nn.Parameter(torch.empty(1, requires_grad=True)) else: self.register_parameter("bias", None) self.initialize_parameters() @@ -16,7 +16,7 @@ class LinearRegressionModel(nn.Module): def initialize_parameters(self) -> None: nn.init.kaiming_uniform_(self.weights) if self.bias is not None: - nn.init.uniform_(self.bias) + nn.init.zeros_(self.bias) def forward(self, x) -> torch.Tensor: return x @ self.weights.T + self.bias diff --git a/complexity_regularizer/train.py b/complexity_regularizer/train.py new file mode 100644 index 0000000..54b6fe9 --- /dev/null +++ b/complexity_regularizer/train.py @@ -0,0 +1,86 @@ +from complexity_regularizer.models import LinearRegressionModel +from torch.utils.data.dataset import TensorDataset +import torch +import argparse +from dataclasses import dataclass +from complexity_regularizer.dataset_builder import create_regression_data +from torch.utils.data import random_split, DataLoader + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser("Simple regression problem with an exponential regularization of the parameters") + parser.add_argument("--n_samples", type=int, default=100) + parser.add_argument("--n_latents", type=int, default=5) + parser.add_argument("--model_degree", type=int, default=2) + parser.add_argument("--noise_factor", type=float, default=0.05) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--val_batch_size", type=int, default=4) + parser.add_argument("--epochs", type=int, default=1) + return parser.parse_args() + +@dataclass +class TrainConfig: + n_samples: int + n_latents: int + noise_factor: float + model_degree: int + lr: float + train_batch_size: int + val_batch_size: int + epochs: int + + +class Trainer(): + """A simple trainer class to keep track of the training""" + def __init__(self, config: TrainConfig): + data: TensorDataset = create_regression_data( + n_samples=config.n_samples, + n_latents=config.n_latents, + noise_factor=config.noise_factor, + model_degree=config.model_degree + ) + self.config = config + train_dataset, val_dataset = random_split(data, lengths=[0.8, 0.2]) + self.train_loader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True) + self.val_loader = DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=False) + self.model = LinearRegressionModel(n_features=config.model_degree + 1) + self.optimizer = torch.optim.SGD(self.model.parameters(), lr=config.lr) + self.criterion = torch.nn.MSELoss() + + def train(self) -> None: + self.model.train() + running_loss: float = 0.0 + for (X, y) in self.train_loader: + self.optimizer.zero_grad() + y_hat: torch.Tensor = self.model(X) + loss: torch.Tensor = self.criterion(y_hat.view(-1), y) + running_loss += loss.item() + loss.backward() + self.optimizer.step() + + print(f"Train loss: {(running_loss / len(self.train_loader)):.3f}") + + @torch.no_grad() + def val(self) -> None: + self.model.eval() + running_loss: float = 0.0 + for (X, y) in self.val_loader: + y_hat: torch.Tensor = self.model(X) + loss: torch.Tensor = self.criterion(y_hat.view(-1), y) + running_loss += loss.item() + + print(f"Validation loss: {(running_loss / len(self.val_loader)):.3f}") + +def main() -> None: + args: argparse.Namespace = parse_args() + config: TrainConfig = TrainConfig(**vars(args)) + trainer: Trainer = Trainer(config) + for epoch in range(config.epochs): + print(f"Starting training for epoch: {epoch + 1}") + trainer.train() + trainer.val() + +if __name__ == "__main__": + main() + + diff --git a/pyproject.toml b/pyproject.toml index b0c638e..33e4bd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.13" dependencies = [ "numpy>=2.3.4", "torch>=2.9.0", + "tqdm>=4.67.1", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index 655b89d..9d4d295 100644 --- a/uv.lock +++ b/uv.lock @@ -18,6 +18,7 @@ source = { virtual = "." } dependencies = [ { name = "numpy" }, { name = "torch" }, + { name = "tqdm" }, ] [package.dev-dependencies] @@ -31,6 +32,7 @@ dev = [ requires-dist = [ { name = "numpy", specifier = ">=2.3.4" }, { name = "torch", specifier = ">=2.9.0" }, + { name = "tqdm", specifier = ">=4.67.1" }, ] [package.metadata.requires-dev] @@ -489,6 +491,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" }, ] +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + [[package]] name = "triton" version = "3.5.0" -- 2.39.5