import matplotlib.pyplot as plt
 import numpy as np
 
+
 def parse_args() -> argparse.Namespace:
     parser = argparse.ArgumentParser("Simple regression problem with an exponential regularization of the parameters")
     parser.add_argument("--n_samples", type=int, default=100)
     parser.add_argument("--trials", type=int, default=4)
     return parser.parse_args()
 
+
 @dataclass
 class TrainConfig:
     n_samples: int
     args: argparse.Namespace = parse_args()
     seed_everything(args.seed)
     
-    # Store results for all polynomial degrees
     all_results = {}
-    
-    # Run experiments for each polynomial degree
     for model_degree in args.model_degrees:
-        print(f"\n{'='*60}")
         print(f"Running experiments for polynomial degree: {model_degree}")
-        print(f"{'='*60}")
         
-        # Create config for this specific degree
         config: TrainConfig = TrainConfig(
             n_samples=args.n_samples,
             n_latents=args.n_latents,
             'std': std_val_loss
         }
     
-    # Create comparison plot
-    plt.figure(figsize=(12, 7))
-    
-    #colors = plt.cm.viridis(np.linspace(0, 0.9, len(args.model_degrees)))
-    
     for idx, model_degree in enumerate(args.model_degrees):
         mean_loss = all_results[model_degree]['mean']
         std_loss = all_results[model_degree]['std']
-        
         iterations = range(len(mean_loss))
-        
-        # Plot mean line
         plt.plot(iterations, mean_loss, 
                 label=f'Degree {model_degree}', 
                 linewidth=2)
         
-        # Plot confidence interval
         plt.fill_between(iterations,
                         mean_loss - std_loss,
                         mean_loss + std_loss,
                         )
     
     plt.xlabel('Training Iteration', fontsize=12)
-    plt.ylabel('Validation RMSE Loss', fontsize=12)
-    plt.title(f'Polynomial Degree Comparison (Data generated with degree {args.n_latents})', 
-              fontsize=14, fontweight='bold')
-    plt.legend(loc='best', fontsize=10)
-    plt.grid(True, alpha=0.3)
+    plt.ylabel('Test RMSE', fontsize=12)
+    plt.title(f'Degree {args.n_latents} Polynomial)')
+    plt.legend(loc='best')
+    plt.grid(True, alpha=0.1)
     plt.tight_layout()
-    
-    # Print final results summary
-    print(f"\n{'='*60}")
-    print("FINAL RESULTS SUMMARY")
-    print(f"{'='*60}")
-    print(f"Data generated with polynomial degree: {args.n_latents}")
-    print(f"\nFinal validation losses (mean ± std):")
-    for model_degree in args.model_degrees:
-        final_mean = all_results[model_degree]['mean'][-1]
-        final_std = all_results[model_degree]['std'][-1]
-        print(f"  Degree {model_degree:2d}: {final_mean:.6f} ± {final_std:.6f}")
-    
     plt.show()
     
 if __name__ == "__main__":