Files
COGMOD-HWI/Project/plot_diagnostics.py
T
2025-04-24 17:45:54 -04:00

81 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import glob
import json
import pandas as pd
import matplotlib.pyplot as plt
# BASE_DIR = "./on_GPU"
# training diagnostics
histories = []
for BASE_DIR in ['./on_CPU', './on_GPU']:
for csv in glob.glob(f"{BASE_DIR}/analytics/history_*.csv"):
df = pd.read_csv(csv) # epochs x metrics
tag = "_".join(csv.split("_")[1:4]) # like sgd_none_42.csv
df['condition'] = tag.replace(".csv", "")
histories.append(df)
logs = pd.concat(histories, ignore_index=True)
# average across seeds, optimiser, augmentation for clarity
mean_log = logs.groupby('epoch').mean(numeric_only=True)
plt.figure()
plt.plot(mean_log.index, mean_log['train_acc'],
color='#228B22', label='train')
plt.plot(mean_log.index, mean_log['test_acc'],
color='#800080', label='validation')
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend()
plt.grid(True)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(f"train_val_accuracy.png")
plt.figure()
plt.plot(mean_log.index, mean_log['train_loss'],
color='#008080', label='train') # teal
plt.plot(mean_log.index, mean_log['test_loss'],
color='#FF00FF', label='validation') # magenta
plt.xlabel("epoch")
plt.ylabel("crossentropy loss")
plt.legend()
plt.grid(True)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(f"train_val_loss.png")
# robustness curves
with open(f"./combined_results.json") as f:
res = json.load(f)
df = pd.DataFrame(res)
records = [] # one row per sigma
for _, row in df.iterrows():
for sigma, acc in row['robustness'].items():
records.append({
'optimizer': row['optimizer'],
'augmentation': row['augmentation'],
'sigma': float(sigma),
'acc': acc,
})
rob_df = pd.DataFrame(records)
pivot = rob_df.groupby(['optimizer', 'sigma']).acc.mean().unstack(0)
ax = pivot.plot(marker='o', linestyle='-', color=['#FF0000', '#00008B'])
plt.xlabel("Gaussian noise sigma")
plt.ylabel("accuracy")
plt.title("Noise robustness")
plt.grid(True)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(f"./robustness_curve.png")
print("saved robustness_curve.png")