]> git.sondrewold.no Git - complexity-regularizer.git/commitdiff
[model] add a simple linear model impl
authorSondre Wold <[email protected]>
Sat, 18 Oct 2025 06:45:50 +0000 (08:45 +0200)
committerSondre Wold <[email protected]>
Sat, 18 Oct 2025 06:45:50 +0000 (08:45 +0200)
complexity_regularizer/models.py [new file with mode: 0644]

diff --git a/complexity_regularizer/models.py b/complexity_regularizer/models.py
new file mode 100644 (file)
index 0000000..15bd52f
--- /dev/null
@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+
+class LinearRegressionModel(nn.Module):
+    """ A plain linear layer: y = xW^T + b"""
+    def __init__(self, n_features: int, bias: bool = True) -> None:
+        super(LinearRegressionModel, self).__init__()
+        self.n_features = n_features
+        self.weights = nn.Parameter(torch.empty((1, n_features), requires_grad=True))
+        if bias:
+            self.bias = nn.Parameter(torch.empty((1), requires_grad=True))
+        else:
+            self.register_parameter("bias", None)
+        self.initialize_parameters()
+
+    def initialize_parameters(self) -> None:
+        nn.init.kaiming_uniform_(self.weights)
+        if self.bias is not None:
+            nn.init.uniform_(self.bias)
+
+    def forward(self, x) -> torch.Tensor:
+        return x @ self.weights.T + self.bias
+
+
+def main() -> None:
+    n_features = 3
+    batch_size = 5
+    mock_data = torch.randn((batch_size, n_features))
+    model = LinearRegressionModel(n_features)
+    model(mock_data)
+
+
+if __name__ == "__main__":
+    main()