self.n_features = n_features
self.weights = nn.Parameter(torch.empty((1, n_features), requires_grad=True))
if bias:
- self.bias = nn.Parameter(torch.empty((1), requires_grad=True))
+ self.bias = nn.Parameter(torch.empty(1, requires_grad=True))
else:
self.register_parameter("bias", None)
self.initialize_parameters()
def initialize_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weights)
if self.bias is not None:
- nn.init.uniform_(self.bias)
+ nn.init.zeros_(self.bias)
def forward(self, x) -> torch.Tensor:
return x @ self.weights.T + self.bias
--- /dev/null
+from complexity_regularizer.models import LinearRegressionModel
+from torch.utils.data.dataset import TensorDataset
+import torch
+import argparse
+from dataclasses import dataclass
+from complexity_regularizer.dataset_builder import create_regression_data
+from torch.utils.data import random_split, DataLoader
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser("Simple regression problem with an exponential regularization of the parameters")
+ parser.add_argument("--n_samples", type=int, default=100)
+ parser.add_argument("--n_latents", type=int, default=5)
+ parser.add_argument("--model_degree", type=int, default=2)
+ parser.add_argument("--noise_factor", type=float, default=0.05)
+ parser.add_argument("--lr", type=float, default=1e-3)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--val_batch_size", type=int, default=4)
+ parser.add_argument("--epochs", type=int, default=1)
+ return parser.parse_args()
+
+@dataclass
+class TrainConfig:
+ n_samples: int
+ n_latents: int
+ noise_factor: float
+ model_degree: int
+ lr: float
+ train_batch_size: int
+ val_batch_size: int
+ epochs: int
+
+
+class Trainer():
+ """A simple trainer class to keep track of the training"""
+ def __init__(self, config: TrainConfig):
+ data: TensorDataset = create_regression_data(
+ n_samples=config.n_samples,
+ n_latents=config.n_latents,
+ noise_factor=config.noise_factor,
+ model_degree=config.model_degree
+ )
+ self.config = config
+ train_dataset, val_dataset = random_split(data, lengths=[0.8, 0.2])
+ self.train_loader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True)
+ self.val_loader = DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=False)
+ self.model = LinearRegressionModel(n_features=config.model_degree + 1)
+ self.optimizer = torch.optim.SGD(self.model.parameters(), lr=config.lr)
+ self.criterion = torch.nn.MSELoss()
+
+ def train(self) -> None:
+ self.model.train()
+ running_loss: float = 0.0
+ for (X, y) in self.train_loader:
+ self.optimizer.zero_grad()
+ y_hat: torch.Tensor = self.model(X)
+ loss: torch.Tensor = self.criterion(y_hat.view(-1), y)
+ running_loss += loss.item()
+ loss.backward()
+ self.optimizer.step()
+
+ print(f"Train loss: {(running_loss / len(self.train_loader)):.3f}")
+
+ @torch.no_grad()
+ def val(self) -> None:
+ self.model.eval()
+ running_loss: float = 0.0
+ for (X, y) in self.val_loader:
+ y_hat: torch.Tensor = self.model(X)
+ loss: torch.Tensor = self.criterion(y_hat.view(-1), y)
+ running_loss += loss.item()
+
+ print(f"Validation loss: {(running_loss / len(self.val_loader)):.3f}")
+
+def main() -> None:
+ args: argparse.Namespace = parse_args()
+ config: TrainConfig = TrainConfig(**vars(args))
+ trainer: Trainer = Trainer(config)
+ for epoch in range(config.epochs):
+ print(f"Starting training for epoch: {epoch + 1}")
+ trainer.train()
+ trainer.val()
+
+if __name__ == "__main__":
+ main()
+
+
dependencies = [
{ name = "numpy" },
{ name = "torch" },
+ { name = "tqdm" },
]
[package.dev-dependencies]
requires-dist = [
{ name = "numpy", specifier = ">=2.3.4" },
{ name = "torch", specifier = ">=2.9.0" },
+ { name = "tqdm", specifier = ">=4.67.1" },
]
[package.metadata.requires-dev]
{ url = "https://files.pythonhosted.org/packages/fc/29/bd361e0cbb2c79ce6450f42643aaf6919956f89923a50571b0ebfe92d142/torch-2.9.0-cp314-cp314t-win_amd64.whl", hash = "sha256:695ba920f234ad4170c9c50e28d56c848432f8f530e6bc7f88fcb15ddf338e75", size = 109503850, upload-time = "2025-10-15T15:50:24.118Z" },
]
+[[package]]
+name = "tqdm"
+version = "4.67.1"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "colorama", marker = "sys_platform == 'win32'" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" },
+]
+
[[package]]
name = "triton"
version = "3.5.0"