import matplotlib.pyplot as plt
import numpy as np
+
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser("Simple regression problem with an exponential regularization of the parameters")
parser.add_argument("--n_samples", type=int, default=100)
parser.add_argument("--trials", type=int, default=4)
return parser.parse_args()
+
@dataclass
class TrainConfig:
n_samples: int
args: argparse.Namespace = parse_args()
seed_everything(args.seed)
- # Store results for all polynomial degrees
all_results = {}
-
- # Run experiments for each polynomial degree
for model_degree in args.model_degrees:
- print(f"\n{'='*60}")
print(f"Running experiments for polynomial degree: {model_degree}")
- print(f"{'='*60}")
- # Create config for this specific degree
config: TrainConfig = TrainConfig(
n_samples=args.n_samples,
n_latents=args.n_latents,
'std': std_val_loss
}
- # Create comparison plot
- plt.figure(figsize=(12, 7))
-
- #colors = plt.cm.viridis(np.linspace(0, 0.9, len(args.model_degrees)))
-
for idx, model_degree in enumerate(args.model_degrees):
mean_loss = all_results[model_degree]['mean']
std_loss = all_results[model_degree]['std']
-
iterations = range(len(mean_loss))
-
- # Plot mean line
plt.plot(iterations, mean_loss,
label=f'Degree {model_degree}',
linewidth=2)
- # Plot confidence interval
plt.fill_between(iterations,
mean_loss - std_loss,
mean_loss + std_loss,
)
plt.xlabel('Training Iteration', fontsize=12)
- plt.ylabel('Validation RMSE Loss', fontsize=12)
- plt.title(f'Polynomial Degree Comparison (Data generated with degree {args.n_latents})',
- fontsize=14, fontweight='bold')
- plt.legend(loc='best', fontsize=10)
- plt.grid(True, alpha=0.3)
+ plt.ylabel('Test RMSE', fontsize=12)
+ plt.title(f'Degree {args.n_latents} Polynomial)')
+ plt.legend(loc='best')
+ plt.grid(True, alpha=0.1)
plt.tight_layout()
-
- # Print final results summary
- print(f"\n{'='*60}")
- print("FINAL RESULTS SUMMARY")
- print(f"{'='*60}")
- print(f"Data generated with polynomial degree: {args.n_latents}")
- print(f"\nFinal validation losses (mean ± std):")
- for model_degree in args.model_degrees:
- final_mean = all_results[model_degree]['mean'][-1]
- final_std = all_results[model_degree]['std'][-1]
- print(f" Degree {model_degree:2d}: {final_mean:.6f} ± {final_std:.6f}")
-
plt.show()
if __name__ == "__main__":