]> git.sondrewold.no Git - complexity-regularizer.git/commitdiff
[plot] add AI generated plot for visualization
authorSondre Wold <[email protected]>
Sat, 18 Oct 2025 08:58:33 +0000 (10:58 +0200)
committerSondre Wold <[email protected]>
Sat, 18 Oct 2025 08:58:33 +0000 (10:58 +0200)
complexity_regularizer/train.py

index fe79d6c2a83dc71dc21cc825bd15f9a75ce45700..cad472d052bbe11fa42c58e3f7ee27b86cef48a1 100644 (file)
@@ -13,7 +13,8 @@ def parse_args() -> argparse.Namespace:
     parser = argparse.ArgumentParser("Simple regression problem with an exponential regularization of the parameters")
     parser.add_argument("--n_samples", type=int, default=100)
     parser.add_argument("--n_latents", type=int, default=5)
-    parser.add_argument("--model_degree", type=int, default=2)
+    parser.add_argument("--model_degrees", type=int, nargs='+', default=[1, 2, 3, 5, 7], 
+                        help="List of polynomial degrees to test")
     parser.add_argument("--noise_factor", type=float, default=0.05)
     parser.add_argument("--lr", type=float, default=1e-3)
     parser.add_argument("--gamma", type=float, default=1.5)
@@ -99,26 +100,88 @@ class Trainer():
 
 def main() -> None:
     args: argparse.Namespace = parse_args()
-    config: TrainConfig = TrainConfig(**vars(args))
-    seed_everything(config.seed)
-    test_losses: list[list[float]] = []
-    for trial in range(config.trials):
-        print(f"Trial: {trial}")
-        trainer: Trainer = Trainer(config)
-        test_losses.append(trainer.train())
-
-    mean_val_loss = np.mean(test_losses, axis=0)
-    std_val_loss = np.std(test_losses, axis=0)
+    seed_everything(args.seed)
+    
+    # Store results for all polynomial degrees
+    all_results = {}
+    
+    # Run experiments for each polynomial degree
+    for model_degree in args.model_degrees:
+        print(f"\n{'='*60}")
+        print(f"Running experiments for polynomial degree: {model_degree}")
+        print(f"{'='*60}")
+        
+        # Create config for this specific degree
+        config: TrainConfig = TrainConfig(
+            n_samples=args.n_samples,
+            n_latents=args.n_latents,
+            noise_factor=args.noise_factor,
+            model_degree=model_degree,
+            lr=args.lr,
+            train_batch_size=args.train_batch_size,
+            val_batch_size=args.val_batch_size,
+            gamma=args.gamma,
+            seed=args.seed,
+            trials=args.trials
+        )
+        
+        test_losses: list[list[float]] = []
+        for trial in range(config.trials):
+            print(f"  Trial: {trial + 1}/{config.trials}")
+            trainer: Trainer = Trainer(config)
+            test_losses.append(trainer.train())
 
-    plt.plot(mean_val_loss, label='Mean Loss')
-    plt.fill_between(range(len(mean_val_loss)),
-                                      mean_val_loss - std_val_loss,
-                                      mean_val_loss + std_val_loss,
-                                      alpha=0.1, label='±1 Std Dev')
-    plt.legend()
+        mean_val_loss = np.mean(test_losses, axis=0)
+        std_val_loss = np.std(test_losses, axis=0)
+        
+        all_results[model_degree] = {
+            'mean': mean_val_loss,
+            'std': std_val_loss
+        }
+    
+    # Create comparison plot
+    plt.figure(figsize=(12, 7))
+    
+    #colors = plt.cm.viridis(np.linspace(0, 0.9, len(args.model_degrees)))
+    
+    for idx, model_degree in enumerate(args.model_degrees):
+        mean_loss = all_results[model_degree]['mean']
+        std_loss = all_results[model_degree]['std']
+        
+        iterations = range(len(mean_loss))
+        
+        # Plot mean line
+        plt.plot(iterations, mean_loss, 
+                label=f'Degree {model_degree}', 
+                linewidth=2)
+        
+        # Plot confidence interval
+        plt.fill_between(iterations,
+                        mean_loss - std_loss,
+                        mean_loss + std_loss,
+                        alpha=0.2,
+                        )
+    
+    plt.xlabel('Training Iteration', fontsize=12)
+    plt.ylabel('Validation RMSE Loss', fontsize=12)
+    plt.title(f'Polynomial Degree Comparison (Data generated with degree {args.n_latents})', 
+              fontsize=14, fontweight='bold')
+    plt.legend(loc='best', fontsize=10)
+    plt.grid(True, alpha=0.3)
+    plt.tight_layout()
+    
+    # Print final results summary
+    print(f"\n{'='*60}")
+    print("FINAL RESULTS SUMMARY")
+    print(f"{'='*60}")
+    print(f"Data generated with polynomial degree: {args.n_latents}")
+    print(f"\nFinal validation losses (mean ± std):")
+    for model_degree in args.model_degrees:
+        final_mean = all_results[model_degree]['mean'][-1]
+        final_std = all_results[model_degree]['std'][-1]
+        print(f"  Degree {model_degree:2d}: {final_mean:.6f} ± {final_std:.6f}")
+    
     plt.show()
     
 if __name__ == "__main__":
     main()
-
-