From: Sondre Wold Date: Sat, 18 Oct 2025 09:18:46 +0000 (+0200) Subject: [plot] clean messy plot code X-Git-Url: https://letsjmore.com/?a=commitdiff_plain;h=HEAD;p=complexity-regularizer.git [plot] clean messy plot code --- diff --git a/complexity_regularizer/train.py b/complexity_regularizer/train.py index cad472d..6af958b 100644 --- a/complexity_regularizer/train.py +++ b/complexity_regularizer/train.py @@ -9,6 +9,7 @@ from complexity_regularizer.utils import seed_everything import matplotlib.pyplot as plt import numpy as np + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser("Simple regression problem with an exponential regularization of the parameters") parser.add_argument("--n_samples", type=int, default=100) @@ -24,6 +25,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--trials", type=int, default=4) return parser.parse_args() + @dataclass class TrainConfig: n_samples: int @@ -102,16 +104,10 @@ def main() -> None: args: argparse.Namespace = parse_args() seed_everything(args.seed) - # Store results for all polynomial degrees all_results = {} - - # Run experiments for each polynomial degree for model_degree in args.model_degrees: - print(f"\n{'='*60}") print(f"Running experiments for polynomial degree: {model_degree}") - print(f"{'='*60}") - # Create config for this specific degree config: TrainConfig = TrainConfig( n_samples=args.n_samples, n_latents=args.n_latents, @@ -139,23 +135,14 @@ def main() -> None: 'std': std_val_loss } - # Create comparison plot - plt.figure(figsize=(12, 7)) - - #colors = plt.cm.viridis(np.linspace(0, 0.9, len(args.model_degrees))) - for idx, model_degree in enumerate(args.model_degrees): mean_loss = all_results[model_degree]['mean'] std_loss = all_results[model_degree]['std'] - iterations = range(len(mean_loss)) - - # Plot mean line plt.plot(iterations, mean_loss, label=f'Degree {model_degree}', linewidth=2) - # Plot confidence interval plt.fill_between(iterations, mean_loss - std_loss, mean_loss + std_loss, @@ -163,24 +150,11 @@ def main() -> None: ) plt.xlabel('Training Iteration', fontsize=12) - plt.ylabel('Validation RMSE Loss', fontsize=12) - plt.title(f'Polynomial Degree Comparison (Data generated with degree {args.n_latents})', - fontsize=14, fontweight='bold') - plt.legend(loc='best', fontsize=10) - plt.grid(True, alpha=0.3) + plt.ylabel('Test RMSE', fontsize=12) + plt.title(f'Degree {args.n_latents} Polynomial)') + plt.legend(loc='best') + plt.grid(True, alpha=0.1) plt.tight_layout() - - # Print final results summary - print(f"\n{'='*60}") - print("FINAL RESULTS SUMMARY") - print(f"{'='*60}") - print(f"Data generated with polynomial degree: {args.n_latents}") - print(f"\nFinal validation losses (mean ± std):") - for model_degree in args.model_degrees: - final_mean = all_results[model_degree]['mean'][-1] - final_std = all_results[model_degree]['std'][-1] - print(f" Degree {model_degree:2d}: {final_mean:.6f} ± {final_std:.6f}") - plt.show() if __name__ == "__main__":