From: Sondre Wold Date: Sat, 18 Oct 2025 06:45:50 +0000 (+0200) Subject: [model] add a simple linear model impl X-Git-Url: https://letsjmore.com/?a=commitdiff_plain;h=930007e1020a22922a9b1bd4665514b8f35a7dba;p=complexity-regularizer.git [model] add a simple linear model impl --- diff --git a/complexity_regularizer/models.py b/complexity_regularizer/models.py new file mode 100644 index 0000000..15bd52f --- /dev/null +++ b/complexity_regularizer/models.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +class LinearRegressionModel(nn.Module): + """ A plain linear layer: y = xW^T + b""" + def __init__(self, n_features: int, bias: bool = True) -> None: + super(LinearRegressionModel, self).__init__() + self.n_features = n_features + self.weights = nn.Parameter(torch.empty((1, n_features), requires_grad=True)) + if bias: + self.bias = nn.Parameter(torch.empty((1), requires_grad=True)) + else: + self.register_parameter("bias", None) + self.initialize_parameters() + + def initialize_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weights) + if self.bias is not None: + nn.init.uniform_(self.bias) + + def forward(self, x) -> torch.Tensor: + return x @ self.weights.T + self.bias + + +def main() -> None: + n_features = 3 + batch_size = 5 + mock_data = torch.randn((batch_size, n_features)) + model = LinearRegressionModel(n_features) + model(mock_data) + + +if __name__ == "__main__": + main()